diff --git a/.gitignore b/.gitignore index 6d90763c1..50dfc2302 100755 --- a/.gitignore +++ b/.gitignore @@ -62,10 +62,10 @@ cellacdc/metrics/* !cellacdc/metrics/CV.py !cellacdc/metrics/combine_metrics_example.py !cellacdc/metrics/channel_indipendent_metric_example.py -cellacdc/models/beno -cellacdc/models/Simone_cellpose +cellacdc/segmenters/beno +cellacdc/segmenters/Simone_cellpose cellacdc/timon_tests -cellacdc/models/test_segm_model +cellacdc/segmenters/test_segm_model cellacdc/trackers/example cellacdc/test_qt_app.py cellacdc/test_qthread.py @@ -98,7 +98,7 @@ UserManual/* # Ignore models folder but keep the folder with placeholder.txt dummy file models/* !models/placeholder.txt -cellacdc/models/*/model/* +cellacdc/segmenters/*/model/* # Hide placeholder.txt dummy file (probably works only on Linux) !models/.placeholder.txt diff --git a/MANIFEST.in b/MANIFEST.in index edbfa4229..da7028e54 100755 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -26,8 +26,8 @@ prune publications prine notebooks prune FijiMacros -exclude cellacdc/models/YeastMate/detectron2 -exclude cellacdc/models/YeastMate/pycocotools +exclude cellacdc/segmenters/YeastMate/detectron2 +exclude cellacdc/segmenters/YeastMate/pycocotools exclude requirements.txt exclude not_installed_requirements.txt diff --git a/README.rst b/README.rst index 3aee6ba7c..bc12fe19d 100644 --- a/README.rst +++ b/README.rst @@ -192,6 +192,57 @@ Alternatively, you can also use **Cite this repository** button on the right ribbon of the GitHub page. +Using Cell-ACDC from a script +============================= + +Cell-ACDC can be launched from a Python script or notebook with a napari-style +API. Install the GUI dependencies first: + +.. code-block:: bash + + pip install "cellacdc[gui]" + +Build a unified :class:`ExperimentData` object from arrays or from a path, then +pass it to the viewer: + +.. code-block:: python + + import cellacdc + import numpy as np + + image = np.zeros((100, 128, 128), dtype=np.uint16) # T, Y, X + data = cellacdc.ExperimentData.from_arrays(image, axes="tyx") + viewer = cellacdc.Viewer(data) + cellacdc.run() + +The convenience helper mirrors ``napari.imshow`` and returns both the viewer +and the data object: + +.. code-block:: python + + data = cellacdc.ExperimentData.from_arrays(image, axes="tyx") + viewer, data = cellacdc.imshow(data) + cellacdc.run() + +Path-based loading: + +.. code-block:: python + + data = cellacdc.ExperimentData.from_path("/path/to/experiment") + viewer, data = cellacdc.imshow(data) + cellacdc.run() + +Optional ``labels`` can be supplied when creating data from arrays. When no +``workspace`` path is given, Cell-ACDC uses a temporary folder so segmentation +outputs can still be saved from the GUI. + +In a Jupyter notebook with ``%gui qt``, ``cellacdc.run()`` is a no-op because +IPython already runs the Qt event loop. + +For view-only inspection of arrays without segmentation, use +``cellacdc.plot.imshow`` instead. + + **IMPORTANT**: when citing Cell-ACDC make sure to also cite the paper of the segmentation models and trackers you used! See `here `__ diff --git a/cellacdc/QtScoped.py b/cellacdc/QtScoped.py index 86a194b0d..d533723ba 100644 --- a/cellacdc/QtScoped.py +++ b/cellacdc/QtScoped.py @@ -2,62 +2,72 @@ from qtpy.QtWidgets import QAbstractSlider, QStyle + def SliderNoAction(): if PYQT6: return QAbstractSlider.SliderAction.SliderNoAction.value else: return QAbstractSlider.SliderAction.SliderNoAction + def SliderSingleStepAdd(): if PYQT6: return QAbstractSlider.SliderAction.SliderSingleStepAdd.value else: return QAbstractSlider.SliderAction.SliderSingleStepAdd + def SliderSingleStepSub(): if PYQT6: return QAbstractSlider.SliderAction.SliderSingleStepSub.value else: return QAbstractSlider.SliderAction.SliderSingleStepSub + def SliderPageStepAdd(): if PYQT6: return QAbstractSlider.SliderAction.SliderPageStepAdd.value else: return QAbstractSlider.SliderAction.SliderPageStepAdd + def SliderPageStepSub(): if PYQT6: return QAbstractSlider.SliderAction.SliderPageStepAdd.value else: return QAbstractSlider.SliderAction.SliderPageStepAdd + def SliderToMinimum(): if PYQT6: return QAbstractSlider.SliderAction.SliderPageStepAdd.value else: return QAbstractSlider.SliderAction.SliderPageStepAdd + def SliderToMaximum(): if PYQT6: return QAbstractSlider.SliderAction.SliderPageStepAdd.value else: return QAbstractSlider.SliderAction.SliderPageStepAdd + def SliderMove(): if PYQT6: return QAbstractSlider.SliderAction.SliderMove.value else: return QAbstractSlider.SliderAction.SliderMove + def QStyleCC_ScrollBar(): if PYQT6: return QStyle.ComplexControl.CC_ScrollBar else: return QStyle.CC_ScrollBar + def QStyleSC_ScrollBarSubLine(): if PYQT6: return QStyle.SubControl.SC_ScrollBarSubLine else: - return QStyle.SC_ScrollBarSubLine \ No newline at end of file + return QStyle.SC_ScrollBarSubLine diff --git a/cellacdc/__init__.py b/cellacdc/__init__.py index c4fcaf553..ba2ae6b34 100755 --- a/cellacdc/__init__.py +++ b/cellacdc/__init__.py @@ -3,44 +3,45 @@ import subprocess + def is_conda_env(): python_exec_path = sys.exec_prefix is_conda_python = ( - python_exec_path.find('conda') != -1 - or python_exec_path.find('mambaforge') != -1 - or python_exec_path.find('miniforge') != -1 + python_exec_path.find("conda") != -1 + or python_exec_path.find("mambaforge") != -1 + or python_exec_path.find("miniforge") != -1 ) if not is_conda_python: return False - + stdout = subprocess.DEVNULL try: - args = ['conda', '-V'] - is_conda_present = subprocess.check_call( - args, shell=True, stdout=stdout) == 0 + args = ["conda", "-V"] + is_conda_present = subprocess.check_call(args, shell=True, stdout=stdout) == 0 return True except Exception as err: pass - + try: - args = ['conda -V'] - is_conda_present = subprocess.check_call( - args, shell=True, stdout=stdout) == 0 + args = ["conda -V"] + is_conda_present = subprocess.check_call(args, shell=True, stdout=stdout) == 0 return True except Exception as err: return False - + return True + def import_torch(): if is_conda_env(): - return - + return + try: import torch except ModuleNotFoundError: return + import_torch() @@ -60,80 +61,90 @@ def import_torch(): from typing import Iterable -KNOWN_EXTENSIONS = ( - '.tif', '.npz', '.npy', '.h5', '.json', '.csv', '.txt' -) +KNOWN_EXTENSIONS = (".tif", ".npz", ".npy", ".h5", ".json", ".csv", ".txt") IMAGE_EXTENSIONS = ( - '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', + ".tif", + ".tiff", + ".png", + ".jpg", + ".jpeg", + ".bmp", + ".gif", ) VIDEO_EXTENSIONS = ( - '.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', + ".mp4", + ".avi", + ".mov", + ".mkv", + ".webm", + ".flv", ) -def _warn_ask_install_package( - commands: Iterable[str], note_txt='', caller='SpotMAX' - ): - open_str = '='*100 - sep_str = '-'*100 - commands_txt = '\n'.join([f' {command}' for command in commands]) + +def _warn_ask_install_package(commands: Iterable[str], note_txt="", caller="SpotMAX"): + open_str = "=" * 100 + sep_str = "-" * 100 + commands_txt = "\n".join([f" {command}" for command in commands]) text = ( - f'{caller} needs to run the following commands{note_txt}:\n\n' - f'{commands_txt}\n\n' + f"{caller} needs to run the following commands{note_txt}:\n\n{commands_txt}\n\n" ) question = ( - 'How do you want to proceed?: ' - '1) Run the commands now. ' - 'q) Quit, I will run the commands myself (1/q): ' + "How do you want to proceed?: " + "1) Run the commands now. " + "q) Quit, I will run the commands myself (1/q): " ) print(open_str) print(text) - + message_on_exit = ( - '[WARNING]: Execution aborted. Run the following commands before ' - f'running spotMAX again:\n\n{commands_txt}\n' + "[WARNING]: Execution aborted. Run the following commands before " + f"running spotMAX again:\n\n{commands_txt}\n" ) msg_on_invalid = ( - '$answer is not a valid answer. ' + "$answer is not a valid answer. " 'Type "1" to run the commands now or "q" to quit.' ) try: while True: answer = input(question) - if answer == 'q': + if answer == "q": print(open_str) exit(message_on_exit) - elif answer == '1': + elif answer == "1": break else: print(sep_str) - print(msg_on_invalid.replace('$answer', answer)) + print(msg_on_invalid.replace("$answer", answer)) print(sep_str) except Exception as err: traceback.print_exc() print(open_str) print(message_on_exit) + def _run_pip_commands(commands: Iterable[str]): import subprocess + for command in commands: try: - subprocess.check_call([sys.executable, '-m', *command.split()]) + subprocess.check_call([sys.executable, "-m", *command.split()]) except Exception as err: pass - + + try: import requests except Exception as err: import traceback + traceback.print_exc() - print('We detected a corrupted library, fixing it now...') + print("We detected a corrupted library, fixing it now...") commands = ( - 'pip uninstall -y charset-normalizer', - 'pip install --upgrade charset-normalizer' + "pip uninstall -y charset-normalizer", + "pip install --upgrade charset-normalizer", ) _warn_ask_install_package( - commands, note_txt=' (fixing charset-normalizer package)', - caller='Cell-ACDC' + commands, note_txt=" (fixing charset-normalizer package)", caller="Cell-ACDC" ) _run_pip_commands(commands) @@ -141,18 +152,16 @@ def _run_pip_commands(commands: Iterable[str]): import sympy except Exception as err: import traceback + traceback.print_exc() - print('Since Cell-ACDC v1.7.2, the sympy library is required.') - commands = ( - 'pip install --upgrade sympy', - ) + print("Since Cell-ACDC v1.7.2, the sympy library is required.") + commands = ("pip install --upgrade sympy",) _warn_ask_install_package( - commands, - note_txt=' (installing sympy)', - caller='Cell-ACDC' + commands, note_txt=" (installing sympy)", caller="Cell-ACDC" ) _run_pip_commands(commands) + def user_data_dir(): r""" Get OS specific data directory path for Cell-ACDC. @@ -174,32 +183,31 @@ def user_data_dir(): os_path = os.getenv("XDG_DATA_HOME", "~/.local/share") os_path = os.path.expanduser(os_path) - return os.path.join(os_path, 'Cell_ACDC') + return os.path.join(os_path, "Cell_ACDC") + cellacdc_path = os.path.dirname(os.path.abspath(__file__)) -debug_true_filepath = os.path.join(cellacdc_path, '.debug_true') -qrc_resources_path = os.path.join(cellacdc_path, 'qrc_resources.py') -qrc_resources_light_path = os.path.join(cellacdc_path, 'qrc_resources_light.py') -qrc_resources_dark_path = os.path.join(cellacdc_path, 'qrc_resources_dark.py') -old_temp_path = os.path.join(cellacdc_path, 'temp') -tooltips_rst_filepath = os.path.join( - cellacdc_path, "docs", "source", "tooltips.rst" -) +debug_true_filepath = os.path.join(cellacdc_path, ".debug_true") +qrc_resources_path = os.path.join(cellacdc_path, "qrc_resources.py") +qrc_resources_light_path = os.path.join(cellacdc_path, "qrc_resources_light.py") +qrc_resources_dark_path = os.path.join(cellacdc_path, "qrc_resources_dark.py") +old_temp_path = os.path.join(cellacdc_path, "temp") +tooltips_rst_filepath = os.path.join(cellacdc_path, "docs", "source", "tooltips.rst") user_data_folderpath = user_data_dir() user_profile_path_txt = os.path.join( - user_data_folderpath, 'acdc_user_profile_location.txt' + user_data_folderpath, "acdc_user_profile_location.txt" ) user_home_path = str(pathlib.Path.home()) -user_profile_path = os.path.join(user_home_path, 'acdc-appdata') +user_profile_path = os.path.join(user_home_path, "acdc-appdata") if os.path.exists(user_profile_path_txt): try: - with open(user_profile_path_txt, 'r') as txt: - user_profile_path = fr'{txt.read()}' + with open(user_profile_path_txt, "r") as txt: + user_profile_path = rf"{txt.read()}" except Exception as e: pass -qrc_resources_user_path = os.path.join(user_profile_path, 'qrc_resources.py') +qrc_resources_user_path = os.path.join(user_profile_path, "qrc_resources.py") try: os.makedirs(user_profile_path, exist_ok=True) @@ -213,23 +221,23 @@ def user_data_dir(): # print(f'User profile path: "{user_profile_path}"') import site + sitepackages = site.getsitepackages() -site_packages = [p for p in sitepackages if p.endswith('-packages')][0] +site_packages = [p for p in sitepackages if p.endswith("-packages")][0] cellacdc_path = os.path.dirname(os.path.abspath(__file__)) cellacdc_installation_path = os.path.dirname(cellacdc_path) if cellacdc_installation_path != site_packages: IS_CLONED = True - settings_folderpath = os.path.join(cellacdc_installation_path, '.acdc-settings') + settings_folderpath = os.path.join(cellacdc_installation_path, ".acdc-settings") else: IS_CLONED = False - settings_folderpath = os.path.join(user_profile_path, '.acdc-settings') + settings_folderpath = os.path.join(user_profile_path, ".acdc-settings") + +fiji_location_filepath = os.path.join(settings_folderpath, "fiji_location.txt") +bioio_sample_data_folderpath = os.path.join(user_profile_path, "acdc_dataStruct_temp") -fiji_location_filepath = os.path.join(settings_folderpath, 'fiji_location.txt') -bioio_sample_data_folderpath = os.path.join( - user_profile_path, 'acdc_dataStruct_temp' -) def copytree(src, dst): os.makedirs(dst, exist_ok=True) @@ -241,6 +249,7 @@ def copytree(src, dst): elif os.path.isfile(src_filepath): shutil.copy2(src_filepath, dst_filepath) + if not os.path.exists(settings_folderpath): os.makedirs(settings_folderpath, exist_ok=True) if os.path.exists(old_temp_path): @@ -248,66 +257,65 @@ def copytree(src, dst): copytree(old_temp_path, settings_folderpath) shutil.rmtree(old_temp_path) except Exception as e: - print('*'*60) + print("*" * 60) print( - '[WARNING]: could not copy settings from previous location. ' - f'Please manually copy the folder "{old_temp_path}" to "{settings_folderpath}"') - print('^'*60) + "[WARNING]: could not copy settings from previous location. " + f'Please manually copy the folder "{old_temp_path}" to "{settings_folderpath}"' + ) + print("^" * 60) import pandas as pd + # Disable pandas 3.0 strict string dtype to maintain backward compatibility # with code that assigns non-string values to DataFrames -if hasattr(pd.options, 'future') and hasattr(pd.options.future, 'infer_string'): +if hasattr(pd.options, "future") and hasattr(pd.options.future, "infer_string"): pd.options.future.infer_string = False -settings_csv_path = os.path.join(settings_folderpath, 'settings.csv') +settings_csv_path = os.path.join(settings_folderpath, "settings.csv") if not os.path.exists(settings_csv_path): - df_settings = pd.DataFrame( - {'setting': [], 'value': []}).set_index('setting') + df_settings = pd.DataFrame({"setting": [], "value": []}).set_index("setting") df_settings.to_csv(settings_csv_path) # Get color scheme if not os.path.exists(settings_csv_path): - scheme = 'light' + scheme = "light" try: - df_settings = pd.read_csv(settings_csv_path, index_col='setting') + df_settings = pd.read_csv(settings_csv_path, index_col="setting") except Exception as err: # Overwrite corrupted setttings file - df_settings = pd.DataFrame( - {'setting': [], 'value': []}).set_index('setting') + df_settings = pd.DataFrame({"setting": [], "value": []}).set_index("setting") df_settings.to_csv(settings_csv_path) - -if 'colorScheme' not in df_settings.index: - scheme = 'light' + +if "colorScheme" not in df_settings.index: + scheme = "light" else: - scheme = df_settings.at['colorScheme', 'value'] + scheme = df_settings.at["colorScheme", "value"] -does_qrc_resources_exists = ( - os.path.exists(qrc_resources_path) - or os.path.exists(qrc_resources_user_path) +does_qrc_resources_exists = os.path.exists(qrc_resources_path) or os.path.exists( + qrc_resources_user_path ) + def _copy_qrc_resources_file( - src_qrc_resources_scheme_path: os.PathLike, - dst_qrc_resources_path: os.PathLike, - user_dst_qrc_resources_path: os.PathLike = qrc_resources_user_path - ): + src_qrc_resources_scheme_path: os.PathLike, + dst_qrc_resources_path: os.PathLike, + user_dst_qrc_resources_path: os.PathLike = qrc_resources_user_path, +): try: shutil.copyfile(src_qrc_resources_scheme_path, dst_qrc_resources_path) return True except Exception as err: - # Copy to user folder because copying to cell-acdc location failed - # possibly PermissionError --> return False to stop application + # Copy to user folder because copying to cell-acdc location failed + # possibly PermissionError --> return False to stop application # and prompt the user to restart Cell-ACDC - shutil.copyfile( - src_qrc_resources_scheme_path, user_dst_qrc_resources_path - ) + shutil.copyfile(src_qrc_resources_scheme_path, user_dst_qrc_resources_path) return False + # Set default qrc resources if not does_qrc_resources_exists: - if scheme == 'light': + if scheme == "light": qrc_resources_scheme_path = qrc_resources_light_path else: qrc_resources_scheme_path = qrc_resources_dark_path @@ -323,23 +331,23 @@ def _copy_qrc_resources_file( # Replace 'from PyQt5' with 'from qtpy' in qrc_resources.py file try: save_qrc = False - with open(qrc_resources_path, 'r') as qrc_py: + with open(qrc_resources_path, "r") as qrc_py: text = qrc_py.read() - if text.find('from PyQt5') != -1: - text = text.replace('from PyQt5', 'from qtpy') + if text.find("from PyQt5") != -1: + text = text.replace("from PyQt5", "from qtpy") save_qrc = True if save_qrc: - with open(qrc_resources_path, 'w') as qrc_py: + with open(qrc_resources_path, "w") as qrc_py: qrc_py.write(text) except Exception as err: raise err try: - # Import qrc_resources explicitly so that "from . import acdc_qrc_resources" imports - # the variable defined here. Use importlib in case qrc_resouces.py is in + # Import qrc_resources explicitly so that "from . import acdc_qrc_resources" imports + # the variable defined here. Use importlib in case qrc_resouces.py is in # user folder qrc_resouces_spec = importlib.util.spec_from_file_location( - 'qrc_resources', qrc_resources_path + "qrc_resources", qrc_resources_path ) acdc_qrc_resources = importlib.util.module_from_spec(qrc_resouces_spec) qrc_resouces_spec.loader.exec_module(acdc_qrc_resources) @@ -347,31 +355,35 @@ def _copy_qrc_resources_file( # Cellacdc in the cli might not have qtpy --> ignore error pass + def try_input_install_package(pkg_name, install_command, question=None): if question is None: - question = 'Do you want to install it now ([y]/n)? ' + question = "Do you want to install it now ([y]/n)? " try: - answer = input(f'\n{question}') + answer = input(f"\n{question}") return answer except Exception as err: raise ModuleNotFoundError( f'The module "{pkg_name}" is not installed. ' - f'Install it with the command `{install_command}`.' + f"Install it with the command `{install_command}`." ) + try: # Force PyQt6 if available try: from PyQt6 import QtCore + os.environ["QT_API"] = "pyqt6" except Exception as e: pass from qtpy import QtCore import pyqtgraph import matplotlib + GUI_INSTALLED = True except Exception as e: - GUI_INSTALLED = False + GUI_INSTALLED = False import pandas as pd @@ -379,109 +391,116 @@ def try_input_install_package(pkg_name, install_command, question=None): pd.set_option("display.max_columns", 20) pd.set_option("display.max_rows", 200) -pd.set_option('display.expand_frame_repr', False) +pd.set_option("display.expand_frame_repr", False) + +open_printl_str = "*" * 100 +close_printl_str = "=" * 100 -open_printl_str = '*'*100 -close_printl_str = '='*100 def printl(*objects, pretty=False, is_decorator=False, idx=1, **kwargs): - timestap = datetime.now().strftime('%H:%M:%S') + timestap = datetime.now().strftime("%H:%M:%S") currentframe = inspect.currentframe() outerframes = inspect.getouterframes(currentframe) - idx = idx+1 if is_decorator else idx + idx = idx + 1 if is_decorator else idx callingframe = outerframes[idx].frame callingframe_info = inspect.getframeinfo(callingframe) filepath = callingframe_info.filename - fileinfo_str = ( - f'File "{filepath}", line {callingframe_info.lineno} - {timestap}:' - ) + fileinfo_str = f'File "{filepath}", line {callingframe_info.lineno} - {timestap}:' if pretty: print(open_printl_str) print(fileinfo_str) for o, object in enumerate(objects): - text = str(object) + text = str(object) pprint(text, **kwargs) print(close_printl_str) else: - sep = kwargs.get('sep', ', ') + sep = kwargs.get("sep", ", ") text = sep.join([str(object) for object in objects]) - text = f'{open_printl_str}\n{fileinfo_str}\n{text}\n{close_printl_str}' + text = f"{open_printl_str}\n{fileinfo_str}\n{text}\n{close_printl_str}" print(text) + parent_path = os.path.dirname(cellacdc_path) -html_path = os.path.join(cellacdc_path, '_html') -models_path = os.path.join(cellacdc_path, 'models') -promptable_models_path = os.path.join(cellacdc_path, 'promptable_models') -data_path = os.path.join(parent_path, 'data') -resources_folderpath = os.path.join(cellacdc_path, 'resources') -resources_filepath = os.path.join(cellacdc_path, 'resources_light.qrc') -logs_path = os.path.join(user_profile_path, '.acdc-logs') -acdc_fiji_path = os.path.join(user_profile_path, 'acdc-fiji') -acdc_ffmpeg_path = os.path.join(user_profile_path, 'acdc-ffmpeg') -resources_path = os.path.join(cellacdc_path, 'resources_light.qrc') -models_list_file_path = os.path.join(settings_folderpath, 'custom_models_paths.ini') -promptable_models_list_file_path = os.path.join( - settings_folderpath, 'custom_promptable_models_paths.ini' +html_path = os.path.join(cellacdc_path, "_html") +segmenters_path = os.path.join(cellacdc_path, "segmenters") +segmenters_promptable_path = os.path.join(cellacdc_path, "segmenters_promptable") +data_path = os.path.join(parent_path, "data") +resources_folderpath = os.path.join(cellacdc_path, "resources") +resources_filepath = os.path.join(cellacdc_path, "resources_light.qrc") +logs_path = os.path.join(user_profile_path, ".acdc-logs") +acdc_fiji_path = os.path.join(user_profile_path, "acdc-fiji") +acdc_ffmpeg_path = os.path.join(user_profile_path, "acdc-ffmpeg") +resources_path = os.path.join(cellacdc_path, "resources_light.qrc") +segmenters_list_file_path = os.path.join(settings_folderpath, "custom_models_paths.ini") +segmenters_promptable_list_file_path = os.path.join( + settings_folderpath, "custom_promptable_models_paths.ini" ) +models_path = segmenters_path +promptable_models_path = segmenters_promptable_path +models_list_file_path = segmenters_list_file_path +promptable_models_list_file_path = segmenters_promptable_list_file_path favourite_func_metrics_csv_path = os.path.join( - settings_folderpath, 'favourite_func_metrics.csv' + settings_folderpath, "favourite_func_metrics.csv" +) +recentPaths_path = os.path.join(settings_folderpath, "recentPaths.csv") +preproc_recipes_path = os.path.join(settings_folderpath, "preprocessing_recipes") +combine_channels_recipes_path = os.path.join(settings_folderpath, "combine_channels") +segm_recipes_path = os.path.join(settings_folderpath, "segmentation_recipes") +user_manual_url = "https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf" +github_home_url = "https://github.com/SchmollerLab/Cell_ACDC" +data_structure_docs_url = ( + "https://cell-acdc.readthedocs.io/en/latest/data-structure.html" ) -recentPaths_path = os.path.join(settings_folderpath, 'recentPaths.csv') -preproc_recipes_path = os.path.join(settings_folderpath, 'preprocessing_recipes') -combine_channels_recipes_path = os.path.join(settings_folderpath, 'combine_channels') -segm_recipes_path = os.path.join(settings_folderpath, 'segmentation_recipes') -user_manual_url = 'https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf' -github_home_url = 'https://github.com/SchmollerLab/Cell_ACDC' -data_structure_docs_url = 'https://cell-acdc.readthedocs.io/en/latest/data-structure.html' moth_bud_tot_selected_columns_filepath = os.path.join( - settings_folderpath, 'mother_bud_total_columns_selection.json' + settings_folderpath, "mother_bud_total_columns_selection.json" ) saved_measurements_selections_folderpath = os.path.join( - settings_folderpath, 'saved_measurements_selections' + settings_folderpath, "saved_measurements_selections" ) -# Use to get the acdc_output file name from `segm_filename` as +# Use to get the acdc_output file name from `segm_filename` as # `m = re.sub(segm_re_pattern, '_acdc_output', segm_filename)` -segm_re_pattern = r'_segm(?!.*_segm)' +segm_re_pattern = r"_segm(?!.*_segm)" try: from setuptools_scm import get_version - __version__ = get_version(root='..', relative_to=__file__) + + __version__ = get_version(root="..", relative_to=__file__) except Exception as e: try: from ._version import version as __version__ except ImportError: __version__ = "not-installed" -__author__ = 'Francesco Padovani and Benedikt Mairhoermann' +__author__ = "Francesco Padovani and Benedikt Mairhoermann" -cite_url = 'https://bmcbiol.biomedcentral.com/articles/10.1186/s12915-022-01372-6' -issues_url = 'https://github.com/SchmollerLab/Cell_ACDC/issues' +cite_url = "https://bmcbiol.biomedcentral.com/articles/10.1186/s12915-022-01372-6" +issues_url = "https://github.com/SchmollerLab/Cell_ACDC/issues" # Initialize variables that need to be globally accessible base_cca_dict = { - 'cell_cycle_stage': 'G1', - 'generation_num': 2, - 'relative_ID': -1, - 'relationship': 'mother', - 'emerg_frame_i': -1, - 'division_frame_i': -1, - 'is_history_known': False, - 'corrected_on_frame_i': -1, - 'will_divide': 0, - 'daughter_disappears_before_division': 0, - 'disappears_before_division': 0 + "cell_cycle_stage": "G1", + "generation_num": 2, + "relative_ID": -1, + "relationship": "mother", + "emerg_frame_i": -1, + "division_frame_i": -1, + "is_history_known": False, + "corrected_on_frame_i": -1, + "will_divide": 0, + "daughter_disappears_before_division": 0, + "disappears_before_division": 0, } cca_df_colnames = list(base_cca_dict.keys()) base_cca_tree_dict = { - 'Cell_ID_tree': -1, - 'generation_num_tree': 1, - 'parent_ID_tree': -1, - 'root_ID_tree': -1, - 'sister_ID_tree': -1 + "Cell_ID_tree": -1, + "generation_num_tree": 1, + "parent_ID_tree": -1, + "root_ID_tree": -1, + "sister_ID_tree": -1, } lineage_tree_cols = list(base_cca_tree_dict.keys()) @@ -494,47 +513,36 @@ def printl(*objects, pretty=False, is_decorator=False, idx=1, **kwargs): # 'sister_ID_tree' # ] -lineage_tree_cols_std_val = [ - -1, - -1, - -1, - -1, - -1 -] +lineage_tree_cols_std_val = [-1, -1, -1, -1, -1] default_annot_df = { - 'is_cell_dead': False, - 'is_cell_excluded': False, + "is_cell_dead": False, + "is_cell_excluded": False, } -base_acdc_df = { - **default_annot_df, - 'was_manually_edited': 0 -} +base_acdc_df = {**default_annot_df, "was_manually_edited": 0} base_acdc_df_cols = list(base_acdc_df.keys()) -sorted_cols = ['time_seconds', 'time_minutes', 'time_hours'] -sorted_cols = [ - *sorted_cols, *cca_df_colnames, *lineage_tree_cols, *base_acdc_df_cols -] +sorted_cols = ["time_seconds", "time_minutes", "time_hours"] +sorted_cols = [*sorted_cols, *cca_df_colnames, *lineage_tree_cols, *base_acdc_df_cols] cca_df_colnames_with_tree = [*cca_df_colnames, *lineage_tree_cols] all_non_metrics_cols = [*base_acdc_df_cols, *cca_df_colnames, *lineage_tree_cols] -is_linux = sys.platform.startswith('linux') -is_mac = sys.platform == 'darwin' +is_linux = sys.platform.startswith("linux") +is_mac = sys.platform == "darwin" is_win = sys.platform.startswith("win") -is_win64 = (is_win and (os.environ["PROCESSOR_ARCHITECTURE"] == "AMD64")) -is_mac_arm64 = is_mac and platform.machine() == 'arm64' +is_win64 = is_win and (os.environ["PROCESSOR_ARCHITECTURE"] == "AMD64") +is_mac_arm64 = is_mac and platform.machine() == "arm64" if is_linux and GUI_INSTALLED: from pathlib import Path acdc_exec_path = shutil.which("acdc") - logo_path = os.path.join(resources_folderpath, 'logo_square_v2.png') + logo_path = os.path.join(resources_folderpath, "logo_square_v2.png") txt = f""" [Desktop Entry] Name=Cell-ACDC @@ -562,10 +570,9 @@ def printl(*objects, pretty=False, is_decorator=False, idx=1, **kwargs): # Make the .desktop file executable (equivalent to chmod +x) import stat + mode = os.stat(desktop_file).st_mode - os.chmod( - desktop_file, mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH - ) + os.chmod(desktop_file, mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) # 🔄 Refresh the desktop database try: @@ -576,85 +583,82 @@ def printl(*objects, pretty=False, is_decorator=False, idx=1, **kwargs): stderr=subprocess.PIPE, ) exit( - 'Cell-ACDC had to update the desktop database. ' - 'Please re-start the software, thanks!' + "Cell-ACDC had to update the desktop database. " + "Please re-start the software, thanks!" ) except FileNotFoundError: - print("⚠️ 'update-desktop-database' not found. It’s part of the 'desktop-file-utils' package.") + print( + "⚠️ 'update-desktop-database' not found. It’s part of the 'desktop-file-utils' package." + ) except subprocess.CalledProcessError as e: print(f"⚠️ Error updating desktop database:\n{e.stderr.decode()}") yeaz_weights_filenames = [ - 'unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5', - 'weights_budding_BF_multilab_0_1.hdf5' + "unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5", + "weights_budding_BF_multilab_0_1.hdf5", ] yeaz_v2_weights_filenames = [ - 'weights_budding_BF_multilab_0_1', - 'weights_budding_PhC_multilab_0_1', - 'weights_fission_multilab_0_2' + "weights_budding_BF_multilab_0_1", + "weights_budding_PhC_multilab_0_1", + "weights_fission_multilab_0_2", ] segment_anything_weights_filenames = [ - 'sam_vit_h_4b8939.pth', - 'sam_vit_l_0b3195.pth', - 'sam_vit_b_01ec64.pth' + "sam_vit_h_4b8939.pth", + "sam_vit_l_0b3195.pth", + "sam_vit_b_01ec64.pth", ] sam2_weights_filenames = [ - 'sam2.1_hiera_large.pt', - 'sam2.1_hiera_base_plus.pt', - 'sam2.1_hiera_small.pt', - 'sam2.1_hiera_tiny.pt' + "sam2.1_hiera_large.pt", + "sam2.1_hiera_base_plus.pt", + "sam2.1_hiera_small.pt", + "sam2.1_hiera_tiny.pt", ] -deepsea_weights_filenames = [ - 'segmentation.pth', - 'tracker.pth' -] +deepsea_weights_filenames = ["segmentation.pth", "tracker.pth"] yeastmate_weights_filenames = [ - 'yeastmate_advanced.yaml', - 'yeastmate_weights.pth', - 'yeastmate.yaml' + "yeastmate_advanced.yaml", + "yeastmate_weights.pth", + "yeastmate.yaml", ] -tapir_weights_filenames = [ - 'tapir_checkpoint.npy' -] +tapir_weights_filenames = ["tapir_checkpoint.npy"] graphLayoutBkgrColor = (235, 235, 235) -darkBkgrColor = [255-v for v in graphLayoutBkgrColor] +darkBkgrColor = [255 - v for v in graphLayoutBkgrColor] + def _critical_exception_gui(self, func_name): from . import widgets, html_utils + result = None traceback_str = traceback.format_exc() - - if hasattr(self, 'is_error_state') and self.is_error_state: + + if hasattr(self, "is_error_state") and self.is_error_state: printl(traceback_str) return - - if hasattr(self, 'logger'): + + if hasattr(self, "logger"): self.logger.error(traceback_str) else: printl(traceback_str) - + try: self.cleanUpOnError() except Exception as e: pass - + msg = widgets.myMessageBox(wrapText=False, showCentered=False) - if hasattr(self, 'logs_path'): - msg.addShowInFileManagerButton( - self.logs_path, txt='Show log file...' - ) - if not hasattr(self, 'log_path'): - log_path = 'NULL' + if hasattr(self, "logs_path"): + msg.addShowInFileManagerButton(self.logs_path, txt="Show log file...") + if not hasattr(self, "log_path"): + log_path = "NULL" else: log_path = self.log_path - + self.is_error_state = True msg.setDetailedText(traceback_str, visible=True) href = f'GitHub page' @@ -669,15 +673,16 @@ def _critical_exception_gui(self, func_name): here: """) - msg.critical(self, 'Critical error', err_msg, commands=(log_path,)) - + msg.critical(self, "Critical error", err_msg, commands=(log_path,)) + + def exception_handler_cli(func): @wraps(func) def inner_function(self, *args, **kwargs): try: - if func.__code__.co_argcount==1 and func.__defaults__ is None: + if func.__code__.co_argcount == 1 and func.__defaults__ is None: result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: result = func(self, *args) else: result = func(self, *args, **kwargs) @@ -688,24 +693,28 @@ def inner_function(self, *args, **kwargs): else: raise err return result + return inner_function + def exec_time(func): @wraps(func) def inner_function(self, *args, **kwargs): t0 = time.perf_counter() - if func.__code__.co_argcount==1 and func.__defaults__ is None: + if func.__code__.co_argcount == 1 and func.__defaults__ is None: result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: result = func(self, *args) else: result = func(self, *args, **kwargs) t1 = time.perf_counter() - s = f'{func.__name__} execution time = {(t1-t0)*1000:.3f} ms' + s = f"{func.__name__} execution time = {(t1 - t0) * 1000:.3f} ms" printl(s, is_decorator=True) return result + return inner_function + def _exception_handler_clean_progress(self): try: if self.progressWin is not None: @@ -714,8 +723,10 @@ def _exception_handler_clean_progress(self): except AttributeError: pass + def exception_handler(func): """Decorator to handle class methods exceptions and show a critical error message.""" + @wraps(func) def inner_function(self, *args, **kwargs): try: @@ -723,10 +734,7 @@ def inner_function(self, *args, **kwargs): except TypeError as e: # Only handle the specific Qt slot error msg = str(e) - if ( - "takes 1 positional argument but 2 were given" in msg - and len(args) > 0 - ): + if "takes 1 positional argument but 2 were given" in msg and len(args) > 0: try: # Remove only the last argument (assumed to be from Qt) filtered_args = args[:-1] @@ -741,8 +749,10 @@ def inner_function(self, *args, **kwargs): _exception_handler_clean_progress(self) result = _critical_exception_gui(self, func.__name__) return result + return inner_function + def disableWindow(func): @wraps(func) def inner_function(self, *args, **kwargs): @@ -766,44 +776,69 @@ def inner_function(self, *args, **kwargs): finally: self.setDisabled(False) self.activateWindow() + return inner_function + def ignore_exception(func): @wraps(func) def inner_function(self, *args, **kwargs): try: - if func.__code__.co_argcount==1 and func.__defaults__ is None: + if func.__code__.co_argcount == 1 and func.__defaults__ is None: result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: result = func(self, *args) else: result = func(self, *args, **kwargs) except Exception as e: pass return result + return inner_function -error_below = f"\n{'*'*50} ERROR {'*'*50}\n" -error_close = f"\n{'^'*(len(error_below)-1)}" -error_up_str = '^'*100 -error_up_str = f'\n{error_up_str}' -error_down_str = '^'*100 -error_down_str = f'\n{error_down_str}' +error_below = f"\n{'*' * 50} ERROR {'*' * 50}\n" +error_close = f"\n{'^' * (len(error_below) - 1)}" -binary_file_extensions = ( - ".png", ".pdf" -) +error_up_str = "^" * 100 +error_up_str = f"\n{error_up_str}" +error_down_str = "^" * 100 +error_down_str = f"\n{error_down_str}" + +binary_file_extensions = (".png", ".pdf") default_index_cols = ( - 'experiment_folderpath', - 'experiment_foldername', - 'Position_n', - 'frame_i', - 'Cell_ID' + "experiment_folderpath", + "experiment_foldername", + "Position_n", + "frame_i", + "Cell_ID", ) -single_pos_index_cols = ( - 'experiment_folderpath', - 'Position_n' -) \ No newline at end of file +single_pos_index_cols = ("experiment_folderpath", "Position_n") + +_SCRIPT_API_EXPORTS = { + "Viewer": ("cellacdc.viewer", "Viewer"), + "ExperimentData": ("cellacdc.data_source", "ExperimentData"), + "current_viewer": ("cellacdc.viewer", "current_viewer"), + "imshow": ("cellacdc.viewer", "imshow"), + "run": ("cellacdc._event_loop", "run"), + "get_qapp": ("cellacdc._event_loop", "get_qapp"), + "quit_app": ("cellacdc._event_loop", "quit_app"), +} + +__all__ = list(_SCRIPT_API_EXPORTS) + + +def __getattr__(name: str): + if name in _SCRIPT_API_EXPORTS: + module_name, attr_name = _SCRIPT_API_EXPORTS[name] + import importlib + + module = importlib.import_module(module_name) + return getattr(module, attr_name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return sorted(set(globals()) | set(_SCRIPT_API_EXPORTS)) diff --git a/cellacdc/__main__.py b/cellacdc/__main__.py index 30d670d16..a394d4638 100755 --- a/cellacdc/__main__.py +++ b/cellacdc/__main__.py @@ -6,148 +6,99 @@ import numpy as np import site + sitepackages = site.getsitepackages() -site_packages = [p for p in sitepackages if p.endswith('site-packages')][0] +site_packages = [p for p in sitepackages if p.endswith("site-packages")][0] cellacdc_path = os.path.dirname(os.path.abspath(__file__)) cellacdc_installation_path = os.path.dirname(cellacdc_path) if cellacdc_installation_path != site_packages: - # Running developer version. Delete cellacdc folder from site_packages + # Running developer version. Delete cellacdc folder from site_packages # if present from a previous installation of cellacdc from PyPi - cellacdc_path_pypi = os.path.join(site_packages, 'cellacdc') + cellacdc_path_pypi = os.path.join(site_packages, "cellacdc") if os.path.exists(cellacdc_path_pypi): import shutil + try: shutil.rmtree(cellacdc_path_pypi) except Exception as err: print(err) print( - '[ERROR]: Previous Cell-ACDC installation detected. ' - f'Please, manually delete this folder and re-start the software ' + "[ERROR]: Previous Cell-ACDC installation detected. " + f"Please, manually delete this folder and re-start the software " f'"{cellacdc_path_pypi}". ' - 'Thank you for you patience!' + "Thank you for you patience!" ) exit() - print('*'*60) + print("*" * 60) input( - '[WARNING]: Cell-ACDC had to clean-up and older installation. ' - 'Please, re-start the software. Thank you for your patience! ' - '(Press any key to exit). ' + "[WARNING]: Cell-ACDC had to clean-up and older installation. " + "Please, re-start the software. Thank you for your patience! " + "(Press any key to exit). " ) exit() from cellacdc import _run + def run(): from cellacdc.config import parser_args - PARAMS_PATH = parser_args['params'] - - if parser_args['version'] or parser_args['info']: - from cellacdc.myutils import get_info_version_text + PARAMS_PATH = parser_args["params"] + + if parser_args["version"] or parser_args["info"]: + from cellacdc.utils import get_info_version_text + info_txt = get_info_version_text() print(info_txt) exit() - if parser_args['reset']: - from cellacdc.myutils import reset_settings + if parser_args["reset"]: + from cellacdc.utils import reset_settings + reset_info_txt = reset_settings() print(reset_info_txt) exit() - + if PARAMS_PATH: _run.run_cli(PARAMS_PATH) else: run_gui() + def main(): # Keep compatibility with users that installed older versions # where the entry point was main() run() -def run_gui(): - from ._run import ( - _setup_gui_libraries, - _setup_symlink_app_name_macos, - _setup_numpy, - download_model_params, - _exit_on_setup - ) - - _setup_symlink_app_name_macos() - - requires_exit = _setup_gui_libraries(exit_at_end=False) - - _setup_numpy() - - download_model_params() - - if requires_exit: - _exit_on_setup() - - from qtpy import QtGui, QtWidgets, QtCore - - if os.name == 'nt': - try: - # Set taskbar icon in windows - import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string - ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) - except Exception as e: - pass - - # Needed by pyqtgraph with display resolution scaling - try: - QtWidgets.QApplication.setAttribute( - QtCore.Qt.HighDpiScaleFactorRoundingPolicy.PassThrough - ) - except Exception as e: - pass - import pyqtgraph as pg - # Interpret image data as row-major instead of col-major - pg.setConfigOption('imageAxisOrder', 'row-major') - try: - import numba - pg.setConfigOption("useNumba", True) - except Exception as e: - pass - - try: - import cupy as cp - pg.setConfigOption("useCupy", True) - except Exception as e: - pass +def run_gui(): + app, splashScreen = _run.setup_gui_runtime(splashscreen=True) - # Create the application - app, splashScreen = _run._setup_app(splashscreen=True) + from cellacdc import utils, printl - from cellacdc import myutils, printl - - print('Launching application...') + print("Launching application...") from cellacdc._main import mainWin - + if not splashScreen.isVisible(): splashScreen.show() - + win = mainWin(app) try: - myutils.check_matplotlib_version(qparent=win) + utils.check_matplotlib_version(qparent=win) except Exception as e: pass - version, success = myutils.read_version( - logger=win.logger.info, return_success=True - ) + version, success = utils.read_version(logger=win.logger.info, return_success=True) if not success: - error = myutils.check_install_package( - 'setuptools_scm', pypi_name='setuptools-scm' + error = utils.check_install_package( + "setuptools_scm", pypi_name="setuptools-scm" ) if error: win.logger.info(error) else: - version = myutils.read_version(logger=win.logger.info) + version = utils.read_version(logger=win.logger.info) win.setVersion(version) win.launchWelcomeGuide() win.show() @@ -155,14 +106,16 @@ def run_gui(): win.welcomeGuide.showPage(win.welcomeGuide.welcomeItem) except AttributeError: pass - win.logger.info('**********************************************') - win.logger.info(f'Welcome to Cell-ACDC v{version}') - win.logger.info('**********************************************') - win.logger.info('----------------------------------------------') - win.logger.info('NOTE: If application is not visible, it is probably minimized\n' - 'or behind some other open windows.') - win.logger.info('----------------------------------------------') + win.logger.info("**********************************************") + win.logger.info(f"Welcome to Cell-ACDC v{version}") + win.logger.info("**********************************************") + win.logger.info("----------------------------------------------") + win.logger.info( + "NOTE: If application is not visible, it is probably minimized\n" + "or behind some other open windows." + ) + win.logger.info("----------------------------------------------") splashScreen.close() # splashScreenApp.quit() # modernWin.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/cellacdc/_base_widgets.py b/cellacdc/_base_widgets.py index 667881183..b5874bc47 100644 --- a/cellacdc/_base_widgets.py +++ b/cellacdc/_base_widgets.py @@ -1,39 +1,3 @@ -from qtpy.QtWidgets import QDialog -from . import printl -from qtpy.QtCore import ( - Qt, QEventLoop -) +from .components.base import QBaseDialog -class QBaseDialog(QDialog): - def __init__(self, parent=None): - super().__init__(parent) - - def exec_(self, resizeWidthFactor=None): - if resizeWidthFactor is not None: - self.show() - self.resize(int(self.width()*resizeWidthFactor), self.height()) - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - - try: - self.setEnabled(True) - except Exception as err: - pass - - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - - def keyPressEvent(self, event) -> None: - if event.key() == Qt.Key_Escape: - event.ignore() - return - - super().keyPressEvent(event) +__all__ = ["QBaseDialog"] diff --git a/cellacdc/_core.py b/cellacdc/_core.py index 91cea1ab2..ee676a8f0 100644 --- a/cellacdc/_core.py +++ b/cellacdc/_core.py @@ -11,49 +11,52 @@ from . import printl time_units_formats = { - 'min': 'minutes', - 'hour': 'hours', - 'second': 'seconds', - 'minutes': 'minutes', - 'seconds': 'seconds', - 'hours': 'hours', - 'H': 'hours', - 'd': 'days', - 'M': 'minutes', - 'S': 'seconds', + "min": "minutes", + "hour": "hours", + "second": "seconds", + "minutes": "minutes", + "seconds": "seconds", + "hours": "hours", + "H": "hours", + "d": "days", + "M": "minutes", + "S": "seconds", } time_units_converters = { - 'seconds -> minutes': lambda x: x/60, - 'seconds -> hours': lambda x: x/3600, - 'seconds -> days': lambda x: x/3600/24, - 'minutes -> hours': lambda x: x/60, - 'minutes -> seconds': lambda x: x*60, - 'minutes -> days': lambda x: x/60/24, - 'hours -> minutes': lambda x: x*60, - 'hours -> seconds': lambda x: x*3600, - 'hours -> days': lambda x: x/24, - 'days -> minutes': lambda x: x*24*60, - 'days -> seconds': lambda x: x*24*3600, - 'days -> hours': lambda x: x*24*3600, + "seconds -> minutes": lambda x: x / 60, + "seconds -> hours": lambda x: x / 3600, + "seconds -> days": lambda x: x / 3600 / 24, + "minutes -> hours": lambda x: x / 60, + "minutes -> seconds": lambda x: x * 60, + "minutes -> days": lambda x: x / 60 / 24, + "hours -> minutes": lambda x: x * 60, + "hours -> seconds": lambda x: x * 3600, + "hours -> days": lambda x: x / 24, + "days -> minutes": lambda x: x * 24 * 60, + "days -> seconds": lambda x: x * 24 * 3600, + "days -> hours": lambda x: x * 24 * 3600, } length_unit_converters = { - 'nm -> μm': lambda x: x/1000, - 'mm -> μm': lambda x: x*1e3, - 'cm -> μm': lambda x: x*1e4, - 'μm -> nm': lambda x: x*1000, - 'μm -> mm': lambda x: x/1e3, - 'μm -> cm': lambda x: x/1e4, - 'μm -> μm': lambda x: x, + "nm -> μm": lambda x: x / 1000, + "mm -> μm": lambda x: x * 1e3, + "cm -> μm": lambda x: x * 1e4, + "μm -> nm": lambda x: x * 1000, + "μm -> mm": lambda x: x / 1e3, + "μm -> cm": lambda x: x / 1e4, + "μm -> μm": lambda x: x, } + def convert_length(value, from_unit, to_unit): - key = f'{from_unit} -> {to_unit}' + key = f"{from_unit} -> {to_unit}" return length_unit_converters[key](value) + def round_to_significant(n, n_significant=1): - return round(n, n_significant-int(floor(log10(abs(n))))-1) + return round(n, n_significant - int(floor(log10(abs(n)))) - 1) + def convert_time_units(x, from_unit, to_unit): try: @@ -65,6 +68,7 @@ def convert_time_units(x, from_unit, to_unit): except Exception as e: return + def _calc_rotational_vol(obj, PhysicalSizeY=1, PhysicalSizeX=1, logger=None): """Given the region properties of a 2D object (from skimage.measure.regionprops). calculate the rotation volume as described in the Supplementary information of @@ -103,18 +107,20 @@ def _calc_rotational_vol(obj, PhysicalSizeY=1, PhysicalSizeX=1, logger=None): try: if is3Dobj: # For 3D objects we use a max projection for the rotation - obj_lab = obj.image.max(axis=0).astype(np.uint32)*obj.label + obj_lab = obj.image.max(axis=0).astype(np.uint32) * obj.label obj = regionprops(obj_lab)[0] - vox_to_fl = float(PhysicalSizeY)*pow(float(PhysicalSizeX), 2) + vox_to_fl = float(PhysicalSizeY) * pow(float(PhysicalSizeX), 2) rotate_ID_img = skimage_rotate( - obj.image.astype(np.single), -(obj.orientation*180/np.pi), - resize=True, order=3 + obj.image.astype(np.single), + -(obj.orientation * 180 / np.pi), + resize=True, + order=3, ) - radii = np.sum(rotate_ID_img, axis=1)/2 - vol_vox = np.sum(np.pi*(radii**2)) + radii = np.sum(rotate_ID_img, axis=1) / 2 + vol_vox = np.sum(np.pi * (radii**2)) if vox_to_fl is not None: - return vol_vox, float(vol_vox*vox_to_fl) + return vol_vox, float(vol_vox * vox_to_fl) else: return vol_vox, vol_vox except Exception as e: @@ -124,17 +130,24 @@ def _calc_rotational_vol(obj, PhysicalSizeY=1, PhysicalSizeX=1, logger=None): printl(traceback.format_exc()) return np.nan, np.nan -def _initialize_single_image(image, is_rgb=False, isZstack=False, img_shape=None, # in use, pylint cant detect it - timelapse=False, img_ndim=None, frame_index_out=None, # assumes that the order of dimesions is t, z, c, h, w - add_rgb=True, ): # for some reason doesnt move axis.... + +def _initialize_single_image( + image, + is_rgb=False, + isZstack=False, + img_shape=None, # in use, pylint cant detect it + timelapse=False, + img_ndim=None, + frame_index_out=None, # assumes that the order of dimesions is t, z, c, h, w + add_rgb=True, +): # for some reason doesnt move axis.... # See cellpose.gui.io._initialize_images if img_shape is None: img_shape = image.shape if img_ndim is None: img_ndim = len(img_shape) - - if is_rgb: # enforce 3 channels if RGB, assuming rgb is last axis + if is_rgb: # enforce 3 channels if RGB, assuming rgb is last axis # move channel axis to the end if it is not already # image = np.moveaxis(image, input_channel_axis, -1) # img_shape = list(image) @@ -143,18 +156,18 @@ def _initialize_single_image(image, is_rgb=False, isZstack=False, img_shape=None if img_shape[-1] == 3: pass elif img_shape[-1] < 3: - shape_to_concat = (img_shape[0], img_shape[1], 3-img_shape[-1]) - to_concat = np.zeros(shape_to_concat,dtype=type(image[0,0,0])) + shape_to_concat = (img_shape[0], img_shape[1], 3 - img_shape[-1]) + to_concat = np.zeros(shape_to_concat, dtype=type(image[0, 0, 0])) image = np.concatenate((image, to_concat), axis=-1) - elif img_shape[-1]<5 and img_shape[-1]>2: - image = image[:,:,:3] - + elif img_shape[-1] < 5 and img_shape[-1] > 2: + image = image[:, :, :3] + image = image.astype(np.float32) if is_rgb: # Compute min and max per channel (last axis) - img_min = image.min(axis=tuple(range(image.ndim-1)), keepdims=True) - img_max = image.max(axis=tuple(range(image.ndim-1)), keepdims=True) + img_min = image.min(axis=tuple(range(image.ndim - 1)), keepdims=True) + img_max = image.max(axis=tuple(range(image.ndim - 1)), keepdims=True) else: # Compute min and max over all channels img_min = image.min() @@ -172,7 +185,7 @@ def _initialize_single_image(image, is_rgb=False, isZstack=False, img_shape=None to_concat = np.zeros(shape_to_concat, dtype=type(image[0, 0, 0])) image = image[..., np.newaxis] # add a new axis for channels image = np.concatenate([image, to_concat], axis=-1) - + if is_rgb or add_rgb: axis_for_channels = -3 image = np.moveaxis(image, -1, axis_for_channels) @@ -182,4 +195,4 @@ def _initialize_single_image(image, is_rgb=False, isZstack=False, img_shape=None # z x W x H x c -> z x c x W x H # W x H x c -> c x W x H image = image.astype(np.float32) - return frame_index_out, image \ No newline at end of file + return frame_index_out, image diff --git a/cellacdc/_debug.py b/cellacdc/_debug.py index a0f3e7292..892703351 100644 --- a/cellacdc/_debug.py +++ b/cellacdc/_debug.py @@ -5,39 +5,45 @@ from . import printl, core + def split_segm_masks_mother_bud_line(lab, obj, obj_bud, ref_p1, ref_p2): import matplotlib.pyplot as plt - + lab = np.zeros_like(lab) lab[obj.slice][obj.image] = obj.label lab[obj_bud.slice][obj_bud.image] = obj_bud.label - + (x_ref_0, y_ref_0), (x_ref1, y_ref1) = ref_p1, ref_p2 - - plt.imshow(lab) - plt.plot([x_ref_0, x_ref1], [y_ref_0, y_ref1], 'r') + + plt.imshow(lab) + plt.plot([x_ref_0, x_ref1], [y_ref_0, y_ref1], "r") plt.show() - - import pdb; pdb.set_trace() + + import pdb + + pdb.set_trace() + def print_all_callers(): currentframe = inspect.currentframe() outerframes = inspect.getouterframes(currentframe, 2) - outerframes_format = '\n' + outerframes_format = "\n" for frame in outerframes: - outerframes_format = f'{outerframes_format} * {frame.function}\n' + outerframes_format = f"{outerframes_format} * {frame.function}\n" printl(outerframes_format) + def _debug_lineage_tree(guiWin): posData = guiWin.data[guiWin.pos_i] - columns = set() + columns = set() for frame_i in range(len(posData.allData_li)): - acdc_df = posData.allData_li[frame_i]['acdc_df'] + acdc_df = posData.allData_li[frame_i]["acdc_df"] if acdc_df is not None: columns.update(acdc_df.reset_index().columns) printl(f"Columns in acdc_df: {columns}") from pandasgui import show as pgshow + if guiWin.lineage_tree is not None and guiWin.lineage_tree.lineage_list is not None: lin_tree_df = pd.DataFrame() for i, df in enumerate(guiWin.lineage_tree.lineage_list): @@ -49,16 +55,13 @@ def _debug_lineage_tree(guiWin): if not isinstance(lin_tree_df.index, pd.RangeIndex): lin_tree_df = lin_tree_df.reset_index() - lin_tree_df = (lin_tree_df - .set_index(["frame_i", "Cell_ID"]) - .sort_index() - ) + lin_tree_df = lin_tree_df.set_index(["frame_i", "Cell_ID"]).sort_index() if "level_0" in lin_tree_df.columns: - lin_tree_df=lin_tree_df.drop(columns="level_0") + lin_tree_df = lin_tree_df.drop(columns="level_0") acdc_df = pd.DataFrame() posData = guiWin.data[guiWin.pos_i] - df_li = [posData.allData_li[i]['acdc_df'] for i in range(len(posData.allData_li))] + df_li = [posData.allData_li[i]["acdc_df"] for i in range(len(posData.allData_li))] for i, df in enumerate(df_li): if df is None: continue @@ -67,10 +70,7 @@ def _debug_lineage_tree(guiWin): df["frame_i"] = i acdc_df = pd.concat([acdc_df, df]) - acdc_df = (acdc_df - .set_index(["frame_i", "Cell_ID"]) - .sort_index() - ) + acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]).sort_index() # for key, value in guiWin.lineage_tree.family_dict.items(): if guiWin.lineage_tree is not None and guiWin.lineage_tree.lineage_list is not None: @@ -82,25 +82,24 @@ def _debug_lineage_tree(guiWin): family_df = family_df.set_index("family_name") families = pd.concat([families, family_df]) if "level_0" in families.columns: - families=families.drop(columns="level_0") + families = families.drop(columns="level_0") # lin_tree_dict_df = (lin_tree_dict_df # .set_index(["family_name", "frame_i", "Cell_ID"]) # .sort_index() # ) - + # for i, df in enumerate([acdc_df, lin_tree_df, families, lin_tree_dict_df]): # printl(f"Columns: {df.columns} for df {i}" ) # if (df.columns == df.index.name).any(): # printl(f"Index name: {df.index.name} for df {i}!!!" ) if "level_0" in acdc_df.columns: - acdc_df=acdc_df.drop(columns="level_0") - + acdc_df = acdc_df.drop(columns="level_0") if guiWin.lineage_tree is not None and guiWin.lineage_tree.lineage_list is not None: pgshow(acdc_df, lin_tree_df, families) else: pgshow(acdc_df) - # printl(posData.tracked_lost_centroids) \ No newline at end of file + # printl(posData.tracked_lost_centroids) diff --git a/cellacdc/_deprecated/filters.py b/cellacdc/_deprecated/filters.py index b3ffba969..4dedd27fb 100644 --- a/cellacdc/_deprecated/filters.py +++ b/cellacdc/_deprecated/filters.py @@ -5,7 +5,7 @@ from cellacdc import html_utils -from . import GUI_INSTALLED, core, myutils +from . import GUI_INSTALLED, core, utils from . import preprocess if GUI_INSTALLED: @@ -245,12 +245,12 @@ def filter(self, img): if sigma1_yx>0: filtered1 = skimage.filters.gaussian(img, sigma=sigmas1) else: - filtered1 = myutils.img_to_float(img) + filtered1 = utils.img_to_float(img) if sigma2_yx>0: filtered2 = skimage.filters.gaussian(img, sigma=sigmas2) else: - filtered2 = myutils.img_to_float(img) + filtered2 = utils.img_to_float(img) resultFiltered = filtered1 - filtered2 return resultFiltered diff --git a/cellacdc/_event_loop.py b/cellacdc/_event_loop.py new file mode 100644 index 000000000..a3537e90d --- /dev/null +++ b/cellacdc/_event_loop.py @@ -0,0 +1,111 @@ +"""Qt event loop helpers for script and notebook usage.""" + +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING +from warnings import warn + +if TYPE_CHECKING: + from qtpy.QtWidgets import QApplication + +_APP_REF = None +_IPYTHON_WAS_HERE_FIRST = "IPython" in sys.modules + + +def _ipython_has_eventloop() -> bool: + ipy_module = sys.modules.get("IPython") + if not ipy_module: + return False + + shell = ipy_module.get_ipython() # type: ignore[attr-defined] + if not shell: + return False + + return shell.active_eventloop == "qt" + + +def _pycharm_has_eventloop(app: QApplication) -> bool: + in_pycharm = "PYCHARM_HOSTED" in os.environ + in_event_loop = getattr(app, "_in_event_loop", False) + return in_pycharm and in_event_loop + + +def get_qapp(*, splashscreen: bool = False): + """Get or create the Qt QApplication used by Cell-ACDC.""" + global _APP_REF + + from qtpy.QtWidgets import QApplication + + app = QApplication.instance() + if app is None: + from cellacdc._run import setup_gui_runtime + + app, _splash = setup_gui_runtime(splashscreen=splashscreen) + _APP_REF = app + elif _APP_REF is None: + _APP_REF = app + + return app + + +def quit_app() -> None: + """Close open viewers and quit if Cell-ACDC started the QApplication.""" + from qtpy.QtWidgets import QApplication + + from cellacdc.viewer import Viewer + + for viewer in list(Viewer._instances): + viewer.close() + + QApplication.closeAllWindows() + + app = QApplication.instance() + if app is None: + return + + if ( + QApplication.applicationName() == "Cell-ACDC" + and not _ipython_has_eventloop() + ): + QApplication.quit() + + +def run(*, force: bool = False, max_loop_level: int = 1, _func_name: str = "run"): + """Start the Qt event loop.""" + if _ipython_has_eventloop(): + return + + from qtpy.QtWidgets import QApplication + + app = QApplication.instance() + + if app is not None and _pycharm_has_eventloop(app): + return + + if app is None: + raise RuntimeError( + "No Qt app has been created. Create one with " + "`cellacdc.get_qapp()` or `cellacdc.Viewer()`." + ) + + if not app.topLevelWidgets() and not force: + warn( + f"Refusing to run a QApplication with no topLevelWidgets. " + f"To run the app anyway, use `{_func_name}(force=True)`.", + stacklevel=2, + ) + return + + if app.thread().loopLevel() >= max_loop_level: + loops = app.thread().loopLevel() + warn( + f"A QApplication is already running with {loops} event loop(s). " + f"To enter another event loop, use " + f"`{_func_name}(max_loop_level={loops + 1})`.", + stacklevel=2, + ) + return + + app.exec_() diff --git a/cellacdc/_get_app_palette.py b/cellacdc/_get_app_palette.py index a6e3b543b..281b1eb1c 100644 --- a/cellacdc/_get_app_palette.py +++ b/cellacdc/_get_app_palette.py @@ -1,17 +1,27 @@ from qtpy import QtGui, QtWidgets, QtCore -print(f'Using Qt version {QtCore.__version__}') +print(f"Using Qt version {QtCore.__version__}") from pprint import pprint app = QtWidgets.QApplication([]) -app.setStyle(QtWidgets.QStyleFactory.create('Fusion')) +app.setStyle(QtWidgets.QStyleFactory.create("Fusion")) app.setPalette(app.style().standardPalette()) roles = ( - 'Window', 'WindowText', 'Base', 'AlternateBase', 'ToolTipBase', - 'ToolTipText', 'Text', 'Button', 'ButtonText', 'BrightText', - 'Link', 'Highlight', 'HighlightedText' + "Window", + "WindowText", + "Base", + "AlternateBase", + "ToolTipBase", + "ToolTipText", + "Text", + "Button", + "ButtonText", + "BrightText", + "Link", + "Highlight", + "HighlightedText", ) colors = {} @@ -21,4 +31,4 @@ rgba = app.palette().color(colorRole).getRgb() colors[role] = rgba -pprint(colors, sort_dicts=False) \ No newline at end of file +pprint(colors, sort_dicts=False) diff --git a/cellacdc/_main.py b/cellacdc/_main.py index f220a3b91..04b372e94 100644 --- a/cellacdc/_main.py +++ b/cellacdc/_main.py @@ -12,48 +12,80 @@ from qtpy import QtCore, QtWidgets from qtpy.QtWidgets import ( - QMainWindow, QVBoxLayout, QPushButton, QLabel, QAction, - QMenu, QHBoxLayout, QFileDialog, QGroupBox, QCheckBox, QSplashScreen + QMainWindow, + QVBoxLayout, + QPushButton, + QLabel, + QAction, + QMenu, + QHBoxLayout, + QFileDialog, + QGroupBox, + QCheckBox, + QSplashScreen, ) from qtpy.QtCore import ( - Qt, QProcess, Signal, Slot, QTimer, QSize, - QSettings, QUrl, QCoreApplication + Qt, + QProcess, + Signal, + Slot, + QTimer, + QSize, + QSettings, + QUrl, + QCoreApplication, ) from qtpy.QtGui import ( - QFontDatabase, QIcon, QDesktopServices, QFont, QColor, - QPalette, QGuiApplication, QPixmap + QFontDatabase, + QIcon, + QDesktopServices, + QFont, + QColor, + QPalette, + QGuiApplication, + QPixmap, ) import qtpy.compat from . import ( - dataPrep, segm, gui, dataStruct, load, help, myutils, - cite_url, html_utils, widgets, apps, dataReStruct + dataPrep, + segm, + gui, + dataStruct, + load, + help, + utils, + cite_url, + html_utils, + widgets, + apps, + dataReStruct, ) from .help import about -from .utils import concat as utilsConcat -from .utils import convert as utilsConvert -from .utils import rename as utilsRename -from .utils import align as utilsAlign -from .utils import compute as utilsCompute -from .utils import repeat as utilsRepeat -from .utils import toImageJroi as utilsToImageJroi -from .utils.resize import util as utilsResizePositionsUtil -from .utils import fromImageJroiToSegm as utilsFromImageJroi -from .utils import toObjCoords as utilsToObjCoords -from .utils import acdcToSymDiv as utilsSymDiv -from .utils import trackSubCellObjects as utilsTrackSubCell -from .utils import createConnected3Dsegm as utilsConnected3Dsegm -from .utils import countObjects as utilsCountObjectsInSegm -from .utils import fucciPreprocess as utilsFucciPreprocess -from .utils import customPreprocess as utilsCustomPreprocess -from .utils import combineChannels as utilsCombineChannels -from .utils import filterObjFromCoordsTable as utilsFilterObjsFromTable -from .utils import stack2Dinto3Dsegm as utilsStack2Dto3D -from .utils import computeMultiChannel as utilsComputeMultiCh -from .utils import applyTrackFromTable as utilsApplyTrackFromTab -from .utils import applyTrackFromTrackMateXML as utilsApplyTrackFromTrackMate -from .utils import fillHolesInSegm -from .utils import generateMothBudTotalTable as utilsGenerateMothBudTotTable +from .tools import concat as utilsConcat +from .tools import convert as utilsConvert +from .tools import rename as utilsRename +from .tools import align as utilsAlign +from .tools import compute as utilsCompute +from .tools import repeat as utilsRepeat +from .tools import toImageJroi as utilsToImageJroi +from .tools.resize import util as utilsResizePositionsUtil +from .tools import fromImageJroiToSegm as utilsFromImageJroi +from .tools import toObjCoords as utilsToObjCoords +from .tools import acdcToSymDiv as utilsSymDiv +from .tools import trackSubCellObjects as utilsTrackSubCell +from .tools import createConnected3Dsegm as utilsConnected3Dsegm +from .tools import countObjects as utilsCountObjectsInSegm +from .tools import fucciPreprocess as utilsFucciPreprocess +from .tools import customPreprocess as utilsCustomPreprocess +from .tools import combineChannels as utilsCombineChannels +from .tools import filterObjFromCoordsTable as utilsFilterObjsFromTable +from .tools import stack2Dinto3Dsegm as utilsStack2Dto3D +from .tools import computeMultiChannel as utilsComputeMultiCh +from .tools import applyTrackFromTable as utilsApplyTrackFromTab +from .tools import applyTrackFromTrackMateXML as utilsApplyTrackFromTrackMate +from .tools import fillHolesInSegm +from .tools import generateMothBudTotalTable as utilsGenerateMothBudTotTable from .info import utilsInfo from . import is_win, is_linux, settings_folderpath, issues_url, is_mac from . import settings_csv_path @@ -63,15 +95,14 @@ from . import exception_handler from . import user_profile_path from . import cellacdc_path -from . config import parser_args +from .config import parser_args try: import spotmax from spotmax import _run as spotmaxRun + spotmax_filepath = os.path.dirname(os.path.abspath(spotmax.__file__)) - spotmax_logo_path = os.path.join( - spotmax_filepath, 'resources', 'spotMAX_logo.svg' - ) + spotmax_logo_path = os.path.join(spotmax_filepath, "resources", "spotMAX_logo.svg") SPOTMAX_INSTALLED = True except Exception as e: # traceback.print_exc() @@ -79,6 +110,7 @@ traceback.print_exc() SPOTMAX_INSTALLED = False + def restart(): QCoreApplication.quit() process = QtCore.QProcess() @@ -86,7 +118,8 @@ def restart(): # process.setStandardOutputFile(QProcess.nullDevice()) status = process.startDetached() if status: - print('Restarting Cell-ACDC...') + print("Restarting Cell-ACDC...") + class mainWin(QMainWindow): def __init__(self, app, parent=None): @@ -95,22 +128,20 @@ def __init__(self, app, parent=None): scheme = self.getColorScheme() self.welcomeGuide = None self._do_restart = False - + super().__init__(parent) self.setWindowTitle("Cell-ACDC") self.setWindowIcon(QIcon(":icon.ico")) self.setAcceptDrops(True) - + self.checkUserDataFolderPath = True - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='main' - ) + logger, logs_path, log_path, log_filename = utils.setupLogger(module="main") self.logger = logger self.log_path = log_path self.log_filename = log_filename - self.logs_path = logs_path - + self.logs_path = logs_path + if not is_linux: self.loadFonts() @@ -125,19 +156,23 @@ def __init__(self, app, parent=None): mainLayout = QVBoxLayout() mainLayout.addStretch() - welcomeLabel = QLabel(html_utils.paragraph( - 'Welcome to Cell-ACDC!', - center=True, font_size='18px' - )) + welcomeLabel = QLabel( + html_utils.paragraph( + "Welcome to Cell-ACDC!", center=True, font_size="18px" + ) + ) # padding: top, left, bottom, right welcomeLabel.setStyleSheet("padding:0px 0px 5px 0px;") mainLayout.addWidget(welcomeLabel) - label = QLabel(html_utils.paragraph( - 'Press any of the following buttons
' - 'to launch the respective module', - center=True, font_size='14px' - )) + label = QLabel( + html_utils.paragraph( + "Press any of the following buttons
" + "to launch the respective module", + center=True, + font_size="14px", + ) + ) # padding: top, left, bottom, right label.setStyleSheet("padding:0px 0px 10px 0px;") mainLayout.addWidget(label) @@ -145,16 +180,16 @@ def __init__(self, app, parent=None): mainLayout.addStretch() iconSize = 26 - + modulesButtonsGroupBox = QGroupBox() - modulesButtonsGroupBox.setTitle('Modules') + modulesButtonsGroupBox.setTitle("Modules") modulesButtonsGroupBoxLayout = QVBoxLayout() modulesButtonsGroupBox.setLayout(modulesButtonsGroupBoxLayout) - + dataStructButton = widgets.setPushButton( - ' 0. Create data structure from microscopy/image file(s)... ' + " 0. Create data structure from microscopy/image file(s)... " ) - dataStructButton.setIconSize(QSize(iconSize,iconSize)) + dataStructButton.setIconSize(QSize(iconSize, iconSize)) font = QFont() font.setPixelSize(13) dataStructButton.setFont(font) @@ -162,9 +197,9 @@ def __init__(self, app, parent=None): self.dataStructButton = dataStructButton modulesButtonsGroupBoxLayout.addWidget(dataStructButton) - dataPrepButton = QPushButton(' 1. Launch data prep module...') - dataPrepButton.setIcon(QIcon(':prep.svg')) - dataPrepButton.setIconSize(QSize(iconSize,iconSize)) + dataPrepButton = QPushButton(" 1. Launch data prep module...") + dataPrepButton.setIcon(QIcon(":prep.svg")) + dataPrepButton.setIconSize(QSize(iconSize, iconSize)) font = QFont() font.setPixelSize(13) dataPrepButton.setFont(font) @@ -172,16 +207,16 @@ def __init__(self, app, parent=None): self.dataPrepButton = dataPrepButton modulesButtonsGroupBoxLayout.addWidget(dataPrepButton) - segmButton = QPushButton(' 2. Launch segmentation module...') - segmButton.setIcon(QIcon(':segment.svg')) - segmButton.setIconSize(QSize(iconSize,iconSize)) + segmButton = QPushButton(" 2. Launch segmentation module...") + segmButton.setIcon(QIcon(":segment.svg")) + segmButton.setIconSize(QSize(iconSize, iconSize)) segmButton.setFont(font) segmButton.clicked.connect(self.launchSegm) self.segmButton = segmButton modulesButtonsGroupBoxLayout.addWidget(segmButton) - guiButton = QPushButton(' 3. Launch GUI...') - guiButton.setIcon(QIcon(':logo.svg')) + guiButton = QPushButton(" 3. Launch GUI...") + guiButton.setIcon(QIcon(":logo.svg")) guiButton.setIconSize(QSize(iconSize, iconSize)) guiButton.setFont(font) guiButton.clicked.connect(self.launchGui) @@ -189,25 +224,25 @@ def __init__(self, app, parent=None): modulesButtonsGroupBoxLayout.addWidget(guiButton) if SPOTMAX_INSTALLED: - spotmaxButton = QPushButton(' 4. Launch SpotMAX...') + spotmaxButton = QPushButton(" 4. Launch SpotMAX...") spotmaxButton.setIcon(QIcon(spotmax_logo_path)) - spotmaxButton.setIconSize(QSize(iconSize,iconSize)) + spotmaxButton.setIconSize(QSize(iconSize, iconSize)) spotmaxButton.setFont(font) self.spotmaxButton = spotmaxButton spotmaxButton.clicked.connect(self.launchSpotmaxGui) modulesButtonsGroupBoxLayout.addWidget(spotmaxButton) - + mainLayout.addWidget(modulesButtonsGroupBox) mainLayout.addSpacing(10) - + controlsButtonsGroupBox = QGroupBox() - controlsButtonsGroupBox.setTitle('Controls') + controlsButtonsGroupBox.setTitle("Controls") controlsButtonsGroupBoxLayout = QVBoxLayout() controlsButtonsGroupBox.setLayout(controlsButtonsGroupBoxLayout) - - showAllWindowsButton = QPushButton(' Restore open windows') - showAllWindowsButton.setIcon(QIcon(':eye.svg')) - showAllWindowsButton.setIconSize(QSize(iconSize,iconSize)) + + showAllWindowsButton = QPushButton(" Restore open windows") + showAllWindowsButton.setIcon(QIcon(":eye.svg")) + showAllWindowsButton.setIconSize(QSize(iconSize, iconSize)) showAllWindowsButton.setFont(font) self.showAllWindowsButton = showAllWindowsButton showAllWindowsButton.clicked.connect(self.showAllWindows) @@ -217,20 +252,17 @@ def __init__(self, app, parent=None): font.setPixelSize(13) closeLayout = QHBoxLayout() - restartButton = QPushButton( - QIcon(":reload.svg"), - ' Restart Cell-ACDC' - ) + restartButton = QPushButton(QIcon(":reload.svg"), " Restart Cell-ACDC") restartButton.setFont(font) restartButton.setIconSize(QSize(iconSize, iconSize)) restartButton.clicked.connect(self.close) self.restartButton = restartButton self.restartButton.hide() closeLayout.addWidget(restartButton) - + closeLayout.addWidget(showAllWindowsButton) - closeButton = QPushButton(QIcon(":close.svg"), ' Close application') + closeButton = QPushButton(QIcon(":close.svg"), " Close application") closeButton.setIconSize(QSize(iconSize, iconSize)) self.closeButton = closeButton # closeButton.setIconSize(QSize(24,24)) @@ -239,9 +271,9 @@ def __init__(self, app, parent=None): closeLayout.addWidget(closeButton) controlsButtonsGroupBoxLayout.addLayout(closeLayout) - + mainLayout.addWidget(controlsButtonsGroupBox) - + mainContainer.setLayout(mainLayout) self.guiWins = [] @@ -250,81 +282,87 @@ def __init__(self, app, parent=None): self._version = None self.progressWin = None self.forceClose = False - + def addStatusBar(self, scheme): self.statusbar = self.statusBar() # Permanent widget - label = QLabel('Dark mode') + label = QLabel("Dark mode") widget = QtWidgets.QWidget() layout = QHBoxLayout() widget.setLayout(layout) layout.addWidget(label) - self.darkModeToggle = widgets.Toggle(label_text='Dark mode') + self.darkModeToggle = widgets.Toggle(label_text="Dark mode") self.darkModeToggle.ignoreEvent = False - self.darkModeToggle.warnMessageBox = True - if scheme == 'dark': + self.darkModeToggle.warnMessageBox = True + if scheme == "dark": self.darkModeToggle.ignoreEvent = True self.darkModeToggle.setChecked(True) self.darkModeToggle.toggled.connect(self.onDarkModeToggled) layout.addWidget(self.darkModeToggle) self.statusBarLayout = layout self.statusbar.addWidget(widget) - + def getColorScheme(self): from ._palettes import get_color_scheme + return get_color_scheme() - + def onDarkModeToggled(self, checked): if self.darkModeToggle.ignoreEvent: self.darkModeToggle.ignoreEvent = False return from ._palettes import getPaletteColorScheme - scheme = 'dark' if checked else 'light' + + scheme = "dark" if checked else "light" load.rename_qrc_resources_file(scheme) if not os.path.exists(settings_csv_path): - df_settings = pd.DataFrame( - {'setting': [], 'value': []}).set_index('setting') + df_settings = pd.DataFrame({"setting": [], "value": []}).set_index( + "setting" + ) else: - df_settings = pd.read_csv(settings_csv_path, index_col='setting') - df_settings.at['colorScheme', 'value'] = scheme + df_settings = pd.read_csv(settings_csv_path, index_col="setting") + df_settings.at["colorScheme", "value"] = scheme df_settings.to_csv(settings_csv_path) if self.darkModeToggle.warnMessageBox: _warnings.warnRestartCellACDCcolorModeToggled( - scheme, app_name='Cell-ACDC', parent=self + scheme, app_name="Cell-ACDC", parent=self ) self.darkModeToggle.warnMessageBox = True self.setStatusBarRestartCellACDC() self.darkModeToggle.setDisabled(True) - + def setStatusBarRestartCellACDC(self): - self.statusBarLayout.addWidget(QLabel(html_utils.paragraph( - 'Restart Cell-ACDC for the change to take effect', - font_color='red' - ))) - + self.statusBarLayout.addWidget( + QLabel( + html_utils.paragraph( + "Restart Cell-ACDC for the change to take effect", + font_color="red", + ) + ) + ) + def checkConfigFiles(self): - print('Loading configuration files...') + print("Loading configuration files...") paths_to_check = [ - gui.favourite_func_metrics_csv_path, - # gui.custom_annot_path, - gui.shortcut_filepath, - os.path.join(settings_folderpath, 'recentPaths.csv'), - load.last_entries_metadata_path, - load.additional_metadata_path, - load.last_selected_measurements_ini_path + gui.favourite_func_metrics_csv_path, + # gui.custom_annot_path, + gui.shortcut_filepath, + os.path.join(settings_folderpath, "recentPaths.csv"), + load.last_entries_metadata_path, + load.additional_metadata_path, + load.last_selected_measurements_ini_path, ] for path in paths_to_check: load.remove_duplicates_file(path) - - def dragEnterEvent(self, event) -> None: - ... - + + def dragEnterEvent(self, event) -> None: ... + def log(self, text): self.logger.info(text) - + if self.progressWin is None: return - + self.progressWin.log(text) def setVersion(self, version): @@ -352,18 +390,19 @@ def loadFonts(self): def launchWelcomeGuide(self, checked=False): if not os.path.exists(settings_csv_path): - idx = ['showWelcomeGuide'] - values = ['Yes'] + idx = ["showWelcomeGuide"] + values = ["Yes"] self.df_settings = pd.DataFrame( - {'setting': idx, 'value': values}).set_index('setting') + {"setting": idx, "value": values} + ).set_index("setting") self.df_settings.to_csv(settings_csv_path) - self.df_settings = pd.read_csv(settings_csv_path, index_col='setting') - if 'showWelcomeGuide' not in self.df_settings.index: - self.df_settings.at['showWelcomeGuide', 'value'] = 'Yes' + self.df_settings = pd.read_csv(settings_csv_path, index_col="setting") + if "showWelcomeGuide" not in self.df_settings.index: + self.df_settings.at["showWelcomeGuide", "value"] = "Yes" self.df_settings.to_csv(settings_csv_path) show = ( - self.df_settings.at['showWelcomeGuide', 'value'] == 'Yes' + self.df_settings.at["showWelcomeGuide", "value"] == "Yes" or self.sender() is not None ) if not show: @@ -374,7 +413,7 @@ def launchWelcomeGuide(self, checked=False): self.welcomeGuide.showPage(self.welcomeGuide.welcomeItem) def setColorsAndText(self): - self.moduleLaunchedColor = '#f1dd00' + self.moduleLaunchedColor = "#f1dd00" self.moduleLaunchedQColor = QColor(self.moduleLaunchedColor) defaultColor = self.guiButton.palette().button().color().name() self.defaultButtonPalette = self.guiButton.palette() @@ -384,12 +423,8 @@ def setColorsAndText(self): self.defaultTextDataPrepButton = self.dataPrepButton.text() self.defaultTextSegmButton = self.segmButton.text() self.moduleLaunchedPalette = self.guiButton.palette() - self.moduleLaunchedPalette.setColor( - QPalette.Button, self.moduleLaunchedQColor - ) - self.moduleLaunchedPalette.setColor( - QPalette.ButtonText, QColor(0, 0, 0) - ) + self.moduleLaunchedPalette.setColor(QPalette.Button, self.moduleLaunchedQColor) + self.moduleLaunchedPalette.setColor(QPalette.ButtonText, QColor(0, 0, 0)) def createMenuBar(self): menuBar = self.menuBar() @@ -397,12 +432,12 @@ def createMenuBar(self): self.recentPathsMenu = QMenu("&Recent paths", self) # On macOS an empty menu would not appear --> add dummy action - self.recentPathsMenu.addAction('dummy macos') + self.recentPathsMenu.addAction("dummy macos") menuBar.addMenu(self.recentPathsMenu) utilsMenu = menuBar.addMenu("&Utilities") - convertMenu = utilsMenu.addMenu('Convert file formats') + convertMenu = utilsMenu.addMenu("Convert file formats") convertMenu.addAction(self.npzToNpyAction) convertMenu.addAction(self.npzToTiffAction) convertMenu.addAction(self.TiffToNpzAction) @@ -411,35 +446,33 @@ def createMenuBar(self): convertMenu.addAction(self.fromImageJroiAction) convertMenu.addAction(self.toObjsCoordsAction) - segmMenu = utilsMenu.addMenu('Segmentation') + segmMenu = utilsMenu.addMenu("Segmentation") segmMenu.addAction(self.createConnected3Dsegm) segmMenu.addAction(self.stack2Dto3DsegmAction) segmMenu.addAction(self.filterObjsFromTableAction) segmMenu.addAction(self.fillHolesInSegmAction) - trackingMenu = utilsMenu.addMenu('Tracking and lineage') + trackingMenu = utilsMenu.addMenu("Tracking and lineage") trackingMenu.addAction(self.trackSubCellFeaturesAction) trackingMenu.addAction(self.applyTrackingFromTableAction) trackingMenu.addAction(self.applyTrackingFromTrackMateXMLAction) - trackingMenu.addAction(self.toSymDivAction) - + trackingMenu.addAction(self.toSymDivAction) + self.trackingMenu = trackingMenu - measurementsMenu = utilsMenu.addMenu('Measurements') + measurementsMenu = utilsMenu.addMenu("Measurements") measurementsMenu.addAction(self.calcMetricsAcdcDf) measurementsMenu.addAction(self.countObjectsInSegmAction) - measurementsMenu.addAction(self.combineMetricsMultiChannelAction) - measurementsMenu.addAction(self.generateMothBudTotTableAction) - - concatMenu = utilsMenu.addMenu('Concatenate') - concatMenu.addAction(self.concatAcdcDfsAction) + measurementsMenu.addAction(self.combineMetricsMultiChannelAction) + measurementsMenu.addAction(self.generateMothBudTotTableAction) + + concatMenu = utilsMenu.addMenu("Concatenate") + concatMenu.addAction(self.concatAcdcDfsAction) if SPOTMAX_INSTALLED: - concatMenu.addAction(self.concatSpotmaxDfsAction) + concatMenu.addAction(self.concatSpotmaxDfsAction) + + dataPrepMenu = utilsMenu.addMenu("Image and segmentation files preprocessing") - dataPrepMenu = utilsMenu.addMenu( - 'Image and segmentation files preprocessing' - ) - dataPrepMenu.addAction(self.batchConverterAction) dataPrepMenu.addAction(self.repeatDataPrepAction) dataPrepMenu.addAction(self.alignAction) @@ -447,17 +480,17 @@ def createMenuBar(self): dataPrepMenu.addAction(self.fucciPreprocessAction) dataPrepMenu.addAction(self.customPreprocessAction) dataPrepMenu.addAction(self.combineChannelsAction) - + utilsMenu.addAction(self.renameAction) self.utilsMenu = utilsMenu utilsMenu.addSeparator() - utilsHelpAction = utilsMenu.addAction('Help...') + utilsHelpAction = utilsMenu.addAction("Help...") utilsHelpAction.triggered.connect(self.showUtilsHelp) - + menuBar.addMenu(utilsMenu) - + self.settingsMenu = QMenu("&Settings", self) self.settingsMenu.addAction(self.changeUserProfileFolderPathAction) self.settingsMenu.addAction(self.openUserProfileFolderAction) @@ -485,12 +518,11 @@ def createMenuBar(self): if SPOTMAX_INSTALLED: helpMenu.addAction(self.updateSPOTMAXAction) - utilsMenu.addAction(self.debugAction) - self.debugAction.setVisible(parser_args['debug']) + self.debugAction.setVisible(parser_args["debug"]) menuBar.addMenu(helpMenu) - + def showUtilsHelp(self): treeInfo = {} for action in self.utilsMenu.actions(): @@ -502,34 +534,35 @@ def showUtilsHelp(self): ) else: treeInfo = self._addActionToTree(action, treeInfo) - + self.utilsHelpWin = apps.TreeSelectorDialog( - title='Utilities help', + title="Utilities help", infoTxt="Double click on a utility's name to get help about it
", - parent=self, multiSelection=False, widthFactor=2, heightFactor=1.5 + parent=self, + multiSelection=False, + widthFactor=2, + heightFactor=1.5, ) self.utilsHelpWin.addTree(treeInfo) self.utilsHelpWin.sigItemDoubleClicked.connect(self._showUtilHelp) self.utilsHelpWin.exec_() - + def resetUserProfileFolderPath(self): from . import user_profile_path, user_home_path - + if os.path.samefile(user_profile_path, user_home_path): msg = widgets.myMessageBox() txt = html_utils.paragraph( - 'The user profile data is already in the default folder.' + "The user profile data is already in the default folder." ) - msg.warning(self, 'Reset user profile data', txt) + msg.warning(self, "Reset user profile data", txt) return - + acdc_folders = load.get_all_acdc_folders(user_profile_path) - acdc_folders_format = [ - f'   {folder}' for folder in acdc_folders - ] - acdc_folders_format = '
'.join(acdc_folders_format) - - txt = (f""" + acdc_folders_format = [f"   {folder}" for folder in acdc_folders] + acdc_folders_format = "
".join(acdc_folders_format) + + txt = f""" Current user profile path:

{user_profile_path}

The user profile contains the following Cell-ACDC folders:

@@ -537,86 +570,87 @@ def resetUserProfileFolderPath(self): After clicking "Ok" you Cell-ACDC will migrate the user profile data to the following folder:

{user_home_path}.
- """) - + """ + txt = html_utils.paragraph(txt) - + msg = widgets.myMessageBox(wrapText=False) msg.information( - self, 'Reset default user profile folder path', txt, - buttonsTexts=('Cancel', 'Ok') + self, + "Reset default user profile folder path", + txt, + buttonsTexts=("Cancel", "Ok"), ) if msg.cancel: - self.logger.info('Resetting user profile folder path cancelled.') + self.logger.info("Resetting user profile folder path cancelled.") return - - + new_user_profile_path = user_home_path - + self.startMigrateUserProfileWorker( user_profile_path, new_user_profile_path, acdc_folders ) - - def changeUserProfileFolderPath(self): + + def changeUserProfileFolderPath(self): acdc_folders = load.get_all_acdc_folders(user_profile_path) - acdc_folders_format = [ - f'   {folder}' for folder in acdc_folders - ] - acdc_folders_format = '
'.join(acdc_folders_format) - - txt = (f""" + acdc_folders_format = [f"   {folder}" for folder in acdc_folders] + acdc_folders_format = "
".join(acdc_folders_format) + + txt = f""" Current user profile path:

{user_profile_path}

The user profile contains the following Cell-ACDC folders:

{acdc_folders_format}

After clicking "Ok" you will be asked to select the folder where you want to migrate the user profile data.
- """) - + """ + txt = html_utils.paragraph(txt) - + msg = widgets.myMessageBox(wrapText=False) msg.information( - self, 'Change user profile folder path', txt, - buttonsTexts=('Cancel', 'Ok') + self, "Change user profile folder path", txt, buttonsTexts=("Cancel", "Ok") ) if msg.cancel: - self.logger.info('Changing user profile folder path cancelled.') + self.logger.info("Changing user profile folder path cancelled.") return from qtpy.compat import getexistingdirectory + new_user_profile_path = getexistingdirectory( parent=self, - caption='Select folder for user profile data', - basedir=user_profile_path + caption="Select folder for user profile data", + basedir=user_profile_path, ) if not new_user_profile_path: - self.logger.info('Changing user profile folder path cancelled.') + self.logger.info("Changing user profile folder path cancelled.") return - + if os.path.samefile(user_profile_path, new_user_profile_path): msg = widgets.myMessageBox() txt = html_utils.paragraph( - 'The user profile data is already in the selected folder.' + "The user profile data is already in the selected folder." ) - msg.warning(self, 'Change user profile data folder', txt) + msg.warning(self, "Change user profile data folder", txt) return - + self.startMigrateUserProfileWorker( user_profile_path, new_user_profile_path, acdc_folders ) - + def startMigrateUserProfileWorker(self, src_path, dst_path, acdc_folders): self.progressWin = apps.QDialogWorkerProgress( - title='Migrate user profile data', parent=self, - pbarDesc='Migrating user profile data...', - showInnerPbar=True + title="Migrate user profile data", + parent=self, + pbarDesc="Migrating user profile data...", + showInnerPbar=True, ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) - + from . import workers - self.workerName = 'Migrating user profile data' + + self.workerName = "Migrating user profile data" self._thread = QtCore.QThread() self.migrateWorker = workers.MigrateUserProfileWorker( src_path, dst_path, acdc_folders @@ -625,27 +659,21 @@ def startMigrateUserProfileWorker(self, src_path, dst_path, acdc_folders): self.migrateWorker.finished.connect(self._thread.quit) self.migrateWorker.finished.connect(self.migrateWorker.deleteLater) self._thread.finished.connect(self._thread.deleteLater) - + self.migrateWorker.progress.connect(self.workerProgress) self.migrateWorker.critical.connect(self.workerCritical) self.migrateWorker.finished.connect(self.migrateWorkerFinished) - - self.migrateWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.migrateWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.migrateWorker.signals.sigInitInnerPbar.connect( - self.workerInitInnerPbar - ) + + self.migrateWorker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.migrateWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.migrateWorker.signals.sigInitInnerPbar.connect(self.workerInitInnerPbar) self.migrateWorker.signals.sigUpdateInnerPbar.connect( self.workerUpdateInnerPbar ) - + self._thread.started.connect(self.migrateWorker.run) self._thread.start() - + def workerInitProgressbar(self, totalIter): self.progressWin.mainPbar.setValue(0) if totalIter == 1: @@ -654,16 +682,16 @@ def workerInitProgressbar(self, totalIter): def workerUpdateProgressbar(self, step): self.progressWin.mainPbar.update(step) - + def workerInitInnerPbar(self, totalIter): self.progressWin.innerPbar.setValue(0) if totalIter == 1: totalIter = 0 self.progressWin.innerPbar.setMaximum(totalIter) - + def workerUpdateInnerPbar(self, step): self.progressWin.innerPbar.update(step) - + def migrateWorkerFinished(self, worker): self.workerFinished() msg = widgets.myMessageBox(wrapText=False) @@ -671,27 +699,34 @@ def migrateWorkerFinished(self, worker): To make this change effective, please restart Cell-ACDC.

Thanks! """) - self.statusBarLayout.addWidget(QLabel(html_utils.paragraph( - 'Restart Cell-ACDC for the change to take effect', - font_color='red' - ))) - msg.information(self, 'Restart Cell-ACDC', txt) - + self.statusBarLayout.addWidget( + QLabel( + html_utils.paragraph( + "Restart Cell-ACDC for the change to take effect", + font_color="red", + ) + ) + ) + msg.information(self, "Restart Cell-ACDC", txt) + def _showUtilHelp(self, item): if item.parent() is None: return utilityName = item.text(0) infoText = html_utils.paragraph(utilsInfo[utilityName]) - runUtilityButton = widgets.playPushButton('Run utility...') + runUtilityButton = widgets.playPushButton("Run utility...") msg = widgets.myMessageBox(showCentered=False, wrapText=False) msg.information( - self.utilsHelpWin, f'"{utilityName}" help', infoText, - buttonsTexts=(runUtilityButton, 'Close'), showDialog=False + self.utilsHelpWin, + f'"{utilityName}" help', + infoText, + buttonsTexts=(runUtilityButton, "Close"), + showDialog=False, ) runUtilityButton.utilityName = utilityName runUtilityButton.clicked.connect(self._runUtility) msg.exec_() - + def _runUtility(self): self.utilsHelpWin.ok_cb() utilityName = self.sender().utilityName @@ -708,15 +743,15 @@ def _runUtility(self): else: action.trigger() break - + def _addActionToTree(self, action, treeInfo, parentMenu=None): if action.isSeparator(): return treeInfo - + text = action.text() if text not in utilsInfo: return treeInfo - + if parentMenu is None: treeInfo[text] = [] elif parentMenu.title() not in treeInfo: @@ -726,115 +761,109 @@ def _addActionToTree(self, action, treeInfo, parentMenu=None): return treeInfo def createActions(self): - self.changeUserProfileFolderPathAction = QAction( - 'Change user profile path...' - ) + self.changeUserProfileFolderPathAction = QAction("Change user profile path...") self.resetUserProfileFolderPathAction = QAction( - 'Reset default user profile path' + "Reset default user profile path" ) - self.npzToNpyAction = QAction('Convert .npz file(s) to .npy...') - self.npzToTiffAction = QAction('Convert .npz file(s) to .tif...') - self.TiffToNpzAction = QAction('Convert .tif file(s) to _segm.npz...') - self.h5ToNpzAction = QAction('Convert .h5 file(s) to _segm.npz...') + self.npzToNpyAction = QAction("Convert .npz file(s) to .npy...") + self.npzToTiffAction = QAction("Convert .npz file(s) to .tif...") + self.TiffToNpzAction = QAction("Convert .tif file(s) to _segm.npz...") + self.h5ToNpzAction = QAction("Convert .h5 file(s) to _segm.npz...") self.toImageJroiAction = QAction( - 'Convert Cell-ACDC segmentation file(s) (segm.npz) to ImageJ ROIs...' + "Convert Cell-ACDC segmentation file(s) (segm.npz) to ImageJ ROIs..." ) self.fromImageJroiAction = QAction( - 'Convert ImageJ ROIs to Cell-ACDC segmentation file(s) (segm.npz)...' + "Convert ImageJ ROIs to Cell-ACDC segmentation file(s) (segm.npz)..." ) self.toObjsCoordsAction = QAction( - 'Convert .npz segmentation file(s) to object coordinates (CSV)...' + "Convert .npz segmentation file(s) to object coordinates (CSV)..." ) - + self.fucciPreprocessAction = QAction( - 'Combine FUCCI channels and enhance nuclear signal...' - ) - + "Combine FUCCI channels and enhance nuclear signal..." + ) + self.customPreprocessAction = QAction( - 'Setup and run custom image preprocessing...' + "Setup and run custom image preprocessing..." ) self.combineChannelsAction = QAction( - 'Combine and manipulate channels and/or segmentation files...' + "Combine and manipulate channels and/or segmentation files..." ) - + self.countObjectsInSegmAction = QAction( - 'Count objects in segmentation mask and save to CSV file...' + "Count objects in segmentation mask and save to CSV file..." ) - + self.createConnected3Dsegm = QAction( - 'Create connected 3D segmentation mask from z-slices segmentation...' - ) - self.fillHolesInSegmAction = QAction( - 'Fill holes in segmentation masks...' + "Create connected 3D segmentation mask from z-slices segmentation..." ) + self.fillHolesInSegmAction = QAction("Fill holes in segmentation masks...") self.filterObjsFromTableAction = QAction( - 'Filter segmented objects using a table of coordinates (e.g., centroids)...' - ) + "Filter segmented objects using a table of coordinates (e.g., centroids)..." + ) self.stack2Dto3DsegmAction = QAction( - 'Stack 2D segmentation objects into 3D objects...' - ) + "Stack 2D segmentation objects into 3D objects..." + ) self.trackSubCellFeaturesAction = QAction( - 'Track and/or count sub-cellular objects (assign same ID as the ' - 'cell they belong to)...' - ) + "Track and/or count sub-cellular objects (assign same ID as the " + "cell they belong to)..." + ) self.applyTrackingFromTableAction = QAction( - 'Apply tracking info from tabular data...' + "Apply tracking info from tabular data..." ) self.applyTrackingFromTrackMateXMLAction = QAction( - 'Apply tracking info from TrackMate XML file...' + "Apply tracking info from TrackMate XML file..." ) self.batchConverterAction = QAction( - 'Create required data structure from image files...' + "Create required data structure from image files..." ) self.repeatDataPrepAction = QAction( - 'Re-apply data prep steps to selected channels...' + "Re-apply data prep steps to selected channels..." ) # self.TiffToHDFAction = QAction('Convert .tif file(s) to .h5py...') self.concatAcdcDfsAction = QAction( - 'Concatenate acdc output tables from multiple Positions and experiments...' + "Concatenate acdc output tables from multiple Positions and experiments..." ) if SPOTMAX_INSTALLED: self.concatSpotmaxDfsAction = QAction( - 'Concatenate spotMAX output tables from multiple Positions and experiments...' + "Concatenate spotMAX output tables from multiple Positions and experiments..." ) self.calcMetricsAcdcDf = QAction( - 'Compute measurements for one or more experiments...' + "Compute measurements for one or more experiments..." ) self.combineMetricsMultiChannelAction = QAction( - 'Combine measurements from multiple segmentation files...' + "Combine measurements from multiple segmentation files..." ) self.generateMothBudTotTableAction = QAction( - 'Generate mothers, buds, and total cell table...' + "Generate mothers, buds, and total cell table..." ) self.toSymDivAction = QAction( - 'Add lineage tree table to one or more experiments...' + "Add lineage tree table to one or more experiments..." ) - self.renameAction = QAction('Rename files by appending additional text...') - self.alignAction = QAction('Align or revert alignment...') + self.renameAction = QAction("Rename files by appending additional text...") + self.alignAction = QAction("Align or revert alignment...") - self.arboretumAction = QAction( - 'View lineage tree in napari-arboretum...' - ) + self.arboretumAction = QAction("View lineage tree in napari-arboretum...") self.resizeImagesAction = QAction( - 'Resize images (downscale or upscale) in one or more experiments...' - ) - self.welcomeGuideAction = QAction('Welcome Guide') - self.userManualAction = QAction('User documentation...') - self.aboutAction = QAction('About Cell-ACDC') - self.citeAction = QAction('Cite us...') - self.contributeAction = QAction('Contribute...') - self.showLogsAction = QAction('Show log files...') - self.openUserProfileFolderAction = QAction('Open user profile path...') - self.openSettingsFolderAction = QAction('Open settings folder...') - self.updateACDCAction = QAction('Update Cell-ACDC...') - self.updateSPOTMAXAction = QAction('Update SpotMAX...') - + "Resize images (downscale or upscale) in one or more experiments..." + ) + self.welcomeGuideAction = QAction("Welcome Guide") + self.userManualAction = QAction("User documentation...") + self.aboutAction = QAction("About Cell-ACDC") + self.citeAction = QAction("Cite us...") + self.contributeAction = QAction("Contribute...") + self.showLogsAction = QAction("Show log files...") + self.openUserProfileFolderAction = QAction("Open user profile path...") + self.openSettingsFolderAction = QAction("Open settings folder...") + self.updateACDCAction = QAction("Update Cell-ACDC...") + self.updateSPOTMAXAction = QAction("Update SpotMAX...") + if SPOTMAX_INSTALLED: - self.aboutSmaxAction = QAction('About SpotMAX') - - self.debugAction = QAction('Daje de mac') + self.aboutSmaxAction = QAction("About SpotMAX") + + self.debugAction = QAction("Daje de mac") def connectActions(self): self.changeUserProfileFolderPathAction.triggered.connect( @@ -846,39 +875,26 @@ def connectActions(self): self.alignAction.triggered.connect(self.launchAlignUtil) self.concatAcdcDfsAction.triggered.connect(self.launchConcatUtil) if SPOTMAX_INSTALLED: - self.concatSpotmaxDfsAction.triggered.connect( - self.launchConcatSpotmaxUtil - ) + self.concatSpotmaxDfsAction.triggered.connect(self.launchConcatSpotmaxUtil) self.npzToNpyAction.triggered.connect(self.launchConvertFormatUtil) self.npzToTiffAction.triggered.connect(self.launchConvertFormatUtil) self.TiffToNpzAction.triggered.connect(self.launchConvertFormatUtil) self.h5ToNpzAction.triggered.connect(self.launchConvertFormatUtil) - self.fromImageJroiAction.triggered.connect( - self.launchFromImageJroiToSegmUtil - ) + self.fromImageJroiAction.triggered.connect(self.launchFromImageJroiToSegmUtil) self.resizeImagesAction.triggered.connect(self.launchResizeUtil) self.toImageJroiAction.triggered.connect(self.launchToImageJroiUtil) - self.toObjsCoordsAction.triggered.connect( - self.launchToObjectsCoordsUtil - ) - - self.fucciPreprocessAction.triggered.connect( - self.launchFucciPreprocessUtil - ) - - self.customPreprocessAction.triggered.connect( - self.launchCustomPreprocessUtil - ) + self.toObjsCoordsAction.triggered.connect(self.launchToObjectsCoordsUtil) + + self.fucciPreprocessAction.triggered.connect(self.launchFucciPreprocessUtil) + + self.customPreprocessAction.triggered.connect(self.launchCustomPreprocessUtil) + + self.combineChannelsAction.triggered.connect(self.launchCombineChannelsUtil) - self.combineChannelsAction.triggered.connect( - self.launchCombineChannelsUtil - ) - - self.countObjectsInSegmAction.triggered.connect( self.launchCountObjectsInSegmActionUtil ) - + self.createConnected3Dsegm.triggered.connect( self.launchConnected3DsegmActionUtil ) @@ -888,9 +904,7 @@ def connectActions(self): self.stack2Dto3DsegmAction.triggered.connect( self.launchStack2Dto3DsegmActionUtil ) - self.fillHolesInSegmAction.triggered.connect( - self.launchFillHolesActionUtil - ) + self.fillHolesInSegmAction.triggered.connect(self.launchFillHolesActionUtil) self.trackSubCellFeaturesAction.triggered.connect( self.launchTrackSubCellFeaturesUtil ) @@ -899,14 +913,10 @@ def connectActions(self): ) self.generateMothBudTotTableAction.triggered.connect( self.launchGenerateMothBudTotTableUtil - ) - - self.batchConverterAction.triggered.connect( - self.launchImageBatchConverter - ) - self.repeatDataPrepAction.triggered.connect( - self.launchRepeatDataPrep - ) + ) + + self.batchConverterAction.triggered.connect(self.launchImageBatchConverter) + self.repeatDataPrepAction.triggered.connect(self.launchRepeatDataPrep) self.welcomeGuideAction.triggered.connect(self.launchWelcomeGuide) self.toSymDivAction.triggered.connect(self.launchToSymDicUtil) self.calcMetricsAcdcDf.triggered.connect(self.launchCalcMetricsUtil) @@ -915,16 +925,14 @@ def connectActions(self): if SPOTMAX_INSTALLED: self.aboutSmaxAction.triggered.connect(self.showAboutSmax) - self.userManualAction.triggered.connect(myutils.browse_docs) + self.userManualAction.triggered.connect(utils.browse_docs) self.contributeAction.triggered.connect(self.showContribute) self.citeAction.triggered.connect( partial(QDesktopServices.openUrl, QUrl(cite_url)) ) self.recentPathsMenu.aboutToShow.connect(self.populateOpenRecent) self.showLogsAction.triggered.connect(self.showLogFiles) - self.openUserProfileFolderAction.triggered.connect( - self.openUserProfileFolder - ) + self.openUserProfileFolderAction.triggered.connect(self.openUserProfileFolder) self.openSettingsFolderAction.triggered.connect(self.openSettingsFolder) self.updateACDCAction.triggered.connect(self.launchUpdateACDC) if SPOTMAX_INSTALLED: @@ -935,35 +943,40 @@ def connectActions(self): self.applyTrackingFromTrackMateXMLAction.triggered.connect( self.launchApplyTrackingFromTrackMateXML ) - + self.debugAction.triggered.connect(self._debug) - + def openSettingsFolder(self): from . import settings_folderpath - myutils.showInExplorer(settings_folderpath) - + + utils.showInExplorer(settings_folderpath) + def openUserProfileFolder(self): from . import user_profile_path - myutils.showInExplorer(user_profile_path) - + + utils.showInExplorer(user_profile_path) + def showLogFiles(self): - logs_path = myutils.get_logs_path() - myutils.showInExplorer(logs_path) - + logs_path = utils.get_logs_path() + utils.showInExplorer(logs_path) + def launchUpdateSpotmax(self): - res = myutils.update_package(self, 'spotmax',) + res = utils.update_package( + self, + "spotmax", + ) if res: - self.showUpdateInfo('spotMAX') + self.showUpdateInfo("spotMAX") else: - self.showNoUpdateInfo('spotMAX') + self.showNoUpdateInfo("spotMAX") def launchUpdateACDC(self): - res = myutils.update_package(self, 'cellacdc') + res = utils.update_package(self, "cellacdc") if res: - self.showUpdateInfo('Cell-ACDC') + self.showUpdateInfo("Cell-ACDC") else: - self.showNoUpdateInfo('Cell-ACDC') - + self.showNoUpdateInfo("Cell-ACDC") + def showNoUpdateInfo(self, package_name): msg = widgets.myMessageBox() txt = html_utils.paragraph(f""" @@ -971,7 +984,7 @@ def showNoUpdateInfo(self, package_name): It is recommended to install git for a better update experience.
Download Git """) - msg.information(self, f'No update for {package_name} performed', txt) + msg.information(self, f"No update for {package_name} performed", txt) def showUpdateInfo(self, package_name): msg = widgets.myMessageBox() @@ -979,18 +992,18 @@ def showUpdateInfo(self, package_name): {package_name} has been updated.
Please restart the application for the changes to take effect. """) - msg.information(self, f'Update {package_name}', txt) + msg.information(self, f"Update {package_name}", txt) def populateOpenRecent(self): # Step 0. Remove the old options from the menu self.recentPathsMenu.clear() # Step 1. Read recent Paths - recentPaths_path = os.path.join(settings_folderpath, 'recentPaths.csv') + recentPaths_path = os.path.join(settings_folderpath, "recentPaths.csv") if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - recentPaths = df['path'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + recentPaths = df["path"].to_list() else: recentPaths = [] # Step 2. Dynamically create the actions @@ -999,7 +1012,7 @@ def populateOpenRecent(self): if not os.path.exists(path): continue action = QAction(path, self) - action.triggered.connect(partial(myutils.showInExplorer, path)) + action.triggered.connect(partial(utils.showInExplorer, path)) actions.append(action) # Step 3. Add the actions to the menu self.recentPathsMenu.addActions(actions) @@ -1011,12 +1024,13 @@ def showContribute(self): def showAbout(self): self.aboutWin = about.QDialogAbout(parent=self) self.aboutWin.show() - + def showAboutSmax(self): from spotmax.dialogs import AboutSpotMAXDialog + win = AboutSpotMAXDialog(parent=self) win.exec_() - + def getSelectedPosPath(self, utilityName): msg = widgets.myMessageBox() txt = html_utils.paragraph(""" @@ -1024,27 +1038,23 @@ def getSelectedPosPath(self, utilityName): to select one position folder that contains timelapse data. """) - msg.information( - self, f'{utilityName}', txt, - buttonsTexts=('Cancel', 'Ok') - ) + msg.information(self, f"{utilityName}", txt, buttonsTexts=("Cancel", "Ok")) if msg.cancel: - print(f'{utilityName} aborted by the user.') + print(f"{utilityName} aborted by the user.") return - - mostRecentPath = myutils.getMostRecentPath() + + mostRecentPath = utils.getMostRecentPath() exp_path = QFileDialog.getExistingDirectory( - self, 'Select Position_n folder', - mostRecentPath + self, "Select Position_n folder", mostRecentPath ) if not exp_path: - print(f'{utilityName} aborted by the user.') + print(f"{utilityName} aborted by the user.") return - - myutils.addToRecentPaths(exp_path) + + utils.addToRecentPaths(exp_path) baseFolder = os.path.basename(exp_path) - isPosFolder = re.search(r'Position_(\d+)$', baseFolder) is not None - isImagesFolder = baseFolder == 'Images' + isPosFolder = re.search(r"Position_(\d+)$", baseFolder) is not None + isImagesFolder = baseFolder == "Images" if isImagesFolder: posPath = os.path.dirname(exp_path) posFolders = [os.path.basename(posPath)] @@ -1054,48 +1064,44 @@ def getSelectedPosPath(self, utilityName): posFolders = [os.path.basename(posPath)] exp_path = os.path.dirname(exp_path) else: - posFolders = myutils.get_pos_foldernames(exp_path) + posFolders = utils.get_pos_foldernames(exp_path) if not posFolders: msg = widgets.myMessageBox() - msg.addShowInFileManagerButton( - exp_path, txt='Show selected folder...' - ) + msg.addShowInFileManagerButton(exp_path, txt="Show selected folder...") _ls = "\n".join(os.listdir(exp_path)) - msg.setDetailedText(f'Files present in the folder:\n{_ls}') + msg.setDetailedText(f"Files present in the folder:\n{_ls}") txt = html_utils.paragraph(f""" The selected folder:

{exp_path}

does not contain any valid Position folders.
""") msg.warning( - self, 'Not valid folder', txt, - buttonsTexts=('Cancel', 'Try again') + self, "Not valid folder", txt, buttonsTexts=("Cancel", "Try again") ) if msg.cancel: - print(f'{utilityName} aborted by the user.') + print(f"{utilityName} aborted by the user.") return if len(posFolders) > 1: win = apps.QDialogCombobox( - 'Select position folder', posFolders, 'Select position folder', - 'Positions: ', parent=self + "Select position folder", + posFolders, + "Select position folder", + "Positions: ", + parent=self, ) win.exec_() posPath = os.path.join(exp_path, win.selectedItemText) else: posPath = os.path.join(exp_path, posFolders[0]) - + return posPath - def getSelectedExpPaths( - self, utilityName, - exp_folderpath=None, - custom_txt=None - ): + def getSelectedExpPaths(self, utilityName, exp_folderpath=None, custom_txt=None): # self._debug() - + if exp_folderpath is None: - self.logger.info('Asking to select experiment folders...') + self.logger.info("Asking to select experiment folders...") msg = widgets.myMessageBox() if custom_txt: txt = html_utils.paragraph(custom_txt) @@ -1106,54 +1112,51 @@ def getSelectedExpPaths( Next, you will be able to choose specific Positions from each selected experiment. """) - msg.information( - self, f'{utilityName}', txt, - buttonsTexts=('Cancel', 'Ok') - ) + msg.information(self, f"{utilityName}", txt, buttonsTexts=("Cancel", "Ok")) if msg.cancel: - self.logger.info(f'{utilityName} aborted by the user.') + self.logger.info(f"{utilityName} aborted by the user.") return - + expPaths = {} - mostRecentPath = myutils.getMostRecentPath() + mostRecentPath = utils.getMostRecentPath() warn_exp_already_selected = True while True: if exp_folderpath is None: exp_path = qtpy.compat.getexistingdirectory( - parent=self, - caption='Select experiment folder containing Position_n folders', + parent=self, + caption="Select experiment folder containing Position_n folders", basedir=mostRecentPath, # options=QFileDialog.DontUseNativeDialog ) if not exp_path: break - myutils.addToRecentPaths(exp_path) - else: + utils.addToRecentPaths(exp_path) + else: exp_path = exp_folderpath selected_path = exp_path baseFolder = os.path.basename(exp_path) - isPosFolder = myutils.is_pos_folderpath(exp_path) - isImagesFolder = baseFolder == 'Images' + isPosFolder = utils.is_pos_folderpath(exp_path) + isImagesFolder = baseFolder == "Images" if isImagesFolder: posPath = os.path.dirname(exp_path) posFolders = [os.path.basename(posPath)] exp_path = os.path.dirname(posPath) - selected_exp_paths = {exp_path:posFolders} + selected_exp_paths = {exp_path: posFolders} elif isPosFolder: posPath = exp_path posFolders = [os.path.basename(posPath)] exp_path = os.path.dirname(exp_path) - selected_exp_paths = {exp_path:posFolders} + selected_exp_paths = {exp_path: posFolders} else: self.logger.info(f'Scanning selected folder "{exp_path}"...') selected_exp_paths = path.get_posfolderpaths_walk(exp_path) if not selected_exp_paths: cancel = self.warnNoValidExpPaths(exp_path) if cancel: - self.logger.info(f'{utilityName} aborted by the user.') + self.logger.info(f"{utilityName} aborted by the user.") return continue - + is_multi_pos = False for exp_path, pos_folders in selected_exp_paths.items(): if exp_path in expPaths: @@ -1162,68 +1165,67 @@ def getSelectedExpPaths( selected_path, exp_path ) if not proceed: - self.logger.info(f'{utilityName} aborted by the user.') + self.logger.info(f"{utilityName} aborted by the user.") return warn_exp_already_selected = False expPaths[exp_path].extend(pos_folders) else: expPaths[exp_path] = pos_folders - + if len(pos_folders) > 1 and not is_multi_pos: is_multi_pos = True - + mostRecentPath = exp_path msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(""" Do you want to select additional experiment folders? """) noButton, yesButton = msg.question( - self, 'Select additional experiments?', txt, - buttonsTexts=('No', 'Yes') + self, "Select additional experiments?", txt, buttonsTexts=("No", "Yes") ) if msg.clickedButton == noButton: break - + if not expPaths: - self.logger.info(f'{utilityName} aborted by the user.') + self.logger.info(f"{utilityName} aborted by the user.") return if len(expPaths) > 1 or is_multi_pos: infoPaths = self.getInfoPosStatus(expPaths, utilityName) selectPosWin = apps.selectPositionsMultiExp( - expPaths, - infoPaths=infoPaths, - parent=self + expPaths, infoPaths=infoPaths, parent=self ) selectPosWin.exec_() if selectPosWin.cancel: - self.logger.info(f'{utilityName} aborted by the user.') + self.logger.info(f"{utilityName} aborted by the user.") return selectedExpPaths = selectPosWin.selectedPaths else: selectedExpPaths = expPaths - + return selectedExpPaths - + def warnNoValidExpPaths(self, selected_path): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(""" The selected folder does not contain any valid experiment folders. """) - command = selected_path.replace('\\', os.sep) - command = selected_path.replace('/', os.sep) + command = selected_path.replace("\\", os.sep) + command = selected_path.replace("/", os.sep) msg.warning( - self, 'No valid folders found', txt, - buttonsTexts=('Cancel', 'Try again'), - commands=(command,), - path_to_browse=selected_path + self, + "No valid folders found", + txt, + buttonsTexts=("Cancel", "Try again"), + commands=(command,), + path_to_browse=selected_path, ) return msg.cancel - + def warnExpPathAlreadySelected(self, selected_path, exp_path): - selected_text = myutils.to_relative_path(selected_path) - exp_text = myutils.to_relative_path(exp_path) + selected_text = utils.to_relative_path(selected_path) + exp_text = utils.to_relative_path(exp_path) txt = html_utils.paragraph(f""" The experiment folder of the selected path was already previously selected.

Are you adding Position folders one by one? If yes, you do not @@ -1238,27 +1240,28 @@ def warnExpPathAlreadySelected(self, selected_path, exp_path): """) msg = widgets.myMessageBox(wrapText=False) msg.warning( - self, 'Folder already selected!', txt, - buttonsTexts=('Cancel', 'Yes'), - path_to_browse=selected_path + self, + "Folder already selected!", + txt, + buttonsTexts=("Cancel", "Yes"), + path_to_browse=selected_path, ) return not msg.cancel - + def _debug(self): try: from . import _q_debug + _q_debug.q_debug(self) except Exception as err: raise err - + def askRestartAcdc(self): - txt = html_utils.paragraph( - 'Are you sure you want to restart Cell-ACDC?
' - ) + txt = html_utils.paragraph("Are you sure you want to restart Cell-ACDC?
") msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Restart?', txt, buttonsTexts=('Cancel', 'Yes')) + msg.warning(self, "Restart?", txt, buttonsTexts=("Cancel", "Yes")) return msg.cancel - + def keyPressEvent(self, event): modifiers = QGuiApplication.keyboardModifiers() ctrl_shift = modifiers == Qt.ControlModifier | Qt.ShiftModifier @@ -1270,190 +1273,202 @@ def keyPressEvent(self, event): self.close() return return super().keyPressEvent(event) - + def launchApplyTrackingFromTrackMateXML(self): - posPath = self.getSelectedPosPath('Apply tracking info from tabular data') + posPath = self.getSelectedPosPath("Apply tracking info from tabular data") if posPath is None: return - - title = 'Apply tracking info from TrackMate XML file utility' - infoText = 'Launching apply tracking info from from TrackMate XML data...' + + title = "Apply tracking info from TrackMate XML file utility" + infoText = "Launching apply tracking info from from TrackMate XML data..." self.applyTrackMateXMLWin = ( utilsApplyTrackFromTrackMate.ApplyTrackingInfoFromTrackMateUtil( - self.app, title, infoText, parent=self, - callbackOnFinished=self.applyTrackingFromTackmateXMLFinished + self.app, + title, + infoText, + parent=self, + callbackOnFinished=self.applyTrackingFromTackmateXMLFinished, ) ) self.applyTrackMateXMLWin.show() func = partial( - self._runApplyTrackingFromTrackMateXML, posPath, - self.applyTrackMateXMLWin + self._runApplyTrackingFromTrackMateXML, posPath, self.applyTrackMateXMLWin ) QTimer.singleShot(200, func) - + def _runApplyTrackingFromTrackMateXML(self, posPath, win): success = win.run(posPath) if not success: self.logger.info( - 'Apply tracking info from TrackMate XML cancelled by the user.' + "Apply tracking info from TrackMate XML cancelled by the user." ) - win.close() - + win.close() + def launchApplyTrackingFromTableUtil(self): - posPath = self.getSelectedPosPath('Apply tracking info from tabular data') + posPath = self.getSelectedPosPath("Apply tracking info from tabular data") if posPath is None: return - - title = 'Apply tracking info from tabular data utility' - infoText = 'Launching apply tracking info from tabular data...' - self.applyTrackWin = ( - utilsApplyTrackFromTab.ApplyTrackingInfoFromTableUtil( - self.app, title, infoText, parent=self, - callbackOnFinished=self.applyTrackingFromTableFinished - ) + + title = "Apply tracking info from tabular data utility" + infoText = "Launching apply tracking info from tabular data..." + self.applyTrackWin = utilsApplyTrackFromTab.ApplyTrackingInfoFromTableUtil( + self.app, + title, + infoText, + parent=self, + callbackOnFinished=self.applyTrackingFromTableFinished, ) self.applyTrackWin.show() - func = partial( - self._runApplyTrackingFromTableUtil, posPath, self.applyTrackWin - ) + func = partial(self._runApplyTrackingFromTableUtil, posPath, self.applyTrackWin) QTimer.singleShot(200, func) def _runApplyTrackingFromTableUtil(self, posPath, win): success = win.run(posPath) if not success: self.logger.info( - 'Apply tracking info from tabular data cancelled by the user.' + "Apply tracking info from tabular data cancelled by the user." ) - win.close() - + win.close() + def applyTrackingFromTackmateXMLFinished(self): msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph( - 'Apply tracking info from TrackMate XML data completed.' + "Apply tracking info from TrackMate XML data completed." ) - msg.information(self, 'Process completed', txt) - self.logger.info('Apply tracking info from TrackMate XML data completed.') + msg.information(self, "Process completed", txt) + self.logger.info("Apply tracking info from TrackMate XML data completed.") self.applyTrackMateXMLWin.close() - + def applyTrackingFromTableFinished(self): msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph( - 'Apply tracking info from tabular data completed.' - ) - msg.information(self, 'Process completed', txt) - self.logger.info('Apply tracking info from tabular data completed.') + txt = html_utils.paragraph("Apply tracking info from tabular data completed.") + msg.information(self, "Process completed", txt) + self.logger.info("Apply tracking info from tabular data completed.") self.applyTrackWin.close() - + def launchNapariUtil(self, action): - myutils.check_install_package('napari', parent=self) + utils.check_install_package("napari", parent=self) if action == self.arboretumAction: self._launchArboretum() def _launchArboretum(self): - myutils.check_install_package('napari_arboretum', parent=self) + utils.check_install_package("napari_arboretum", parent=self) from cellacdc.napari_utils import arboretum - - posPath = self.getSelectedPosPath('napari-arboretum') + + posPath = self.getSelectedPosPath("napari-arboretum") if posPath is None: return - title = 'napari-arboretum utility' - infoText = 'Launching napari-arboretum to visualize lineage tree...' + title = "napari-arboretum utility" + infoText = "Launching napari-arboretum to visualize lineage tree..." self.arboretumWindow = arboretum.NapariArboretumDialog( posPath, self.app, title, infoText, parent=self ) self.arboretumWindow.show() - + def launchToObjectsCoordsUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'From _segm.npz to objects coordinates (CSV)' + "From _segm.npz to objects coordinates (CSV)" ) if selectedExpPaths is None: return - - title = 'Convert _segm.npz file(s) to objects coordinates (CSV)' - infoText = 'Launching to to objects coordinates process...' + + title = "Convert _segm.npz file(s) to objects coordinates (CSV)" + infoText = "Launching to to objects coordinates process..." progressDialogueTitle = ( - 'Converting _segm.npz file(s) to to objects coordinates (CSV)' + "Converting _segm.npz file(s) to to objects coordinates (CSV)" ) self.toObjCoordsWin = utilsToObjCoords.toObjCoordsUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.toObjCoordsWin.show() - + def launchFromImageJroiToSegmUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - myutils.check_install_package('roifile', parent=self) + utils.check_install_package("roifile", parent=self) import roifile - selectedExpPaths = self.getSelectedExpPaths( - 'From ImageJ ROIs to _segm.npz' - ) + selectedExpPaths = self.getSelectedExpPaths("From ImageJ ROIs to _segm.npz") if selectedExpPaths is None: return - - title = 'Convert ImageJ ROIs to _segm.npz file(s)' - infoText = 'Launching ImageJ ROIs conversion process...' - progressDialogueTitle = 'Converting ImageJ ROIs to _segm.npz file(s)' + + title = "Convert ImageJ ROIs to _segm.npz file(s)" + infoText = "Launching ImageJ ROIs conversion process..." + progressDialogueTitle = "Converting ImageJ ROIs to _segm.npz file(s)" self.toImageJroiWin = utilsFromImageJroi.fromImageJRoiToSegmUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) - self.toImageJroiWin.show() - + self.toImageJroiWin.show() + def launchResizeUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - - selectedExpPaths = self.getSelectedExpPaths( - 'From _segm.npz to ImageJ ROIs' - ) + + selectedExpPaths = self.getSelectedExpPaths("From _segm.npz to ImageJ ROIs") if selectedExpPaths is None: return - - title = 'Resize images' - infoText = 'Launching resizing images process...' - progressDialogueTitle = 'Resize images' + + title = "Resize images" + infoText = "Launching resizing images process..." + progressDialogueTitle = "Resize images" self.resizeUtilWin = utilsResizePositionsUtil.ResizePositionsUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.resizeUtilWin.show() - + def launchToImageJroiUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - myutils.check_install_package('roifile', parent=self) + utils.check_install_package("roifile", parent=self) import roifile - selectedExpPaths = self.getSelectedExpPaths( - 'From _segm.npz to ImageJ ROIs' - ) + selectedExpPaths = self.getSelectedExpPaths("From _segm.npz to ImageJ ROIs") if selectedExpPaths is None: return - - title = 'Convert _segm.npz file(s) to ImageJ ROIs' - infoText = 'Launching to ImageJ ROIs process...' - progressDialogueTitle = 'Converting _segm.npz file(s) to ImageJ ROIs' + + title = "Convert _segm.npz file(s) to ImageJ ROIs" + infoText = "Launching to ImageJ ROIs process..." + progressDialogueTitle = "Converting _segm.npz file(s) to ImageJ ROIs" self.toImageJroiWin = utilsToImageJroi.toImageRoiUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.toImageJroiWin.show() - + def launchGenerateMothBudTotTableUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - - title = 'Generate mothers, buds, and total cell table' - infoText = 'Launching generate mothers, buds, and total cell table...' + + title = "Generate mothers, buds, and total cell table" + infoText = "Launching generate mothers, buds, and total cell table..." self.genMothBudTotalTableWin = ( utilsGenerateMothBudTotTable.GenerateMothBudTotalUtil( - self.app, title, infoText, parent=self, - callbackOnFinished=self.generateMothBudTotTableFinished + self.app, + title, + infoText, + parent=self, + callbackOnFinished=self.generateMothBudTotTableFinished, ) ) self.genMothBudTotalTableWin.show() @@ -1461,74 +1476,82 @@ def launchGenerateMothBudTotTableUtil(self): self._runGenerateMothBudTotTableUtil, self.genMothBudTotalTableWin ) QTimer.singleShot(200, func) - + def _runGenerateMothBudTotTableUtil(self, win): success = win.run() if not success: self.logger.info( - 'Generating mothers, buds, and total cell table cancelled by the user.' + "Generating mothers, buds, and total cell table cancelled by the user." ) - win.close() - + win.close() + def generateMothBudTotTableFinished(self): msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph( - 'Generating mothers, buds, and total cell table completed.' - ) - msg.information(self, 'Process completed', txt) - self.logger.info( - 'Generating mothers, buds, and total cell table completed.' + "Generating mothers, buds, and total cell table completed." ) + msg.information(self, "Process completed", txt) + self.logger.info("Generating mothers, buds, and total cell table completed.") self.genMothBudTotalTableWin.close() - + def launchCombineMeatricsMultiChanneliUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Combine measurements from multiple channels' + "Combine measurements from multiple channels" ) if selectedExpPaths is None: return - - title = 'Compute measurements from multiple channels' - infoText = 'Launching compute measurements from multiple channels process...' - progressDialogueTitle = 'Compute measurements from multiple channels' + + title = "Compute measurements from multiple channels" + infoText = "Launching compute measurements from multiple channels process..." + progressDialogueTitle = "Compute measurements from multiple channels" self.multiChannelWin = utilsComputeMultiCh.ComputeMetricsMultiChannel( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.multiChannelWin.show() - + def launchFucciPreprocessUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - selectedExpPaths = self.getSelectedExpPaths( - 'Combine FUCCI channels' - ) + selectedExpPaths = self.getSelectedExpPaths("Combine FUCCI channels") if selectedExpPaths is None: return - - title = 'Combine FUCCI channels' - infoText = 'Launching Combine FUCCI channels process...' - progressDialogueTitle = 'Combining FUCCI channels' + + title = "Combine FUCCI channels" + infoText = "Launching Combine FUCCI channels process..." + progressDialogueTitle = "Combining FUCCI channels" self.fucciPreprocessWin = utilsFucciPreprocess.FucciPreprocessUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.fucciPreprocessWin.show() - + def launchCustomPreprocessUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Pre-process images with custom recipe' + "Pre-process images with custom recipe" ) if selectedExpPaths is None: return - - title = 'Pre-process images with custom recipe' - infoText = 'Launching Pre-process images with custom recipe process...' - progressDialogueTitle = 'Pre-process images' + + title = "Pre-process images with custom recipe" + infoText = "Launching Pre-process images with custom recipe process..." + progressDialogueTitle = "Pre-process images" self.customPreprocessWin = utilsCustomPreprocess.CustomPreprocessUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.customPreprocessWin.show() @@ -1542,147 +1565,175 @@ def launchCombineChannelsUtil(self): recepies will be applied to all of them. """ selectedExpPaths = self.getSelectedExpPaths( - 'Combine and manipulate channels and/or segmentation files', - custom_txt=custom_txt + "Combine and manipulate channels and/or segmentation files", + custom_txt=custom_txt, ) if selectedExpPaths is None: return - - title = 'Combine and manipulate channels and/or segmentation files' - infoText = 'Launching combine and manipulate channels utility...' - progressDialogueTitle = 'Combine and manipulate channels and/or segmentation files' + + title = "Combine and manipulate channels and/or segmentation files" + infoText = "Launching combine and manipulate channels utility..." + progressDialogueTitle = ( + "Combine and manipulate channels and/or segmentation files" + ) self.CombineChannelsWin = utilsCombineChannels.CombineChannelsUtil( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.CombineChannelsWin.show() - + def launchConnected3DsegmActionUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Create connected 3D segmentation mask' + "Create connected 3D segmentation mask" ) if selectedExpPaths is None: return - - title = 'Create connected 3D segmentation mask' - infoText = 'Launching connected 3D segmentation mask creation process...' - progressDialogueTitle = 'Creating connected 3D segmentation mask' + + title = "Create connected 3D segmentation mask" + infoText = "Launching connected 3D segmentation mask creation process..." + progressDialogueTitle = "Creating connected 3D segmentation mask" self.connected3DsegmWin = utilsConnected3Dsegm.CreateConnected3Dsegm( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.connected3DsegmWin.show() - + def launchCountObjectsInSegmActionUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Create connected 3D segmentation mask' + "Create connected 3D segmentation mask" ) if selectedExpPaths is None: return - - title = 'Count objects in segmentation mask' - infoText = 'Launching count objects in segmentation masks process...' - progressDialogueTitle = 'Counting objects in segmentation mask' + + title = "Count objects in segmentation mask" + infoText = "Launching count objects in segmentation masks process..." + progressDialogueTitle = "Counting objects in segmentation mask" self.connected3DsegmWin = utilsCountObjectsInSegm.CountObjectsInsegm( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.connected3DsegmWin.show() - + def launchFillHolesActionUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - selectedExpPaths = self.getSelectedExpPaths( - 'Fill holes in segmentation masks' - ) + selectedExpPaths = self.getSelectedExpPaths("Fill holes in segmentation masks") if selectedExpPaths is None: return - title = 'Fill holes in segmentation masks' - infoText = 'Launching fill holes in segmentation masks process...' - progressDialogueTitle = 'Filling holes in segmentation masks' + title = "Fill holes in segmentation masks" + infoText = "Launching fill holes in segmentation masks process..." + progressDialogueTitle = "Filling holes in segmentation masks" self.fillHolesWin = fillHolesInSegm.fillHolesInSegm( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.fillHolesWin.show() def launchFilterObjsFromTableActionUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Create connected 3D segmentation mask' + "Create connected 3D segmentation mask" ) if selectedExpPaths is None: return - - title = 'Filter segmented objects from coordinates' - infoText = 'Launching Filter segmented objects from coordinates process...' - progressDialogueTitle = 'Filtering objects' + + title = "Filter segmented objects from coordinates" + infoText = "Launching Filter segmented objects from coordinates process..." + progressDialogueTitle = "Filtering objects" self.filterObjsFromTableWin = ( - utilsFilterObjsFromTable.FilterObjsFromCoordsTable( - selectedExpPaths, self.app, title, infoText, - progressDialogueTitle, parent=self + utilsFilterObjsFromTable.FilterObjsFromCoordsTable( + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) ) self.filterObjsFromTableWin.show() - + def launchStack2Dto3DsegmActionUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Create 3D segmentation mask from 2D' + "Create 3D segmentation mask from 2D" ) if selectedExpPaths is None: return - + SizeZwin = apps.NumericEntryDialog( - title='Number of z-slices', - instructions='Enter number of z-slices required', - currentValue=1, parent=self, - stretch=True + title="Number of z-slices", + instructions="Enter number of z-slices required", + currentValue=1, + parent=self, + stretch=True, ) SizeZwin.exec_() if SizeZwin.cancel: return - - title = 'Create stacked 3D segmentation mask' - infoText = 'Launching stacked 3D segmentation mask creation process...' - progressDialogueTitle = 'Creating stacked 3D segmentation mask' + + title = "Create stacked 3D segmentation mask" + infoText = "Launching stacked 3D segmentation mask creation process..." + progressDialogueTitle = "Creating stacked 3D segmentation mask" self.stack2DsegmWin = utilsStack2Dto3D.Stack2DsegmTo3Dsegm( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - SizeZwin.value, parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + SizeZwin.value, + parent=self, ) self.stack2DsegmWin.show() def launchTrackSubCellFeaturesUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - selectedExpPaths = self.getSelectedExpPaths( - 'Track sub-cellular objects' - ) + selectedExpPaths = self.getSelectedExpPaths("Track sub-cellular objects") if selectedExpPaths is None: return - + win = apps.TrackSubCellObjectsDialog() win.exec_() if win.cancel: return - - title = 'Track sub-cellular objects' - infoText = 'Launching sub-cellular objects tracker...' - progressDialogueTitle = 'Tracking sub-cellular objects' + + title = "Track sub-cellular objects" + infoText = "Launching sub-cellular objects tracker..." + progressDialogueTitle = "Tracking sub-cellular objects" self.trackSubCellObjWin = utilsTrackSubCell.TrackSubCellFeatures( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - win.trackSubCellObjParams, parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + win.trackSubCellObjParams, + parent=self, ) self.trackSubCellObjWin.show() - def launchCalcMetricsUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - selectedExpPaths = self.getSelectedExpPaths('Compute measurements utility') + selectedExpPaths = self.getSelectedExpPaths("Compute measurements utility") if selectedExpPaths is None: return - + self._lauchCalcMetricsUtil(selectedExpPaths) def _lauchCalcMetricsUtil(self, selectedExpPaths): @@ -1690,10 +1741,10 @@ def _lauchCalcMetricsUtil(self, selectedExpPaths): selectedExpPaths, self.app, parent=self ) self.calcMeasWin.show() - + def launchToSymDicUtil(self): self.logger.info(f'Launching utility "{self.sender().text()}"') - selectedExpPaths = self.getSelectedExpPaths('Lineage tree utility') + selectedExpPaths = self.getSelectedExpPaths("Lineage tree utility") if selectedExpPaths is None: return @@ -1701,18 +1752,18 @@ def launchToSymDicUtil(self): selectedExpPaths, self.app, parent=self ) self.toSymDivWin.show() - + def getInfoPosStatus(self, expPaths, utilityName): - if 'spotmax' in utilityName.lower(): - caller = 'SpotMAX' + if "spotmax" in utilityName.lower(): + caller = "SpotMAX" else: - caller = 'Cell-ACDC' + caller = "Cell-ACDC" infoPaths = {} for exp_path, posFoldernames in expPaths.items(): posFoldersInfo = {} for pos in posFoldernames: pos_path = os.path.join(exp_path, pos) - status = myutils.get_pos_status(pos_path, caller=caller) + status = utils.get_pos_status(pos_path, caller=caller) posFoldersInfo[pos] = status infoPaths[exp_path] = posFoldersInfo return infoPaths @@ -1722,9 +1773,7 @@ def launchRenameUtil(self): if isUtilnabled: self.sender().setDisabled(True) self.renameWin = utilsRename.renameFilesWin( - parent=self, - actionToEnable=self.sender(), - mainWin=self + parent=self, actionToEnable=self.sender(), mainWin=self ) self.renameWin.show() self.renameWin.main() @@ -1735,7 +1784,7 @@ def launchRenameUtil(self): def launchConvertFormatUtil(self, checked=False): s = self.sender().text() - m = re.findall(r'Convert \.(\w+) file\(s\) to (.*)\.(\w+)...', s) + m = re.findall(r"Convert \.(\w+) file\(s\) to (.*)\.(\w+)...", s) from_, info, to = m[0] isConvertEnabled = self.sender().isEnabled() if isConvertEnabled: @@ -1743,8 +1792,10 @@ def launchConvertFormatUtil(self, checked=False): self.convertWin = utilsConvert.convertFileFormatWin( parent=self, actionToEnable=self.sender(), - mainWin=self, from_=from_, to=to, - info=info + mainWin=self, + from_=from_, + to=to, + info=info, ) self.convertWin.show() self.convertWin.main() @@ -1752,50 +1803,49 @@ def launchConvertFormatUtil(self, checked=False): geometry = self.convertWin.saveGeometry() self.convertWin.setWindowState(Qt.WindowActive) self.convertWin.restoreGeometry(geometry) - + def launchImageBatchConverter(self): self.batchConverterWin = utilsConvert.ImagesToPositions(parent=self) self.batchConverterWin.show() - + def launchRepeatDataPrep(self): self.batchConverterWin = utilsRepeat.repeatDataPrepWindow(parent=self) self.batchConverterWin.show() def launchDataStruct(self, checked=False): self.dataStructButton.setPalette(self.moduleLaunchedPalette) - self.dataStructButton.setText( - '0. Creating data structure running...' - ) + self.dataStructButton.setText("0. Creating data structure running...") QTimer.singleShot(100, self._showDataStructWin) def _showDataStructWin(self): msg = widgets.myMessageBox(wrapText=False, showCentered=False) - bioformats_url = 'https://www.openmicroscopy.org/bio-formats/' - bioformats_href = html_utils.href_tag( - 'Bio-Formats', bioformats_url - ) - - bioio_url = 'https://bioio-devs.github.io/bioio/' - bioio_href = html_utils.href_tag('BioIO', bioio_url) - - aicsimageio_url = 'https://allencellmodeling.github.io/aicsimageio/#' - aicsimageio_href = html_utils.href_tag('AICSImageIO', aicsimageio_url) - - acdc_fiji_macros_url = 'https://cell-acdc.readthedocs.io/en/latest/data-structure-fiji.html' + bioformats_url = "https://www.openmicroscopy.org/bio-formats/" + bioformats_href = html_utils.href_tag("Bio-Formats", bioformats_url) + + bioio_url = "https://bioio-devs.github.io/bioio/" + bioio_href = html_utils.href_tag("BioIO", bioio_url) + + aicsimageio_url = "https://allencellmodeling.github.io/aicsimageio/#" + aicsimageio_href = html_utils.href_tag("AICSImageIO", aicsimageio_url) + + acdc_fiji_macros_url = ( + "https://cell-acdc.readthedocs.io/en/latest/data-structure-fiji.html" + ) acdc_fiji_macros_href = html_utils.href_tag( - 'Cell-ACDC Fiji macros guide', acdc_fiji_macros_url + "Cell-ACDC Fiji macros guide", acdc_fiji_macros_url ) - + conda_important_admon = html_utils.to_admonition( f""" Java can be installed only using conda! If you are not using conda and the file format of your files requires Bio-Formats,
you will need to use the provided ImageJ/Fiji macros.
See this guide for more information: {acdc_fiji_macros_href} - """, 'important' + """, + "important", ) - + issues_href = f'GitHub page' txt = html_utils.paragraph(f""" To process microscopy files, Cell-ACDC uses the {bioio_href} library.

@@ -1822,27 +1872,24 @@ def _showDataStructWin(self): # useAICSImageIO = QPushButton( # QIcon(':AICS_logo.svg'), ' Use AICSImageIO ', msg # ) - useBioFormatsButton = QPushButton( - QIcon(':ome.svg'), ' Use BioIO ', msg - ) + useBioFormatsButton = QPushButton(QIcon(":ome.svg"), " Use BioIO ", msg) restructButton = QPushButton( - QIcon(':folders.svg'), ' Re-structure image files ', msg + QIcon(":folders.svg"), " Re-structure image files ", msg ) buttons = [useBioFormatsButton, restructButton] if is_mac: useFijiMacroButton = QPushButton( - QIcon(':fiji-logo.svg'), ' Use Fiji Macro ', msg + QIcon(":fiji-logo.svg"), " Use Fiji Macro ", msg ) buttons.insert(1, useFijiMacroButton) msg.question( - self, 'How to structure files', txt, - buttonsTexts=('Cancel', *buttons) + self, "How to structure files", txt, buttonsTexts=("Cancel", *buttons) ) if msg.cancel: - self.logger.info('Creating data structure process aborted by the user.') + self.logger.info("Creating data structure process aborted by the user.") self.restoreDefaultButtons() return - + useBioFormats = msg.clickedButton == useBioFormatsButton useFijiMacro = False if is_mac: @@ -1852,63 +1899,61 @@ def _showDataStructWin(self): self.dataStructWin = dataStruct.createDataStructWin( parent=self, version=self._version ) - if self.dataStructWin.bioformats_backend == 'python-bioformats': + if self.dataStructWin.bioformats_backend == "python-bioformats": self.dataStructButton.setPalette(self.defaultButtonPalette) self.dataStructButton.setText( - '0. Restart Cell-ACDC to enable module 0 again.') + "0. Restart Cell-ACDC to enable module 0 again." + ) self.dataStructButton.setToolTip( - 'Due to an interal limitation of the Java Virtual Machine\n' - 'moduel 0 can be launched only once.\n' - 'To use it again close and reopen Cell-ACDC' + "Due to an interal limitation of the Java Virtual Machine\n" + "moduel 0 can be launched only once.\n" + "To use it again close and reopen Cell-ACDC" ) self.dataStructButton.setDisabled(True) - + self.dataStructWin.show() self.dataStructWin.main() - if self.dataStructWin.bioformats_backend != 'python-bioformats': + if self.dataStructWin.bioformats_backend != "python-bioformats": self.restoreDefaultButtons() elif useFijiMacro: self.runFijiMacroWorkflow() if msg.clickedButton == restructButton: self.progressWin = apps.QDialogWorkerProgress( - title='Re-structure image files log', parent=self, - pbarDesc='Re-structuring image files running...' + title="Re-structure image files log", + parent=self, + pbarDesc="Re-structuring image files running...", ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) - self.workerName = 'Re-structure image files' + self.workerName = "Re-structure image files" success = dataReStruct.run(self) if not success: self.progressWin.workerFinished = True self.progressWin.close() self.restoreDefaultButtons() - self.logger.info('Re-structuring files NOT completed.') - + self.logger.info("Re-structuring files NOT completed.") + def runFijiMacroWorkflow(self): self.progressWin = apps.QDialogWorkerProgress( - title='Initialising Fiji', - parent=self, - pbarDesc='Initialising Fiji...' + title="Initialising Fiji", parent=self, pbarDesc="Initialising Fiji..." ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(0) - + QTimer.singleShot(100, self._runFijiMacroWindow) - + def _runFijiMacroWindow(self): - self.dataStructWin = ( - dataStruct.InitFijiMacro(self) - ) + self.dataStructWin = dataStruct.InitFijiMacro(self) self.dataStructWin.run() self.progressWin.workerFinished = True self.progressWin.close() self.restoreDefaultButtons() self.progressWin = None - + def progressWinClosed(self): self.progressWin = None self._gc_collect() - + def workerInitProgressbar(self, totalIter): if self.progressWin is None: return @@ -1917,52 +1962,48 @@ def workerInitProgressbar(self, totalIter): if totalIter == 1: totalIter = 0 self.progressWin.mainPbar.setMaximum(totalIter) - + def workerFinished(self, worker=None): msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph( - f'{self.workerName} process finished.' - ) - msg.information(self, 'Process finished', txt) + txt = html_utils.paragraph(f"{self.workerName} process finished.") + msg.information(self, "Process finished", txt) if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() - + self.restoreDefaultButtons() - + @exception_handler def workerCritical(self, error): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() - raise error - + raise error + def workerUpdateProgressbar(self, step): if self.progressWin is None: return self.progressWin.mainPbar.update(step) - - def workerProgress(self, text, loggerLevel='INFO'): + + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: self.progressWin.logConsole.append(text) self.logger.log(getattr(logging, loggerLevel), text) def restoreDefaultButtons(self): self.dataStructButton.setText( - '0. Create data structure from microscopy/image file(s)...' + "0. Create data structure from microscopy/image file(s)..." ) self.dataStructButton.setPalette(self.defaultButtonPalette) def launchDataPrep(self, checked=False): - dataPrepWin = dataPrep.dataPrepWin( - mainWin=self, version=self._version - ) + dataPrepWin = dataPrep.dataPrepWin(mainWin=self, version=self._version) dataPrepWin.sigClose.connect(self.dataPrepClosed) dataPrepWin.show() self.dataPrepWins.append(dataPrepWin) - + def dataPrepClosed(self, dataPrepWin): try: self.dataPrepWins.remove(dataPrepWin) @@ -1977,13 +2018,15 @@ def launchSegm(self, checked=False): defaultText = self.defaultTextSegmButton if c != self.moduleLaunchedColor: self.segmButton.setPalette(self.moduleLaunchedPalette) - self.segmButton.setText('Segmentation is running. ' - 'Check progress in the terminal/console') + self.segmButton.setText( + "Segmentation is running. Check progress in the terminal/console" + ) self.segmWin = segm.segmWin( buttonToRestore=(self.segmButton, defaultColor, defaultText), - mainWin=self, version=self._version + mainWin=self, + version=self._version, ) - self.segmWin.sigClosed.connect(self.segmWinClosed) + self.segmWin.sigClosed.connect(self.segmWinClosed) self.segmWin.show() self.segmWin.main() else: @@ -1996,63 +2039,67 @@ def segmWinClosed(self): self._gc_collect() def launchGui(self, checked=False): - self.logger.info('Opening GUI...') + self.logger.info("Opening GUI...") guiWin = gui.guiWin( - self.app, mainWin=self, version=self._version, - launcherSlot=self.launchGui + self.app, mainWin=self, version=self._version, launcherSlot=self.launchGui ) self.guiWins.append(guiWin) guiWin.sigClosed.connect(self.guiClosed) guiWin.run() - + def launchSpotmaxGui(self, checked=False): from spotmax import icon_path, logo_path # logoDialog = apps.LogoDialog(logo_path, icon_path, parent=self) - + splashScreen = QSplashScreen() splashScreen.setPixmap(QPixmap(logo_path)) splashScreen.show() QTimer.singleShot(300, partial(self._launchSpotMaxGui, splashScreen)) - + def _launchSpotMaxGui(self, splashScreen): - self.logger.info('Launching spotMAX...') + self.logger.info("Launching spotMAX...") spotmaxWin = spotmaxRun.run_gui( - app=self.app, mainWin=self, launcherSlot=self.launchSpotmaxGui, - + app=self.app, + mainWin=self, + launcherSlot=self.launchSpotmaxGui, ) spotmaxWin.sigClosed.connect(self.spotmaxGuiClosed) self.spotmaxWins.append(spotmaxWin) splashScreen.close() - + def spotmaxGuiClosed(self, spotmaxWin): self.spotmaxWins.remove(spotmaxWin) self._gc_collect() - + def guiClosed(self, guiWin): try: self.guiWins.remove(guiWin) except ValueError: pass self._gc_collect() - + def _gc_collect(self): QTimer.singleShot(100, gc.collect) def launchAlignUtil(self, checked=False): self.logger.info(f'Launching utility "{self.sender().text()}"') selectedExpPaths = self.getSelectedExpPaths( - 'Align frames in X and Y with phase cross-correlation' + "Align frames in X and Y with phase cross-correlation" ) if selectedExpPaths is None: return - - title = 'Align frames' - infoText = 'Aligning frames in X and Y with phase cross-correlation...' - progressDialogueTitle = 'Align frames' + + title = "Align frames" + infoText = "Aligning frames in X and Y with phase cross-correlation..." + progressDialogueTitle = "Align frames" self.alignWindow = utilsAlign.alignWin( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.alignWindow.show() @@ -2061,40 +2108,48 @@ def launchConcatUtil(self, checked=False, exp_folderpath=None): f'Launching utility "Concatenate tables from multipe positions"' ) selectedExpPaths = self.getSelectedExpPaths( - 'Concatenate acdc_output files', exp_folderpath=exp_folderpath + "Concatenate acdc_output files", exp_folderpath=exp_folderpath ) if selectedExpPaths is None: return - - title = 'Concatenate acdc_output files' - infoText = 'Launching concatenate acdc_output files process...' - progressDialogueTitle = 'Concatenate acdc_output files' + + title = "Concatenate acdc_output files" + infoText = "Launching concatenate acdc_output files process..." + progressDialogueTitle = "Concatenate acdc_output files" self.concatWindow = utilsConcat.ConcatWin( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.concatWindow.show() - + def launchConcatSpotmaxUtil(self, checked=False, exp_folderpath=None): self.logger.info( f'Launching utility "Concatenate tables from multipe positions"' ) selectedExpPaths = self.getSelectedExpPaths( - 'Concatenate spotMAX output files', + "Concatenate spotMAX output files", exp_folderpath=exp_folderpath, ) if selectedExpPaths is None: return - - title = 'Concatenate spotMAX output files' - infoText = 'Launching concatenate spotMAX output files process...' - progressDialogueTitle = 'Concatenate spotMAX output files' + + title = "Concatenate spotMAX output files" + infoText = "Launching concatenate spotMAX output files process..." + progressDialogueTitle = "Concatenate spotMAX output files" self.concatWindow = utilsConcat.ConcatWin( - selectedExpPaths, self.app, title, infoText, progressDialogueTitle, - parent=self + selectedExpPaths, + self.app, + title, + infoText, + progressDialogueTitle, + parent=self, ) self.concatWindow.show() - + def showEvent(self, event): self.showAllWindows() # self.setFocus() @@ -2102,55 +2157,56 @@ def showEvent(self, event): if not self.checkUserDataFolderPath: return self.checkMigrateUserDataFolderPath() - + def checkMigrateUserDataFolderPath(self): from . import user_home_path + user_home_acdc_folders = load.get_all_acdc_folders(user_home_path) if not user_home_acdc_folders: self.checkUserDataFolderPath = False return - if 'doNotAskMigrate' in self.df_settings.index: - if str(self.df_settings.at['doNotAskMigrate', 'value']) == 'Yes': + if "doNotAskMigrate" in self.df_settings.index: + if str(self.df_settings.at["doNotAskMigrate", "value"]) == "Yes": return - + msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'Starting from version 1.4.0, Cell-ACDC default user profile path ' - f'has been changed to {user_profile_path}

' - 'Since you have some profile data saved in the old path, Cell-ACDC ' - 'can now migrate everything to the new folder.

' - 'Do you want to migrate now?
' + "Starting from version 1.4.0, Cell-ACDC default user profile path " + f"has been changed to {user_profile_path}

" + "Since you have some profile data saved in the old path, Cell-ACDC " + "can now migrate everything to the new folder.

" + "Do you want to migrate now?
" ) acdc_folders_format = [ - f'   {os.path.join(user_home_path, folder)}' + f"   {os.path.join(user_home_path, folder)}" for folder in user_home_acdc_folders ] - acdc_folders_format = '
'.join(acdc_folders_format) + acdc_folders_format = "
".join(acdc_folders_format) detailsText = ( - f'Folders found in the previous location:

{acdc_folders_format}' + f"Folders found in the previous location:

{acdc_folders_format}" ) - doNotAskAgainCheckbox = QCheckBox('Do not ask again') + doNotAskAgainCheckbox = QCheckBox("Do not ask again") msg.warning( - self, 'Migrate old user profile', txt, - buttonsTexts=('Cancel', 'Yes'), + self, + "Migrate old user profile", + txt, + buttonsTexts=("Cancel", "Yes"), detailsText=detailsText, - widgets=doNotAskAgainCheckbox + widgets=doNotAskAgainCheckbox, ) if doNotAskAgainCheckbox.isChecked(): - self.df_settings.at['doNotAskMigrate', 'value'] = 'Yes' + self.df_settings.at["doNotAskMigrate", "value"] = "Yes" self.df_settings.to_csv(settings_csv_path) if msg.cancel: - self.logger.info( - 'Migrating old user profile cancelled.' - ) + self.logger.info("Migrating old user profile cancelled.") self.checkUserDataFolderPath = False return self.startMigrateUserProfileWorker( user_home_path, user_profile_path, user_home_acdc_folders ) self.checkUserDataFolderPath = False - + def showAllWindows(self): openModules = self.getOpenModules() for win in openModules: @@ -2168,32 +2224,32 @@ def show(self): super().show() h = self.dataPrepButton.geometry().height() f = 1.5 - self.dataStructButton.setMinimumHeight(int(h*f)) - self.dataPrepButton.setMinimumHeight(int(h*f)) - self.segmButton.setMinimumHeight(int(h*f)) - self.guiButton.setMinimumHeight(int(h*f)) - if hasattr(self, 'spotmaxButton'): - self.spotmaxButton.setMinimumHeight(int(h*f)) - self.showAllWindowsButton.setMinimumHeight(int(h*f)) - self.restartButton.setMinimumHeight(int(int(h*f))) - self.closeButton.setMinimumHeight(int(int(h*f))) + self.dataStructButton.setMinimumHeight(int(h * f)) + self.dataPrepButton.setMinimumHeight(int(h * f)) + self.segmButton.setMinimumHeight(int(h * f)) + self.guiButton.setMinimumHeight(int(h * f)) + if hasattr(self, "spotmaxButton"): + self.spotmaxButton.setMinimumHeight(int(h * f)) + self.showAllWindowsButton.setMinimumHeight(int(h * f)) + self.restartButton.setMinimumHeight(int(int(h * f))) + self.closeButton.setMinimumHeight(int(int(h * f))) # iconWidth = int(self.closeButton.iconSize().width()*1.3) # self.closeButton.setIconSize(QSize(iconWidth, iconWidth)) self.setColorsAndText() self.readSettings() if self.app.toggle_dark_mode: - self.darkModeToggle.warnMessageBox = False + self.darkModeToggle.warnMessageBox = False self.darkModeToggle.setChecked(True) def saveWindowGeometry(self): - settings = QSettings('schmollerlab', 'acdc_main') + settings = QSettings("schmollerlab", "acdc_main") settings.setValue("geometry", self.saveGeometry()) def readSettings(self): - settings = QSettings('schmollerlab', 'acdc_main') - if settings.value('geometry') is not None: + settings = QSettings("schmollerlab", "acdc_main") + if settings.value("geometry") is not None: self.restoreGeometry(settings.value("geometry")) - + def getOpenModules(self): c2 = self.segmButton.palette().button().color().name() launchedColor = self.moduleLaunchedColor @@ -2209,7 +2265,6 @@ def getOpenModules(self): openModules.extend(self.spotmaxWins) return openModules - def checkOpenModules(self): openModules = self.getOpenModules() @@ -2218,11 +2273,11 @@ def checkOpenModules(self): msg = widgets.myMessageBox() warn_txt = html_utils.paragraph( - 'There are still other Cell-ACDC windows open.

' - 'Are you sure you want to close everything?' + "There are still other Cell-ACDC windows open.

" + "Are you sure you want to close everything?" ) _, yesButton = msg.warning( - self, 'Modules still open!', warn_txt, buttonsTexts=('Cancel', 'Yes') + self, "Modules still open!", warn_txt, buttonsTexts=("Cancel", "Yes") ) return msg.clickedButton == yesButton, openModules @@ -2253,9 +2308,9 @@ def closeEvent(self, event): restart() except Exception as e: traceback.print_exc() - print('-----------------------------------------') - print('Failed to restart Cell-ACDC. Please restart manually') + print("-----------------------------------------") + print("Failed to restart Cell-ACDC. Please restart manually") else: - self.logger.info('**********************************************') - self.logger.info(f'Cell-ACDC closed. {myutils.get_salute_string()}') - self.logger.info('**********************************************') + self.logger.info("**********************************************") + self.logger.info(f"Cell-ACDC closed. {utils.get_salute_string()}") + self.logger.info("**********************************************") diff --git a/cellacdc/_palettes.py b/cellacdc/_palettes.py index 5de708312..e8ed7273e 100644 --- a/cellacdc/_palettes.py +++ b/cellacdc/_palettes.py @@ -7,134 +7,147 @@ if GUI_INSTALLED: from qtpy import QtGui + def _highlight_rgba(): scheme = get_color_scheme() - if scheme == 'light': + if scheme == "light": return (207, 235, 155, 255) else: return (141, 196, 39, 255) + def _highlighted_text(): return (0, 0, 0, 255) + def base_color(): scheme = get_color_scheme() - if scheme == 'light': - return '#4d4d4d' + if scheme == "light": + return "#4d4d4d" else: - return '#d9d9d9' + return "#d9d9d9" + def _light_colors(): colors = { - 'Window': (239, 239, 239, 255), - 'WindowText': (0, 0, 0, 255), - 'Base': (255, 255, 255, 255), - 'AlternateBase': (247, 247, 247, 255), - 'ToolTipBase': (255, 255, 220, 255), - 'ToolTipText': (0, 0, 0, 255), - 'Text': (0, 0, 0, 255), - 'Button': (239, 239, 239, 255), - 'ButtonText': (0, 0, 0, 255), - 'BrightText': (255, 255, 255, 255), - 'Link': (0, 0, 255, 255), - 'Highlight': _highlight_rgba(), - 'HighlightedText': _highlighted_text() + "Window": (239, 239, 239, 255), + "WindowText": (0, 0, 0, 255), + "Base": (255, 255, 255, 255), + "AlternateBase": (247, 247, 247, 255), + "ToolTipBase": (255, 255, 220, 255), + "ToolTipText": (0, 0, 0, 255), + "Text": (0, 0, 0, 255), + "Button": (239, 239, 239, 255), + "ButtonText": (0, 0, 0, 255), + "BrightText": (255, 255, 255, 255), + "Link": (0, 0, 255, 255), + "Highlight": _highlight_rgba(), + "HighlightedText": _highlighted_text(), } return colors + def _get_highligth_header_background_rgba(): scheme = get_color_scheme() - if scheme == 'light': - window_rgba = _light_colors()['Window'] - return tuple([val-40 for val in window_rgba]) + if scheme == "light": + window_rgba = _light_colors()["Window"] + return tuple([val - 40 for val in window_rgba]) else: - window_rgba = _dark_colors()['Window'] - return tuple([val+40 for val in window_rgba]) + window_rgba = _dark_colors()["Window"] + return tuple([val + 40 for val in window_rgba]) + def _get_highligth_text_background_rgba(): scheme = get_color_scheme() - if scheme == 'light': - window_rgba = _light_colors()['Window'] - return tuple([val-20 for val in window_rgba]) + if scheme == "light": + window_rgba = _light_colors()["Window"] + return tuple([val - 20 for val in window_rgba]) else: - window_rgba = _dark_colors()['Window'] - return tuple([val+20 for val in window_rgba]) + window_rgba = _dark_colors()["Window"] + return tuple([val + 20 for val in window_rgba]) + def text_float_rgba(): scheme = get_color_scheme() - if scheme == 'light': - text_rgba = _light_colors()['Text'] - return tuple([val/255 for val in text_rgba]) + if scheme == "light": + text_rgba = _light_colors()["Text"] + return tuple([val / 255 for val in text_rgba]) else: - text_rgba = _dark_colors()['Text'] - return tuple([val/255 for val in text_rgba]) + text_rgba = _dark_colors()["Text"] + return tuple([val / 255 for val in text_rgba]) + def get_disabled_colors(): scheme = get_color_scheme() - if scheme == 'light': + if scheme == "light": return _light_disabled_colors() else: return _dark_disabled_colors() + def _light_disabled_colors(): disabled_colors = { - 'ButtonText': (150, 150, 150, 255), - 'WindowText': (128, 128, 128, 255), - 'Text': (150, 150, 150, 255), - 'Light': (255, 255, 255, 255), - 'Button': (230, 230, 230, 255), + "ButtonText": (150, 150, 150, 255), + "WindowText": (128, 128, 128, 255), + "Text": (150, 150, 150, 255), + "Light": (255, 255, 255, 255), + "Button": (230, 230, 230, 255), # 'Window': (200, 200, 200, 255), # 'Highlight': (0, 0, 0, 255), # 'HighlightedText': (0, 0, 0, 255), - } return disabled_colors + def _dark_disabled_colors(): disabled_colors = { - 'ButtonText': (150, 150, 150, 255), - 'WindowText': (128, 128, 128, 255), - 'Text': (128, 128, 128, 255), - 'Light': (53, 53, 53, 255), - 'Button': (70, 70, 70, 255), + "ButtonText": (150, 150, 150, 255), + "WindowText": (128, 128, 128, 255), + "Text": (128, 128, 128, 255), + "Light": (53, 53, 53, 255), + "Button": (70, 70, 70, 255), # 'Window': (0, 0, 0, 255), } return disabled_colors + def text_pen_color(): scheme = get_color_scheme() - if scheme == 'light': - return '#4d4d4d' + if scheme == "light": + return "#4d4d4d" else: - return '#d9d9d9' + return "#d9d9d9" + def _dark_colors(): colors = { - 'Window': (50, 50, 50, 255), - 'WindowText': (240, 240, 240, 255), - 'Base': (36, 36, 36, 255), - 'AlternateBase': (43, 43, 43, 255), - 'ToolTipBase': (255, 255, 220, 255), - 'ToolTipText': (0, 0, 0, 255), - 'Text': (240, 240, 240, 255), - 'Button': (50, 50, 50, 255), - 'ButtonText': (240, 240, 240, 255), - 'BrightText': (75, 75, 75, 255), - 'Link': (48, 140, 198, 255), - 'Highlight': _highlight_rgba(), - 'HighlightedText': _highlighted_text() + "Window": (50, 50, 50, 255), + "WindowText": (240, 240, 240, 255), + "Base": (36, 36, 36, 255), + "AlternateBase": (43, 43, 43, 255), + "ToolTipBase": (255, 255, 220, 255), + "ToolTipText": (0, 0, 0, 255), + "Text": (240, 240, 240, 255), + "Button": (50, 50, 50, 255), + "ButtonText": (240, 240, 240, 255), + "BrightText": (75, 75, 75, 255), + "Link": (48, 140, 198, 255), + "Highlight": _highlight_rgba(), + "HighlightedText": _highlighted_text(), } return colors + def getPainterColor(): scheme = get_color_scheme() - if scheme == 'light': - return _light_colors()['Text'] + if scheme == "light": + return _light_colors()["Text"] else: - return _dark_colors()['Text'] + return _dark_colors()["Text"] -def getPaletteColorScheme(palette: 'QtGui.QPalette', scheme='light'): - if scheme == 'light': + +def getPaletteColorScheme(palette: "QtGui.QPalette", scheme="light"): + if scheme == "light": colors = _light_colors() disabled_colors = _light_disabled_colors() else: @@ -149,96 +162,107 @@ def getPaletteColorScheme(palette: 'QtGui.QPalette', scheme='light'): palette.setColor(ColorGroup, colorRole, QtGui.QColor(*rgba)) return palette + def get_color_scheme(): if not os.path.exists(settings_csv_path): - return 'light' - df_settings = pd.read_csv(settings_csv_path, index_col='setting') - if 'colorScheme' not in df_settings.index: - return 'light' + return "light" + df_settings = pd.read_csv(settings_csv_path, index_col="setting") + if "colorScheme" not in df_settings.index: + return "light" else: - return df_settings.at['colorScheme', 'value'] - + return df_settings.at["colorScheme", "value"] + + def lineedit_background_hex(): scheme = get_color_scheme() - if scheme == 'light': - return r'{background:#ffffff;}' + if scheme == "light": + return r"{background:#ffffff;}" else: - return r'{background:#242424;}' + return r"{background:#242424;}" + def lineedit_invalid_entry_stylesheet(): return ( # 'background: #FEF9C3;' - 'border-radius: 4px;' - 'border: 1.5px solid red;' - 'padding: 1px 0px 1px 0px' + "border-radius: 4px;border: 1.5px solid red;padding: 1px 0px 1px 0px" ) -def lineedit_warning_stylesheet(): + +def lineedit_warning_stylesheet(): scheme = get_color_scheme() - if scheme == 'light': - stylesheet = 'background: #FEF9C3;' + if scheme == "light": + stylesheet = "background: #FEF9C3;" else: - stylesheet = 'background: #FEF9C3; color: black' + stylesheet = "background: #FEF9C3; color: black" return stylesheet -def setToolTipStyleSheet(app, scheme='light'): - if scheme == 'dark': - app.setStyleSheet(r"QToolTip {" + +def setToolTipStyleSheet(app, scheme="light"): + if scheme == "dark": + app.setStyleSheet( + r"QToolTip {" "color: #e6e6e6; background-color: #3c3c3c; border: 1px solid white;" - "}" + "}" ) else: - app.setStyleSheet(r"QToolTip {" + app.setStyleSheet( + r"QToolTip {" "color: #141414; background-color: #ffffff; border: 1px solid black;" - "}" + "}" ) + def green(): scheme = get_color_scheme() - if scheme == 'light': - return '#CFEB9B' + if scheme == "light": + return "#CFEB9B" else: - return '#607a2f' + return "#607a2f" + def TreeWidgetStyleSheet(): scheme = get_color_scheme() - if scheme == 'light': - styleSheet = (""" + if scheme == "light": + styleSheet = """ QTreeWidget::item:hover {background-color:#E6E6E6; color:black;} QTreeWidget::item:selected {background-color:#CFEB9B; color:black;} QTreeView { selection-background-color: #CFEB9B; show-decoration-selected: 1; } - """) + """ else: - styleSheet = (""" + styleSheet = """ QTreeWidget::item:hover {background-color:#E6E6E6; color:black;} QTreeWidget::item:selected {background-color:#8dc427; color:black;} QTreeView { selection-background-color: #8dc427; show-decoration-selected: 1; } - """) + """ return styleSheet + def ListWidgetStyleSheet(): styleSheet = TreeWidgetStyleSheet() - styleSheet = styleSheet.replace('QTreeWidget', 'QListWidget') - styleSheet = styleSheet.replace('QTreeView', 'QListView') + styleSheet = styleSheet.replace("QTreeWidget", "QListWidget") + styleSheet = styleSheet.replace("QTreeView", "QListView") return styleSheet + def QProgressBarColor(): styleSheet = TreeWidgetStyleSheet() - hex = re.findall(r'selection-background-color: (#[A-Za-z0-9]+)', styleSheet)[0] - return QtGui.QColor(hex) + hex = re.findall(r"selection-background-color: (#[A-Za-z0-9]+)", styleSheet)[0] + return QtGui.QColor(hex) + def QProgressBarHighlightedTextColor(): return QtGui.QColor(0, 0, 0, 255) + def moduleLaunchedButtonRgb(self): scheme = get_color_scheme() - if scheme == 'light': - return (241,221,0) + if scheme == "light": + return (241, 221, 0) else: - return (241,221,0) \ No newline at end of file + return (241, 221, 0) diff --git a/cellacdc/_process.py b/cellacdc/_process.py index 4b8f11e24..027f6f4cd 100644 --- a/cellacdc/_process.py +++ b/cellacdc/_process.py @@ -7,29 +7,37 @@ import argparse ap = argparse.ArgumentParser( - prog='Cell-ACDC process', description='Used to spawn a separate process', - formatter_class=argparse.RawTextHelpFormatter + prog="Cell-ACDC process", + description="Used to spawn a separate process", + formatter_class=argparse.RawTextHelpFormatter, ) ap.add_argument( - '-c', '--command', required=True, type=str, metavar='COMMAND', - help='String of commands separated by comma.' + "-c", + "--command", + required=True, + type=str, + metavar="COMMAND", + help="String of commands separated by comma.", ) ap.add_argument( - '-l', '--log_filepath', - default='', + "-l", + "--log_filepath", + default="", type=str, - metavar='LOG_FILEPATH', - help=('Path of an additional log file') + metavar="LOG_FILEPATH", + help=("Path of an additional log file"), ) -def worker(*commands): - subprocess.run(list(commands)) # [sys.executable, r'spotmax\test.py']) -if __name__ == '__main__': +def worker(*commands): + subprocess.run(list(commands)) # [sys.executable, r'spotmax\test.py']) + + +if __name__ == "__main__": args = vars(ap.parse_args()) - command = args['command'] - commands = command.split(',') + command = args["command"] + commands = command.split(",") commands = [command.lstrip() for command in commands] process = multiprocessing.Process(target=worker, args=commands) process.start() diff --git a/cellacdc/_profile/spline_to_obj/model.py b/cellacdc/_profile/spline_to_obj/model.py index 6583cba24..7a91a759b 100644 --- a/cellacdc/_profile/spline_to_obj/model.py +++ b/cellacdc/_profile/spline_to_obj/model.py @@ -18,40 +18,35 @@ pwd_path = os.path.dirname(os.path.abspath(__file__)) + class Model: def __init__(self): pass def fit(self): # Read data - filename = '1_exec_time_space_size_step_10.csv' + filename = "1_exec_time_space_size_step_10.csv" df_path = os.path.join(pwd_path, filename) df = pd.read_csv(df_path) - + # Define predictor and response variables - X_train = df[['bbox_area', 'exec_time']] - y_train = df['space_size'] + X_train = df[["bbox_area", "exec_time"]] + y_train = df["space_size"] # Scale the data scaler = StandardScaler().fit(X_train) - X_scaled = pd.DataFrame( - scaler.transform(X_train), columns=X_train.columns - ) + X_scaled = pd.DataFrame(scaler.transform(X_train), columns=X_train.columns) self.scaler = scaler # Define regression model and fit - model = tree.DecisionTreeRegressor() # LinearRegression() + model = tree.DecisionTreeRegressor() # LinearRegression() reg = model.fit(X_scaled, y_train) self.model = model def predict(self, bbox_area, max_exec_time=150): - X_pred = pd.DataFrame({ - 'bbox_area': [bbox_area], 'exec_time': [max_exec_time] - }) - X_scaled = pd.DataFrame( - self.scaler.transform(X_pred), columns=X_pred.columns - ) + X_pred = pd.DataFrame({"bbox_area": [bbox_area], "exec_time": [max_exec_time]}) + X_scaled = pd.DataFrame(self.scaler.transform(X_pred), columns=X_pred.columns) y_pred = self.model.predict(X_scaled) pred_space_size = y_pred[0] return pred_space_size diff --git a/cellacdc/_profile/spline_to_obj/profile_skimage_draw_polygon.py b/cellacdc/_profile/spline_to_obj/profile_skimage_draw_polygon.py index 44fbd6559..00d83a5d7 100644 --- a/cellacdc/_profile/spline_to_obj/profile_skimage_draw_polygon.py +++ b/cellacdc/_profile/spline_to_obj/profile_skimage_draw_polygon.py @@ -12,7 +12,7 @@ pwd_path = os.path.dirname(os.path.abspath(__file__)) -img = np.zeros((1000,1000), dtype=np.uint8) +img = np.zeros((1000, 1000), dtype=np.uint8) dfs = [] keys = [] @@ -20,25 +20,27 @@ space_size_step = 10 square_side_range_min, square_side_range_max = 10, 600 -for space_size in tqdm(np.arange(10,1001,10), ncols=100): +for space_size in tqdm(np.arange(10, 1001, 10), ncols=100): bbox_areas = [] exec_times = [] space = np.linspace(0, 1, space_size) - for side in tqdm(np.arange(square_side_range_min,square_side_range_min+1,2), ncols=100): + for side in tqdm( + np.arange(square_side_range_min, square_side_range_min + 1, 2), ncols=100 + ): img[:] = 0 - half_side = int(side/2) + half_side = int(side / 2) - left = 500-half_side - right = 500+half_side - - anchors_xx = [left,right,right,left,left] - anchors_yy = [left,left,right,right,left] + left = 500 - half_side + right = 500 + half_side + + anchors_xx = [left, right, right, left, left] + anchors_yy = [left, left, right, right, left] bbox_area = side**2 bbox_areas.append(bbox_area) - tck, u = scipy.interpolate.splprep( + tck, u = scipy.interpolate.splprep( [anchors_xx, anchors_yy], s=0, k=3, per=False ) xi, yi = scipy.interpolate.splev(space, tck) @@ -47,7 +49,7 @@ rr, cc = skimage.draw.polygon(yi, xi, shape=img.shape) t1 = time.perf_counter() - exec_times.append((t1-t0)*1000) + exec_times.append((t1 - t0) * 1000) img[rr, cc] = 2 @@ -57,19 +59,21 @@ img[rr, cc] = 1 - df = pd.DataFrame({'bbox_area': bbox_areas, 'exec_time': exec_times}).set_index('bbox_area') + df = pd.DataFrame({"bbox_area": bbox_areas, "exec_time": exec_times}).set_index( + "bbox_area" + ) dfs.append(df) keys.append(space_size) -final_df = pd.concat(dfs, keys=keys, names=['space_size', 'bbox_area']) +final_df = pd.concat(dfs, keys=keys, names=["space_size", "bbox_area"]) df_filename = ( - f'side_range_{square_side_range_min}-{square_side_range_max}_' - f'space_size_step_{space_size_step}.csv' + f"side_range_{square_side_range_min}-{square_side_range_max}_" + f"space_size_step_{space_size_step}.csv" ) final_df.to_csv(os.path.join(pwd_path, df_filename)) -plt.plot(xi, yi, c='r') +plt.plot(xi, yi, c="r") plt.imshow(img) -plt.show() \ No newline at end of file +plt.show() diff --git a/cellacdc/_run.py b/cellacdc/_run.py index 3e615cf9c..78f0aa84f 100644 --- a/cellacdc/_run.py +++ b/cellacdc/_run.py @@ -4,51 +4,54 @@ from importlib import import_module import traceback from tqdm import tqdm -from . import config, myutils +from . import config, utils -def _install_tables(parent_software='Cell-ACDC'): + +def _install_tables(parent_software="Cell-ACDC"): from . import try_input_install_package, is_conda_env + try: import tables + return False except Exception as e: - if parent_software == 'Cell-ACDC': - issues_url = 'https://github.com/SchmollerLab/Cell_ACDC/issues' + if parent_software == "Cell-ACDC": + issues_url = "https://github.com/SchmollerLab/Cell_ACDC/issues" note_txt = ( - 'If the installation fails, you can still use Cell-ACDC, but we ' - 'highly recommend you report the issue (see link below) and we ' - 'will be very happy to help. Thank you for your patience!' + "If the installation fails, you can still use Cell-ACDC, but we " + "highly recommend you report the issue (see link below) and we " + "will be very happy to help. Thank you for your patience!" ) else: - issues_url = 'https://github.com/SchmollerLab/Cell_ACDC/issues' + issues_url = "https://github.com/SchmollerLab/Cell_ACDC/issues" note_txt = ( - 'If the installation fails, report the issue (see link below) and we ' - 'will be very happy to help. Thank you for your patience!' + "If the installation fails, report the issue (see link below) and we " + "will be very happy to help. Thank you for your patience!" ) while True: txt = ( - f'{parent_software} needs to install a library called `tables`.\n\n' - f'{note_txt}\n\n' - f'Report issue here: {issues_url}\n' + f"{parent_software} needs to install a library called `tables`.\n\n" + f"{note_txt}\n\n" + f"Report issue here: {issues_url}\n" ) - print('-'*60) + print("-" * 60) print(txt) - conda_prefix, pip_prefix = myutils.get_pip_conda_prefix() - conda_list, pip_list = myutils.get_pip_conda_prefix(list_return=True) + conda_prefix, pip_prefix = utils.get_pip_conda_prefix() + conda_list, pip_list = utils.get_pip_conda_prefix(list_return=True) - conda_txt = f'{conda_prefix} pytables' - pip_text = f'{pip_prefix} --upgrade tables' + conda_txt = f"{conda_prefix} pytables" + pip_text = f"{pip_prefix} --upgrade tables" - conda_list = conda_list + ['pytables'] - pip_list = pip_list + ['--upgrade', 'tables'] + conda_list = conda_list + ["pytables"] + pip_list = pip_list + ["--upgrade", "tables"] if is_conda_env(): command_txt = conda_txt alt_command_txt = pip_text cmd_args = [command_txt] alt_cmd_args1 = conda_list alt_cmd_args2 = pip_list - pkg_mng = 'conda' - alt_pkg_mng = 'pip' + pkg_mng = "conda" + alt_pkg_mng = "pip" shell = True alt_shell = False else: @@ -57,63 +60,65 @@ def _install_tables(parent_software='Cell-ACDC'): cmd_args = pip_list alt_cmd_args1 = conda_list alt_cmd_args2 = [alt_command_txt] - pkg_mng = 'pip' - alt_pkg_mng = 'conda' + pkg_mng = "pip" + alt_pkg_mng = "conda" shell = False alt_shell = True - - answer = try_input_install_package('tables', command_txt) - - if answer.lower() == 'y' or not answer: + + answer = try_input_install_package("tables", command_txt) + + if answer.lower() == "y" or not answer: import subprocess, traceback + try: subprocess.check_call(cmd_args, shell=shell) break except Exception as err: traceback.print_exc() - print('-'*100) + print("-" * 100) print( - f'[WARNING]: Installation with command `{cmd_args}` ' - f'failed. Trying with `{alt_cmd_args1}`...' + f"[WARNING]: Installation with command `{cmd_args}` " + f"failed. Trying with `{alt_cmd_args1}`..." ) - print('-'*100) - + print("-" * 100) + try: subprocess.check_call(alt_cmd_args1, shell=shell) break except Exception as err: traceback.print_exc() - print('-'*100) + print("-" * 100) print( - f'[WARNING]: Installation of `tables` with ' - f'{pkg_mng} failed. Trying with {alt_pkg_mng}...' + f"[WARNING]: Installation of `tables` with " + f"{pkg_mng} failed. Trying with {alt_pkg_mng}..." ) - print('-'*100) - + print("-" * 100) + # import pdb; pdb.set_trace() try: subprocess.check_call(alt_cmd_args2, shell=alt_shell) break except Exception as err: import traceback + traceback.print_exc() - print('*'*60) - if parent_software == 'Cell-ACDC': - msg_type = '[WARNING]' + print("*" * 60) + if parent_software == "Cell-ACDC": + msg_type = "[WARNING]" log_func = print else: - msg_type = '[ERROR]' + msg_type = "[ERROR]" log_func = exit - + log_func( - f'{msg_type}: Installation of `tables` failed. ' - 'Please report the issue here (**including the error ' - f'message above**): {issues_url}' + f"{msg_type}: Installation of `tables` failed. " + "Please report the issue here (**including the error " + f"message above**): {issues_url}" ) - print('^'*60) + print("^" * 60) finally: break - elif answer.lower() == 'n': + elif answer.lower() == "n": raise e else: print( @@ -123,45 +128,51 @@ def _install_tables(parent_software='Cell-ACDC'): return True + def _setup_symlink_app_name_macos(): - """On Mac generate a symlink from the Python path defined in the shebang - of the `acdc` binary called Cell-ACDC and modify the shebang to run - the acdc binary from the symlink. This will correctly display Cell-ACDC + """On Mac generate a symlink from the Python path defined in the shebang + of the `acdc` binary called Cell-ACDC and modify the shebang to run + the acdc binary from the symlink. This will correctly display Cell-ACDC in the menubar instead of Python. - """ + """ from . import is_mac, printl + if not is_mac: return - + import subprocess + acdc_binary_path = os.path.dirname(sys.executable) - symlink = os.path.join(acdc_binary_path, 'Cell-ACDC') + symlink = os.path.join(acdc_binary_path, "Cell-ACDC") if os.path.exists(symlink): return - - for acdc_exec_name in ('acdc', 'cellacdc'): + + for acdc_exec_name in ("acdc", "cellacdc"): acdc_exec_path = os.path.join(acdc_binary_path, acdc_exec_name) try: - with open(acdc_exec_path, 'r') as bin: + with open(acdc_exec_path, "r") as bin: acdc_exec_text = bin.read() - shebang = acdc_exec_text.split('\n')[0][2:] + shebang = acdc_exec_text.split("\n")[0][2:] if not os.path.exists(symlink): - command = f'ln -s {shebang} {symlink}' + command = f"ln -s {shebang} {symlink}" subprocess.check_call(command, shell=True) acdc_exec_text = acdc_exec_text.replace(shebang, symlink) - with open(acdc_exec_path, 'w') as bin: + with open(acdc_exec_path, "w") as bin: bin.write(acdc_exec_text) except Exception as err: printl(traceback.format_exc()) - print('[WARNING]: Failed at creating Cell-ACDC symlink') + print("[WARNING]: Failed at creating Cell-ACDC symlink") + -def _setup_gui_libraries(caller_name='Cell-ACDC', exit_at_end=True): +def _setup_gui_libraries(caller_name="Cell-ACDC", exit_at_end=True): from . import try_input_install_package, is_conda_env + warn_restart = False - + # Force PyQt6 if available try: from PyQt6 import QtCore + os.environ["QT_API"] = "pyqt6" except Exception as e: pass @@ -169,32 +180,34 @@ def _setup_gui_libraries(caller_name='Cell-ACDC', exit_at_end=True): try: import qtpy except ModuleNotFoundError as e: - conda_prefix, pip_prefix = myutils.get_pip_conda_prefix() - conda_list, pip_list = myutils.get_pip_conda_prefix(list_return=True) - - command_txt = f'{pip_prefix} --upgrade qtpy' - + conda_prefix, pip_prefix = utils.get_pip_conda_prefix() + conda_list, pip_list = utils.get_pip_conda_prefix(list_return=True) + + command_txt = f"{pip_prefix} --upgrade qtpy" + txt = ( - f'{caller_name} needs to install the package `qtpy`.\n\n' - f'You can let {caller_name} install it now, or you can abort ' - f'and install it manually with the command `{command_txt}`.' + f"{caller_name} needs to install the package `qtpy`.\n\n" + f"You can let {caller_name} install it now, or you can abort " + f"and install it manually with the command `{command_txt}`." ) - print('-'*60) + print("-" * 60) print(txt) while True: from .config import parser_args - if parser_args['yes']: - answer = 'y' + + if parser_args["yes"]: + answer = "y" else: - answer = try_input_install_package('qtpy', command_txt) - - if answer.lower() == 'y' or not answer: + answer = try_input_install_package("qtpy", command_txt) + + if answer.lower() == "y" or not answer: import subprocess - cmd = pip_list + ['-U', 'qtpy'] + + cmd = pip_list + ["-U", "qtpy"] subprocess.check_call(cmd) break - elif answer.lower() == 'n': + elif answer.lower() == "n": raise e else: print( @@ -202,508 +215,585 @@ def _setup_gui_libraries(caller_name='Cell-ACDC', exit_at_end=True): 'Type "y" for "yes", or "n" for "no".' ) except ImportError as e: - # Ignore that qtpy is installed but there is no PyQt bindings --> this + # Ignore that qtpy is installed but there is no PyQt bindings --> this # is handled in the next block pass - + from . import is_mac_arm64 - default_qt = 'PyQt5' if is_mac_arm64 else 'PyQt6' - - try: # no need to handle no_cli, acdc is run with -y flag + + default_qt = "PyQt5" if is_mac_arm64 else "PyQt6" + + try: # no need to handle no_cli, acdc is run with -y flag from qtpy.QtCore import Qt except Exception as e: traceback.print_exc() txt = ( - f'{caller_name} needs to install a GUI library (default library is ' - f'`{default_qt}`).\n\n' + f"{caller_name} needs to install a GUI library (default library is " + f"`{default_qt}`).\n\n" 'You can install it now or you can close (press "n") and install\n' - 'a compatible GUI library with one of ' - 'the following commands:\n\n' - f' * {pip_prefix} PyQt6==6.6.0 PyQt6-Qt6==6.6.0\n' - f' * {pip_prefix} PyQt5 (or `conda install pyqt`)\n' - f' * {pip_prefix} PySide2\n' - f' * {pip_prefix} PySide6\n\n' - f'Note: If `{default_qt}` installation fails, you could try installing any ' - 'of the other libraries.\n' + "a compatible GUI library with one of " + "the following commands:\n\n" + f" * {pip_prefix} PyQt6==6.6.0 PyQt6-Qt6==6.6.0\n" + f" * {pip_prefix} PyQt5 (or `conda install pyqt`)\n" + f" * {pip_prefix} PySide2\n" + f" * {pip_prefix} PySide6\n\n" + f"Note: If `{default_qt}` installation fails, you could try installing any " + "of the other libraries.\n" ) - print('-'*60) + print("-" * 60) print(txt) - pip_command = f'{pip_prefix} -U PyQt6==6.6.0 PyQt6-Qt6==6.6.0' + pip_command = f"{pip_prefix} -U PyQt6==6.6.0 PyQt6-Qt6==6.6.0" if is_mac_arm64: - commnad_txt = f'{conda_prefix} pyqt' - pkg_name = 'pyqt' + commnad_txt = f"{conda_prefix} pyqt" + pkg_name = "pyqt" else: commnad_txt = pip_command - pkg_name = 'PyQt6' + pkg_name = "PyQt6" while True: from .config import parser_args - if parser_args['yes']: - answer = 'y' + + if parser_args["yes"]: + answer = "y" else: answer = try_input_install_package(pkg_name, commnad_txt) - if answer.lower() == 'y' or not answer: + if answer.lower() == "y" or not answer: import subprocess + if is_mac_arm64 and is_conda_env(): - subprocess.check_call( - [f'{conda_prefix} pyqt'], shell=True - ) + subprocess.check_call([f"{conda_prefix} pyqt"], shell=True) else: - pip_args = pip_list + ['-U', 'PyQt6==6.6.0', 'PyQt6-Qt6==6.6.0'] + pip_args = pip_list + ["-U", "PyQt6==6.6.0", "PyQt6-Qt6==6.6.0"] subprocess.check_call(pip_args) warn_restart = True break - elif answer.lower() == 'n': + elif answer.lower() == "n": raise e else: print( f'"{answer}" is not a valid answer. ' 'Type "y" for "yes", or "n" for "no".' ) - + try: import pyqtgraph - version = pyqtgraph.__version__.split('.') + + version = pyqtgraph.__version__.split(".") pg_major, pg_minor, pg_patch = [int(val) for val in version] # if pg_major < 1: # raise ModuleNotFoundError('pyqtgraph must be upgraded') if pg_minor < 13: - raise ModuleNotFoundError('pyqtgraph must be upgraded') + raise ModuleNotFoundError("pyqtgraph must be upgraded") if pg_minor == 13 and pg_patch < 7: - raise ModuleNotFoundError('pyqtgraph must be upgraded') + raise ModuleNotFoundError("pyqtgraph must be upgraded") except ModuleNotFoundError: import subprocess + subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', 'pyqtgraph'] + [sys.executable, "-m", "pip", "install", "-U", "pyqtgraph"] ) warn_restart = True - + try: import seaborn except ModuleNotFoundError: import subprocess - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', 'seaborn'] - ) + + subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "seaborn"]) warn_restart = True - + if not warn_restart: return warn_restart - + if not exit_at_end: return warn_restart - + _exit_on_setup(caller_name=caller_name) - + return warn_restart -def _exit_on_setup(caller_name='Cell-ACDC'): - print('*'*60) - note_text = ( - f'[NOTE]: {caller_name} had to install the required libraries. ' - ) + +def _exit_on_setup(caller_name="Cell-ACDC"): + print("*" * 60) + note_text = f"[NOTE]: {caller_name} had to install the required libraries. " note_text = ( - f'{note_text}' - 'Please, re-start the software. Thank you for your patience! ' + f"{note_text}Please, re-start the software. Thank you for your patience! " ) - + from .config import parser_args - if parser_args['yes']: + + if parser_args["yes"]: print(note_text) else: - note_text = ( - f'{note_text}' - '(Press any key to exit). ' - ) + note_text = f"{note_text}(Press any key to exit). " input(note_text) - + exit() - + + def download_model_params(): print("Downloading specified models...") from .config import parser_args - if parser_args['cpModelsDownload'] or parser_args['AllModelsDownload']: - print('[INFO]: Downloading Cellpose models...') + + if parser_args["cpModelsDownload"] or parser_args["AllModelsDownload"]: + print("[INFO]: Downloading Cellpose models...") from cellpose import models + model_names = ["cyto", "cyto2", "cyto3", "nuclei"] try: # download size model weights from cellpose.models import size_model_path, model_path + for model_name in model_names: - print(f'[INFO]: Downloading {model_name} model weights...') + print(f"[INFO]: Downloading {model_name} model weights...") try: size_model_path(model_name) model_path(model_name) except Exception as e: - print( - f'[WARNING]: Failed to download {model_name} model weights. ' - ) + print(f"[WARNING]: Failed to download {model_name} model weights. ") print(e) pass - + from cellpose.denoise import MODEL_NAMES + for model_name in MODEL_NAMES: - print(f'[INFO]: Downloading {model_name} model weights...') + print(f"[INFO]: Downloading {model_name} model weights...") try: model_path(model_name) except Exception as e: - print( - f'[WARNING]: Failed to download {model_name} model weights. ' - ) - if model_name in ["oneclick_per_cyto2", - "oneclick_seg_cyto2", - "oneclick_rec_cyto2", - "oneclick_per_nuclei", - "oneclick_seg_nuclei", - "oneclick_rec_nuclei"]: - print(f' This model is not available for download. ') + print(f"[WARNING]: Failed to download {model_name} model weights. ") + if model_name in [ + "oneclick_per_cyto2", + "oneclick_seg_cyto2", + "oneclick_rec_cyto2", + "oneclick_per_nuclei", + "oneclick_seg_nuclei", + "oneclick_rec_nuclei", + ]: + print(f" This model is not available for download. ") print(e) pass except Exception as e: - print( - '[WARNING]: Failed to download Cellpose model weights. ' - ) + print("[WARNING]: Failed to download Cellpose model weights. ") print(e) pass - if parser_args['StarDistModelsDownload'] or parser_args['AllModelsDownload']: - print('[INFO]: Downloading StarDist models...') + if parser_args["StarDistModelsDownload"] or parser_args["AllModelsDownload"]: + print("[INFO]: Downloading StarDist models...") try: - from cellacdc.models import STARDIST_MODELS + from cellacdc.segmenters import STARDIST_MODELS from stardist.models import StarDist2D, StarDist3D + for model_type in [StarDist2D, StarDist3D]: for model_name in STARDIST_MODELS: - print(f'[INFO]: Downloading {model_name} model weights...') + print(f"[INFO]: Downloading {model_name} model weights...") try: model_type.from_pretrained(model_name) except Exception as e: print( - f'[WARNING]: Failed to download {model_name} model weights. ' + f"[WARNING]: Failed to download {model_name} model weights. " ) print(e) pass except Exception as e: - print( - '[WARNING]: Failed to download StarDist model weights. ' - ) + print("[WARNING]: Failed to download StarDist model weights. ") print(e) pass - if parser_args['YeaZModelsDownload'] or parser_args['AllModelsDownload']: - print('[INFO]: Downloading YeaZ models...') - from cellacdc.myutils import _download_yeaz_models + if parser_args["YeaZModelsDownload"] or parser_args["AllModelsDownload"]: + print("[INFO]: Downloading YeaZ models...") + from cellacdc.utils import _download_yeaz_models + try: _download_yeaz_models() except Exception as e: - print( - '[WARNING]: Failed to download YeaZ model weights. ' - ) + print("[WARNING]: Failed to download YeaZ model weights. ") print(e) pass - if parser_args['DeepSeaModelsDownload'] or parser_args['AllModelsDownload']: - print('[INFO]: Downloading DeepSea models...') - from cellacdc.myutils import _download_deepsea_models + if parser_args["DeepSeaModelsDownload"] or parser_args["AllModelsDownload"]: + print("[INFO]: Downloading DeepSea models...") + from cellacdc.utils import _download_deepsea_models + try: _download_deepsea_models() except Exception as e: - print( - '[WARNING]: Failed to download DeepSea model weights. ' - ) + print("[WARNING]: Failed to download DeepSea model weights. ") print(e) pass - if parser_args['TrackastraModelsDownload'] or parser_args['AllModelsDownload']: - print('[INFO]: Downloading TrackAstra models...') - # from cellacdc.myutils import _download_trackastra_models + if parser_args["TrackastraModelsDownload"] or parser_args["AllModelsDownload"]: + print("[INFO]: Downloading TrackAstra models...") + # from cellacdc.utils import _download_trackastra_models from trackastra.model import Trackastra + try: from cellacdc.trackers.Trackastra import get_pretrained_model_names + model_names = get_pretrained_model_names() for model_name in model_names: - print(f'[INFO]: Downloading {model_name} model weights...') + print(f"[INFO]: Downloading {model_name} model weights...") try: Trackastra.from_pretrained(model_name) except Exception as e: - print( - f'[WARNING]: Failed to download {model_name} model weights. ' - ) + print(f"[WARNING]: Failed to download {model_name} model weights. ") print(e) pass except Exception as e: - print( - '[WARNING]: Failed to download TrackAstra model weights. ' - ) + print("[WARNING]: Failed to download TrackAstra model weights. ") print(e) pass - + + +def setup_gui_runtime(*, splashscreen=False): + """Shared Qt/pyqtgraph/model-download setup for CLI and script API.""" + _setup_symlink_app_name_macos() + + requires_exit = _setup_gui_libraries(exit_at_end=False) + + _setup_numpy() + + download_model_params() + + if requires_exit: + _exit_on_setup() + + from qtpy import QtWidgets, QtCore + + if os.name == "nt": + try: + import ctypes + + myappid = "schmollerlab.cellacdc.pyqt.v1" + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) + except Exception: + pass + + try: + QtWidgets.QApplication.setAttribute( + QtCore.Qt.HighDpiScaleFactorRoundingPolicy.PassThrough + ) + except Exception: + pass + + import pyqtgraph as pg + + pg.setConfigOption("imageAxisOrder", "row-major") + try: + import numba # noqa: F401 + + pg.setConfigOption("useNumba", True) + except Exception: + pass + + try: + import cupy # noqa: F401 + + pg.setConfigOption("useCupy", True) + except Exception: + pass + + return _setup_app(splashscreen=splashscreen) + + def _setup_app(splashscreen=False, icon_path=None, logo_path=None, scheme=None): from qtpy import QtCore + if QtCore.QCoreApplication.instance() is not None: return QtCore.QCoreApplication.instance(), None - + from qtpy import QtWidgets + # Handle high resolution displays: - if hasattr(QtCore.Qt, 'AA_EnableHighDpiScaling'): + if hasattr(QtCore.Qt, "AA_EnableHighDpiScaling"): QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True) - if hasattr(QtCore.Qt, 'AA_UseHighDpiPixmaps'): + if hasattr(QtCore.Qt, "AA_UseHighDpiPixmaps"): QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True) - if hasattr(QtCore.Qt, 'AA_PluginApplication'): + if hasattr(QtCore.Qt, "AA_PluginApplication"): QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_PluginApplication, False) # Check OS dark or light mode from qtpy.QtWidgets import QApplication, QStyleFactory from qtpy.QtGui import QPalette, QIcon from . import settings_csv_path, resources_folderpath, is_linux - - app = QApplication(['Cell-ACDC']) + + app = QApplication(["Cell-ACDC"]) app.setApplicationName("Cell-ACDC") - app.setStyle(QStyleFactory.create('Fusion')) + app.setStyle(QStyleFactory.create("Fusion")) is_OS_dark_mode = app.palette().color(QPalette.Window).getHsl()[2] < 100 app.toggle_dark_mode = False if is_OS_dark_mode: - # Switch to dark mode if scheme was never selected by user and OS is + # Switch to dark mode if scheme was never selected by user and OS is # dark mode import pandas as pd - df_settings = pd.read_csv(settings_csv_path, index_col='setting') - if 'colorScheme' not in df_settings.index: + + df_settings = pd.read_csv(settings_csv_path, index_col="setting") + if "colorScheme" not in df_settings.index: app.toggle_dark_mode = True - + if icon_path is None: - icon_path = os.path.join(resources_folderpath, 'icon_v2.ico') + icon_path = os.path.join(resources_folderpath, "icon_v2.ico") app.setWindowIcon(QIcon(icon_path)) if is_linux: app.setDesktopFileName("cell-acdc") - + if logo_path is None: - logo_path = os.path.join(resources_folderpath, 'logo_v2.png') - + logo_path = os.path.join(resources_folderpath, "logo_v2.png") + from qtpy import QtWidgets, QtGui splashScreen = None if splashscreen: + class SplashScreen(QtWidgets.QSplashScreen): def __init__(self, logo_path, icon_path): super().__init__() pixmap = QtGui.QPixmap(logo_path) - pixmap = pixmap.scaledToWidth( - 300, QtCore.Qt.SmoothTransformation - ) + pixmap = pixmap.scaledToWidth(300, QtCore.Qt.SmoothTransformation) self.setPixmap(pixmap) self.setWindowIcon(QIcon(icon_path)) self.setWindowFlags( - QtCore.Qt.WindowStaysOnTopHint - | QtCore.Qt.SplashScreen + QtCore.Qt.WindowStaysOnTopHint + | QtCore.Qt.SplashScreen | QtCore.Qt.FramelessWindowHint ) - + def mousePressEvent(self, a0: QtGui.QMouseEvent) -> None: pass - + def showEvent(self, event): self.raise_() - + # Launch splashscreen splashScreen = SplashScreen(logo_path, icon_path) - splashScreen.show() - + splashScreen.show() + from ._palettes import getPaletteColorScheme, setToolTipStyleSheet from ._palettes import get_color_scheme from . import qrc_resources_path from . import acdc_qrc_resources from . import printl - + # Check if there are new icons --> replace qrc_resources.py if scheme is None: scheme = get_color_scheme() - if scheme == 'light': + if scheme == "light": from . import qrc_resources_light_path as qrc_resources_scheme_path - qrc_resources_scheme = import_module('cellacdc.qrc_resources_light') + + qrc_resources_scheme = import_module("cellacdc.qrc_resources_light") qrc_resource_data_scheme = qrc_resources_scheme.qt_resource_data else: from . import qrc_resources_dark_path as qrc_resources_scheme_path - qrc_resources_scheme = import_module('cellacdc.qrc_resources_dark') + + qrc_resources_scheme = import_module("cellacdc.qrc_resources_dark") qrc_resource_data_scheme = qrc_resources_scheme.qt_resource_data - + qrc_resource_version_required = 1 try: qrc_resource_version_required = qrc_resources_scheme.version except Exception as err: pass - + current_qrc_resource_version = 1 try: current_qrc_resource_version = acdc_qrc_resources.version except Exception as err: pass - + is_copy_qrc_required = ( qrc_resource_data_scheme != acdc_qrc_resources.qt_resource_data or qrc_resource_version_required != current_qrc_resource_version ) - + if is_copy_qrc_required: from . import _copy_qrc_resources_file, _warnings, qrc_resources_path - _copy_qrc_resources_file( - qrc_resources_scheme_path, qrc_resources_path - ) + + _copy_qrc_resources_file(qrc_resources_scheme_path, qrc_resources_path) _warnings.warnRestartAcdcIconsUpdated() exit() - + from . import load + scheme = get_color_scheme() palette = getPaletteColorScheme(app.palette(), scheme=scheme) - app.setPalette(palette) + app.setPalette(palette) # load.rename_qrc_resources_file(scheme) # setToolTipStyleSheet(app, scheme=scheme) - + return app, splashScreen + def run_segm_workflow(workflow_params, logger, log_path): - logger.info('Initializing segmentation and tracking kernel...') + logger.info("Initializing segmentation and tracking kernel...") from cellacdc import cli + from cellacdc.workflow.adapters import ( + runnable_config_from_segm_kernel, + sync_segm_kernel_from_context, + ) + from cellacdc.workflow.pipelines.batch import run_segm_batch + kernel = cli.SegmKernel(logger, log_path, is_cli=True) kernel.init_args_from_params(workflow_params, logger.info) ch_filepaths = kernel.parse_paths(workflow_params) stop_frame_nums = kernel.parse_stop_frame_numbers(workflow_params) pbar = tqdm(total=len(ch_filepaths), ncols=100) - for ch_filepath, stop_frame_n in zip(ch_filepaths, stop_frame_nums): - logger.info(f'\nProcessing "{ch_filepath}"...') - kernel.run(ch_filepath, stop_frame_n) - pbar.update() + run_segm_batch( + kernel._workflow_ctx, + ch_filepaths, + stop_frame_nums, + runnable_config_from_segm_kernel(kernel), + progress=pbar, + ) + sync_segm_kernel_from_context(kernel, kernel._workflow_ctx) pbar.close() + def run_measurements_workflow(workflow_params, logger, log_path): - logger.info('Initializing measurements kernel...') + logger.info("Initializing measurements kernel...") from cellacdc import cli + from cellacdc.workflow.pipelines.batch import run_measurements_batch + from cellacdc.workflow.runnable import RunnableConfig + kernel = cli.ComputeMeasurementsKernel(logger, log_path, is_cli=True) ch_filepaths = kernel.parse_paths(workflow_params) stop_frame_nums = kernel.parse_stop_frame_numbers(workflow_params) - end_filename_segm = workflow_params['measurements']['end_filename_segm'] - kernel.set_metrics_from_workflow_config_params( - workflow_params['measurements'] - ) + end_filename_segm = workflow_params["measurements"]["end_filename_segm"] + kernel.set_metrics_from_workflow_config_params(workflow_params["measurements"]) pbar = tqdm(total=len(ch_filepaths), ncols=100) - for ch_filepath, stop_frame_n in zip(ch_filepaths, stop_frame_nums): - logger.info(f'\nProcessing "{ch_filepath}"...') - kernel.run( - img_path=ch_filepath, - stop_frame_n=stop_frame_n, - end_filename_segm=end_filename_segm, - ) - pbar.update() + run_measurements_batch( + kernel, + ch_filepaths, + stop_frame_nums, + end_filename_segm, + RunnableConfig(logger_func=logger.info), + progress=pbar, + ) pbar.close() + def run_cli(ini_filepath): - from cellacdc import myutils - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='cli', logs_path=None + from cellacdc import utils + from cellacdc.workflow.pipelines.full_workflow import build_full_workflow_graph + from cellacdc.workflow.runnable import RunnableConfig + from cellacdc.workflow.state import FullWorkflowState + + logger, logs_path, log_path, log_filename = utils.setupLogger( + module="cli", logs_path=None ) - + download_model_params() - + logger.info(f'Reading workflow file "{ini_filepath}"...') from cellacdc import load + workflow_params = load.read_segm_workflow_from_config(ini_filepath) - workflow_type = workflow_params['workflow']['type'] - - if workflow_type == 'segmentation and/or tracking': - run_segm_workflow(workflow_params, logger, log_path) - - if 'measurements' in workflow_params.keys(): - logger.info('Loading measurements workflow...') - meas_workflow_params = load.read_measurements_workflow_from_config( - ini_filepath - ) - run_measurements_workflow(meas_workflow_params, logger, log_path) - - logger.info('**********************************************') - logger.info(f'Cell-ACDC command-line closed. {myutils.get_salute_string()}') - logger.info('**********************************************') - - -def _setup_numpy(caller_name='Cell-ACDC'): + workflow_type = workflow_params["workflow"]["type"] + run_segm = workflow_type == "segmentation and/or tracking" + run_measurements = "measurements" in workflow_params + + meas_params = None + if run_measurements: + logger.info("Loading measurements workflow...") + meas_params = load.read_measurements_workflow_from_config(ini_filepath) + + workflow_ctx = type("WorkflowCliContext", (), {"logger": logger, "log_path": log_path})() + graph = build_full_workflow_graph(workflow_ctx).compile() + graph.invoke( + FullWorkflowState( + segm_params=workflow_params, + measurements_params=meas_params, + run_segm=run_segm, + run_measurements=run_measurements, + ), + RunnableConfig(logger_func=logger.info), + ) + + logger.info("**********************************************") + logger.info(f"Cell-ACDC command-line closed. {utils.get_salute_string()}") + logger.info("**********************************************") + + +def _setup_numpy(caller_name="Cell-ACDC"): import urllib.request import json import re - + from . import try_input_install_package - + numpy_versions = [] url = "https://pypi.org/pypi/numba/json" try: with urllib.request.urlopen(url) as response: data = json.load(response) requires_dist = data["info"].get("requires_dist", []) - numpy_versions = [ - req for req in requires_dist if "numpy" in req.lower() - ] + numpy_versions = [req for req in requires_dist if "numpy" in req.lower()] except urllib.error.URLError as e: print(f"Could not update np: {e}") return - + if not numpy_versions: print( - f'[WARNING]: Could not find NumPy version requirements for Numba. ' - 'Please, install the latest version of NumPy manually.' + f"[WARNING]: Could not find NumPy version requirements for Numba. " + "Please, install the latest version of NumPy manually." ) return - + numpy_versions_txt = numpy_versions[0] - - max_version = re.findall(r'<=?(\d+\.\d+)', numpy_versions_txt) - min_version = re.findall(r'>=?(\d+\.\d+)', numpy_versions_txt) + + max_version = re.findall(r"<=?(\d+\.\d+)", numpy_versions_txt) + min_version = re.findall(r">=?(\d+\.\d+)", numpy_versions_txt) if max_version: max_version = max_version[0] else: - max_version = '' - + max_version = "" + if min_version: min_version = min_version[0] else: - min_version = '' - + min_version = "" + import numpy + installed_numpy_version = numpy.__version__ - is_numpy_version_within_range = myutils.is_pkg_version_within_range( - installed_numpy_version, - min_version=min_version, - max_version=max_version + is_numpy_version_within_range = utils.is_pkg_version_within_range( + installed_numpy_version, min_version=min_version, max_version=max_version ) - + if is_numpy_version_within_range: return - - conda_prefix, pip_prefix = myutils.get_pip_conda_prefix() - conda_list, pip_list = myutils.get_pip_conda_prefix(list_return=True) + + conda_prefix, pip_prefix = utils.get_pip_conda_prefix() + conda_list, pip_list = utils.get_pip_conda_prefix(list_return=True) command_txt = f'{pip_prefix} --upgrade "{numpy_versions_txt}"' - + txt = ( - f'{caller_name} needs to upgrade the package `numpy`.\n\n' - f'The current version is {installed_numpy_version}, but it needs to be ' - f'between {min_version} and {max_version}.\n\n' - f'You can let {caller_name} install it now, or you can abort ' - f'and install it manually with the command `{command_txt}`.' + f"{caller_name} needs to upgrade the package `numpy`.\n\n" + f"The current version is {installed_numpy_version}, but it needs to be " + f"between {min_version} and {max_version}.\n\n" + f"You can let {caller_name} install it now, or you can abort " + f"and install it manually with the command `{command_txt}`." ) - print('-'*60) + print("-" * 60) print(txt) - + while True: from .config import parser_args - if parser_args['yes']: - answer = 'y' + + if parser_args["yes"]: + answer = "y" else: - answer = try_input_install_package('qtpy', command_txt) - - if answer.lower() == 'y' or not answer: + answer = try_input_install_package("qtpy", command_txt) + + if answer.lower() == "y" or not answer: import subprocess - cmd = pip_list + ['-U', numpy_versions_txt] + + cmd = pip_list + ["-U", numpy_versions_txt] subprocess.check_call(cmd) break - elif answer.lower() == 'n': - raise ModuleNotFoundError(f'Numba requires {numpy_versions_txt} ') + elif answer.lower() == "n": + raise ModuleNotFoundError(f"Numba requires {numpy_versions_txt} ") else: print( f'"{answer}" is not a valid answer. ' 'Type "y" for "yes", or "n" for "no".' - ) \ No newline at end of file + ) diff --git a/cellacdc/_types.py b/cellacdc/_types.py index 5e0894523..b2d0dd6dd 100644 --- a/cellacdc/_types.py +++ b/cellacdc/_types.py @@ -3,67 +3,78 @@ from typing import Union, Tuple, Any, List import numpy as np + class NotGUIParam: not_a_param = True + ChannelsDict = dict[str, List[np.ndarray]] + class RescaleIntensitiesInRangeHow: - values = ['percentage', 'image', 'absolute'] + values = ["percentage", "image", "absolute"] + class BaSiCpyResizeModes: - values = ['jax', 'skimage', 'skimage_dask'] + values = ["jax", "skimage", "skimage_dask"] + class BaSiCpyFittingModes: - values = ['ladmap', 'approximate'] + values = ["ladmap", "approximate"] + class BaSiCpyTimelapse: values = ["True", "False", "additive", "multiplicative"] + class Vector: - """Class used to define model parameter as a vector that will use the + """Class used to define model parameter as a vector that will use the cellacdc.widgets.VectorLineEdit widget in the automatic GUI. """ + @staticmethod def cast_dtype(value: Any) -> Union[Tuple[float], int, float]: if isinstance(value, str): - value = value.lstrip('(').rstrip(')') - value = value.lstrip('[').rstrip(']') - values = value.split(',') + value = value.lstrip("(").rstrip(")") + value = value.lstrip("[").rstrip("]") + values = value.split(",") values = tuple([float(val) for val in values]) return values elif isinstance(value, (int, float)): return value - - raise TypeError(f'Could not convert {value} {(type(value))} to Vector') - + + raise TypeError(f"Could not convert {value} {(type(value))} to Vector") + def __call__(self, value: Any) -> Union[Tuple[float], int, float]: return self.cast_dtype(value) - + + class FolderPath: - """Class used to define model parameter as a folder path control with a + """Class used to define model parameter as a folder path control with a browse button to select a folder in the automatic GUI. """ + def cast_dtype(self, value: Any) -> Union[Tuple[float], int, float]: return str(value) - + def __call__(self, value: Any) -> str: return self.cast_dtype(value) + class SecondChannelImage: pass + def is_optional(field): - return ( - typing.get_origin(field) is Union and - type(None) in typing.get_args(field) - ) + return typing.get_origin(field) is Union and type(None) in typing.get_args(field) + def is_second_channel_type(field): if is_optional(field): field = typing.get_args(field)[0] - - return getattr(field, '__name__', None) == 'SecondChannelImage' # avoid union + + return getattr(field, "__name__", None) == "SecondChannelImage" # avoid union + def is_widget_not_required(ArgSpec): try: @@ -71,21 +82,22 @@ def is_widget_not_required(ArgSpec): return True except Exception as err: pass - + try: - # If a parameter if None, python initializes it to + # If a parameter if None, python initializes it to # typing.Optional and we need to access the first type ArgSpec.type.__args__[0]().not_a_param return True except Exception as err: pass - + return False + def to_str(*args): if len(args) == 2: value = args[1] else: value = args[0] - - return str(value) \ No newline at end of file + + return str(value) diff --git a/cellacdc/_view_all_buttons.py b/cellacdc/_view_all_buttons.py index 3428725a4..d51b0700a 100644 --- a/cellacdc/_view_all_buttons.py +++ b/cellacdc/_view_all_buttons.py @@ -1,13 +1,17 @@ import sys -SCHEME = 'dark' +SCHEME = "dark" FLAT = False from qtpy.QtGui import QIcon from qtpy.QtCore import Qt, QSize from qtpy.QtWidgets import ( - QApplication, QPushButton, QStyleFactory, QWidget, QGridLayout, - QCheckBox + QApplication, + QPushButton, + QStyleFactory, + QWidget, + QGridLayout, + QCheckBox, ) from cellacdc import widgets, _run @@ -21,9 +25,9 @@ # Distribute icons over a 16:9 grid nicons = len(buttons_names) -ncols = round((nicons / 16*9)**(1/2)) +ncols = round((nicons / 16 * 9) ** (1 / 2)) nrows = nicons // ncols -left_nicons = nicons % ncols +left_nicons = nicons % ncols if left_nicons > 0: nrows += 1 @@ -52,19 +56,21 @@ max_height = max([button.sizeHint().height() for button in buttons]) for button in buttons: - button.setMinimumHeight(max_height*2) + button.setMinimumHeight(max_height * 2) + def setDisabled(checked): for button in buttons: button.setDisabled(checked) - -checkbox = QCheckBox('Disable buttons') + + +checkbox = QCheckBox("Disable buttons") checkbox.toggled.connect(setDisabled) -layout.addWidget(checkbox, i, j+1) +layout.addWidget(checkbox, i, j + 1) -layout.setRowStretch(i+1, 1) -layout.setColumnStretch(j+2, 1) +layout.setRowStretch(i + 1, 1) +layout.setColumnStretch(j + 2, 1) splashScreen.close() win.show() -app.exec_() \ No newline at end of file +app.exec_() diff --git a/cellacdc/_view_all_icons.py b/cellacdc/_view_all_icons.py index eca265fea..569480eb9 100644 --- a/cellacdc/_view_all_icons.py +++ b/cellacdc/_view_all_icons.py @@ -2,14 +2,18 @@ import os import shutil -SCHEME = 'dark' +SCHEME = "dark" FLAT = True from qtpy.QtGui import QIcon from qtpy.QtCore import Qt, QSize from qtpy.QtWidgets import ( - QApplication, QPushButton, QStyleFactory, QWidget, QGridLayout, - QCheckBox + QApplication, + QPushButton, + QStyleFactory, + QWidget, + QGridLayout, + QCheckBox, ) from cellacdc import _run @@ -22,13 +26,13 @@ # Distribute icons over a 16:9 grid nicons = len(svg_aliases) -ncols = round((nicons / 16*9)**(1/2)) +ncols = round((nicons / 16 * 9) ** (1 / 2)) nrows = nicons // ncols -left_nicons = nicons % ncols +left_nicons = nicons % ncols if left_nicons > 0: nrows += 1 -if hasattr(Qt, 'AA_UseHighDpiPixmaps'): +if hasattr(Qt, "AA_UseHighDpiPixmaps"): QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) win = QWidget() @@ -42,10 +46,10 @@ if idx == nicons: break alias = svg_aliases[idx] - icon = QIcon(f':{alias}') + icon = QIcon(f":{alias}") button = QPushButton(alias) button.setIcon(icon) - button.setIconSize(QSize(32,32)) + button.setIconSize(QSize(32, 32)) button.setCheckable(True) if FLAT: button.setFlat(True) @@ -53,14 +57,16 @@ buttons.append(button) idx += 1 + def setDisabled(checked): for button in buttons: button.setDisabled(checked) - -checkbox = QCheckBox('Disable buttons') + + +checkbox = QCheckBox("Disable buttons") checkbox.toggled.connect(setDisabled) -layout.addWidget(checkbox, i, j+1) +layout.addWidget(checkbox, i, j + 1) splashScreen.close() win.showMaximized() -app.exec_() \ No newline at end of file +app.exec_() diff --git a/cellacdc/_warnings.py b/cellacdc/_warnings.py index e313b3472..86f49672e 100644 --- a/cellacdc/_warnings.py +++ b/cellacdc/_warnings.py @@ -2,16 +2,18 @@ from functools import partial import re -from cellacdc import html_utils, myutils +from cellacdc import html_utils, utils from . import issues_url from . import urls from . import error_below, error_close + def warnTooManyItems(mainWin, numItems, qparent): from . import widgets + mainWin.logger.info( - '[WARNING]: asking user what to do with too many graphical items...' + "[WARNING]: asking user what to do with too many graphical items..." ) msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(f""" @@ -24,118 +26,111 @@ def warnTooManyItems(mainWin, numItems, qparent): """) _, stayHighResButton, switchToLowResButton = msg.warning( - qparent, 'Too many objects', txt, + qparent, + "Too many objects", + txt, buttonsTexts=( - 'Cancel', 'Stay on high resolution', - widgets.reloadPushButton(' Switch to low resolution ') - ) + "Cancel", + "Stay on high resolution", + widgets.reloadPushButton(" Switch to low resolution "), + ), ) - return msg.cancel, msg.clickedButton==switchToLowResButton + return msg.cancel, msg.clickedButton == switchToLowResButton -def warnRestartCellACDCcolorModeToggled( - scheme, app_name='Cell-ACDC', parent=None - ): + +def warnRestartCellACDCcolorModeToggled(scheme, app_name="Cell-ACDC", parent=None): from . import widgets + msg = widgets.myMessageBox(wrapText=False) - txt = ( - 'In order for the change to take effect, ' - f'please restart {app_name}' - ) - if scheme == 'dark': + txt = f"In order for the change to take effect, please restart {app_name}" + if scheme == "dark": issues_href = f'GitHub page' note_txt = ( - 'NOTE: dark mode is a recent feature so if you see ' - 'if you see anything odd,
' - 'please, report it by opening an issue ' - f'on our {issues_href}.

' - 'Thanks!' + "NOTE: dark mode is a recent feature so if you see " + "if you see anything odd,
" + "please, report it by opening an issue " + f"on our {issues_href}.

" + "Thanks!" ) - txt = f'{txt}

{note_txt}' + txt = f"{txt}

{note_txt}" txt = html_utils.paragraph(txt) - msg.information(parent, f'Restart {app_name}', txt) + msg.information(parent, f"Restart {app_name}", txt) + class DataTypeWarning(RuntimeWarning): def __init__(self, message): self._message = message - + def __str__(self): return repr(self._message) + def warn_image_overflow_dtype(input_dtype, max_value, inferred_dtype): import warnings + warnings.warn( - f'The input image has data type {input_dtype}. Since it is neither ' - f'8-bit, 16-bit, nor 32-bit the data was inferred as {inferred_dtype} ' - f'from the max value of the image of {max_value}.', - DataTypeWarning + f"The input image has data type {input_dtype}. Since it is neither " + f"8-bit, 16-bit, nor 32-bit the data was inferred as {inferred_dtype} " + f"from the max value of the image of {max_value}.", + DataTypeWarning, ) + def warn_cca_integrity(txt, category, qparent, go_to_frame_callback=None): from . import widgets from qtpy.QtWidgets import QCheckBox - + preamble = html_utils.paragraph( - 'WARNING: integrity of cell cycle annotations ' - 'might be compromised:' + "WARNING: integrity of cell cycle annotations " + "might be compromised:" ) - - msg_text = f'{preamble}{txt}' - - stopSpecificMessageCheckbox = QCheckBox( - 'Stop warning with this specific message' - ) - stopCategoryCheckbox = QCheckBox( - f'Stop warning about "{category}"' - ) - disableAllWarningsCheckbox = QCheckBox( - 'Disable all warnings' - ) - + + msg_text = f"{preamble}{txt}" + + stopSpecificMessageCheckbox = QCheckBox("Stop warning with this specific message") + stopCategoryCheckbox = QCheckBox(f'Stop warning about "{category}"') + disableAllWarningsCheckbox = QCheckBox("Disable all warnings") + checkboxes = ( - stopSpecificMessageCheckbox, - stopCategoryCheckbox, - disableAllWarningsCheckbox + stopSpecificMessageCheckbox, + stopCategoryCheckbox, + disableAllWarningsCheckbox, ) - + msg = widgets.myMessageBox(wrapText=False) - if go_to_frame_callback is not None and txt.find('At frame n.') != -1: - frame_n = re.findall(r'At frame n. (\d+)', txt)[0] - goToFrameButton = widgets.NavigatePushButton(f'Go to frame n. {frame_n}') + if go_to_frame_callback is not None and txt.find("At frame n.") != -1: + frame_n = re.findall(r"At frame n. (\d+)", txt)[0] + goToFrameButton = widgets.NavigatePushButton(f"Go to frame n. {frame_n}") goToFrameButton = msg.addButton(goToFrameButton) goToFrameButton.disconnect() - goToFrameButton.clicked.connect( - partial(go_to_frame_callback, int(frame_n)) - ) - - msg.warning( - qparent, 'Annotations integrity warning', msg_text, - widgets=checkboxes - ) - + goToFrameButton.clicked.connect(partial(go_to_frame_callback, int(frame_n))) + + msg.warning(qparent, "Annotations integrity warning", msg_text, widgets=checkboxes) + if stopSpecificMessageCheckbox.isChecked(): return txt - + if stopCategoryCheckbox.isChecked(): return category - + if disableAllWarningsCheckbox.isChecked(): - return 'disable_all' - - return '' + return "disable_all" + + return "" + -def warn_installing_different_cellpose_version( - requested_version, installed_version - ): +def warn_installing_different_cellpose_version(requested_version, installed_version): from cellacdc import widgets - if not myutils.is_gui_running(): + + if not utils.is_gui_running(): print( - f'[WARNING]: You requested to install `Cellpose {requested_version}` ' - f'but you already have `Cellpose {installed_version}`.\n\n' - f'If you proceed, Cell-ACDC will *uninstall* `{installed_version}` ' - f'and will install `{requested_version}`.' + f"[WARNING]: You requested to install `Cellpose {requested_version}` " + f"but you already have `Cellpose {installed_version}`.\n\n" + f"If you proceed, Cell-ACDC will *uninstall* `{installed_version}` " + f"and will install `{requested_version}`." ) return False - + note_text = """ You can still proceed and let Cell-ACDC take care of uninstalling/installing the right versions every time you request it. @@ -153,15 +148,14 @@ def warn_installing_different_cellpose_version( {html_utils.to_note(note_text)} """) msg = widgets.myMessageBox(wrapText=False) - msg.warning( - None, 'Cellpose already installed', txt, - buttonsTexts=('Cancel', 'Ok') - ) + msg.warning(None, "Cellpose already installed", txt, buttonsTexts=("Cancel", "Ok")) return msg.cancel + def warn_download_bioformats_jar_failed(jar_dst_filepath, qparent=None): from cellacdc import widgets - href = html_utils.href_tag('here', urls.bioformats_download_page) + + href = html_utils.href_tag("here", urls.bioformats_download_page) txt = html_utils.paragraph(f""" [WARNING]: Download of bioformats_package.jar failed.

@@ -170,26 +164,34 @@ def warn_download_bioformats_jar_failed(jar_dst_filepath, qparent=None): """) msg = widgets.myMessageBox(wrapText=False) msg.warning( - qparent, 'Download of bioformats failed', txt, - commands=(jar_dst_filepath,), - path_to_browse=os.path.dirname(jar_dst_filepath) + qparent, + "Download of bioformats failed", + txt, + commands=(jar_dst_filepath,), + path_to_browse=os.path.dirname(jar_dst_filepath), ) return msg.cancel + def warn_segment_for_lost_IDs_first_frame(qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" The segmentation for lost IDs is not available on the first frame.

Thank you for your patience! """) msg = widgets.myMessageBox(wrapText=False) msg.warning( - qparent, 'Not available on first frame', txt, + qparent, + "Not available on first frame", + txt, ) return msg.cancel + def warnPromptSegmentPointsLayerNotInit(qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" The points layer was not initialized!

To initialize it, please, deactivate and reactivate the @@ -198,12 +200,16 @@ def warnPromptSegmentPointsLayerNotInit(qparent=None): """) msg = widgets.myMessageBox(wrapText=False) msg.warning( - qparent, 'Points layer not initialized', txt, + qparent, + "Points layer not initialized", + txt, ) return msg.cancel + def warnPromptSegmentModelNotInit(qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" Promptable model was not initialized!

To initialize it, please, click on the {numNewCells} new object(s) will ' - 'appear (highlighted in green on left image).

' - - f'However, in the previous frame (frame n. {frame_i}) there are ' - f'{G1_text} in G1 available.

' - - 'Note that cells must be in G1 in the previous frame too, ' - 'because if they are in G1
' - 'only at current frame, assigning a bud to it would result in no ' - 'G1 phase at all between current
' - 'and previous cell cycle.
' - - 'You can either cancel the operation and annotate division on previous ' - 'frames or continue.

' - - 'If you continue the new cell will be annotated as a ' - 'cell in G1 with unknown history.

' - - 'Do you want to continue?
' + f"In the next frame {numNewCells} new object(s) will " + "appear (highlighted in green on left image).

" + f"However, in the previous frame (frame n. {frame_i}) there are " + f"{G1_text} in G1 available.

" + "Note that cells must be in G1 in the previous frame too, " + "because if they are in G1
" + "only at current frame, assigning a bud to it would result in no " + "G1 phase at all between current
" + "and previous cell cycle.
" + "You can either cancel the operation and annotate division on previous " + "frames or continue.

" + "If you continue the new cell will be annotated as a " + "cell in G1 with unknown history.

" + "Do you want to continue?
" ) - + msg = widgets.myMessageBox(wrapText=False) _, yesButton = msg.warning( - qparent, 'No cells in G1!', text, - buttonsTexts=('Cancel', 'Continue anyway (new cells will start in G1)') + qparent, + "No cells in G1!", + text, + buttonsTexts=("Cancel", "Continue anyway (new cells will start in G1)"), ) return msg.clickedButton == yesButton - + def log_pytorch_not_installed(): print(error_below) print( - 'PyTorch is not installed. See here how to install it ' - f'{urls.install_pytorch}' + f"PyTorch is not installed. See here how to install it {urls.install_pytorch}" ) print(error_close) + def warnExportToVideo(qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" Exporting to video will start now.

During this process, the GUI will automatically update the images @@ -271,16 +279,17 @@ def warnExportToVideo(qparent=None): """) msg = widgets.myMessageBox(wrapText=False) msg.warning( - qparent, 'Export to video is starting', txt, - buttonsTexts=('Cancel', 'Ok') + qparent, "Export to video is starting", txt, buttonsTexts=("Cancel", "Ok") ) return msg.cancel + def warnDivisionAnnotationCannotBeUndone(ID, relID, issue_frame_i, qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" Cell division annotation cannot be undone because Cell ID {relID} - is in 'S' phase at frame n. {issue_frame_i+1}.

+ is in 'S' phase at frame n. {issue_frame_i + 1}.

By undoing division annotation, Cell ID {relID} would be restored as relative of Cell ID {ID}, but this cannot be done.

The only solution is to go to frame n. {issue_frame_i} and reset the @@ -288,13 +297,13 @@ def warnDivisionAnnotationCannotBeUndone(ID, relID, issue_frame_i, qparent=None) Thank you for your patience! """) msg = widgets.myMessageBox(wrapText=False) - msg.warning( - qparent, 'Division annotation cannot be undone', txt - ) + msg.warning(qparent, "Division annotation cannot be undone", txt) return msg.cancel + def warnCannotAddRemovePointsProjection(qparent=None): from cellacdc import widgets + txt = html_utils.paragraph(f""" Points cannot be added or removed in a projection!

Please, switch to "single z-slice" mode (bottom of the image on @@ -302,56 +311,56 @@ def warnCannotAddRemovePointsProjection(qparent=None): Thank you for your patience. """) msg = widgets.myMessageBox(wrapText=False) - msg.warning(qparent, 'WARNING: Editing points in projection', txt) + msg.warning(qparent, "WARNING: Editing points in projection", txt) + def warnRestartAcdcIconsUpdated(qparent=None): from cellacdc import widgets + txt = ( - 'Cell-ACDC had to update the GUI icons. ' - 'Please re-start the application.\n\n' - 'Thank you for your patience!' + "Cell-ACDC had to update the GUI icons. " + "Please re-start the application.\n\n" + "Thank you for your patience!" ) - print('*'*100) + print("*" * 100) print(txt) - print('^'*100) - html_txt = html_utils.paragraph(txt.replace('\n', '
')) + print("^" * 100) + html_txt = html_utils.paragraph(txt.replace("\n", "
")) msg = widgets.myMessageBox(wrapText=False) - msg.information(qparent, 'GUI icons updated', txt) + msg.information(qparent, "GUI icons updated", txt) + def warnMissingCca(missing_cca_items, qparent=None): from cellacdc import widgets, printl + mainText = html_utils.paragraph(f""" Some objects have missing cell cycle annotations!

Please, fix them before saving again, thanks!

See below the list of object IDs without annotations. """) - + details_txt_list = [] for cca_df, posData, frame_i in missing_cca_items: - txt = ( - f'{posData.pos_foldername}:

' - ) - indent = '  ' + txt = f"{posData.pos_foldername}:

" + indent = "  " if frame_i is not None: - txt = (f'{txt}' - f' - Frame n. {frame_i+1}:
' - ) - indent = '    ' + txt = f"{txt} - Frame n. {frame_i + 1}:
" + indent = "    " missing_IDs = cca_df[cca_df.isnull().any(axis=1)].index.to_list() for missing_ID in missing_IDs: - txt = (f'{txt}' - f'{indent}* ID: {missing_ID}
' - ) - + txt = f"{txt}{indent}* ID: {missing_ID}
" + details_txt_list.append(txt) - - detailsText = '
'.join(details_txt_list) + + detailsText = "
".join(details_txt_list) msg = widgets.myMessageBox(wrapText=False) _, ignoreButton = msg.warning( - qparent, 'Missing cell cycle annotations', mainText, + qparent, + "Missing cell cycle annotations", + mainText, detailsText, - buttonsTexts=('Cancel', 'Ignore'), + buttonsTexts=("Cancel", "Ignore"), add_do_not_show_again_checkbox=True, ) doNotShowAgain = msg.doNotShowAgainCheckbox.isChecked() - return msg.clickedButton == ignoreButton, doNotShowAgain \ No newline at end of file + return msg.clickedButton == ignoreButton, doNotShowAgain diff --git a/cellacdc/acdc_bioio_bioformats/__init__.py b/cellacdc/acdc_bioio_bioformats/__init__.py index e248fa8a2..cbabf0854 100644 --- a/cellacdc/acdc_bioio_bioformats/__init__.py +++ b/cellacdc/acdc_bioio_bioformats/__init__.py @@ -5,35 +5,33 @@ conda_prefix = os.environ.get("CONDA_PREFIX") if conda_prefix is not None: if is_win64: - os.environ["JAVA_HOME"] = rf'{conda_prefix}\Library' + os.environ["JAVA_HOME"] = rf"{conda_prefix}\Library" else: os.environ["JAVA_HOME"] = conda_prefix - - print('Setting JAVA_HOME:', os.environ["JAVA_HOME"]) + + print("Setting JAVA_HOME:", os.environ["JAVA_HOME"]) EXTENSION_PACKAGE_MAPPER = { - '.czi': 'bioio-czi', - '.dv': 'bioio-dv', - '.r3d': 'bioio-dv', - '.lif': 'bioio-lif', - '.nd2': 'bioio-nd2', - '.tif': 'bioio-tifffile', - '.tiff': 'bioio-tifffile', - '.ome.tiff': 'bioio-ome-tiff', - '.zarr': 'bioio-ome-zarr', - '.sldy': 'bioio-sldy', - '.dir': 'bioio-sldy', + ".czi": "bioio-czi", + ".dv": "bioio-dv", + ".r3d": "bioio-dv", + ".lif": "bioio-lif", + ".nd2": "bioio-nd2", + ".tif": "bioio-tifffile", + ".tiff": "bioio-tifffile", + ".ome.tiff": "bioio-ome-tiff", + ".zarr": "bioio-ome-zarr", + ".sldy": "bioio-sldy", + ".dir": "bioio-sldy", } EXTENSION_BIOIMAGE_KWARGS_MAPPER = { - '.czi': {'use_aicspylibczi': True}, + ".czi": {"use_aicspylibczi": True}, } EXTENSION_METADATA_ATTR_MAPPER = { - '.czi': { - 'TimeIncrement': 'standard_metadata.timelapse_interval.total_seconds()' - } + ".czi": {"TimeIncrement": "standard_metadata.timelapse_interval.total_seconds()"} } from .reader import ImageReader, get_omexml_metadata, OMEXML, Metadata -from . import _utils \ No newline at end of file +from . import _utils diff --git a/cellacdc/acdc_bioio_bioformats/_init_reader.py b/cellacdc/acdc_bioio_bioformats/_init_reader.py index ae79af9e5..f5ca2e710 100644 --- a/cellacdc/acdc_bioio_bioformats/_init_reader.py +++ b/cellacdc/acdc_bioio_bioformats/_init_reader.py @@ -6,21 +6,21 @@ try: ap.add_argument( - '-f', - '--filepath', - required=True, - type=str, - metavar='FILEPATH', - help='Filepath of a raw microscopy file to test.' + "-f", + "--filepath", + required=True, + type=str, + metavar="FILEPATH", + help="Filepath of a raw microscopy file to test.", ) args = vars(ap.parse_args()) - raw_filepath = args['filepath'] + raw_filepath = args["filepath"] with bioformats.ImageReader(raw_filepath, qparent=None) as reader: print(reader) except Exception as err: args = vars(ap.parse_args()) - uuid4 = args['uuid'] - - bioformats._utils.dump_exception(err, uuid4) \ No newline at end of file + uuid4 = args["uuid"] + + bioformats._utils.dump_exception(err, uuid4) diff --git a/cellacdc/acdc_bioio_bioformats/_read_metadata.py b/cellacdc/acdc_bioio_bioformats/_read_metadata.py index 87e3b8bab..2873e3220 100644 --- a/cellacdc/acdc_bioio_bioformats/_read_metadata.py +++ b/cellacdc/acdc_bioio_bioformats/_read_metadata.py @@ -14,32 +14,28 @@ try: ap.add_argument( - '-f', - '--filepath', - required=True, - type=str, - metavar='FILEPATH', - help='Filepath of a raw microscopy file to test.' + "-f", + "--filepath", + required=True, + type=str, + metavar="FILEPATH", + help="Filepath of a raw microscopy file to test.", ) args = vars(ap.parse_args()) - raw_filepath = args['filepath'] + raw_filepath = args["filepath"] metadataXML = bioformats.get_omexml_metadata(raw_filepath) metadata = bioformats.OMEXML().init_from_metadata(metadataXML) os.makedirs(bioio_sample_data_folderpath, exist_ok=True) - metadataXML_filepath = os.path.join( - bioio_sample_data_folderpath, 'metadataXML.txt' - ) + metadataXML_filepath = os.path.join(bioio_sample_data_folderpath, "metadataXML.txt") metadataXML.to_file(metadataXML_filepath) - metadata_filepath = os.path.join( - bioio_sample_data_folderpath, 'metadata.txt' - ) + metadata_filepath = os.path.join(bioio_sample_data_folderpath, "metadata.txt") metadata.to_file(metadata_filepath) except Exception as err: args = vars(ap.parse_args()) - uuid4 = args['uuid'] - - bioformats._utils.dump_exception(err, uuid4) \ No newline at end of file + uuid4 = args["uuid"] + + bioformats._utils.dump_exception(err, uuid4) diff --git a/cellacdc/acdc_bioio_bioformats/_read_sample_data.py b/cellacdc/acdc_bioio_bioformats/_read_sample_data.py index bdea8020e..59da04ad1 100644 --- a/cellacdc/acdc_bioio_bioformats/_read_sample_data.py +++ b/cellacdc/acdc_bioio_bioformats/_read_sample_data.py @@ -13,61 +13,61 @@ try: ap.add_argument( - '-f', - '--filepath', - required=True, - type=str, - metavar='FILEPATH', - help='Filepath of a raw microscopy file to test.' + "-f", + "--filepath", + required=True, + type=str, + metavar="FILEPATH", + help="Filepath of a raw microscopy file to test.", ) ap.add_argument( - '-c', - '--SizeC', - required=True, - type=int, - metavar='SIZEC', - help='Number of channels in the microscopy file.' + "-c", + "--SizeC", + required=True, + type=int, + metavar="SIZEC", + help="Number of channels in the microscopy file.", ) ap.add_argument( - '-t', - '--SizeT', - required=True, - type=int, - metavar='SIZET', - help='Number of timepoints in the microscopy file.' + "-t", + "--SizeT", + required=True, + type=int, + metavar="SIZET", + help="Number of timepoints in the microscopy file.", ) ap.add_argument( - '-z', - '--SizeZ', - required=True, - type=int, - metavar='SIZEZ', - help='Number of z-slices in a single z-stack.' + "-z", + "--SizeZ", + required=True, + type=int, + metavar="SIZEZ", + help="Number of z-slices in a single z-stack.", ) - + ap.add_argument( - '-a', - '--all', - action='store_true', - help='Whether to read entire position into RAM or not.' + "-a", + "--all", + action="store_true", + help="Whether to read entire position into RAM or not.", ) args = vars(ap.parse_args()) - raw_filepath = args['filepath'] + raw_filepath = args["filepath"] + + SizeC = args["SizeC"] + SizeT = args["SizeT"] + SizeZ = args["SizeZ"] - SizeC = args['SizeC'] - SizeT = args['SizeT'] - SizeZ = args['SizeZ'] - - lazy_load = not args['all'] + lazy_load = not args["all"] if SizeT >= 4: sampleSizeT = 4 else: - sampleSizeT = SizeT + sampleSizeT = SizeT if SizeZ > 20: sampleSizeZ = 20 else: @@ -75,17 +75,17 @@ allChannelsData = [] with bioformats.ImageReader(raw_filepath, lazy_load=lazy_load) as reader: - numIter = SizeC*sampleSizeT*sampleSizeZ + numIter = SizeC * sampleSizeT * sampleSizeZ pbar = tqdm(total=numIter, ncols=100, leave=False) - + for c in range(SizeC): imgData_tz = [] - for t in range(sampleSizeT): + for t in range(sampleSizeT): imgData_z = [] for z in range(sampleSizeZ): imgData = reader.read(c=c, z=z, t=t, rescale=False) - imgData_z.append(imgData) - pbar.update() + imgData_z.append(imgData) + pbar.update() imgData_z = np.array(imgData_z, dtype=imgData.dtype) imgData_z = np.squeeze(imgData_z) imgData_tz.append(imgData_z) @@ -95,13 +95,11 @@ os.makedirs(bioio_sample_data_folderpath, exist_ok=True) for c, channel_data in enumerate(allChannelsData): - filepath = os.path.join( - bioio_sample_data_folderpath, f"sample_channel_{c}.npy" - ) + filepath = os.path.join(bioio_sample_data_folderpath, f"sample_channel_{c}.npy") np.save(filepath, channel_data) except Exception as err: args = vars(ap.parse_args()) - uuid4 = args['uuid'] - - bioformats._utils.dump_exception(err, uuid4) \ No newline at end of file + uuid4 = args["uuid"] + + bioformats._utils.dump_exception(err, uuid4) diff --git a/cellacdc/acdc_bioio_bioformats/_save_data.py b/cellacdc/acdc_bioio_bioformats/_save_data.py index 0ac7c1e1a..96a239085 100644 --- a/cellacdc/acdc_bioio_bioformats/_save_data.py +++ b/cellacdc/acdc_bioio_bioformats/_save_data.py @@ -6,7 +6,7 @@ import h5py from cellacdc import bioio_sample_data_folderpath -from cellacdc import myutils +from cellacdc import utils from cellacdc import acdc_bioio_bioformats as bioformats import argparse @@ -15,173 +15,181 @@ try: ap.add_argument( - '-f', - '--filepath', - required=True, - type=str, - metavar='FILEPATH', - help='Filepath of the raw microscopy file.' + "-f", + "--filepath", + required=True, + type=str, + metavar="FILEPATH", + help="Filepath of the raw microscopy file.", ) ap.add_argument( - '-d', - '--do_save_channels', + "-d", + "--do_save_channels", type=str, - required=True, - metavar='DO_SAVE_CHANNELS', - help='Whether to save the channel or not.' + required=True, + metavar="DO_SAVE_CHANNELS", + help="Whether to save the channel or not.", ) ap.add_argument( - '-c', - '--channel_names', + "-c", + "--channel_names", type=str, - required=True, - metavar='CHANNEL_NAMES', - help='List of channel names.' + required=True, + metavar="CHANNEL_NAMES", + help="List of channel names.", ) ap.add_argument( - '-s', - '--series_idx', - required=True, - type=int, - metavar='SERIES_IDX', - help='Index of the Position in the microscopy file.' + "-s", + "--series_idx", + required=True, + type=int, + metavar="SERIES_IDX", + help="Index of the Position in the microscopy file.", ) ap.add_argument( - '-i', - '--images_path', - required=True, - type=str, - metavar='IMAGE_PATH', - help='Images folder path.' + "-i", + "--images_path", + required=True, + type=str, + metavar="IMAGE_PATH", + help="Images folder path.", ) ap.add_argument( - '-p', - '--filename_no_ext', - required=True, - type=str, - metavar='FILENAME_NO_EXT', - help='Name of the file without extension.' + "-p", + "--filename_no_ext", + required=True, + type=str, + metavar="FILENAME_NO_EXT", + help="Name of the file without extension.", ) ap.add_argument( - '-pos', - '--pos_idx_str', - required=True, - type=str, - metavar='POS_IDX_STR', - help='String index of the Position padded with required zeros.' + "-pos", + "--pos_idx_str", + required=True, + type=str, + metavar="POS_IDX_STR", + help="String index of the Position padded with required zeros.", ) ap.add_argument( - '-t', - '--SizeT', - required=True, - type=int, - metavar='SIZET', - help='Number of timepoints in the microscopy file.' + "-t", + "--SizeT", + required=True, + type=int, + metavar="SIZET", + help="Number of timepoints in the microscopy file.", ) ap.add_argument( - '-z', - '--SizeZ', - required=True, - type=int, - metavar='SIZEZ', - help='Number of z-slices in a single z-stack.' + "-z", + "--SizeZ", + required=True, + type=int, + metavar="SIZEZ", + help="Number of z-slices in a single z-stack.", ) ap.add_argument( - '-time_increment', - '--time_increment', - type=float, - required=True, - metavar='TIME_INCREMENT', - help='Time between consecutive frames in seconds.' + "-time_increment", + "--time_increment", + type=float, + required=True, + metavar="TIME_INCREMENT", + help="Time between consecutive frames in seconds.", ) ap.add_argument( - '-zyx', - '--zyx_physical_sizes', + "-zyx", + "--zyx_physical_sizes", type=str, - required=True, - metavar='ZYX_PHYSICAL_SIZES', - help='Physical sizes in z, y, x dimensions.' + required=True, + metavar="ZYX_PHYSICAL_SIZES", + help="Physical sizes in z, y, x dimensions.", ) ap.add_argument( - '-to_h5', - '--to_h5', - action='store_true', - help='Whether to save with h5 file format.' + "-to_h5", + "--to_h5", + action="store_true", + help="Whether to save with h5 file format.", ) ap.add_argument( - '-r', - '--time_range_to_save', + "-r", + "--time_range_to_save", type=str, - required=True, - metavar='TIME_RANGE_TO_SAVE', - help='Start and end frame to save.' + required=True, + metavar="TIME_RANGE_TO_SAVE", + help="Start and end frame to save.", ) - + ap.add_argument( - '-a', - '--all', - action='store_true', - help='Whether to read entire position into RAM or not.' + "-a", + "--all", + action="store_true", + help="Whether to read entire position into RAM or not.", ) args = vars(ap.parse_args()) - raw_filepath = args['filepath'] - do_save_channels_li = args['do_save_channels'].split() - do_save_channels = [val=='True' for val in do_save_channels_li] - channel_names = args['channel_names'].split() - series = args['series_idx'] - images_path = args['images_path'] - filename_no_ext = args['filename_no_ext'] - SizeT = args['SizeT'] - SizeZ = args['SizeZ'] - TimeIncrement = args['time_increment'] - s0p = args['pos_idx_str'] - - lazy_load = not args['all'] - - zyx_physical_sizes_li = args['zyx_physical_sizes'].split() + raw_filepath = args["filepath"] + do_save_channels_li = args["do_save_channels"].split() + do_save_channels = [val == "True" for val in do_save_channels_li] + channel_names = args["channel_names"].split() + series = args["series_idx"] + images_path = args["images_path"] + filename_no_ext = args["filename_no_ext"] + SizeT = args["SizeT"] + SizeZ = args["SizeZ"] + TimeIncrement = args["time_increment"] + s0p = args["pos_idx_str"] + + lazy_load = not args["all"] + + zyx_physical_sizes_li = args["zyx_physical_sizes"].split() zyx_physical_sizes = [float(val) for val in zyx_physical_sizes_li] PhysicalSizeZ, PhysicalSizeY, PhysicalSizeX = zyx_physical_sizes - to_h5 = args['to_h5'] + to_h5 = args["to_h5"] - time_range_to_save_li = args['time_range_to_save'].split() + time_range_to_save_li = args["time_range_to_save"].split() timeRangeToSave = [int(val) for val in time_range_to_save_li] with bioformats.ImageReader(raw_filepath, lazy_load=lazy_load) as reader: iter = enumerate(zip(channel_names, do_save_channels)) - pbar = tqdm( - total=len(channel_names), - ncols=100, - desc='Saving channels' - ) + pbar = tqdm(total=len(channel_names), ncols=100, desc="Saving channels") for c, (chName, saveCh) in iter: if not saveCh: pbar.update() continue bioformats._utils.saveImgDataChannel( - reader, series, images_path, filename_no_ext, s0p, - chName, c, {}, SizeT, SizeZ, TimeIncrement, PhysicalSizeZ, - PhysicalSizeY, PhysicalSizeX, to_h5, - timeRangeToSave - ) + reader, + series, + images_path, + filename_no_ext, + s0p, + chName, + c, + {}, + SizeT, + SizeZ, + TimeIncrement, + PhysicalSizeZ, + PhysicalSizeY, + PhysicalSizeX, + to_h5, + timeRangeToSave, + ) pbar.update() pbar.close() except Exception as err: args = vars(ap.parse_args()) - uuid4 = args['uuid'] - - bioformats._utils.dump_exception(err, uuid4) \ No newline at end of file + uuid4 = args["uuid"] + + bioformats._utils.dump_exception(err, uuid4) diff --git a/cellacdc/acdc_bioio_bioformats/_save_data_single_channel.py b/cellacdc/acdc_bioio_bioformats/_save_data_single_channel.py index a96984a83..ec6943959 100644 --- a/cellacdc/acdc_bioio_bioformats/_save_data_single_channel.py +++ b/cellacdc/acdc_bioio_bioformats/_save_data_single_channel.py @@ -6,7 +6,7 @@ import h5py from cellacdc import bioio_sample_data_folderpath -from cellacdc import myutils +from cellacdc import utils from cellacdc import acdc_bioio_bioformats as bioformats import argparse @@ -15,172 +15,186 @@ try: ap.add_argument( - '-f', - '--filepath', - required=True, - type=str, - metavar='FILEPATH', - help='Filepath of the raw microscopy file.' + "-f", + "--filepath", + required=True, + type=str, + metavar="FILEPATH", + help="Filepath of the raw microscopy file.", ) ap.add_argument( - '-d', - '--do_save_channels', + "-d", + "--do_save_channels", type=str, - required=True, - metavar='DO_SAVE_CHANNELS', - help='Whether to save the channel or not.' + required=True, + metavar="DO_SAVE_CHANNELS", + help="Whether to save the channel or not.", ) ap.add_argument( - '-c', - '--channel_name', - type=str, - required=True, - metavar='CHANNEL_NAMES', - help='Channel name' + "-c", + "--channel_name", + type=str, + required=True, + metavar="CHANNEL_NAMES", + help="Channel name", ) ap.add_argument( - '-ch_idx', - '--ch_idx', - required=True, - type=int, - metavar='CH_IDX', - help='Index of the channel.' + "-ch_idx", + "--ch_idx", + required=True, + type=int, + metavar="CH_IDX", + help="Index of the channel.", ) ap.add_argument( - '-z', - '--SizeZ', - required=True, - type=int, - metavar='SIZEZ', - help='Number of z-slices in a single z-stack.' + "-z", + "--SizeZ", + required=True, + type=int, + metavar="SIZEZ", + help="Number of z-slices in a single z-stack.", ) ap.add_argument( - '-s', - '--series_idx', - required=True, - type=int, - metavar='SERIES_IDX', - help='Index of the Position in the microscopy file.' + "-s", + "--series_idx", + required=True, + type=int, + metavar="SERIES_IDX", + help="Index of the Position in the microscopy file.", ) ap.add_argument( - '-i', - '--images_path', - required=True, - type=str, - metavar='IMAGE_PATH', - help='Images folder path.' + "-i", + "--images_path", + required=True, + type=str, + metavar="IMAGE_PATH", + help="Images folder path.", ) ap.add_argument( - '-p', - '--filename_no_ext', - required=True, - type=str, - metavar='FILENAME_NO_EXT', - help='Name of the file without extension.' + "-p", + "--filename_no_ext", + required=True, + type=str, + metavar="FILENAME_NO_EXT", + help="Name of the file without extension.", ) ap.add_argument( - '-pos', - '--pos_idx_str', - required=True, - type=str, - metavar='POS_IDX_STR', - help='String index of the Position padded with required zeros.' + "-pos", + "--pos_idx_str", + required=True, + type=str, + metavar="POS_IDX_STR", + help="String index of the Position padded with required zeros.", ) ap.add_argument( - '-t', - '--SizeT', - required=True, - type=int, - metavar='SIZET', - help='Number of timepoints in the microscopy file.' + "-t", + "--SizeT", + required=True, + type=int, + metavar="SIZET", + help="Number of timepoints in the microscopy file.", ) ap.add_argument( - '-time_increment', - '--time_increment', - type=float, - required=True, - metavar='TIME_INCREMENT', - help='Time between consecutive frames in seconds.' + "-time_increment", + "--time_increment", + type=float, + required=True, + metavar="TIME_INCREMENT", + help="Time between consecutive frames in seconds.", ) ap.add_argument( - '-zyx', - '--zyx_physical_sizes', + "-zyx", + "--zyx_physical_sizes", type=str, - required=True, - metavar='ZYX_PHYSICAL_SIZES', - help='Physical sizes in z, y, x dimensions.' + required=True, + metavar="ZYX_PHYSICAL_SIZES", + help="Physical sizes in z, y, x dimensions.", ) ap.add_argument( - '-to_h5', - '--to_h5', - action='store_true', - help='Whether to save with h5 file format.' + "-to_h5", + "--to_h5", + action="store_true", + help="Whether to save with h5 file format.", ) ap.add_argument( - '-r', - '--time_range_to_save', + "-r", + "--time_range_to_save", type=str, - required=True, - metavar='TIME_RANGE_TO_SAVE', - help='Start and end frame to save.' + required=True, + metavar="TIME_RANGE_TO_SAVE", + help="Start and end frame to save.", ) - + ap.add_argument( - '-a', - '--all', - action='store_true', - help='Whether to read entire position into RAM or not.' + "-a", + "--all", + action="store_true", + help="Whether to read entire position into RAM or not.", ) args = vars(ap.parse_args()) - raw_filepath = args['filepath'] - do_save_channels_li = args['do_save_channels'].split() - do_save_channels = [val=='True' for val in do_save_channels_li] - - channel_name = args['channel_name'] - ch_idx = args['ch_idx'] - series = args['series_idx'] - images_path = args['images_path'] - filename_no_ext = args['filename_no_ext'] - SizeT = args['SizeT'] - SizeZ = args['SizeZ'] - TimeIncrement = args['time_increment'] - s0p = args['pos_idx_str'] - - lazy_load = not args['all'] - - zyx_physical_sizes_li = args['zyx_physical_sizes'].split() + raw_filepath = args["filepath"] + do_save_channels_li = args["do_save_channels"].split() + do_save_channels = [val == "True" for val in do_save_channels_li] + + channel_name = args["channel_name"] + ch_idx = args["ch_idx"] + series = args["series_idx"] + images_path = args["images_path"] + filename_no_ext = args["filename_no_ext"] + SizeT = args["SizeT"] + SizeZ = args["SizeZ"] + TimeIncrement = args["time_increment"] + s0p = args["pos_idx_str"] + + lazy_load = not args["all"] + + zyx_physical_sizes_li = args["zyx_physical_sizes"].split() zyx_physical_sizes = [float(val) for val in zyx_physical_sizes_li] PhysicalSizeZ, PhysicalSizeY, PhysicalSizeX = zyx_physical_sizes - to_h5 = args['to_h5'] + to_h5 = args["to_h5"] - time_range_to_save_li = args['time_range_to_save'].split() + time_range_to_save_li = args["time_range_to_save"].split() timeRangeToSave = [int(val) for val in time_range_to_save_li] with bioformats.ImageReader(raw_filepath, lazy_load=lazy_load) as reader: - print(f'Saving channel {ch_idx+1}/{len(do_save_channels)} ({channel_name})...') + print( + f"Saving channel {ch_idx + 1}/{len(do_save_channels)} ({channel_name})..." + ) bioformats._utils.saveImgDataChannel( - reader, series, images_path, filename_no_ext, s0p, - channel_name, 0, {}, SizeT, SizeZ, TimeIncrement, PhysicalSizeZ, - PhysicalSizeY, PhysicalSizeX, to_h5, - timeRangeToSave + reader, + series, + images_path, + filename_no_ext, + s0p, + channel_name, + 0, + {}, + SizeT, + SizeZ, + TimeIncrement, + PhysicalSizeZ, + PhysicalSizeY, + PhysicalSizeX, + to_h5, + timeRangeToSave, ) except Exception as err: args = vars(ap.parse_args()) - uuid4 = args['uuid'] - - bioformats._utils.dump_exception(err, uuid4) \ No newline at end of file + uuid4 = args["uuid"] + + bioformats._utils.dump_exception(err, uuid4) diff --git a/cellacdc/acdc_bioio_bioformats/_utils.py b/cellacdc/acdc_bioio_bioformats/_utils.py index 56c149630..d18604f07 100644 --- a/cellacdc/acdc_bioio_bioformats/_utils.py +++ b/cellacdc/acdc_bioio_bioformats/_utils.py @@ -11,121 +11,112 @@ import numpy as np import h5py -from cellacdc import myutils, bioio_sample_data_folderpath +from cellacdc import utils, bioio_sample_data_folderpath + def setup_argparser(): ap = argparse.ArgumentParser( - prog='Cell-ACDC process', - description='Used to spawn a separate process', - formatter_class=argparse.RawTextHelpFormatter + prog="Cell-ACDC process", + description="Used to spawn a separate process", + formatter_class=argparse.RawTextHelpFormatter, ) ap.add_argument( - '-uuid', - '--uuid4', - required=False, - type=str, - metavar='UUID4', - help='String ID to use to store error for current session.', - default='42' + "-uuid", + "--uuid4", + required=False, + type=str, + metavar="UUID4", + help="String ID to use to store error for current session.", + default="42", ) return ap + def removeInvalidCharacters(chName_in): # Remove invalid charachters chName = "".join( - c if c.isalnum() or c=='_' or c=='' else '_' for c in chName_in + c if c.isalnum() or c == "_" or c == "" else "_" for c in chName_in ) - trim_ = chName.endswith('_') + trim_ = chName.endswith("_") while trim_: chName = chName[:-1] - trim_ = chName.endswith('_') + trim_ = chName.endswith("_") + -def getFilename( - filenameNOext, s0p, appendTxt, series, ext, - return_basename=False - ): +def getFilename(filenameNOext, s0p, appendTxt, series, ext, return_basename=False): # Do not allow dots in the filename since it breaks stuff here and there - filenameNOext = filenameNOext.replace('.', '_') - basename = f'{filenameNOext}_s{s0p}_' - filename = f'{basename}{appendTxt}{ext}' + filenameNOext = filenameNOext.replace(".", "_") + basename = f"{filenameNOext}_s{s0p}_" + filename = f"{basename}{appendTxt}{ext}" if return_basename: return filename, basename else: return filename + def saveImgDataChannel( - reader, - series: int, - images_path: os.PathLike, - filenameNOext: str, - s0p: str, - chName: str, - ch_idx: int, - idxs: dict, - SizeT: int, - SizeZ: int, - TimeIncrement: float, - PhysicalSizeZ: float, - PhysicalSizeY: float, - PhysicalSizeX: float, - to_h5: bool, - timeRangeToSave: Tuple[int, int], - ): + reader, + series: int, + images_path: os.PathLike, + filenameNOext: str, + s0p: str, + chName: str, + ch_idx: int, + idxs: dict, + SizeT: int, + SizeZ: int, + TimeIncrement: float, + PhysicalSizeZ: float, + PhysicalSizeY: float, + PhysicalSizeX: float, + to_h5: bool, + timeRangeToSave: Tuple[int, int], +): savedSizeT = timeRangeToSave[1] - timeRangeToSave[0] + 1 if to_h5: - filename = getFilename( - filenameNOext, s0p, chName, series, '.h5' - ) + filename = getFilename(filenameNOext, s0p, chName, series, ".h5") tempDir = tempfile.mkdtemp() tempFilepath = os.path.join(tempDir, filename) - print('==========================================================') + print("==========================================================") print(f'.h5 tempfile: "{tempFilepath}"') - print('==========================================================') - h5f = h5py.File(tempFilepath, 'w') + print("==========================================================") + h5f = h5py.File(tempFilepath, "w") # Read SizeX and SizeY from the shape of one image - imgData = reader.read( - c=ch_idx, z=0, t=0, series=series, rescale=False - ) + imgData = reader.read(c=ch_idx, z=0, t=0, series=series, rescale=False) shape = (savedSizeT, SizeZ, *imgData.shape) - chunks = (1,1,*imgData.shape) + chunks = (1, 1, *imgData.shape) imgData_ch = h5f.create_dataset( - 'data', shape, dtype=imgData.dtype, - chunks=chunks, shuffle=False + "data", shape, dtype=imgData.dtype, chunks=chunks, shuffle=False ) else: - filename = getFilename( - filenameNOext, s0p, chName, series, '.tif' - ) + filename = getFilename(filenameNOext, s0p, chName, series, ".tif") imgData_ch = [] - framesRange = range(timeRangeToSave[0]-1, timeRangeToSave[1]) + framesRange = range(timeRangeToSave[0] - 1, timeRangeToSave[1]) filePath = os.path.join(images_path, filename) - dimsIdx = {'c': ch_idx} + dimsIdx = {"c": ch_idx} numFrames = len(framesRange) - num_imgs = numFrames*SizeZ + num_imgs = numFrames * SizeZ pbar = tqdm( - total=num_imgs, - ncols=100, - desc=f'Reading image (z 0/{SizeZ}, t 0/{numFrames})' + total=num_imgs, ncols=100, desc=f"Reading image (z 0/{SizeZ}, t 0/{numFrames})" ) for out_t, t in enumerate(framesRange): imgData_z = [] - dimsIdx['t'] = t + dimsIdx["t"] = t for z in range(SizeZ): pbar.set_description( - f'Reading image (z {z+1}/{SizeZ}, t {out_t+1}/{numFrames})' + f"Reading image (z {z + 1}/{SizeZ}, t {out_t + 1}/{numFrames})" ) - dimsIdx['z'] = z + dimsIdx["z"] = z idx = None imgData = reader.read( - c=ch_idx, z=z, t=t, series=series, rescale=False, - index=idx + c=ch_idx, z=z, t=t, series=series, rescale=False, index=idx ) if to_h5: imgData_ch[out_t, z] = imgData else: imgData_z.append(imgData) - + pbar.update() if not to_h5: @@ -135,8 +126,9 @@ def saveImgDataChannel( if not to_h5: imgData_ch = np.squeeze(np.array(imgData_ch, dtype=imgData.dtype)) - myutils.to_tiff( - filePath, imgData_ch, + utils.to_tiff( + filePath, + imgData_ch, SizeT=savedSizeT, SizeZ=SizeZ, TimeIncrement=TimeIncrement, @@ -149,25 +141,25 @@ def saveImgDataChannel( shutil.move(tempFilepath, filePath) shutil.rmtree(tempDir) + def dump_exception(err, error_id): import pickle - error_path = os.path.join( - bioio_sample_data_folderpath, f'error_{error_id}.pkl' - ) - with open(error_path, 'wb') as file: + + error_path = os.path.join(bioio_sample_data_folderpath, f"error_{error_id}.pkl") + with open(error_path, "wb") as file: pickle.dump(err, file) + def check_raise_exception(error_id): import pickle - error_path = os.path.join( - bioio_sample_data_folderpath, f'error_{error_id}.pkl' - ) + + error_path = os.path.join(bioio_sample_data_folderpath, f"error_{error_id}.pkl") if not os.path.exists(error_path): return - + with open(error_path, "rb") as file: err = pickle.load(file) - + os.remove(error_path) - - raise err \ No newline at end of file + + raise err diff --git a/cellacdc/acdc_bioio_bioformats/install.py b/cellacdc/acdc_bioio_bioformats/install.py index 8f8ddd045..3be680800 100644 --- a/cellacdc/acdc_bioio_bioformats/install.py +++ b/cellacdc/acdc_bioio_bioformats/install.py @@ -2,65 +2,63 @@ import re -from cellacdc import myutils +from cellacdc import utils from . import EXTENSION_PACKAGE_MAPPER -pkg_regex = r'[a-zA-Z0-9_\-]+' +pkg_regex = r"[a-zA-Z0-9_\-]+" -def _check_install_bioio_bioformats(qparent=None): - myutils.check_install_package( - 'scyjava', - installer='conda', + +def _check_install_bioio_bioformats(qparent=None): + utils.check_install_package( + "scyjava", + installer="conda", is_cli=qparent is None, - exact_version='1.10.2', - parent=qparent + exact_version="1.10.2", + parent=qparent, ) - - myutils.check_install_package( - 'bioio-bioformats', - installer='pip', + + utils.check_install_package( + "bioio-bioformats", + installer="pip", is_cli=qparent is None, - min_version='1.0.0', - max_version='2.0.0', + min_version="1.0.0", + max_version="2.0.0", include_higher_version=False, include_lower_version=True, - parent=qparent + parent=qparent, ) - + return True -def _check_install_extra_format_dependency( - image_filepath: os.PathLike, - qparent=None - ): - - if image_filepath.endswith('.ome.tiff'): - ext = '.ome.tiff' + +def _check_install_extra_format_dependency(image_filepath: os.PathLike, qparent=None): + + if image_filepath.endswith(".ome.tiff"): + ext = ".ome.tiff" else: _, ext = os.path.splitext(image_filepath) package_name = EXTENSION_PACKAGE_MAPPER.get(ext) - + if package_name is None: _check_install_bioio_bioformats(qparent=qparent) return - - myutils.check_install_package( + + utils.check_install_package( package_name, - installer='pip', + installer="pip", is_cli=qparent is None, parent=qparent, ) + def install_reader_dependencies( - image_filepath: os.PathLike, - exception: Exception, - qparent=None - ): + image_filepath: os.PathLike, exception: Exception, qparent=None +): try: success = _check_install_extra_format_dependency( image_filepath, qparent=qparent ) - + except Exception as err: - raise exception \ No newline at end of file + raise exception diff --git a/cellacdc/acdc_bioio_bioformats/reader.py b/cellacdc/acdc_bioio_bioformats/reader.py index 03a823bd8..a7fc4d2e9 100644 --- a/cellacdc/acdc_bioio_bioformats/reader.py +++ b/cellacdc/acdc_bioio_bioformats/reader.py @@ -4,157 +4,169 @@ import numpy as np from .. import printl -from ..myutils import safe_get_or_call +from ..utils import safe_get_or_call from . import install, EXTENSION_PACKAGE_MAPPER from . import EXTENSION_BIOIMAGE_KWARGS_MAPPER from . import EXTENSION_METADATA_ATTR_MAPPER + def set_reader(image_filepath, **kwargs): - if 'reader' in kwargs: + if "reader" in kwargs: return kwargs - + _, ext = os.path.splitext(image_filepath) if ext in EXTENSION_PACKAGE_MAPPER: - all_kwargs = { - **kwargs, - **EXTENSION_BIOIMAGE_KWARGS_MAPPER.get(ext, {}) - } + all_kwargs = {**kwargs, **EXTENSION_BIOIMAGE_KWARGS_MAPPER.get(ext, {})} return all_kwargs - + try: import bioio_bioformats - kwargs['reader'] = bioio_bioformats.Reader + + kwargs["reader"] = bioio_bioformats.Reader except ImportError: from bioio_base.exceptions import UnsupportedFileFormatError + raise UnsupportedFileFormatError( - 'Bioformats', 'Bioformats reader is not installed' + "Bioformats", "Bioformats reader is not installed" ) - + return kwargs + class ImageReader: def __init__( - self, image_filepath: os.PathLike, qparent=None, lazy_load=True, - **kwargs - ): + self, image_filepath: os.PathLike, qparent=None, lazy_load=True, **kwargs + ): from bioio import BioImage from bioio_base.exceptions import UnsupportedFileFormatError - + self._image_filepath = image_filepath - + # Capture BioImage error and install required dependencies try: kwargs = set_reader(image_filepath, **kwargs) self._bioioimage = BioImage(image_filepath, **kwargs) except UnsupportedFileFormatError as err: - install.install_reader_dependencies( - image_filepath, err, - qparent=qparent - ) + install.install_reader_dependencies(image_filepath, err, qparent=qparent) kwargs = set_reader(image_filepath, **kwargs) self._bioioimage = BioImage(image_filepath, **kwargs) - + self._is_lazy_load = lazy_load - + if lazy_load: return - + self.img_data = self._bioioimage.data - + def read(self, c=0, z=0, t=0, rescale=False, index=None, series=0): if self._bioioimage.current_scene_index != series: self._bioioimage.set_scene(series) if not self._is_lazy_load: self.img_data = self._bioioimage.data - + if self._is_lazy_load: lazy_img = self._bioioimage.get_image_dask_data("YX", T=t, C=c, Z=z) return lazy_img.compute() - + return self.img_data[t, c, z] - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_value, traceback): return + class Metadata: def __init__(self): pass - + def to_file(self, filepath): - with open(filepath, 'w') as file: + with open(filepath, "w") as file: file.write(str(self)) - + def init_from_image_filepath(self, image_filepath, qparent=None): self.image_filepath = image_filepath self.qparent = qparent - + with ImageReader(image_filepath, qparent=qparent) as bioio_image: self.metadata = bioio_image._bioioimage.metadata - + return self def init_from_file(self, filepath): - with open(filepath, 'r') as file: + with open(filepath, "r") as file: self.metadata = file.read() - + def __str__(self): return str(self.metadata) + class Channel: pass + class Node: def __init__(self, image_filepath, bioimage_class): _, ext = os.path.splitext(image_filepath) try: self._node = { - 'TimeIncrement': bioimage_class.time_interval.total_seconds(), - 'TimeIncrementUnit': 's' + "TimeIncrement": bioimage_class.time_interval.total_seconds(), + "TimeIncrementUnit": "s", } except Exception as err: self._node = {} - + if ext not in EXTENSION_METADATA_ATTR_MAPPER: return - + name_expression_mapper = EXTENSION_METADATA_ATTR_MAPPER[ext] for name, expression in name_expression_mapper.items(): try: self._node[name] = safe_get_or_call(bioimage_class, expression) except Exception as err: self._node[name] = None - + def get(self, name): value = self._node.get(name) if value is None: raise ValueError(f"Node '{name}' not found in metadata.") - + return value -class Pixels: + +class Pixels: def Channel(self, c: int): channel = Channel() channel.Name = self.channel_names[c] return channel + def get_omexml_metadata(image_filepath, qparent=None): return Metadata().init_from_image_filepath(image_filepath, qparent=None) + class PhysicalPixelSizes: def __init__(self, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ): self.X = PhysicalSizeX self.Y = PhysicalSizeY self.Z = PhysicalSizeZ + class BioImageMetadata: def __init__( - self, SizeT, SizeC, SizeZ, SizeY, SizeX, - PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, - channel_names, image_count - ): + self, + SizeT, + SizeC, + SizeZ, + SizeY, + SizeX, + PhysicalSizeX, + PhysicalSizeY, + PhysicalSizeZ, + channel_names, + image_count, + ): self.shape = (SizeT, SizeC, SizeZ, SizeY, SizeX) self.physical_pixel_sizes = PhysicalPixelSizes( PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ @@ -162,99 +174,107 @@ def __init__( self.channel_names = channel_names self.scenes = list(range(image_count)) + class OMEXML: def __init__(self): self.qparent = None - + def init_from_metadata(self, metadata: Metadata): self.image_filepath = metadata.image_filepath self.qparent = metadata.qparent - + image_filepath = self.image_filepath qparent = self.qparent - + with ImageReader(image_filepath, qparent=qparent) as bioio_image: self.bioimage = bioio_image._bioioimage - + self._init_Pixels(image_filepath) - + return self - + def _init_Pixels(self, image_filepath): self.Pixels = Pixels() self.Pixels.node = Node(image_filepath, self.bioimage) - self.Pixels.channel_names = self.bioimage.channel_names - + self.Pixels.channel_names = self.bioimage.channel_names + def __str__(self): self.image() txt = ( - f'Image: {self.image_filepath}\n' - f'Channels: {self.Pixels.channel_names}\n' - f'SizeC: {self.Pixels.SizeC}\n' - f'SizeT: {self.Pixels.SizeT}\n' - f'SizeZ: {self.Pixels.SizeZ}\n' - f'SizeY: {self.Pixels.SizeY}\n' - f'SizeX: {self.Pixels.SizeX}\n' - f'PhysicalSizeX: {self.bioimage.physical_pixel_sizes.X}\n' - f'PhysicalSizeY: {self.bioimage.physical_pixel_sizes.Y}\n' - f'PhysicalSizeZ: {self.bioimage.physical_pixel_sizes.Z}\n' - f'Image count: {self.get_image_count()}' + f"Image: {self.image_filepath}\n" + f"Channels: {self.Pixels.channel_names}\n" + f"SizeC: {self.Pixels.SizeC}\n" + f"SizeT: {self.Pixels.SizeT}\n" + f"SizeZ: {self.Pixels.SizeZ}\n" + f"SizeY: {self.Pixels.SizeY}\n" + f"SizeX: {self.Pixels.SizeX}\n" + f"PhysicalSizeX: {self.bioimage.physical_pixel_sizes.X}\n" + f"PhysicalSizeY: {self.bioimage.physical_pixel_sizes.Y}\n" + f"PhysicalSizeZ: {self.bioimage.physical_pixel_sizes.Z}\n" + f"Image count: {self.get_image_count()}" ) return txt - + def to_file(self, filepath): - with open(filepath, 'w') as file: + with open(filepath, "w") as file: file.write(str(self)) - + def init_from_file(self, filepath, image_filepath): - with open(filepath, 'r') as file: + with open(filepath, "r") as file: txt = file.read() - + keys_dtype_kwarg_mapper = { - 'Image': (str, 'image_filepath', ''), - 'Channels': (eval, 'channel_names', ['ch0']), - 'SizeC': (int, 'SizeC', 1), - 'SizeT': (int, 'SizeT', 1), - 'SizeZ': (int, 'SizeZ', 1), - 'SizeY': (int, 'SizeY', 1), - 'SizeX': (int, 'SizeX', 1), - 'PhysicalSizeX': (float, 'PhysicalSizeX', 1.0), - 'PhysicalSizeY': (float, 'PhysicalSizeY', 1.0), - 'PhysicalSizeZ': (float, 'PhysicalSizeZ', 1.0), - 'Image count': (int, 'image_count', 1.0), + "Image": (str, "image_filepath", ""), + "Channels": (eval, "channel_names", ["ch0"]), + "SizeC": (int, "SizeC", 1), + "SizeT": (int, "SizeT", 1), + "SizeZ": (int, "SizeZ", 1), + "SizeY": (int, "SizeY", 1), + "SizeX": (int, "SizeX", 1), + "PhysicalSizeX": (float, "PhysicalSizeX", 1.0), + "PhysicalSizeY": (float, "PhysicalSizeY", 1.0), + "PhysicalSizeZ": (float, "PhysicalSizeZ", 1.0), + "Image count": (int, "image_count", 1.0), } for key, (dtype, kwarg, default) in keys_dtype_kwarg_mapper.items(): - value = re.search(f'{key}: (.+)', txt).group(1) + value = re.search(f"{key}: (.+)", txt).group(1) print(key, value, type(value)) try: setattr(self, kwarg, dtype(value)) except Exception as err: setattr(self, kwarg, default) - + self.bioimage = BioImageMetadata( - self.SizeT, self.SizeC, self.SizeZ, self.SizeY, self.SizeX, - self.PhysicalSizeX, self.PhysicalSizeY, self.PhysicalSizeZ, - self.channel_names, self.image_count + self.SizeT, + self.SizeC, + self.SizeZ, + self.SizeY, + self.SizeX, + self.PhysicalSizeX, + self.PhysicalSizeY, + self.PhysicalSizeZ, + self.channel_names, + self.image_count, ) - + self._init_Pixels(image_filepath) - + return self - - def image(self): + + def image(self): SizeT, SizeC, SizeZ, SizeY, SizeX = self.bioimage.shape - + self.Pixels.SizeY = SizeY self.Pixels.SizeX = SizeX self.Pixels.SizeZ = SizeZ self.Pixels.SizeT = SizeT self.Pixels.SizeC = SizeC - + self.Pixels.PhysicalSizeX = self.bioimage.physical_pixel_sizes.X self.Pixels.PhysicalSizeY = self.bioimage.physical_pixel_sizes.Y self.Pixels.PhysicalSizeZ = self.bioimage.physical_pixel_sizes.Z - + return self - + def get_image_count(self): - return len(self.bioimage.scenes) \ No newline at end of file + return len(self.bioimage.scenes) diff --git a/cellacdc/acdc_regex.py b/cellacdc/acdc_regex.py index fc355f4db..28f79c03b 100644 --- a/cellacdc/acdc_regex.py +++ b/cellacdc/acdc_regex.py @@ -1,46 +1,53 @@ import re -RE_SPLIT_SPACES_IGNORE_QUOTES = re.compile(r'''((?:[^ "']|"[^"]*"|'[^']*')+)''') +RE_SPLIT_SPACES_IGNORE_QUOTES = re.compile(r"""((?:[^ "']|"[^"]*"|'[^']*')+)""") -def float_regex(allow_negative=True, left_chars='', include_nan=False): - pattern = r'[-+]?[0-9]*\.?[0-9]*[eE]?[\-+]?[0-9]+' + +def float_regex(allow_negative=True, left_chars="", include_nan=False): + pattern = r"[-+]?[0-9]*\.?[0-9]*[eE]?[\-+]?[0-9]+" if left_chars: - pattern = fr'{left_chars}{pattern}' + pattern = rf"{left_chars}{pattern}" if not allow_negative: - pattern.replace('[-+]?', '[+]?') + pattern.replace("[-+]?", "[+]?") if include_nan: - nan_pattern = r'NAN|Nan|NaN|nan' - pattern = fr'{nan_pattern}|{pattern}' + nan_pattern = r"NAN|Nan|NaN|nan" + pattern = rf"{nan_pattern}|{pattern}" return pattern -def to_alphanumeric(text, replacing_char='_'): - return re.sub(r'[^\w\-.]', '_', text) + +def to_alphanumeric(text, replacing_char="_"): + return re.sub(r"[^\w\-.]", "_", text) + def get_function_names(text, include_class_methods=True): if include_class_methods: - pattern = r'\bdef\s+([a-zA-Z_]\w*)\s*\(' + pattern = r"\bdef\s+([a-zA-Z_]\w*)\s*\(" else: - pattern = r'\ndef\s+([a-zA-Z_]\w*)\s*\(' + pattern = r"\ndef\s+([a-zA-Z_]\w*)\s*\(" return re.findall(pattern, text) + def is_alphanumeric_filename(text, allow_space=True): if allow_space: - pattern = r'^[\w\-_. ]+$' + pattern = r"^[\w\-_. ]+$" else: - pattern = r'^[\w\-_.]+$' - is_single_or_no_dot = len(re.findall(r'\.', text)) <= 1 + pattern = r"^[\w\-_.]+$" + is_single_or_no_dot = len(re.findall(r"\.", text)) <= 1 return bool(re.match(pattern, text)) and is_single_or_no_dot + def get_non_alphanumeric_characters(text): - return re.findall(r'[^\w\-.]', text) - -if __name__ == '__main__': + return re.findall(r"[^\w\-.]", text) + + +if __name__ == "__main__": import re - s = '0.5, 2.5, nan, NaN' - expr = fr'{float_regex(include_nan=True)}' - m = re.findall(expr, s.replace(' ', '')) + + s = "0.5, 2.5, nan, NaN" + expr = rf"{float_regex(include_nan=True)}" + m = re.findall(expr, s.replace(" ", "")) print(m) - - s = 'ciao_ciao_-yessa' - - print(is_alphanumeric_filename(s)) \ No newline at end of file + + s = "ciao_ciao_-yessa" + + print(is_alphanumeric_filename(s)) diff --git a/cellacdc/annotate.py b/cellacdc/annotate.py index ec09b3610..8cefd95fc 100644 --- a/cellacdc/annotate.py +++ b/cellacdc/annotate.py @@ -11,129 +11,134 @@ from PIL import Image, ImageFont, ImageDraw from qtpy.QtGui import QFont import pyqtgraph as pg - pg.setConfigOption('imageAxisOrder', 'row-major') - + + pg.setConfigOption("imageAxisOrder", "row-major") + from . import plot -INVERTIBLE_COLOR_NAMES = [ - 'label', 'S_phase_mother', 'G1_phase' -] -FONT_FAMILY = 'Helvetica' +INVERTIBLE_COLOR_NAMES = ["label", "S_phase_mother", "G1_phase"] +FONT_FAMILY = "Helvetica" font_path = os.path.join( - cellacdc_path, 'resources', 'fonts', f'{FONT_FAMILY}-Regular.ttf') + cellacdc_path, "resources", "fonts", f"{FONT_FAMILY}-Regular.ttf" +) font_bold_path = os.path.join( - cellacdc_path, 'resources', 'fonts', f'{FONT_FAMILY}-Bold.ttf' + cellacdc_path, "resources", "fonts", f"{FONT_FAMILY}-Bold.ttf" ) + def get_obj_text_label_annot( - obj, acdc_df: pd.DataFrame, is_tree_annot: bool, add_num_zslices: bool - ) -> str: + obj, acdc_df: pd.DataFrame, is_tree_annot: bool, add_num_zslices: bool +) -> str: if is_tree_annot and acdc_df is not None: try: - annot_label = acdc_df.at[obj.label, 'Cell_ID_tree'] + annot_label = acdc_df.at[obj.label, "Cell_ID_tree"] except Exception as err: # print(traceback.format_exc()) annot_label = obj.label else: annot_label = obj.label - + if not add_num_zslices: return str(annot_label) - - num_z_slices = np.sum(np.any(obj.image, axis=(1,2))) - return f'{annot_label} ({num_z_slices})' -def get_obj_text_cca_annot( - obj, acdc_df: pd.DataFrame, is_tree_annot: bool - ) -> str: + num_z_slices = np.sum(np.any(obj.image, axis=(1, 2))) + return f"{annot_label} ({num_z_slices})" + + +def get_obj_text_cca_annot(obj, acdc_df: pd.DataFrame, is_tree_annot: bool) -> str: ID = obj.label try: cca_df_obj = acdc_df.loc[ID] except Exception as e: return str(ID), None - + try: - ccs = cca_df_obj['cell_cycle_stage'] + ccs = cca_df_obj["cell_cycle_stage"] except Exception as err: - return str(ID), None + return str(ID), None try: - generation_num = int(cca_df_obj['generation_num']) + generation_num = int(cca_df_obj["generation_num"]) except Exception as e: return str(ID), None - - generation_num = 'ND' if generation_num==-1 else generation_num + + generation_num = "ND" if generation_num == -1 else generation_num if is_tree_annot: try: - generation_num = cca_df_obj['generation_num_tree'] + generation_num = cca_df_obj["generation_num_tree"] except Exception as e: generation_num = generation_num - txt = f'{ccs}-{generation_num}' + txt = f"{ccs}-{generation_num}" - is_history_known = cca_df_obj['is_history_known'] + is_history_known = cca_df_obj["is_history_known"] if not is_history_known: - txt = f'{txt}?' + txt = f"{txt}?" return txt, cca_df_obj + def get_obj_text_annot_opts( - obj, acdc_df: pd.DataFrame, is_cca_annot: bool, is_new_obj: bool, - add_num_zslices: bool, is_label_tree_annot: bool, - is_gen_num_tree_annot: bool, frame_i: int - ) -> dict: + obj, + acdc_df: pd.DataFrame, + is_cca_annot: bool, + is_new_obj: bool, + add_num_zslices: bool, + is_label_tree_annot: bool, + is_gen_num_tree_annot: bool, + frame_i: int, +) -> dict: if acdc_df is None or not is_cca_annot: bold = False if is_new_obj: - color_name = 'new_object' + color_name = "new_object" else: - color_name = 'label' + color_name = "label" text = get_obj_text_label_annot( obj, acdc_df, is_label_tree_annot, add_num_zslices ) else: - text, cca_df_obj = get_obj_text_cca_annot( - obj, acdc_df, is_gen_num_tree_annot - ) + text, cca_df_obj = get_obj_text_cca_annot(obj, acdc_df, is_gen_num_tree_annot) if cca_df_obj is None: if is_new_obj: - color_name = 'new_object' + color_name = "new_object" else: - color_name = 'label' - opts = {'text': text, 'color_name': color_name, 'bold': False} + color_name = "label" + opts = {"text": text, "color_name": color_name, "bold": False} return opts - - ccs = cca_df_obj['cell_cycle_stage'] - relationship = cca_df_obj['relationship'] - is_bud = relationship == 'bud' - emerg_frame_i = int(cca_df_obj['emerg_frame_i']) + + ccs = cca_df_obj["cell_cycle_stage"] + relationship = cca_df_obj["relationship"] + is_bud = relationship == "bud" + emerg_frame_i = int(cca_df_obj["emerg_frame_i"]) bud_emerged_now = (emerg_frame_i == frame_i) and is_bud bold = bud_emerged_now # Check if it will divide to use orange instead of red bud_will_divide = False - if ccs == 'S' and is_bud: - bud_will_divide = cca_df_obj['will_divide'] > 0 + if ccs == "S" and is_bud: + bud_will_divide = cca_df_obj["will_divide"] > 0 if bud_will_divide: - color_name = 'bud_will_divide' - elif ccs == 'S': - if relationship == 'mother': - color_name = 'S_phase_mother' + color_name = "bud_will_divide" + elif ccs == "S": + if relationship == "mother": + color_name = "S_phase_mother" else: - color_name = 'S_phase_bud' - elif ccs == 'G1': - color_name = 'G1_phase' - - opts = {'text': text, 'color_name': color_name, 'bold': bold} + color_name = "S_phase_bud" + elif ccs == "G1": + color_name = "G1_phase" + + opts = {"text": text, "color_name": color_name, "bold": bold} return opts + class TextAnnotationsImageItem(pg.ImageItem): def __init__(self, **kargs): super().__init__(**kargs) - + def initFonts(self, fontSize): self.fontSize = fontSize self.fontBold = ImageFont.truetype(font_path, fontSize) @@ -143,80 +148,81 @@ def initFonts(self, fontSize): ) self.highlighterItem.initFonts(fontSize) self.highlighterItem.initSymbols(range(10)) - + def initSizes(self): pass - + def init(self, image_shape): shape = (*image_shape, 4) self.pilImage = Image.fromarray(np.zeros(shape, dtype=np.uint8)) self.pilDraw = ImageDraw.Draw(self.pilImage) - + def clearImage(self): - self.pilDraw.rectangle([(0,0), self.pilDraw.im.size], fill=(0,0,0,0)) - + self.pilDraw.rectangle([(0, 0), self.pilDraw.im.size], fill=(0, 0, 0, 0)) + def clearData(self): self.clearImage() self.setOpacity(1.0) self.highlighterItem.setData([], []) self.texts = [] self.annotData = [] - + def update(self): pass - + def appendData(self, data, text): self.annotData.append(data) self.texts.append(text) - + def highlightObject(self, obj): self.highlighterItem.texts = self.texts self.highlighterItem.highlightObject(obj) - + def grayOutAnnotations(self, IDsToSkip=None): self.setOpacity(0.3) - + def addObjAnnot(self, pos, draw=True, **objOpts): - if objOpts['bold']: + if objOpts["bold"]: font = self.fontBold else: font = self.fontRegular - - text = objOpts['text'] - color = self._colors[objOpts['color_name']] - self.pilDraw.text(pos, text, color, font=font, anchor='mm') - return objOpts - + + text = objOpts["text"] + color = self._colors[objOpts["color_name"]] + self.pilDraw.text(pos, text, color, font=font, anchor="mm") + return objOpts + def draw(self): super().setImage(np.array(self.pilImage)) def setColors(self, colors): self._colors = colors.copy() self.highlighterItem.setColors(colors) - + def initSymbols(self, allIDs): pass def colors(self): return self._colors + class TextAnnotationsScatterItem(pg.ScatterPlotItem): def __init__(self, *args, anchor=(0.5, 0.5), **kargs): super().__init__(*args, **kargs) - self.initFonts(kargs.get('size', 10)) + self.initFonts(kargs.get("size", 10)) self.texts = [] self.annotData = [] self._anchor = anchor - + def clearData(self): self.setData([], []) self.annotData = [] self.texts = [] - + def appendData(self, data, text): self.annotData.append(data) self.texts.append(text) - + def draw(self): super().setData(self.annotData) @@ -228,31 +234,31 @@ def initFonts(self, fontSize): self.fontRegular = QFont(FONT_FAMILY.lower()) self.fontRegular.setPixelSize(fontSize) - + def init(self, *args): pass def initSymbols(self, allIDs, onlyIDs=False): - annotTexts = ['?'] + annotTexts = ["?"] for ID in allIDs: annotTexts.append(str(ID)) if not onlyIDs: - annotTexts.append(f'{ID}?') - + annotTexts.append(f"{ID}?") + if not onlyIDs: for gen_num in range(20): - annotTexts.append(f'G1-{gen_num}') - annotTexts.append(f'G1-{gen_num}?') - annotTexts.append(f'S-{gen_num}') - annotTexts.append(f'S-{gen_num}?') - - if hasattr(self, 'symbolsBold'): + annotTexts.append(f"G1-{gen_num}") + annotTexts.append(f"G1-{gen_num}?") + annotTexts.append(f"S-{gen_num}") + annotTexts.append(f"S-{gen_num}?") + + if hasattr(self, "symbolsBold"): # Symbols already created in prev. session --> add missing ones self.addSymbols(annotTexts) else: # Symbols never created --> create now self.createSymbols(annotTexts) - + def addSymbols(self, annotTexts, includeBold=True): for text in annotTexts: if includeBold: @@ -275,11 +281,11 @@ def createSymbols(self, annotTexts, includeBold=True): ) self.scalesRegular = scalesRegular self.initSizes(includeBold=includeBold) - + def initSizes(self, includeBold=True): - if not hasattr(self, 'scalesBold'): + if not hasattr(self, "scalesBold"): includeBold = False - + if includeBold: self.sizesBold = plot.get_symbol_sizes( self.scalesBold, self.symbolsBold, self.fontSize @@ -287,7 +293,7 @@ def initSizes(self, includeBold=True): self.sizesRegular = plot.get_symbol_sizes( self.scalesRegular, self.symbolsRegular, self.fontSize ) - + def setColors(self, colors): self._colors = colors.copy() self._brushes = {} @@ -295,10 +301,10 @@ def setColors(self, colors): for name, color in self._colors.items(): self._brushes[name] = pg.mkBrush(color) self._pens[name] = pg.mkPen(color[:3], width=1) - + def pens(self): return self._pens - + def brushes(self): return self._brushes @@ -314,7 +320,7 @@ def getObjTextAnnotSymbol(self, text, bold=False, initSizes=True): symbols = self.symbolsRegular font = self.fontRegular scales = self.scalesRegular - + symbol = symbols.get(text) if symbol is not None: return symbol @@ -329,12 +335,12 @@ def getObjTextAnnotSymbol(self, text, bold=False, initSizes=True): return symbol def grayOutAnnotations(self, IDsToSkip=None): - brushes = [self._brushes['grayed'] for _ in range(len(self.data))] - pens = [self._pens['grayed'] for _ in range(len(self.data))] + brushes = [self._brushes["grayed"] for _ in range(len(self.data))] + pens = [self._pens["grayed"] for _ in range(len(self.data))] if IDsToSkip is not None: pointItems = self.points() for idx, objData in enumerate(self.data): - ID = objData['data'] + ID = objData["data"] doNotGray = IDsToSkip.get(ID, False) if not doNotGray: continue @@ -350,30 +356,28 @@ def highlightObject(self, obj): ID = obj.label objIdx = None for idx, objData in enumerate(self.data): - if ID == objData['data']: + if ID == objData["data"]: objIdx = idx break if objIdx is None: - objOpts = { - 'text': str(ID), 'bold': True, 'color_name': 'new_object' - } + objOpts = {"text": str(ID), "bold": True, "color_name": "new_object"} yc, xc = obj.centroid[-2:] pos = (int(xc), int(yc)) self.addObjAnnot(pos, draw=True, **objOpts) return - + pointItem = self.points()[objIdx] symbol = self.getObjTextAnnotSymbol(str(ID), bold=True) pointItem.setSymbol(symbol) - pointItem.setBrush(self._brushes['new_object']) - pointItem.setPen(self._pens['new_object']) + pointItem.setBrush(self._brushes["new_object"]) + pointItem.setPen(self._pens["new_object"]) def removeHighlightObject(self, obj): ID = obj.label objIdx = None for idx, objData in enumerate(self.data): - if ID == objData['data']: + if ID == objData["data"]: objIdx = idx break if objIdx is None: @@ -384,28 +388,28 @@ def removeHighlightObject(self, obj): default_symbol = self.getObjTextAnnotSymbol(str(ID), bold=False) pointItem.setSymbol(default_symbol) - pointItem.setBrush(self._brushes['label']) - pointItem.setPen(self._pens['label']) - + pointItem.setBrush(self._brushes["label"]) + pointItem.setPen(self._pens["label"]) + def modifyPosAnchor(self, pointOpts, anchor, symbol): if anchor is None: return pointOpts - + xa, ya = anchor if (xa, ya) == (0.5, 0.5): return pointOpts - + br = symbol.boundingRect() - xf = br.width()*(anchor[0]-0.5) - yf = br.height()*(anchor[1]-0.5) - x, y = pointOpts['pos'] - pointOpts['pos'] = (x-xf, y-yf) - - return pointOpts - - def addObjAnnot(self, pos, draw=False, anchor=None, **objOpts): - text = objOpts['text'] - bold = objOpts['bold'] + xf = br.width() * (anchor[0] - 0.5) + yf = br.height() * (anchor[1] - 0.5) + x, y = pointOpts["pos"] + pointOpts["pos"] = (x - xf, y - yf) + + return pointOpts + + def addObjAnnot(self, pos, draw=False, anchor=None, **objOpts): + text = objOpts["text"] + bold = objOpts["bold"] symbol = self.getObjTextAnnotSymbol(text, bold) if bold: @@ -413,20 +417,21 @@ def addObjAnnot(self, pos, draw=False, anchor=None, **objOpts): else: size = self.sizesRegular[text] - color_name = objOpts['color_name'] + color_name = objOpts["color_name"] pointOpts = {} - pointOpts['brush'] = self._brushes[color_name] - pointOpts['pen'] = self._pens[color_name] - pointOpts['symbol'] = symbol - pointOpts['size'] = size - pointOpts['pos'] = tuple(pos) + pointOpts["brush"] = self._brushes[color_name] + pointOpts["pen"] = self._pens[color_name] + pointOpts["symbol"] = symbol + pointOpts["size"] = size + pointOpts["pos"] = tuple(pos) pointOpts = self.modifyPosAnchor(pointOpts, anchor, symbol) if draw: self.addPoints([pointOpts]) - - return pointOpts + + return pointOpts + class TextAnnotations: def __init__(self): @@ -436,25 +441,23 @@ def __init__(self): self._isLabelTreeAnnotation = False self._isGenNumTreeAnnotation = False self._isGenNumTreeAnnotation = False - + def initFonts(self, fontSize): self.fontSize = fontSize - + def initItem(self, *args): self.item.init(*args) - + def clear(self): self.item.clear() - if hasattr(self.item, 'highlighterItem'): + if hasattr(self.item, "highlighterItem"): self.item.highlighterItem.setData([], []) - + def invertBlackAndWhite(self): - invertedColors = { - name:color[:3] for name, color in self.item.colors().items() - } + invertedColors = {name: color[:3] for name, color in self.item.colors().items()} for color_name in INVERTIBLE_COLOR_NAMES: color = self.item.colors()[color_name] - invertedColors[color_name] = tuple([255-val for val in color[:3]]) + invertedColors[color_name] = tuple([255 - val for val in color[:3]]) self.setColors(**invertedColors) @@ -463,106 +466,109 @@ def createItems(self, isHighResolution, allIDs, pxMode=False): if isHighResolution: self._createHighResolutionItems(allIDs, pxMode=pxMode) else: - self._createLowResolutionItem() - + self._createLowResolutionItem() + def _createLowResolutionItem(self): self.item = TextAnnotationsImageItem() self.setFontSize(self.fontSize, []) - + def _createHighResolutionItems(self, allIDs, pxMode=False): - self.item = TextAnnotationsScatterItem( - size=self.fontSize, pxMode=pxMode - ) + self.item = TextAnnotationsScatterItem(size=self.fontSize, pxMode=pxMode) self.setFontSize(self.fontSize, allIDs) - + def setFontSize(self, fontSize, allIDs): self.fontSize = fontSize self.item.initFonts(self.fontSize) self.item.initSymbols(allIDs) - + def changeFontSize(self, fontSize): self.fontSize = fontSize self.item.initFonts(fontSize) self.item.initSizes() - + def changeResolution(self, mode, allIDs, ax, img_shape): self.removeFromPlotItem(ax) - highRes = True if mode == 'high' else False + highRes = True if mode == "high" else False self.createItems(highRes, allIDs, pxMode=self._pxMode) self.initItem(img_shape) self.item.setColors(self.colors()) self.item.clearData() self.addToPlotItem(ax) - + def addToPlotItem(self, ax): ax.addItem(self.item) - if hasattr(self.item, 'highlighterItem'): + if hasattr(self.item, "highlighterItem"): ax.addItem(self.item.highlighterItem) def removeFromPlotItem(self, ax): ax.removeItem(self.item) - if hasattr(self.item, 'highlighterItem'): + if hasattr(self.item, "highlighterItem"): ax.removeItem(self.item.highlighterItem) - + def addObjAnnotation(self, obj, color_name, text, bold): objOpts = { - 'text': text, - 'bold': bold, - 'color_name': color_name, + "text": text, + "bold": bold, + "color_name": color_name, } yc, xc = obj.centroid[-2:] pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=True, **objOpts) - self.item.appendData(objData, objOpts['text']) - + self.item.appendData(objData, objOpts["text"]) + def setAnnotations(self, **kwargs): if self.isDisabled(): return - + self.item.clearData() - - labelsToSkip = kwargs.get('labelsToSkip') - posData = kwargs['posData'] - delROIsIDs = kwargs.get('delROIsIDs', []) - isObjVisibleFunc = kwargs.get('isVisibleCheckFunc') - highlightedID = kwargs.get('highlightedID') - annotateLost = kwargs.get('annotateLost') - getCurrentZfunc = kwargs.get('getCurrentZfunc') - getObjCentroidFunc = kwargs.get('getObjCentroidFunc') + + labelsToSkip = kwargs.get("labelsToSkip") + posData = kwargs["posData"] + delROIsIDs = kwargs.get("delROIsIDs", []) + isObjVisibleFunc = kwargs.get("isVisibleCheckFunc") + highlightedID = kwargs.get("highlightedID") + annotateLost = kwargs.get("annotateLost") + getCurrentZfunc = kwargs.get("getCurrentZfunc") + getObjCentroidFunc = kwargs.get("getObjCentroidFunc") currentZ = getCurrentZfunc(checkIfProj=True) isCcaAnnot = self.isCcaAnnot() isAnnotateNumZslices = self.isAnnotateNumZslices() isLabelTreeAnnotation = self.isLabelTreeAnnotation() isGenNumTreeAnnotation = self.isGenNumTreeAnnotation() - - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] + + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] if posData.cca_df is not None and acdc_df is not None: cols = posData.cca_df.columns idx = posData.cca_df.index.intersection(acdc_df.index) acdc_df.loc[idx, cols] = posData.cca_df - + if acdc_df is None and posData.cca_df is not None: acdc_df = posData.cca_df - + for obj in posData.rp: if labelsToSkip is not None: if labelsToSkip.get(obj.label, False): continue - + if not isObjVisibleFunc(obj.bbox): continue - + if obj.label in delROIsIDs: continue isNewObject = obj.label in posData.new_IDs - + objOpts = get_obj_text_annot_opts( - obj, acdc_df, isCcaAnnot, isNewObject, - isAnnotateNumZslices, isLabelTreeAnnotation, - isGenNumTreeAnnotation, posData.frame_i + obj, + acdc_df, + isCcaAnnot, + isNewObject, + isAnnotateNumZslices, + isLabelTreeAnnotation, + isGenNumTreeAnnotation, + posData.frame_i, ) - + yc, xc = getObjCentroidFunc(obj.centroid) try: rp_zslice = posData.zSlicesRp[currentZ] @@ -570,59 +576,58 @@ def setAnnotations(self, **kwargs): yc, xc = obj_2d.centroid except Exception as err: pass - + pos = (int(xc), int(yc)) - + objData = self.item.addObjAnnot(pos, draw=False, **objOpts) - objData['data'] = obj.label - self.item.appendData(objData, objOpts['text']) + objData["data"] = obj.label + self.item.appendData(objData, objOpts["text"]) if posData.trackedLostIDs and annotateLost: - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] if prev_rp is None: self.item.draw() return - + for obj in prev_rp: if obj.label not in posData.trackedLostIDs: continue if obj.label in delROIsIDs: continue - + if not isObjVisibleFunc(obj.bbox): continue objOpts = { - 'text': f'{obj.label}', - 'color_name': 'tracked_lost_object', - 'bold': False, + "text": f"{obj.label}", + "color_name": "tracked_lost_object", + "bold": False, } yc, xc = obj.centroid[-2:] pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=False, **objOpts) - self.item.appendData(objData, objOpts['text']) - + self.item.appendData(objData, objOpts["text"]) if posData.lost_IDs and annotateLost: - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] if prev_rp is None: self.item.draw() return for obj in prev_rp: if obj.label not in posData.lost_IDs: continue - + if obj.label in delROIsIDs: continue - + if not isObjVisibleFunc(obj.bbox): continue - + objOpts = { - 'text': f'{obj.label}?', - 'color_name': 'lost_object', - 'bold': False, + "text": f"{obj.label}?", + "color_name": "lost_object", + "bold": False, } yc, xc = getObjCentroidFunc(obj.centroid) try: @@ -630,52 +635,58 @@ def setAnnotations(self, **kwargs): except Exception as err: printl("""WARNING: Could not annotate lost object, failed to get position. Skipping annotation.""") - # Sometimes xc or yc can be nan, causing an error when + # Sometimes xc or yc can be nan, causing an error when # converting to int --> skip annotation in this case continue objData = self.item.addObjAnnot(pos, draw=False, **objOpts) - self.item.appendData(objData, objOpts['text']) + self.item.appendData(objData, objOpts["text"]) self.item.draw() - + def highlightObject(self, obj): self.item.highlightObject(obj) - + def removeHighlightObject(self, obj): self.item.removeHighlightObject(obj) - + def grayOutAnnotations(self, IDsToSkip=None): self.item.grayOutAnnotations(IDsToSkip=IDsToSkip) def isDisabled(self): _isEnabled = self._isLabelAnnot or self._isCcaAnnot - return (not _isEnabled) - + return not _isEnabled + def setColors( - self, label, bud_will_divide, S_phase_mother, G1_phase, - lost_object, tracked_lost_object, **kwargs - ): + self, + label, + bud_will_divide, + S_phase_mother, + G1_phase, + lost_object, + tracked_lost_object, + **kwargs, + ): alpha = 200 if len(G1_phase) == 3: G1_phase = (*G1_phase, 220) else: G1_phase = tuple(G1_phase) colors = { - 'label': (*label, alpha), - 'bud_will_divide': (*bud_will_divide, alpha), - 'S_phase_mother': (*S_phase_mother, alpha), - 'G1_phase': G1_phase, - 'new_object': (255,0,0,255), - 'lost_object': (*lost_object, alpha), - 'tracked_lost_object': (*tracked_lost_object, alpha), - 'grayed': (100,100,100,75), - 'highlight': (255,0,0,200), - 'S_phase_bud': (255,0,0,220), - 'green': (0,255,0,220) + "label": (*label, alpha), + "bud_will_divide": (*bud_will_divide, alpha), + "S_phase_mother": (*S_phase_mother, alpha), + "G1_phase": G1_phase, + "new_object": (255, 0, 0, 255), + "lost_object": (*lost_object, alpha), + "tracked_lost_object": (*tracked_lost_object, alpha), + "grayed": (100, 100, 100, 75), + "highlight": (255, 0, 0, 200), + "S_phase_bud": (255, 0, 0, 220), + "green": (0, 255, 0, 220), } self.item.setColors(colors) self._colors = colors - + def colors(self): return self._colors @@ -684,7 +695,7 @@ def setLabelAnnot(self, isLabelAnnot): def setCcaAnnot(self, isCcaAnnot): self._isCcaAnnot = isCcaAnnot - + def isCcaAnnot(self): return self._isCcaAnnot @@ -693,24 +704,24 @@ def isLabelAnnot(self): def setAnnotateNumZslices(self, isAnnotateNumZslices): self._isAnnotateNumZslices = isAnnotateNumZslices - + def isAnnotateNumZslices(self): return self._isAnnotateNumZslices - + def setLabelTreeAnnotationsEnabled(self, isTreeAnnotations): self._isLabelTreeAnnotation = isTreeAnnotations - + def setGenNumTreeAnnotationsEnabled(self, isTreeAnnotations): self._isGenNumTreeAnnotation = isTreeAnnotations - + def isLabelTreeAnnotation(self): return self._isLabelTreeAnnotation def isGenNumTreeAnnotation(self): return self._isGenNumTreeAnnotation - + def setPxMode(self, mode): self.item.setPxMode(mode) - + def update(self): self.item.update() diff --git a/cellacdc/apps.py b/cellacdc/apps.py index 5bc4553c0..f26cf6f89 100755 --- a/cellacdc/apps.py +++ b/cellacdc/apps.py @@ -1,19722 +1,3 @@ -import os -import sys -import re -from typing import Literal, Callable, Dict, Iterable, List, Tuple -import datetime -import pathlib -from collections import defaultdict -import zipfile -from heapq import nlargest -import matplotlib -import matplotlib.pyplot as plt -from matplotlib.lines import Line2D -from matplotlib.patches import Rectangle, Circle, PathPatch, Path -import numpy as np -import scipy.interpolate -try: - import tkinter as tk -except Exception as err: - pass +"""Compatibility shim; implementation lives in dialogs/.""" -import cv2 -import traceback -from itertools import combinations, permutations -from collections import namedtuple -from natsort import natsorted -# from MyWidgets import Slider, Button, MyRadioButtons -from skimage.measure import label, regionprops -from functools import partial -import skimage.filters -import skimage.measure -import skimage.morphology -import skimage.exposure -import skimage.draw -import skimage.registration -import skimage.color -import skimage.segmentation -from matplotlib.backends.backend_tkagg import ( - FigureCanvasTkAgg, NavigationToolbar2Tk -) -import matplotlib.pyplot as plt -import seaborn as sns -import pandas as pd -import math -import time -import sympy as sp -import json -import html - -import pyqtgraph as pg -pg.setConfigOption('imageAxisOrder', 'row-major') - -from qtpy import QtCore -from qtpy.QtGui import ( - QIcon, QFontMetrics, QKeySequence, QFont, QRegularExpressionValidator, - QCursor, QKeyEvent, QPixmap, QFont, QPalette, QMouseEvent, QColor -) -from qtpy.QtCore import ( - Qt, QSize, QEvent, Signal, QEventLoop, QTimer, QRegularExpression -) -from qtpy.QtWidgets import ( - QFileDialog, QApplication, QMainWindow, QMenu, QLabel, QToolBar, - QScrollBar, QWidget, QVBoxLayout, QLineEdit, QPushButton, - QHBoxLayout, QDialog, QFormLayout, QListWidget, QAbstractItemView, - QButtonGroup, QCheckBox, QSizePolicy, QComboBox, QSlider, QGridLayout, - QSpinBox, QToolButton, QTableView, QTextBrowser, QDoubleSpinBox, - QScrollArea, QFrame, QProgressBar, QGroupBox, QRadioButton, - QDockWidget, QMessageBox, QStyle, QPlainTextEdit, QSpacerItem, - QTreeWidget, QTreeWidgetItem, QTextEdit, QSplashScreen, QAction, - QListWidgetItem, QActionGroup, QHeaderView, QStyledItemDelegate -) -import qtpy.compat - -from . import exception_handler -from . import load, prompts, core, measurements, html_utils -from . import is_mac, is_win, is_linux, settings_folderpath, config -from . import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path -from . import is_conda_env -from . import printl -from . import colors -from . import issues_url -from . import myutils -from . import qutils -from . import _palettes -from . import base_cca_dict -from . import widgets -from . import user_profile_path, promptable_models_path, models_path -from . import features -from . import _core -from . import _types -from . import plot -from . import urls -from .acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric -from . import _base_widgets -from . import io -from . import cca_functions -from . import path - -POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) -TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() -LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() -BACKGROUND_RGBA = _palettes.get_disabled_colors()['Button'] - -font = QFont() -font.setPixelSize(12) -italicFont = QFont() -italicFont.setPixelSize(12) -italicFont.setItalic(True) - -class ArgWidget: - def __init__(self, name, type, widget, defaultVal, valueSetter, valueGetter, changeSig=None): - self.name = name - self.type = type - self.widget = widget - self.defaultVal = defaultVal - self.valueSetter = valueSetter - self.valueGetter = valueGetter - if changeSig is not None: - self.changeSig = changeSig - - -def addCustomModelMessages(QParent=None): - modelFilePath = None - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(""" - Do you already have the acdcSegment.py file for your code - or do you need instructions on how to set-up your custom model?
- """) - infoButton = widgets.infoPushButton(' I need instructions') - browseButton = widgets.browseFileButton(' I have the model, let me select it') - msg.information( - QParent, 'Add custom model', txt, - buttonsTexts=('Cancel', infoButton, browseButton), - showDialog=False - ) - browseButton.clicked.disconnect() - browseButton.clicked.connect(msg.buttonCallBack) - msg.exec_() - if msg.cancel: - return - if msg.clickedButton == infoButton: - txt = myutils.get_add_custom_model_instructions() - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.information( - QParent, 'Custom model instructions', txt, buttonsTexts=('Ok',), - path_to_browse=models_path, - browse_button_text='Open models folder...' - ) - else: - homePath = pathlib.Path.home() - modelFilePath = QFileDialog.getOpenFileName( - QParent, 'Select the acdcSegment.py file of your model', - str(homePath), 'acdcSegment.py file (*.py);;All files (*)' - )[0] - if not modelFilePath: - return - - return modelFilePath - -def addCustomPromptModelMessages(QParent=None): - modelFilePath = None - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(""" - Do you already have the acdcPromptSegment.py file for your code - or do you need instructions on how to set-up your custom model?
- """) - infoButton = widgets.infoPushButton(' I need instructions') - browseButton = widgets.browseFileButton(' I have the model, let me select it') - msg.information( - QParent, 'Add custom promptable model', txt, - buttonsTexts=('Cancel', infoButton, browseButton), - showDialog=False - ) - browseButton.clicked.disconnect() - browseButton.clicked.connect(msg.buttonCallBack) - msg.exec_() - if msg.cancel: - return - if msg.clickedButton == infoButton: - txt = myutils.get_add_custom_prompt_model_instructions() - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.information( - QParent, 'Custom promptable model instructions', - txt, buttonsTexts=('Ok',), - path_to_browse=promptable_models_path, - browse_button_text='Open promptable models folder...' - ) - else: - homePath = pathlib.Path.home() - modelFilePath = QFileDialog.getOpenFileName( - QParent, 'Select the acdcPromptSegment.py file of your model', - str(homePath), 'acdcPromptSegment.py file (*.py);;All files (*)' - )[0] - if not modelFilePath: - return - - return modelFilePath - -class QBaseDialog(_base_widgets.QBaseDialog): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class customAnnotationDialog(QDialog): - sigDeleteSelecAnnot = Signal(object) - - def __init__(self, savedCustomAnnot, parent=None, state=None): - self.cancel = True - self.loop = None - self.clickedButton = None - self.savedCustomAnnot = savedCustomAnnot - - self.internalNames = measurements.get_all_acdc_df_colnames( - include_custom=False - ) - - super().__init__(parent) - - self.setWindowTitle('Custom annotation') - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - layout = widgets.FormLayout() - - row = 0 - typeCombobox = QComboBox() - typeCombobox.addItems([ - 'Single time-point', - 'Multiple time-points', - 'Multiple values class' - ]) - if state is not None: - typeCombobox.setCurrentText(state['type']) - self.typeCombobox = typeCombobox - body_txt = (""" - Single time-point annotation: use this to annotate - an event that happens on a single frame in time - (e.g. cell division). -

- Multiple time-points annotation: use this to annotate - an event that has a duration, i.e., a start frame and a stop - frame (e.g. cell cycle phase).

- Multiple values class annotation: use this to annotate a class - that has multiple values. An example could be a cell cycle stage - that can have different values, such as 2-cells division - or 4-cells division. - """) - typeInfoTxt = (f'{html_utils.paragraph(body_txt)}') - self.typeWidget = widgets.formWidget( - typeCombobox, addInfoButton=True, labelTextLeft='Type: ', - parent=self, infoTxt=typeInfoTxt - ) - layout.addFormWidget(self.typeWidget, row=row) - typeCombobox.currentTextChanged.connect(self.warnType) - - row += 1 - nameInfoTxt = (""" - Name of the column that will be saved in the acdc_output.csv - file.

- Valid charachters are letters and numbers separate by underscore - or dash only.

- Additionally, some names are reserved because they are used - by Cell-ACDC for standard measurements.

- Internally reserved names: - """) - self.nameInfoTxt = (f'{html_utils.paragraph(nameInfoTxt)}') - self.nameWidget = widgets.formWidget( - widgets.alphaNumericLineEdit(), addInfoButton=True, - labelTextLeft='Name: ', parent=self, infoTxt=self.nameInfoTxt - ) - self.nameWidget.infoButton.disconnect() - self.nameWidget.infoButton.clicked.connect(self.showNameInfo) - if state is not None: - self.nameWidget.widget.setText(state['name']) - self.nameWidget.widget.textChanged.connect(self.checkName) - layout.addFormWidget(self.nameWidget, row=row) - - row += 1 - self.nameInfoLabel = QLabel() - layout.addWidget( - self.nameInfoLabel, row, 0, 1, 2, alignment=Qt.AlignCenter - ) - - row += 1 - spacing = QSpacerItem(10, 10) - layout.addItem(spacing, row, 0) - - row += 1 - symbolInfoTxt = (""" - Symbol that will be drawn on the annotated cell at - the requested time frame. - """) - symbolInfoTxt = (f'{html_utils.paragraph(symbolInfoTxt)}') - self.symbolWidget = widgets.formWidget( - widgets.pgScatterSymbolsCombobox(), addInfoButton=True, - labelTextLeft='Symbol: ', parent=self, infoTxt=symbolInfoTxt - ) - if state is not None: - self.symbolWidget.widget.setCurrentText(state['symbol']) - layout.addFormWidget(self.symbolWidget, row=row) - - row += 1 - shortcutInfoTxt = (""" - Shortcut that you can use to activate/deactivate annotation - of this event.

Leave empty if you don't need a shortcut. - """) - shortcutInfoTxt = (f'{html_utils.paragraph(shortcutInfoTxt)}') - self.shortcutWidget = widgets.formWidget( - widgets.ShortcutLineEdit(), addInfoButton=True, - labelTextLeft='Shortcut: ', parent=self, infoTxt=shortcutInfoTxt - ) - if state is not None: - self.shortcutWidget.widget.setText(state['shortcut']) - layout.addFormWidget(self.shortcutWidget, row=row) - - row += 1 - descInfoTxt = (""" - Description will be used as the tool tip that will be - displayed when you hover with th mouse cursor on the toolbar button - specific for this annotation - """) - descInfoTxt = (f'{html_utils.paragraph(descInfoTxt)}') - self.descWidget = widgets.formWidget( - QPlainTextEdit(), addInfoButton=True, - labelTextLeft='Description: ', parent=self, infoTxt=descInfoTxt - ) - if state is not None: - self.descWidget.widget.setPlainText(state['description']) - layout.addFormWidget(self.descWidget, row=row) - - row += 1 - optionsGroupBox = QGroupBox('Additional options') - optionsLayout = QGridLayout() - toggle = widgets.Toggle() - toggle.setChecked(True) - self.keepActiveToggle = toggle - toggleLabel = QLabel('Keep tool active after using it: ') - colorButtonLabel = QLabel('Symbol color: ') - self.hideAnnotTooggle = widgets.Toggle() - self.hideAnnotTooggle.setChecked(True) - hideAnnotTooggleLabel = QLabel( - 'Hide annotation when button is not active: ' - ) - self.colorButton = widgets.myColorButton(color=(255, 0, 0)) - self.colorButton.clicked.disconnect() - self.colorButton.clicked.connect(self.selectColor) - - optionsLayout.setColumnStretch(0, 1) - optRow = 0 - optionsLayout.addWidget(toggleLabel, optRow, 1) - optionsLayout.addWidget(toggle, optRow, 2) - optRow += 1 - optionsLayout.addWidget(hideAnnotTooggleLabel, optRow, 1) - optionsLayout.addWidget(self.hideAnnotTooggle, optRow, 2) - optionsLayout.setColumnStretch(3, 1) - optRow += 1 - optionsLayout.addWidget(colorButtonLabel, optRow, 1) - optionsLayout.addWidget(self.colorButton, optRow, 2) - - optionsGroupBox.setLayout(optionsLayout) - layout.addWidget(optionsGroupBox, row, 1, alignment=Qt.AlignCenter) - optionsInfoButton = QPushButton(self) - optionsInfoButton.setCursor(Qt.WhatsThisCursor) - optionsInfoButton.setIcon(QIcon(":info.svg")) - optionsInfoButton.clicked.connect(self.showOptionsInfo) - layout.addWidget(optionsInfoButton, row, 3, alignment=Qt.AlignRight) - - row += 1 - layout.addItem(QSpacerItem(5, 5), row, 0) - - row += 1 - noteText = ( - 'NOTE: you can change these options later with
' - 'RIGHT-click on the associated left-side toolbar button.
' - ) - noteLabel = QLabel(html_utils.paragraph(noteText, font_size='11px')) - layout.addWidget(noteLabel, row, 1, 1, 3) - - buttonsLayout = QHBoxLayout() - - self.loadSavedAnnotButton = widgets.OpenFilePushButton( - ' Load annotation... ' - ) - if not savedCustomAnnot: - self.loadSavedAnnotButton.setDisabled(True) - self.okButton = widgets.okPushButton(' Ok ') - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(self.loadSavedAnnotButton) - buttonsLayout.addWidget(self.okButton) - - cancelButton.clicked.connect(self.cancelCallBack) - self.cancelButton = cancelButton - self.loadSavedAnnotButton.clicked.connect(self.loadSavedAnnot) - self.okButton.clicked.connect(self.ok_cb) - self.okButton.setFocus() - - mainLayout = QVBoxLayout() - - noteTxt = (""" - Custom annotations will be saved in the acdc_output.csv
- file as a column with the name you write in the field Name
- """) - noteTxt = (f'{html_utils.paragraph(noteTxt, font_size="15px")}') - noteLabel = QLabel(noteTxt) - noteLabel.setAlignment(Qt.AlignCenter) - mainLayout.addWidget(noteLabel) - - mainLayout.addLayout(layout) - mainLayout.addStretch(1) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def checkName(self, text): - if not text: - txt = 'Name cannot be empty' - self.nameInfoLabel.setText( - html_utils.paragraph( - txt, font_size='11px', font_color='red' - ) - ) - return - for name in self.internalNames: - if name.find(text) != -1: - txt = ( - f'"{text}" cannot be part of the name, ' - 'because reserved.' - ) - self.nameInfoLabel.setText( - html_utils.paragraph( - txt, font_size='11px', font_color='red' - ) - ) - break - else: - self.nameInfoLabel.setText('') - - def loadSavedAnnot(self): - items = list(self.savedCustomAnnot.keys()) - self.selectAnnotWin = widgets.QDialogListbox( - 'Load annotation parameters', - 'Select annotation to load:', items, - additionalButtons=('Delete selected annnotations', ), - parent=self, multiSelection=False - ) - for button in self.selectAnnotWin._additionalButtons: - button.disconnect() - button.clicked.connect(self.deleteSelectedAnnot) - self.selectAnnotWin.exec_() - if self.selectAnnotWin.cancel: - return - if self.selectAnnotWin.listBox.count() == 0: - return - if not self.selectAnnotWin.selectedItemsText: - self.warnNoItemsSelected() - return - selectedName = self.selectAnnotWin.selectedItemsText[-1] - selectedAnnot = self.savedCustomAnnot[selectedName] - self.typeCombobox.setCurrentText(selectedAnnot['type']) - self.nameWidget.widget.setText(selectedAnnot['name']) - self.symbolWidget.widget.setCurrentText(selectedAnnot['symbol']) - self.shortcutWidget.widget.setText(selectedAnnot['shortcut']) - self.descWidget.widget.setPlainText(selectedAnnot['description']) - self.colorButton.setColor(selectedAnnot['symbolColor']) - keySequence = widgets.macShortcutToWindows(selectedAnnot['shortcut']) - if keySequence: - self.shortcutWidget.widget.keySequence = widgets.KeySequenceFromText(keySequence) - - def warnNoItemsSelected(self): - msg = widgets.myMessageBox(parent=self) - msg.setIcon(iconName='SP_MessageBoxWarning') - msg.setWindowTitle('Delete annotation?') - msg.addText('You didn\'t select any annotation!') - msg.addButton(' Ok ') - msg.exec_() - - def deleteSelectedAnnot(self): - msg = widgets.myMessageBox(parent=self) - msg.setIcon(iconName='SP_MessageBoxWarning') - msg.setWindowTitle('Delete annotation?') - msg.addText('Are you sure you want to delete the selected annotations?') - msg.addButton('Yes') - cancelButton = msg.addButton(' Cancel ') - msg.exec_() - if msg.clickedButton == cancelButton: - return - for item in self.selectAnnotWin.listBox.selectedItems(): - name = item.text() - self.savedCustomAnnot.pop(name) - self.sigDeleteSelecAnnot.emit(self.selectAnnotWin.listBox.selectedItems()) - items = list(self.savedCustomAnnot.keys()) - self.selectAnnotWin.listBox.clear() - self.selectAnnotWin.listBox.addItems(items) - - def selectColor(self): - color = self.colorButton.color() - self.colorButton.origColor = color - self.colorButton.colorDialog.setCurrentColor(color) - self.colorButton.colorDialog.setWindowFlags( - Qt.Window | Qt.WindowStaysOnTopHint - ) - self.colorButton.colorDialog.open() - w = self.width() - left = self.pos().x() - colorDialogTop = self.colorButton.colorDialog.pos().y() - self.colorButton.colorDialog.move(w+left+10, colorDialogTop) - - def warnType(self, currentText): - if currentText == 'Single time-point': - return - - self.typeCombobox.setCurrentIndex(0) - - txt = (""" - Unfortunately, the only annotation type that is available so far is - Single time-point.

- We are working on implementing the other types too, so stay tuned!

- Thank you for your patience! - """) - txt = (f'{html_utils.paragraph(txt)}') - msg = widgets.myMessageBox() - msg.setIcon(iconName='SP_MessageBoxWarning') - msg.setWindowTitle(f'Feature not implemented yet') - msg.addText(txt) - msg.addButton(' Ok ') - msg.exec_() - - def showOptionsInfo(self): - info = (""" - Keep tool active after using it: Choose whether the tool - should stay active or not after annotating.

- Hide annotation when button is not active: Choose whether - annotation on the cell/object should be visible only if the - button is active or also when it is not active.
- NOTE: annotations are always stored no matter whether - they are visible or not.

- Symbol color: Choose color of the symbol that will be used - to label annotated cell/object. - """) - info = (f'{html_utils.paragraph(info)}') - msg = widgets.myMessageBox() - msg.setIcon() - msg.setWindowTitle(f'Additional options info') - msg.addText(info) - msg.addButton(' Ok ') - msg.exec_() - - def ok_cb(self, checked=True): - self.cancel = False - self.clickedButton = self.okButton - self.close() - - def cancelCallBack(self, checked=True): - self.cancel = True - self.clickedButton = self.cancelButton - self.close() - - def showNameInfo(self): - msg = widgets.myMessageBox() - listView = widgets.readOnlyQList(msg) - listView.addItems(self.internalNames) - # listView.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) - msg.information( - self, 'Annotation Name info', self.nameInfoTxt, - widgets=listView - ) - - def closeEvent(self, event): - if self.clickedButton is None or self.clickedButton==self.cancelButton: - # cancel button or closed with 'x' button - self.cancel = True - return - - if self.clickedButton==self.okButton and not self.nameWidget.widget.text(): - msg = QMessageBox() - msg.critical( - self, 'Empty name', 'The name cannot be empty!', msg.Ok - ) - event.ignore() - self.cancel = True - return - - if self.clickedButton==self.okButton and self.nameInfoLabel.text(): - msg = widgets.myMessageBox() - listView = widgets.listWidget(msg) - listView.addItems(self.internalNames) - listView.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) - name = self.nameWidget.widget.text() - txt = ( - f'"{name}" cannot be part of the name, ' - 'because it is reserved for standard measurements ' - 'saved by Cell-ACDC.

' - 'Internally reserved names:' - ) - msg.critical( - self, 'Not a valid name', html_utils.paragraph(txt), - widgets=listView - ) - event.ignore() - self.cancel = True - return - - self.toolTip = ( - f'Name: {self.nameWidget.widget.text()}\n\n' - f'Type: {self.typeWidget.widget.currentText()}\n\n' - f'Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n' - f'Description: {self.descWidget.widget.toPlainText()}\n\n' - f'SHORTCUT: "{self.shortcutWidget.widget.text()}"' - ) - - symbol = self.symbolWidget.widget.currentText() - self.symbol = re.findall(r"\'(.+)\'", symbol)[0] - - self.state = { - 'type': self.typeWidget.widget.currentText(), - 'name': self.nameWidget.widget.text(), - 'symbol': self.symbolWidget.widget.currentText(), - 'shortcut': self.shortcutWidget.widget.text(), - 'description': self.descWidget.widget.toPlainText(), - 'keepActive': self.keepActiveToggle.isChecked(), - 'isHideChecked': self.hideAnnotTooggle.isChecked(), - 'symbolColor': self.colorButton.color() - } - - if self.loop is not None: - self.loop.exit() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - -class _PointsLayerAppearanceGroupbox(QGroupBox): - def __init__(self, *args): - super().__init__(*args) - - self.setTitle('Points appearance') - - layout = widgets.FormLayout() - - '----------------------------------------------------------------------' - row = 0 - symbolInfoTxt = (""" - Symbol used to draw the points. - """) - symbolInfoTxt = (f'{html_utils.paragraph(symbolInfoTxt)}') - self.symbolWidget = widgets.formWidget( - widgets.pgScatterSymbolsCombobox(), addInfoButton=True, - labelTextLeft='Symbol: ', parent=self, infoTxt=symbolInfoTxt, - stretchWidget=False - ) - layout.addFormWidget(self.symbolWidget, row=row) - '----------------------------------------------------------------------' - - '----------------------------------------------------------------------' - row += 1 - self.colorButton = widgets.myColorButton(color=(255, 0, 0)) - self.colorWidget = widgets.formWidget( - self.colorButton, stretchWidget=True, - labelTextLeft='Colour: ', parent=self - ) - layout.addFormWidget(self.colorWidget, align=Qt.AlignLeft, row=row) - self.colorButton.clicked.disconnect() - self.colorButton.clicked.connect(self.selectColor) - '----------------------------------------------------------------------' - - '----------------------------------------------------------------------' - row += 1 - self.sizeSpinBox = widgets.SpinBox() - self.sizeSpinBox.setValue(5) - self.sizeWidget = widgets.formWidget( - self.sizeSpinBox, stretchWidget=True, - labelTextLeft='Size: ', parent=self - ) - layout.addFormWidget(self.sizeWidget, row=row) - '----------------------------------------------------------------------' - - '----------------------------------------------------------------------' - row += 1 - zHeightTooltip = ( - 'If "Z-depth" is greater than 1, the points will be annotated ' - 'in all the z-slices in the range `z - (Z-depth/2) < z < z + (Z-depth/2)`\n' - 'where `z` is the center z-slice of the added point.' - ) - self.zHeightSpinBox = widgets.OddSpinBox() - self.zHeightSpinBox.setValue(1) - self.zHeightSpinBox.setMinimum(1) - self.zHeightWidget = widgets.formWidget( - self.zHeightSpinBox, stretchWidget=True, - labelTextLeft='Z-depth: ', parent=self, - toolTip=zHeightTooltip - ) - layout.addFormWidget(self.zHeightWidget, row=row) - '----------------------------------------------------------------------' - - '----------------------------------------------------------------------' - row += 1 - shortcutInfoTxt = (""" - Shortcut that you can use to hide/show points. - """) - shortcutInfoTxt = (f'{html_utils.paragraph(shortcutInfoTxt)}') - self.shortcutWidget = widgets.formWidget( - widgets.ShortcutLineEdit(), addInfoButton=True, - labelTextLeft='Shortcut: ', parent=self, infoTxt=shortcutInfoTxt - ) - layout.addFormWidget(self.shortcutWidget, row=row) - '----------------------------------------------------------------------' - - self.setLayout(layout) - - def restoreState(self, state): - self.shortcutWidget.widget.setText(state['shortcut']) - self.colorButton.setColor(state['color']) - self.symbolWidget.widget.setCurrentText(state['symbol']) - self.sizeSpinBox.setValue(state['pointSize']) - self.zHeightSpinBox.setValue(state['zHeight']) - - def selectColor(self): - color = self.colorButton.color() - self.colorButton.origColor = color - self.colorButton.colorDialog.setCurrentColor(color) - self.colorButton.colorDialog.setWindowFlags( - Qt.Window | Qt.WindowStaysOnTopHint - ) - self.colorButton.colorDialog.open() - w = self.width() - left = self.pos().x() - colorDialogTop = self.colorButton.colorDialog.pos().y() - self.colorButton.colorDialog.move(w+left+10, colorDialogTop) - - def state(self): - r,g,b,a = self.colorButton.color().getRgb() - _state = { - 'symbol': self.symbolWidget.widget.currentText(), - 'color': (r,g,b), - 'pointSize': self.sizeSpinBox.value(), - 'zHeight': self.zHeightSpinBox.value(), - 'shortcut': self.shortcutWidget.widget.text() - } - return _state - -class AddPointsLayerDialog(QBaseDialog): - sigClosed = Signal() - sigCriticalReadTable = Signal(str) - sigLoadedTable = Signal(object, str) - sigCheckClickEntryTableEndnameExists = Signal(str, bool) - - def __init__( - self, - channelNames=None, - imagesPath='', - SizeT=1, - hideCentroidsSection=False, - hideWeightedCentroidsSection=False, - hideFromTableSection=False, - hideManualEntrySection=False, - hideWithMouseClicksSection=False, - parent=None, - ): - self.cancel = True - super().__init__(parent) - - self._parent = parent - - self.imagesPath = imagesPath - - self.setWindowTitle('Add points layer') - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - mainLayout = QVBoxLayout() - - scrollArea = widgets.ScrollArea() - typeGroupbox = QGroupBox('Points to draw') - typeLayout = QGridLayout() - typeGroupbox.setLayout(typeLayout) - typeLayout.addItem(QSpacerItem(10,1), 0, 0) - typeLayout.setColumnStretch(0, 0) - typeLayout.setColumnStretch(2, 1) - vSpacing = 15 - - row = 0 - - sections = ( - ('addCentroidsSection', hideCentroidsSection), - ('addWeightedCentroidsSection', hideWeightedCentroidsSection), - ('addFromTableSection', hideFromTableSection), - ('addManualEntrySection', hideManualEntrySection), - ('addWithMouseClicksSection', hideWithMouseClicksSection) - ) - radioButtonChecked = False - for section, hideSection in sections: - addFunc = getattr(self, section) - row, sectionWidgets = addFunc( - row, typeLayout, - imagesPath=imagesPath, - SizeT=SizeT, - channelNames=channelNames - ) - if not hideSection: - spacer = QSpacerItem(1, vSpacing) - typeLayout.addItem(spacer, row, 0) - row += 1 - if not radioButtonChecked: - sectionWidgets[0].setChecked(True) - radioButtonChecked = True - continue - - for widget in sectionWidgets: - widget.setVisible(False) - - self.scrollArea = scrollArea - scrollArea.setWidget(typeGroupbox) - - self.appearanceGroupbox = _PointsLayerAppearanceGroupbox() - self.appearanceGroupbox.sizeSpinBox.setValue(3) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - self.buttonsLayout = buttonsLayout - - mainLayout.addWidget(scrollArea) - mainLayout.addSpacing(20) - _layout = QHBoxLayout() - _layout.addWidget(self.appearanceGroupbox) - _layout.addStretch(1) - mainLayout.addLayout(_layout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - self.setFont(font) - - def addCentroidsSection(self, row, layout, **kwargs): - sectionWidgets = [] - self.centroidsRadiobutton = QRadioButton('Centroids') - layout.addWidget(self.centroidsRadiobutton, row, 0, 1, 2) - sectionWidgets.append(self.centroidsRadiobutton) - - self.centroidsRadiobutton.setChecked(True) - return row + 1, sectionWidgets - - def addWeightedCentroidsSection( - self, row, layout, channelNames=None, **kwargs - ): - if channelNames is None: - channelNames = [] - - sectionWidgets = [] - - self.weightedCentroidsRadiobutton = QRadioButton('Weighted centroids') - layout.addWidget(self.weightedCentroidsRadiobutton, row, 0, 1, 2) - sectionWidgets.append(self.weightedCentroidsRadiobutton) - - row += 1 - label = QLabel('Weighing channel: ') - label.setEnabled(False) - layout.addWidget(label, row, 1) - sectionWidgets.append(label) - - self.channelNameForWeightedCentr = widgets.QCenteredComboBox() - if channelNames: - self.channelNameForWeightedCentr.addItems(channelNames) - self.channelNameForWeightedCentr.setDisabled(True) - layout.addWidget(self.channelNameForWeightedCentr, row, 2) - sectionWidgets.append(self.channelNameForWeightedCentr) - - self.weightedCentroidsRadiobutton.toggled.connect(label.setEnabled) - self.weightedCentroidsRadiobutton.toggled.connect( - self.channelNameForWeightedCentr.setEnabled - ) - - return row + 1, sectionWidgets - - def addFromTableSection( - self, row, layout, imagesPath='', SizeT=1, **kwargs - ): - sectionWidgets = [] - - self.fromTableRadiobutton = QRadioButton('From table') - layout.addWidget(self.fromTableRadiobutton, row, 0, 1, 2) - sectionWidgets.append(self.fromTableRadiobutton) - self.fromTableRadiobutton.widgets = [] - - row += 1 - self.tablePath = widgets.ElidingLineEdit() - self.tablePath.label = QLabel('Table file path: ') - layout.addWidget(self.tablePath.label, row, 1) - layout.addWidget(self.tablePath, row, 2) - self.fromTableRadiobutton.widgets.append(self.tablePath) - sectionWidgets.append(self.tablePath.label) - sectionWidgets.append(self.tablePath) - - browseButton = widgets.browseFileButton( - start_dir=imagesPath, ext={'Table': ['.csv', '.h5']} - ) - layout.addWidget(browseButton, row, 3) - browseButton.sigPathSelected.connect(self.tablePathSelected) - self.browseTableButton = browseButton - self.fromTableRadiobutton.widgets.append(browseButton) - sectionWidgets.append(browseButton) - - row += 1 - self.xColName = widgets.QCenteredComboBox() - self.xColName.addItem('None') - self.xColName.label = QLabel('X coord. column: ') - layout.addWidget(self.xColName.label, row, 1) - layout.addWidget(self.xColName, row, 2) - self.xColName.currentTextChanged.connect(self.checkColNameX) - self.fromTableRadiobutton.widgets.append(self.xColName) - sectionWidgets.append(self.xColName.label) - sectionWidgets.append(self.xColName) - - row += 1 - self.yColName = widgets.QCenteredComboBox() - self.yColName.addItem('None') - self.yColName.label = QLabel('Y coord. column: ') - layout.addWidget(self.yColName.label, row, 1) - layout.addWidget(self.yColName, row, 2) - self.yColName.currentTextChanged.connect(self.checkColNameY) - self.fromTableRadiobutton.widgets.append(self.yColName) - sectionWidgets.append(self.yColName.label) - sectionWidgets.append(self.yColName) - - row += 1 - self.zColName = widgets.QCenteredComboBox() - self.zColName.addItem('None') - self.zColName.label = QLabel('Z coord. column: ') - layout.addWidget(self.zColName.label, row, 1) - layout.addWidget(self.zColName, row, 2) - self.zColName.currentTextChanged.connect(self.checkColNameZ) - self.fromTableRadiobutton.widgets.append(self.zColName) - sectionWidgets.append(self.zColName.label) - sectionWidgets.append(self.zColName) - - row += 1 - self.tColName = widgets.QCenteredComboBox() - self.tColName.addItem('None') - self.tColName.label = QLabel('Frame index column: ') - layout.addWidget(self.tColName.label, row, 1) - layout.addWidget(self.tColName, row, 2) - self.fromTableRadiobutton.widgets.append(self.tColName) - sectionWidgets.append(self.tColName.label) - sectionWidgets.append(self.tColName) - - if SizeT == 1: - self.tColName.clear() - self.tColName.addItem('None') - self.tColName.label.setVisible(False) - self.tColName.setVisible(False) - - self.fromTableRadiobutton.toggled.connect(self.enableRadioButtonWidgets) - self.enableRadioButtonWidgets(False, sender=self.fromTableRadiobutton) - - return row + 1, sectionWidgets - - def addManualEntrySection(self, row, layout, SizeT=1, **kwargs): - sectionWidgets = [] - - self.manualEntryRadiobutton = QRadioButton('Manual entry') - layout.addWidget(self.manualEntryRadiobutton, row, 0, 1, 2) - self.manualEntryRadiobutton.widgets = [] - sectionWidgets.append(self.manualEntryRadiobutton) - - row += 1 - self.manualXspinbox = widgets.NumericCommaLineEdit() - self.manualXspinbox.label = QLabel('X coords: ') - layout.addWidget(self.manualXspinbox.label, row, 1) - layout.addWidget(self.manualXspinbox, row, 2) - self.manualEntryRadiobutton.widgets.append(self.manualXspinbox) - sectionWidgets.append(self.manualXspinbox.label) - sectionWidgets.append(self.manualXspinbox) - - row += 1 - self.manualYspinbox = widgets.NumericCommaLineEdit() - self.manualYspinbox.label = QLabel('Y coords: ') - layout.addWidget(self.manualYspinbox.label, row, 1) - layout.addWidget(self.manualYspinbox, row, 2) - self.manualEntryRadiobutton.widgets.append(self.manualYspinbox) - sectionWidgets.append(self.manualYspinbox.label) - sectionWidgets.append(self.manualYspinbox) - - row += 1 - self.manualZspinbox = widgets.NumericCommaLineEdit() - self.manualZspinbox.label = QLabel('Z coords: ') - layout.addWidget(self.manualZspinbox.label, row, 1) - layout.addWidget(self.manualZspinbox, row, 2) - self.manualEntryRadiobutton.widgets.append(self.manualZspinbox) - sectionWidgets.append(self.manualZspinbox.label) - sectionWidgets.append(self.manualZspinbox) - - row += 1 - self.manualTspinbox = widgets.NumericCommaLineEdit() - self.manualTspinbox.label = QLabel('Frame numbers: ') - layout.addWidget(self.manualTspinbox.label, row, 1) - layout.addWidget(self.manualTspinbox, row, 2) - self.manualEntryRadiobutton.widgets.append(self.manualTspinbox) - sectionWidgets.append(self.manualTspinbox.label) - sectionWidgets.append(self.manualTspinbox) - - if SizeT == 1: - self.manualTspinbox.setVisible(False) - self.manualTspinbox.label.setVisible(False) - - self.manualEntryRadiobutton.toggled.connect(self.enableRadioButtonWidgets) - self.enableRadioButtonWidgets(False, sender=self.manualEntryRadiobutton) - - return row + 1, sectionWidgets - - def addWithMouseClicksSection(self, row, layout, imagesPath='', **kwargs): - sectionWidgets = [] - - self.clickEntryIsLoadedDf = None - - self.clickEntryRadiobutton = QRadioButton('Add points with mouse clicks') - layout.addWidget(self.clickEntryRadiobutton, row, 0, 1, 2) - self.clickEntryRadiobutton.widgets = [] - sectionWidgets.append(self.clickEntryRadiobutton) - - row += 1 - self.snapToMaxToggle = widgets.Toggle() - self.snapToMaxToggle.label = QLabel('Snap to closest maximum: ') - layout.addWidget(self.snapToMaxToggle.label, row, 1) - layout.addWidget( - self.snapToMaxToggle, row, 2, alignment=Qt.AlignCenter - ) - sectionWidgets.append(self.snapToMaxToggle.label) - sectionWidgets.append(self.snapToMaxToggle) - - self.snapToMaxInfoButton = widgets.infoPushButton() - layout.addWidget(self.snapToMaxInfoButton, row, 3) - sectionWidgets.append(self.snapToMaxInfoButton) - - self.snapToMaxInfoButton.clicked.connect(self.showSnapToMaxButton) - self.clickEntryRadiobutton.widgets.append(self.snapToMaxToggle) - self.clickEntryRadiobutton.widgets.append(self.snapToMaxInfoButton) - - row += 1 - self.autoPilotToggle = widgets.Toggle() - self.autoPilotToggle.label = QLabel('Use auto-pilot: ') - layout.addWidget(self.autoPilotToggle.label, row, 1) - layout.addWidget( - self.autoPilotToggle, row, 2, alignment=Qt.AlignCenter - ) - sectionWidgets.append(self.autoPilotToggle.label) - sectionWidgets.append(self.autoPilotToggle) - self.autoPilotInfoButton = widgets.infoPushButton() - layout.addWidget(self.autoPilotInfoButton, row, 3) - sectionWidgets.append(self.autoPilotInfoButton) - - self.autoPilotInfoButton.clicked.connect(self.showAutoPilotInfo) - self.clickEntryRadiobutton.widgets.append(self.autoPilotToggle) - self.clickEntryRadiobutton.widgets.append(self.autoPilotInfoButton) - - row += 1 - self.clickEntryTableEndname = widgets.alphaNumericLineEdit() - self.clickEntryTableEndname.setText('points_added_by_clicking') - self.clickEntryTableEndname.setAlignment(Qt.AlignCenter) - self.clickEntryTableEndname.label = QLabel('Table endname: ') - loadButton = widgets.browseFileButton( - start_dir=imagesPath, ext={'CSV': '.csv'} - ) - layout.addWidget(loadButton, row, 3) - sectionWidgets.append(loadButton) - - loadButton.sigPathSelected.connect(self.loadClickEntryTable) - self.loadButton = loadButton - self.clickEntryLoadTableButton = loadButton - layout.addWidget(self.clickEntryTableEndname.label, row, 1) - layout.addWidget(self.clickEntryTableEndname, row, 2) - self.clickEntryRadiobutton.widgets.append(self.clickEntryTableEndname) - self.clickEntryTableEndname.editingFinished.connect( - self.emitCheckClickEntryTableEndnameExists - ) - sectionWidgets.append(self.clickEntryTableEndname) - sectionWidgets.append(self.clickEntryTableEndname.label) - - row += 1 - instructionsText = html_utils.paragraph( - '
Left-click to annotate a new point with a new id.

' - 'Right-click to annotate a point with the same id

' - 'Same click used to delete objects to annotate
' - 'a point with id = 0 (negative prompt)

' - 'Click on point to delete it', - font_size='11px' - ) - self.instructionsLabel = QLabel(instructionsText) - self.instructionsLabel.label = QLabel('Instructions') - layout.addWidget(self.instructionsLabel.label, row, 1) - layout.addWidget(self.instructionsLabel, row, 2) - self.clickEntryRadiobutton.widgets.append(self.instructionsLabel) - sectionWidgets.append(self.instructionsLabel) - sectionWidgets.append(self.instructionsLabel.label) - - self.clickEntryRadiobutton.toggled.connect(self.enableRadioButtonWidgets) - self.clickEntryRadiobutton.toggled.connect( - self.emitCheckClickEntryTableEndnameExists - ) - self.enableRadioButtonWidgets(False, sender=self.clickEntryRadiobutton) - - return row + 1, sectionWidgets - - def emitCheckClickEntryTableEndnameExists(self, *args, **kwargs): - if not self.clickEntryRadiobutton.isChecked(): - return - self.clickEntryIsLoadedDf = None - tableEndName = self.clickEntryTableEndname.text() - self.sigCheckClickEntryTableEndnameExists.emit( - tableEndName, False - ) - - def loadClickEntryTable(self, csv_path): - self.clickEntryIsLoadedDf = None - posData = load.loadData(csv_path, 'points') - posData.getBasenameAndChNames(qparent=self) - basename = posData.basename - filename = os.path.basename(csv_path) - filename, ext = os.path.splitext(filename) - if not basename.endswith('_'): - basename = f'{basename}_' - - endname = filename[len(basename):] - self.clickEntryTableEndname.setText(endname) - self.sigCheckClickEntryTableEndnameExists.emit( - endname, True - ) - - def showAutoPilotInfo(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - With Auto-pilot mode active, Cell-ACDC will automatically zoom on - to an object
- to allow you clicking on the points you want to add.

- You can then go to the next object by pressing the - Enter key or go back to the
- previous object by pressing Backspace. - """) - msg.information(self, 'Auto-pilot info', txt) - - def showSnapToMaxButton(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - With mode active, Cell-ACDC will - automatically add the point
- to the closest maximum within the point footprint (defined in - the appearance settings). - """) - msg.information(self, 'Snap to closest maximum info', txt) - - def closeEvent(self, event): - self.sigClosed.emit() - - def enableRadioButtonWidgets(self, enabled, sender=None): - if sender is None: - sender = self.sender() - for widget in sender.widgets: - widget.setDisabled(not enabled) - try: - widget.label.setDisabled(not enabled) - except: - pass - - def _readTable(self, path): - return load.load_df_points_layer(path) - - def tryAutoFillColNames(self, df): - if 'x' in df.columns: - self.xColName.setCurrentText('x') - - if 'y' in df.columns: - self.yColName.setCurrentText('y') - - if 'z' in df.columns: - self.zColName.setCurrentText('z') - - if 'frame_i' in df.columns: - self.tColName.setCurrentText('frame_i') - - def tablePathSelected(self, path): - self.tablePath.setText(path) - try: - df = self._readTable(path) - self.xColName.addItems(df.columns) - self.yColName.addItems(df.columns) - self.zColName.addItems(df.columns) - self.tColName.addItems(df.columns) - self.tryAutoFillColNames(df) - self.sigLoadedTable.emit(df, os.path.basename(path)) - self.browseTableButton.confirmAction() - except Exception as e: - traceback_format = traceback.format_exc() - self.sigCriticalReadTable.emit(traceback_format) - self.criticalReadTable(path, traceback_format) - self.tablePath.setText('') - - - def criticalLenMismatchManualEntry(self): - txt = html_utils.paragraph(f""" - X coords and Y coords must have the same length. - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.critical(self, f'X and Y have different length', txt) - - def criticalColNameIsNone(self, axis): - txt = html_utils.paragraph(f""" - The "{axis.upper()} coord. column" cannot be "None" - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.critical(self, f'{axis.upper()} coord. is None', txt) - - def criticalReadTable(self, path, traceback_format): - txt = html_utils.paragraph(f""" - Something went wrong when reading the table from the - following path:

- {path}

- See the error message below. - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - detailsText = traceback_format - msg.critical( - self, 'Error when reading table', txt, detailsText=detailsText) - - def criticalEmptyTablePath(self): - txt = html_utils.paragraph(f""" - The table file path cannot be empty. - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.critical(self, 'Table file path is empty', txt) - - def state(self): - _state = self.appearanceGroupbox.state() - return _state - - def _checkSelectedColName(self, colName, label): - labelsToCheck = ['z', 'y', 'x'] - labelsToCheck.remove(label) - for labelToCheck in labelsToCheck: - if colName.find(labelToCheck) != -1: - break - else: - return True - - txt = html_utils.paragraph(f""" - Are you sure that the {label.upper()} coord. column should contain - the letter {labelToCheck}? - """) - - msg = widgets.myMessageBox(wrapText=False) - _, noButton, yesButton = msg.warning( - self, 'Check column name', txt, - buttonsTexts=('Cancel', 'No, let me correct it', 'Yes, I am') - ) - if msg.cancel or msg.clickedButton == noButton: - return False - return True - - def checkColNameX(self, text): - accepted = self._checkSelectedColName(text, 'x') - if accepted: - return - self.xColName.setCurrentText('None') - - def checkColNameY(self, text): - accepted = self._checkSelectedColName(text, 'y') - if accepted: - return - self.yColName.setCurrentText('None') - - def checkColNameZ(self, text): - accepted = self._checkSelectedColName(text, 'z') - if accepted: - return - self.zColName.setCurrentText('None') - - def ok_cb(self): - self.pointsData = {} - self.loadedDfInfo = None - self.loadedDf = None - self.weighingChannel = '' - if self.fromTableRadiobutton.isChecked(): - tablePath = self.tablePath.text() - if not tablePath: - self.criticalEmptyTablePath() - return - - try: - df = self._readTable(tablePath) - tColName = self.tColName.currentText() - xColName = self.xColName.currentText() - yColName = self.yColName.currentText() - zColName = self.zColName.currentText() - - self.loadedDfInfo = { - 'filepath': tablePath, - 't': tColName, - 'z': zColName, - 'y': yColName, - 'x': xColName - } - - self._df_to_pointsData( - df, tColName, zColName, yColName, xColName - ) - - except Exception as e: - traceback_format = traceback.format_exc() - self.sigCriticalReadTable.emit(traceback_format) - self.criticalReadTable(tablePath, traceback_format) - return - - if self.xColName.currentText() == 'None': - self.criticalColNameIsNone('x') - return - if self.yColName.currentText() == 'None': - self.criticalColNameIsNone('y') - return - - self.layerType = os.path.basename(self.tablePath.text()) - self.layerTypeIdx = 2 - elif self.centroidsRadiobutton.isChecked(): - self.layerType = 'Centroids' - self.layerTypeIdx = 0 - elif self.weightedCentroidsRadiobutton.isChecked(): - channel = self.channelNameForWeightedCentr.currentText() - self.weighingChannel = channel - self.layerType = f'Centroids weighted by channel {channel}' - self.layerTypeIdx = 1 - elif self.manualEntryRadiobutton.isChecked(): - xx = self.manualXspinbox.values() - yy = self.manualYspinbox.values() - if len(xx) != len(yy): - self.criticalLenMismatchManualEntry() - return - zz = self.manualZspinbox.values() - tt = [t+1 for t in self.manualTspinbox.values()] - df = pd.DataFrame({'x': xx, 'y': yy, 'id': np.arange(1, len(xx)+1)}) - if tt: - df['t'] = tt - tCol = 't' - else: - tCol = 'None' - if zz: - df['z'] = zz - zCol = 'z' - else: - zCol = 'None' - - self._df_to_pointsData(df, tCol, zCol, 'y', 'x') - - self.layerType = 'Manual entry' - self.layerTypeIdx = 3 - elif self.clickEntryRadiobutton.isChecked(): - self.layerType = ('Click to annotate point') - self.description = ( - 'Left-click to add a point, click on point to delete it.\n' - 'With auto-pilot you can navigate through object with Up/Down arrows.' - ) - self.clickEntryTableEndnameText = self.clickEntryTableEndname.text() - self.layerTypeIdx = 4 - - self.cancel = False - symbol = self.appearanceGroupbox.symbolWidget.widget.currentText() - self.symbol = re.findall(r"\'(.+)\'", symbol)[0] - self.symbolText = symbol - self.color = self.appearanceGroupbox.colorButton.color() - self.pointSize = self.appearanceGroupbox.sizeSpinBox.value() - self.zHeight = self.appearanceGroupbox.zHeightSpinBox.value() - shortcutWidget = self.appearanceGroupbox.shortcutWidget - self.shortcut = shortcutWidget.widget.text() - self.keySequence = shortcutWidget.widget.keySequence - self.close() - - def _df_to_pointsData(self, df, tColName, zColName, yColName, xColName): - self.pointsData = load.loaded_df_to_points_data( - df, tColName, zColName, yColName, xColName - ) - - def showEvent(self, event) -> None: - if self._parent is None: - screen = self.screen() - else: - screen = self._parent.screen() - screenWidth = screen.size().width() - screenHeight = screen.size().height() - - maxHeight = screenHeight - 100 - - buttonHeight = self.buttonsLayout.okButton.minimumSizeHint().height() - height = ( - self.scrollArea.minimumHeightNoScrollbar() - + self.appearanceGroupbox.sizeHint().height() - + buttonHeight + 70 - ) - width = self.scrollArea.minimumWidthNoScrollbar() + 50 - - height = min(height, maxHeight) - - self.resize(width, height) - - screenLeft = screen.geometry().x() - screenTop = screen.geometry().y() - w, h = self.width(), self.height() - left = int(screenLeft + screenWidth/2 - w/2) - top = int(screenTop + screenHeight/2 - h/2 - 20) - - self.move(left, top) - -class EditPointsLayerAppearanceDialog(QBaseDialog): - sigClosed = Signal() - - def __init__(self, parent=None): - self.cancel = True - super().__init__(parent) - - self._parent = parent - - self.setWindowTitle('Custom annotation') - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - mainLayout = QVBoxLayout() - - self.appearanceGroupbox = _PointsLayerAppearanceGroupbox() - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(self.appearanceGroupbox) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - self.setFont(font) - - def restoreState(self, state): - self.appearanceGroupbox.restoreState(state) - - def closeEvent(self, event): - super().closeEvent(event) - self.sigClosed.emit() - - def state(self): - _state = self.appearanceGroupbox.state() - return _state - - def ok_cb(self): - self.cancel = False - symbol = self.appearanceGroupbox.symbolWidget.widget.currentText() - self.symbol = re.findall(r"\'(.+)\'", symbol)[0] - self.color = self.appearanceGroupbox.colorButton.color() - self.pointSize = self.appearanceGroupbox.sizeSpinBox.value() - self.zHeight = self.appearanceGroupbox.zHeightSpinBox.value() - shortcutWidget = self.appearanceGroupbox.shortcutWidget - self.shortcut = shortcutWidget.widget.text() - self.keySequence = shortcutWidget.widget.keySequence - self.close() - -class filenameDialog(QDialog): - def __init__( - self, ext='.npz', basename='', title='Insert file name', - hintText='', existingNames='', parent=None, allowEmpty=True, - helpText='', defaultEntry='', resizeOnShow=True, - additionalButtons=None, addDoNotSaveButton=False - ): - self.cancel = True - super().__init__(parent) - - self.resizeOnShow = resizeOnShow - - if hintText.find('segmentation') != -1: - if helpText: - helpText = (f'{helpText}') - helpText_loc = (""" - With Cell-ACDC you can create as many segmentation files - as you want.

- If you plan to create only one file then you can leave the - text entry empty.
- Cell-ACDC will save the segmentation file with the filename - ending with _segm.npz.

- However, we recommend to insert some text that will easily - allow you to identify what is the segmentation file about.

- For example, if you are about to segment the channel - phase_contr, you could write - phase_contr.
- Cell-ACDC will then save the file with the - filename ending with _segm_phase_contr.npz.

- This way you can create multiple segmentation files, - for example one for each channel or one for each segmentation model.

- Note that the numerical features and annotations will be saved - in a CSV file ending with the same text as the segmentation file,
- e.g., ending with _acdc_output_phase_contr.csv. - """) - helpText = (f'{helpText}{html_utils.paragraph(helpText_loc)}') - - self.isSegmFile = basename.endswith('_segm') - self.allowEmpty = allowEmpty - self.basename = basename - if ext and not ext.startswith('.'): - ext = f'.{ext}' - self.ext = ext - - self.setWindowTitle(title) - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - layout = QVBoxLayout() - entryLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - - hintLabel = QLabel(hintText) - - basenameLabel = QLabel(basename) - - self.lineEdit = widgets.alphaNumericLineEdit(onlyWarn=True) - self.lineEdit.setAlignment(Qt.AlignCenter) - defaultEntry = to_alphanumeric(defaultEntry) - defaultEntry = defaultEntry.replace('.', '_') - self.lineEdit.setText(defaultEntry) - - extLabel = QLabel(ext) - - self.filenameLabel = QLabel() - self.filenameLabel.setText(f'{basename}{ext}') - - entryLayout.addWidget(basenameLabel, 0, 1) - entryLayout.addWidget(self.lineEdit, 0, 2) - entryLayout.addWidget(extLabel, 0, 3) - entryLayout.addWidget( - self.filenameLabel, 1, 1, 1, 3, alignment=Qt.AlignCenter - ) - # entryLayout.setColumnStretch(0, 1) - entryLayout.setColumnStretch(2, 1) - - self.warningInvalidCharLabel = QLabel() - - okButton = widgets.okPushButton('Ok') - cancelButton = widgets.cancelPushButton('Cancel') - self.okButton = okButton - - buttonsLayout.addStretch() - buttonsLayout.addWidget(cancelButton) - - if addDoNotSaveButton: - doNotSaveButton = widgets.noPushButton('Do not save') - doNotSaveButton.clicked.connect(self.doNotSave_cb) - buttonsLayout.addWidget(doNotSaveButton) - self.doNotSave = False - - buttonsLayout.addSpacing(20) - if helpText: - helpButton = widgets.helpPushButton('Help...') - helpButton.clicked.connect(partial(self.showHelp, helpText)) - buttonsLayout.addWidget(helpButton) - if additionalButtons is not None: - for button in additionalButtons: - buttonsLayout.addWidget(button) - buttonsLayout.addWidget(okButton) - - cancelButton.clicked.connect(self.close) - okButton.clicked.connect(self.ok_cb) - self.lineEdit.textChanged.connect(self.updateFilename) - self.lineEdit.sigInvalidCharactersEntered.connect( - self.warnInvalidCharactersEntered - ) - - self.existingNames = [] - if existingNames: - self.existingNames = existingNames - # self.lineEdit.editingFinished.connect(self.checkExistingNames) - - layout.addWidget(hintLabel) - layout.addSpacing(20) - layout.addLayout(entryLayout) - layout.addSpacing(10) - layout.addWidget(self.warningInvalidCharLabel) - layout.addStretch(1) - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - - self.setLayout(layout) - self.setFont(font) - - if defaultEntry: - self.updateFilename(defaultEntry) - - def doNotSave_cb(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Are you sure you do not want to save the file?' - ) - noButton, yesButton = msg.warning( - self, 'Do not save?', txt, buttonsTexts=('No', 'Yes') - ) - if msg.clickedButton == noButton: - return - - self.doNotSave = True - self.cancel = False - self.close() - - def showHelp(self, text): - text = html_utils.paragraph(text) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Filename help', text) - - def _text(self): - return self.lineEdit.text() - - def warnInvalidCharactersEntered(self, characters: set[str]): - statement = 'is not a valid character' - if len(characters) > 1: - statement = 'are not valid characters' - - characters_str = ''.join(characters) - characters_str = html.escape(characters_str) - warning_text = html_utils.span(f""" - WARNING: "{characters_str}" {statement}.
- """) - warning_text = ( - f'{warning_text}' - 'Valid characters are letters, numbers, underscore, and dash.' - ) - self.warningInvalidCharLabel.setText(warning_text) - - def checkExistingNames(self): - is_existing = ( - self._text() in self.existingNames - or self.filenameLabel.text() in self.existingNames - ) - if not is_existing: - return True - - filename = self.filenameLabel.text() - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'The following file

' - f'{filename}

' - 'is already existing.

' - 'Do you want to overwrite the existing file?' - ) - noButton, yesButton = msg.warning( - self, 'File name existing', txt, buttonsTexts=('No', 'Yes') - ) - return msg.clickedButton == yesButton - - def updateFilename(self, text): - if self.lineEdit.invalidCharacters(): - return - - if not text: - self.filenameLabel.setText(f'{self.basename}{self.ext}') - else: - text = text.replace(' ', '_') - if self.basename: - if self.basename.endswith('_'): - self.filenameLabel.setText(f'{self.basename}{text}{self.ext}') - else: - self.filenameLabel.setText(f'{self.basename}_{text}{self.ext}') - else: - self.filenameLabel.setText(f'{text}{self.ext}') - - self.warningInvalidCharLabel.setText('') - - def checkEmptyText(self): - if self.allowEmpty: - return True - - if self._text(): - return True - - msg = widgets.myMessageBox() - msg.critical( - self, 'Empty text', - html_utils.paragraph('Text entry field cannot be empty') - ) - return False - - def checkSegmFilename(self): - if not self.isSegmFile: - return True - - if 'segm' not in self._text(): - return True - - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'The text appended to the filename cannot contain the text ' - '"segm".

' - 'Sorry, that would confuse me. Thank you for your patience!' - ) - msg.critical( - self, 'Cannot use "segm" in filename', txt - ) - return False - - def ok_cb(self, checked=True): - if self.warningInvalidCharLabel.text(): - return - - valid = self.checkExistingNames() - if not valid: - return - - valid = self.checkEmptyText() - if not valid: - return - - valid = self.checkSegmFilename() - if not valid: - return - - self.filename = self.filenameLabel.text() - self.entryText = self._text() - self.cancel = False - self.close() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - super().show() - if self.resizeOnShow: - self.lineEdit.setMinimumWidth(self.lineEdit.width()*2) - self.okButton.setDefault(True) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - -class wandToleranceWidget(QFrame): - def __init__(self, parent=None): - super().__init__(parent) - - self.slider = widgets.sliderWithSpinBox(title='Tolerance') - self.slider.setMaximum(255) - self.slider._layout.setColumnStretch(2, 21) - - self.setLayout(self.slider.layout) - -class TrackSubCellObjectsDialog(QBaseDialog): - def __init__(self, basename='', parent=None): - self.cancel = True - super().__init__(parent=parent) - - self.setWindowTitle('Track sub-cellular objects parameters') - - mainLayout = QVBoxLayout() - entriesLayout = widgets.FormLayout() - - row = 0 - infoTxt = html_utils.paragraph(""" - Select behaviour with untracked objects:

- NOTE: this utility always create new files. - Original segmentation masks
are not modified
. - """) - options = ( - 'Delete sub-cellular objects that do not belong to any cell', - 'Delete cells that do not have any sub-cellular object', - 'Delete both cells and sub-cellular objects without an assignment', - 'Only track the objects and keep all the non-tracked objects' - ) - combobox = widgets.QCenteredComboBox() - combobox.addItems(options) - self.optionsWidget = widgets.formWidget( - combobox, addInfoButton=True, labelTextLeft='Tracking mode: ', - infoTxt=infoTxt - ) - entriesLayout.addFormWidget(self.optionsWidget, row=row) - - row += 1 - infoTxt = html_utils.paragraph(""" - Re-label sub-cellular objects before assigning them to the cell.

- Activate this option if you have merged sub-cellular objects - that must be separated, or the segmentation is a boolean mask - (i.e., semantic segmentation). - """) - self.relabelSubObjLab = widgets.formWidget( - widgets.Toggle(), addInfoButton=True, stretchWidget=False, - labelTextLeft='Re-label sub-cellular objects before tracking: ', - infoTxt=infoTxt - ) - entriesLayout.addFormWidget(self.relabelSubObjLab, row=row) - - row += 1 - IoAtext = html_utils.paragraph(""" - Enter a minimum percentage (0-1) of the sub-cellular object's area
- that MUST overlap with the parent cell to be considered belonging to a cell: - """) - spinbox = widgets.CenteredDoubleSpinbox() - spinbox.setMaximum(1) - spinbox.setValue(0.5) - spinbox.setSingleStep(0.1) - self.IoAwidget = widgets.formWidget( - spinbox, addInfoButton=True, labelTextLeft='IoA threshold: ', - infoTxt=IoAtext - ) - entriesLayout.addFormWidget(self.IoAwidget, row=row) - - row += 1 - infoTxt = html_utils.paragraph(""" - The third segmentation file is the result of subtracting the - sub-cellular objects from the parent objects

- This is useful if, for example, you need to compute measurements - only from the cytoplasm (i.e., the sub-cellular object is the nucleus). - """) - self.createThirdSegmWidget = widgets.formWidget( - widgets.Toggle(), addInfoButton=True, stretchWidget=False, - labelTextLeft='Create third segmentation: ', infoTxt=infoTxt - ) - entriesLayout.addFormWidget(self.createThirdSegmWidget, row=row) - - row += 1 - infoTxt = html_utils.paragraph(""" - Text to append at the end of the third segmentation file.

- The third segmentation file is the result of subtracting the - sub-cellular objects from the parent objects

- This is useful if, for example, you need to compute measurements - only from the cytoplasm (i.e., the sub-cellular object is the nucleus). - """) - lineEdit = widgets.alphaNumericLineEdit() - lineEdit.setText('difference') - lineEdit.setAlignment(Qt.AlignCenter) - self.appendTextWidget = widgets.formWidget( - lineEdit, addInfoButton=True, labelTextLeft='Text to append: ', - infoTxt=infoTxt - ) - entriesLayout.addFormWidget(self.appendTextWidget, row=row) - self.appendTextWidget.setDisabled(True) - - - self.createThirdSegmWidget.widget.toggled.connect( - self.createThirdSegmToggled - ) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(entriesLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - self.setFont(font) - - def createThirdSegmToggled(self, checked): - self.appendTextWidget.setDisabled(not checked) - - def ok_cb(self): - self.cancel = False - if self.createThirdSegmWidget.widget.isChecked(): - if not self.appendTextWidget.widget.text(): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph( - 'When creating the third segmentation file, ' - 'the name to append cannot be empty!' - ) - msg.critical(self, 'Empty name', txt) - return - - self.trackSubCellObjParams = { - 'how': self.optionsWidget.widget.currentText(), - 'IoA': self.IoAwidget.widget.value(), - 'createThirdSegm': self.createThirdSegmWidget.widget.isChecked(), - 'relabelSubObjLab': self.relabelSubObjLab.widget.isChecked(), - 'thirdSegmAppendedText': self.appendTextWidget.widget.text() - } - self.close() - -class SetMeasurementsDialog(QBaseDialog): - sigClosed = Signal() - sigCancel = Signal() - sigRestart = Signal() - - def __init__( - self, loadedChNames, notLoadedChNames, isZstack, isSegm3D, - favourite_funcs=None, parent=None, allPos_acdc_df_cols=None, - acdc_df_path=None, posData=None, addCombineMetricCallback=None, - allPosData=None, is_concat=False, isSingleSelection=False, - state=None - ): - super().__init__(parent=parent) - - self.checkBoxedGroup = QButtonGroup() - self.checkBoxedGroup.setExclusive(isSingleSelection) - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - self.cancel = True - - self.delExistingCols = False - self.okClicked = False - self.is_concat = is_concat - self.allPos_acdc_df_cols = allPos_acdc_df_cols - self.acdc_df_path = acdc_df_path - self.allPosData = allPosData - self.doNotWarn = False - - self.setWindowTitle('Set measurements') - # self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - layout = QVBoxLayout() - - searchLayout = QHBoxLayout() - - searchLineEdit = widgets.SearchLineEdit() - searchLayout.addStretch(5) - searchLayout.addWidget(searchLineEdit) - searchLayout.setStretch(1, 3) - - mainScrollArea = widgets.ScrollArea() - mainScrollAreaWidget = QWidget() - mainScrollArea.setWidget(mainScrollAreaWidget) - - groupsLayout = QGridLayout() - self.groupsLayout = groupsLayout - - mainScrollAreaWidget.setLayout(groupsLayout) - - buttonsLayout = QHBoxLayout() - - self.chNameGroupboxes = [] - self.all_metrics = [] - - col = 0 - for col, chName in enumerate(loadedChNames): - channelGBox = widgets.channelMetricsQGBox( - isZstack, chName, isSegm3D, favourite_funcs=favourite_funcs, - posData=posData, is_concat=is_concat - ) - channelGBox.chName = chName - groupsLayout.addWidget(channelGBox, 0, col, 3, 1) - self.chNameGroupboxes.append(channelGBox) - channelGBox.sigDelClicked.connect(self.delMixedChannelCombineMetric) - channelGBox.sigCheckboxToggled.connect(self.channelCheckboxToggled) - groupsLayout.setColumnStretch(col, 5) - self.all_metrics.extend([c.text() for c in channelGBox.checkBoxes]) - - current_col = col+1 - for col, chName in enumerate(notLoadedChNames): - channelGBox = widgets.channelMetricsQGBox( - isZstack, chName, isSegm3D, favourite_funcs=favourite_funcs, - posData=posData, is_concat=is_concat - ) - channelGBox.setChecked(False) - channelGBox.chName = chName - groupsLayout.addWidget(channelGBox, 0, current_col, 3, 1) - self.chNameGroupboxes.append(channelGBox) - groupsLayout.setColumnStretch(current_col, 5) - channelGBox.sigDelClicked.connect(self.delMixedChannelCombineMetric) - channelGBox.sigCheckboxToggled.connect(self.channelCheckboxToggled) - current_col += 1 - self.all_metrics.extend([c.text() for c in channelGBox.checkBoxes]) - - current_col += 1 - - if posData is None: - isTimelapse = False - else: - isTimelapse = posData.SizeT>1 - size_metrics_desc = measurements.get_size_metrics_desc( - isSegm3D, isTimelapse - ) - if not isSegm3D: - size_metrics_desc = { - key:val for key,val in size_metrics_desc.items() - if not key.endswith('_3D') - } - - row = 0 - sizeMetricsQGBox = widgets._metricsQGBox( - size_metrics_desc, 'Physical measurements', - favourite_funcs=favourite_funcs, isZstack=isZstack, - addCalcForEachZsliceToggle=isSegm3D - ) - self.all_metrics.extend([c.text() for c in sizeMetricsQGBox.checkBoxes]) - self.sizeMetricsQGBox = sizeMetricsQGBox - for sizeCheckbox in sizeMetricsQGBox.checkBoxes: - sizeCheckbox.toggled.connect(self.sizeMetricToggled) - groupsLayout.addWidget(sizeMetricsQGBox, row, current_col) - groupsLayout.setRowStretch(0, 1) - groupsLayout.setColumnStretch(current_col, 3) - row += 1 - - props_info_txt_mapper = measurements.get_props_info_txt_mapper( - isSegm3D=isSegm3D - ) - rp_desc = props_info_txt_mapper - regionPropsQGBox = widgets._metricsQGBox( - rp_desc, 'Morphological properties', - favourite_funcs=favourite_funcs, isZstack=isZstack - ) - self.regionPropsQGBox = regionPropsQGBox - for rpCheckbox in regionPropsQGBox.checkBoxes: - rpCheckbox.toggled.connect(self.rpMetricToggled) - groupsLayout.addWidget(regionPropsQGBox, row, current_col) - groupsLayout.setRowStretch(1, 2) - self.all_metrics.extend([c.text() for c in regionPropsQGBox.checkBoxes]) - row += 1 - - # Custom metrics that are channel indipendent - self.chIndipendCustomeMetricsQGBox = None - out = measurements.ch_indipend_custom_metrics_desc( - isZstack, isSegm3D=isSegm3D, - ) - ch_indipend_custom_metrics_desc = out - if ch_indipend_custom_metrics_desc: - self.chIndipendCustomeMetricsQGBox = widgets._metricsQGBox( - ch_indipend_custom_metrics_desc, - 'Channel indipendent custom measurements', - favourite_funcs=favourite_funcs, isZstack=isZstack, - parent=self - ) - groupsLayout.addWidget( - self.chIndipendCustomeMetricsQGBox, row, current_col - ) - groupsLayout.setRowStretch(1, 1) - row += 1 - - desc, equations = measurements.combine_mixed_channels_desc( - isSegm3D=isSegm3D, posData=posData, available_cols=self.all_metrics - ) - self.mixedChannelsCombineMetricsQGBox = None - if desc: - self.mixedChannelsCombineMetricsQGBox = widgets._metricsQGBox( - desc, 'Mixed channels combined measurements', - favourite_funcs=favourite_funcs, isZstack=isZstack, - equations=equations, addDelButton=True - ) - self.mixedChannelsCombineMetricsQGBox.sigDelClicked.connect( - self.delMixedChannelCombineMetric - ) - groupsLayout.addWidget( - self.mixedChannelsCombineMetricsQGBox, row, current_col - ) - groupsLayout.setRowStretch(1, 1) - if not self.is_concat: - self.setDisabledMetricsRequestedForCombined(False) - self.mixedChannelsCombineMetricsQGBox.toggled.connect( - self.setDisabledMetricsRequestedForCombined - ) - for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - combCheckbox.toggled.connect( - self.setDisabledMetricsRequestedForCombined - ) - else: - for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - combCheckbox.toggled.connect( - self.mixedChannelsMetricToggled - ) - row += 1 - - self.last_row = row - self.last_col = current_col - - okButton = widgets.okPushButton(' Ok ') - cancelButton = widgets.cancelPushButton('Cancel') - if addCombineMetricCallback is not None: - addCombineMetricButton = widgets.addPushButton( - 'Add combined measurement...' - ) - addCombineMetricButton.clicked.connect(addCombineMetricCallback) - self.okButton = okButton - - loadLastSelButton = widgets.reloadPushButton('Load last selection...') - self.deselectAllButton = QPushButton('Deselect all') - self.deselectAllButton.setIcon(QIcon(':deselect_all.svg')) - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(self.deselectAllButton) - buttonsLayout.addSpacing(20) - - if addCombineMetricCallback is not None: - buttonsLayout.addWidget(addCombineMetricButton) - buttonsLayout.addSpacing(20) - - saveCurrentSelectionButton = widgets.savePushButton( - 'Save current selection...' - ) - saveCurrentSelectionButton.clicked.connect( - self.saveCurrentSelectionClicked - ) - - buttonsLayout.addWidget(saveCurrentSelectionButton) - - loadSavedSelectionButton = widgets.OpenFilePushButton( - 'Load saved selection...' - ) - loadSavedSelectionButton.clicked.connect(self.loadSavedSelectionClicked) - buttonsLayout.addWidget(loadSavedSelectionButton) - - buttonsLayout.addWidget(loadLastSelButton) - - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - self.okButton = okButton - - layout.addLayout(searchLayout) - layout.addSpacing(10) - # layout.addLayout(groupsLayout) - layout.addWidget(mainScrollArea) - layout.addLayout(buttonsLayout) - - self.setLayout(layout) - - if state is not None: - self.setState(state) - - searchLineEdit.textEdited.connect(self.searchAndHighlight) - self.deselectAllButton.clicked.connect(self.deselectAll) - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - loadLastSelButton.clicked.connect(self.loadLastSelection) - - self.addCheckboxesToGroup() - - for channelGBox in self.chNameGroupboxes: - for checkbox in channelGBox.checkBoxes: - self.channelCheckboxToggled(checkbox) - - def allMetricsDict(self): - all_metrics = { - 'standard': {}, - 'regionprop': [], - 'size': [], - 'mixed_channels': [] - } - for chNameGroupbox in self.chNameGroupboxes: - channel_name = chNameGroupbox.chName - for checkBox in chNameGroupbox.checkBoxes: - if channel_name not in all_metrics['standard']: - all_metrics['standard'][channel_name] = [] - all_metrics['standard'][channel_name].append(checkBox.text()) - - for checkBox in self.regionPropsQGBox.checkBoxes: - all_metrics['regionprop'].append(checkBox.text()) - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - all_metrics['size'].append(checkBox.text()) - - if self.chIndipendCustomeMetricsQGBox is not None: - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - all_metrics['ch_indipend_custom_metric'].append(checkBox.text()) - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - all_metrics['mixed_channels'].append(checkBox.text()) - - return all_metrics - - def searchAndHighlight(self, text): - for chNameGroupbox in self.chNameGroupboxes: - for groupbox in chNameGroupbox.groupboxes: - groupbox.highlightCheckboxesFromSearchText(text) - - self.regionPropsQGBox.highlightCheckboxesFromSearchText(text) - self.sizeMetricsQGBox.highlightCheckboxesFromSearchText(text) - - if self.chIndipendCustomeMetricsQGBox is not None: - self.chIndipendCustomeMetricsQGBox.highlightCheckboxesFromSearchText( - text - ) - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - self.mixedChannelsCombineMetricsQGBox.highlightCheckboxesFromSearchText( - text - ) - - def selectedMetricNameAndGroup(self): - for chNameGroupbox in self.chNameGroupboxes: - for checkBox in chNameGroupbox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text(), {'standard': chNameGroupbox.chName} - - for checkBox in self.regionPropsQGBox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text(), 'regionprop' - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text(), 'size' - - if self.chIndipendCustomeMetricsQGBox is not None: - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - if checkBox.isChecked(): - return checkBox.text(), 'ch_indipend_custom_metric' - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - if checkBox.isChecked(): - return checkBox.text(), 'mixed_channels' - - def selectedMetricGroup(self): - for chNameGroupbox in self.chNameGroupboxes: - for checkBox in chNameGroupbox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text() - - for checkBox in self.regionPropsQGBox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text() - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - if checkBox.isChecked(): - return checkBox.text() - - if self.chIndipendCustomeMetricsQGBox is not None: - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - if checkBox.isChecked(): - return checkBox.text() - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - if checkBox.isChecked(): - return checkBox.text() - - def addCheckboxesToGroup(self): - for chNameGroupbox in self.chNameGroupboxes: - for checkBox in chNameGroupbox.checkBoxes: - self.checkBoxedGroup.addButton(checkBox) - - for checkBox in self.regionPropsQGBox.checkBoxes: - self.checkBoxedGroup.addButton(checkBox) - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - self.checkBoxedGroup.addButton(checkBox) - - if self.chIndipendCustomeMetricsQGBox is not None: - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - self.checkBoxedGroup.addButton(checkBox) - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - self.checkBoxedGroup.addButton(checkBox) - - def channelCheckboxToggled(self, checkbox): - # Make sure to automatically check the requested cell_vol metric for - # concentration metrics - if checkbox.text().find('concentration_') == -1: - return - - if self.is_concat: - # When this dialogue is used in concatenate pos utility we do not - # need to check that certain metrics are present - return - - pattern = r'.+_from_vol_([a-z]+)(_3D)?(_?[A-Za-z0-9]*)' - repl = r'cell_vol_\1\2' - cell_vol_metric_name = re.sub(pattern, repl, checkbox.text()) - for sizeCheckbox in self.sizeMetricsQGBox.checkBoxes: - if sizeCheckbox.text() == cell_vol_metric_name: - break - else: - # Make sure to not check for similarly named custom metrics - return - - if checkbox.isChecked(): - sizeCheckbox.setChecked(True) - sizeCheckbox.isRequired = True - else: - # Do not enable cell vol checkbox is any of the other - # concentration metrics requiring it is checked - unit = cell_vol_metric_name[9:] - is3D = unit.endswith('3D') - for channelGBox in self.chNameGroupboxes: - if not channelGBox.isChecked(): - continue - for _checkbox in channelGBox.checkBoxes: - if _checkbox.text().find(f'_from_vol_{unit}') == -1: - continue - if not is3D and _checkbox.text().find(f'{unit}_3D') != -1: - # Metric is 3D but the cell_vol is not - continue - if _checkbox.isChecked(): - return - sizeCheckbox.isRequired = False - - def rpMetricToggled(self, checked): - pass - - def mixedChannelsMetricToggled(self, checked): - pass - - def sizeMetricToggled(self, checked): - """Method called when a checkbox of a size metric is toggled. - Check if the size value is required and explain why it cannot be - unchecked. - - Parameters - ---------- - checked : bool - State of the checkbox toggled - """ - checkbox = self.sender() - - if self.is_concat: - # When this dialogue is used in concatenate pos utility we do not - # need to check that certain metrics are present - return - - if not hasattr(checkbox, 'isRequired'): - return - - if not checkbox.isRequired: - return - - if checkbox.isChecked(): - return - - checkbox.setChecked(True) - - if self.doNotWarn: - return - - linked_autoBkgr_metric = checkbox.text().replace('cell', '_autoBkgr_from') - linked_dataPrepBkgr_metric = checkbox.text().replace( - 'cell', '_dataPrepBkgr_from' - ) - txt = html_utils.paragraph(f""" - This physical measurement cannot be unchecked - because it is required - by the {linked_autoBkgr_metric} and - {linked_dataPrepBkgr_metric} measurements - that you requested to save.

- - Thank you for you patience! - """) - msg = widgets.myMessageBox(showCentered=False) - msg.warning(self, 'Physical measurement required', txt) - - def deselectAll(self): - self.doNotWarn = True - for chNameGroupbox in self.chNameGroupboxes: - for gb in chNameGroupbox.groupboxes: - gb.checkAll(None, False) - cgb = getattr(chNameGroupbox, 'customMetricsQGBox', None) - if cgb is not None: - cgb.checkAll(None, False) - - self.sizeMetricsQGBox.checkAll(None, False) - self.regionPropsQGBox.checkAll(None, False) - if self.chIndipendCustomeMetricsQGBox is not None: - self.chIndipendCustomeMetricsQGBox.checkAll(None, False) - - if self.mixedChannelsCombineMetricsQGBox is not None: - self.mixedChannelsCombineMetricsQGBox.checkAll(None, False) - self.doNotWarn = False - - def delMixedChannelCombineMetric(self, colname_to_del, hlayout): - cp = measurements.read_saved_user_combine_config() - for section in cp.sections(): - cp.remove_option(section, colname_to_del) - measurements.save_common_combine_metrics(cp) - - for i in range(hlayout.count()): - item = hlayout.itemAt(i) - w = item.widget() - if w is None: - continue - w.hide() - - if self.allPosData is not None: - for posData in self.allPosData: - _config = posData.combineMetricsConfig - for section in _config.sections(): - _config.remove_option(section, colname_to_del) - posData.saveCombineMetrics() - - def setState(self, state): - self.doNotWarn = True - for chNameGroupbox in self.chNameGroupboxes: - measurementsInfo = state.get(chNameGroupbox.title()) - if not measurementsInfo: - chNameGroupbox.setChecked(False) - else: - for checkBox in chNameGroupbox.checkBoxes: - colname = checkBox.text() - checkBox.setChecked(measurementsInfo[colname]) - - measurementsInfo = state.get(self.sizeMetricsQGBox.title()) - if not measurementsInfo: - self.sizeMetricsQGBox.setChecked(False) - else: - for checkBox in self.sizeMetricsQGBox.checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - checkBox.setChecked(measurementsInfo[colname]) - - measurementsInfo = state.get(self.regionPropsQGBox.title()) - if not measurementsInfo: - self.regionPropsQGBox.setChecked(False) - else: - self.regionPropsToSave = [] - for checkBox in self.regionPropsQGBox.checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - checkBox.setChecked(measurementsInfo[colname]) - - if self.chIndipendCustomeMetricsQGBox is not None: - measurementsInfo = state.get( - self.chIndipendCustomeMetricsQGBox.title() - ) - if not measurementsInfo: - self.chIndipendCustomeMetricsQGBox.setChecked(False) - else: - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - key = self.chIndipendCustomeMetricsQGBox.title() - checkBox.setChecked(measurementsInfo[colname]) - - if self.mixedChannelsCombineMetricsQGBox is not None: - measurementsInfo = state.get( - self.mixedChannelsCombineMetricsQGBox.title() - ) - if not measurementsInfo: - self.mixedChannelsCombineMetricsQGBox.setChecked(False) - else: - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - key = self.mixedChannelsCombineMetricsQGBox.title() - checkBox.setChecked(measurementsInfo[colname]) - - self.doNotWarn = False - - def state(self): - state = { - self.sizeMetricsQGBox.title(): {}, - self.regionPropsQGBox.title(): {} - } - for chNameGroupbox in self.chNameGroupboxes: - state[chNameGroupbox.title()] = {} - if not chNameGroupbox.isChecked(): - # Channel unchecked - continue - else: - for checkBox in chNameGroupbox.checkBoxes: - colname = checkBox.text() - state[chNameGroupbox.title()][colname] = checkBox.isChecked() - - if not self.sizeMetricsQGBox.isChecked(): - pass - else: - for checkBox in self.sizeMetricsQGBox.checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - state[self.sizeMetricsQGBox.title()][colname] = checked - - if not self.regionPropsQGBox.isChecked(): - pass - else: - self.regionPropsToSave = [] - for checkBox in self.regionPropsQGBox.checkBoxes: - checked = checkBox.isChecked() - colname = checkBox.text() - state[self.regionPropsQGBox.title()][colname] = checked - - if self.chIndipendCustomeMetricsQGBox is not None: - state[self.chIndipendCustomeMetricsQGBox.title()] = {} - if self.chIndipendCustomeMetricsQGBox.isChecked(): - checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - checked = checkBox.isChecked() - key = self.chIndipendCustomeMetricsQGBox.title() - colname = checkBox.text() - state[key][colname] = checked - - if self.mixedChannelsCombineMetricsQGBox is not None: - state[self.mixedChannelsCombineMetricsQGBox.title()] = {} - if self.mixedChannelsCombineMetricsQGBox.isChecked(): - checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes - for checkBox in checkBoxes: - checked = checkBox.isChecked() - key = self.mixedChannelsCombineMetricsQGBox.title() - colname = checkBox.text() - state[key][colname] = checked - - return state - - def restoreState(self, state): - for chNameGroupbox in self.chNameGroupboxes: - _state = state.get(chNameGroupbox.title()) - if _state is None or not _state: - continue - for checkBox in chNameGroupbox.checkBoxes: - isChecked = _state.get(checkBox.text()) - if isChecked is None: - continue - checkBox.setChecked(isChecked) - - _state = state.get(self.sizeMetricsQGBox.title()) - if _state is None or not _state: - pass - else: - for checkBox in self.sizeMetricsQGBox.checkBoxes: - isChecked = _state.get(checkBox.text()) - if isChecked is None: - continue - checkBox.setChecked(isChecked) - - _state = state.get(self.regionPropsQGBox.title()) - if _state is None or not _state: - pass - else: - for checkBox in self.regionPropsQGBox.checkBoxes: - isChecked = _state.get(checkBox.text()) - if isChecked is None: - continue - checkBox.setChecked(isChecked) - - if self.chIndipendCustomeMetricsQGBox is not None: - _state = state.get(self.chIndipendCustomeMetricsQGBox.title()) - if _state is None or not _state: - pass - else: - for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: - isChecked = _state.get(checkBox.text()) - if isChecked is None: - continue - checkBox.setChecked(isChecked) - - if self.mixedChannelsCombineMetricsQGBox is not None: - _state = state.get(self.mixedChannelsCombineMetricsQGBox.title()) - if _state is None or not _state: - pass - else: - for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - isChecked = _state.get(checkBox.text()) - if isChecked is None: - continue - checkBox.setChecked(isChecked) - - def currentSelectionMapper(self): - current_selected_meas = defaultdict(dict) - - for chNameGroupbox in self.chNameGroupboxes: - if not chNameGroupbox.isChecked(): - continue - - chName = chNameGroupbox.chName - for checkBox in chNameGroupbox.checkBoxes: - if not checkBox.isChecked(): - continue - - current_selected_meas[chName][checkBox.text()] = 'Yes' - - size_selected_meas = current_selected_meas.get( - self.sizeMetricsQGBox.title() - ) - if self.sizeMetricsQGBox.isChecked(): - for checkBox in self.sizeMetricsQGBox.checkBoxes: - if not checkBox.isChecked(): - continue - - section = self.sizeMetricsQGBox.title() - current_selected_meas[section][checkBox.text()] = 'Yes' - - size_selected_meas = current_selected_meas.get( - self.regionPropsQGBox.title() - ) - if self.regionPropsQGBox.isChecked(): - for checkBox in self.regionPropsQGBox.checkBoxes: - if not checkBox.isChecked(): - continue - - section = self.regionPropsQGBox.title() - current_selected_meas[section][checkBox.text()] = 'Yes' - - if self.chIndipendCustomeMetricsQGBox is not None: - if self.chIndipendCustomeMetricsQGBox.isChecked(): - for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: - if not checkBox.isChecked(): - continue - - section = self.chIndipendCustomeMetricsQGBox.title() - current_selected_meas[section][checkBox.text()] = 'Yes' - - if self.mixedChannelsCombineMetricsQGBox is not None: - if self.mixedChannelsCombineMetricsQGBox.isChecked(): - for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - if not checkBox.isChecked(): - continue - - section = self.mixedChannelsCombineMetricsQGBox.title() - current_selected_meas[section][checkBox.text()] = 'Yes' - - return current_selected_meas - - def saveCurrentSelectionClicked(self): - current_selection_mapper = self.currentSelectionMapper() - defaultEntry = '_and_'.join(current_selection_mapper.keys()) - defaultEntry = defaultEntry.replace(' ', '_').lower() - saved_selections = io.get_saved_measurements_selections() - win = filenameDialog( - basename='', - ext='', - hintText='Insert a name for the current selection:', - existingNames=saved_selections, - allowEmpty=False, - defaultEntry=defaultEntry - ) - win.exec_() - if win.cancel: - return - - filename = win.filename - ini_filepath = io.save_measurements_selections( - filename, current_selection_mapper) - - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(f""" - Done!

- Current selection saved with name {filename} at - the following path: - """) - msg.information( - self, 'Selection saved', txt, - commands=(ini_filepath,), - path_to_browse=os.path.dirname(ini_filepath), - ) - - def loadSavedSelectionClicked(self): - self.doNotWarn = True - - saved_selections = io.get_saved_measurements_selections() - - selectNameWin = widgets.QDialogListbox( - 'Choose selection to load', - 'Choose selection to load:\n', - saved_selections, - multiSelection=False, - parent=self - ) - selectNameWin.exec_() - if selectNameWin.cancel: - return - - selection_mapper = ( - io.read_measurements_selections(selectNameWin.selectedItemsText[0]) - ) - - self.setCurrentSelectionFromMapper(selection_mapper) - - self.doNotWarn = False - - def saveLastSelection(self): - last_selected_meas = self.currentSelectionMapper() - load.write_last_selected_set_measurements(last_selected_meas) - - def setCurrentSelectionFromMapper(self, selection_mapper): - for chNameGroupbox in self.chNameGroupboxes: - chName = chNameGroupbox.chName - chSelectedMeas = selection_mapper.get(chName) - if chSelectedMeas is None: - chNameGroupbox.setChecked(False) - continue - - chNameGroupbox.setChecked(True) - for checkBox in chNameGroupbox.checkBoxes: - checked = chSelectedMeas.get(checkBox.text()) - if checked is not None: - checkBox.setChecked(True) - else: - checkBox.setChecked(False) - - size_selected_meas = selection_mapper.get( - self.sizeMetricsQGBox.title() - ) - if size_selected_meas is None: - self.sizeMetricsQGBox.setChecked(False) - else: - self.sizeMetricsQGBox.setChecked(True) - for checkBox in self.sizeMetricsQGBox.checkBoxes: - checked = size_selected_meas.get(checkBox.text()) - if checked is not None: - checkBox.setChecked(True) - else: - checkBox.setChecked(False) - - size_selected_meas = selection_mapper.get( - self.regionPropsQGBox.title() - ) - if size_selected_meas is None: - self.regionPropsQGBox.setChecked(False) - else: - self.regionPropsQGBox.setChecked(True) - for checkBox in self.regionPropsQGBox.checkBoxes: - checked = size_selected_meas.get(checkBox.text()) - if checked is not None: - checkBox.setChecked(True) - else: - checkBox.setChecked(False) - - if self.chIndipendCustomeMetricsQGBox is not None: - ch_indip_custom_metrics = selection_mapper.get( - self.chIndipendCustomeMetricsQGBox.title() - ) - if size_selected_meas is None: - self.chIndipendCustomeMetricsQGBox.setChecked(False) - else: - self.chIndipendCustomeMetricsQGBox.setChecked(True) - for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: - checked = size_selected_meas.get(checkBox.text()) - if checked is not None: - checkBox.setChecked(True) - else: - checkBox.setChecked(False) - - if self.mixedChannelsCombineMetricsQGBox is not None: - ch_indip_custom_metrics = selection_mapper.get( - self.mixedChannelsCombineMetricsQGBox.title() - ) - if size_selected_meas is None: - self.mixedChannelsCombineMetricsQGBox.setChecked(False) - else: - self.mixedChannelsCombineMetricsQGBox.setChecked(True) - for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - checked = size_selected_meas.get(checkBox.text()) - if checked is not None: - checkBox.setChecked(True) - else: - checkBox.setChecked(False) - - def loadLastSelection(self): - self.doNotWarn = True - last_selected_meas = load.read_last_selected_set_measurements() - last_selected_meas = dict(last_selected_meas) - - self.setCurrentSelectionFromMapper(last_selected_meas) - - self.doNotWarn = False - - def setDisabledMetricsRequestedForCombined(self, checked): - checkbox = self.sender() - - if self.is_concat: - # When this dialogue is used in concatenate pos utility we do not - # need to check that certain metrics are present - return - - # Set checked and disable those metrics that are requested for - # combined measurements - allCheckboxes = [] - - for chNameGroupbox in self.chNameGroupboxes: - for chCheckBox in chNameGroupbox.checkBoxes: - chCheckBox.setDisabled(False) - allCheckboxes.append(chCheckBox) - - for sizeCheckBox in self.sizeMetricsQGBox.checkBoxes: - sizeCheckBox.setDisabled(False) - allCheckboxes.append(chCheckBox) - - for rpCheckBox in self.regionPropsQGBox.checkBoxes: - rpCheckBox.setDisabled(False) - allCheckboxes.append(chCheckBox) - - if not self.mixedChannelsCombineMetricsQGBox.isChecked(): - return - - for cb in allCheckboxes: - metricName = cb.text() - for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - equation = combCheckbox.equation - if equation.find(metricName) == -1: - continue - elif combCheckbox.isChecked(): - cb.setChecked(True) - cb.setDisabled(True) - cb.setToolTip( - 'This metric cannot be removed because it is required ' - f'by the combined measurement "{combCheckbox.text()}"' - ) - - def keyPressEvent(self, a0: QKeyEvent) -> None: - state = self.state() - return super().keyPressEvent(a0) - - def closeEvent(self, event): - if self.cancel: - self.sigCancel.emit() - super().closeEvent(event) - - def restart(self): - self.cancel = False - self.close() - self.sigRestart.emit() - - def setDisabledNotExistingMeasurements(self, existing_colnames): - self.existing_colnames = existing_colnames - for chNameGroupbox in self.chNameGroupboxes: - for checkBox in chNameGroupbox.checkBoxes: - colname = checkBox.text() - if colname in existing_colnames: - checkBox.setChecked(True) - continue - - checkBox.setChecked(False) - checkBox.setDisabled(True) - self.setNotExistingMeasurementTooltip(checkBox) - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - colname = checkBox.text() - if colname in existing_colnames: - checkBox.setChecked(True) - continue - checkBox.setChecked(False) - checkBox.setDisabled(True) - self.setNotExistingMeasurementTooltip(checkBox) - - for checkBox in self.regionPropsQGBox.checkBoxes: - prop_name = checkBox.text() - for existing_col in existing_colnames: - if prop_name == existing_col: - checkBox.setChecked(True) - break - m = re.match(fr'{prop_name}-\d', existing_col) - if m is not None: - checkBox.setChecked(True) - break - else: - checkBox.setChecked(False) - checkBox.setDisabled(True) - self.setNotExistingMeasurementTooltip(checkBox) - - if self.mixedChannelsCombineMetricsQGBox is None: - return - - for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: - colname = combCheckbox.text() - if colname in existing_colnames: - combCheckbox.setChecked(True) - continue - combCheckbox.setChecked(False) - combCheckbox.setDisabled(True) - self.setNotExistingMeasurementTooltip(combCheckbox) - - def addNonMeasurementColumns(self, colnames): - additionalCols = measurements.get_non_measurements_cols( - colnames, self.all_metrics - ) - if not additionalCols: - return - self.nonMeasurementsGroupbox = widgets.CheckboxesGroupBox( - additionalCols, title='Additional columns', checkable=True - ) - self.groupsLayout.addWidget( - self.nonMeasurementsGroupbox, 0, self.last_col+1, self.last_row+1, 1 - ) - - - def setNotExistingMeasurementTooltip(self, checkBox): - checkBox.setToolTip( - 'Measurement is disabled because it is not present in selected ' - 'acdc_output tables, hence it cannot be addded to concatenated ' - 'table. ' - ) - - def ok_cb(self): - for chNameGroupbox in self.chNameGroupboxes: - chNameGroupbox.calcForEachZsliceRequested = ( - chNameGroupbox.isCalcForEachZsliceRequested() - ) - - self.sizeMetricsQGBox.calcForEachZsliceRequested = ( - self.sizeMetricsQGBox.isCalcForEachZsliceRequested() - ) - - if self.allPos_acdc_df_cols is None: - self.saveLastSelection() - self.cancel = False - self.close() - self.sigClosed.emit() - return - - self.okClicked = True - existing_colnames = self.allPos_acdc_df_cols - unchecked_existing_colnames = [] - unchecked_existing_rps = [] - for chNameGroupbox in self.chNameGroupboxes: - for checkBox in chNameGroupbox.checkBoxes: - colname = checkBox.text() - is_existing = colname in existing_colnames - if not chNameGroupbox.isChecked() and is_existing: - unchecked_existing_colnames.append(colname) - continue - if not checkBox.isChecked() and is_existing: - unchecked_existing_colnames.append(colname) - - for checkBox in self.sizeMetricsQGBox.checkBoxes: - colname = checkBox.text() - is_existing = colname in existing_colnames - if not self.sizeMetricsQGBox.isChecked() and is_existing: - unchecked_existing_colnames.append(colname) - continue - - if not checkBox.isChecked() and is_existing: - unchecked_existing_colnames.append(colname) - for checkBox in self.regionPropsQGBox.checkBoxes: - colname = checkBox.text() - is_existing = any([col == colname for col in existing_colnames]) - if not self.regionPropsQGBox.isChecked() and is_existing: - unchecked_existing_rps.append(colname) - continue - - if not checkBox.isChecked() and is_existing: - unchecked_existing_rps.append(colname) - - if unchecked_existing_colnames or unchecked_existing_rps: - cancel, self.delExistingCols = self.warnUncheckedExistingMeasurements( - unchecked_existing_colnames, unchecked_existing_rps - ) - self.existingUncheckedColnames = unchecked_existing_colnames - self.existingUncheckedRps = unchecked_existing_rps - if cancel: - return - - self.saveLastSelection() - self.cancel = False - self.close() - self.sigClosed.emit() - - def warnUncheckedExistingMeasurements( - self, unchecked_existing_colnames, unchecked_existing_rps - ): - msg = widgets.myMessageBox() - msg.setWidth(500) - msg.addShowInFileManagerButton(self.acdc_df_path) - txt = html_utils.paragraph( - 'You chose to not save some measurements that are ' - 'already present in the saved acdc_output.csv ' - 'file.

' - 'Do you want to delete these measurements or ' - 'keep them?

' - 'Existing measurements not selected:' - ) - listView = widgets.readOnlyQList(msg) - items = unchecked_existing_colnames.copy() - items.extend(unchecked_existing_rps) - listView.addItems(items) - _, delButton, keepButton = msg.warning( - self, 'Unchecked existing measurements', txt, - widgets=listView, buttonsTexts=('Cancel', 'Delete', 'Keep') - ) - return msg.cancel, msg.clickedButton == delButton - - def show(self, block=False): - super().show(block=False) - self.deselectAllButton.setMinimumHeight(self.okButton.height()) - screenWidth = self.screen().size().width() - screenHeight = self.screen().size().height() - screenLeft = self.screen().geometry().x() - screenTop = self.screen().geometry().y() - h = screenHeight-200 - minColWith = screenWidth/5 - w = minColWith*(self.last_col+1) - xLeft = int((screenWidth-w)/2) - if w > screenWidth: - self.move(screenLeft+10, screenTop+50) - self.resize(screenWidth-20, h) - else: - self.move(screenLeft+xLeft, screenTop+50) - self.resize(int(w), h) - super().show(block=block) - -class QDialogMetadataXML(QDialog): - def __init__( - self, title='Metadata', - LensNA=1.0, rawFilename='test', SizeT=1, SizeZ=1, SizeC=1, SizeS=1, - TimeIncrement=1.0, TimeIncrementUnit='s', - PhysicalSizeX=1.0, PhysicalSizeY=1.0, PhysicalSizeZ=1.0, - PhysicalSizeUnit='μm', ImageName='', chNames=None, emWavelens=None, - parent=None, rawDataStruct=None, sampleImgData=None, - rawFilePath=None - ): - self.cancel = True - self.trust = False - self.overWrite = False - rawFilename = os.path.splitext(rawFilename)[0] - self.rawFilename = self.removeInvalidCharacters(rawFilename) - self.rawFilePath = rawFilePath - self.sampleImgData = sampleImgData - self.ImageName = ImageName - self.rawDataStruct = rawDataStruct - self.readSampleImgDataAgain = False - self.requestedReadingSampleImageDataAgain = False - self.imageViewer = None - super().__init__(parent) - self.setWindowTitle(title) - font = QFont() - font.setPixelSize(12) - self.setFont(font) - - mainLayout = QVBoxLayout() - entriesLayout = QGridLayout() - self.channelNameLayouts = ( - QVBoxLayout(), QVBoxLayout(), QVBoxLayout(), QVBoxLayout() - ) - self.channelEmWLayouts = ( - QVBoxLayout(), QVBoxLayout(), QVBoxLayout(), QVBoxLayout() - ) - buttonsLayout = QGridLayout() - - infoLabel = QLabel() - infoTxt = ( - 'Confirm/Edit the metadata below.' - ) - infoLabel.setText(infoTxt) - # padding: top, left, bottom, right - infoLabel.setStyleSheet("font-size:12pt; padding:0px 0px 5px 0px;") - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - - noteLabel = QLabel() - noteLabel.setText( - f'NOTE: If you are not sure about some of the entries ' - 'you can try to click "Ok".\n' - 'If they are wrong you will get ' - 'an error message later when trying to read the data.' - ) - noteLabel.setAlignment(Qt.AlignCenter) - mainLayout.addWidget(noteLabel, alignment=Qt.AlignCenter) - - row = 0 - to_tif_radiobutton = QRadioButton(".tif") - to_tif_radiobutton.setChecked(True) - to_h5_radiobutton = QRadioButton(".h5") - to_h5_radiobutton.setToolTip( - '.h5 is highly recommended for big datasets to avoid memory issues.\n' - 'As a rule of thumb, if the single position, single channel file\n' - 'is larger than 1/5 of the available RAM we recommend using .h5 format' - ) - self.to_h5_radiobutton = to_h5_radiobutton - txt = 'File format: ' - label = QLabel(txt) - fileFormatLayout = QHBoxLayout() - fileFormatLayout.addStretch(1) - fileFormatLayout.addWidget(to_tif_radiobutton) - fileFormatLayout.addStretch(1) - fileFormatLayout.addWidget(to_h5_radiobutton) - fileFormatLayout.addStretch(1) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addLayout(fileFormatLayout, row, 1) - to_h5_radiobutton.toggled.connect(self.updateFileFormat) - - row += 1 - self.SizeS_SB = QSpinBox() - self.SizeS_SB.setAlignment(Qt.AlignCenter) - self.SizeS_SB.setMinimum(1) - self.SizeS_SB.setMaximum(2147483647) - self.SizeS_SB.setValue(SizeS) - txt = 'Number of positions (SizeS): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.SizeS_SB, row, 1) - - if rawDataStruct == 0: - row += 1 - self.SizeS_SB.setValue(1) - self.SizeS_SB.setDisabled(True) - self.posSelector = widgets.ExpandableListBox() - positions = ['All positions'] - positions.extend([f'Position_{i+1}' for i in range(SizeS)]) - self.posSelector.addItems(positions) - txt = 'Positions to save: ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.posSelector, row, 1) - self.SizeS_SB.valueChanged.connect(self.SizeSvalueChanged) - - row += 1 - self.LensNA_DSB = QDoubleSpinBox() - self.LensNA_DSB.setAlignment(Qt.AlignCenter) - self.LensNA_DSB.setSingleStep(0.1) - self.LensNA_DSB.setValue(LensNA) - txt = 'Numerical Aperture Objective Lens: ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.LensNA_DSB, row, 1) - - row += 1 - self.SizeT_SB = QSpinBox() - self.SizeT_SB.setAlignment(Qt.AlignCenter) - self.SizeT_SB.setMinimum(1) - self.SizeT_SB.setMaximum(2147483647) - self.SizeT_SB.setValue(SizeT) - txt = 'Number of frames (SizeT): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.SizeT_SB, row, 1) - self.SizeT_SB.valueChanged.connect(self.hideShowTimeIncrement) - - row += 1 - self.timeRangeToSaveWidget = widgets.RangeSelector(integers=True) - self.timeRangeToSaveWidget.setRange(1, SizeT) - txt = 'Time range to save: ' - label = QLabel(txt) - self.timeRangeToSaveWidget.label = label - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.timeRangeToSaveWidget, row, 1) - - row += 1 - self.SizeZ_SB = QSpinBox() - self.SizeZ_SB.setAlignment(Qt.AlignCenter) - self.SizeZ_SB.setMinimum(1) - self.SizeZ_SB.setMaximum(2147483647) - self.SizeZ_SB.setValue(SizeZ) - txt = 'Number of z-slices in the z-stack (SizeZ): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.SizeZ_SB, row, 1) - self.SizeZ_SB.valueChanged.connect(self.hideShowPhysicalSizeZ) - - row += 1 - self.TimeIncrement_DSB = widgets.FloatLineEdit( - allowNegative=False, warningValues={1.0} - ) - self.TimeIncrement_DSB.setValue(TimeIncrement) - self.TimeIncrement_DSB.setMinimum(0.0) - txt = 'Frame interval: ' - label = QLabel(txt) - self.TimeIncrement_Label = label - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.TimeIncrement_DSB, row, 1) - - self.TimeIncrementUnit_CB = QComboBox() - unitItems = [ - 'ms', 'seconds', 'minutes', 'hours' - ] - currentTxt = [unit for unit in unitItems - if unit.startswith(TimeIncrementUnit)] - self.TimeIncrementUnit_CB.addItems(unitItems) - if currentTxt: - self.TimeIncrementUnit_CB.setCurrentText(currentTxt[0]) - entriesLayout.addWidget( - self.TimeIncrementUnit_CB, row, 2, alignment=Qt.AlignLeft - ) - - row += 1 - self.PhysicalSizeX_DSB = QDoubleSpinBox() - self.PhysicalSizeX_DSB.setAlignment(Qt.AlignCenter) - self.PhysicalSizeX_DSB.setMaximum(2147483647.0) - self.PhysicalSizeX_DSB.setSingleStep(0.001) - self.PhysicalSizeX_DSB.setDecimals(7) - self.PhysicalSizeX_DSB.setValue(PhysicalSizeX) - txt = 'Pixel width (X): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.PhysicalSizeX_DSB, row, 1) - - self.PhysicalSizeUnit_CB = QComboBox() - unitItems = [ - 'nm', 'μm', 'mm', 'cm' - ] - currentTxt = [unit for unit in unitItems - if unit.startswith(PhysicalSizeUnit)] - self.PhysicalSizeUnit_CB.addItems(unitItems) - if currentTxt: - self.PhysicalSizeUnit_CB.setCurrentText(currentTxt[0]) - else: - self.PhysicalSizeUnit_CB.setCurrentText(unitItems[1]) - entriesLayout.addWidget( - self.PhysicalSizeUnit_CB, row, 2, alignment=Qt.AlignLeft - ) - self.PhysicalSizeUnit_CB.currentTextChanged.connect(self.updatePSUnit) - - row += 1 - self.PhysicalSizeY_DSB = QDoubleSpinBox() - self.PhysicalSizeY_DSB.setAlignment(Qt.AlignCenter) - self.PhysicalSizeY_DSB.setMaximum(2147483647.0) - self.PhysicalSizeY_DSB.setSingleStep(0.001) - self.PhysicalSizeY_DSB.setDecimals(7) - self.PhysicalSizeY_DSB.setValue(PhysicalSizeY) - txt = 'Pixel height (Y): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.PhysicalSizeY_DSB, row, 1) - - self.PhysicalSizeYUnit_Label = QLabel() - self.PhysicalSizeYUnit_Label.setStyleSheet( - 'font-size:13px; padding:5px 0px 2px 0px;' - ) - unit = self.PhysicalSizeUnit_CB.currentText() - self.PhysicalSizeYUnit_Label.setText(unit) - entriesLayout.addWidget(self.PhysicalSizeYUnit_Label, row, 2) - - row += 1 - self.PhysicalSizeZ_DSB = QDoubleSpinBox() - self.PhysicalSizeZ_DSB.setAlignment(Qt.AlignCenter) - self.PhysicalSizeZ_DSB.setMaximum(2147483647.0) - self.PhysicalSizeZ_DSB.setSingleStep(0.001) - self.PhysicalSizeZ_DSB.setDecimals(7) - self.PhysicalSizeZ_DSB.setValue(PhysicalSizeZ) - txt = 'Voxel depth (Z): ' - self.PSZlabel = QLabel(txt) - entriesLayout.addWidget(self.PSZlabel, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.PhysicalSizeZ_DSB, row, 1) - - self.PhysicalSizeZUnit_Label = QLabel() - # padding: top, left, bottom, right - self.PhysicalSizeZUnit_Label.setStyleSheet( - 'font-size:13px; padding:5px 0px 2px 0px;' - ) - unit = self.PhysicalSizeUnit_CB.currentText() - self.PhysicalSizeZUnit_Label.setText(unit) - entriesLayout.addWidget(self.PhysicalSizeZUnit_Label, row, 2) - - if SizeZ == 1: - self.PSZlabel.hide() - self.PhysicalSizeZ_DSB.hide() - self.PhysicalSizeZUnit_Label.hide() - - row += 1 - self.SizeC_SB = QSpinBox() - self.SizeC_SB.setAlignment(Qt.AlignCenter) - self.SizeC_SB.setMinimum(1) - self.SizeC_SB.setMaximum(2147483647) - self.SizeC_SB.setValue(SizeC) - txt = 'Number of channels (SizeC): ' - label = QLabel(txt) - entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - entriesLayout.addWidget(self.SizeC_SB, row, 1) - self.SizeC_SB.valueChanged.connect(self.addRemoveChannels) - - row += 1 - for j, layout in enumerate(self.channelNameLayouts): - entriesLayout.addLayout(layout, row, j) - - self.chNames_QLEs = [] - self.saveChannels_QCBs = [] - self.filename_QLabels = [] - self.showChannelDataButtons = [] - - ext = 'h5' if self.to_h5_radiobutton.isChecked() else 'tif' - for c in range(SizeC): - chName_QLE = QLineEdit() - chName_QLE.setStyleSheet('') - chName_QLE.setAlignment(Qt.AlignCenter) - chName_QLE.textChanged.connect(self.checkChNames) - if chNames is not None: - chName_QLE.setText(chNames[c]) - else: - chName_QLE.setText(f'channel_{c}') - filename = f'' - - txt = f'Channel {c} name: ' - label = QLabel(txt) - - filenameDescLabel = QLabel(f'e.g., filename for channel {c}: ') - - chName = chName_QLE.text() - chName = self.removeInvalidCharacters(chName) - rawFilename = self.elidedRawFilename() - filenameLabel = QLabel(f""" -

{rawFilename}_{chName}.{ext}

- """) - filenameLabel.setToolTip(f'{self.rawFilename}_{chName}.{ext}') - - checkBox = QCheckBox('Save this channel') - checkBox.setChecked(True) - checkBox.stateChanged.connect(self.saveCh_checkBox_cb) - - self.channelNameLayouts[0].addWidget(label, alignment=Qt.AlignRight) - self.channelNameLayouts[0].addWidget( - filenameDescLabel, alignment=Qt.AlignRight - ) - self.channelNameLayouts[1].addWidget(chName_QLE) - self.channelNameLayouts[1].addWidget( - filenameLabel, alignment=Qt.AlignCenter - ) - - self.channelNameLayouts[2].addWidget(checkBox) - if c == 0 and ImageName: - addImageName_QCB = QCheckBox('Include image name') - addImageName_QCB.stateChanged.connect(self.addImageName_cb) - self.addImageName_QCB = addImageName_QCB - self.channelNameLayouts[2].addWidget(addImageName_QCB) - else: - self.addImageName_QCB = QCheckBox('dummy') - self.addImageName_QCB.hide() - self.channelNameLayouts[2].addWidget(QLabel()) - - showChannelDataButton = QPushButton() - showChannelDataButton.setIcon(QIcon(":eye-plus.svg")) - showChannelDataButton.clicked.connect(self.showChannelData) - self.channelNameLayouts[3].addWidget(showChannelDataButton) - if self.sampleImgData is None: - showChannelDataButton.setDisabled(True) - - self.chNames_QLEs.append(chName_QLE) - self.saveChannels_QCBs.append(checkBox) - self.filename_QLabels.append(filenameLabel) - self.showChannelDataButtons.append(showChannelDataButton) - - self.checkChNames() - - row += 1 - for j, layout in enumerate(self.channelEmWLayouts): - entriesLayout.addLayout(layout, row, j) - - self.emWavelens_DSBs = [] - for c in range(SizeC): - row += 1 - emWavelen_DSB = QDoubleSpinBox() - emWavelen_DSB.setAlignment(Qt.AlignCenter) - emWavelen_DSB.setMaximum(2147483647.0) - emWavelen_DSB.setSingleStep(0.001) - emWavelen_DSB.setDecimals(2) - if emWavelens is not None: - emWavelen_DSB.setValue(emWavelens[c]) - else: - emWavelen_DSB.setValue(500.0) - - txt = f'Channel {c} emission wavelength: ' - label = QLabel(txt) - self.channelEmWLayouts[0].addWidget(label, alignment=Qt.AlignRight) - self.channelEmWLayouts[1].addWidget(emWavelen_DSB) - self.emWavelens_DSBs.append(emWavelen_DSB) - - unit = QLabel('nm') - unit.setStyleSheet('font-size:13px; padding:5px 0px 2px 0px;') - self.channelEmWLayouts[2].addWidget(unit) - - entriesLayout.setContentsMargins(0, 15, 0, 0) - - if rawDataStruct is None or rawDataStruct!=-1: - okButton = widgets.okPushButton(' Ok ') - elif rawDataStruct==1: - okButton = QPushButton(' Load next position ') - buttonsLayout.addWidget(okButton, 0, 1) - - self.trustButton = None - self.overWriteButton = None - if rawDataStruct==1: - trustButton = QPushButton( - ' Trust metadata reader\n for all next positions ') - trustButton.setToolTip( - "If you didn't have to manually modify metadata entries\n" - "it is very likely that metadata from the metadata reader\n" - "will be correct also for all the next positions.\n\n" - "Click this button to stop showing this dialog and use\n" - "the metadata from the reader\n" - "(except for channel names, I will use the manually entered)" - ) - buttonsLayout.addWidget(trustButton, 1, 1) - self.trustButton = trustButton - - overWriteButton = QPushButton( - ' Use the above metadata\n for all the next positions ') - overWriteButton.setToolTip( - "If you had to manually modify metadata entries\n" - "AND you know they will be the same for all next positions\n" - "you can click this button to stop showing this dialog\n" - "and use the same metadata for all the next positions." - ) - buttonsLayout.addWidget(overWriteButton, 1, 2) - self.overWriteButton = overWriteButton - - trustButton.clicked.connect(self.ok_cb) - overWriteButton.clicked.connect(self.ok_cb) - - cancelButton = widgets.cancelPushButton('Cancel') - buttonsLayout.addWidget(cancelButton, 0, 2) - buttonsLayout.setColumnStretch(0, 1) - buttonsLayout.setColumnStretch(3, 1) - buttonsLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(entriesLayout) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch(1) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - - self.hideShowTimeIncrement(SizeT) - self.readSampleImgDataAgain = False - - self.setLayout(mainLayout) - # self.setModal(True) - - def saveCh_checkBox_cb(self, state): - self.checkChNames() - idx = self.saveChannels_QCBs.index(self.sender()) - LE = self.chNames_QLEs[idx] - idx *= 2 - LE.setDisabled(state==0) - label = self.channelNameLayouts[0].itemAt(idx).widget() - if state == 0: - label.setStyleSheet('color: gray; font-size: 10pt') - else: - label.setStyleSheet('color: black; font-size: 10pt') - - label = self.channelNameLayouts[0].itemAt(idx+1).widget() - if state == 0: - label.setStyleSheet('color: gray; font-size: 10pt') - else: - label.setStyleSheet('color: black; font-size: 10pt') - - label = self.channelNameLayouts[1].itemAt(idx+1).widget() - if state == 0: - label.setStyleSheet('color: gray; font-size: 10pt') - else: - label.setStyleSheet('color: black; font-size: 10pt') - - def addImageName_cb(self, state): - for idx in range(self.SizeC_SB.value()): - self.updateFilename(idx) - - def setInvalidChName_StyleSheet(self, LE): - LE.setStyleSheet( - 'border-radius: 4px;' - 'border: 1.5px solid red;' - 'padding: 1px 0px 1px 0px' - ) - - def removeInvalidCharacters(self, chName): - # Remove invalid charachters - chName = "".join( - c if c.isalnum() or c=='_' or c=='' else '_' for c in chName - ) - trim_ = chName.endswith('_') - while trim_: - chName = chName[:-1] - trim_ = chName.endswith('_') - return chName - - def updateFileFormat(self, is_h5): - for idx in range(len(self.chNames_QLEs)): - self.updateFilename(idx) - - def SizeSvalueChanged(self, SizeS): - positions = ['All positions'] - positions.extend([f'Position_{i+1}' for i in range(SizeS)]) - self.posSelector.setItems(positions) - - def elidedRawFilename(self): - n = 31 - idx = int((n-3)/2) - if len(self.rawFilename) > 21: - elidedText = f'{self.rawFilename[:idx]}...{self.rawFilename[-idx:]}' - else: - elidedText = self.rawFilename - return elidedText - - def updateFilename(self, idx): - chName = self.chNames_QLEs[idx].text() - chName = self.removeInvalidCharacters(chName) - if self.rawDataStruct == 2: - rawFilename = f'{self.rawFilename}_s{idx+1}' - else: - rawFilename = self.rawFilename - - ext = 'h5' if self.to_h5_radiobutton.isChecked() else 'tif' - - rawFilename = self.elidedRawFilename() - - filenameLabel = self.filename_QLabels[idx] - if self.addImageName_QCB.isChecked(): - self.ImageName = self.removeInvalidCharacters(self.ImageName) - filename = (f""" -

- {rawFilename}_{self.ImageName}_{chName}.{ext} -

- """) - fullFilename = f'{self.rawFilename}_{self.ImageName}_{chName}.{ext}' - else: - filename = (f""" -

- {rawFilename}_{chName}.{ext} -

- """) - fullFilename = f'{self.rawFilename}_{chName}.{ext}' - filenameLabel.setToolTip(fullFilename) - filenameLabel.setText(filename) - - def checkChNames(self, text=''): - if self.sender() in self.chNames_QLEs: - idx = self.chNames_QLEs.index(self.sender()) - self.updateFilename(idx) - elif self.sender() in self.saveChannels_QCBs: - idx = self.saveChannels_QCBs.index(self.sender()) - self.updateFilename(idx) - - - areChNamesValid = True - if len(self.chNames_QLEs) == 1: - LE1 = self.chNames_QLEs[0] - saveCh = self.saveChannels_QCBs[0].isChecked() - if not saveCh: - LE1.setStyleSheet('') - return areChNamesValid - - s1 = LE1.text() - if not s1: - self.setInvalidChName_StyleSheet(LE1) - areChNamesValid = False - else: - LE1.setStyleSheet('') - return areChNamesValid - - for LE1, LE2 in combinations(self.chNames_QLEs, 2): - s1 = LE1.text() - s2 = LE2.text() - LE1_idx = self.chNames_QLEs.index(LE1) - LE2_idx = self.chNames_QLEs.index(LE2) - saveCh1 = self.saveChannels_QCBs[LE1_idx].isChecked() - saveCh2 = self.saveChannels_QCBs[LE2_idx].isChecked() - if not s1 or not s2 or s1==s2: - if not s1 and saveCh1: - self.setInvalidChName_StyleSheet(LE1) - areChNamesValid = False - else: - LE1.setStyleSheet('') - if not s2 and saveCh2: - self.setInvalidChName_StyleSheet(LE2) - areChNamesValid = False - else: - LE2.setStyleSheet('') - if s1 == s2 and saveCh1 and saveCh2: - self.setInvalidChName_StyleSheet(LE1) - self.setInvalidChName_StyleSheet(LE2) - areChNamesValid = False - else: - LE1.setStyleSheet('') - LE2.setStyleSheet('') - return areChNamesValid - - def hideShowTimeIncrement(self, value): - if self.TimeIncrement_DSB.isVisible() and value == 1: - self.readSampleImgDataAgain = True - - if not self.TimeIncrement_DSB.isVisible() and value > 1: - self.readSampleImgDataAgain = True - - if value > 1: - self.TimeIncrement_DSB.show() - self.TimeIncrementUnit_CB.show() - self.TimeIncrement_Label.show() - self.timeRangeToSaveWidget.show() - self.timeRangeToSaveWidget.label.show() - self.timeRangeToSaveWidget.setRange(1, value) - else: - self.TimeIncrement_DSB.hide() - self.TimeIncrementUnit_CB.hide() - self.TimeIncrement_Label.hide() - self.timeRangeToSaveWidget.hide() - self.timeRangeToSaveWidget.label.hide() - - def hideShowPhysicalSizeZ(self, value): - if value > 1: - self.PSZlabel.show() - self.PhysicalSizeZ_DSB.show() - self.PhysicalSizeZUnit_Label.show() - else: - self.PSZlabel.hide() - self.PhysicalSizeZ_DSB.hide() - self.PhysicalSizeZUnit_Label.hide() - self.readSampleImgDataAgain = True - - def updatePSUnit(self, unit): - self.PhysicalSizeYUnit_Label.setText(unit) - self.PhysicalSizeZUnit_Label.setText(unit) - - def warnRestart(self): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(""" - Since you manually changed some of the metadata, this dialogue will now restart
- because it needs to read the image data again.

- Thank you for your patience. - """) - msg.warning(self, 'Restart required', txt) - - def showChannelData(self, checked=False, idx=None): - if self.readSampleImgDataAgain: - # User changed SizeZ, SizeT, or SizeC --> we need to read sample - # image again - del self.sampleImgData - self.requestedReadingSampleImageDataAgain = True - self.sampleImgData = None - self.warnRestart() - self.getValues() - self.cancel = False - self.close() - return - - if idx is None: - idx = self.showChannelDataButtons.index(self.sender()) - dimsOrder = 'ctz' - imgData = self.sampleImgData[dimsOrder][idx] - posData = myutils.utilClass() - posData.frame_i = 0 - sampleSizeT = 4 if self.SizeT_SB.value() >= 4 else self.SizeT_SB.value() - posData.SizeT = sampleSizeT - SizeZ = self.SizeZ_SB.value() - posData.SizeZ = 20 if SizeZ>20 else SizeZ - posData.filename = f'{self.rawFilename}_C={idx}' - posData.segmInfo_df = pd.DataFrame({ - 'filename': [posData.filename]*sampleSizeT, - 'frame_i': range(sampleSizeT), - 'which_z_proj_gui': ['single z-slice']*sampleSizeT, - 'z_slice_used_gui': [int(posData.SizeZ/2)]*sampleSizeT - }).set_index(['filename', 'frame_i']) - path_li = os.path.normpath(self.rawFilePath).split(os.sep) - posData.relPath = f'{f"{os.sep}".join(path_li[-3:1])}' - posData.relPath = f'{posData.relPath}{os.sep}{posData.filename}' - if sampleSizeT == 1: - posData.img_data = [imgData] # single frame data - else: - posData.img_data = imgData - - if self.imageViewer is not None: - self.imageViewer.close() - - self.imageViewer = imageViewer( - posData=posData, isSigleFrame=False, enableOverlay=False - ) - self.imageViewer.channelIndex = idx - self.imageViewer.update_img() - self.imageViewer.sigClosed.connect(self.imageViewerClosed) - self.imageViewer.show() - - def imageViewerClosed(self): - self.imageViewer = None - - def addRemoveChannels(self, value): - self.readSampleImgDataAgain = True - currentSizeC = len(self.chNames_QLEs) - DeltaChannels = abs(value-currentSizeC) - ext = 'h5' if self.to_h5_radiobutton.isChecked() else 'tif' - if value > currentSizeC: - for c in range(currentSizeC, currentSizeC+DeltaChannels): - chName_QLE = QLineEdit() - chName_QLE.setStyleSheet('') - chName_QLE.setAlignment(Qt.AlignCenter) - chName_QLE.setText(f'channel_{c}') - chName_QLE.textChanged.connect(self.checkChNames) - - txt = f'Channel {c} name: ' - label = QLabel(txt) - - filenameDescLabel = QLabel( - f'e.g., filename for channel {c}: ' - ) - - chName = chName_QLE.text() - rawFilename = self.elidedRawFilename() - filenameLabel = QLabel(f""" -

{rawFilename}_{chName}.{ext}

- """) - filenameLabel.setToolTip(f'{self.rawFilename}_{chName}.{ext}') - - checkBox = QCheckBox('Save this channel') - checkBox.setChecked(True) - checkBox.stateChanged.connect(self.saveCh_checkBox_cb) - - self.channelNameLayouts[0].addWidget(label, alignment=Qt.AlignRight) - self.channelNameLayouts[0].addWidget( - filenameDescLabel, alignment=Qt.AlignRight - ) - self.channelNameLayouts[1].addWidget(chName_QLE) - self.channelNameLayouts[1].addWidget( - filenameLabel, alignment=Qt.AlignCenter - ) - - self.channelNameLayouts[2].addWidget(checkBox) - self.channelNameLayouts[2].addWidget(QLabel()) - - showChannelDataButton = QPushButton() - showChannelDataButton.setIcon(QIcon(":eye-plus.svg")) - showChannelDataButton.clicked.connect(self.showChannelData) - self.channelNameLayouts[3].addWidget(showChannelDataButton) - if self.sampleImgData is None: - showChannelDataButton.setDisabled(True) - - self.chNames_QLEs.append(chName_QLE) - self.saveChannels_QCBs.append(checkBox) - self.filename_QLabels.append(filenameLabel) - self.showChannelDataButtons.append(showChannelDataButton) - - emWavelen_DSB = QDoubleSpinBox() - emWavelen_DSB.setAlignment(Qt.AlignCenter) - emWavelen_DSB.setMaximum(2147483647.0) - emWavelen_DSB.setSingleStep(0.001) - emWavelen_DSB.setDecimals(2) - emWavelen_DSB.setValue(500.0) - unit = QLabel('nm') - unit.setStyleSheet('font-size:13px; padding:5px 0px 2px 0px;') - - txt = f'Channel {c} emission wavelength: ' - label = QLabel(txt) - self.channelEmWLayouts[0].addWidget(label, alignment=Qt.AlignRight) - self.channelEmWLayouts[1].addWidget(emWavelen_DSB) - self.channelEmWLayouts[2].addWidget(unit) - self.emWavelens_DSBs.append(emWavelen_DSB) - else: - for c in range(currentSizeC, currentSizeC+DeltaChannels): - idx = (c-1)*2 - label1 = self.channelNameLayouts[0].itemAt(idx).widget() - label2 = self.channelNameLayouts[0].itemAt(idx+1).widget() - chName_QLE = self.channelNameLayouts[1].itemAt(idx).widget() - filename_L = self.channelNameLayouts[1].itemAt(idx+1).widget() - checkBox = self.channelNameLayouts[2].itemAt(idx).widget() - dummyLabel = self.channelNameLayouts[2].itemAt(idx+1).widget() - showButton = self.showChannelDataButtons[-1] - showButton.clicked.disconnect() - - self.channelNameLayouts[0].removeWidget(label1) - self.channelNameLayouts[0].removeWidget(label2) - self.channelNameLayouts[1].removeWidget(chName_QLE) - self.channelNameLayouts[1].removeWidget(filename_L) - self.channelNameLayouts[2].removeWidget(checkBox) - self.channelNameLayouts[2].removeWidget(dummyLabel) - self.channelNameLayouts[3].removeWidget(showButton) - - self.chNames_QLEs.pop(-1) - self.saveChannels_QCBs.pop(-1) - self.filename_QLabels.pop(-1) - self.showChannelDataButtons.pop(-1) - - label = self.channelEmWLayouts[0].itemAt(c-1).widget() - emWavelen_DSB = self.channelEmWLayouts[1].itemAt(c-1).widget() - unit = self.channelEmWLayouts[2].itemAt(c-1).widget() - self.channelEmWLayouts[0].removeWidget(label) - self.channelEmWLayouts[1].removeWidget(emWavelen_DSB) - self.channelEmWLayouts[2].removeWidget(unit) - self.emWavelens_DSBs.pop(-1) - - self.adjustSize() - - def ok_cb(self, event): - areChNamesValid = self.checkChNames() - if not areChNamesValid: - err_msg = html_utils.paragraph( - 'Channel names cannot be empty or equal to each other.' - '

' - 'Insert a unique text for each channel name.' - ) - msg = widgets.myMessageBox() - msg.critical( - self, 'Invalid channel names', err_msg - ) - return - - self.getValues() - self.convertUnits() - - if self.sender() == self.trustButton: - self.trust = True - elif self.sender() == self.overWriteButton: - self.overWrite = True - - self.cancel = False - self.close() - - def getValues(self): - self.LensNA = self.LensNA_DSB.value() - self.SizeT = self.SizeT_SB.value() - self.SizeZ = self.SizeZ_SB.value() - self.SizeC = self.SizeC_SB.value() - self.SizeS = self.SizeS_SB.value() - self.timeRangeToSave = self.timeRangeToSaveWidget.range() - self.TimeIncrement = self.TimeIncrement_DSB.value() - self.PhysicalSizeX = self.PhysicalSizeX_DSB.value() - self.PhysicalSizeY = self.PhysicalSizeY_DSB.value() - self.PhysicalSizeZ = self.PhysicalSizeZ_DSB.value() - self.to_h5 = self.to_h5_radiobutton.isChecked() - if hasattr(self, 'posSelector'): - self.selectedPos = self.posSelector.selectedItemsText() - else: - self.selectedPos = ['All Positions'] - self.chNames = [] - if hasattr(self, 'addImageName_QCB'): - self.addImageName = self.addImageName_QCB.isChecked() - else: - self.addImageName = False - self.saveChannels = [] - for LE, QCB in zip(self.chNames_QLEs, self.saveChannels_QCBs): - s = LE.text() - s = "".join(c if c.isalnum() or c=='_' or c=='' else '_' for c in s) - trim_ = s.endswith('_') - while trim_: - s = s[:-1] - trim_ = s.endswith('_') - self.chNames.append(s) - self.saveChannels.append(QCB.isChecked()) - self.emWavelens = [DSB.value() for DSB in self.emWavelens_DSBs] - - def convertUnits(self): - timeUnit = self.TimeIncrementUnit_CB.currentText() - if timeUnit == 'ms': - self.TimeIncrement /= 1000 - elif timeUnit == 'minutes': - self.TimeIncrement *= 60 - elif timeUnit == 'hours': - self.TimeIncrement *= 3600 - - PhysicalSizeUnit = self.PhysicalSizeUnit_CB.currentText() - if timeUnit == 'nm': - self.PhysicalSizeX /= 1000 - self.PhysicalSizeY /= 1000 - self.PhysicalSizeZ /= 1000 - elif timeUnit == 'mm': - self.PhysicalSizeX *= 1000 - self.PhysicalSizeY *= 1000 - self.PhysicalSizeZ *= 1000 - elif timeUnit == 'cm': - self.PhysicalSizeX *= 1e4 - self.PhysicalSizeY *= 1e4 - self.PhysicalSizeZ *= 1e4 - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def setSize(self): - h = self.SizeS_SB.height() - self.TimeIncrement_DSB.setMinimumHeight(h) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - self.setSize() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class CellACDCTrackerParamsWin(QDialog): - def __init__(self, parent=None): - self.cancel = True - super().__init__(parent) - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - self.setWindowTitle('Cell-ACDC tracker parameters') - - paramsLayout = QGridLayout() - paramsBox = QGroupBox() - - row = 0 - label = QLabel(html_utils.paragraph( - 'Minimum overlap between objects' - )) - paramsLayout.addWidget(label, row, 0) - maxOverlapSpinbox = QDoubleSpinBox() - maxOverlapSpinbox.setAlignment(Qt.AlignCenter) - maxOverlapSpinbox.setMinimum(0) - maxOverlapSpinbox.setMaximum(1) - maxOverlapSpinbox.setSingleStep(0.1) - maxOverlapSpinbox.setValue(0.4) - self.maxOverlapSpinbox = maxOverlapSpinbox - paramsLayout.addWidget(maxOverlapSpinbox, row, 1) - infoButton = widgets.infoPushButton() - infoButton.clicked.connect(self.showInfo) - paramsLayout.addWidget(infoButton, row, 2) - paramsLayout.setColumnStretch(0, 0) - paramsLayout.setColumnStretch(1, 1) - paramsLayout.setColumnStretch(2, 0) - - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton(' Ok ') - cancelButton.clicked.connect(self.cancel_cb) - okButton.clicked.connect(self.ok_cb) - - buttonsLayout = QHBoxLayout() - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - layout = QVBoxLayout() - infoText = html_utils.paragraph('Cell-ACDC tracker parameters') - infoLabel = QLabel(infoText) - layout.addWidget(infoLabel, alignment=Qt.AlignCenter) - layout.addSpacing(10) - paramsBox.setLayout(paramsLayout) - layout.addWidget(paramsBox) - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - layout.addStretch(1) - self.setLayout(layout) - self.setFont(font) - - def showInfo(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'Cell-ACDC tracker computes the percentage of overlap between ' - 'all the objects
at frame n and all the ' - 'objects in previous frame n-1.

' - 'All objects with overlap less than ' - 'Minimum overlap between objects
are considered ' - 'new objects.

' - 'Set this value to 0 if you want to force tracking of ALL the ' - 'objects
in the previous frame (e.g., if cells move a lot ' - 'between frames)' - ) - msg.information(self, 'Cell-ACDC tracker info', txt) - - def ok_cb(self, checked=False): - self.cancel = False - self.params = {'IoA_thresh': self.maxOverlapSpinbox.value()} - self.close() - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - super().show() - self.resize(int(self.width()*1.3), self.height()) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class BayesianTrackerParamsWin(QDialog): - def __init__( - self, segmShape, parent=None, channels=None, - currentChannelName=None - ): - self.cancel = True - super().__init__(parent) - - self.channels = channels - self.currentChannelName = currentChannelName - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - self.setWindowTitle('Bayesian tracker parameters') - - paramsLayout = QGridLayout() - paramsBox = QGroupBox() - - row = 0 - this_path = os.path.dirname(os.path.abspath(__file__)) - default_model_path = os.path.join( - this_path, 'trackers', 'BayesianTracker', - 'model', 'cell_config.json' - ) - label = QLabel(html_utils.paragraph('Model path')) - paramsLayout.addWidget(label, row, 0) - modelPathLineEdit = QLineEdit() - start_dir = '' - if os.path.exists(default_model_path): - start_dir = os.path.dirname(default_model_path) - modelPathLineEdit.setText(default_model_path) - self.modelPathLineEdit = modelPathLineEdit - paramsLayout.addWidget(modelPathLineEdit, row, 1) - browseButton = widgets.browseFileButton( - title='Select Bayesian Tracker model file', - ext={'JSON Config': ('.json',)}, - start_dir=start_dir - ) - browseButton.sigPathSelected.connect(self.onPathSelected) - paramsLayout.addWidget(browseButton, row, 2, alignment=Qt.AlignLeft) - - if self.channels is not None: - row += 1 - label = QLabel(html_utils.paragraph('Intensity image channel: ')) - paramsLayout.addWidget(label, row, 0) - items = ['None', *self.channels] - self.channelCombobox = widgets.QCenteredComboBox() - self.channelCombobox.addItems(items) - paramsLayout.addWidget(self.channelCombobox, row, 1) - if self.currentChannelName is not None: - self.channelCombobox.setCurrentText(self.currentChannelName) - - row += 1 - label = QLabel(html_utils.paragraph('Features')) - paramsLayout.addWidget(label, row, 0) - selectFeaturesButton = widgets.setPushButton('Select features') - paramsLayout.addWidget(selectFeaturesButton, row, 1) - self.features = [] - selectFeaturesButton.clicked.connect(self.selectFeatures) - - row += 1 - label = QLabel(html_utils.paragraph('Verbose')) - paramsLayout.addWidget(label, row, 0) - verboseToggle = widgets.Toggle() - verboseToggle.setChecked(True) - self.verboseToggle = verboseToggle - paramsLayout.addWidget(verboseToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Run optimizer')) - paramsLayout.addWidget(label, row, 0) - optimizeToggle = widgets.Toggle() - optimizeToggle.setChecked(True) - self.optimizeToggle = optimizeToggle - paramsLayout.addWidget(optimizeToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Max search radius')) - paramsLayout.addWidget(label, row, 0) - maxSearchRadiusSpinbox = QSpinBox() - maxSearchRadiusSpinbox.setAlignment(Qt.AlignCenter) - maxSearchRadiusSpinbox.setMinimum(1) - maxSearchRadiusSpinbox.setMaximum(2147483647) - maxSearchRadiusSpinbox.setValue(50) - self.maxSearchRadiusSpinbox = maxSearchRadiusSpinbox - self.maxSearchRadiusSpinbox.setDisabled(True) - paramsLayout.addWidget(maxSearchRadiusSpinbox, row, 1) - - row += 1 - Z, Y, X = segmShape - label = QLabel(html_utils.paragraph('Tracking volume')) - paramsLayout.addWidget(label, row, 0) - volumeLineEdit = QLineEdit() - defaultVol = f' (0, {X}), (0, {Y}) ' - if Z > 1: - defaultVol = f'{defaultVol}, (0, {Z}) ' - volumeLineEdit.setText(defaultVol) - volumeLineEdit.setAlignment(Qt.AlignCenter) - self.volumeLineEdit = volumeLineEdit - paramsLayout.addWidget(volumeLineEdit, row, 1) - - row += 1 - label = QLabel(html_utils.paragraph('Interactive mode step size')) - paramsLayout.addWidget(label, row, 0) - stepSizeSpinbox = QSpinBox() - stepSizeSpinbox.setAlignment(Qt.AlignCenter) - stepSizeSpinbox.setMinimum(1) - stepSizeSpinbox.setMaximum(2147483647) - stepSizeSpinbox.setValue(100) - self.stepSizeSpinbox = stepSizeSpinbox - paramsLayout.addWidget(stepSizeSpinbox, row, 1) - - row += 1 - label = QLabel(html_utils.paragraph('Update method')) - paramsLayout.addWidget(label, row, 0) - updateMethodCombobox = QComboBox() - updateMethodCombobox.addItems(['EXACT', 'APPROXIMATE']) - self.updateMethodCombobox = updateMethodCombobox - self.updateMethodCombobox.currentTextChanged.connect(self.methodChanged) - paramsLayout.addWidget(updateMethodCombobox, row, 1) - - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton(' Ok ') - cancelButton.clicked.connect(self.cancel_cb) - okButton.clicked.connect(self.ok_cb) - - buttonsLayout = QHBoxLayout() - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - layout = QVBoxLayout() - infoText = html_utils.paragraph('Bayesian Tracker parameters') - infoLabel = QLabel(infoText) - layout.addWidget(infoLabel, alignment=Qt.AlignCenter) - layout.addSpacing(10) - paramsBox.setLayout(paramsLayout) - layout.addWidget(paramsBox) - - url = 'https://btrack.readthedocs.io/en/latest/index.html' - moreInfoText = html_utils.paragraph( - 'Find more info on the Bayesian Tracker\'s ' - f'home page' - ) - moreInfoLabel = QLabel(moreInfoText) - moreInfoLabel.setOpenExternalLinks(True) - layout.addWidget(moreInfoLabel, alignment=Qt.AlignCenter) - - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - layout.addStretch(1) - self.setLayout(layout) - self.setFont(font) - - def selectFeatures(self): - features = measurements.get_btrack_features() - selectWin = widgets.QDialogListbox( - 'Select features', - 'Select features to use for tracking:\n', - features, multiSelection=True, parent=self, - includeSelectionHelp=True - ) - for i in range(selectWin.listBox.count()): - item = selectWin.listBox.item(i) - if item.text() in self.features: - item.setSelected(True) - selectWin.exec_() - if selectWin.cancel: - return - self.features = selectWin.selectedItemsText - - def methodChanged(self, method): - if method == 'APPROXIMATE': - self.maxSearchRadiusSpinbox.setDisabled(False) - else: - self.maxSearchRadiusSpinbox.setDisabled(True) - - def onPathSelected(self, path): - self.modelPathLineEdit.setText(path) - - def ok_cb(self, checked=False): - self.cancel = False - try: - m = re.findall(r'\((\d+), *(\d+)\)', self.volumeLineEdit.text()) - if len(m) < 2: - raise - self.volume = tuple([(int(start), int(end)) for start, end in m]) - if len(self.volume) == 2: - self.volume = (self.volume[0], self.volume[1], (-1e5, 1e5)) - except Exception as e: - self.warnNotAcceptedVolume() - return - - if not os.path.exists(self.modelPathLineEdit.text()): - self.warnNotVaidPath() - return - - self.intensityImageChannel = None - self.verbose = self.verboseToggle.isChecked() - self.max_search_radius = self.maxSearchRadiusSpinbox.value() - self.update_method = self.updateMethodCombobox.currentText() - self.model_path = os.path.normpath(self.modelPathLineEdit.text()) - self.params = { - 'model_path': self.model_path, - 'verbose': self.verbose, - 'volume': self.volume, - 'max_search_radius': self.max_search_radius, - 'update_method': self.update_method, - 'step_size': self.stepSizeSpinbox.value(), - 'optimize': self.optimizeToggle.isChecked(), - 'features': self.features - } - if self.channels is not None: - if self.channelCombobox.currentText() != 'None': - self.intensityImageChannel = self.channelCombobox.currentText() - self.close() - - def warnNotVaidPath(self): - url = 'https://github.com/lowe-lab-ucl/segment-classify-track/tree/main/models' - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'The model configuration file path

' - f'{self.modelPathLineEdit.text()}

' - 'does not exist.

' - 'You can find some pre-configured models ' - f'here.' - ) - msg.critical( - self, 'Invalid volume', txt - ) - - def warnNotAcceptedVolume(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - f'{self.volumeLineEdit.text()} is not a valid volume!

' - 'Valid volume is for example (0, 2048), (0, 2048)
' - 'for 2D segmentation or (0, 2048), (0, 2048), (0, 2048)
' - 'for 3D segmentation.' - ) - msg.critical( - self, 'Invalid volume', txt - ) - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - super().show() - self.resize(int(self.width()*1.3), self.height()) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class DeltaTrackerParamsWin(QDialog): - - def __init__(self, posData=None, parent=None): - self.cancel = True - super().__init__(parent) - - self.setWindowFlags(Qt.Dialog | Qt.WindowStaysOnTopHint) - self.setWindowTitle('Delta tracker parameters') - - paramsLayout = QGridLayout() - paramsBox = QGroupBox() - - row = 0 - this_path = os.path.dirname(os.path.abspath(__file__)) - default_model_path = this_path - - label = QLabel(html_utils.paragraph('Original Images path')) - paramsLayout.addWidget(label, row, 0) - modelPathLineEdit = QLineEdit() - start_dir = '' - if os.path.exists(default_model_path): - start_dir = os.path.dirname(default_model_path) - modelPathLineEdit.setText(default_model_path) - self.modelPathLineEdit = modelPathLineEdit - paramsLayout.addWidget(modelPathLineEdit, row, 1) - browseButton = widgets.browseFileButton( - title='Select Original Images', - ext={'TIFF': ('.tif',)}, - start_dir=start_dir - ) - if posData is not None: - modelPathLineEdit.setText(posData.imgPath) - browseButton.sigPathSelected.connect(self.onPathSelected) - paramsLayout.addWidget(browseButton, row, 2, alignment=Qt.AlignLeft) - - row += 1 - label = QLabel(html_utils.paragraph('Model Type')) - paramsLayout.addWidget(label, row, 0) - updateMethodCombobox = QComboBox() - updateMethodCombobox.addItems(['2D', 'mothermachine']) - self.model_type = '2D' - self.updateMethodCombobox = updateMethodCombobox - self.updateMethodCombobox.currentTextChanged.connect(self.methodChanged) - paramsLayout.addWidget(updateMethodCombobox, row, 1) - - row += 1 - label = QLabel(html_utils.paragraph('Single Mother Machine Chamber?')) - paramsLayout.addWidget(label, row, 0) - chamberToggle = widgets.Toggle() - chamberToggle.setChecked(True) - self.chamberToggle = chamberToggle - paramsLayout.addWidget(chamberToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Verbose')) - paramsLayout.addWidget(label, row, 0) - verboseToggle = widgets.Toggle() - verboseToggle.setChecked(True) - self.verboseToggle = verboseToggle - paramsLayout.addWidget(verboseToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Legacy Save (.mat)')) - paramsLayout.addWidget(label, row, 0) - legacyToggle = widgets.Toggle() - legacyToggle.setChecked(False) - self.legacyToggle = legacyToggle - paramsLayout.addWidget(legacyToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Pickle (.pkl)')) - paramsLayout.addWidget(label, row, 0) - pickleToggle = widgets.Toggle() - pickleToggle.setChecked(False) - self.pickleToggle = pickleToggle - paramsLayout.addWidget(pickleToggle, row, 1, alignment=Qt.AlignCenter) - - row += 1 - label = QLabel(html_utils.paragraph('Movie (.mp4) *only for 2D images')) - paramsLayout.addWidget(label, row, 0) - movieToggle = widgets.Toggle() - movieToggle.setChecked(False) - self.movieToggle = movieToggle - paramsLayout.addWidget(movieToggle, row, 1, alignment=Qt.AlignCenter) - - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton(' Ok ') - cancelButton.clicked.connect(self.cancel_cb) - okButton.clicked.connect(self.ok_cb) - - buttonsLayout = QHBoxLayout() - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - layout = QVBoxLayout() - infoText = html_utils.paragraph('Delta Tracker parameters') - infoLabel = QLabel(infoText) - layout.addWidget(infoLabel, alignment=Qt.AlignCenter) - layout.addSpacing(10) - paramsBox.setLayout(paramsLayout) - layout.addWidget(paramsBox) - - url = 'https://delta.readthedocs.io/en/latest/' - moreInfoText = html_utils.paragraph( - 'Find more info on Delta Tracker\'s ' - f'home page' - ) - moreInfoLabel = QLabel(moreInfoText) - moreInfoLabel.setOpenExternalLinks(True) - layout.addWidget(moreInfoLabel, alignment=Qt.AlignCenter) - - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - layout.addStretch(1) - self.setLayout(layout) - self.setFont(font) - - def methodChanged(self, method): - if method == 'mothermachine': - self.model_type = 'mothermachine' - - def onPathSelected(self, path): - self.modelPathLineEdit.setText(path) - - def ok_cb(self, checked=False): - self.cancel = False - - if not os.path.exists(self.modelPathLineEdit.text()): - self.warnNotVaidPath() - return - - self.verbose = self.verboseToggle.isChecked() - self.legacy = self.legacyToggle.isChecked() - self.pickle = self.pickleToggle.isChecked() - self.movie = self.movieToggle.isChecked() - self.chamber = self.chamberToggle.isChecked() - self.model_path = os.path.normpath(self.modelPathLineEdit.text()) - self.params = { - 'original_images_path': self.model_path, - 'verbose': self.verbose, - 'legacy': self.legacy, - 'pickle': self.pickle, - 'movie': self.movie, - 'model_type': self.model_type, - 'single mothermachine chamber': self.chamber - } - self.close() - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - super().show() - self.resize(int(self.width()*1.3), self.height()) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QDialogWorkerProgress(QDialog): - sigClosed = Signal(bool) - - def __init__( - self, title='Progress', infoTxt='', - showInnerPbar=False, pbarDesc='', - parent=None - ): - self.workerFinished = False - self.aborted = False - self.clickCount = 0 - super().__init__(parent) - - abort_text = 'Option+Command+C' if is_mac else 'Ctrl+Alt+C' - self.abort_text = abort_text - - self.setWindowTitle(f'{title} ({abort_text} to abort)') - self.setWindowFlags(Qt.Window) - - mainLayout = QVBoxLayout() - pBarLayout = QGridLayout() - - if infoTxt: - infoLabel = QLabel(infoTxt) - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - - self.progressLabel = QLabel(pbarDesc) - - self.mainPbar = widgets.ProgressBarWithETA(self) - self.mainPbar.setValue(0) - pBarLayout.addWidget(self.mainPbar, 0, 0) - pBarLayout.addWidget(self.mainPbar.ETA_label, 0, 1) - - self.innerPbar = widgets.ProgressBarWithETA(self) - self.innerPbar.setValue(0) - pBarLayout.addWidget(self.innerPbar, 1, 0) - pBarLayout.addWidget(self.innerPbar.ETA_label, 1, 1) - if showInnerPbar: - self.innerPbar.show() - else: - self.innerPbar.hide() - - self.logConsole = widgets.QLogConsole() - - mainLayout.addWidget(self.progressLabel) - mainLayout.addLayout(pBarLayout) - mainLayout.addWidget(self.logConsole) - - self.setLayout(mainLayout) - # self.setModal(True) - - def keyPressEvent(self, event): - isCtrlAlt = event.modifiers() == (Qt.ControlModifier | Qt.AltModifier) - if isCtrlAlt and event.key() == Qt.Key_C: - doAbort = self.askAbort() - if doAbort: - self.aborted = True - self.workerFinished = True - self.close() - - def askAbort(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - Aborting with {self.abort_text} to abort is - not safe.

- The system status cannot be predicted and - it will require a restart.

- Are you sure you want to abort? - """) - yesButton, noButton = msg.critical( - self, 'Are you sure you want to abort?', txt, - buttonsTexts=('Yes', 'No') - ) - return msg.clickedButton == yesButton - - def closeEvent(self, event): - if not self.workerFinished: - event.ignore() - return - - self.sigClosed.emit(self.aborted) - - def log(self, text): - self.logConsole.append(text) - - def show(self, app): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - QDialog.show(self) - screen = app.primaryScreen() - screenWidth = screen.size().width() - screenHeight = screen.size().height() - parentGeometry = self.parent().geometry() - mainWinLeft, mainWinWidth = parentGeometry.left(), parentGeometry.width() - mainWinTop, mainWinHeight = parentGeometry.top(), parentGeometry.height() - mainWinCenterX = int(mainWinLeft+mainWinWidth/2) - mainWinCenterY = int(mainWinTop+mainWinHeight/2) - - width = int(screenWidth/3) - width = width if self.width() < width else self.width() - height = int(screenHeight/3) - left = int(mainWinCenterX - width/2) - left = left if left >= 0 else 0 - top = int(mainWinCenterY - height/2) - - self.setGeometry(left, top, width, height) - -class QDialogCombobox(QDialog): - def __init__( - self, title, ComboBoxItems, informativeText, - CbLabel='Select value: ', parent=None, - defaultChannelName=None, iconPixmap=None, centeredCombobox=False - ): - self.cancel = True - self.selectedItemText = '' - self.selectedItemIdx = None - super().__init__(parent=parent) - self.setWindowTitle(title) - - mainLayout = QVBoxLayout() - infoLayout = QHBoxLayout() - topLayout = QHBoxLayout() - bottomLayout = QHBoxLayout() - - self.mainLayout = mainLayout - - if iconPixmap is not None: - label = QLabel() - # padding: top, left, bottom, right - # label.setStyleSheet("padding:5px 0px 12px 0px;") - label.setPixmap(iconPixmap) - infoLayout.addWidget(label) - - if informativeText: - infoLabel = QLabel(informativeText) - infoLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - - if CbLabel: - label = QLabel(CbLabel) - topLayout.addWidget(label, alignment=Qt.AlignRight) - - if centeredCombobox: - combobox = widgets.QCenteredComboBox() - else: - combobox = QComboBox() - combobox.addItems(ComboBoxItems) - if defaultChannelName is not None and defaultChannelName in ComboBoxItems: - combobox.setCurrentText(defaultChannelName) - self.ComboBox = combobox - topLayout.addWidget(combobox) - topLayout.setContentsMargins(0, 10, 0, 0) - - okButton = widgets.okPushButton('Ok') - - cancelButton = widgets.cancelPushButton('Cancel') - - bottomLayout.addStretch(1) - bottomLayout.addWidget(cancelButton) - bottomLayout.addSpacing(20) - bottomLayout.addWidget(okButton) - bottomLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(infoLayout) - mainLayout.addLayout(topLayout) - mainLayout.addLayout(bottomLayout) - self.setLayout(mainLayout) - - # self.setModal(True) - - # Connect events - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - self.loop = None - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - self.setFont(font) - - def ok_cb(self, checked=False): - self.cancel = False - self.selectedItemText = self.ComboBox.currentText() - self.selectedItemIdx = self.ComboBox.currentIndex() - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - QDialog.show(self) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class MultiTimePointFilePattern(QBaseDialog): - def __init__(self, fileName, folderPath, readPatternFunc=None, parent=None): - super().__init__(parent) - - self.setWindowTitle('File name pattern') - self.cancel = True - self.additionalChannelWidgets = {} - - mainLayout = QVBoxLayout() - self.readPatternFunc = readPatternFunc - - infoText = html_utils.paragraph(""" - The image files for each time-point must be named with the following pattern:

- position_channel_timepoint -

- For example a file with name "pos1_GFP_1.tif" would be the first time-point of the channell GFP
- and position called pos1.

- The Position number will be determined by alphabetically sorting - all the image files.

- Please, provide the channel names below. - Optionally, you can provide a basename
- that will be pre-pended to the name of all created files.

- You can also provide a folder path containing the segmentation masks file.
- These files MUST be named exactly as the raw files. -
- """) - - noteLayout = QHBoxLayout() - noteText = html_utils.paragraph(""" - Channels do not need to have the same number of frames, - however, Cell-ACDC will place
- the frames at the right frame number - (given by timepoint number at the end
- of the filename) and it will fill missing frames with zeros. - """) - noteLayout.addWidget( - QLabel(html_utils.to_admonition(noteText)), - # alignment=(Qt.AlignTop | Qt.AlignRight) - ) - - mainLayout.addWidget(QLabel(infoText)) - mainLayout.addLayout(noteLayout) - noteLayout.setStretch(0,0) - noteLayout.setStretch(1,1) - - label = QLabel(html_utils.paragraph( - f'Sample file name: {fileName}' - )) - mainLayout.addWidget(label, alignment=Qt.AlignCenter) - mainLayout.addSpacing(5) - - channelName = '' - posName = '' - frameNumber = None - if readPatternFunc is not None: - posName, frameNumber, channelName = readPatternFunc(fileName) - - formLayout = QGridLayout() - - ncols = 3 - self.vLayouts = [QVBoxLayout() for _ in range(ncols)] - for j, l in enumerate(self.vLayouts): - formLayout.addLayout(l, 0, j) - - row = 0 - items = QLabel('Position name: '), widgets.ReadOnlyLineEdit(), QLabel() - label, self.posNameEntry, button = items - self.posNameEntry.setAlignment(Qt.AlignCenter) - self.posNameEntry.setText(str(posName)) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - row += 1 - items = ( - QLabel('Frame number name: '), widgets.ReadOnlyLineEdit(), QLabel() - ) - self.frameNumberEntry = items[1] - self.frameNumberEntry.setText(str(frameNumber)) - self.frameNumberEntry.setAlignment(Qt.AlignCenter) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - row += 1 - self.channelNameLE = widgets.alphaNumericLineEdit() - items = ( - QLabel('Channel_1 name: '), self.channelNameLE, - widgets.addPushButton(' Add channel') - ) - self.addChannelButton = items[2] - self.addChannelButton._row = row - self.channelNameLE.setAlignment(Qt.AlignCenter) - self.channelNameLE.setText(channelName) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - row += 1 - items = ( - QLabel('Basename (optional): '), widgets.alphaNumericLineEdit(), - QLabel() - ) - label, self.baseNameLE, button = items - self.baseNameLE.setAlignment(Qt.AlignCenter) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - row += 1 - items = QLabel('File will be saved as: '), QLineEdit(), QLabel() - label, self.relPathEntry, button = items - self.relPathEntry.setAlignment(Qt.AlignCenter) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - row += 1 - items = ( - QLabel('Segmentation masks folder path: '), - widgets.ElidingLineEdit(), - widgets.browseFileButton( - 'Browse...', - title='Select folder containing segmentation masks', - start_dir=folderPath, openFolder=True - ) - ) - label, self.segmFolderPathEntry, button = items - button.sigPathSelected.connect(self.segmFolderpathSelected) - self.segmFolderPathEntry.setAlignment(Qt.AlignCenter) - for j, w in enumerate(items): - self.vLayouts[j].addWidget(w) - - self.formLayout = formLayout - - self.updateRelativePath() - - self.channelNameLE.textChanged.connect(self.updateRelativePath) - self.baseNameLE.textChanged.connect(self.updateRelativePath) - self.addChannelButton.clicked.connect(self.addChannel) - - mainLayout.addLayout(formLayout) - - buttonsLayout = widgets.CancelOkButtonsLayout() - showInFileManagerButton = widgets.showInFileManagerButton( - myutils.get_open_filemaneger_os_string() - ) - buttonsLayout.insertWidget(3, showInFileManagerButton) - func = partial(myutils.showInExplorer, folderPath) - showInFileManagerButton.clicked.connect(func) - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch() - - self.setLayout(mainLayout) - - self.setFont(font) - - def segmFolderpathSelected(self, path): - self.segmFolderPathEntry.setText(path) - - def addChannel(self): - self.addChannelButton._row += 1 - row = self.addChannelButton._row - - channel_idx = len(self.additionalChannelWidgets) - items = ( - QLabel(f'Channel_{channel_idx+1} name: '), - widgets.alphaNumericLineEdit(), - widgets.subtractPushButton('Remove channel') - ) - label, lineEdit, button = items - lineEdit.setAlignment(Qt.AlignCenter) - button.clicked.connect(self.removeChannel) - button._row = row - for j, w in enumerate(items): - self.vLayouts[j].insertWidget(row, w) - - self.additionalChannelWidgets[row] = items - lineEdit.setFocus() - - def removeChannel(self): - row = self.sender()._row - for j, w in enumerate(self.additionalChannelWidgets[row]): - self.vLayouts[j].removeWidget(w) - - self.additionalChannelWidgets.pop(row) - self.addChannelButton._row -= 1 - - def checkChannelNames(self): - allChannels = [self.channelNameLE.text()] - allChannels.extend( - [w[1].text() for w in self.additionalChannelWidgets.values()] - ) - for ch1, ch2 in combinations(allChannels, 2): - if ch1 == ch2: - break - if not ch1 or not ch2: - break - else: - # Channel names are fine - return allChannels - - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(""" - Some channel names are empty or not different from each other. - """) - msg.critical(self, 'Select two or more items', txt) - return None - - def updateRelativePath(self, text=''): - posName = self.posNameEntry.text() - frameNumber = self.frameNumberEntry.text() - channelName = self.channelNameLE.text() - basename = self.baseNameLE.text() - if basename: - filename = f'{basename}_{posName}_{channelName}.tif' - else: - filename = f'{posName}_{channelName}.tif' - relPath = f'...{os.sep}Position_1{os.sep}Images{os.sep}{filename}' - self.relPathEntry.setText(relPath) - - def ok_cb(self): - allChannels = self.checkChannelNames() - if allChannels is None: - return - self.allChannels = allChannels - self.basename = self.baseNameLE.text() - self.segmFolderPath = self.segmFolderPathEntry.text() - self.cancel = False - self.close() - - def showEvent(self, event) -> None: - self.channelNameLE.setFocus() - -class OrderableListWidgetDialog(QBaseDialog): - def __init__( - self, items, title='Select items', infoTxt='', helpText='', - parent=None - ): - super().__init__(parent) - - self.selectedItemsText = [] - - self.cancel = True - self.setWindowTitle(title) - - mainLayout = QVBoxLayout() - self.helpText = helpText - - if infoTxt: - mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) - - self.listWidget = widgets.OrderableList() - self.listWidget.addItems(items) - - buttonsLayout = widgets.CancelOkButtonsLayout() - if helpText: - helpButton = widgets.helpPushButton('Help...') - buttonsLayout.insertWidget(3, helpButton) - helpButton.clicked.connect(self.showHelp) - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(self.listWidget) - mainLayout.addSpacing(10) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def showHelp(self): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(self.helpText) - msg.information(self, 'Select tables help', txt) - - def ok_cb(self): - self.cancel = False - self.selectedItemsText = [None]*len(self.listWidget.selectedItems()) - for itemW in self.listWidget.selectedItems(): - idx = int(itemW._nrWidget.currentText()) - 1 - if idx >= len(self.selectedItemsText): - idx = len(self.selectedItemsText) - 1 - self.selectedItemsText[idx] = itemW._text - self.close() - - -class QDialogAutomaticThresholding(QBaseDialog): - def __init__(self, parent=None, isSegm3D=True): - super().__init__(parent) - - self.cancel = True - - self.setWindowTitle('Automatic thresholding parameters') - - layout = QVBoxLayout() - formLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - - row = 0 - self.sigmaGaussSpinbox = QDoubleSpinBox() - self.sigmaGaussSpinbox.setValue(1) - self.sigmaGaussSpinbox.setMaximum(2**31) - self.sigmaGaussSpinbox.setAlignment(Qt.AlignCenter) - formLayout.addWidget( - QLabel('Gaussian filter sigma (0 to ignore): '), row, 0, - alignment=Qt.AlignRight - ) - formLayout.addWidget(self.sigmaGaussSpinbox, row, 1, 1, 2) - - row += 1 - self.threshMethodCombobox = QComboBox() - self.threshMethodCombobox.addItems([ - 'Isodata', 'Li', 'Mean', 'Minimum', 'Otsu', 'Triangle', 'Yen' - ]) - formLayout.addWidget( - QLabel('Thresholding algorithm: '), row, 0, - alignment=Qt.AlignRight - ) - formLayout.addWidget(self.threshMethodCombobox, row, 1, 1, 2) - - self.segment3Dcheckbox = None - if isSegm3D: - row += 1 - formLayout.addWidget( - QLabel('Segment 3D volume: '), row, 0, alignment=Qt.AlignRight - ) - group = QButtonGroup() - group.setExclusive(True) - self.segment3Dcheckbox = QRadioButton('Yes') - segmentSliceBySliceCheckbox = QRadioButton('No, segment slice-by-slice') - group.addButton(self.segment3Dcheckbox) - group.addButton(segmentSliceBySliceCheckbox) - formLayout.addWidget(self.segment3Dcheckbox, row, 1) - formLayout.addWidget(segmentSliceBySliceCheckbox, row, 2) - self.segment3Dcheckbox.setChecked(True) - - okButton = widgets.okPushButton('Ok') - cancelButton = widgets.cancelPushButton('Cancel') - helpButton = widgets.helpPushButton('Help...') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(helpButton) - buttonsLayout.addWidget(okButton) - - layout.addLayout(formLayout) - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - - okButton.clicked.connect(self.ok_cb) - helpButton.clicked.connect(self.help_cb) - cancelButton.clicked.connect(self.close) - - self.setLayout(layout) - self.setFont(font) - - self.configPars = self.loadLastSelection() - - - def help_cb(self): - import webbrowser - url = 'https://scikit-image.org/docs/stable/auto_examples/applications/plot_thresholding.html' - webbrowser.open(url) - - def ok_cb(self): - self.cancel = False - self.gaussSigma = self.sigmaGaussSpinbox.value() - threshMethod = self.threshMethodCombobox.currentText().lower() - self.threshMethod = f'threshold_{threshMethod}' - self.segment_kwargs = { - 'gauss_sigma': self.gaussSigma, - 'threshold_method': self.threshMethod, - 'segment_3D_volume': False - } - self.reduceMemoryUsage = False - if self.segment3Dcheckbox is not None: - doSegm3D = self.segment3Dcheckbox.isChecked() - self.segment_kwargs['segment_3D_volume'] = doSegm3D - self.close() - - def loadLastSelection(self): - self.ini_path = os.path.join( - settings_folderpath, 'last_params_segm_models.ini' - ) - if not os.path.exists(self.ini_path): - return - - configPars = config.ConfigParser() - configPars.read(self.ini_path) - - if 'thresholding.segment' not in configPars.sections(): - return - - section = configPars['thresholding.segment'] - self.sigmaGaussSpinbox.setValue(float(section['gauss_sigma'])) - - threshold_method = section['threshold_method'] - Method = threshold_method[10:].capitalize() - self.threshMethodCombobox.setCurrentText(Method) - if self.segment3Dcheckbox is None: - return - self.segment3Dcheckbox.setChecked(section.getboolean('segment_3D_volume')) - -class GenerateMotherBudTotalTableSelectColumnsDialog(QBaseDialog): - def __init__(self, df: pd.DataFrame, parent=None): - super().__init__(parent) - - self.setWindowTitle('Select columns to combine into the output table') - - self.cancel = True - - self.columns = core.natsort_acdc_columns(df.columns) - self.operations = ( - 'Sum mother and bud', - 'Copy column from mother', - ) - - self.mainLayout = QVBoxLayout() - - instructionsText = html_utils.paragraph(""" - Select which columns and how you want to combine them - into the output table.
- """) - self.mainLayout.addWidget(QLabel(instructionsText)) - - settingsLayout = QGridLayout() - - row = 0 - settingsLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) - - row += 1 - settingsLayout.addWidget( - QLabel('Copy all non-selected columns from mother cell'), row, 0 - ) - self.copyAllColsToggle = widgets.Toggle() - settingsLayout.addWidget( - self.copyAllColsToggle, row, 1, alignment=Qt.AlignLeft - ) - - row += 1 - settingsLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) - - self.mainLayout.addLayout(settingsLayout) - - scrollArea = widgets.ScrollArea() - scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) - scrollWidget = QWidget() - scrollArea.setWidget(scrollWidget) - self.centralLayout = QGridLayout() - scrollWidget.setLayout(self.centralLayout) - - self.centralLayout.addWidget(QLabel('Grouping columns'), 0, 0) - self.centralLayout.addWidget(QLabel('Column'), 0, 1) - self.centralLayout.addWidget(QLabel('Operation'), 0, 2) - self.centralLayout.setRowStretch(0, 0) - - self.groupingColsListWidget = widgets.listWidget( - isMultipleSelection=True, - ) - self.groupingColsListWidget.addItems(self.columns) - self.centralLayout.addWidget(self.groupingColsListWidget, 1, 0, 2, 1) - - selector = widgets.ComboBox(self) - selector.addItems(self.columns) - operationCombobox = widgets.ComboBox(self) - operationCombobox.addItems(self.operations) - self.addSelectorButton = widgets.addPushButton() - - dummyButton = widgets.delPushButton() - dummyButton.setRetainSizeWhenHidden(True) - dummyButton.hide() - self.centralLayout.addWidget(dummyButton, 1, 4) - - self.centralLayout.addWidget(selector, 1, 1) - self.centralLayout.addWidget(operationCombobox, 1, 2) - self.centralLayout.addWidget(self.addSelectorButton, 1, 3) - - self.centralLayout.setRowStretch(1, 1) - self.centralLayout.setRowStretch(2, 1) - - self.selectors = {1: (selector, operationCombobox)} - - buttonsLayout = widgets.CancelOkButtonsLayout() - - saveSelectionButton = widgets.savePushButton( - 'Save current selection' - ) - buttonsLayout.insertWidget(3, saveSelectionButton) - - loadDefaultColsButton = widgets.reloadPushButton( - 'Load default summable columns' - ) - buttonsLayout.insertWidget(4, loadDefaultColsButton) - - loadPreviousSelButton = widgets.OpenFilePushButton( - 'Load previous selection' - ) - buttonsLayout.insertWidget(5, loadPreviousSelButton) - - saveSelectionButton.clicked.connect(self.saveSelection) - loadDefaultColsButton.clicked.connect(self.loadDefaultCols) - loadPreviousSelButton.clicked.connect(self.loadPreviousSelection) - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addWidget(scrollArea) - self.mainLayout.addSpacing(20) - self.mainLayout.addLayout(buttonsLayout) - - self.addSelectorButton.clicked.connect(self.addSelector) - selector.currentTextChanged.connect( - self.selectorTextChanged - ) - - self.setLayout(self.mainLayout) - self.setFont(font) - - def saveSelection(self): - saved_selections = io.get_saved_moth_bud_tot_selections() - existing_names = set(saved_selections.keys()) - win = filenameDialog( - basename='', - ext='', - hintText='Insert a name for the current selection:', - existingNames=existing_names, - allowEmpty=False, - defaultEntry='mother_bud_total_columns_selection' - ) - win.exec_() - if win.cancel: - return - - name = win.filename - saved_selections[name] = self.selectedOptions() - io.save_moth_bud_tot_selected_options(saved_selections) - - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(f""" - Current selection saved with name {name}. - """) - msg.information(self, 'Selection saved', txt) - - def loadDefaultCols(self): - from . import single_pos_index_cols - - grouping_cols = [ - col for col in single_pos_index_cols if col in self.columns - ] - self.groupingColsListWidget.setSelectedItems(grouping_cols) - - column_operation_mapper = { - col: 'Sum mother and bud' - for col in cca_functions.default_summable_columns - } - column_operation_mapper = { - col: op for col, op in column_operation_mapper.items() - if col in self.columns and op in self.operations - } - self.addSelectors( - len(column_operation_mapper), - callback_on_finished=partial( - self.setSelectorValues, column_operation_mapper - ) - ) - - def loadPreviousSelection(self): - saved_selections = io.get_saved_moth_bud_tot_selections() - if not saved_selections: - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(""" - There are no saved selections. - """) - msg.warning(self, 'No saved selections', txt) - return - - existing_names = natsorted(saved_selections.keys(), key=str.casefold) - - selectNameWin = widgets.QDialogListbox( - 'Choose selection to load', - 'Choose selection to load:\n', - existing_names, - multiSelection=False, - parent=self - ) - selectNameWin.exec_() - if selectNameWin.cancel: - return - - self.loadOptions(saved_selections[selectNameWin.selectedItemsText[0]]) - - def resetSelectors(self, callback_on_finished=None): - self.callback_on_finished = callback_on_finished - QTimer.singleShot(1, self._removeLastSelector) - - def _removeLastSelector(self): - if len(self.selectors) == 1: - if self.callback_on_finished is not None: - self.callback_on_finished() - return - - lastRow = max(self.selectors.keys()) - lastSelector, _ = self.selectors[lastRow] - self.removeSelector(sender=lastSelector.delButton) - QTimer.singleShot(1, self._removeLastSelector) - - def addSelectors(self, number, callback_on_finished=None): - self.callback_on_finished = callback_on_finished - QTimer.singleShot(1, partial(self._addSelectorRecursive, number)) - - def _addSelectorRecursive(self, number): - if len(self.selectors) == number: - if self.callback_on_finished is not None: - self.callback_on_finished() - return - - self.addSelector() - QTimer.singleShot(1, partial(self._addSelectorRecursive, number)) - - def loadOptions(self, options: dict): - if len(self.selectors) > 1: - self.resetSelectors( - callback_on_finished=partial(self.loadOptions, options) - ) - return - - self.copyAllColsToggle.setChecked( - options.get('do_copy_all_nonselected_columns', False) - ) - self.groupingColsListWidget.setSelectedItems( - options.get('grouping_columns', []) - ) - column_operation_mapper = options.get('column_operation_mapper', {}) - column_operation_mapper = { - col: op for col, op in column_operation_mapper.items() - if col in self.columns and op in self.operations - } - if len(column_operation_mapper) > 1: - self.addSelectors( - len(column_operation_mapper), - callback_on_finished=partial( - self.setSelectorValues, column_operation_mapper - ) - ) - return - - self.setSelectorValues(column_operation_mapper) - - def setSelectorValues(self, column_operation_mapper): - for i, (col, op) in enumerate(column_operation_mapper.items()): - selector, operationCombobox = self.selectors[i+1] - selector.setCurrentText(col) - operationCombobox.setCurrentText(op) - - def resetSelectorsStyles(self): - for selector, _ in self.selectors.values(): - selector.setStyleSheet('') - - def selectorTextChanged(self, text): - self.resetSelectorsStyles() - selector = self.sender() - for other_selector, _ in self.selectors.values(): - if other_selector == selector: - continue - - if selector.currentText() != other_selector.currentText(): - continue - - self.setWarningStyleSelector(selector) - self.setWarningStyleSelector(other_selector) - - def addSelector(self): - row = len(self.selectors) + 1 - - selector = widgets.ComboBox(self) - selector.addItems(self.columns) - selector.setCurrentIndex(len(self.selectors)) - operationCombobox = widgets.ComboBox(self) - operationCombobox.addItems(self.operations) - delButton = widgets.delPushButton() - selector.delButton = delButton - delButton._row = row - - self.selectors[row] = (selector, operationCombobox) - - self.centralLayout.addWidget(selector, row, 1) - self.centralLayout.addWidget(operationCombobox, row, 2) - self.centralLayout.addWidget(delButton, row, 3) - - self.centralLayout.removeWidget(self.addSelectorButton) - self.centralLayout.addWidget(self.addSelectorButton, row, 4) - - delButton.clicked.connect(self.removeSelector) - - self.centralLayout.removeWidget(self.groupingColsListWidget) - rowSpan = self.centralLayout.rowCount() - self.centralLayout.addWidget( - self.groupingColsListWidget, 1, 0, rowSpan, 1 - ) - self.centralLayout.setRowStretch(rowSpan, 1) - - selector.currentTextChanged.connect( - self.selectorTextChanged - ) - - def removeSelector(self, checked=False, sender=None): - if sender is None: - delButton = self.sender() - else: - delButton = sender - - selector, operationCombobox = self.selectors.pop(delButton._row) - - self.centralLayout.removeWidget(selector) - self.centralLayout.removeWidget(operationCombobox) - self.centralLayout.removeWidget(delButton) - - resorted_selectors = {} - for i, (row, (sel, op)) in enumerate(self.selectors.items()): - if i == 0: - resorted_selectors[i+1] = (sel, op) - continue - - delButton = sel.delButton - delButton._row = i+1 - self.centralLayout.removeWidget(sel) - self.centralLayout.removeWidget(op) - self.centralLayout.removeWidget(delButton) - self.centralLayout.addWidget(sel, i+1, 1) - self.centralLayout.addWidget(op, i+1, 2) - self.centralLayout.addWidget(delButton, i+1, 3) - - resorted_selectors[i+1] = (sel, op) - - last_row = i+1 - col = 4 if last_row > 1 else 3 - self.centralLayout.removeWidget(self.addSelectorButton) - self.centralLayout.addWidget(self.addSelectorButton, i+1, col) - - self.selectors = resorted_selectors - - def sizeHint(self): - width = super().sizeHint().width() - height = super().sizeHint().height() - groupingColsWidth = widgets.get_min_width_for_no_scrollbar( - self.groupingColsListWidget - ) - width += groupingColsWidth - return QSize(width, height) - - def checkDuplicatedSelectedColumns(self): - for selector, _ in self.selectors.values(): - selector.setStyleSheet('background-color: none') - for other_selector, _ in self.selectors.values(): - if other_selector == selector: - continue - - if other_selector.currentText() != selector.currentText(): - continue - - self.warnDuplicatedSelectedColumns(selector, other_selector) - return False - - return True - - def setWarningStyleSelector(self, selector): - popup = selector.view() - palette = popup.palette() - text_color = palette.color(palette.ColorRole.Text) - warningStyleSheet = (f""" - QComboBox {{ - color: black; - background-color: orange; /* main area */ - }} - QComboBox QAbstractItemView {{ - background-color: {text_color.name()}; - }} - """) - selector.setStyleSheet(warningStyleSheet) - - def warnDuplicatedSelectedColumns(self, selector1, selector2): - self.setWarningStyleSelector(selector1) - self.setWarningStyleSelector(selector2) - - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(f""" - The following column has been selected more than once - (highlighted in orange).

- {selector1.currentText()}

- Please, select each column only once.

- Thank you for your patience! - """) - msg.warning(self, 'Duplicated selection', txt) - - - def checkGroupingColumnsNotSelected(self): - if self.groupingColsListWidget.selectedItems(): - return True - - return self.warnGroupingColumnsNotSelected() - - def warnGroupingColumnsNotSelected(self): - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(f""" - Are you sure you do not want to select any grouping column?

- Grouping columns are those needed to identify each unique - Position folder. - """) - _, noButton, yesButton = msg.question( - self, 'No grouping columns selected?', txt, - buttonsTexts=( - 'Cancel', - 'No, let me select grouping columns', - 'Yes, I do not need grouping columns' - ) - ) - return msg.clickedButton == yesButton - - def selectedOptions(self): - selected_options = { - 'grouping_columns': self.groupingColsListWidget.selectedItemsText(), - 'column_operation_mapper': { - selector.currentText(): operationCombobox.currentText() - for selector, operationCombobox in self.selectors.values() - }, - 'do_copy_all_nonselected_columns': self.copyAllColsToggle.isChecked() - } - return selected_options - - def ok_cb(self): - proceed = self.checkDuplicatedSelectedColumns() - if not proceed: - return - - proceed = self.checkGroupingColumnsNotSelected() - if not proceed: - return - - self.selected_options = self.selectedOptions() - - self.cancel = False - self.close() - -class ApplyTrackTableSelectColumnsDialog(QBaseDialog): - def __init__(self, df, parent=None): - super().__init__(parent) - - self.setWindowTitle('Select columns containing tracking info') - - self.cancel = True - self.mainLayout = QVBoxLayout() - - options = ( - '"Frame index", "Tracked IDs" and "Segmentation mask IDs"
', - '"Frame index", "Tracked IDs", "X coord. centroid", and "Y coord. centroid"' - ) - self.instructionsText = html_utils.paragraph( - f""" - Select which columns contain the tracking information.

- You must choose one of the following combinations:
- {html_utils.to_list(options)} - Optionally, you can provide the column name containing the parent ID.
- This will allow you to load lineage information into Cell-ACDC. - """ - ) - self.mainLayout.addWidget(QLabel(self.instructionsText)) - - formLayout = QFormLayout() - - self.frameIndexCombobox = widgets.QCenteredComboBox() - self.frameIndexCombobox.addItems(df.columns) - self.frameIndexCheckbox = QCheckBox('1st frame is index 1') - frameIndexLayout = QHBoxLayout() - frameIndexLayout.addWidget(self.frameIndexCombobox) - frameIndexLayout.addWidget(self.frameIndexCheckbox) - frameIndexLayout.setStretch(0, 2) - frameIndexLayout.setStretch(1, 0) - formLayout.addRow( - 'Frame index: ', frameIndexLayout - ) - - self.trackedIDsCombobox = widgets.QCenteredComboBox() - self.trackedIDsCombobox.addItems(df.columns) - formLayout.addRow('Tracked IDs: ', self.trackedIDsCombobox) - - items = df.columns.to_list() - items.insert(0, 'None') - self.maskIDsCombobox = widgets.QCenteredComboBox() - self.maskIDsCombobox.addItems(items) - formLayout.addRow('Segmentation mask IDs: ', self.maskIDsCombobox) - - self.xCentroidCombobox = widgets.QCenteredComboBox() - self.xCentroidCombobox.addItems(items) - formLayout.addRow('X coord. centroid: ', self.xCentroidCombobox) - - self.yCentroidCombobox = widgets.QCenteredComboBox() - self.yCentroidCombobox.addItems(items) - formLayout.addRow('Y coord. centroid: ', self.yCentroidCombobox) - - self.parentIDcombobox = widgets.QCenteredComboBox() - self.parentIDcombobox.addItems(items) - formLayout.addRow('Parent ID (optional): ', self.parentIDcombobox) - - deleteUntrackedLayout = QHBoxLayout() - self.deleteUntrackedIDsToggle = widgets.Toggle() - deleteUntrackedLayout.addStretch(1) - deleteUntrackedLayout.addWidget(self.deleteUntrackedIDsToggle) - deleteUntrackedLayout.addStretch(1) - formLayout.addRow('Delete untracked IDs: ', deleteUntrackedLayout) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addSpacing(30) - self.mainLayout.addLayout(formLayout) - self.mainLayout.addSpacing(20) - self.mainLayout.addLayout(buttonsLayout) - - self.setLayout(self.mainLayout) - self.setFont(font) - - def ok_cb(self): - self.cancel = False - self.frameIndexCol = self.frameIndexCombobox.currentText() - self.trackedIDsCol = self.trackedIDsCombobox.currentText() - self.maskIDsCol = self.maskIDsCombobox.currentText() - self.xCentroidCol = self.xCentroidCombobox.currentText() - self.yCentroidCol = self.yCentroidCombobox.currentText() - self.deleteUntrackedIDs = self.deleteUntrackedIDsToggle.isChecked() - if self.maskIDsCol == 'None': - if self.xCentroidCol == 'None' or self.yCentroidCol == 'None': - self.warnInvalidSelection() - return - else: - self.xCentroidCol = 'None' - self.yCentroidCol = 'None' - self.parentIDcol = self.parentIDcombobox.currentText() - self.isFirstFrameOne = self.frameIndexCheckbox.isChecked() - self.close() - - def warnInvalidSelection(self): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.warning( - self, 'Invalid selection', html_utils.paragraph( - f'Invalid selection
{self.instructionsText}' - ) - ) - - -class SelectPromptableModelDialog(QBaseDialog): - def __init__(self, parent=None): - self.cancel = True - super().__init__(parent) - - self.setWindowTitle('Select model for segmentation') - - mainLayout = QVBoxLayout() - - label = QLabel(html_utils.paragraph( - 'Select model to use for segmentation: ' - )) - mainLayout.addWidget(label, alignment=Qt.AlignCenter) - - listBox = widgets.listWidget() - models = myutils.get_list_of_promptable_models() - listBox.addItems(models) - listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) - listBox.setCurrentRow(0) - listBox.itemDoubleClicked.connect(self.ok_cb) - - self.listBox = listBox - - mainLayout.addWidget(listBox) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def ok_cb(self): - self.cancel = False - self.model_name = self.listBox.currentItem().text() - self.close() - - -class QDialogSelectModel(QDialog): - def __init__( - self, parent=None, addSkipSegmButton=False, customFirst='' - ): - self.cancel = True - super().__init__(parent) - self.setWindowTitle('Select model') - - mainLayout = QVBoxLayout() - topLayout = QVBoxLayout() - bottomLayout = QHBoxLayout() - - self.mainLayout = mainLayout - - label = QLabel(html_utils.paragraph( - 'Select model to use for segmentation: ' - )) - # padding: top, left, bottom, right - label.setStyleSheet("padding:0px 0px 3px 0px;") - topLayout.addWidget(label, alignment=Qt.AlignCenter) - - listBox = widgets.listWidget() - models = myutils.get_list_of_models() - - if customFirst: - try: - idx = models.index(customFirst) - models.insert(0, models.pop(idx)) - except ValueError: - print(f'Warning: {customFirst} not found in models list.') - pass - - listBox.setFont(font) - listBox.addItems(models) - addCustomModelItem = QListWidgetItem('Add custom model...') - addCustomModelItem.setFont(italicFont) - listBox.addItem(addCustomModelItem) - listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) - listBox.setCurrentRow(0) - self.listBox = listBox - listBox.itemDoubleClicked.connect(self.ok_cb) - topLayout.addWidget(listBox) - - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton(' Ok ') - okButton.setShortcut(Qt.Key_Enter) - - bottomLayout.addStretch(1) - bottomLayout.addWidget(cancelButton) - bottomLayout.addSpacing(20) - if addSkipSegmButton: - skipSegmButton = widgets.SkipPushButton('Skip segmentation') - bottomLayout.addWidget(skipSegmButton) - skipSegmButton.clicked.connect(self.skipSegm) - bottomLayout.addWidget(okButton) - bottomLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(topLayout) - mainLayout.addLayout(bottomLayout) - self.setLayout(mainLayout) - - # Connect events - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - - self.setStyleSheet(LISTWIDGET_STYLESHEET) - - def skipSegm(self): - self.cancel = False - self.selectedModel = 'skip_segmentation' - self.close() - - def keyPressEvent(self, event: QKeyEvent) -> None: - if event.key() == Qt.Key_Escape: - event.ignore() - return - - super().keyPressEvent(event) - - def ok_cb(self, event): - self.clickedButton = self.sender() - self.cancel = False - item = self.listBox.currentItem() - model = item.text() - if model == 'Add custom model...': - modelFilePath = addCustomModelMessages(self) - if modelFilePath is None: - return - myutils.store_custom_model_path(modelFilePath) - modelName = os.path.basename(os.path.dirname(modelFilePath)) - item = QListWidgetItem(modelName) - self.listBox.addItem(item) - self.listBox.setCurrentItem(item) - elif model == 'Automatic thresholding': - self.selectedModel = 'thresholding' - self.close() - else: - self.selectedModel = model - self.close() - - def cancel_cb(self, event): - self.cancel = True - self.selectedModel = None - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - - horizontal_sb = self.listBox.horizontalScrollBar() - while horizontal_sb.isVisible(): - self.resize(self.height(), self.width() + 10) - - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class ViewTextDialog(QBaseDialog): - def __init__(self, text, parent=None): - super().__init__(parent) - - mainLayout = QVBoxLayout() - - textViewWidget = QTextEdit() - textViewWidget.setReadOnly(True) - - textViewWidget.setText(text) - - buttonsLayout = QHBoxLayout() - okButton = widgets.okPushButton('Ok') - - okButton.clicked.connect(self.close) - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(okButton) - - mainLayout.addWidget(textViewWidget) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - self.setFont(font) - -class startStopFramesDialog(QBaseDialog): - def __init__( - self, SizeT, currentFrameNum=0, parent=None, - windowTitle='Select frame range to segment' - ): - super().__init__(parent=parent) - - self.setWindowTitle(windowTitle) - - self.cancel = True - - layout = QVBoxLayout() - buttonsLayout = QHBoxLayout() - - self.selectFramesGroupbox = widgets.selectStartStopFrames( - SizeT, currentFrameNum=currentFrameNum, parent=parent - ) - - okButton = widgets.okPushButton('Ok') - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - layout.addWidget(self.selectFramesGroupbox) - layout.addLayout(buttonsLayout) - self.setLayout(layout) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - self.setFont(font) - - def ok_cb(self): - if self.selectFramesGroupbox.warningLabel.text(): - return - else: - self.startFrame = self.selectFramesGroupbox.startFrame_SB.value() - self.stopFrame = self.selectFramesGroupbox.stopFrame_SB.value() - self.cancel = False - self.close() - - def show(self, block=False): - super().show(block=False) - - self.resize(int(self.width()*1.5), self.height()) - - if block: - super().show(block=True) - -class QDialogAppendTextFilename(QDialog): - def __init__(self, filename, ext, parent=None, font=None): - super().__init__(parent) - self.cancel = True - filenameNOext, _ = os.path.splitext(filename) - self.filenameNOext = filenameNOext - if ext.find('.') == -1: - ext = f'.{ext}' - self.ext = ext - - self.setWindowTitle('Append text to file name') - - mainLayout = QVBoxLayout() - formLayout = QFormLayout() - buttonsLayout = QHBoxLayout() - - if font is not None: - self.setFont(font) - - self.LE = QLineEdit() - self.LE.setAlignment(Qt.AlignCenter) - formLayout.addRow('Appended text', self.LE) - self.LE.textChanged.connect(self.updateFinalFilename) - - self.finalName_label = QLabel( - f'Final file name: "{filenameNOext}_{ext}"' - ) - # padding: top, left, bottom, right - self.finalName_label.setStyleSheet( - 'font-size:13px; padding:5px 0px 0px 0px;' - ) - - okButton = widgets.okPushButton('Ok') - okButton.setShortcut(Qt.Key_Enter) - - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - buttonsLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(formLayout) - mainLayout.addWidget(self.finalName_label, alignment=Qt.AlignCenter) - mainLayout.addLayout(buttonsLayout) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - self.formLayout = formLayout - - self.setLayout(mainLayout) - # self.setModal(True) - - def updateFinalFilename(self, text): - finalFilename = f'{self.filenameNOext}_{text}{self.ext}' - self.finalName_label.setText(f'Final file name: "{finalFilename}"') - - def ok_cb(self, event): - if not self.LE.text(): - err_msg = ( - 'Appended name cannot be empty!' - ) - msg = QMessageBox() - msg.critical( - self, 'Empty name', err_msg, msg.Ok - ) - return - self.cancel = False - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QDialogEntriesWidget(QDialog): - def __init__(self, entriesLabels, defaultTxts, winTitle='Input', - parent=None, font=None): - self.cancel = True - self.entriesTxt = [] - self.entriesLabels = entriesLabels - self.QLEs = [] - super().__init__(parent) - self.setWindowTitle(winTitle) - - mainLayout = QVBoxLayout() - formLayout = QFormLayout() - buttonsLayout = QHBoxLayout() - - if font is not None: - self.setFont(font) - - for label, txt in zip(entriesLabels, defaultTxts): - LE = QLineEdit() - LE.setAlignment(Qt.AlignCenter) - LE.setText(txt) - formLayout.addRow(label, LE) - self.QLEs.append(LE) - - okButton = widgets.okPushButton('Ok') - okButton.setShortcut(Qt.Key_Enter) - - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - buttonsLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(formLayout) - mainLayout.addLayout(buttonsLayout) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - self.formLayout = formLayout - - self.setLayout(mainLayout) - # self.setModal(True) - - def ok_cb(self, event): - self.cancel = False - self.entriesTxt = [self.formLayout.itemAt(i, 1).widget().text() - for i in range(len(self.entriesLabels))] - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QDialogMetadata(QDialog): - def __init__( - self, SizeT, SizeZ, TimeIncrement, - PhysicalSizeZ, PhysicalSizeY, PhysicalSizeX, - ask_SizeT, ask_TimeIncrement, ask_PhysicalSizes, - parent=None, font=None, imgDataShape=None, posData=None, - singlePos=False, askSegm3D=True, additionalValues=None, - forceEnableAskSegm3D=False, SizeT_metadata=None, - SizeZ_metadata=None, basename='' - ): - self.cancel = True - self.ask_TimeIncrement = ask_TimeIncrement - self.ask_PhysicalSizes = ask_PhysicalSizes - self.askSegm3D = askSegm3D - self.imgDataShape = imgDataShape - self.posData = posData - self._additionalValues = additionalValues - self.SizeT_metadata = SizeT_metadata - self.SizeZ_metadata = SizeZ_metadata - super().__init__(parent) - self.setWindowTitle('Image properties') - - mainLayout = QVBoxLayout() - gridLayout = QGridLayout() - # formLayout = QFormLayout() - buttonsLayout = QGridLayout() - - if imgDataShape is not None: - label = QLabel( - html_utils.paragraph( - f'Image data shape = {imgDataShape}
' - ) - ) - mainLayout.addWidget(label, alignment=Qt.AlignCenter) - - row = 0 - self.basenameLineEdit = None - if basename: - gridLayout.addWidget( - QLabel('Basename (read-only)'), row, 0, alignment=Qt.AlignRight - ) - self.basenameLineEdit = QLineEdit() - self.basenameLineEdit.setReadOnly(True) - self.basenameLineEdit.setText(basename) - minWidth = ( - self.basenameLineEdit.fontMetrics() - .boundingRect(basename).width() + 10 - ) - self.basenameLineEdit.setMinimumWidth(minWidth) - self.basenameLineEdit.setAlignment(Qt.AlignCenter) - gridLayout.addWidget(self.basenameLineEdit, row, 1) - row += 1 - - gridLayout.addWidget( - QLabel('Number of frames (SizeT)'), row, 0, alignment=Qt.AlignRight - ) - self.SizeT_SpinBox = QSpinBox() - self.SizeT_SpinBox.setMinimum(1) - self.SizeT_SpinBox.setMaximum(2147483647) - SizeTinfoButton = widgets.infoPushButton() - self.allowEditSizeTcheckbox = QCheckBox('Let me edit it') - if ask_SizeT: - self.SizeT_SpinBox.setValue(SizeT) - SizeTinfoButton.hide() - self.allowEditSizeTcheckbox.hide() - else: - self.SizeT_SpinBox.setValue(1) - self.SizeT_SpinBox.setDisabled(True) - SizeTinfoButton.show() - SizeTinfoButton.clicked.connect(self.showWhySizeTisGrayed) - self.allowEditSizeTcheckbox.show() - self.allowEditSizeTcheckbox.toggled.connect(self.allowEditSizeT) - self.SizeT_SpinBox.setAlignment(Qt.AlignCenter) - self.SizeT_SpinBox.valueChanged.connect(self.TimeIncrementShowHide) - gridLayout.addWidget(self.SizeT_SpinBox, row, 1) - gridLayout.addWidget(SizeTinfoButton, row, 2) - gridLayout.setColumnStretch(2,0) - gridLayout.addWidget(self.allowEditSizeTcheckbox, row, 3) - gridLayout.setColumnStretch(3,0) - - row += 1 - gridLayout.addWidget( - QLabel('Number of z-slices (SizeZ)'), row, 0, alignment=Qt.AlignRight - ) - self.SizeZ_SpinBox = QSpinBox() - self.SizeZ_SpinBox.setMinimum(1) - self.SizeZ_SpinBox.setMaximum(2147483647) - self.SizeZ_SpinBox.setValue(SizeZ) - self.SizeZ_SpinBox.setAlignment(Qt.AlignCenter) - self.SizeZ_SpinBox.valueChanged.connect(self.SizeZvalueChanged) - gridLayout.addWidget(self.SizeZ_SpinBox, row, 1) - - row += 1 - self.TimeIncrementLabel = QLabel('Time interval (s)') - gridLayout.addWidget( - self.TimeIncrementLabel, row, 0, alignment=Qt.AlignRight - ) - self.TimeIncrementSpinBox = widgets.FloatLineEdit() - self.TimeIncrementSpinBox.setValue(TimeIncrement) - gridLayout.addWidget(self.TimeIncrementSpinBox, row, 1) - - if SizeT == 1 or not ask_TimeIncrement: - self.TimeIncrementSpinBox.hide() - self.TimeIncrementLabel.hide() - - row += 1 - self.PhysicalSizeZLabel = QLabel('Physical Size Z (um/pixel)') - gridLayout.addWidget( - self.PhysicalSizeZLabel, row, 0, alignment=Qt.AlignRight - ) - self.PhysicalSizeZSpinBox = widgets.FloatLineEdit() - self.PhysicalSizeZSpinBox.setValue(PhysicalSizeZ) - gridLayout.addWidget(self.PhysicalSizeZSpinBox, row, 1) - - if SizeZ==1 or not ask_PhysicalSizes: - self.PhysicalSizeZSpinBox.hide() - self.PhysicalSizeZLabel.hide() - - row += 1 - self.PhysicalSizeYLabel = QLabel('Physical Size Y (um/pixel)') - gridLayout.addWidget( - self.PhysicalSizeYLabel, row, 0, alignment=Qt.AlignRight - ) - self.PhysicalSizeYSpinBox = widgets.FloatLineEdit() - self.PhysicalSizeYSpinBox.setValue(PhysicalSizeY) - gridLayout.addWidget(self.PhysicalSizeYSpinBox, row, 1) - - if not ask_PhysicalSizes: - self.PhysicalSizeYSpinBox.hide() - self.PhysicalSizeYLabel.hide() - - row += 1 - self.PhysicalSizeXLabel = QLabel('Physical Size X (um/pixel)') - gridLayout.addWidget( - self.PhysicalSizeXLabel, row, 0, alignment=Qt.AlignRight - ) - self.PhysicalSizeXSpinBox = widgets.FloatLineEdit() - self.PhysicalSizeXSpinBox.setValue(PhysicalSizeX) - gridLayout.addWidget(self.PhysicalSizeXSpinBox, row, 1) - - if not ask_PhysicalSizes: - self.PhysicalSizeXSpinBox.hide() - self.PhysicalSizeXLabel.hide() - - row += 1 - self.isSegm3Dtoggle = widgets.Toggle() - if posData is not None: - self.isSegm3Dtoggle.setChecked(posData.getIsSegm3D()) - disableToggle = ( - # Disable toggle if not force enable and if - # segm data was found (we cannot change the shape of - # loaded segmentation in the GUI) - posData.segmFound is not None - and posData.segmFound - and not forceEnableAskSegm3D - ) - if disableToggle: - self.isSegm3Dtoggle.setDisabled(True) - self.isSegm3DLabel = QLabel('Work with 3D segmentation masks (z-stack)') - gridLayout.addWidget( - self.isSegm3DLabel, row, 0, alignment=Qt.AlignRight - ) - gridLayout.addWidget( - self.isSegm3Dtoggle, row, 1, alignment=Qt.AlignCenter - ) - self.infoButtonSegm3D = QPushButton(self) - self.infoButtonSegm3D.setCursor(Qt.WhatsThisCursor) - self.infoButtonSegm3D.setIcon(QIcon(":info.svg")) - gridLayout.addWidget( - self.infoButtonSegm3D, row, 2, alignment=Qt.AlignLeft - ) - self.infoButtonSegm3D.clicked.connect(self.infoSegm3D) - if SizeZ == 1 or not askSegm3D: - self.isSegm3DLabel.hide() - self.isSegm3Dtoggle.hide() - self.infoButtonSegm3D.hide() - - self.SizeZvalueChanged(SizeZ) - - self.additionalFieldsWidgets = [] - addFieldButton = widgets.addPushButton('Add custom field') - addFieldInfoButton = widgets.infoPushButton() - addFieldInfoButton.clicked.connect(self.showAddFieldInfo) - addFieldButton.clicked.connect(self.addField) - addFieldLayout = QHBoxLayout() - addFieldLayout.addStretch(1) - addFieldLayout.addWidget(addFieldButton) - addFieldLayout.addWidget(addFieldInfoButton) - addFieldLayout.addStretch(1) - - if singlePos: - okTxt = 'Apply only to this Position' - else: - okTxt = 'Ok for loaded Positions' - okButton = widgets.okPushButton(okTxt) - okButton.setToolTip( - 'Save metadata only for current positionh' - ) - okButton.setShortcut(Qt.Key_Enter) - self.okButton = okButton - - if ask_TimeIncrement or ask_PhysicalSizes: - okAllButton = QPushButton('Apply to ALL Positions') - okAllButton.setToolTip( - 'Update existing Physical Sizes, Time interval, cell volume (fl), ' - 'cell area (um^2), and time (s) for all the positions ' - 'in the experiment folder.' - ) - self.okAllButton = okAllButton - - selectButton = QPushButton('Select the Positions to be updated') - selectButton.setToolTip( - 'Ask to select positions then update existing Physical Sizes, ' - 'Time interval, cell volume (fl), cell area (um^2), and time (s)' - 'for selected positions.' - ) - self.selectButton = selectButton - else: - self.okAllButton = None - self.selectButton = None - okButton.setText('Ok') - - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.setColumnStretch(0, 1) - buttonsLayout.addWidget(okButton, 0, 1) - if ask_TimeIncrement or ask_PhysicalSizes: - buttonsLayout.addWidget(okAllButton, 0, 2) - buttonsLayout.addWidget(selectButton, 1, 1) - buttonsLayout.addWidget(cancelButton, 1, 2) - else: - buttonsLayout.addWidget(cancelButton, 0, 2) - buttonsLayout.setColumnStretch(3, 1) - - gridLayout.setColumnMinimumWidth(1, 100) - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(10) - mainLayout.addLayout(addFieldLayout) - # mainLayout.addLayout(formLayout) - mainLayout.addSpacing(20) - mainLayout.addStretch(1) - mainLayout.addLayout(buttonsLayout) - self.mainLayout = mainLayout - - okButton.clicked.connect(self.ok_cb) - if ask_TimeIncrement or ask_PhysicalSizes: - okAllButton.clicked.connect(self.ok_cb) - selectButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - - self.addAdditionalValues(additionalValues) - - self.setLayout(mainLayout) - self.setFont(font) - # self.setModal(True) - - def showWhySizeTisGrayed(self): - txt = html_utils.paragraph(f""" - The "Number of frames" field is grayed-out because you loaded multiple Positions.

- Cell-ACDC cannot load multiple time-lapse Positions, - so it is assuming you are loading NON time-lapse data.

- To load time-lapse data, load one Position at a time.

- Note that you can still edit the number of frames if you need to correct it.
- However, you can only edit the metadata, then the loading process will be stopped. - """) - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information( - self, 'Why is the number of frames grayed out?', txt - ) - - def addAdditionalValues(self, values): - if values is None: - return - - for i, (name, value) in enumerate(values.items()): - self.addField() - nameWidget = self.additionalFieldsWidgets[i]['nameWidget'] - valueWidget = self.additionalFieldsWidgets[i]['valueWidget'] - nameWidget.setText(str(name).strip('__')) - valueWidget.setText(str(value)) - - def addField(self): - nameWidget = QLineEdit() - nameWidget.setAlignment(Qt.AlignCenter) - valueWidget = QLineEdit() - valueWidget.setAlignment(Qt.AlignCenter) - removeButton = widgets.delPushButton() - - fieldLayout = QGridLayout() - fieldLayout.addWidget(QLabel('Name'), 0, 0) - fieldLayout.addWidget(nameWidget, 1, 0) - fieldLayout.addWidget(QLabel('Value'), 0, 1) - fieldLayout.addWidget(valueWidget, 1, 1) - fieldLayout.addWidget(removeButton, 1, 2) - - self.additionalFieldsWidgets.append({ - 'nameWidget': nameWidget, - 'valueWidget': valueWidget, - 'removeButton': removeButton, - 'layout': fieldLayout - }) - - idx = len(self.additionalFieldsWidgets)-1 - removeButton.clicked.connect(partial(self.removeField, idx)) - - row = self.mainLayout.count()-3 - self.mainLayout.insertLayout(row, fieldLayout) - - def removeField(self, idx): - widgets = self.additionalFieldsWidgets[idx] - - layoutToRemove = widgets['layout'] - for row in range(layoutToRemove.rowCount()): - for col in range(layoutToRemove.columnCount()): - item = layoutToRemove.itemAtPosition(row, col) - if item is not None: - widget = item.widget() - layoutToRemove.removeWidget(widget) - - self.additionalFieldsWidgets.pop(idx) - - self.mainLayout.removeItem(layoutToRemove) - - def showAddFieldInfo(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - Add a field (name and value) that will be saved to the - metadata.csv file and as a column in the - acdc_output.csv table.

- Example: a strain name or the replicate number. - """) - msg.information(self, 'Add field info', txt) - - def infoSegm3D(self): - txt = ( - 'Cell-ACDC supports both 2D and 3D segmentation. If your data ' - 'also have a time dimension, then you can choose to segment ' - 'a specific z-slice (2D segmentation mask per frame) or all of them ' - '(3D segmentation mask per frame)

' - 'In any case, if you choose to activate 3D segmentation then the ' - 'segmentation mask will have the same number of z-slices ' - 'of the image data.

' - 'Additionally, in the model parameters window, you will be able ' - 'to choose if you want to segment the entire 3D volume at once ' - 'or use the 2D model on each z-slice, one by one.

' - 'NOTE: if the toggle is disabled it means you already ' - 'loaded segmentation data and the shape cannot be changed now.
' - 'if you need to start with a blank segmentation, ' - 'use the "Create a new segmentation file" button instead of the ' - '"Load folder" button.' - '
' - ) - msg = widgets.myMessageBox() - msg.setIcon() - msg.setWindowTitle(f'3D segmentation info') - msg.addText(html_utils.paragraph(txt)) - msg.addButton(' Ok ') - msg.exec_() - - def SizeZvalueChanged(self, val): - if len(self.imgDataShape) < 3: - return - - if val > 1 and self.imgDataShape is not None: - maxSizeZ = self.imgDataShape[-3] - self.SizeZ_SpinBox.setMaximum(maxSizeZ) - else: - self.SizeZ_SpinBox.setMaximum(2147483647) - - if val > 1: - if self.ask_PhysicalSizes: - self.PhysicalSizeZSpinBox.show() - self.PhysicalSizeZLabel.show() - if self.askSegm3D: - self.isSegm3DLabel.show() - self.isSegm3Dtoggle.show() - self.infoButtonSegm3D.show() - else: - self.PhysicalSizeZSpinBox.hide() - self.PhysicalSizeZLabel.hide() - self.isSegm3DLabel.hide() - self.isSegm3Dtoggle.hide() - self.infoButtonSegm3D.hide() - - self.checkSegmDataShape() - - def checkSegmDataShape(self): - if self.posData is None: - return - - if self.isSegm3Dtoggle.isEnabled(): - return - - SizeT = self.SizeT_SpinBox.value() - SizeZ = self.SizeZ_SpinBox.value() - segm_data_ndim = self.posData.segm_data.ndim - isSegm3D = False - if segm_data_ndim == 4: - # Segm data is 4D so it must be 3D over time - isSegm3D = True - elif segm_data_ndim == 3 and SizeZ > 1 and SizeT == 1: - # Segm data is 3D while SizeT == 1 and SizeZ > 1 - # --> also segm is 3D z-stack - isSegm3D = True - - self.isSegm3Dtoggle.setDisabled(False) - self.isSegm3Dtoggle.setChecked(isSegm3D) - self.isSegm3Dtoggle.setDisabled(True) - - def TimeIncrementShowHide(self, val): - self.checkSegmDataShape() - if not self.ask_TimeIncrement: - return - - if val > 1: - self.TimeIncrementSpinBox.show() - self.TimeIncrementLabel.show() - else: - self.TimeIncrementSpinBox.hide() - self.TimeIncrementLabel.hide() - - def allowEditSizeT(self, checked): - if checked: - self.SizeT_SpinBox.setDisabled(False) - if self.SizeT_metadata is not None: - self.SizeT_SpinBox.setValue(self.SizeT_metadata) - else: - self.SizeT_SpinBox.setDisabled(True) - self.SizeT_SpinBox.setValue(1) - - def warnEditingMetadata(self, Size, Size_metadata, which_dim): - txt = html_utils.paragraph(f""" - The number of {which_dim} in the saved metadata is {Size_metadata}, - but you are requesting to change it to {Size}.

- Are you sure you want to proceed? - """) - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - _, noButton, yesButton = msg.warning( - self, 'WARNING: Edinting saved metadata', txt, - buttonsTexts=('Cancel', 'No', 'Yes, edit the metadata') - ) - return msg.clickedButton == yesButton - - def ok_cb(self, checked=False): - self.cancel = False - self.SizeT = self.SizeT_SpinBox.value() - self.SizeZ = self.SizeZ_SpinBox.value() - - if self.SizeT_metadata is not None: - if self.SizeT != self.SizeT_metadata: - proceed = self.warnEditingMetadata( - self.SizeT, self.SizeT_metadata, 'frames' - ) - if not proceed: - return - - if self.SizeZ_metadata is not None: - if self.SizeZ != self.SizeZ_metadata: - proceed = self.warnEditingMetadata( - self.SizeZ, self.SizeZ_metadata, 'z-slices' - ) - if not proceed: - return - - - self.isSegm3D = self.isSegm3Dtoggle.isChecked() - - self.TimeIncrement = self.TimeIncrementSpinBox.value() - self.PhysicalSizeX = self.PhysicalSizeXSpinBox.value() - self.PhysicalSizeY = self.PhysicalSizeYSpinBox.value() - self.PhysicalSizeZ = self.PhysicalSizeZSpinBox.value() - self._additionalValues = { - f"__{field['nameWidget'].text()}":field['valueWidget'].text() - for field in self.additionalFieldsWidgets - } - proceed = self.checkShapeMismatchMetadata() - if not proceed: - return - - if self.posData is not None and self.sender() != self.okButton: - exp_path = self.posData.exp_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) - if self.sender() == self.selectButton: - select_folder = load.select_exp_folder() - select_folder.pos_foldernames = pos_foldernames - select_folder.QtPrompt( - self, pos_foldernames, allow_cancel=False, toggleMulti=True - ) - pos_foldernames = select_folder.selected_pos - for pos in pos_foldernames: - images_path = os.path.join(exp_path, pos, 'Images') - ls = myutils.listdir(images_path) - search = [file for file in ls if file.find('metadata.csv')!=-1] - metadata_df = None - if search: - fileName = search[0] - metadata_csv_path = os.path.join(images_path, fileName) - metadata_df = pd.read_csv( - metadata_csv_path - ).set_index('Description') - if metadata_df is not None: - metadata_df.at['TimeIncrement', 'values'] = self.TimeIncrement - metadata_df.at['PhysicalSizeZ', 'values'] = self.PhysicalSizeZ - metadata_df.at['PhysicalSizeY', 'values'] = self.PhysicalSizeY - metadata_df.at['PhysicalSizeX', 'values'] = self.PhysicalSizeX - metadata_df.to_csv(metadata_csv_path) - - search = [file for file in ls if file.find('acdc_output.csv')!=-1] - acdc_df = None - if search: - fileName = search[0] - acdc_df_path = os.path.join(images_path, fileName) - acdc_df = pd.read_csv(acdc_df_path) - yx_pxl_to_um2 = self.PhysicalSizeY*self.PhysicalSizeX - vox_to_fl = self.PhysicalSizeY*(self.PhysicalSizeX**2) - if 'cell_vol_fl' not in acdc_df.columns: - continue - acdc_df['cell_vol_fl'] = acdc_df['cell_vol_vox']*vox_to_fl - acdc_df['cell_area_um2'] = acdc_df['cell_area_pxl']*yx_pxl_to_um2 - acdc_df['time_seconds'] = acdc_df['frame_i']*self.TimeIncrement - try: - acdc_df.to_csv(acdc_df_path, index=False) - except PermissionError: - err_msg = html_utils.paragraph( - 'The below file is open in another app ' - '(Excel maybe?).

' - f'{acdc_df_path}

' - 'Close file and then press "Ok".' - ) - msg = widgets.myMessageBox() - msg.critical(self, 'Permission denied', err_msg) - acdc_df.to_csv(acdc_df_path, index=False) - - elif self.sender() == self.selectButton: - pass - - self.close() - - def checkShapeMismatchMetadata(self): - valid4D = True - valid3D = True - valid2D = True - if self.imgDataShape is None: - self.close() - elif len(self.imgDataShape) == 4: - T, Z, Y, X = self.imgDataShape - valid4D = self.SizeT == T and self.SizeZ == Z - elif len(self.imgDataShape) == 3: - TorZ, Y, X = self.imgDataShape - valid3D = self.SizeT == TorZ or self.SizeZ == TorZ - elif len(self.imgDataShape) == 2: - valid2D = self.SizeT == 1 and self.SizeZ == 1 - - valid = all([valid4D, valid3D, valid2D]) - if valid: - return True - - if not valid4D: - txt = (f""" - You loaded 4D data, hence the number of frames MUST be - {T}
and the number of z-slices MUST be {Z}.

- What do you want to do? - """) - if not valid3D: - txt = (f""" - You loaded 3D data, hence either the number of frames or - the number of z-slices is {TorZ}.

- However, if the number of frames is greater than 1 then the
- number of z-slices MUST be 1, and vice-versa.

- What do you want to do? - """) - - if not valid2D: - txt = (f""" - You loaded 2D data, hence the number of frames MUST be 1 - and the number of z-slices MUST be 1.

- What do you want to do? - """) - - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(txt) - - continueButton = widgets.okPushButton('Continue anyway') - correctButton = widgets.editPushButton('Let me correct') - - msg.warning( - self, 'Shape-metadata mismatch', txt, - buttonsTexts=(continueButton, correctButton) - ) - if msg.cancel or msg.clickedButton == correctButton: - return False - - return True - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QCropZtool(QBaseDialog): - sigClose = Signal() - sigZvalueChanged = Signal(str, int) - sigReset = Signal() - sigCrop = Signal(int, int) - - def __init__( - self, SizeZ, cropButtonText='Apply crop', parent=None, - addDoNotShowAgain=False, title='Select z-slices' - ): - super().__init__(parent) - - self.cancel = True - - self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) - - self.SizeZ = SizeZ - self.numDigits = len(str(self.SizeZ)) - - self.setWindowTitle(title) - - layout = QGridLayout() - buttonsLayout = QHBoxLayout() - - self.lowerZscrollbar = widgets.ScrollBarWithNumericControl() - self.lowerZscrollbar.setMaximum(SizeZ) - self.lowerZscrollbar.setMinimum(1) - self.lowerZscrollbar.setValue(1) - - self.upperZscrollbar = widgets.ScrollBarWithNumericControl() - self.upperZscrollbar.setMaximum(SizeZ) - self.upperZscrollbar.setValue(SizeZ) - - cancelButton = widgets.cancelPushButton('Cancel') - cropButton = widgets.okPushButton(cropButtonText) - buttonsLayout.addWidget(cropButton) - buttonsLayout.addWidget(cancelButton) - - row = 0 - layout.addWidget( - QLabel('Lower z-slice '), row, 0, alignment=Qt.AlignRight - ) - layout.addWidget(self.lowerZscrollbar, row, 1) - - row += 1 - layout.setRowStretch(row, 5) - - row += 1 - layout.addWidget( - QLabel('Upper z-slice '), row, 0, alignment=Qt.AlignRight - ) - layout.addWidget(self.upperZscrollbar, row, 1) - - row += 1 - if addDoNotShowAgain: - self.doNotShowAgainCheckbox = QCheckBox('Do not ask again') - layout.addWidget( - self.doNotShowAgainCheckbox, row, 1, alignment=Qt.AlignLeft - ) - row += 1 - - layout.addLayout(buttonsLayout, row, 1, alignment=Qt.AlignRight) - - layout.setColumnStretch(0, 0) - layout.setColumnStretch(1, 10) - - self.setLayout(layout) - - # resetButton.clicked.connect(self.emitReset) - cropButton.clicked.connect(self.emitCrop) - cancelButton.clicked.connect(self.close) - self.lowerZscrollbar.sigValueChanged.connect(self.ZvalueChanged) - self.upperZscrollbar.sigValueChanged.connect(self.ZvalueChanged) - - def emitReset(self): - self.sigReset.emit() - - def emitCrop(self): - self.cancel = False - low_z = self.lowerZscrollbar.value() - 1 - high_z = self.upperZscrollbar.value() - 1 - self.sigCrop.emit(low_z, high_z) - self.close() - - def updateScrollbars(self, lower_z, upper_z): - self.lowerZscrollbar.setValue(lower_z+1) - self.upperZscrollbar.setValue(upper_z+1) - - def ZvalueChanged(self, value): - which = 'lower' if self.sender() == self.lowerZscrollbar else 'upper' - if which == 'lower' and value > self.upperZscrollbar.value()-1: - self.lowerZscrollbar.setValue(self.upperZscrollbar.value()-1) - return - if which == 'upper' and value < self.lowerZscrollbar.value()+1: - self.upperZscrollbar.setValue(self.lowerZscrollbar.value()+1) - return - - z_slice_n = value - 1 - self.sigZvalueChanged.emit(which, z_slice_n) - - def showEvent(self, event): - self.resize(int(self.width()*1.5), self.height()) - - def closeEvent(self, event): - super().closeEvent(event) - self.sigClose.emit() - -class randomWalkerDialog(QDialog): - def __init__(self, mainWindow): - super().__init__(mainWindow) - self.cancel = True - self.mainWindow = mainWindow - - if mainWindow is not None: - posData = self.mainWindow.data[self.mainWindow.pos_i] - items = [posData.filename] - else: - items = ['test'] - try: - posData = self.mainWindow.data[self.mainWindow.pos_i] - items.extend(list(posData.ol_data_dict.keys())) - except Exception as e: - pass - - self.keys = items - - self.setWindowTitle('Random walker segmentation') - - self.colors = [self.mainWindow.RWbkgrColor, - self.mainWindow.RWforegrColor] - - mainLayout = QVBoxLayout() - paramsLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - - self.mainWindow.clearAllItems() - - row = 0 - paramsLayout.addWidget(QLabel('Background threshold:'), row, 0) - row += 1 - self.bkgrThreshValLabel = QLabel('0.05') - paramsLayout.addWidget(self.bkgrThreshValLabel, row, 1) - self.bkgrThreshSlider = QSlider(Qt.Horizontal) - self.bkgrThreshSlider.setMinimum(1) - self.bkgrThreshSlider.setMaximum(100) - self.bkgrThreshSlider.setValue(5) - self.bkgrThreshSlider.setTickPosition(QSlider.TickPosition.TicksBelow) - self.bkgrThreshSlider.setTickInterval(10) - paramsLayout.addWidget(self.bkgrThreshSlider, row, 0) - - row += 1 - foregrQSLabel = QLabel('Foreground threshold:') - # padding: top, left, bottom, right - foregrQSLabel.setStyleSheet("font-size:13px; padding:5px 0px 0px 0px;") - paramsLayout.addWidget(foregrQSLabel, row, 0) - row += 1 - self.foregrThreshValLabel = QLabel('0.95') - paramsLayout.addWidget(self.foregrThreshValLabel, row, 1) - self.foregrThreshSlider = QSlider(Qt.Horizontal) - self.foregrThreshSlider.setMinimum(1) - self.foregrThreshSlider.setMaximum(100) - self.foregrThreshSlider.setValue(95) - self.foregrThreshSlider.setTickPosition(QSlider.TickPosition.TicksBelow) - self.foregrThreshSlider.setTickInterval(10) - paramsLayout.addWidget(self.foregrThreshSlider, row, 0) - - # Parameters link label - row += 1 - url1 = 'https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_random_walker_segmentation.html' - url2 = 'https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.random_walker' - htmlTxt1 = f'here' - htmlTxt2 = f'here' - seeHereLabel = QLabel() - seeHereLabel.setText(f'See {htmlTxt1} and {htmlTxt2} for details ' - 'about Random walker segmentation.') - seeHereLabel.setTextFormat(Qt.RichText) - seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) - seeHereLabel.setOpenExternalLinks(True) - font = QFont() - font.setPixelSize(12) - seeHereLabel.setFont(font) - seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") - paramsLayout.addWidget(seeHereLabel, row, 0, 1, 2) - - computeButton = QPushButton('Compute segmentation') - closeButton = QPushButton('Close') - - buttonsLayout.addWidget(computeButton, alignment=Qt.AlignRight) - buttonsLayout.addWidget(closeButton, alignment=Qt.AlignLeft) - - paramsLayout.setContentsMargins(0, 10, 0, 0) - buttonsLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(paramsLayout) - mainLayout.addLayout(buttonsLayout) - - self.bkgrThreshSlider.sliderMoved.connect(self.bkgrSliderMoved) - self.foregrThreshSlider.sliderMoved.connect(self.foregrSliderMoved) - computeButton.clicked.connect(self.computeSegmAndPlot) - closeButton.clicked.connect(self.close) - - self.setLayout(mainLayout) - - self.getImage() - self.plotMarkers() - - def getImage(self): - img = self.mainWindow.getDisplayedImg1() - self.img = img/img.max() - self.imgRGB = (skimage.color.gray2rgb(self.img)*255).astype(np.uint8) - - def setSize(self): - x = self.pos().x() - y = self.pos().y() - h = self.size().height() - w = self.size().width() - if w < 400: - w = 400 - self.setGeometry(x, y, w, h) - - def plotMarkers(self): - imgMin, imgMax = self.computeMarkers() - - img = self.img - - imgRGB = self.imgRGB.copy() - R, G, B = self.colors[0] - imgRGB[:, :, 0][img < imgMin] = R - imgRGB[:, :, 1][img < imgMin] = G - imgRGB[:, :, 2][img < imgMin] = B - R, G, B = self.colors[1] - imgRGB[:, :, 0][img > imgMax] = R - imgRGB[:, :, 1][img > imgMax] = G - imgRGB[:, :, 2][img > imgMax] = B - - self.mainWindow.img1.setImage(imgRGB) - - def computeMarkers(self): - bkgrThresh = self.bkgrThreshSlider.sliderPosition()/100 - foregrThresh = self.foregrThreshSlider.sliderPosition()/100 - img = self.img - self.markers = np.zeros(img.shape, np.uint8) - imgRange = img.max() - img.min() - imgMin = img.min() + imgRange*bkgrThresh - imgMax = img.min() + imgRange*foregrThresh - self.markers[img < imgMin] = 1 - self.markers[img > imgMax] = 2 - return imgMin, imgMax - - def computeSegm(self, checked=True): - self.mainWindow.storeUndoRedoStates(False) - self.mainWindow.titleLabel.setText( - 'Randomly walking around... ', color='w') - img = self.img - img = skimage.exposure.rescale_intensity(img) - t0 = time.time() - lab = skimage.segmentation.random_walker(img, self.markers, mode='bf') - lab = skimage.measure.label(lab>1) - t1 = time.time() - if len(np.unique(lab)) > 2: - lab = skimage.morphology.remove_small_objects(lab, min_size=5) - posData = self.mainWindow.data[self.mainWindow.pos_i] - posData.lab = lab - return t1-t0 - - def computeSegmAndPlot(self): - deltaT = self.computeSegm() - - posData = self.mainWindow.data[self.mainWindow.pos_i] - - self.mainWindow.update_rp() - self.mainWindow.tracking(enforce=True) - self.mainWindow.updateAllImages() - self.mainWindow.warnEditingWithCca_df('Random Walker segmentation') - txt = f'Random Walker segmentation computed in {deltaT:.3f} s' - print('-----------------') - print(txt) - print('=================') - # self.mainWindow.titleLabel.setText(txt, color='g') - - def bkgrSliderMoved(self, intVal): - self.bkgrThreshValLabel.setText(f'{intVal/100:.2f}') - self.plotMarkers() - - def foregrSliderMoved(self, intVal): - self.foregrThreshValLabel.setText(f'{intVal/100:.2f}') - self.plotMarkers() - - def closeEvent(self, event): - self.mainWindow.segmModel = '' - self.mainWindow.updateAllImages() - -class FutureFramesAction_QDialog(QDialog): - def __init__( - self, frame_i, last_tracked_i, change_txt, - applyTrackingB=False, parent=None, - addApplyAllButton=False - ): - self.decision = None - self.last_tracked_i = last_tracked_i - super().__init__(parent) - self.setWindowTitle('Future frames action?') - - mainLayout = QVBoxLayout() - txtLayout = QVBoxLayout() - doNotShowLayout = QVBoxLayout() - buttonsLayout = QVBoxLayout() - - txt = html_utils.paragraph( - 'You already visited/checked future frames ' - f'{frame_i+1}-{last_tracked_i+1}.

' - f'The requested "{change_txt}" change might result in
' - 'NON-correct segmentation/tracking for those frames.
' - ) - - txtLabel = QLabel(txt) - txtLabel.setAlignment(Qt.AlignCenter) - txtLayout.addWidget(txtLabel, alignment=Qt.AlignCenter) - - options = [ - f'Apply the "{change_txt}" only to current frame and re-initialize
' - 'the future frames to the segmentation file present
' - 'on the hard drive.', - 'Apply only to this frame and keep the future frames as they are.', - 'Apply the change to ALL visited/checked future frames.' - ] - if addApplyAllButton: - options.append('Apply to ALL future frames including unvisited ones.') - if applyTrackingB: - options.append('Repeat ONLY tracking for all future frames (RECOMMENDED)') - - infoTxt = html_utils.paragraph( - f'Choose one of the following options:' - f'{html_utils.to_list(options, ordered=True)}' - ) - - infotxtLabel = QLabel(infoTxt) - txtLayout.addWidget(infotxtLabel, alignment=Qt.AlignCenter) - - noteLayout = QHBoxLayout() - noteTxt = html_utils.paragraph( - 'Only changes applied to current frame can be undone.
' - 'Changes applied to future frames CANNOT be UNDONE
' - ) - noteLayout.addWidget( - QLabel(html_utils.paragraph('NOTE:')), alignment=Qt.AlignTop - ) - noteTxtLabel = QLabel(noteTxt) - noteLayout.addWidget(noteTxtLabel) - noteLayout.addStretch(1) - txtLayout.addSpacing(10) - txtLayout.addLayout(noteLayout) - - # Do not show this message again checkbox - doNotShowCheckbox = QCheckBox( - 'Remember my choice and do not show this message again') - doNotShowLayout.addWidget(doNotShowCheckbox) - doNotShowLayout.setContentsMargins(50, 0, 0, 10) - self.doNotShowCheckbox = doNotShowCheckbox - - apply_and_reinit_b = widgets.reloadPushButton( - ' 1. Apply only to this frame and re-initialize future frames' - ) - - self.apply_and_reinit_b = apply_and_reinit_b - buttonsLayout.addWidget(apply_and_reinit_b) - - apply_and_NOTreinit_b = widgets.currentPushButton( - ' 2. Apply only to this frame and keep future frames as they are' - ) - self.apply_and_NOTreinit_b = apply_and_NOTreinit_b - buttonsLayout.addWidget(apply_and_NOTreinit_b) - - apply_to_all_visited_b = widgets.futurePushButton( - ' 3. Apply to all future VISITED frames' - ) - self.apply_to_all_visited_b = apply_to_all_visited_b - buttonsLayout.addWidget(apply_to_all_visited_b) - - - if addApplyAllButton: - apply_to_all_b = QPushButton( - ' 4. Apply to ALL future frames (including unvisted)' - ) - apply_to_all_b.setIcon(QIcon(':arrow_future_all.svg')) - self.apply_to_all_b = apply_to_all_b - buttonsLayout.addWidget(apply_to_all_b) - - self.applyTrackingButton = None - if applyTrackingB: - n = '5' if addApplyAllButton else '4' - applyTrackingButton = QPushButton( - f' {n}. Repeat ONLY tracking for all future frames' - ) - applyTrackingButton.setIcon(QIcon(':repeat-tracking.svg')) - self.applyTrackingButton = applyTrackingButton - buttonsLayout.addWidget(applyTrackingButton) - - buttonsLayout.setContentsMargins(20, 0, 20, 0) - - self.formLayout = QFormLayout() - - ButtonsGroup = QButtonGroup(self) - ButtonsGroup.addButton(apply_and_reinit_b) - ButtonsGroup.addButton(apply_and_NOTreinit_b) - ButtonsGroup.addButton(apply_to_all_visited_b) - if addApplyAllButton: - ButtonsGroup.addButton(apply_to_all_b) - if applyTrackingB: - ButtonsGroup.addButton(applyTrackingButton) - - mainLayout.addLayout(txtLayout) - mainLayout.addLayout(doNotShowLayout) - mainLayout.addLayout(buttonsLayout) - mainLayout.addLayout(self.formLayout) - mainLayout.addStretch(1) - self.mainLayout = mainLayout - self.setLayout(mainLayout) - - # Connect events - ButtonsGroup.buttonClicked.connect(self.buttonClicked) - self.ButtonsGroup = ButtonsGroup - - # self.setModal(True) - - def buttonClicked(self, button): - if button == self.apply_and_reinit_b: - self.decision = 'apply_and_reinit' - self.endFrame_i = None - elif button == self.apply_and_NOTreinit_b: - self.decision = 'apply_and_NOTreinit' - self.endFrame_i = None - elif button == self.apply_to_all_visited_b: - self.decision = 'apply_to_all_visited' - self.endFrame_i = self.last_tracked_i - elif button == self.applyTrackingButton: - self.decision = 'only_tracking' - self.endFrame_i = self.last_tracked_i - elif button == self.apply_to_all_b: - self.decision = 'apply_to_all' - self.endFrame_i = self.last_tracked_i - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - for button in self.ButtonsGroup.buttons(): - button.setMinimumHeight(int(button.height()*1.2)) - if hasattr(self, 'apply_to_all_b'): - iconHeight = self.apply_to_all_b.iconSize().height() - self.apply_to_all_b.setIconSize(QSize(iconHeight*2, iconHeight)) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class ComputeMetricsErrorsDialog(QBaseDialog): - def __init__( - self, errorsDict, log_path='', parent=None, - log_type='custom_metrics' - ): - super().__init__(parent) - - self.errorsDict = errorsDict - - layout = QGridLayout() - - self.setWindowTitle('Errors summary') - - label = QLabel(self) - standardIcon = getattr(QStyle, 'SP_MessageBoxWarning') - icon = self.style().standardIcon(standardIcon) - pixmap = icon.pixmap(60, 60) - label.setPixmap(pixmap) - layout.addWidget(label, 0, 0, alignment=Qt.AlignTop) - - if log_type == 'custom_metrics': - infoText = (""" - When computing custom metrics the following metrics - were ignored because they raised an error.

- """) - elif log_type == 'standard_metrics': - infoText = (""" - Some or all of the standard metrics were NOT saved - because Cell-ACDC encoutered the following errors.

- """) - elif log_type == 'region_props': - rp_url = 'https://scikit-image.org/docs/0.18.x/api/skimage.measure.html#skimage.measure.regionprops' - rp_href = f'skimage.measure.regionprops' - infoText = (f""" - Region properties were NOT saved because Cell-ACDC - encoutered the following errors.
- Region properties are calculated using the scikit-image - function called {rp_href}.

- """) - elif log_type == 'missing_annot': - infoText = (""" - The following Positions were SKIPPED because they did - not have cell cycle annotations.

- To add lineage tree information you first need to do the - cell cycle analysis in module 3 "Main GUI".

- """) - else: - infoText = (""" - Process raised the errors listed below.

- """) - - github_issues_href = f'here' - noteText = (f""" - NOTE: If you need help understanding these errors you can - open an issue on our github page {github_issues_href}. - """) - - infoLabel = QLabel(html_utils.paragraph(f'{infoText}{noteText}')) - infoLabel.setOpenExternalLinks(True) - layout.addWidget(infoLabel, 0, 1) - - scrollArea = QScrollArea() - scrollAreaWidget = QWidget() - textLayout = QVBoxLayout() - for func_name, traceback_format in errorsDict.items(): - nameLabel = QLabel(f'{func_name}: ') - errorMessage = f'\n{traceback_format}' - errorLabel = QLabel(errorMessage) - errorLabel.setTextInteractionFlags( - Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard - ) - # errorLabel.setStyleSheet("background-color: white") - errorLabel.setFrameShape(QFrame.Shape.Panel) - errorLabel.setFrameShadow(QFrame.Shadow.Sunken) - textLayout.addWidget(nameLabel) - textLayout.addWidget(errorLabel) - textLayout.addStretch(1) - - scrollAreaWidget.setLayout(textLayout) - scrollArea.setWidget(scrollAreaWidget) - - layout.addWidget(scrollArea, 1, 1) - - buttonsLayout = QHBoxLayout() - showLogButton = widgets.showInFileManagerButton('Show log file...') - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(showLogButton) - - copyButton = widgets.copyPushButton('Copy error message') - copyButton.clicked.connect(self.copyErrorMessage) - buttonsLayout.addWidget(copyButton) - self.copyButton = copyButton - self.copyButton.text = 'Copy error message' - self.copyButton.icon = self.copyButton.icon() - - okButton = widgets.okPushButton(' Ok ') - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - showLogButton.clicked.connect(partial(myutils.showInExplorer, log_path)) - okButton.clicked.connect(self.close) - layout.setVerticalSpacing(10) - layout.addLayout(buttonsLayout, 2, 1) - - self.setLayout(layout) - self.setFont(font) - - def copyErrorMessage(self): - cb = QApplication.clipboard() - cb.clear(mode=cb.Clipboard) - copiedText = '' - for _, traceback_format in self.errorsDict.items(): - errorBlock = f'{"="*30}\n{traceback_format}{"*"*30}' - copiedText = f'{copiedText}{errorBlock}' - cb.setText(copiedText, mode=cb.Clipboard) - print('Error message copied.') - self.copyButton.setIcon(QIcon(':okButton.svg')) - self.copyButton.setText(' Copied to clipboard!') - QTimer.singleShot(2000, self.restoreCopyButton) - - def restoreCopyButton(self): - self.copyButton.setText(self.copyButton.text) - self.copyButton.setIcon(self.copyButton.icon) - - def showEvent(self, a0) -> None: - self.copyButton.setFixedWidth(self.copyButton.width()) - return super().showEvent(a0) - -class PostProcessSegmParams(QGroupBox): - valueChanged = Signal(object) - editingFinished = Signal() - - def __init__( - self, title, posData, - useSliders=False, - parent=None, - maxSize=None, - force_postprocess_2D=False - ): - QGroupBox.__init__(self, title, parent) - SizeZ = posData.SizeZ - self.isSegm3D = posData.isSegm3D - self.channelName = posData.user_ch_name - self.useSliders = useSliders - self.force_postprocess_2D = force_postprocess_2D - if maxSize is None: - maxSize=2147483647 - - layout = QGridLayout() - - self.controlWidgets = [] - - row = 0 - label = QLabel("Minimum area (pixels) ") - layout.addWidget(label, row, 0, alignment=Qt.AlignRight) - - minSize_SB = widgets.PostProcessSegmWidget( - 1, 1000, 10, useSliders, label=label - ) - - txt = ( - 'Area is the total number of pixels in the segmented object.' - ) - - layout.addWidget(minSize_SB, row, 1) - infoButton = widgets.infoPushButton() - infoButton.clicked.connect(self.showInfo) - infoButton.tooltip = txt - infoButton.name = 'area' - infoButton.desc = f'less than "{label.text()}"' - layout.addWidget(infoButton, row, 2) - self.minSize_SB = minSize_SB - self.controlWidgets.append(minSize_SB) - - # minSize_SB.disableThisCheckbox = QCheckBox('Disable this filter') - # layout.addWidget(minSize_SB.disableThisCheckbox, row, 3) - - row += 1 - label = QLabel("Minimum solidity (0-1) ") - layout.addWidget(label, row, 0, alignment=Qt.AlignRight) - minSolidity_DSB = widgets.PostProcessSegmWidget( - 0, 1.0, 0.5, useSliders, isFloat=True, normalize=True, - label=label - ) - minSolidity_DSB.setValue(0.5) - minSolidity_DSB.setSingleStep(0.1) - self.controlWidgets.append(minSolidity_DSB) - - txt = ( - 'Solidity is a measure of convexity. A solidity of 1 means ' - 'that the shape is fully convex (i.e., equal to the convex hull). ' - 'As solidity approaches 0 the object is more concave.
' - 'Write 0 for ignoring this parameter.' - ) - - layout.addWidget(minSolidity_DSB, row, 1) - infoButton = widgets.infoPushButton() - infoButton.clicked.connect(self.showInfo) - infoButton.tooltip = txt - infoButton.name = 'solidity' - infoButton.desc = f'less than "{label.text()}"' - layout.addWidget(infoButton, row, 2) - self.minSolidity_DSB = minSolidity_DSB - - row += 1 - label = QLabel("Max elongation (1=circle) ") - layout.addWidget(label, row, 0, alignment=Qt.AlignRight) - maxElongation_DSB = widgets.PostProcessSegmWidget( - 0, 100, 3, useSliders, isFloat=True, normalize=False, - label=label - ) - maxElongation_DSB.setDecimals(1) - maxElongation_DSB.setSingleStep(1.0) - - txt = ( - 'Elongation is the ratio between major and minor axis lengths. ' - 'An elongation of 1 is like a circle.
' - 'Write 0 for ignoring this parameter.' - ) - - layout.addWidget(maxElongation_DSB, row, 1) - infoButton = widgets.infoPushButton() - infoButton.clicked.connect(self.showInfo) - infoButton.tooltip = txt - infoButton.name = 'elongation' - infoButton.desc = f'greater than "{label.text()}"' - layout.addWidget(infoButton, row, 2) - self.maxElongation_DSB = maxElongation_DSB - self.controlWidgets.append(maxElongation_DSB) - - if self.isSegm3D: - row += 1 - label = QLabel("Minimum number of z-slices ") - layout.addWidget(label, row, 0, alignment=Qt.AlignRight) - minObjSizeZ_SB = widgets.PostProcessSegmWidget( - 0, SizeZ, 3, useSliders, isFloat=False, normalize=False, - label=label - ) - - txt = ( - 'Minimum number of z-slices per object.' - ) - - layout.addWidget(minObjSizeZ_SB, row, 1) - infoButton = widgets.infoPushButton() - infoButton.clicked.connect(self.showInfo) - infoButton.tooltip = txt - infoButton.name = 'number of z-slices' - infoButton.desc = f'less than "{label.text()}"' - layout.addWidget(infoButton, row, 2) - self.minObjSizeZ_SB = minObjSizeZ_SB - self.controlWidgets.append(minObjSizeZ_SB) - else: - self.minObjSizeZ_SB = widgets.NoneWidget() - - row += 1 - addCustomFeatureLayout = QHBoxLayout() - self.addCustomFeaturesButton = widgets.setPushButton( - 'Select custom features for post-processing...', - ) - addCustomFeatureLayout.addWidget(self.addCustomFeaturesButton) - addCustomFeatureLayout.addStretch(1) - self.selectedFeaturesDialog = SelectFeaturesRangeDialog( - posData=posData, parent=self, - force_postprocess_2D=force_postprocess_2D - ) - self.selectedFeaturesDialog.hide() - self.addCustomFeaturesButton.clicked.connect( - self.selectedFeaturesDialog.show - ) - self.selectedFeaturesDialog.sigValueChanged.connect(self.onValueChanged) - - layout.addLayout(addCustomFeatureLayout, row, 0, 1, 2) - - layout.setColumnStretch(1, 2) - # layout.setRowStretch(row+1, 1) - - self.setLayout(layout) - - for widget in self.controlWidgets: - widget.valueChanged.connect(self.onValueChanged) - widget.editingFinished.connect(self.onEditingFinished) - - def selectedFeaturesRange(self): - return self.selectedFeaturesDialog.groupbox.selectedFeaturesRange() - - def groupedFeatures(self): - return self.selectedFeaturesDialog.groupbox.groupedFeatures() - - def restoreDefault(self): - self.minSolidity_DSB.setValue(0.5) - self.minSize_SB.setValue(10) - self.maxElongation_DSB.setValue(3) - self.minObjSizeZ_SB.setValue(3) - self.selectedFeaturesDialog.groupbox.resetFields() - - def restoreFromKwargs(self, kwargs): - for name, value in kwargs.items(): - if name == 'min_solidity': - self.minSolidity_DSB.setValue(value) - elif name == 'min_area': - self.minSize_SB.setValue(value) - elif name == 'max_elongation': - self.maxElongation_DSB.setValue(value) - elif name == 'min_obj_no_zslices': - self.minObjSizeZ_SB.setValue(value) - - def kwargs(self): - kwargs = { - 'min_solidity': self.minSolidity_DSB.value(), - 'min_area': self.minSize_SB.value(), - 'max_elongation': self.maxElongation_DSB.value(), - 'min_obj_no_zslices': self.minObjSizeZ_SB.value() - } - return kwargs - - def onValueChanged(self, value): - self.valueChanged.emit(value) - - def onEditingFinished(self): - self.editingFinished.emit() - - def showInfo(self): - title = f'{self.sender().text()} info' - tooltip = self.sender().tooltip - name = self.sender().name - desc = self.sender().desc - txt = (f""" - The post-processing step is applied to the output of the - segmentation model.

- During this step, Cell-ACDC will remove all the objects with {name} - {desc}.

- {tooltip} - """) - if self.isCheckable(): - note = f"""" - You can deactivate this step by un-checking the checkbox - called "Post-processing parameters". - """ - txt = f'{txt}{note}' - msg = widgets.myMessageBox(showCentered=False) - msg.information(self, title, html_utils.paragraph(txt)) - -class PostProcessSegmDialog(QBaseDialog): - sigClosed = Signal() - sigValueChanged = Signal(object, object) - sigEditingFinished = Signal() - sigApplyToAllFutureFrames = Signal(object, object, object) - - def __init__( - self, posData, - mainWin=None, - useSliders=True, - maxSize=None - ): - super().__init__(mainWin) - self.cancel = True - self.mainWin = mainWin - self.isTimelapse = False - self.isMultiPos = False - if mainWin is not None: - self.isMultiPos = len(self.mainWin.data) > 1 - self.isTimelapse = self.mainWin.data[self.mainWin.pos_i].SizeT > 1 - - self.setWindowTitle('Post-processing segmentation parameters') - self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) - - mainLayout = QVBoxLayout() - buttonsLayout = QHBoxLayout() - - self.postProcessGroupbox = PostProcessSegmParams( - 'Post-processing parameters', posData, - useSliders=useSliders, - maxSize=maxSize, - parent=mainWin - ) - - self.postProcessGroupbox.valueChanged.connect(self.valueChanged) - self.postProcessGroupbox.editingFinished.connect(self.onEditingFinished) - - if self.isTimelapse: - applyAllButton = widgets.futurePushButton( - 'Apply to all frames...' - ) - applyAllButton.clicked.connect(self.applyAll_cb) - applyButton = widgets.okPushButton( - 'Apply', isDefault=False - ) - applyButton.clicked.connect(self.apply_cb) - elif self.isMultiPos: - applyAllButton = widgets.futurePushButton( - 'Apply to all Positions...' - ) - applyAllButton.clicked.connect(self.applyAll_cb) - applyButton = widgets.okPushButton('Apply', isDefault=False) - applyButton.clicked.connect(self.apply_cb) - else: - applyAllButton = widgets.okPushButton('Apply', isDefault=False) - applyAllButton.clicked.connect(self.ok_cb) - applyButton = None - - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - if applyButton is not None: - buttonsLayout.addWidget(applyButton) - buttonsLayout.addWidget(applyAllButton) - - emitEditingFinishedButton = widgets.okPushButton() - buttonsLayout.addWidget(emitEditingFinishedButton) - emitEditingFinishedButton.hide() - buttonsLayout.setContentsMargins(0,10,0,0) - - mainLayout.addWidget(self.postProcessGroupbox) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - cancelButton.clicked.connect(self.cancel_cb) - - if mainWin is not None: - self.setPosData() - - def keyPressEvent(self, event) -> None: - return super().keyPressEvent(event) - - def setPosData(self): - if self.mainWin is None: - return - - self.mainWin.storeUndoRedoStates(False) - self.posData = self.mainWin.data[self.mainWin.pos_i] - # self.img.setCurrentPosIndex(self.pos_i) - # self.img.minMaxValuesMapper = self.mainWin.img1.minMaxValuesMapper - self.origLab = self.posData.lab.copy() - self.origRp = skimage.measure.regionprops(self.origLab) - self.origObjs = {obj.label:obj for obj in self.origRp} - - def valueChanged(self, value): - lab, delObjs = self.apply() - self.sigValueChanged.emit(lab, delObjs) - - def apply(self, origLab=None): - self.mainWin.warnEditingWithCca_df( - 'post-processing segmentation mask', update_images=False - ) - ccaAnnotRemoved = self.mainWin.removeCcaAnnotationsCurrentFrame() - if ccaAnnotRemoved: - self.mainWin.updateAllImages() - - - if origLab is None: - origLab = self.origLab.copy() - - lab, delIDs = core.post_process_segm( - origLab, return_delIDs=True, **self.postProcessGroupbox.kwargs() - ) - - if self.postProcessGroupbox.selectedFeaturesRange(): - lab, custom_delIDs = features.custom_post_process_segm( - self.posData, - self.postProcessGroupbox.groupedFeatures(), - lab, - self.posData.img_data[self.posData.frame_i], - self.posData.frame_i, - self.posData.filename, - self.posData.user_ch_name, - self.postProcessGroupbox.selectedFeaturesRange(), - return_delIDs=True - ) - delIDs.extend(custom_delIDs) - - delObjs = {delID:self.origObjs[delID] for delID in delIDs} - return lab, delObjs - - def onEditingFinished(self): - self.sigEditingFinished.emit() - - def ok_cb(self): - self.cancel = False - self.apply() - self.onEditingFinished() - self.close() - - def apply_cb(self): - self.cancel = False - self.apply() - self.onEditingFinished() - - def applyAll_cb(self): - self.cancel = False - self.sigApplyToAllFutureFrames.emit( - self.postProcessGroupbox.kwargs(), - self.postProcessGroupbox.groupedFeatures(), - self.postProcessGroupbox.selectedFeaturesRange() - ) - self.close() - - def cancel_cb(self): - self.cancel = True - self.close() - - def undoChanges(self): - if self.mainWin is not None: - self.posData.lab = self.origLab - self.mainWin.update_rp() - self.mainWin.updateAllImages() - - # Undo if changes were applied to all future frames - if hasattr(self, 'origSegmData'): - if self.isTimelapse: - current_frame_i = self.posData.frame_i - for frame_i in range(self.posData.segmSizeT): - self.posData.frame_i = frame_i - origLab = self.origSegmData[frame_i] - lab = self.posData.allData_li[frame_i]['labels'] - if lab is None: - # Non-visited frame modify segm_data - self.posData.segm_data[frame_i] = origLab - else: - self.posData.allData_li[frame_i]['labels'] = origLab.copy() - self.posData.lab = origLab.copy() - self.mainWin.update_rp() - # Get the rest of the stored metadata based on the new lab - self.mainWin.get_data() - self.mainWin.store_data() - # Back to current frame - self.posData.frame_i = current_frame_i - self.mainWin.get_data() - self.mainWin.updateAllImages() - elif self.isMultiPos: - current_pos_i = self.mainWin.pos_i - # Apply to all future frames or future positions - for pos_i, posData in enumerate(self.mainWin.data): - self.mainWin.pos_i = pos_i - origLab = self.origSegmData[pos_i] - self.posData.allData_li[0]['labels'] = lab.copy() - # Get the rest of the stored metadata based on the new lab - self.mainWin.get_data() - self.mainWin.store_data() - # Back to current pos and current frame - self.mainWin.pos_i = current_pos_i - self.mainWin.get_data() - self.mainWin.updateAllImages() - - def show(self, block=False): - # self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show(block=False) - self.resize(int(self.width()*1.5), self.height()) - super().show(block=block) - - def closeEvent(self, event): - self.sigClosed.emit() - if self.cancel: - self.undoChanges() - super().closeEvent(event) - -class imageViewer(QMainWindow): - """Main Window.""" - sigClosed = Signal() - sigHoveringImage = Signal(object, object) - - def __init__( - self, parent=None, posData=None, button_toUncheck=None, - spinBox=None, linkWindow=None, enableOverlay=False, - isSigleFrame=False, enableMirroredCursor=False - ): - self.button_toUncheck = button_toUncheck - self.parent = parent - self.posData = posData - self.spinBox = spinBox - self.linkWindow = linkWindow - self.enableMirroredCursor = enableMirroredCursor - self.isSigleFrame = isSigleFrame - self.minMaxValuesMapper = None - """Initializer.""" - super().__init__(parent) - - if posData is None: - posData = self.parent.data[self.parent.pos_i] - self.posData = posData - self.enableOverlay = enableOverlay - - self.gui_createActions() - self.gui_createMenuBar() - self.gui_createToolBars() - - self.gui_createStatusBar() - - self.gui_createGraphics() - - self.gui_connectImgActions() - - self.gui_createImgWidgets() - self.gui_connectActions() - - self.gui_setSingleFrameMode(self.isSigleFrame) - - self.setupMirroredCursor() - - mainContainer = QWidget() - self.setCentralWidget(mainContainer) - - mainLayout = QGridLayout() - mainLayout.addWidget(self.graphLayout, 0, 0, 1, 1) - mainLayout.addLayout(self.img_Widglayout, 1, 0) - - mainContainer.setLayout(mainLayout) - - self.frame_i = posData.frame_i - self.num_frames = posData.SizeT - - version = myutils.read_version() - self.setWindowTitle(f"Cell-ACDC v{version} - {posData.relPath}") - - def gui_createActions(self): - # File actions - self.exitAction = QAction("&Exit", self) - - # Toolbar actions - self.prevAction = QAction("Previous frame", self) - self.nextAction = QAction("Next Frame", self) - self.jumpForwardAction = QAction("Jump to 10 frames ahead", self) - self.jumpBackwardAction = QAction("Jump to 10 frames back", self) - self.prevAction.setShortcut("left") - self.nextAction.setShortcut("right") - self.jumpForwardAction.setShortcut("up") - self.jumpBackwardAction.setShortcut("down") - self.addAction(self.nextAction) - self.addAction(self.prevAction) - self.addAction(self.jumpBackwardAction) - self.addAction(self.jumpForwardAction) - if self.enableOverlay: - self.overlayButton = widgets.rightClickToolButton(parent=self) - self.overlayButton.setIcon(QIcon(":overlay.svg")) - self.overlayButton.setCheckable(True) - - def gui_createMenuBar(self): - menuBar = self.menuBar() - # File menu - fileMenu = QMenu("&File", self) - menuBar.addMenu(fileMenu) - # fileMenu.addAction(self.newAction) - fileMenu.addAction(self.exitAction) - - def gui_createToolBars(self): - toolbarSize = 30 - - editToolBar = QToolBar("Edit", self) - editToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) - self.addToolBar(editToolBar) - - self.editToolBar = editToolBar - - if self.enableOverlay: - editToolBar.addWidget(self.overlayButton) - - if self.linkWindow: - # Insert a spacing - editToolBar.addWidget(QLabel(' ')) - self.linkWindowCheckbox = QCheckBox("Link to main GUI") - self.linkWindowCheckbox.setChecked(True) - editToolBar.addWidget(self.linkWindowCheckbox) - - if self.enableMirroredCursor: - self.showMirroredCursorCheckbox = QCheckBox( - 'Show mirrored cursor from main window' - ) - self.showMirroredCursorCheckbox.setChecked(True) - editToolBar.addWidget(self.showMirroredCursorCheckbox) - - def setupMirroredCursor(self): - self.cursor = pg.ScatterPlotItem( - symbol='+', pxMode=True, pen=pg.mkPen('k', width=1), - brush=pg.mkBrush('w'), size=16, tip=None - ) - self.Plot.addItem(self.cursor) - - def gui_connectActions(self): - self.exitAction.triggered.connect(self.close) - self.prevAction.triggered.connect(self.prev_frame) - self.nextAction.triggered.connect(self.next_frame) - self.jumpForwardAction.triggered.connect(self.skip10ahead_frames) - self.jumpBackwardAction.triggered.connect(self.skip10back_frames) - if self.enableOverlay: - self.overlayButton.toggled.connect(self.overlay_cb) - self.overlayButton.sigRightClick.connect(self.showOverlayContextMenu) - - def gui_setSingleFrameMode(self, isSingleFrame: bool): - if not isSingleFrame: - return - - self.framesScrollBar.setDisabled(True) - self.framesScrollBar.setVisible(False) - self.frameLabel.hide() - self.t_label.hide() - self.prevAction.triggered.disconnect() - self.nextAction.triggered.disconnect() - self.jumpForwardAction.triggered.disconnect() - self.jumpBackwardAction.triggered.disconnect() - self.editToolBar.setVisible(False) - - def showOverlayContextMenu(self, event): - if not self.overlayButton.isChecked(): - return - - if self.parent is not None: - self.overlayContextMenu.exec_(QCursor.pos()) - - def gui_createStatusBar(self): - self.statusbar = self.statusBar() - # Temporary message - self.statusbar.showMessage("Ready", 3000) - # Permanent widget - self.wcLabel = QLabel(f"") - self.statusbar.addPermanentWidget(self.wcLabel) - - def gui_createGraphics(self): - self.graphLayout = pg.GraphicsLayoutWidget() - - # Plot Item container for image - self.Plot = pg.PlotItem() - self.Plot.invertY(True) - self.Plot.setAspectLocked(True) - self.Plot.hideAxis('bottom') - self.Plot.hideAxis('left') - self.graphLayout.addItem(self.Plot, row=1, col=1) - - # Image Item - self.img = widgets.BaseImageItem() - self.img.setEnableAutoLevels(True) - self.Plot.addItem(self.img) - - #Image histogram - self.imgGrad = widgets.myHistogramLUTitem(isViewer=True) - self.imgGrad.gradient.showMenu = self.showLutItemOverlayContextMenu - self.imgGrad.vb.raiseContextMenu = lambda x: None - self.imgGrad.setImageItem(self.img) - self.graphLayout.addItem(self.imgGrad, row=1, col=0) - - # Current frame text - self.frameLabel = pg.LabelItem(justify='center', color='w', size='14pt') - self.frameLabel.setText(' ') - self.graphLayout.addItem(self.frameLabel, row=2, col=0, colspan=2) - - if not self.enableOverlay: - return - - def gui_createOverlayItems(self): - self.createOverlayChannelsActions() - self.overlayLayersItems = {} - for ch in self.posData.chNames: - if ch == self.parent.user_ch_name: - continue - overlayItems = self.getOverlayItems(ch) - imageItem, lutItem, alphaScrollbar = overlayItems - lutItem.vb.raiseContextMenu = lambda x: None - lutItem.gradient.showMenu = self.showLutItemOverlayContextMenu - lutItem.overlayColorButton.sigColorChanging.connect( - self.updateOlColors - ) - self.addAlphaScrollbar(ch, imageItem, alphaScrollbar) - self.overlayLayersItems[ch] = overlayItems - self.Plot.addItem(imageItem) - - def createOverlayChannelsActions(self): - self.overlayLutItemAdditionalActions = [] - separator = QAction(self) - separator.setSeparator(True) - self.overlayLutItemAdditionalActions.append(separator) - section = self.imgGrad.gradient.menu.addSection( - 'Select channel to adjust: ' - ) - self.overlayLutItemAdditionalActions.append(section) - self.imgGrad.gradient.menu.removeAction(section) - - self.overlayChNamesActionGroup = QActionGroup(self) - self.overlayChNamesActionGroup.setExclusive(True) - for chName in self.posData.chNames: - action = QAction(chName, self) - action.setCheckable(True) - if chName == self.parent.user_ch_name: - action.setChecked(True) - self.overlayChNamesActionGroup.addAction(action) - self.overlayChNamesActionGroup.triggered.connect( - self.chNameGradientActionClicked - ) - - def chNameGradientActionClicked(self, action): - # Action triggered from lutItem - self.checkedOverlayChName = action.text() - if action.text() == self.posData.user_ch_name: - self.setOverlayItemsVisible('', False) - else: - self.setOverlayItemsVisible(action.text(), True) - - def showLutItemOverlayContextMenu(self, event): - lutItem = self.currentLutItem - - for action in self.overlayLutItemAdditionalActions: - try: - lutItem.gradient.menu.removeAction(action) - except Exception as e: - pass - - for action in self.overlayChNamesActionGroup.actions(): - try: - lutItem.gradient.menu.removeAction(action) - except Exception as e: - pass - - if self.overlayButton.isChecked(): - for action in self.overlayLutItemAdditionalActions: - lutItem.gradient.menu.addAction(action) - - for action in self.overlayChNamesActionGroup.actions(): - if action.text() == self.posData.user_ch_name: - lutItem.gradient.menu.addAction(action) - continue - for filename in self.posData.ol_data: - if filename.endswith(action.text()): - lutItem.gradient.menu.addAction(action) - break - if filename.endswith(f'{action.text()}_aligned'): - lutItem.gradient.menu.addAction(action) - break - - try: - # Convert QPointF to QPoint - lutItem.gradient.menu.popup(event.screenPos().toPoint()) - except AttributeError: - lutItem.gradient.menu.popup(event.screenPos()) - - - def gui_connectImgActions(self): - self.img.hoverEvent = self.gui_hoverEventImg - - def gui_createImgWidgets(self): - if self.posData is None: - posData = self.parent.data[self.parent.pos_i] - else: - posData = self.posData - self.img_Widglayout = QGridLayout() - - # Frames scrollbar - self.framesScrollBar = QScrollBar(Qt.Horizontal) - # self.framesScrollBar.setFixedHeight(20) - self.framesScrollBar.setMinimum(1) - self.framesScrollBar.setMaximum(posData.SizeT) - t_label = QLabel('frame ') - _font = QFont() - _font.setPixelSize(12) - t_label.setFont(_font) - self.img_Widglayout.addWidget( - t_label, 0, 0, alignment=Qt.AlignRight) - self.img_Widglayout.addWidget( - self.framesScrollBar, 0, 1, 1, 20) - self.t_label = t_label - self.framesScrollBar.valueChanged.connect(self.framesScrollBarMoved) - - # z-slice scrollbar - self.zSliceScrollBar = QScrollBar(Qt.Horizontal) - # self.zSliceScrollBar.setFixedHeight(20) - self.zSliceScrollBar.setMaximum(self.posData.SizeZ-1) - _z_label = QLabel('z-slice ') - _font = QFont() - _font.setPixelSize(12) - _z_label.setFont(_font) - self.z_label = _z_label - self.img_Widglayout.addWidget(_z_label, 1, 0, alignment=Qt.AlignCenter) - self.img_Widglayout.addWidget(self.zSliceScrollBar, 1, 1, 1, 20) - - if self.posData.SizeZ == 1: - self.zSliceScrollBar.setDisabled(True) - self.zSliceScrollBar.setVisible(False) - _z_label.setVisible(False) - - self.img_Widglayout.setContentsMargins(100, 0, 50, 0) - self.zSliceScrollBar.valueChanged.connect(self.update_z_slice) - - if self.enableOverlay: - self.setOverlayColors() - self.gui_createOverlayItems() - self.createOverlayContextMenu() - - self.img.alphaScrollbar = self.addAlphaScrollbar( - self.parent.user_ch_name, self.img - ) - - def getOverlayItems(self, channelName): - imageItem = pg.ImageItem() - imageItem.setOpacity(0.5) - - lutItem = widgets.myHistogramLUTitem(isViewer=True) - - lutItem.setImageItem(imageItem) - lutItem.vb.raiseContextMenu = lambda x: None - initColor = self.overlayRGBs.pop(0) - self.parent.initColormapOverlayLayerItem(initColor, lutItem) - lutItem.addOverlayColorButton(initColor, channelName) - lutItem.initColor = initColor - lutItem.hide() - - alphaScrollBar = self.addAlphaScrollbar(channelName, imageItem) - return imageItem, lutItem, alphaScrollBar - - def setMirroredCursorPos(self, x, y): - if not self.enableMirroredCursor: - return - - if not self.showMirroredCursorCheckbox.isChecked(): - return - - self.cursor.setData([x], [y]) - - def setOverlayColors(self): - self.overlayRGBs = [ - (255, 255, 0), - (252, 72, 254), - (49, 222, 134), - (22, 108, 27) - ] - cmap = matplotlib.colormaps['gist_rainbow'] - self.overlayRGBs.extend( - [tuple([round(c*255) for c in cmap(i)][:3]) - for i in np.linspace(0,1,8)] - ) - - def setOpacityOverlayLayersItems(self, value, imageItem=None): - if imageItem is None: - imageItem = self.sender().imageItem - alpha = value/self.sender().maximum() - else: - alpha = value - imageItem.setOpacity(alpha) - - def overlay_cb(self, checked): - if checked: - if self.posData.ol_data is None: - selectedChannels = self.askSelectOverlayChannel() - if selectedChannels is None: - self.overlayButton.toggled.disconnect() - self.overlayButton.setChecked(False) - self.overlayButton.toggled.connect(self.overlay_cb) - return - success = self.parent.loadOverlayData(selectedChannels) - if not success: - return False - lastChannel = selectedChannels[-1] - self.checkedOverlayChName = lastChannel - imageItem = self.overlayLayersItems[lastChannel][0] - self.setOpacityOverlayLayersItems(0.5, imageItem=imageItem) - self.img.setOpacity(0.5) - self.setCheckedOverlayContextMenusActions(selectedChannels) - else: - self.checkedOverlayChName = ( - self.parent.imgGrad.checkedChannelname - ) - selectedChannels = self.parent.checkedOverlayChannels - self.setCheckedOverlayContextMenusActions(selectedChannels) - self.setOverlayItemsVisible(self.checkedOverlayChName, True) - else: - self.img.setOpacity(1.0) - self.setOverlayItemsVisible('', False) - for items in self.overlayLayersItems.values(): - imageItem = items[0] - imageItem.clear() - self.update_img() - - def createOverlayContextMenu(self): - ch_names = [ - ch for ch in self.posData.chNames - if ch != self.posData.user_ch_name - ] - self.overlayContextMenu = QMenu() - self.overlayContextMenu.addSeparator() - self.checkedOverlayChannels = set() - for chName in ch_names: - action = QAction(chName, self.overlayContextMenu) - action.setCheckable(True) - action.toggled.connect(self.overlayChannelToggled) - self.overlayContextMenu.addAction(action) - - def setCheckedOverlayContextMenusActions(self, channelNames): - for action in self.overlayContextMenu.actions(): - if action.text() not in channelNames: - continue - action.setChecked(True) - self.checkedOverlayChannels.add(action.text()) - - def overlayChannelToggled(self, checked): - # Action toggled from overlayButton context menu - channelName = self.sender().text() - if checked: - posData = self.posData - if channelName not in posData.loadedFluoChannels: - self.parent.loadOverlayData([channelName], addToExisting=True) - self.setOverlayItemsVisible(channelName, True) - self.checkedOverlayChannels.add(channelName) - self.updateOlColors(None) - else: - self.checkedOverlayChannels.remove(channelName) - imageItem = self.overlayLayersItems[channelName][0] - imageItem.clear() - try: - channelToShow = next(iter(self.checkedOverlayChannels)) - self.setOverlayItemsVisible(channelToShow, True) - except StopIteration: - self.setOverlayItemsVisible('', False) - self.update_img() - - def updateOlColors(self, button): - lutItem = self.overlayLayersItems[self.checkedOverlayChName][1] - rgb = lutItem.overlayColorButton.color().getRgb()[:3] - self.parent.initColormapOverlayLayerItem(rgb, lutItem) - lutItem.overlayColorButton.setColor(rgb) - - def addAlphaScrollbar(self, channelName, imageItem, alphaScrollBar=None): - if alphaScrollBar is None: - alphaScrollBar = QScrollBar(Qt.Horizontal) - label = QLabel(f'Alpha {channelName}') - label.setFont(font) - label.hide() - alphaScrollBar.imageItem = imageItem - alphaScrollBar.label = label - alphaScrollBar.setFixedHeight(self.parent.h) - alphaScrollBar.hide() - alphaScrollBar.setMinimum(0) - alphaScrollBar.setMaximum(40) - alphaScrollBar.setValue(20) - alphaScrollBar.setToolTip( - f'Control the alpha value of the overlaid channel {channelName}.\n' - 'alpha=0 results in NO overlay,\n' - 'alpha=1 results in only fluorescence data visible' - ) - self.img_Widglayout.addWidget( - alphaScrollBar.label, 2, 0, alignment=Qt.AlignRight - ) - self.img_Widglayout.addWidget(alphaScrollBar, 2, 1, 1, 20) - sp = alphaScrollBar.label.sizePolicy() - sp.setRetainSizeWhenHidden(True) - alphaScrollBar.label.setSizePolicy(sp) - - sp = alphaScrollBar.sizePolicy() - sp.setRetainSizeWhenHidden(True) - alphaScrollBar.setSizePolicy(sp) - - alphaScrollBar.valueChanged.connect(self.setOpacityOverlayLayersItems) - return alphaScrollBar - - def setOverlayItemsVisible(self, channelName, visible): - if visible: - self.imgGrad.hide() - self.img.alphaScrollbar.hide() - self.img.alphaScrollbar.label.hide() - try: - self.graphLayout.removeItem(self.imgGrad) - except Exception as e: - pass - itemsToShow = None - for name, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB = items - if name == channelName: - itemsToShow = items - else: - lutItem.hide() - alphaSB.hide() - alphaSB.label.hide() - try: - self.graphLayout.removeItem(lutItem) - except Exception as e: - pass - - if itemsToShow is None: - self.graphLayout.addItem(self.imgGrad, row=1, col=0) - self.imgGrad.show() - self.currentLutItem = self.imgGrad - self.img.alphaScrollbar.show() - self.img.alphaScrollbar.label.show() - else: - _, lutItem, alphaSB = itemsToShow - lutItem.show() - alphaSB.show() - alphaSB.label.show() - self.currentLutItem = lutItem - self.graphLayout.addItem(lutItem, row=1, col=0) - else: - if self.overlayButton.isChecked(): - self.img.alphaScrollbar.show() - self.img.alphaScrollbar.label.show() - else: - self.img.alphaScrollbar.hide() - self.img.alphaScrollbar.label.hide() - for name, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB = items - lutItem.hide() - alphaSB.hide() - alphaSB.label.hide() - try: - self.graphLayout.removeItem(lutItem) - except Exception as e: - pass - self.graphLayout.addItem(self.imgGrad, row=1, col=0) - self.imgGrad.show() - self.currentLutItem = self.imgGrad - - def framesScrollBarMoved(self, frame_n): - self.frame_i = frame_n-1 - self.t_label.setText( - f'frame n. {self.frame_i+1}/{self.num_frames}' - ) - if self.spinBox is not None: - self.spinBox.setValue(frame_n) - self.update_img() - - def gui_hoverEventImg(self, event): - # Update x, y, value label bottom right - try: - x, y = event.pos() - xdata, ydata = int(x), int(y) - _img = self.img.image - Y, X = _img.shape - if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - val = _img[ydata, xdata] - self.wcLabel.setText(f'(x={x:.2f}, y={y:.2f}, value={val:.2f})') - else: - self.wcLabel.setText(f'') - except Exception as e: - self.wcLabel.setText(f'') - - emitHovering = ( - self.enableMirroredCursor and - self.showMirroredCursorCheckbox.isChecked() - ) - if emitHovering: - if event.isExit(): - x, y = None, None - else: - x, y = event.pos() - self.sigHoveringImage.emit(x, y) - self.cursor.setData([], []) - - def next_frame(self): - if self.frame_i < self.num_frames-1: - self.frame_i += 1 - else: - self.frame_i = 0 - self.update_img() - - def prev_frame(self): - if self.frame_i > 0: - self.frame_i -= 1 - else: - self.frame_i = self.num_frames-1 - self.update_img() - - def skip10ahead_frames(self): - if self.frame_i < self.num_frames-10: - self.frame_i += 10 - else: - self.frame_i = 0 - self.update_img() - - def skip10back_frames(self): - if self.frame_i > 9: - self.frame_i -= 10 - else: - self.frame_i = self.num_frames-1 - self.update_img() - - def update_z_slice(self, z): - if self.posData is None: - posData = self.parent.data[self.parent.pos_i] - else: - posData = self.posData - idx = (posData.filename, posData.frame_i) - posData.segmInfo_df.at[idx, 'z_slice_used_gui'] = z - - self.z_label.setText(f'z-slice {z+1:02}/{posData.SizeZ}') - self.img.setCurrentZsliceIndex(z) - self.update_img() - - def getImage(self): - posData = self.posData - frame_i = self.frame_i - if posData.SizeZ > 1: - idx = (posData.filename, frame_i) - z = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - zProjHow = posData.segmInfo_df.at[idx, 'which_z_proj_gui'] - img = posData.img_data[frame_i] - if zProjHow == 'single z-slice': - self.zSliceScrollBar.setSliderPosition(z) - self.z_label.setText(f'z-slice {z+1:02}/{posData.SizeZ}') - img = img[z].copy() - elif zProjHow == 'max z-projection': - img = img.max(axis=0).copy() - elif zProjHow == 'mean z-projection': - img = img.mean(axis=0).copy() - elif zProjHow == 'median z-proj.': - img = np.median(img, axis=0).copy() - else: - img = posData.img_data[frame_i].copy() - return img - - def update_img(self): - self.frameLabel.setText( - f'Current frame = {self.frame_i+1}/{self.num_frames}' - ) - if self.parent is None: - img = self.getImage() - else: - img = self.parent.getImage(frame_i=self.frame_i, raw=True) - - self.img.setCurrentFrameIndex(self.frame_i) - self.img.setImage(img) - self.framesScrollBar.setSliderPosition(self.frame_i+1) - - if not self.enableOverlay: - return - - if not self.overlayButton.isChecked(): - return - - self.setOverlayImages(frame_i=self.frame_i) - - def askSelectOverlayChannel(self): - ch_names = [ - ch for ch in self.posData.chNames - if ch != self.posData.user_ch_name - ] - selectFluo = widgets.QDialogListbox( - 'Select channel', - 'Select channel names to overlay:\n', - ch_names, multiSelection=True, parent=self - ) - selectFluo.exec_() - if selectFluo.cancel: - return - - return selectFluo.selectedItemsText - - def setOverlayImages(self, frame_i=None): - posData = self.posData - for filename in posData.ol_data: - chName = myutils.get_chname_from_basename( - filename, posData.basename, remove_ext=False - ) - if chName not in self.checkedOverlayChannels: - continue - - imageItem = self.overlayLayersItems[chName][0] - ol_img = self.parent.getOlImg(filename, frame_i=frame_i) - imageItem.setImage(ol_img) - - def closeEvent(self, event): - if self.button_toUncheck is not None: - self.button_toUncheck.setChecked(False) - self.sigClosed.emit() - - def show(self, left=None, top=None): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - QMainWindow.show(self) - try: - self.framesScrollBar.setFixedHeight(self.parent.h) - except Exception as e: - pass - try: - self.zSliceScrollBar.setFixedHeight(self.parent.h) - except Exception as e: - pass - - try: - self.img.alphaScrollbar.setFixedHeight(self.parent.h) - except Exception as e: - pass - if left is not None and top is not None: - self.setGeometry(left, top, 850, 800) - -class TreeSelectorDialog(QBaseDialog): - sigItemDoubleClicked = Signal(object) - - def __init__( - self, title='Tree selector', infoTxt='', parent=None, - multiSelection=True, widthFactor=None, heightFactor=None, - expandOnDoubleClick=False, isTopLevelSelectable=True, - allItemsExpanded=True, allowNoSelection=True - ): - super().__init__(parent) - - self.setWindowTitle(title) - - self.cancel = True - self.widthFactor = widthFactor - self.heightFactor = heightFactor - self.allItemsExpanded = allItemsExpanded - self.mainLayout = QVBoxLayout() - self._isTopLevelSelectable = isTopLevelSelectable - self.allowNoSelection = allowNoSelection - - if infoTxt: - self.mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) - - self.treeWidget = widgets.TreeWidget(multiSelection=multiSelection) - self.treeWidget.setExpandsOnDoubleClick(expandOnDoubleClick) - self.treeWidget.setHeaderHidden(True) - self.mainLayout.addWidget(self.treeWidget) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addSpacing(20) - self.mainLayout.addLayout(buttonsLayout) - - self.buttonsLayout = buttonsLayout - - self.setLayout(self.mainLayout) - - self.treeWidget.itemClicked.connect(self.onItemClicked) - self.treeWidget.itemDoubleClicked.connect(self.onItemDoubleClicked) - - def onItemDoubleClicked(self, item): - self.sigItemDoubleClicked.emit(item) - - def onItemClicked(self, item): - if self._isTopLevelSelectable: - return - if item.parent() is None: - item.setSelected(False) - - def addTree(self, tree: dict): - for topLevel, children in tree.items(): - topLevelItem = widgets.TreeWidgetItem(self.treeWidget) - topLevelItem.setText(0, topLevel) - self.treeWidget.addTopLevelItem(topLevelItem) - childrenItems = [widgets.TreeWidgetItem([c]) for c in children] - topLevelItem.addChildren(childrenItems) - if not self.allItemsExpanded: - continue - topLevelItem.setExpanded(True) - - def resizeVertical(self): - if not self.isVisible(): - self.show() - - currentTreeWidgetHeight = self.treeWidget.height() - treeWidgetHeight = 0 - for i in range(self.treeWidget.topLevelItemCount()): - topLevelItem = self.treeWidget.topLevelItem(i) - rect = self.treeWidget.visualItemRect(topLevelItem) - treeWidgetHeight += rect.height() - for j in range(topLevelItem.childCount()): - childItem = topLevelItem.child(j) - rect = self.treeWidget.visualItemRect(childItem) - treeWidgetHeight += rect.height() - - deltaHeight = treeWidgetHeight - currentTreeWidgetHeight + 10 - self.resize(self.width(), self.height() + deltaHeight) - self.move(self.x(), 20) - - def setCurrentItem(self, itemText: dict): - if not itemText: - return - for i in range(self.treeWidget.topLevelItemCount()): - topLevelItem = self.treeWidget.topLevelItem(i) - topLevelName = topLevelItem.text(0) - childText = itemText.get(topLevelName) - if childText is None: - continue - for j in range(topLevelItem.childCount()): - childItem = topLevelItem.child(j) - childItemText = childItem.text(0) - if childItemText == childText: - childItem.setSelected(True) - topLevelItem.setExpanded(True) - self.treeWidget.scrollToItem(topLevelItem) - break - - def selectedItems(self): - self._selectedItems = {} - for i in range(self.treeWidget.topLevelItemCount()): - topLevelItem = self.treeWidget.topLevelItem(i) - topLevelName = topLevelItem.text(0) - for j in range(topLevelItem.childCount()): - childItem = topLevelItem.child(j) - if not childItem.isSelected(): - continue - if topLevelName not in self._selectedItems: - self._selectedItems[topLevelName] = [childItem.text(0)] - else: - self._selectedItems[topLevelName].append(childItem.text(0)) - return self._selectedItems - - def warnSelectionIsEmpty(self): - txt = html_utils.paragraph(""" - You did not select anything :(.

- Please press Cancel to exit without selecting items. - Thanks! - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Selection is empty', txt) - - def ok_cb(self): - if not self.allowNoSelection and not self.selectedItems(): - self.warnSelectionIsEmpty() - return - self.cancel = False - self.close() - - def showEvent(self, event) -> None: - super().showEvent(event) - if self.widthFactor is not None: - self.resize(int(self.width()*self.widthFactor), self.height()) - if self.heightFactor is not None: - self.resize(self.width(), int(self.height()*self.heightFactor)) - -class TreesSelectorDialog(QBaseDialog): - def __init__( - self, trees, groupsDescr=None, title='Trees selector', - infoTxt='', parent=None - ): - super().__init__(parent) - - self.setWindowTitle(title) - - self.cancel = True - self.mainLayout = QVBoxLayout() - - if infoTxt: - self.mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) - - self.treeWidgets = {} - self.setLayout(self.mainLayout) - - createdGroupLayouts = {} - for treeName, tree in trees.items(): - if groupsDescr is None: - groupName = '' - else: - groupName = groupsDescr.get(treeName, 'Group info missing') - groupLayout = createdGroupLayouts.get(groupName, None) - if groupLayout is None: - self.mainLayout.addWidget(QLabel(html_utils.paragraph(groupName))) - groupBox = QGroupBox() - self.mainLayout.addWidget(groupBox) - groupLayout = QVBoxLayout() - groupBox.setLayout(groupLayout) - createdGroupLayouts[groupName] = groupLayout - else: - groupLayout.addSpacing(10) - groupLayout.addWidget(QLabel(html_utils.paragraph(treeName))) - treeWidget = widgets.TreeWidget(multiSelection=True) - treeWidget.setHeaderHidden(True) - for topLevel, children in tree.items(): - topLevelItem = widgets.TreeWidgetItem(treeWidget) - topLevelItem.setText(0, topLevel) - treeWidget.addTopLevelItem(topLevelItem) - childrenItems = [widgets.TreeWidgetItem([c]) for c in children] - topLevelItem.addChildren(childrenItems) - topLevelItem.setExpanded(True) - self.treeWidgets[treeName] = treeWidget - groupLayout.addWidget(treeWidget) - self.mainLayout.addSpacing(20) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addSpacing(10) - self.mainLayout.addLayout(buttonsLayout) - - def ok_cb(self): - self.cancel = False - self.selectedItems = {} - for treeName, treeWidget in self.treeWidgets.items(): - for i in range(treeWidget.topLevelItemCount()): - topLevelItem = treeWidget.topLevelItem(i) - for j in range(topLevelItem.childCount()): - childItem = topLevelItem.child(j) - if not childItem.isSelected(): - continue - if treeName not in self.selectedItems: - self.selectedItems[treeName] = [childItem.text(0)] - else: - self.selectedItems[treeName].append(childItem.text(0)) - self.close() - - -class MultiListSelector(QBaseDialog): - def __init__( - self, lists: dict, groupsDescr: dict=None, - title='Lists selector', infoTxt='', parent=None - ): - super().__init__(parent) - - self.setWindowTitle(title) - - self.cancel = True - mainLayout = QVBoxLayout() - - if infoTxt: - mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) - - self.listWidgets = {} - createdGroupLayouts = {} - for listName, listItems in lists.items(): - if groupsDescr is None: - groupName = '' - else: - groupName = groupsDescr.get(listName, 'Group info missing') - groupLayout = createdGroupLayouts.get(listName, None) - if groupLayout is None: - mainLayout.addWidget(QLabel(html_utils.paragraph(groupName))) - groupBox = QGroupBox() - mainLayout.addWidget(groupBox) - groupLayout = QVBoxLayout() - groupBox.setLayout(groupLayout) - createdGroupLayouts[groupName] = groupLayout - else: - groupLayout.addSpacing(10) - groupLayout.addWidget(QLabel(html_utils.paragraph(listName))) - listWidget = widgets.listWidget() - listWidget.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - listWidget.addItems(listItems) - groupLayout.addWidget(listWidget) - mainLayout.addSpacing(20) - self.listWidgets[listName] = listWidget - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(10) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def ok_cb(self): - self.cancel = False - self.selectedItems = {} - for listName, listWidget in self.listWidgets.items(): - if not listWidget.selectedItems(): - continue - self.selectedItems[listName] = [ - item.text() for item in listWidget.selectedItems() - ] - self.close() - -class selectPositionsMultiExp(QBaseDialog): - def __init__(self, expPaths: dict, infoPaths: dict=None, parent=None): - super().__init__(parent=parent) - - self.expPaths = expPaths - self.cancel = True - - mainLayout = QVBoxLayout() - - self.setWindowTitle('Select Positions to process') - - infoTxt = html_utils.paragraph( - 'Select one or more Positions to process

' - 'Click on experiment path to select all positions
' - 'Ctrl+Click to select multiple items
' - 'Shift+Click to select a range of items
', - center=True - ) - infoLabel = QLabel(infoTxt) - - self.treeWidget = QTreeWidget() - self.treeWidget.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection - ) - self.treeWidget.setHeaderHidden(True) - self.treeWidget.setFont(font) - for exp_path, positions in expPaths.items(): - pathLevels = exp_path.split(os.sep) - posFoldersInfo = None - if infoPaths is not None: - posFoldersInfo = infoPaths.get(exp_path) - if len(pathLevels) > 4: - itemText = os.path.join(*pathLevels[-4:]) - itemText = f'...{itemText}' - else: - itemText = exp_path - exp_path_item = QTreeWidgetItem([itemText]) - exp_path_item.setToolTip(0, exp_path) - exp_path_item.full_path = exp_path - self.treeWidget.addTopLevelItem(exp_path_item) - postions_items = [] - for pos in positions: - if posFoldersInfo is not None: - status = posFoldersInfo.get(pos, '') - else: - status = '' - pos_item_text = f'{pos}{status}' - pos_item = QTreeWidgetItem(exp_path_item, [pos_item_text]) - pos_item.posFoldername = pos - postions_items.append(pos_item) - exp_path_item.addChildren(postions_items) - exp_path_item.setExpanded(True) - - self.treeWidget.itemClicked.connect(self.selectAllChildren) - - buttonsLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton(' Ok ') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - mainLayout.addWidget(self.treeWidget) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - self.setStyleSheet(TREEWIDGET_STYLESHEET) - - def selectAllChildren(self, item, col): - if item.parent() is not None: - return - - for i in range(item.childCount()): - item.child(i).setSelected(True) - - def ok_cb(self): - if not self.treeWidget.selectedItems(): - msg = widgets.myMessageBox(wrapText=False) - txt = 'You did not select any experiment/Position folder!' - msg.warning(self, 'Empty selection!', html_utils.paragraph(txt)) - return - - self.cancel = False - self.selectedPaths = {} - for item in self.treeWidget.selectedItems(): - if item.parent() is None: - continue - parent = item.parent() - exp_path = parent.full_path - pos_folder = item.posFoldername - if exp_path not in self.selectedPaths: - self.selectedPaths[exp_path] = [] - self.selectedPaths[exp_path].append(pos_folder) - - self.close() - - def showEvent(self, event): - self.resize(int(self.width()*2), self.height()) - - -class editCcaTableWidget(QDialog): - sigApplyChangesFutureFrames = Signal(object, int) - - def __init__( - self, cca_df, SizeT, title='Edit cell cycle annotations', - parent=None, current_frame_i=0 - ): - self.inputCca_df = cca_df - self.cancel = True - self.SizeT = SizeT - self.cca_df = None - self.current_frame_i = current_frame_i - - super().__init__(parent) - self.setWindowTitle(title) - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - # Layouts - mainLayout = QVBoxLayout() - headerLayout = QGridLayout() - tableLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - self.scrollArea = QScrollArea() - self.viewBox = QWidget() - - # Header labels - col = 0 - row = 0 - IDsLabel = QLabel('Cell ID') - AC = Qt.AlignCenter - IDsLabel.setAlignment(AC) - headerLayout.addWidget(IDsLabel, 0, col, alignment=AC) - - col += 1 - ccsLabel = QLabel('Cell cycle stage') - ccsLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(ccsLabel, 0, col, alignment=AC) - - col += 1 - relIDLabel = QLabel('Relative ID') - relIDLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(relIDLabel, 0, col, alignment=AC) - - col += 1 - genNumLabel = QLabel('Generation number') - genNumLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(genNumLabel, 0, col, alignment=AC) - genNumColWidth = genNumLabel.sizeHint().width() - - col += 1 - relationshipLabel = QLabel('Relationship') - relationshipLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(relationshipLabel, 0, col, alignment=AC) - - col += 1 - emergFrameLabel = QLabel('Emerging frame num.') - emergFrameLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(emergFrameLabel, 0, col, alignment=AC) - - col += 1 - divitionFrameLabel = QLabel('Division frame num.') - divitionFrameLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(divitionFrameLabel, 0, col, alignment=AC) - - col += 1 - historyKnownLabel = QLabel('Is history known?') - historyKnownLabel.setAlignment(Qt.AlignCenter) - headerLayout.addWidget(historyKnownLabel, 0, col, alignment=AC) - - self.headerLayout = headerLayout - - tableLayout.setHorizontalSpacing(20) - self.tableLayout = tableLayout - - # Add buttons - cancelButton = widgets.cancelPushButton('Cancel') - moreInfoButton = widgets.helpPushButton('More info...') - moreInfoButton.setIcon(QIcon(':info.svg')) - applyToFutureFramesbutton = widgets.futurePushButton( - 'Apply changes to future frames...' - ) - okButton = widgets.okPushButton('Ok') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(moreInfoButton) - buttonsLayout.addWidget(applyToFutureFramesbutton) - buttonsLayout.addWidget(okButton) - - # Scroll area properties - self.scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) - self.scrollArea.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) - self.scrollArea.setFrameStyle(QFrame.Shape.NoFrame) - self.scrollArea.setWidgetResizable(True) - - # Add layouts - self.viewBox.setLayout(tableLayout) - self.scrollArea.setWidget(self.viewBox) - mainLayout.addLayout(headerLayout) - mainLayout.addWidget(self.scrollArea) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - # Populate table Layout - IDs = cca_df.index - self.IDs = IDs.to_list() - relIDsOptions = [str(ID) for ID in IDs] - relIDsOptions.insert(0, '-1') - self.IDlabels = [] - self.ccsComboBoxes = [] - self.genNumSpinBoxes = [] - self.relIDComboBoxes = [] - self.relationshipComboBoxes = [] - self.emergFrameSpinBoxes = [] - self.divisFrameSpinBoxes = [] - self.emergFrameSpinPrevValues = [] - self.divisFrameSpinPrevValues = [] - self.historyKnownCheckBoxes = [] - for row, ID in enumerate(IDs): - col = 0 - IDlabel = QLabel(f'{ID}') - IDlabel.setAlignment(Qt.AlignCenter) - tableLayout.addWidget(IDlabel, row+1, col, alignment=AC) - self.IDlabels.append(IDlabel) - - col += 1 - ccsComboBox = QComboBox() - ccsComboBox.setFocusPolicy(Qt.StrongFocus) - ccsComboBox.installEventFilter(self) - ccsComboBox.addItems(['G1', 'S/G2/M']) - ccsValue = cca_df.at[ID, 'cell_cycle_stage'] - if ccsValue == 'S': - ccsValue = 'S/G2/M' - - try: - ccsComboBox.setCurrentText(ccsValue) - except Exception as err: - printl(ccsValue) - printl(cca_df) - raise err - tableLayout.addWidget(ccsComboBox, row+1, col, alignment=AC) - self.ccsComboBoxes.append(ccsComboBox) - ccsComboBox.activated.connect(self.clearComboboxFocus) - - col += 1 - relIDComboBox = QComboBox() - relIDComboBox.setFocusPolicy(Qt.StrongFocus) - relIDComboBox.installEventFilter(self) - relIDComboBox.addItems(relIDsOptions) - relIDComboBox.setCurrentText(str(cca_df.at[ID, 'relative_ID'])) - tableLayout.addWidget(relIDComboBox, row+1, col) - self.relIDComboBoxes.append(relIDComboBox) - relIDComboBox.currentIndexChanged.connect(self.setRelID) - relIDComboBox.activated.connect(self.clearComboboxFocus) - - col += 1 - genNumSpinBox = widgets.SpinBox() - genNumSpinBox.setFocusPolicy(Qt.StrongFocus) - genNumSpinBox.installEventFilter(self) - genNumSpinBox.setValue(2) - genNumSpinBox.setMaximum(2147483647) - genNumSpinBox.setAlignment(Qt.AlignCenter) - genNumSpinBox.setFixedWidth(int(genNumColWidth*2/3)) - genNumSpinBox.setValue(int(cca_df.at[ID, 'generation_num'])) - tableLayout.addWidget(genNumSpinBox, row+1, col, alignment=AC) - self.genNumSpinBoxes.append(genNumSpinBox) - - col += 1 - relationshipComboBox = QComboBox() - relationshipComboBox.setFocusPolicy(Qt.StrongFocus) - relationshipComboBox.installEventFilter(self) - relationshipComboBox.addItems(['mother', 'bud']) - relationshipComboBox.setCurrentText( - str(cca_df.at[ID, 'relationship']) - ) - tableLayout.addWidget(relationshipComboBox, row+1, col) - self.relationshipComboBoxes.append(relationshipComboBox) - relationshipComboBox.currentIndexChanged.connect( - self.relationshipChanged_cb) - relationshipComboBox.activated.connect(self.clearComboboxFocus) - - col += 1 - emergFrameSpinBox = widgets.SpinBox() - emergFrameSpinBox.setFocusPolicy(Qt.StrongFocus) - emergFrameSpinBox.installEventFilter(self) - emergFrameSpinBox.setMaximum(SizeT) - emergFrameSpinBox.setMinimum(-1) - emergFrameSpinBox.setValue(-1) - emergFrameSpinBox.setAlignment(Qt.AlignCenter) - emergFrameSpinBox.setFixedWidth(int(genNumColWidth*2/3)) - emergFrame_i = cca_df.at[ID, 'emerg_frame_i'] - val = emergFrame_i+1 if emergFrame_i>=0 else -1 - emergFrameSpinBox.setValue(val) - tableLayout.addWidget(emergFrameSpinBox, row+1, col, alignment=AC) - self.emergFrameSpinBoxes.append(emergFrameSpinBox) - self.emergFrameSpinPrevValues.append(emergFrameSpinBox.value()) - emergFrameSpinBox.valueChanged.connect(self.skip0emergFrame) - - - col += 1 - divisFrameSpinBox = widgets.SpinBox() - divisFrameSpinBox.setFocusPolicy(Qt.StrongFocus) - divisFrameSpinBox.installEventFilter(self) - divisFrameSpinBox.setMinimum(-1) - divisFrameSpinBox.setMaximum(SizeT) - divisFrameSpinBox.setValue(-1) - divisFrameSpinBox.setAlignment(Qt.AlignCenter) - divisFrameSpinBox.setFixedWidth(int(genNumColWidth*2/3)) - divisFrame_i = int(cca_df.at[ID, 'division_frame_i']) - val = divisFrame_i+1 if divisFrame_i>=0 else -1 - divisFrameSpinBox.setValue(val) - tableLayout.addWidget(divisFrameSpinBox, row+1, col, alignment=AC) - self.divisFrameSpinBoxes.append(divisFrameSpinBox) - self.divisFrameSpinPrevValues.append(divisFrameSpinBox.value()) - divisFrameSpinBox.valueChanged.connect(self.skip0divisFrame) - - col += 1 - HistoryCheckBox = QCheckBox() - HistoryCheckBox.setChecked(bool(cca_df.at[ID, 'is_history_known'])) - tableLayout.addWidget(HistoryCheckBox, row+1, col, alignment=AC) - self.historyKnownCheckBoxes.append(HistoryCheckBox) - - self.setLayout(mainLayout) - - # Connect to events - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - moreInfoButton.clicked.connect(self.moreInfo) - applyToFutureFramesbutton.clicked.connect(self.applyToFutureFrames) - - # self.setModal(True) - - def getChanges(self): - newCcaDf = self.getCca_df() - changes = {} - for row in newCcaDf.itertuples(): - ID = row.Index - for col in newCcaDf.columns: - inputValue = self.inputCca_df.at[ID, col] - newValue = getattr(row, col) - if newValue == inputValue: - continue - - if ID not in changes: - changes[ID] = {col: (inputValue, newValue)} - else: - changes[ID][col] = (inputValue, newValue) - return changes - - def applyToFutureFrames(self): - txt = 'Enter up to which frame you want to apply the changes
' - win = NumericEntryDialog( - title='Stop frame', instructions=txt, parent=self, minValue=1, - maxValue=self.SizeT, currentValue=self.current_frame_i - ) - win.exec_() - if win.cancel: - return - - stop_frame_i = win.value - changes = self.getChanges() - changes_format = myutils.format_cca_manual_changes(changes) - detailsText = ( - f'Changes that will be applied from frame n. {self.current_frame_i+1}' - f' to frame n. {stop_frame_i+1}:\n\n{changes_format}' - ) - txt = html_utils.paragraph(""" -Use this feature with caution!

-Before propagating to future frames carefully inspect what changes -will be applied (see below).

-""") - msg = widgets.myMessageBox(wrapText=False) - msg.setDetailedText(detailsText, visible=True) - msg.warning( - self, 'Caution!', txt, buttonsTexts=('Yes, I am sure', 'Cancel') - ) - if msg.cancel: - return - - self.sigApplyChangesFutureFrames.emit(changes, stop_frame_i) - - def moreInfo(self, checked=True): - desc = myutils.get_cca_colname_desc() - msg = widgets.myMessageBox(parent=self) - msg.setWindowTitle('Cell cycle annotations info') - msg.setWidth(400) - msg.setIcon() - for col, txt in desc.items(): - msg.addText(html_utils.paragraph(f'{col}: {txt}')) - msg.addButton(' Ok ') - msg.exec_() - - def setRelID(self, itemIndex): - idx = self.relIDComboBoxes.index(self.sender()) - relID = self.sender().currentText() - IDofRelID = self.IDs[idx] - relIDidx = self.IDs.index(int(relID)) - relIDComboBox = self.relIDComboBoxes[relIDidx] - relIDComboBox.setCurrentText(str(IDofRelID)) - - def skip0emergFrame(self, value): - idx = self.emergFrameSpinBoxes.index(self.sender()) - prevVal = self.emergFrameSpinPrevValues[idx] - if value == 0 and value > prevVal: - self.sender().setValue(1) - self.emergFrameSpinPrevValues[idx] = 1 - elif value == 0 and value < prevVal: - self.sender().setValue(-1) - self.emergFrameSpinPrevValues[idx] = -1 - - def skip0divisFrame(self, value): - idx = self.divisFrameSpinBoxes.index(self.sender()) - prevVal = self.divisFrameSpinPrevValues[idx] - if value == 0 and value > prevVal: - self.sender().setValue(1) - self.divisFrameSpinPrevValues[idx] = 1 - elif value == 0 and value < prevVal: - self.sender().setValue(-1) - self.divisFrameSpinPrevValues[idx] = -1 - - def relationshipChanged_cb(self, itemIndex): - idx = self.relationshipComboBoxes.index(self.sender()) - ccs = self.sender().currentText() - if ccs == 'bud': - self.ccsComboBoxes[idx].setCurrentText('S/G2/M') - self.genNumSpinBoxes[idx].setValue(0) - - def getCca_df(self): - ccsValues = [var.currentText() for var in self.ccsComboBoxes] - ccsValues = [val if val=='G1' else 'S' for val in ccsValues] - genNumValues = [var.value() for var in self.genNumSpinBoxes] - relIDValues = [int(var.currentText()) for var in self.relIDComboBoxes] - relatValues = [var.currentText() for var in self.relationshipComboBoxes] - emergFrameValues = [ - var.value()-1 if var.value()>0 else -1 - for var in self.emergFrameSpinBoxes - ] - divisFrameValues = [ - var.value()-1 if var.value()>0 else -1 - for var in self.divisFrameSpinBoxes - ] - historyValues = [ - var.isChecked() for var in self.historyKnownCheckBoxes - ] - check_rel = [ID == relID for ID, relID in zip(self.IDs, relIDValues)] - - # Buds in S phase must have 0 as number of cycles - check_buds_S = [ - ccs=='S' and rel_ship=='bud' and not numc==0 - for ccs, rel_ship, numc - in zip(ccsValues, relatValues, genNumValues) - ] - - # Mother cells must have at least 1 as number of cycles if history known - check_mothers = [ - rel_ship=='mother' and not numc>=1 - if is_history_known else False - for rel_ship, numc, is_history_known - in zip(relatValues, genNumValues, historyValues) - ] - - # Buds cannot be in G1 - check_buds_G1 = [ - ccs=='G1' and rel_ship=='bud' for ccs, rel_ship - in zip(ccsValues, relatValues) - ] - - # The number of cells in S phase must be half mothers and half buds - num_moth_S = len([ - 0 for ccs, rel_ship in zip(ccsValues, relatValues) - if ccs=='S' and rel_ship=='mother' - ]) - num_bud_S = len([ - 0 for ccs, rel_ship in zip(ccsValues, relatValues) - if ccs=='S' and rel_ship=='bud' - ]) - - # Cells in S phase cannot have -1 as relative's ID - check_relID_S = [ - ccs=='S' and relID==-1 - for ccs, relID in zip(ccsValues, relIDValues) - ] - - # Mother cells with unknown history at emergence is recommended to have - # generation number = 2 (easier downstream analysis) - check_unknown_mothers = [ - rel_ship=='mother' and not is_history_known and gen_num!=2 - and (emerg_frame_i == self.current_frame_i or self.current_frame_i==0) - for rel_ship, is_history_known, gen_num, emerg_frame_i - in zip(relatValues, historyValues, genNumValues, emergFrameValues) - ] - - if any(check_rel): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - Some cells are mother or bud of itself!

- Make sure that the relative ID is different from the Cell ID. - """) - msg.critical(self, 'Some IDs are equal to relative ID', txt) - return None - elif any(check_unknown_mothers): - txt = html_utils.paragraph(""" - We recommend to set generation number to 2 for mother cells - with unknown history
- that just appeared
(i.e., first cell cycle in the video).

- While it is allowed to insert any number, knowing that these - cells start at generation number 2
- makes downstream analysis easier.

- What do you want to do? - """) - correctButtonText = ' Fine, let me correct. ' - keepButtonText = ' Keep the generation number that I chose. ' - buttonsTexts = (correctButtonText, keepButtonText) - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.warning(self, 'Recommendation', txt, buttonsTexts=buttonsTexts) - if msg.cancel or msg.clickedButton == correctButtonText: - return None - elif any(check_buds_S): - msg = widgets.myMessageBox(wrapText=False) - title = ( - 'Bud in S/G2/M not in 0 Generation number' - ) - txt = html_utils.paragraph( - 'Some buds ' - 'in S phase do not have 0 as Generation number!
' - 'Buds in S phase must have 0 as "Generation number"' - ) - msg.critical(self, title, txt) - return None - elif any(check_mothers): - msg = widgets.myMessageBox(wrapText=False) - title = ( - 'Mother not in >=1 Generation number' - ) - txt = html_utils.paragraph( - 'Some mother cells do not have >=1 as "Generation number"!
' - 'Mothers MUST have >1 "Generation number"' - ) - msg.critical(self, title, txt) - return None - elif any(check_buds_G1): - msg = widgets.myMessageBox(wrapText=False) - title = ( - 'Buds in G1!' - ) - txt = html_utils.paragraph( - 'Some buds are in G1 phase!

' - 'Buds MUST be in S/G2/M phase' - ) - msg.critical(self, title, txt) - return None - elif num_moth_S != num_bud_S: - msg = widgets.myMessageBox(wrapText=False) - title = ( - 'Number of mothers-buds mismatch!' - ) - txt = html_utils.paragraph( - f'There are {num_moth_S} mother cells in "S/G2/M" phase,' - f'but there are {num_bud_S} bud cells.

' - 'The number of mothers and buds in "S/G2/M" ' - 'phase must be equal!' - ) - msg.critical(self, title, txt) - return None - elif any(check_relID_S): - msg = widgets.myMessageBox(wrapText=False) - title = ( - 'Relative\'s ID of cells in S/G2/M = -1' - ) - txt = html_utils.paragraph( - 'Some cells are in "S/G2/M" phase but have -1 as Relative\'s ID!
' - 'Cells in "S/G2/M" phase must have an existing ' - 'ID as Relative\'s ID!' - ) - msg.critical(self, title, txt) - return None - - corrected_on_frame_i = self.inputCca_df['corrected_on_frame_i'] - cca_df = pd.DataFrame({ - 'cell_cycle_stage': ccsValues, - 'generation_num': genNumValues, - 'relative_ID': relIDValues, - 'relationship': relatValues, - 'emerg_frame_i': emergFrameValues, - 'division_frame_i': divisFrameValues, - 'is_history_known': historyValues, - 'corrected_on_frame_i': corrected_on_frame_i, - 'will_divide': self.inputCca_df['will_divide'], - }, index=self.IDs - ) - cca_df.index.name = 'Cell_ID' - - # Add missing columns - for column, default in base_cca_dict.items(): - if column in cca_df.columns: - continue - - value = self.inputCca_df.get(column, default=default) - cca_df[column] = value - - # Check that every pair of cells in S are relative of each other - proceed = self.check_ID_rel_ID_mismatches(cca_df) - if not proceed: - return None - - d = dict.fromkeys(cca_df.select_dtypes(np.int64).columns, np.int32) - cca_df = cca_df.astype(d) - return cca_df - - def check_ID_rel_ID_mismatches(self, cca_df): - ID_rel_ID_mismatches = [] - for row in cca_df.itertuples(): - if row.cell_cycle_stage == 'G1': - continue - - ID = row.Index - relID = row.relative_ID - relID_of_relID = cca_df.at[relID, 'relative_ID'] - - if relID_of_relID != ID: - ID_rel_ID_mismatches.append((ID, relID, relID_of_relID)) - - if not ID_rel_ID_mismatches: - return True - - items = [ - f'Cell ID {ID} has relative ID = {relID}, ' - f'while cell ID {relID} has relative ID = {relID_of_relID}' - for ID, relID, relID_of_relID in ID_rel_ID_mismatches - ] - title = '`ID-relative_ID` mismatches' - txt = html_utils.paragraph( - f'`ID-relative_ID` mismatches:' - f'{html_utils.to_list(items)}' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.critical(self, title, txt) - return False - - def ok_cb(self, checked): - cca_df = self.getCca_df() - if cca_df is None: - return - self.cca_df = cca_df - self.cancel = False - self.close() - - def cancel_cb(self, checked): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - ncols = self.tableLayout.columnCount() - maxLabelWidth = max([ - self.headerLayout.itemAt(j).widget().sizeHint().width() - for j in range(ncols) - ]) - minWidth = (maxLabelWidth+5)*ncols - self.setMinimumWidth(minWidth) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def eventFilter(self, object, event): - # Disable wheel scroll on widgets to allow scroll only on scrollarea - if event.type() == QEvent.Type.Wheel: - event.ignore() - return True - return False - - def clearComboboxFocus(self): - self.sender().clearFocus() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class askStopFrameSegm(QDialog): - def __init__( - self, user_ch_file_paths, user_ch_name, parent=None - ): - self.parent = parent - self.cancel = True - - super().__init__(parent) - self.setWindowTitle('Enter stop frame') - - self.visualizeWindows = [] - - mainLayout = QVBoxLayout() - buttonsLayout = QHBoxLayout() - - # Message - infoTxt = html_utils.paragraph(""" - Enter a stop frame number when to stop - segmentation for each Position loaded: - """) - infoLabel = QLabel(infoTxt, self) - infoLabel.setAlignment(Qt.AlignCenter) - # padding: top, left, bottom, right - infoLabel.setStyleSheet("padding:0px 0px 8px 0px;") - - self.dataDict = {} - - exp_path_pos_mapper = path.get_exp_path_pos_foldernames_mapper( - user_ch_file_paths - ) - - columnsLayout = QHBoxLayout() - mainScrollArea = widgets.ScrollArea() - mainScrollAreaWidget = QWidget() - mainScrollAreaWidget.setLayout(columnsLayout) - mainScrollArea.setWidget(mainScrollAreaWidget) - self.mainScrollArea = mainScrollArea - - # Form layout widget - self.spinBoxes = [] - self.tab_idx = 0 - iter_items = exp_path_pos_mapper.items() - self.groupboxScrollAreas = [] - - for col, (exp_path, pos_folders_files) in enumerate(iter_items): - groupboxScrollArea = widgets.ScrollArea() - self.groupboxScrollAreas.append(groupboxScrollArea) - groupbox = QGroupBox() - groupbox.setCheckable(False) - groupbox.setToolTip(exp_path) - groupboxLayout = QFormLayout() - groupbox.setLayout(groupboxLayout) - groupboxScrollArea.setWidget(groupbox) - columnsLayout.addWidget(groupboxScrollArea) - pos_folders = pos_folders_files['pos_foldernames'] - filenames = pos_folders_files['filenames'] - for i, pos_foldername in enumerate(pos_folders): - img_filename = filenames[i] - images_path = os.path.join(exp_path, pos_foldername, 'Images') - img_path = os.path.join(images_path, img_filename) - spinBox = widgets.mySpinBox() - spinBox.sigTabEvent.connect(self.keyTabEventSpinbox) - posData = load.loadData(img_path, user_ch_name, QParent=parent) - posData.getBasenameAndChNames(qparent=self) - posData.buildPaths() - posData.loadOtherFiles( - load_segm_data=False, - load_metadata=True, - loadSegmInfo=True, - ) - spinBox.setMaximum(posData.SizeT) - stopFrameNum = posData.readLastUsedStopFrameNumber() - if stopFrameNum is None: - spinBox.setValue(posData.SizeT) - else: - spinBox.setValue(stopFrameNum) - spinBox.setAlignment(Qt.AlignCenter) - visualizeButton = widgets.viewPushButton('Visualize') - visualizeButton.clicked.connect(self.visualize_cb) - formLabel = QLabel(html_utils.paragraph(f'{pos_foldername} ')) - layout = QHBoxLayout() - layout.addWidget(formLabel, alignment=Qt.AlignRight) - layout.addWidget(spinBox) - layout.addWidget(visualizeButton) - self.dataDict[visualizeButton] = (spinBox, posData) - groupboxLayout.addRow(layout) - spinBox.idx = i - self.spinBoxes.append(spinBox) - - fm = QFontMetrics(self.font()) - elidedTitle = fm.elidedText( - exp_path, Qt.ElideLeft, groupbox.sizeHint().width() - ) - groupbox.setTitle(elidedTitle) - - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - mainLayout.addWidget(mainScrollArea) - - okButton = widgets.okPushButton('Ok') - okButton.setShortcut(Qt.Key_Enter) - - cancelButton = widgets.cancelPushButton('Cancel') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - # # self.setModal(True) - - def keyTabEventSpinbox(self, event, sender): - self.tab_idx += 1 - if self.tab_idx >= len(self.spinBoxes): - self.tab_idx = 0 - focusSpinbox = self.spinBoxes[self.tab_idx] - focusSpinbox.setFocus() - - def saveStopFrameNumbers(self): - for spinBox, posData in self.dataDict.values(): - posData.metadata_df.at['stop_frame_num', 'values'] = spinBox.value() - posData.metadataToCsv() - - def ok_cb(self, event): - self.cancel = False - try: - self.saveStopFrameNumbers() - except Exception as err: - printl(traceback.format_exc()) - self.stopFrames = [ - spinBox.value() for spinBox, posData in self.dataDict.values() - ] - self.close() - - def closeEvent(self, event): - for window in self.visualizeWindows: - window.close() - - def visualize_cb(self, checked=True): - self.setDisabled(True) - spinBox, posData = self.dataDict[self.sender()] - print('Loading image data...') - posData.loadImgData() - posData.frame_i = spinBox.value()-1 - win = plot.imshow( - posData.img_data, - lut='gray', - figure_title=posData.relPath, - block=False - ) - self.visualizeWindows.append(win) - self.setDisabled(False) - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - screenSize = self.screen().size() - maxWidth = screenSize.width() - 50 - maxHeight = screenSize.height() - 100 - width, height = 0, 0 - for scrollArea in self.groupboxScrollAreas: - width += scrollArea.minimumWidthNoScrollbar() - scrollAreaHeight = scrollArea.minimumHeightNoScrollbar() - if scrollAreaHeight > height: - height = scrollAreaHeight - - width += 70 - height += ( - self.sizeHint().height() - - self.mainScrollArea.sizeHint().height() - ) - - if width > maxWidth: - width = maxWidth - - if height > maxHeight: - height = maxHeight - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - self.resize(width, height) - self.move(25, 50) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QLineEditDialog(QDialog): - def __init__( - self, title='Entry messagebox', msg='Entry value', - defaultTxt='', parent=None, allowedValues=None, - warnLastFrame=False, isInteger=False, isFloat=False, - stretchEntry=True, allowEmpty=True, allowedTextEntries=None, - allowText=False, lastVisitedFrame=None, allowList=False - ): - QDialog.__init__(self, parent) - - self.loop = None - self.cancel = True - self.assignNewID = False - self.allowedValues = allowedValues - self.warnLastFrame = warnLastFrame - self.isFloat = isFloat - self.allowEmpty = allowEmpty - self.isInteger = isInteger - self.allowedTextEntries = allowedTextEntries - self.allowText = allowText - self.lastVisitedFrame = lastVisitedFrame - if allowedValues and warnLastFrame: - self.maxValue = max(allowedValues) - - self.setWindowTitle(title) - - # Layouts - mainLayout = QVBoxLayout() - LineEditLayout = QVBoxLayout() - buttonsLayout = QHBoxLayout() - - # Widgets - if not msg.startswith(' np.iinfo(np.uint32).max: - self.entryWidget.setText(str(np.iinfo(np.uint32).max)) - except Exception as e: - text = text.replace(newChar, '') - self.entryWidget.setText(text) - return - - if self.allowedValues is not None: - currentVal = self.value() - if self.allowList: - currentVal = currentVal[-1] - if currentVal not in self.allowedValues: - self.notValidLabel.setText(f'{currentVal} not existing!') - else: - self.notValidLabel.setText('') - - def warnValLessLastFrame(self, val): - msg = widgets.myMessageBox() - warn_txt = html_utils.paragraph(f""" - WARNING: saving until a frame number below the last visited - frame ({self.lastVisitedFrame}) will result in LOSS of information - about any edit or annotation you did on frames - {val+1}-{self.lastVisitedFrame}.

- Are you sure you want to proceed? - """) - msg.warning( - self, 'WARNING: Potential loss of information', warn_txt, - buttonsTexts=('Cancel', 'Yes, I am sure.') - ) - return msg.cancel - - def warnValMoreLastVisitedFrame(self, val): - msg = widgets.myMessageBox() - warn_txt = html_utils.paragraph(f""" - The last visited/validated frame is {self.lastVisitedFrame} - .

- Are you sure you want to save until frame n. {val}?
- """) - msg.warning( - self, 'Saving past last visited frame', warn_txt, - buttonsTexts=('Cancel', 'Yes, I am sure.') - ) - return msg.cancel - - def ok_cb(self, event): - if not self.allowEmpty and not self.entryWidget.text(): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.critical( - self, 'Empty text', - html_utils.paragraph('Text entry field cannot be empty') - ) - return - if self.allowedTextEntries is not None: - if self.entryWidget.text() not in self.allowedTextEntries: - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph( - f'"{self.entryWidget.text()}" is not a valid entry.

' - 'Valid entries are:
' - f'{html_utils.to_list(self.allowedTextEntries)}' - ) - msg.critical(self, 'Not a valid entry', txt) - return - - if self.allowedValues: - if self.notValidLabel.text(): - return - - val = self.value() - - if self.warnLastFrame and self.lastVisitedFrame is not None: - if val < self.lastVisitedFrame: - cancel = self.warnValLessLastFrame(val) - if cancel: - return - - if self.lastVisitedFrame is not None: - if val > self.lastVisitedFrame: - cancel = self.warnValMoreLastVisitedFrame(val) - if cancel: - return - - self.cancel = False - try: - self.EntryID = int(val) - except Exception as err: - self.EntryID = val - - self.enteredValue = val - self.close() - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class FindIDDialog(QLineEditDialog): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.okButton.setIcon(QIcon(':magnGlass.svg')) - self.okButton.setText(' Find ') - -class NumericEntryDialog(QBaseDialog): - def __init__( - self, title='Entry a value', currentValue=0, - instructions='Entry value', parent=None, - maxValue=None, minValue=None, stretch=False - ): - super().__init__(parent=parent) - self.setWindowTitle(title) - self.cancel = False - mainLayout = QVBoxLayout() - entryLayout = QHBoxLayout() - cancelOkLayout = widgets.CancelOkButtonsLayout() - cancelOkLayout.okButton.clicked.connect(self.ok_cb) - cancelOkLayout.cancelButton.clicked.connect(self.close) - - instructionsLabel = QLabel(html_utils.paragraph(instructions)) - mainLayout.addWidget(instructionsLabel) - - if type(currentValue) == int: - self.entryWidget = widgets.SpinBox() - self.entryWidget.setValue(currentValue) - self.valueGetter = 'value' - if maxValue is not None: - self.entryWidget.setMaximum(maxValue) - if minValue is not None: - self.entryWidget.setMinimum(minValue) - - if stretch: - entryLayout.addWidget(self.entryWidget) - else: - entryLayout.addStretch(1) - entryLayout.addWidget(self.entryWidget) - entryLayout.addStretch(1) - - mainLayout.addLayout(entryLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(cancelOkLayout) - - self.setLayout(mainLayout) - - def ok_cb(self): - self.cancel = False - self.value = getattr(self.entryWidget, self.valueGetter)() - self.close() - -class EditIDDialog(QDialog): - def __init__( - self, clickedID, IDs, - entryID=None, - doNotShowAgain=False, - parent=None, - nextUniqueID=1, - allIDs=None, - addPropagateCheckbox=False - ): - self.assignNewID = False - self.IDs = IDs - self.clickedID = clickedID - self.cancel = True - self.how = None - self.mergeWithExistingID = True - self.doNotAskAgainExistingID = doNotShowAgain - self.allIDs = allIDs - if allIDs is None: - self.allIDs = set(self.IDs) - self.nextUniqueID = nextUniqueID - - super().__init__(parent) - self.setWindowTitle("Edit ID") - mainLayout = QVBoxLayout() - - VBoxLayout = QVBoxLayout() - msg = QLabel(f'Replace ID {clickedID} with:') - _font = QFont() - _font.setPixelSize(12) - msg.setFont(_font) - # padding: top, left, bottom, right - msg.setStyleSheet("padding:0px 0px 3px 0px;") - VBoxLayout.addWidget(msg, alignment=Qt.AlignCenter) - - entryWidget = QLineEdit() - entryWidget.setFont(_font) - entryWidget.setAlignment(Qt.AlignCenter) - self.entryWidget = entryWidget - VBoxLayout.addWidget(entryWidget) - if entryID is not None: - entryWidget.setText(str(entryID)) - entryWidget.selectAll() - - VBoxLayout.addWidget( - QLabel(f'Next unique ID = {nextUniqueID}'), alignment=Qt.AlignCenter - ) - - VBoxLayout.addWidget(widgets.QHLine()) - - self.warnExistingIDLabel = QLabel() - self.warnExistingIDLabel.setStyleSheet('color: red') - VBoxLayout.addWidget( - self.warnExistingIDLabel, alignment=Qt.AlignCenter - ) - - note = QLabel( - 'NOTE: To replace multiple IDs at once\n' - 'write "(old ID, new ID), (old ID, new ID)" etc.' - ) - note.setFont(_font) - note.setAlignment(Qt.AlignCenter) - # padding: top, left, bottom, right - note.setStyleSheet("padding:12px 0px 0px 0px;") - VBoxLayout.addWidget(note, alignment=Qt.AlignCenter) - mainLayout.addLayout(VBoxLayout) - - self.propagateCheckbox = None - if addPropagateCheckbox: - mainLayout.addSpacing(10) - self.propagateCheckbox = QCheckBox('Apply to future frames') - mainLayout.addWidget(self.propagateCheckbox) - - buttonsLayout = QHBoxLayout() - okButton = widgets.okPushButton('Ok') - cancelButton = widgets.cancelPushButton('Cancel') - applyNewIDButton = widgets.AssignNewIDButton('Assign new, unique ID') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(applyNewIDButton) - buttonsLayout.addWidget(okButton) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - # Connect events - self.prevText = '' - entryWidget.textChanged[str].connect(self.onTextChanged) - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - applyNewIDButton.clicked.connect(self.assignNewIDclicked) - - # self.setModal(True) - - def onTextChanged(self, text): - self.warnExistingIDLabel.setText('') - try: - ID = int(text) - if ID in self.allIDs: - self.warnExistingIDLabel.setText( - f'WARNING: ID {ID} was already used' - ) - except Exception as err: - pass - - # Get inserted char - idx = self.entryWidget.cursorPosition() - if idx == 0: - return - - newChar = text[idx-1] - - # Do nothing if user is deleting text - if idx == 0 or len(text) uint32_max: - text = self.entryWidget.text() - text = f'{text[:m.start()]}{uint32_max}{text[m.end():]}' - self.entryWidget.setText(text) - - # Automatically close ( bracket - if newChar == '(': - text += ')' - self.entryWidget.setText(text) - self.prevText = text - - def _warnExistingID(self, existingID, newID): - warn_msg = html_utils.paragraph(f""" - ID {existingID} is already existing.

- How do you want to proceed?
- """) - msg = widgets.myMessageBox() - doNotAskAgainCheckbox = QCheckBox('Remember my choice and do not ask again') - swapButton = widgets.reloadPushButton(f'Swap {newID} with {existingID}') - mergeButton = widgets.mergePushButton(f'Merge {newID} with {existingID}') - msg.warning( - self, 'Existing ID', warn_msg, - buttonsTexts=('Cancel', mergeButton, swapButton), - widgets=doNotAskAgainCheckbox - ) - if msg.cancel: - return False - self.doNotAskAgainExistingID = doNotAskAgainCheckbox.isChecked() - self.mergeWithExistingID = msg.clickedButton == mergeButton - return True - - def assignNewIDclicked(self): - self.cancel = False - self.how = None - self.assignNewID = True - self.close() - - def ok_cb(self, event): - txt = self.entryWidget.text() - valid = False - - # Check validity of inserted text - try: - ID = int(txt) - how = [(self.clickedID, ID)] - if ID in self.IDs and not self.doNotAskAgainExistingID: - proceed = self._warnExistingID(self.clickedID, ID) - if not proceed: - return - valid = True - else: - valid = True - except ValueError: - pattern = r'\((\d+),\s*(\d+)\)' - fa = re.findall(pattern, txt) - if fa: - how = [(int(g[0]), int(g[1])) for g in fa] - valid = True - else: - valid = False - - if not valid: - err_msg = html_utils.paragraph( - 'You entered invalid text. Valid text is either a single integer' - f' ID that will be used to replace ID {self.clickedID} ' - 'or a list of elements enclosed in parenthesis separated by a comma
' - 'such as (5, 10), (8, 27) to replace ID 5 with ID 10 and ID 8 with ID 27' - ) - msg = widgets.myMessageBox() - msg.warning( - self, 'Invalid entry', err_msg - ) - return - - self.cancel = False - self.how = how - self.doPropagateFutureFrames = False - if self.propagateCheckbox is not None: - self.doPropagateFutureFrames = self.propagateCheckbox.isChecked() - self.close() - - - def cancel_cb(self, event): - self.cancel = True - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QtSelectItems(QDialog): - def __init__( - self, title, items, informativeText, - CbLabel='Select value: ', parent=None, - showInFileManagerPath=None - ): - self.cancel = True - self.selectedItemsText = '' - self.selectedItemsIdx = None - self.showInFileManagerPath = showInFileManagerPath - self.items = items - super().__init__(parent) - self.setWindowTitle(title) - - mainLayout = QVBoxLayout() - topLayout = QHBoxLayout() - self.topLayout = topLayout - bottomLayout = QHBoxLayout() - - stretchRow = 0 - if informativeText: - infoLabel = QLabel(informativeText) - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - stretchRow = 1 - - label = QLabel(CbLabel) - topLayout.addWidget(label, alignment=Qt.AlignRight) - - combobox = QComboBox(self) - combobox.addItems(items) - self.ComboBox = combobox - topLayout.addWidget(combobox) - - okButton = widgets.okPushButton('Ok') - cancelButton = widgets.cancelPushButton('Cancel') - if showInFileManagerPath is not None: - txt = myutils.get_open_filemaneger_os_string() - showInFileManagerButton = widgets.showInFileManagerButton(txt) - - bottomLayout.addStretch(1) - bottomLayout.addWidget(cancelButton) - bottomLayout.addSpacing(20) - if showInFileManagerPath is not None: - bottomLayout.addWidget(showInFileManagerButton) - bottomLayout.addWidget(okButton) - - multiPosButton = QPushButton('Multiple selection') - multiPosButton.setCheckable(True) - self.multiPosButton = multiPosButton - bottomLayout.addWidget(multiPosButton, alignment=Qt.AlignLeft) - - listBox = widgets.listWidget() - listBox.addItems(items) - listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - listBox.setCurrentRow(0) - listBox.setFont(font) - topLayout.addWidget(listBox) - listBox.hide() - self.ListBox = listBox - - mainLayout.addLayout(topLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(bottomLayout) - - self.setLayout(mainLayout) - self.mainLayout = mainLayout - self.topLayout = topLayout - - # self.setModal(True) - - # Connect events - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - multiPosButton.toggled.connect(self.toggleMultiSelection) - if showInFileManagerPath is not None: - showInFileManagerButton.clicked.connect(self.showInFileManager) - - self.setFont(font) - - def setSelectedItems(self, selectedItemsText): - if self.multiPosButton.isChecked(): - for i in range(self.ListBox.count()): - item = self.ListBox.item(i) - if item.text() in selectedItemsText: - item.setSelected(True) - else: - idx = self.items.index(selectedItemsText[0]) - self.ComboBox.setCurrentIndex(idx) - - def showInFileManager(self): - selectedTexts, _ = self.getSelectedItems() - folder = selectedTexts[0].split('(')[0].strip() - path = os.path.join(self.showInFileManagerPath, folder) - if os.path.exists(path) and os.path.isdir(path): - showPath = path - else: - showPath = self.showInFileManagerPath - myutils.showInExplorer(showPath) - - def toggleMultiSelection(self, checked): - if checked: - self.multiPosButton.setText('Single selection') - self.ComboBox.hide() - self.ListBox.show() - # Show 10 items - n = self.ListBox.count() - if n > 10: - h = sum([self.ListBox.sizeHintForRow(i) for i in range(10)]) - else: - h = sum([self.ListBox.sizeHintForRow(i) for i in range(n)]) - self.ListBox.setMinimumHeight(h+5) - self.ListBox.setFocusPolicy(Qt.StrongFocus) - self.ListBox.setFocus() - self.ListBox.setCurrentRow(0) - self.mainLayout.setStretchFactor(self.topLayout, 2) - else: - self.multiPosButton.setText('Multiple selection') - self.ListBox.hide() - self.ComboBox.show() - self.resize(self.width(), self.singleSelectionHeight) - - def getSelectedItems(self): - if self.multiPosButton.isChecked(): - selectedItems = self.ListBox.selectedItems() - selectedItemsText = [item.text() for item in selectedItems] - selectedItemsText = natsorted(selectedItemsText) - selectedItemsIdx = [ - self.items.index(txt) for txt in selectedItemsText - ] - else: - selectedItemsText = [self.ComboBox.currentText()] - selectedItemsIdx = [self.ComboBox.currentIndex()] - return selectedItemsText, selectedItemsIdx - - def ok_cb(self, event): - self.cancel = False - self.selectedItemsText, self.selectedItemsIdx = self.getSelectedItems() - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - self.singleSelectionHeight = self.height() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - - -class manualSeparateGui(QMainWindow): - def __init__( - self, lab, ID, img, fontSize='12pt', IDcolor=[255, 255, 0], - parent=None, loop=None, drawMode='threepoints_arc' - ): - super().__init__(parent) - self.loop = loop - self.cancel = True - self.drawMode = drawMode - self._parent = parent - self.lab = lab.copy() - self.lab[lab!=ID] = 0 - self.ID = ID - self.img = skimage.exposure.equalize_adapthist(img/img.max()) - self.IDcolor = IDcolor - self.countClicks = 0 - self.prevLabs = [] - self.prevAllCutsCoords = [] - self.labelItemsIDs = [] - self.undoIdx = 0 - self.fontSize = fontSize - self.AllCutsCoords = [] - self.setWindowTitle("Split object") - # self.setGeometry(Left, Top, 850, 800) - - self.gui_createActions() - self.gui_createMenuBar() - self.gui_createToolBars() - - self.gui_createStatusBar() - - self.gui_createGraphics() - self.gui_connectImgActions() - - self.gui_createImgWidgets() - self.gui_connectActions() - - self.updateImg() - self.zoomToObj() - - mainContainer = QWidget() - self.setCentralWidget(mainContainer) - - mainLayout = QGridLayout() - mainLayout.addWidget(self.graphLayout, 0, 0, 1, 1) - mainLayout.addLayout(self.img_Widglayout, 1, 0) - - mainContainer.setLayout(mainLayout) - - self.setWindowModality(Qt.WindowModal) - - def centerWindow(self): - parent = self._parent - if parent is not None: - # Center the window on main window - mainWinGeometry = parent.geometry() - mainWinLeft = mainWinGeometry.left() - mainWinTop = mainWinGeometry.top() - mainWinWidth = mainWinGeometry.width() - mainWinHeight = mainWinGeometry.height() - mainWinCenterX = int(mainWinLeft + mainWinWidth/2) - mainWinCenterY = int(mainWinTop + mainWinHeight/2) - winGeometry = self.geometry() - winWidth = winGeometry.width() - winHeight = winGeometry.height() - winLeft = int(mainWinCenterX - winWidth/2) - winRight = int(mainWinCenterY - winHeight/2) - self.move(winLeft, winRight) - - def gui_createActions(self): - # File actions - self.exitAction = QAction("&Exit", self) - self.helpAction = QAction('Help', self) - self.undoAction = QAction(QIcon(":undo.svg"), "Undo (Ctrl+Z)", self) - self.undoAction.setEnabled(False) - self.undoAction.setShortcut("Ctrl+Z") - - self.okAction = QAction(QIcon(":applyCrop.svg"), "Happy with that", self) - self.cancelAction = QAction(QIcon(":cancel.svg"), "Cancel", self) - - self.drawModesActionGroup = QActionGroup(self) - - self.threePointsArcAction = QAction( - QIcon(":threepoints_arc.svg"), 'Separate with three-points arc', - self - ) - self.threePointsArcAction.setCheckable(True) - self.threePointsArcAction.drawMode = 'threepoints_arc' - self.drawModesActionGroup.addAction(self.threePointsArcAction) - - self.freeHandAction = QAction( - QIcon(":freehand.svg"), 'Separate with freehand line', self - ) - self.freeHandAction.setCheckable(True) - self.freeHandAction.drawMode = 'freehand' - self.drawModesActionGroup.addAction(self.freeHandAction) - - if self.drawMode == 'threepoints_arc': - self.threePointsArcAction.setChecked(True) - elif self.drawMode == 'freehand': - self.freeHandAction.setChecked(True) - - self.swapIDsAction = QAction( - QIcon(":reload.svg"), "Swap IDs", self - ) - self.swapIDsAction.setToolTip( - 'Swap the two displayed IDs\n\n' - 'Shortcut: "S"' - ) - self.swapIDsAction.setShortcut('S') - - def state(self): - return { - 'is_overlay_active': self.overlayButton.isChecked(), - 'is_three_points_active': self.threePointsArcAction.isChecked(), - 'is_free_hand_active': self.freeHandAction.isChecked() - } - - def show(self, block=False): - super().show() - if not block: - return - self.loop = QEventLoop(self) - self.loop.exec_() - - def setState(self, state): - if state is None: - return - self.overlayButton.setChecked(state.get('is_overlay_active', False)) - self.threePointsArcAction.setChecked( - state.get('is_three_points_active', True) - ) - self.freeHandAction.setChecked(state.get('is_free_hand_active', False)) - - def gui_storeDrawMode(self): - self.drawMode = self.sender().drawMode - - def gui_createMenuBar(self): - menuBar = self.menuBar() - # style = "QMenuBar::item:selected { background: white; }" - # menuBar.setStyleSheet(style) - # File menu - fileMenu = QMenu("&File", self) - menuBar.addMenu(fileMenu) - - menuBar.addAction(self.helpAction) - fileMenu.addAction(self.exitAction) - - def gui_createToolBars(self): - toolbarSize = 30 - - editToolBar = QToolBar("Edit", self) - editToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) - self.addToolBar(editToolBar) - - editToolBar.addAction(self.okAction) - editToolBar.addAction(self.cancelAction) - - editToolBar.addAction(self.undoAction) - - self.overlayButton = QToolButton(self) - self.overlayButton.setIcon(QIcon(":overlay.svg")) - self.overlayButton.setCheckable(True) - self.overlayButton.setToolTip( - 'Overlay channel\'s image' - ) - editToolBar.addWidget(self.overlayButton) - - editToolBar.addAction(self.threePointsArcAction) - editToolBar.addAction(self.freeHandAction) - - editToolBar.addAction(self.swapIDsAction) - - self.warnLabel = QLabel() - editToolBar.addWidget(self.warnLabel) - - - def gui_connectActions(self): - self.exitAction.triggered.connect(self.close) - self.helpAction.triggered.connect(self.help) - self.okAction.triggered.connect(self.ok_cb) - self.cancelAction.triggered.connect(self.close) - self.undoAction.triggered.connect(self.undo) - self.overlayButton.toggled.connect(self.toggleOverlay) - self.imgGrad.sigLookupTableChanged.connect(self.histLUT_cb) - self.swapIDsAction.triggered.connect(self.swapIDs) - - def gui_createStatusBar(self): - self.statusbar = self.statusBar() - # Temporary message - self.statusbar.showMessage("Ready", 3000) - # Permanent widget - self.wcLabel = QLabel(f"") - self.statusbar.addPermanentWidget(self.wcLabel) - - def gui_createGraphics(self): - self.graphLayout = pg.GraphicsLayoutWidget() - - # Plot Item container for image - self.ax = pg.PlotItem() - self.ax.invertY(True) - self.ax.setAspectLocked(True) - self.ax.hideAxis('bottom') - self.ax.hideAxis('left') - self.graphLayout.addItem(self.ax, row=1, col=1) - - # Image Item - self.imgItem = pg.ImageItem(np.zeros((512,512))) - self.ax.addItem(self.imgItem) - - #Image histogram - self.imgGrad = widgets.myHistogramLUTitem() - - # Curvature items - self.hoverLinSpace = np.linspace(0, 1, 1000) - self.hoverLinePen = pg.mkPen(color=(200, 0, 0, 255*0.5), - width=2, style=Qt.DashLine) - self.hoverCurvePen = pg.mkPen(color=(200, 0, 0, 255*0.5), width=3) - self.lineHoverPlotItem = pg.PlotDataItem(pen=self.hoverLinePen) - self.curvHoverPlotItem = pg.PlotDataItem(pen=self.hoverCurvePen) - self.curvAnchors = pg.ScatterPlotItem( - symbol='o', size=9, - brush=pg.mkBrush((255,0,0,50)), - pen=pg.mkPen((255,0,0), width=2), - hoverable=True, hoverPen=pg.mkPen((255,0,0), width=3), - hoverBrush=pg.mkBrush((255,0,0)) - ) - self.ax.addItem(self.curvAnchors) - self.ax.addItem(self.curvHoverPlotItem) - self.ax.addItem(self.lineHoverPlotItem) - - self.freeHandItem = widgets.PlotCurveItem( - pen=pg.mkPen(color='r', width=2) - ) - self.ax.addItem(self.freeHandItem) - - def gui_createImgWidgets(self): - self.img_Widglayout = QGridLayout() - self.img_Widglayout.setContentsMargins(50, 0, 50, 0) - - alphaScrollBar_label = QLabel('Overlay alpha ') - alphaScrollBar = QScrollBar(Qt.Horizontal) - alphaScrollBar.setFixedHeight(20) - alphaScrollBar.setMinimum(0) - alphaScrollBar.setMaximum(40) - alphaScrollBar.setValue(12) - alphaScrollBar.setToolTip( - 'Control the alpha value of the overlay.\n' - 'alpha=0 results in NO overlay,\n' - 'alpha=1 results in only labels visible' - ) - alphaScrollBar.sliderMoved.connect(self.alphaScrollBarMoved) - self.alphaScrollBar = alphaScrollBar - self.alphaScrollBar_label = alphaScrollBar_label - self.img_Widglayout.addWidget( - alphaScrollBar_label, 0, 0, alignment=Qt.AlignCenter - ) - self.img_Widglayout.addWidget(alphaScrollBar, 0, 1, 1, 20) - self.alphaScrollBar.hide() - self.alphaScrollBar_label.hide() - - def gui_connectImgActions(self): - self.imgItem.hoverEvent = self.gui_hoverEventImg - self.imgItem.mousePressEvent = self.gui_mousePressEventImg - self.imgItem.mouseMoveEvent = self.gui_mouseDragEventImg - self.imgItem.mouseReleaseEvent = self.gui_mouseReleaseEventImg - - def gui_hoverEventImg(self, event): - # Update x, y, value label bottom right - try: - x, y = event.pos() - xdata, ydata = int(x), int(y) - _img = self.lab - Y, X = _img.shape - if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - val = _img[ydata, xdata] - self.wcLabel.setText(f'(x={x:.2f}, y={y:.2f}, ID={val:.0f})') - else: - self.wcLabel.setText(f'') - except Exception as e: - self.wcLabel.setText(f'') - - if event.isExit(): - return - - self.drawHoverEvent(*event.pos()) - - def gui_mousePressEventImg(self, event): - right_click = event.button() == Qt.MouseButton.RightButton - left_click = event.button() == Qt.MouseButton.LeftButton - - dragImg = (left_click) - - if dragImg: - pg.ImageItem.mousePressEvent(self.imgItem, event) - - if not right_click: - return - - self.drawPressEvent(event) - - def gui_mouseDragEventImg(self, event): - pass - - def gui_mouseReleaseEventImg(self, event): - if self.countClicks == 0: - return - if self.freeHandAction.isChecked(): - self.countClicks = 0 - xx, yy = self.freeHandItem.getData() - self.setSplitCurveCoords(xx, yy) - self.splitObjectAlongCurve() - self.freeHandItem.setData([], []) - self.curvAnchors.setData([], []) - - def getSpline(self, xx, yy): - tck, u = scipy.interpolate.splprep([xx, yy], s=0, k=2) - xi, yi = scipy.interpolate.splev(self.hoverLinSpace, tck) - return xi, yi - - def drawPressEvent(self, event): - if self.freeHandAction.isChecked(): - self.countClicks = 1 - x, y = event.pos().x(), event.pos().y() - self.curvAnchors.addPoints([x], [y]) - elif self.threePointsArcAction.isChecked(): - self.threePointsArcPressEvent(event) - - def drawHoverEvent(self, x, y): - if self.freeHandAction.isChecked(): - self.freeHandHoverEvent(x, y) - elif self.threePointsArcAction.isChecked(): - self.threePointsArcHoverEvent(x, y) - - def freeHandHoverEvent(self, x, y): - if self.countClicks == 0: - return - self.freeHandItem.addPoint(int(x), int(y)) - _xx, _yy = self.freeHandItem.getData() - xx = [_xx[0], x] - yy = [_yy[0], y] - self.curvAnchors.setData(xx, yy) - - def threePointsArcHoverEvent(self, x, y): - if self.countClicks == 1: - self.lineHoverPlotItem.setData([self.x0, x], [self.y0, y]) - elif self.countClicks == 2: - xx = [self.x0, x, self.x1] - yy = [self.y0, y, self.y1] - xi, yi = self.getSpline(xx, yy) - self.curvHoverPlotItem.setData(xi, yi) - elif self.countClicks == 0: - self.curvHoverPlotItem.setData([], []) - self.lineHoverPlotItem.setData([], []) - self.curvAnchors.setData([], []) - - def threePointsArcPressEvent(self, event): - if self.countClicks == 0: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - self.x0, self.y0 = xdata, ydata - self.curvAnchors.addPoints([xdata], [ydata]) - self.countClicks = 1 - elif self.countClicks == 1: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - self.x1, self.y1 = xdata, ydata - self.curvAnchors.addPoints([xdata], [ydata]) - self.countClicks = 2 - elif self.countClicks == 2: - self.countClicks = 0 - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - xx = [self.x0, xdata, self.x1] - yy = [self.y0, ydata, self.y1] - xi, yi = self.getSpline(xx, yy) - yy, xx = np.round(yi).astype(int), np.round(xi).astype(int) - self.setSplitCurveCoords(xx, yy) - self.splitObjectAlongCurve() - - def setSplitCurveCoords(self, xx, yy): - self.storeUndoState() - xxCurve, yyCurve = [], [] - for i, (r0, c0) in enumerate(zip(yy, xx)): - if i == len(yy)-1: - break - r1 = yy[i+1] - c1 = xx[i+1] - rr, cc, _ = skimage.draw.line_aa(r0, c0, r1, c1) - # rr, cc = skimage.draw.line(r0, c0, r1, c1) - nonzeroMask = self.lab[rr, cc]>0 - xxCurve.extend(cc[nonzeroMask]) - yyCurve.extend(rr[nonzeroMask]) - self.AllCutsCoords.append((yyCurve, xxCurve)) - for rr, cc in self.AllCutsCoords: - self.lab[rr, cc] = 0 - self.lab = skimage.morphology.remove_small_objects(self.lab, 5) - - def histLUT_cb(self, LUTitem): - if self.overlayButton.isChecked(): - overlay = self.getOverlay() - self.imgItem.setImage(overlay) - - def swapIDs(self, checked=False): - if len(self.rp) == 1: - self.warnLabel.setText( - html_utils.paragraph( - 'WARNING: Split the object before swapping IDs', - font_color='red' - ) - ) - return - - self.warnLabel.setText('') - - obj1 = self.rp[0] - obj2 = self.rp[1] - - self.lab[obj1.slice][obj1.image] = obj2.label - self.lab[obj2.slice][obj2.image] = obj1.label - - self.updateImg() - - def updateImg(self): - self.updateLookuptable() - rp = skimage.measure.regionprops(self.lab) - self.rp = rp - - if self.overlayButton.isChecked(): - overlay = self.getOverlay() - self.imgItem.setImage(overlay) - else: - self.imgItem.setImage(self.lab) - - # Draw ID on centroid of each label - for labelItemID in self.labelItemsIDs: - self.ax.removeItem(labelItemID) - self.labelItemsIDs = [] - for obj in rp: - labelItemID = widgets.myLabelItem() - labelItemID.setText( - f'{obj.label}', color='r', size=f'{self.fontSize}px' - ) - y, x = obj.centroid - w, h = labelItemID.rect().right(), labelItemID.rect().bottom() - labelItemID.setPos(x-w/2, y-h/2) - self.labelItemsIDs.append(labelItemID) - self.ax.addItem(labelItemID) - - def zoomToObj(self): - # Zoom to object - lab_mask = (self.lab>0).astype(np.uint8) - rp = skimage.measure.regionprops(lab_mask) - obj = rp[0] - min_row, min_col, max_row, max_col = obj.bbox - xRange = min_col-10, max_col+10 - yRange = max_row+10, min_row-10 - self.ax.setRange(xRange=xRange, yRange=yRange) - - def storeUndoState(self): - self.prevLabs.append(self.lab.copy()) - self.prevAllCutsCoords.append(self.AllCutsCoords.copy()) - self.undoIdx += 1 - self.undoAction.setEnabled(True) - - def undo(self): - self.undoIdx -= 1 - self.lab = self.prevLabs[self.undoIdx] - self.AllCutsCoords = self.prevAllCutsCoords[self.undoIdx] - self.updateImg() - if self.undoIdx == 0: - self.undoAction.setEnabled(False) - self.prevLabs = [] - self.prevAllCutsCoords = [] - - def splitObjectAlongCurve(self): - self.lab = skimage.measure.label(self.lab, connectivity=1) - - # Relabel largest object with original ID - rp = skimage.measure.regionprops(self.lab) - areas = [obj.area for obj in rp] - IDs = [obj.label for obj in rp] - maxAreaIdx = areas.index(max(areas)) - maxAreaID = IDs[maxAreaIdx] - if self.ID not in self.lab: - self.lab[self.lab==maxAreaID] = self.ID - else: - tempID = self.lab.max() + 1 - self.lab[self.lab==maxAreaID] = tempID - self.lab[self.lab==self.ID] = maxAreaID - self.lab[self.lab==tempID] = self.ID - - # Keep only the two largest objects - larger_areas = nlargest(2, areas) - larger_ids = [rp[areas.index(area)].label for area in larger_areas] - for obj in rp: - if obj.label not in larger_ids: - self.lab[tuple(obj.coords.T)] = 0 - - rp = skimage.measure.regionprops(self.lab) - - if self._parent is not None: - self._parent.setBrushID() - # Use parent window setBrushID function for all other IDs - for obj in rp: - if self._parent is None: - break - if obj.label == self.ID: - continue - posData = self._parent.data[self._parent.pos_i] - posData.brushID += 1 - self.lab[obj.slice][obj.image] = posData.brushID - - # Replace 0s on the cutting curve with IDs - self.cutLab = self.lab.copy() - for rr, cc in self.AllCutsCoords: - for y, x in zip(rr, cc): - top_row = self.cutLab[y+1, x-1:x+2] - bot_row = self.cutLab[y-1, x-1:x+1] - left_col = self.cutLab[y-1, x-1] - right_col = self.cutLab[y:y+2, x+1] - allNeigh = list(top_row) - allNeigh.extend(bot_row) - allNeigh.append(left_col) - allNeigh.extend(right_col) - newID = max(allNeigh) - self.lab[y,x] = newID - - self.rp = skimage.measure.regionprops(self.lab) - self.updateImg() - - def updateLookuptable(self): - # Lookup table - self.cmap = colors.getFromMatplotlib('viridis') - self.lut = self.cmap.getLookupTable(0,1,self.lab.max()+1) - self.lut[0] = [25,25,25] - self.lut[self.ID] = self.IDcolor - if self.overlayButton.isChecked(): - self.imgItem.setLookupTable(None) - else: - self.imgItem.setLookupTable(self.lut) - - def keyPressEvent(self, ev): - if ev.key() == Qt.Key_Escape: - self.countClicks = 0 - self.curvHoverPlotItem.setData([], []) - self.lineHoverPlotItem.setData([], []) - self.curvAnchors.setData([], []) - self.freeHandItem.setData([], []) - elif ev.key() == Qt.Key_Enter or ev.key() == Qt.Key_Return: - self.ok_cb(True) - - def getOverlay(self): - # Rescale intensity based on hist ticks values - min = self.imgGrad.gradient.listTicks()[0][1] - max = self.imgGrad.gradient.listTicks()[1][1] - img = skimage.exposure.rescale_intensity(self.img, in_range=(min, max)) - alpha = self.alphaScrollBar.value()/self.alphaScrollBar.maximum() - - # Convert img and lab to RGBs - rgb_shape = (self.lab.shape[0], self.lab.shape[1], 3) - labRGB = np.zeros(rgb_shape) - labRGB[self.lab>0] = [1, 1, 1] - imgRGB = skimage.color.gray2rgb(img) - overlay = imgRGB*(1.0-alpha) + labRGB*alpha - - # Color eaach label - for obj in self.rp: - rgb = self.lut[obj.label]/255 - overlay[obj.slice][obj.image] *= rgb - - # Convert (0,1) to (0,255) - overlay = (np.clip(overlay, 0, 1)*255).astype(np.uint8) - return overlay - - def alphaScrollBarMoved(self, alpha_int): - overlay = self.getOverlay() - self.imgItem.setImage(overlay) - - def toggleOverlay(self, checked): - if checked: - self.graphLayout.addItem(self.imgGrad, row=1, col=0) - self.alphaScrollBar.show() - self.alphaScrollBar_label.show() - else: - self.graphLayout.removeItem(self.imgGrad) - self.alphaScrollBar.hide() - self.alphaScrollBar_label.hide() - self.updateImg() - - def help(self): - msg = QMessageBox() - msg.information(self, 'Help', - 'Separate object along a curved line.\n\n' - 'To draw a curved line you will need 3 right-clicks:\n\n' - '1. Right-click outside of the object --> a line appears.\n' - '2. Right-click to end the line and a curve going through the ' - 'mouse cursor will appear.\n' - '3. Once you are happy with the cutting curve right-click again ' - 'and the object will be separated along the curve.\n\n' - 'Note that you can separate as many times as you want.\n\n' - 'Once happy click on the green tick on top-right or ' - 'cancel the process with the "X" button') - - def ok_cb(self, checked): - self.cancel = False - self.close() - - def closeEvent(self, event): - if self.loop is not None: - self.loop.exit() - -class DataFrameModel(QtCore.QAbstractTableModel): - # https://stackoverflow.com/questions/44603119/how-to-display-a-pandas-data-frame-with-pyqt5-pyside2 - DtypeRole = QtCore.Qt.UserRole + 1000 - ValueRole = QtCore.Qt.UserRole + 1001 - - def __init__(self, df=pd.DataFrame(), parent=None): - super(DataFrameModel, self).__init__(parent) - self._dataframe = df - - def setDataFrame(self, dataframe): - self.beginResetModel() - self._dataframe = dataframe.copy() - self.endResetModel() - - def dataFrame(self): - return self._dataframe - - dataFrame = QtCore.Property(pd.DataFrame, fget=dataFrame, - fset=setDataFrame) - - @QtCore.Slot(int, QtCore.Qt.Orientation, result=str) - def headerData(self, section: int, - orientation: QtCore.Qt.Orientation, - role: int = QtCore.Qt.DisplayRole): - if role == QtCore.Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._dataframe.columns[section] - else: - return str(self._dataframe.index[section]) - return QtCore.QVariant() - - def rowCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return len(self._dataframe.index) - - def columnCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return self._dataframe.columns.size - - def data(self, index, role=QtCore.Qt.DisplayRole): - if not index.isValid() or not (0 <= index.row() < self.rowCount() \ - and 0 <= index.column() < self.columnCount()): - return QtCore.QVariant() - row = self._dataframe.index[index.row()] - col = self._dataframe.columns[index.column()] - dt = self._dataframe[col].dtype - - if role == Qt.TextAlignmentRole: - return Qt.AlignCenter - - val = self._dataframe.iloc[row][col] - if role == QtCore.Qt.DisplayRole: - return str(val) - elif role == DataFrameModel.ValueRole: - return val - if role == DataFrameModel.DtypeRole: - return dt - return QtCore.QVariant() - - def roleNames(self): - roles = { - QtCore.Qt.DisplayRole: b'display', - DataFrameModel.DtypeRole: b'dtype', - DataFrameModel.ValueRole: b'value' - } - return roles - -class pdDataFrameWidget(QMainWindow): - def __init__(self, df, parent=None): - super().__init__(parent) - self.parent = parent - self.setWindowTitle('Cell cycle annotations') - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - - mainContainer = QWidget() - self.setCentralWidget(mainContainer) - - layout = QVBoxLayout() - self._layout = layout - - self.tableView = QTableView(self) - layout.addWidget(self.tableView) - model = DataFrameModel(df) - self.tableView.setModel(model) - for i in range(len(df.columns)): - self.tableView.resizeColumnToContents(i) - # layout.addWidget(QPushButton('Ok', self)) - mainContainer.setLayout(layout) - - def updateTable(self, df, IDs=None): - if df is None: - df = self.parent.getBaseCca_df() - - if IDs is not None: - df = df.loc[IDs] - - df = df.reset_index() - model = DataFrameModel(df) - self.tableView.setModel(model) - for i in range(len(df.columns)): - self.tableView.resizeColumnToContents(i) - - def setGeometryWindow(self, maxWidth=1024): - width = self.tableView.verticalHeader().width() + 4 - for j in range(self.tableView.model().columnCount()): - width += self.tableView.columnWidth(j) + 4 - height = self.tableView.horizontalHeader().height() + 4 - h = height + (self.tableView.rowHeight(0) + 4)*10 - w = width if width
{filename}

- however you never selected which z-slice
you want to use - when calculating metrics
(e.g., mean, median, amount...etc.)

- Choose one of following options: - """, center=True - ) - infoLabel = QLabel(txt) - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - - runDataPrepButton = QPushButton( - ' Visualize the data now and select a z-slice ' - ) - buttonsLayout.addWidget(runDataPrepButton, 0, 1, 1, 2) - runDataPrepButton.clicked.connect(self.runDataPrep_cb) - - useMiddleSliceButton = QPushButton( - f' Use the middle z-slice ({int(SizeZ/2)+1}) ' - ) - buttonsLayout.addWidget(useMiddleSliceButton, 1, 1, 1, 2) - useMiddleSliceButton.clicked.connect(self.useMiddleSlice_cb) - - useSameAsChButton = QPushButton( - ' Use the same z-slice used for the channel: ' - ) - useSameAsChButton.clicked.connect(self.useSameAsCh_cb) - - chNameComboBox = QComboBox() - chNameComboBox.addItems(filenamesWithInfo) - # chNameComboBox.setEditable(True) - # chNameComboBox.lineEdit().setAlignment(Qt.AlignCenter) - # chNameComboBox.lineEdit().setReadOnly(True) - self.chNameComboBox = chNameComboBox - buttonsLayout.addWidget(useSameAsChButton, 2, 1) - buttonsLayout.addWidget(chNameComboBox, 2, 2) - - - - buttonsLayout.setColumnStretch(0, 1) - buttonsLayout.setColumnStretch(3, 1) - buttonsLayout.setContentsMargins(10, 0, 10, 0) - - - - cancelButtonLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton('Cancel') - cancelButtonLayout.addStretch(1) - cancelButtonLayout.addWidget(cancelButton) - cancelButtonLayout.addStretch(1) - cancelButtonLayout.setStretch(1,1) - cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(buttonsLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(cancelButtonLayout) - mainLayout.addStretch(1) - - self.setLayout(mainLayout) - - font = QFont() - font.setPixelSize(12) - self.setFont(font) - - # self.setModal(True) - - def ok_cb(self, checked=True): - self.cancel = False - self.close() - - def useSameAsCh_cb(self, checked): - self.useSameAsCh = True - self.selectedChannel = self.chNameComboBox.currentText() - self.ok_cb() - - def useMiddleSlice_cb(self, checked): - self.useMiddleSlice = True - self.ok_cb() - - def runDataPrep_cb(self, checked): - self.runDataPrep = True - self.ok_cb() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class SelectSegmFileDialog(QDialog): - def __init__( - self, images_ls, parent_path, parent=None, - addNewFileButton=False, basename='', infoText=None, - fileType='segmentation', allowMultipleSelection=False, - custom_first=None - ): - self.cancel = True - self.selectedItemText = '' - self.selectedItemIdx = None - self.removeOthers = False - self.okAllPos = False - self.newSegmEndName = None - self.allowMultipleSelection = allowMultipleSelection - self.basename = basename - images_ls = sorted(images_ls, key=len) - if custom_first is not None: - images_ls.remove(custom_first) - images_ls.insert(0, custom_first) - - # Remove the 'segm_' part to allow filenameDialog to check if - # a new file is existing (since we only ask for the part after - # 'segm_') - self.existingEndNames = [ - n.replace('segm', '', 1).replace('_', '', 1) for n in images_ls - ] - - self.images_ls = images_ls - self.parent_path = parent_path - super().__init__(parent) - - informativeText = html_utils.paragraph(f""" - The loaded Position folders already contains - {len(self.existingEndNames)} {fileType} masks
- """) - - self.setWindowTitle(f'{fileType.capitalize()} files detected') - is_win = sys.platform.startswith("win") - - mainLayout = QVBoxLayout() - infoLayout = QHBoxLayout() - selectionLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - - # Standard Qt Question icon - label = QLabel() - standardIcon = getattr(QStyle, 'SP_MessageBoxQuestion') - icon = self.style().standardIcon(standardIcon) - pixmap = icon.pixmap(60, 60) - label.setPixmap(pixmap) - infoLayout.addWidget(label) - - infoLabel = QLabel(informativeText) - infoLayout.addWidget(infoLabel) - infoLayout.addStretch(1) - mainLayout.addLayout(infoLayout) - - if infoText is None: - infoText = f'Select which {fileType} file to load:' - - questionText = html_utils.paragraph(infoText) - label = QLabel(questionText) - listWidget = widgets.listWidget() - listWidget.addItems(images_ls) - listWidget.setCurrentRow(0) - listWidget.itemDoubleClicked.connect(self.listDoubleClicked) - if allowMultipleSelection: - listWidget.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection - ) - self.items = list(images_ls) - self.listWidget = listWidget - - okButton = widgets.okPushButton(' Load selected ') - txt = 'Reveal in Finder...' if is_mac else 'Show in Explorer...' - showInFileManagerButton = widgets.showInFileManagerButton(txt) - cancelButton = widgets.cancelPushButton(' Cancel ') - - if addNewFileButton: - newFileButton = widgets.newFilePushButton('New file...') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addWidget(showInFileManagerButton) - buttonsLayout.addSpacing(20) - if addNewFileButton: - buttonsLayout.addWidget(newFileButton) - buttonsLayout.addWidget(okButton) - - buttonsLayout.setContentsMargins(0, 10, 0, 10) - - selectionLayout.addWidget(label, 0, 1, alignment=Qt.AlignLeft) - selectionLayout.addWidget(listWidget, 1, 1) - selectionLayout.setColumnStretch(0, 0) - selectionLayout.setColumnStretch(1, 1) - selectionLayout.setColumnStretch(2, 0) - selectionLayout.addLayout(buttonsLayout, 2, 1) - - mainLayout.addLayout(selectionLayout) - self.setLayout(mainLayout) - - self.okButton = okButton - - # Connect events - okButton.clicked.connect(self.ok_cb) - if addNewFileButton: - newFileButton.clicked.connect(self.newFile_cb) - cancelButton.clicked.connect(self.close) - showInFileManagerButton.clicked.connect(self.showInFileManager) - - def listDoubleClicked(self, item): - self.ok_cb() - - def showInFileManager(self, checked=True): - myutils.showInExplorer(self.parent_path) - - def newFile_cb(self): - win = filenameDialog( - basename=f'{self.basename}segm', - hintText='Insert a filename for the segmentation file:', - existingNames=self.existingEndNames - ) - win.exec_() - if win.cancel: - return - self.cancel = False - self.newSegmEndName = win.entryText - self.close() - - def setSelectedItemFromText(self, itemText): - for i in range(self.listWidget.count()): - if self.listWidget.item(i).text() == itemText: - self.listWidget.setCurrentRow(i) - break - - def ok_cb(self, event=None): - self.cancel = False - try: - self.selectedItemText = self.listWidget.selectedItems()[0].text() - except IndexError: - self.cancel = True - self.close() - return - self.selectedItemIdx = self.items.index(self.selectedItemText) - self.selectedItemTexts = [ - selectedItem.text() - for selectedItem in self.listWidget.selectedItems() - ] - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - -class QDialogPbar(QDialog): - def __init__(self, title='Progress', infoTxt='', parent=None): - self.workerFinished = False - self.aborted = False - self.clickCount = 0 - super().__init__(parent) - - abort_text = 'Option+Command+C' if is_mac else 'Ctrl+Alt+C' - self.abort_text = abort_text - - self.setWindowTitle(f'{title} ({abort_text} to abort)') - self.setWindowFlags(Qt.Window) - - mainLayout = QVBoxLayout() - pBarLayout = QGridLayout() - - if infoTxt: - infoLabel = QLabel(infoTxt) - mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) - - self.progressLabel = QLabel() - - self.QPbar = widgets.ProgressBar(self) - pBarLayout.addWidget(self.QPbar, 0, 0) - self.ETA_label = QLabel('NDh:NDm:NDs') - pBarLayout.addWidget(self.ETA_label, 0, 1) - - self.metricsQPbar = widgets.ProgressBar(self) - self.metricsQPbar.setValue(0) - pBarLayout.addWidget(self.metricsQPbar, 1, 0) - - #pBarLayout.setColumnStretch(2, 1) - - mainLayout.addWidget(self.progressLabel) - mainLayout.addLayout(pBarLayout) - - self.setLayout(mainLayout) - # self.setModal(True) - - def keyPressEvent(self, event): - isCtrlAlt = event.modifiers() == (Qt.ControlModifier | Qt.AltModifier) - if isCtrlAlt and event.key() == Qt.Key_C: - doAbort = self.askAbort() - if doAbort: - self.aborted = True - self.workerFinished = True - self.close() - - def askAbort(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - Aborting with {self.abort_text} to abort - is not safe.

- The system status cannot be predicted and - it will require a restart.

- Are you sure you want to abort? - """) - yesButton, noButton = msg.critical( - self, 'Are you sure you want to abort?', txt, - buttonsTexts=('Yes', 'No') - ) - return msg.clickedButton == yesButton - - - def abort(self): - self.clickCount += 1 - self.aborted = True - if self.clickCount > 3: - self.workerFinished = True - self.close() - - def closeEvent(self, event): - if not self.workerFinished: - event.ignore() - -class FunctionParamsDialog(QBaseDialog): - sigValuesChanged = Signal(dict) - - def __init__( - self, params_argspecs, - function_name='Function', - df_metadata=None, - parent=None, - addApplyButton=False - ): - self.cancel = True - self.df_metadata = df_metadata - - super().__init__(parent) - - self.setWindowTitle(f'{function_name} parameters') - - self.mainLayout = QVBoxLayout() - - widgetsLayout, self.argsWidgets = self.getWidgetsLayout(params_argspecs) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - self.buttonsLayout = buttonsLayout - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - if addApplyButton: - applyButton = widgets.viewPushButton('Apply') - applyButton.clicked.connect(self.emitValuesChanged) - buttonsLayout.insertWidget(3, applyButton) - self.applyButton = applyButton - - self.mainLayout.addLayout(widgetsLayout) - self.mainLayout.addSpacing(20) - self.mainLayout.addLayout(buttonsLayout) - - self.setLayout(self.mainLayout) - - def emitValuesChanged(self, *args, **kwargs): - self.sigValuesChanged.emit(self.functionKwargs()) - - def functionKwargs(self): - function_kwargs = { - argWidget.name:argWidget.valueGetter(argWidget.widget) - for argWidget in self.argsWidgets - } - return function_kwargs - - def kwargWidgetMapper(self) -> Dict[str, tuple]: - kwarg_widget_mapper = { - argWidget.name:(argWidget.widget, argWidget.valueSetter) - for argWidget in self.argsWidgets - } - return kwarg_widget_mapper - - def ok_cb(self): - self.cancel = False - - self.function_kwargs = self.functionKwargs() - - self.close() - - def getValueFromMetadata(self, name): - try: - value = self.df_metadata.at[name, 'values'] - except Exception as e: - # traceback.print_exc() - value = None - return value - - def getWidgetsLayout(self, params_argspecs): - widgetsLayout = QGridLayout() - ArgsWidgets_list = [] - - for row, ArgSpec in enumerate(params_argspecs): - if _types.is_widget_not_required(ArgSpec): - continue - - arg_name = ArgSpec.name - var_name = arg_name.replace('_', ' ') - var_name = f'{var_name[0].upper()}{var_name[1:]}' - label = QLabel(f'{var_name}: ') - metadata_val = self.getValueFromMetadata(ArgSpec.name) - widgetsLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) - try: - values = ArgSpec.type().values - isCustomListType = True - except Exception as err: - isCustomListType = False - - isVectorEntry = False - try: - if isinstance(ArgSpec.type(), _types.Vector): - isVectorEntry = True - except Exception as err: - pass - - isFolderPath = False - try: - if isinstance(ArgSpec.type(), _types.FolderPath): - isFolderPath = True - except Exception as err: - pass - - isCustomWidget = hasattr(ArgSpec.type, 'isWidget') - - if isCustomWidget: - widget = ArgSpec.type().widget - self.checkIfTypeCLassHasCastDtype(widget) - defaultVal = ArgSpec.default - valueSetter = widget.setValue - valueGetter = widget.value - widgetsLayout.addWidget(widget, row, 1, 1, 2) - try: - widget.sigValueChanged.connect(self.emitValuesChanged) - except Exception as err: - pass - elif isVectorEntry: - vectorLineEdit = widgets.VectorLineEdit() - self.checkIfTypeCLassHasCastDtype(ArgSpec.type) - vectorLineEdit.setValue(ArgSpec.default) - defaultVal = ArgSpec.default - valueSetter = widgets.VectorLineEdit.setValue - valueGetter = widgets.VectorLineEdit.value - widget = vectorLineEdit - widgetsLayout.addWidget(vectorLineEdit, row, 1, 1, 2) - widget.valueChangeFinished.connect(self.emitValuesChanged) - elif isFolderPath: - folderPathControl = widgets.FolderPathControl() - self.checkIfTypeCLassHasCastDtype(ArgSpec.type) - folderPathControl.setText(str(ArgSpec.default)) - widget = folderPathControl - defaultVal = str(ArgSpec.default) - valueSetter = widgets.FolderPathControl.setText - valueGetter = widgets.FolderPathControl.path - widgetsLayout.addWidget(folderPathControl, row, 1, 1, 2) - widget.sigValueChanged.connect(self.emitValuesChanged) - elif ArgSpec.type == bool: - booleanGroup = QButtonGroup() - booleanGroup.setExclusive(True) - checkBox = widgets.Toggle() - checkBox.setChecked(ArgSpec.default) - defaultVal = ArgSpec.default - valueSetter = widgets.Toggle.setChecked - valueGetter = widgets.Toggle.isChecked - widget = checkBox - widgetsLayout.addWidget( - checkBox, row, 1, 1, 2, alignment=Qt.AlignCenter - ) - widget.toggled.connect(self.emitValuesChanged) - elif ArgSpec.type == int: - spinBox = widgets.SpinBox() - if metadata_val is None: - spinBox.setValue(ArgSpec.default) - else: - spinBox.setValue(int(metadata_val)) - spinBox.isMetadataValue = True - defaultVal = ArgSpec.default - valueSetter = QSpinBox.setValue - valueGetter = QSpinBox.value - widget = spinBox - widgetsLayout.addWidget(spinBox, row, 1, 1, 2) - widget.sigValueChanged.connect(self.emitValuesChanged) - elif ArgSpec.type == float: - doubleSpinBox = widgets.FloatLineEdit() - if metadata_val is None: - doubleSpinBox.setValue(ArgSpec.default) - else: - doubleSpinBox.setValue(float(metadata_val)) - doubleSpinBox.isMetadataValue = True - widget = doubleSpinBox - defaultVal = ArgSpec.default - valueSetter = widgets.FloatLineEdit.setValue - valueGetter = widgets.FloatLineEdit.value - widgetsLayout.addWidget(doubleSpinBox, row, 1, 1, 2) - widget.valueChanged.connect(self.emitValuesChanged) - elif ArgSpec.type == os.PathLike: - filePathControl = widgets.filePathControl() - filePathControl.setText(str(ArgSpec.default)) - widget = filePathControl - defaultVal = str(ArgSpec.default) - valueSetter = widgets.filePathControl.setText - valueGetter = widgets.filePathControl.path - widgetsLayout.addWidget(filePathControl, row, 1, 1, 2) - widget.sigValueChanged.connect(self.emitValuesChanged) - elif isCustomListType: - items = ArgSpec.type().values - ArgSpec.type.cast_dtype = _types.to_str - defaultVal = str(ArgSpec.default) - combobox = widgets.AlphaNumericComboBox() - combobox.addItems(items) - combobox.setCurrentValue(defaultVal) - valueSetter = widgets.AlphaNumericComboBox.setCurrentValue - valueGetter = widgets.AlphaNumericComboBox.currentValue - widget = combobox - widgetsLayout.addWidget(combobox, row, 1, 1, 2) - widget.currentTextChanged.connect(self.emitValuesChanged) - else: - lineEdit = QLineEdit() - lineEdit.setText(str(ArgSpec.default)) - lineEdit.setAlignment(Qt.AlignCenter) - widget = lineEdit - defaultVal = str(ArgSpec.default) - valueSetter = QLineEdit.setText - valueGetter = QLineEdit.text - widgetsLayout.addWidget(lineEdit, row, 1, 1, 2) - widget.editingFinished.connect(self.emitValuesChanged) - - if ArgSpec.desc: - infoButton = self.getInfoButton(ArgSpec.name, ArgSpec.desc) - widgetsLayout.addWidget(infoButton, row, 3) - - argsInfo = ArgWidget( - name=ArgSpec.name, - type=ArgSpec.type, - widget=widget, - defaultVal=defaultVal, - valueSetter=valueSetter, - valueGetter=valueGetter - ) - ArgsWidgets_list.append(argsInfo) - - widgetsLayout.setColumnStretch(0, 0) - widgetsLayout.setColumnStretch(1, 1) - widgetsLayout.setColumnStretch(3, 0) - - return widgetsLayout, ArgsWidgets_list - - def checkIfTypeCLassHasCastDtype(self, cls): - cast_dtype = getattr(cls, 'cast_dtype', None) - if callable(cast_dtype): - return - - raise AttributeError( - 'The custom type or widget does not have the `cast_dtype` method. ' - 'Please, implement it. The method should cast the value to the ' - 'correct type.' - ) - - def getInfoButton(self, param_name, infoText): - infoButton = widgets.infoPushButton() - infoButton.param_name = param_name - infoButton.setToolTip( - f'Click to get more info about `{param_name}` parameter...' - ) - infoButton.infoText = infoText - infoButton.clicked.connect(self.showInfoParam) - return infoButton - - def showInfoParam(self): - text = self.sender().infoText - text = html_utils.rst_urls_to_html(text) - text = html_utils.rst_to_html(text) - text = html_utils.paragraph(text) - param_name = self.sender().param_name - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, f'Info about `{param_name}` parameter', text) - -class QDialogModelParams(QDialog): - def __init__( - self, - init_params, - segment_params, - model_name, - is_tracker=False, - url=None, - parent=None, - initLastParams=True, - posData=None, - channels=None, - currentChannelName=None, - segmFileEndnames=None, - df_metadata=None, - force_postprocess_2D=False, - model_module=None, - action_type='', - addPreProcessParams=True, - addPostProcessParams=True, - extraParams=None, - extraParamsTitle=None, - ini_filename=None, - add_additional_segm_params=False - ): - self.cancel = True - super().__init__(parent) - self.channels = channels - self.is_tracker = is_tracker - self.currentChannelName = currentChannelName - self.channelCombobox = None - self.segmFileEndnames = segmFileEndnames - self.df_metadata = df_metadata - self.force_postprocess_2D = force_postprocess_2D - - self.skipSegmentation = False - if len(segment_params) > 0: - if segment_params[0].name.lower().find('skip_segmentation') != -1: - self.skipSegmentation = True - addPreProcessParams = False - else: - self.skipSegmentation = False - if ini_filename is not None: - self.ini_filename = ini_filename - elif is_tracker: - self.ini_filename = 'last_params_trackers.ini' - addPreProcessParams = False - addPostProcessParams = False - else: - self.ini_filename = 'last_params_segm_models.ini' - - self.addPreProcessParams = addPreProcessParams - - self.model_name = model_name - - self.setWindowTitle(f'{model_name} parameters') - - # Create main vertical layout and horizontal layout for two columns - mainLayout = QVBoxLayout() - - gridLayout = QGridLayout() - self.gridLayout = gridLayout - - loadFunc = self.loadLastSelection - - self.paramsGroupPosMapper = {} - - # LEFT COLUMN: Preprocessing params - row, col = 0, 0 - preProcessLayout = None - self.preProcessParamsWidget = None - if addPreProcessParams: - preProcessLayout = QVBoxLayout() - self.preProcessParamsWidget = PreProcessParamsWidget( - parent=self, addApplyButton=False - ) - self.preProcessParamsWidget.setChecked(False) - preProcessLayout.addWidget(self.preProcessParamsWidget) - self.preProcessParamsWidget.sigLoadRecipe.connect( - self.loadPreprocRecipe - ) - gridLayout.addLayout(preProcessLayout, row, col, 1, 2) - self.paramsGroupPosMapper[self.preProcessParamsWidget] = (row, col) - gridLayout.addItem(QSpacerItem(10, 5), 0, col+1) - # gridLayout.setColumnMinimumWidth(col+1, 15) - col += 2 - - # Center COLUMN: Init, Segmentation/Eval - row = 0 - self.secondColLayout = QVBoxLayout() - self.initParamsScrollArea = widgets.ScrollArea() - initParamsScrollAreaLayout = QVBoxLayout() - self.initParamsScrollArea.setVerticalLayout(initParamsScrollAreaLayout) - - initGroupBox, self.init_argsWidgets = self.createGroupParams( - init_params, 'Parameters for model initialization' - ) - self.init_params = init_params - initDefaultButton = widgets.reloadPushButton('Restore default') - initLoadLastSelButton = widgets.OpenFilePushButton( - 'Load last parameters' - ) - initLoadLastSelButton.setIcon(QIcon(':folder-open.svg')) - initButtonsLayout = QHBoxLayout() - initButtonsLayout.addStretch(1) - initButtonsLayout.addWidget(initDefaultButton) - initButtonsLayout.addWidget(initLoadLastSelButton) - initDefaultButton.clicked.connect(self.restoreDefaultInit) - initLoadLastSelButton.clicked.connect( - partial(loadFunc, f'{self.model_name}.init', self.init_argsWidgets) - ) - - initParamsScrollAreaLayout.addWidget(initGroupBox) - - initParamsLayout = QVBoxLayout() - initParamsLayout.addWidget(QLabel(f'{initGroupBox.title()}')) - initGroupBox.setTitle('') - initParamsLayout.addWidget(self.initParamsScrollArea) - initParamsLayout.addLayout(initButtonsLayout) - self.secondColLayout.addLayout(initParamsLayout) - self.paramsGroupPosMapper[self.initParamsScrollArea] = (0, col) - - self.segmentParamsScrollArea = None - if not self.skipSegmentation: - self.segmentParamsScrollArea = widgets.ScrollArea() - segmentParamsScrollAreaLayout = QVBoxLayout() - self.segmentParamsScrollArea.setVerticalLayout( - segmentParamsScrollAreaLayout - ) - if action_type: - runGroupboxTitle = f'Parameters for {action_type}' - elif is_tracker: - runGroupboxTitle = 'Parameters for tracking' - else: - runGroupboxTitle = 'Parameters for segmentation' - - segmentGroupBox, self.argsWidgets = self.createGroupParams( - segment_params, runGroupboxTitle, - addChannelSelector=True - ) - self.segment_params = segment_params - self.segmentGroupBox = segmentGroupBox - segmentDefaultButton = widgets.reloadPushButton('Restore default') - segmentLoadLastSelButton = widgets.OpenFilePushButton( - 'Load last parameters' - ) - segmentButtonsLayout = QHBoxLayout() - segmentButtonsLayout.addStretch(1) - segmentButtonsLayout.addWidget(segmentDefaultButton) - segmentButtonsLayout.addWidget(segmentLoadLastSelButton) - segmentDefaultButton.clicked.connect(self.restoreDefaultSegment) - section = f'{self.model_name}.segment' - segmentLoadLastSelButton.clicked.connect( - partial(loadFunc, section, self.argsWidgets) - ) - segmentParamsScrollAreaLayout.addWidget(segmentGroupBox) - - segmentParamsLayout = QVBoxLayout() - segmentParamsLayout.addWidget( - QLabel(f'{segmentGroupBox.title()}') - ) - segmentGroupBox.setTitle('') - segmentParamsLayout.addWidget(self.segmentParamsScrollArea) - segmentParamsLayout.addLayout(segmentButtonsLayout) - self.secondColLayout.addLayout(segmentParamsLayout) - self.paramsGroupPosMapper[self.segmentParamsScrollArea] = (1, col) - - gridLayout.addLayout(self.secondColLayout, row, col) - - gridLayout.addItem(QSpacerItem(10, 5), 0, col+1) - col += 2 - - # Buttons layout (spans both columns) - buttonsLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton(' Cancel ') - okButton = widgets.okPushButton(' Ok ') - - enableLoadingSavingRecipe = ( - not is_tracker and (addPreProcessParams or addPostProcessParams) - ) - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - if enableLoadingSavingRecipe: - loadEntireRecipeButton = widgets.OpenFilePushButton( - 'Load saved recipe...' - ) - saveEntireRecipeButton = widgets.savePushButton( - 'Save all parameters to recipe file...' - ) - buttonsLayout.addWidget(loadEntireRecipeButton) - buttonsLayout.addWidget(saveEntireRecipeButton) - loadEntireRecipeButton.clicked.connect(self.loadEntireRecipe) - saveEntireRecipeButton.clicked.connect(self.saveEntireRecipe) - - buttonsLayout.addWidget(okButton) - - buttonsLayout.setContentsMargins(0, 10, 0, 10) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - self.okButton = okButton - - # Extra params in right column - row = 0 - self.extraArgsWidgets = None - self.extraParamsScrollArea = None - if extraParams is not None: - self.extraParamsScrollArea = widgets.ScrollArea() - extraParamsScrollAreaLayout = QVBoxLayout() - self.extraParamsScrollArea.setVerticalLayout( - extraParamsScrollAreaLayout - ) - if extraParamsTitle is None: - extraParamsTitle = 'Additional parameters' - - self.extraGroupBox, self.extraArgsWidgets = self.createGroupParams( - extraParams, extraParamsTitle - ) - - extraDefaultButton = widgets.reloadPushButton('Restore default') - extraLoadLastSelButton = widgets.OpenFilePushButton( - 'Load last parameters' - ) - extraButtonsLayout = QHBoxLayout() - extraButtonsLayout.addStretch(1) - extraButtonsLayout.addWidget(extraDefaultButton) - extraButtonsLayout.addWidget(extraLoadLastSelButton) - extraDefaultButton.clicked.connect(self.restoreDefaultExtra) - section = f'{self.model_name}.extra' - extraLoadLastSelButton.clicked.connect( - partial(loadFunc, section, self.extraArgsWidgets) - ) - - extraParamsScrollAreaLayout.addWidget(self.extraGroupBox) - - extraParamsLayout = QVBoxLayout() - extraParamsLayout.addWidget(QLabel(f'{self.extraGroupBox.title()}')) - self.extraGroupBox.setTitle('') - extraParamsLayout.addWidget(self.extraParamsScrollArea) - extraParamsLayout.addLayout(extraButtonsLayout) - self.paramsGroupPosMapper[self.extraParamsScrollArea] = (row, col) - gridLayout.addLayout(extraParamsLayout, row, col) - row += 1 - - # Post-processing in right-most column - self.postProcessGroupbox = None - self.seeHereLabel = None - thirdColumnLayout = QVBoxLayout() - if addPostProcessParams: - # Add minimum size spinbox which is valid for all models - postProcessGroupbox = PostProcessSegmParams( - 'Post-processing segmentation parameters', posData, - force_postprocess_2D=force_postprocess_2D - ) - postProcessGroupbox.setCheckable(True) - postProcessGroupbox.setChecked(False) - self.postProcessGroupbox = postProcessGroupbox - - thirdColumnLayout.addWidget(postProcessGroupbox) - - postProcDefaultButton = widgets.reloadPushButton('Restore default') - postProcLoadLastSelButton = widgets.OpenFilePushButton( - 'Load last parameters' - ) - postProcButtonsLayout = QHBoxLayout() - postProcButtonsLayout.addStretch(1) - postProcButtonsLayout.addWidget(postProcDefaultButton) - postProcButtonsLayout.addWidget(postProcLoadLastSelButton) - postProcDefaultButton.clicked.connect(self.restoreDefaultPostprocess) - postProcLoadLastSelButton.clicked.connect( - self.loadLastSelectionPostProcess - ) - thirdColumnLayout.addLayout(postProcButtonsLayout) - thirdColumnLayout.addSpacing(15) - - if url is not None: - self.seeHereLabel = self.createSeeHereLabel(url) - thirdColumnLayout.addWidget( - self.seeHereLabel, alignment=Qt.AlignCenter - ) - - self.paramsGroupPosMapper[self.preProcessParamsWidget] = (row, col) - - # Additional segmentation params in right column - self.additionalSegmGroupbox = None - if add_additional_segm_params: - thirdColumnLayout.addWidget(widgets.QHLine()) - additionalSegmGroupbox = self.getAdditionalSegmParams() - thirdColumnLayout.addWidget(additionalSegmGroupbox) - self.additionalSegmGroupbox = additionalSegmGroupbox - self.paramsGroupPosMapper[self.additionalSegmGroupbox] = (row, col) - - thirdColumnLayout.addStretch(1) - gridLayout.addLayout(thirdColumnLayout, row, col) - row += 1 - - # Add everything to main layout - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.configPars = self.readLastSelection() - if self.configPars is None: - initLoadLastSelButton.setDisabled(True) - segmentLoadLastSelButton.setDisabled(True) - if self.postProcessGroupbox is not None: - postProcLoadLastSelButton.setDisabled(True) - - if initLastParams: - initLoadLastSelButton.click() - if not self.skipSegmentation: - segmentLoadLastSelButton.click() - - if self.extraArgsWidgets is not None: - extraLoadLastSelButton.click() - - if self.postProcessGroupbox is not None: - postProcLoadLastSelButton.click() - - try: - self.connectCustomSignals(model_module) - except Exception as e: - printl(traceback.format_exc()) - - self.setLayout(mainLayout) - self.setFont(font) - # self.setModal(True) - - def warningNoSegmRecipes(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'No segmentation recipes found!

' - 'To create a segmentation recipe you need click on ' - 'Save all parameters to recipe file... ' - 'button.' - ) - msg.warning(self, 'No segmentation recipes found!', txt) - - def selectIniFileToLoadEntireRecipe(self): - import qtpy.compat - recipe_filepath = qtpy.compat.getopenfilename( - parent=self, - caption='Select INI file to load entire recipe', - filters='INI (*.ini);;All Files (*)' - )[0] - if not recipe_filepath: - return - - self.loadRecipeFromFilepath(recipe_filepath) - - txt = html_utils.paragraph( - 'Done!

' - 'Segmentation recipe loaded from:' - ) - msg = widgets.myMessageBox() - msg.information( - self, 'Segmentation recipe loaded!', txt, - commands=(recipe_filepath,), - path_to_browse=os.path.dirname(recipe_filepath) - ) - - print('Done. Segmentation recipe loaded from:', recipe_filepath) - - def loadEntireRecipe(self): - segm_recipes_path_model = os.path.join( - segm_recipes_path, self.model_name - ) - - if not os.path.exists(segm_recipes_path_model): - # self.warningNoSegmRecipes() - self.selectIniFileToLoadEntireRecipe() - return - - recipe_files = os.listdir(segm_recipes_path_model) - - if not recipe_files: - # self.warningNoSegmRecipes() - self.selectIniFileToLoadEntireRecipe() - return - - headerLabels = ['Name', 'Date Created'] - items = [] - for recipe_file in recipe_files: - cp = config.ConfigParser() - cp.read(os.path.join(segm_recipes_path_model, recipe_file)) - date_created = cp['info']['created_on'] - items.append((recipe_file, date_created)) - - browseButton = widgets.browseFileButton( - 'Select INI file...', - title='Select INI file to load entire recipe', - openFolder=False, - start_dir=myutils.getMostRecentPath(), - ext={'INI': '.ini'} - ) - win = QTreeDialog( - items, - headerLabels=headerLabels, - title='Select a segmentation recipe to load', - infoText='Select a segmentation recipe to load:
', - path_to_browse=segm_recipes_path_model, - additional_buttons=(browseButton, ) - ) - browseButton.sigPathSelected.connect( - partial( - self.entireRecipeIniFileSelected, - selectRecipeWin=win, - sender=browseButton - ) - ) - win.exec_() - if win.cancel or not hasattr(win, 'selectedText'): - print('Loading segmentation recipe cancelled.') - return - - if win.clickedButton == browseButton: - recipe_filepath = win.selectedIniFilepath - else: - recipe_filename = win.selectedText - recipe_filepath = os.path.join( - segm_recipes_path_model, recipe_filename - ) - - self.loadRecipeFromFilepath(recipe_filepath) - - txt = html_utils.paragraph( - 'Done!

' - 'Segmentation recipe loaded from:' - ) - msg = widgets.myMessageBox() - msg.information( - self, 'Segmentation recipe laoded!', txt, - commands=(recipe_filepath,), - path_to_browse=os.path.dirname(recipe_filepath) - ) - - print('Done. Segmentation recipe loaded from:', recipe_filepath) - - def entireRecipeIniFileSelected( - self, recipe_filepath, selectRecipeWin=None, sender=None - ): - selectRecipeWin.selectedText = 'None' - selectRecipeWin.clickedButton = sender - selectRecipeWin.selectedIniFilepath = recipe_filepath - selectRecipeWin.cancel = False - selectRecipeWin.close() - - def loadRecipeFromFilepath(self, recipe_filepath): - cp = config.ConfigParser() - cp.read(recipe_filepath) - - self.loadPreprocRecipe(configPars=cp) - self.loadLastSelection( - f'{self.model_name}.init', self.init_argsWidgets, configPars=cp - ) - self.loadLastSelection( - f'{self.model_name}.segment', self.argsWidgets, configPars=cp - ) - if self.extraArgsWidgets: - self.loadLastSelection( - f'{self.model_name}.extra', self.extraArgsWidgets, configPars=cp - ) - self.loadLastSelectionPostProcess(configPars=cp) - - def saveEntireRecipe(self): - segm_recipes_path_model = os.path.join( - segm_recipes_path, self.model_name - ) - try: - existingNames=os.listdir(segm_recipes_path_model) - except FileNotFoundError: - existingNames = [] - - win = filenameDialog( - title='Filename for segmentation recipe', - basename='segmentation_recipe', - ext='.ini', - hintText='Insert a filename for the segmentation recipe:', - allowEmpty=False, - parent=self, - existingNames=existingNames - ) - win.exec_() - if win.cancel: - return - - ini_filename = win.filename - os.makedirs(segm_recipes_path, exist_ok=True) - os.makedirs(segm_recipes_path_model, exist_ok=True) - ini_filepath = os.path.join(segm_recipes_path_model, ini_filename) - - configPars = self.getConfigPars(create_new=True) - - if hasattr(self, 'reduceMemUsageToggle'): - configPars[f'{self.model_name}.additional_segm_params'] = {} - reduceMemoryUsage = self.reduceMemUsageToggle.isChecked() - option = self.reduceMemUsageToggle.label - configPars[f'{self.model_name}.additional_segm_params'][option] = ( - str(reduceMemoryUsage) - ) - - configPars['info'] = {} - configPars['info']['created_on'] = datetime.datetime.now().strftime( - r'%Y/%m/%d %H:%M' - ) - - with open(ini_filepath, 'w') as configfile: - configPars.write(configfile) - - txt = html_utils.paragraph( - 'Done!

' - 'Segmentation recipe saved to:' - ) - msg = widgets.myMessageBox() - msg.information( - self, 'Segmnentation recipe saved!', txt, - commands=(ini_filepath,), - path_to_browse=os.path.dirname(ini_filepath) - ) - - print('Done. Segmentation recipe saved to:', ini_filepath) - - def getAdditionalSegmParams(self): - additionalSegmGroupbox = QGroupBox('Additional segmentation parameters') - local_row = 0 - additionalSegmLayout = QGridLayout() - option = 'Reduce memory usage' - additionalSegmLayout.addWidget( - QLabel(f'{option}: '), local_row, 0, - alignment=Qt.AlignRight - ) - self.reduceMemUsageToggle = widgets.Toggle() - additionalSegmLayout.addWidget( - self.reduceMemUsageToggle, local_row, 1, 1, 2, - alignment=Qt.AlignCenter - ) - self.reduceMemUsageToggle.label = option - reduceMemUsageInfoButton = widgets.infoPushButton() - additionalSegmLayout.addWidget(reduceMemUsageInfoButton, local_row, 3) - reduceMemUsageInfoButton.clicked.connect( - self.showInfoReduceMemUsage - ) - additionalSegmLayout.setColumnStretch(0, 0) - additionalSegmLayout.setColumnStretch(1, 1) - additionalSegmLayout.setColumnStretch(3, 0) - additionalSegmGroupbox.setLayout(additionalSegmLayout) - return additionalSegmGroupbox - - def showInfoReduceMemUsage(self): - infoText = html_utils.paragraph(f""" - If you are experiencing memory issues, you can try reducing the - memory usage by toggling this option.

- This will reduce the memory usage by segmenting timelapse data - frame-by-frame instead of all frames at once. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information( - self, 'Reduce memory usage', infoText - ) - - def loadPreprocRecipe(self, configPars=None): - if self.configPars is None and configPars is None: - return - - if configPars is None: - configPars = self.configPars - - preprocConfigPars = {} - for section in configPars.sections(): - if not section.startswith(f'{self.model_name}.preprocess'): - continue - - preprocConfigPars[section] = configPars[section] - - if not preprocConfigPars: - return - - self.preProcessParamsWidget.loadRecipe(preprocConfigPars) - - def connectCustomSignals(self, model_module): - if model_module is None: - return - - if not hasattr(model_module, 'CustomSignals'): - return - - customSignals = model_module.CustomSignals() - for slot_info in customSignals.slots_info: - group = slot_info['group'] - widget_name = slot_info['widget_name'] - if group == 'init': - ArgsWidgets_list = self.init_argsWidgets - else: - ArgsWidgets_list = self.argsWidgets - for argwidget in ArgsWidgets_list: - if argwidget.name == widget_name: - signal = getattr(argwidget.widget, slot_info['signal']) - signal.connect(partial(slot_info['slot'], self)) - break - - def selectedFeaturesRange(self): - if self.postProcessGroupbox is None: - return {} - return self.postProcessGroupbox.selectedFeaturesRange() - - def groupedFeatures(self): - if self.postProcessGroupbox is None: - return {} - return self.postProcessGroupbox.groupedFeatures() - - def setChannelNames(self, chNames): - if not hasattr(self, 'channelsCombobox'): - return - - items = ['None'] - items.extend(chNames) - self.channelsCombobox.addItems(items) - - def getValueFromMetadata(self, name): - try: - value = self.df_metadata.at[name, 'values'] - except Exception as e: - # traceback.print_exc() - value = None - return value - - def criticalSegmFileRequiredButNoneAvailable(self): - model_name = f'{self.model_name} model' - action_txt = ( - 'Please, segment the correct channel before using ' - f'{self.model_name}.' - ) - if self.model_name == 'skip_segmentation': - model_name = 'Skipping the segmentation' - action_txt = ( - 'To be able to skip the segmentation step, you need ' - 'create at least one segmentation file.' - ) - txt = html_utils.paragraph(f""" - {model_name} - requires an additional segmentation file - but there are none available!

- {action_txt} -

Thank you for you patience! - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Segmentation file required', txt) - raise FileNotFoundError( - 'Model requires segmentation file but none are available.' - ) - - - def checkAddSegmEndnameCombobox(self, ArgSpec, groupBoxLayout, row): - if ArgSpec.name != 'Auxiliary segmentation file': - return False - - if self.segmFileEndnames is None or not self.segmFileEndnames: - self.criticalSegmFileRequiredButNoneAvailable() - - label = QLabel(f'{ArgSpec.name}: ') - groupBoxLayout.addWidget( - label, row, 0, alignment=Qt.AlignRight - ) - items = self.segmFileEndnames - self.segmEndnameCombobox = widgets.QCenteredComboBox() - self.segmEndnameCombobox.addItems(items) - groupBoxLayout.addWidget(self.segmEndnameCombobox, row, 1, 1, 2) - return True - - - def createGroupParams(self, ArgSpecs_list, groupName, addChannelSelector=False): - ArgsWidgets_list = [] - groupBox = QGroupBox(groupName) - groupBoxLayout = QGridLayout() - - start_row = 0 - if self.is_tracker and self.channels is not None and addChannelSelector: - label = QLabel(f'Input image: ') - groupBoxLayout.addWidget( - label, start_row, 0, alignment=Qt.AlignRight - ) - items = ['None', *self.channels] - self.channelCombobox = widgets.QCenteredComboBox() - self.channelCombobox.addItems(items) - groupBoxLayout.addWidget(self.channelCombobox, start_row, 1, 1, 2) - if self.currentChannelName is not None: - self.channelCombobox.setCurrentText(self.currentChannelName) - infoText = ( - 'Some trackers require the intensity image as input.

' - 'If this one does not require it, leave the selected value ' - 'to `None`.' - ) - infoButton = self.getInfoButton('Input image', infoText) - groupBoxLayout.addWidget(infoButton, start_row, 3) - start_row += 1 - - addSecondChannelSelector = addChannelSelector - if len(ArgSpecs_list) > 0: - if addSecondChannelSelector and ArgSpecs_list[0].docstring is not None: - isSingleChannel = ArgSpecs_list[0].docstring.lower().find( - 'single channel only' - ) != -1 - if isSingleChannel: - addSecondChannelSelector = False - - isDualChannelModel = ( - self.model_name.find('cellpose') != -1 - or any([ - _types.is_second_channel_type(ArgSpec.type) - for ArgSpec in ArgSpecs_list - ]) - ) - askSecondChannel = isDualChannelModel and addSecondChannelSelector - - if askSecondChannel: - label = QLabel('Second channel (optional): ') - groupBoxLayout.addWidget(label, start_row, 0, alignment=Qt.AlignRight) - self.channelsCombobox = widgets.QCenteredComboBox() - groupBoxLayout.addWidget(self.channelsCombobox, start_row, 1, 1, 2) - infoText = ( - 'Some models can merge two channels (e.g., cyto + ' - 'nucleus) to obtain better perfomance.\n\n' - 'Select a channel as additional input to the model.' - ) - infoButton = self.getInfoButton('Second channel', infoText) - groupBoxLayout.addWidget(infoButton, start_row, 3) - start_row += 1 - - exclusive_withs = dict() - default_exclusives = dict() - row_mapper = dict() - for row, ArgSpec in enumerate(ArgSpecs_list): - if _types.is_second_channel_type(ArgSpec.type): - continue - - if _types.is_widget_not_required(ArgSpec): - continue - - row = row + start_row - skip = self.checkAddSegmEndnameCombobox( - ArgSpec, groupBoxLayout, row - ) - if skip: - continue - - arg_name = ArgSpec.name - var_name = arg_name.replace('_', ' ') - var_name = f'{var_name[0].upper()}{var_name[1:]}' - label = QLabel(f'{var_name}: ') - metadata_val = self.getValueFromMetadata(ArgSpec.name) - groupBoxLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) - try: - values = ArgSpec.type().values - isCustomListType = True - except Exception as err: - isCustomListType = False - - isVectorEntry = False - try: - if isinstance(ArgSpec.type(), _types.Vector): - isVectorEntry = True - except Exception as err: - pass - - isFolderPath = False - try: - if isinstance(ArgSpec.type(), _types.FolderPath): - isFolderPath = True - except Exception as err: - pass - - try: - exclusive_with = ArgSpec.type().is_exclusive_with - except Exception as err: - exclusive_with = [] - - try: - default_exclusive = ArgSpec.type().default_exclusive - except Exception as err: - default_exclusive = '' - - exclusive_withs[arg_name] = exclusive_with - default_exclusives[arg_name] = default_exclusive - row_mapper[arg_name] = row - - isCustomWidget = hasattr(ArgSpec.type, 'isWidget') - - if isCustomWidget: - widget = ArgSpec.type().widget - defaultVal = ArgSpec.default - valueSetter = widget.setValue - valueGetter = widget.value - changeSig = widget.sigValueChanged - groupBoxLayout.addWidget(widget, row, 1, 1, 2) - elif isVectorEntry: - vectorLineEdit = widgets.VectorLineEdit() - vectorLineEdit.setValue(ArgSpec.default) - defaultVal = ArgSpec.default - valueSetter = widgets.VectorLineEdit.setValue - valueGetter = widgets.VectorLineEdit.value - changeSig = vectorLineEdit.valueChanged - widget = vectorLineEdit - groupBoxLayout.addWidget(vectorLineEdit, row, 1, 1, 2) - elif isFolderPath: - folderPathControl = widgets.FolderPathControl() - folderPathControl.setText(str(ArgSpec.default)) - widget = folderPathControl - defaultVal = str(ArgSpec.default) - valueSetter = widgets.FolderPathControl.setText - valueGetter = widgets.FolderPathControl.path - changeSig = widget.sigValueChanged - groupBoxLayout.addWidget(folderPathControl, row, 1, 1, 2) - elif ArgSpec.type == bool: - booleanGroup = QButtonGroup() - booleanGroup.setExclusive(True) - checkBox = widgets.Toggle() - checkBox.setChecked(ArgSpec.default) - defaultVal = ArgSpec.default - valueSetter = widgets.Toggle.setChecked - valueGetter = widgets.Toggle.isChecked - changeSig = checkBox.toggled - widget = checkBox - groupBoxLayout.addWidget( - checkBox, row, 1, 1, 2, alignment=Qt.AlignCenter - ) - elif ArgSpec.type == int: - spinBox = widgets.SpinBox() - if metadata_val is None: - spinBox.setValue(ArgSpec.default) - else: - spinBox.setValue(int(metadata_val)) - spinBox.isMetadataValue = True - defaultVal = ArgSpec.default - valueSetter = QSpinBox.setValue - valueGetter = QSpinBox.value - changeSig = spinBox.sigValueChanged - widget = spinBox - groupBoxLayout.addWidget(spinBox, row, 1, 1, 2) - elif ArgSpec.type == float: - doubleSpinBox = widgets.FloatLineEdit() - if metadata_val is None: - doubleSpinBox.setValue(ArgSpec.default) - else: - doubleSpinBox.setValue(float(metadata_val)) - doubleSpinBox.isMetadataValue = True - widget = doubleSpinBox - defaultVal = ArgSpec.default - valueSetter = widgets.FloatLineEdit.setValue - valueGetter = widgets.FloatLineEdit.value - changeSig = doubleSpinBox.valueChanged - groupBoxLayout.addWidget(doubleSpinBox, row, 1, 1, 2) - elif ArgSpec.type == os.PathLike: - filePathControl = widgets.filePathControl() - filePathControl.setText(str(ArgSpec.default)) - widget = filePathControl - defaultVal = str(ArgSpec.default) - valueSetter = widgets.filePathControl.setText - valueGetter = widgets.filePathControl.path - changeSig = filePathControl.sigValueChanged - groupBoxLayout.addWidget(filePathControl, row, 1, 1, 2) - elif isCustomListType: - items = ArgSpec.type().values - defaultVal = str(ArgSpec.default) - combobox = widgets.AlphaNumericComboBox() - combobox.addItems(items) - combobox.setCurrentValue(defaultVal) - valueSetter = widgets.AlphaNumericComboBox.setCurrentValue - valueGetter = widgets.AlphaNumericComboBox.currentValue - changeSig = combobox.currentTextChanged - widget = combobox - groupBoxLayout.addWidget(combobox, row, 1, 1, 2) - else: - lineEdit = QLineEdit() - lineEdit.setText(str(ArgSpec.default)) - lineEdit.setAlignment(Qt.AlignCenter) - widget = lineEdit - defaultVal = str(ArgSpec.default) - valueSetter = QLineEdit.setText - valueGetter = QLineEdit.text - changeSig = lineEdit.editingFinished - groupBoxLayout.addWidget(lineEdit, row, 1, 1, 2) - - if ArgSpec.desc: - infoButton = self.getInfoButton(ArgSpec.name, ArgSpec.desc) - groupBoxLayout.addWidget(infoButton, row, 3) - - argsInfo = ArgWidget( - name=ArgSpec.name, - type=ArgSpec.type, - widget=widget, - defaultVal=defaultVal, - valueSetter=valueSetter, - valueGetter=valueGetter, - changeSig=changeSig - ) - ArgsWidgets_list.append(argsInfo) - - exclusive_group = core.connected_components_in_undirected_graph( - exclusive_withs - ) - - for group in exclusive_group: - if len(group) == 1: - continue - for arg_name in group: - default_exclusive = default_exclusives[arg_name] - row = row_mapper[arg_name] - - argsInfo = ArgsWidgets_list[row] - valueSetter = argsInfo.valueSetter - widget = argsInfo.widget - valueGetter = argsInfo.valueGetter - - argsInfo.valueGetter = qutils.replace_certain_vals( - argsInfo.valueGetter, default_exclusive, None - ) - - for arg_name_other in group: - if arg_name == arg_name_other: - continue - row_other = row_mapper[arg_name_other] - argsInfo_other = ArgsWidgets_list[row_other] - changeSig_other = argsInfo_other.changeSig - changeSig_other.connect( - partial(qutils.set_exclusive_valueSetter, widget, - valueSetter, default_exclusive) - ) - - groupBoxLayout.setColumnStretch(0, 0) - groupBoxLayout.setColumnStretch(1, 1) - groupBoxLayout.setColumnStretch(3, 0) - nrows = groupBoxLayout.rowCount() - groupBoxLayout.setRowStretch(nrows, 1) - - groupBox.setLayout(groupBoxLayout) - return groupBox, ArgsWidgets_list - - def getInfoButton(self, param_name, infoText): - infoButton = widgets.infoPushButton() - infoButton.param_name = param_name - infoButton.setToolTip( - f'Click to get more info about `{param_name}` parameter...' - ) - infoButton.infoText = infoText - infoButton.clicked.connect(self.showInfoParam) - return infoButton - - def showInfoParam(self): - text = self.sender().infoText - text = text.replace('\n', '
') - text = html_utils.rst_urls_to_html(text) - text = html_utils.rst_to_html(text) - text = html_utils.paragraph(text) - param_name = self.sender().param_name - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, f'Info about `{param_name}` parameter', text) - - def restoreDefaultInit(self): - for argWidget in self.init_argsWidgets: - defaultVal = argWidget.defaultVal - widget = argWidget.widget - valueSetter = argWidget.valueSetter - qutils.set_exclusive_valueSetter( - widget, valueSetter, defaultVal - ) - - def restoreDefaultSegment(self): - for argWidget in self.argsWidgets: - defaultVal = argWidget.defaultVal - widget = argWidget.widget - valueSetter = argWidget.valueSetter - qutils.set_exclusive_valueSetter( - widget, valueSetter, defaultVal - ) - - def restoreDefaultExtra(self): - for argWidget in self.extraArgsWidgets: - defaultVal = argWidget.defaultVal - widget = argWidget.widget - valueSetter = argWidget.valueSetter - qutils.set_exclusive_valueSetter( - widget, valueSetter, defaultVal - ) - - def restoreDefaultPostprocess(self): - self.postProcessGroupbox.restoreDefault() - - def readLastSelection(self): - self.ini_path = os.path.join(settings_folderpath, self.ini_filename) - - if not os.path.exists(self.ini_path): - return None - - print(f'Reading last selected parameters from: {self.ini_path}') - configPars = config.ConfigParser() - configPars.read(self.ini_path) - return configPars - - def setValuesFromParams(self, init_params, segment_params, extra_params=None): - sections = { - f'{self.model_name}.init': (init_params, self.init_argsWidgets), - f'{self.model_name}.segment': (segment_params, self.argsWidgets), - } - if extra_params is not None: - sections[f'{self.model_name}.extra'] = ( - extra_params, self.extraArgsWidgets - ) - - for section, values in sections.items(): - params, argWidgetList = values - for argWidget in argWidgetList: - val = params.get(argWidget.name) - widget = argWidget.widget - if val is None: - continue - casters = [lambda x: x, int, float, str, bool] - for caster in casters: - try: - argWidget.valueSetter(widget, caster(val)) - break - except Exception as e: - continue - - def loadLastSelection( - self, section, argWidgetList, checked=False, configPars=None - ): - if self.configPars is None and configPars is None: - return - - if configPars is None: - configPars = self.configPars - - getters = ['getboolean', 'getint', 'getfloat', 'get'] - try: - options = configPars.options(section) - except Exception: - return - - for argWidget in argWidgetList: - option = argWidget.name - val = None - for getter in getters: - try: - val = getattr(configPars, getter)(section, option) - break - except Exception as err: - pass - widget = argWidget.widget - - if hasattr(widget, 'isMetadataValue'): - continue - if val is None: - continue - - casters = [lambda x: x, int, float, str, bool] - for caster in casters: - try: - val = caster(val) - valueSetter = argWidget.valueSetter - qutils.set_exclusive_valueSetter( - widget, valueSetter, val - ) - break - except Exception as e: - printl(traceback.format_exc()) - continue - - def loadLastSelectionPostProcess(self, checked=False, configPars=None): - if self.postProcessGroupbox is None: - return - - postProcessSection = f'{self.model_name}.postprocess' - - if isinstance(configPars, bool): - configPars = None - - if configPars is None: - configPars = self.configPars - - if postProcessSection in configPars.sections(): - try: - minSize = configPars.getint( - postProcessSection, 'minSize', fallback=10 - ) - except ValueError: - minSize = 10 - - try: - minSolidity = configPars.getfloat( - postProcessSection, 'minSolidity', fallback=0.5 - ) - except ValueError: - minSolidity = 0.5 - - try: - maxElongation = configPars.getfloat( - postProcessSection, 'maxElongation', fallback=3 - ) - except ValueError: - maxElongation = 3 - - try: - minObjSizeZ = configPars.getint( - postProcessSection, 'min_obj_no_zslices', fallback=3 - ) - except ValueError: - minObjSizeZ = 3 - - kwargs = { - 'min_solidity': minSolidity, - 'min_area': minSize, - 'max_elongation': maxElongation, - 'min_obj_no_zslices': minObjSizeZ - } - self.postProcessGroupbox.restoreFromKwargs(kwargs) - - applyPostProcessing = configPars.getboolean( - postProcessSection, 'applyPostProcessing' - ) - self.postProcessGroupbox.setChecked(applyPostProcessing) - - customPostProcessSection = f'{self.model_name}.custom_postprocess' - if postProcessSection not in configPars.sections(): - return - - selectFeaturesWidget = ( - self.postProcessGroupbox.selectedFeaturesDialog.groupbox - ) - selectFeaturesWidget.resetFields() - f = 0 - for col_name, value in configPars[customPostProcessSection].items(): - low, high = value.split(',') - low = low.strip() - high = high.strip() - if f > 0: - selectFeaturesWidget.addFeatureField() - - selector = selectFeaturesWidget.selectors[f] - selector.selectButton.setText(col_name) - selector.selectButton.setFlat(True) - - feature_group = measurements.get_metric_group_name(col_name) - selector.featureGroup = feature_group - - if low != 'None': - try: - low_val = int(low) - except ValueError: - low_val = float(low) - - selector.lowRangeWidgets.checkbox.setChecked(True) - selector.lowRangeWidgets.spinbox.setValue(low_val) - - if high != 'None': - try: - high_val = int(high) - except ValueError: - high_val = float(high) - - selector.highRangeWidgets.checkbox.setChecked(True) - selector.highRangeWidgets.spinbox.setValue(high_val) - - f += 1 - - def createSeeHereLabel(self, url): - htmlTxt = f'here' - seeHereLabel = QLabel() - seeHereLabel.setText(f""" -

- See {htmlTxt} for details on the parameters -

- """) - seeHereLabel.setTextFormat(Qt.RichText) - seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) - seeHereLabel.setOpenExternalLinks(True) - seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") - return seeHereLabel - - def argsWidgets_to_kwargs(self, argsWidgets): - kwargs_dict = { - argWidget.name:argWidget.valueGetter(argWidget.widget) - for argWidget in argsWidgets - } - return kwargs_dict - - def getInitKwargs(self): - init_kwargs = self.argsWidgets_to_kwargs(self.init_argsWidgets) - if hasattr(self, 'segmEndnameCombobox'): - init_kwargs['segm_endname'] = ( - self.segmEndnameCombobox.currentText() - ) - - return init_kwargs - - def getModelKwargs(self): - if self.skipSegmentation: - return {} - - return self.argsWidgets_to_kwargs(self.argsWidgets) - - def getExtraKwargs(self): - if self.extraArgsWidgets is None: - return {} - - return self.argsWidgets_to_kwargs(self.extraArgsWidgets) - - def ok_cb(self, checked): - self.cancel = False - self.preproc_recipe = None - if self.preProcessParamsWidget is not None: - self.preproc_recipe = self.preProcessParamsWidget.recipe() - if self.preproc_recipe is None: - return - - self.init_kwargs = self.getInitKwargs() - - if self.extraArgsWidgets: - self.extra_kwargs = self.getExtraKwargs() - - self.model_kwargs = self.getModelKwargs() - self.segment_kwargs = self.model_kwargs - - if self.postProcessGroupbox is not None: - self.applyPostProcessing = self.postProcessGroupbox.isChecked() - self.standardPostProcessKwargs = self.postProcessGroupbox.kwargs() - self.secondChannelName = None - if hasattr(self, 'channelsCombobox'): - self.secondChannelName = self.channelsCombobox.currentText() - if self.secondChannelName == 'None': - self.secondChannelName = None - self.inputChannelName = 'None' - if self.channelCombobox is not None: - self.inputChannelName = self.channelCombobox.currentText() - - self.reduceMemoryUsage = False - if hasattr(self, 'reduceMemUsageToggle'): - self.reduceMemoryUsage = self.reduceMemUsageToggle.isChecked() - self.customPostProcessFeatures = self.selectedFeaturesRange() - self.customPostProcessGroupedFeatures = self.groupedFeatures() - self.saveLastSelection() - self.freePosData() - self.close() - - def freePosData(self): - if hasattr(self, 'postProcessGroupbox'): - try: - for selector in self.postProcessGroupbox.selectedFeaturesDialog.groupbox.selectors: - qutils.hardDelete(selector) - except AttributeError: - pass - try: - qutils.hardDelete(self.postProcessGroupbox.selectedFeaturesDialog.groupbox) - except AttributeError: - pass - try: - qutils.hardDelete(self.postProcessGroupbox.selectedFeaturesDialog) - except AttributeError: - pass - try: - qutils.hardDelete(self.postProcessGroupbox) - except AttributeError: - pass - - def getConfigPars(self, create_new=False): - if self.configPars is None or create_new: - configPars = config.ConfigParser() - else: - configPars = self.configPars - - if self.preProcessParamsWidget is not None: - preprocCp = self.preProcessParamsWidget.recipeConfigPars( - self.model_name - ) - for section in preprocCp.sections(): - configPars[section] = preprocCp[section] - - configPars[f'{self.model_name}.init'] = {} - configPars[f'{self.model_name}.segment'] = {} - configPars[f'{self.model_name}.extra'] = {} - - init_kwargs = self.getInitKwargs() - model_kwargs = self.getModelKwargs() - - for key, val in init_kwargs.items(): - configPars[f'{self.model_name}.init'][key] = str(val) - for key, val in model_kwargs.items(): - configPars[f'{self.model_name}.segment'][key] = str(val) - if self.extraArgsWidgets: - extra_kwargs = self.getExtraKwargs() - for key, val in extra_kwargs.items(): - configPars[f'{self.model_name}.extra'][key] = str(val) - - configPars[f'{self.model_name}.postprocess'] = {} - if self.postProcessGroupbox is not None: - postProcKwargs = self.postProcessGroupbox.kwargs() - postProcessConfig = configPars[f'{self.model_name}.postprocess'] - postProcessConfig['minSize'] = str(postProcKwargs['min_area']) - postProcessConfig['minSolidity'] = str(postProcKwargs['min_solidity']) - postProcessConfig['maxElongation'] = str( - postProcKwargs['max_elongation'] - ) - postProcessConfig['min_obj_no_zslices'] = str( - postProcKwargs['min_obj_no_zslices'] - ) - postProcessConfig['applyPostProcessing'] = str( - self.postProcessGroupbox.isChecked() - ) - - custom_postproc_section = f'{self.model_name}.custom_postprocess' - configPars[custom_postproc_section] = {} - if self.postProcessGroupbox is not None: - selectFeaturesWidget = ( - self.postProcessGroupbox.selectedFeaturesDialog.groupbox - ) - for selector in selectFeaturesWidget.selectors: - col_name = selector.selectButton.text() - lowStr = 'None' - highStr = 'None' - if selector.lowRangeWidgets.checkbox.isChecked(): - lowVal = selector.lowRangeWidgets.spinbox.value() - lowStr = str(lowVal) - if selector.highRangeWidgets.checkbox.isChecked(): - highVal = selector.highRangeWidgets.spinbox.value() - highStr = str(highVal) - - configPars[custom_postproc_section][col_name] = ( - f'{lowStr}, {highStr}' - ) - - - return configPars - - def saveLastSelection(self): - self.configPars = self.getConfigPars() - with open(self.ini_path, 'w') as configfile: - self.configPars.write(configfile) - - mode = 'Segmentation' if not self.is_tracker else 'Tracking' - - print(f'{mode} parameters saved at "{self.ini_path}"') - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if self.model_name == 'thresholding': - self.segmentGroupBox.setDisabled(True) - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - self.freePosData() - if hasattr(self, 'loop'): - self.loop.exit() - - def cancel_cb(self, checked): - self.cancel = True - self.freePosData() - - def showEvent(self, event) -> None: - buttonHeight = self.okButton.minimumSizeHint().height() - heightInitParams = ( - self.initParamsScrollArea.minimumHeightNoScrollbar() - ) - heightLeft = 70 + buttonHeight - heightCenter = heightInitParams - heightRight = 0 - if self.segmentParamsScrollArea is not None: - heightSegmentParams = ( - self.segmentParamsScrollArea.minimumHeightNoScrollbar() - ) - heightCenter += (heightSegmentParams+ 70 + buttonHeight) - - rowInitParams, _ = self.paramsGroupPosMapper[self.initParamsScrollArea] - rowSegmParams, _ = self.paramsGroupPosMapper[self.segmentParamsScrollArea] - - numInitParams = len(self.init_params) - numSegmentParams = len(self.segment_params) - - try: - segmentParamsStretch = max(1, round(numSegmentParams/numInitParams)) - except ZeroDivisionError as err: - segmentParamsStretch = 1 - self.secondColLayout.setStretch(rowInitParams, 1) - self.secondColLayout.setStretch(rowSegmParams, segmentParamsStretch) - - if self.extraParamsScrollArea is not None: - heightRight += ( - self.extraParamsScrollArea.minimumHeightNoScrollbar() - + 70 + buttonHeight - ) - - - if self.additionalSegmGroupbox is not None: - heightRight += self.additionalSegmGroupbox.minimumSizeHint().height() - heightRight += buttonHeight - if self.preProcessParamsWidget is not None: - heightPreprocParams = ( - self.preProcessParamsWidget.minimumSizeHint().height() - ) - heightLeft += heightPreprocParams - heightLeft += buttonHeight - if self.postProcessGroupbox is not None: - heightRight += self.postProcessGroupbox.minimumSizeHint().height() - heightRight += buttonHeight - if self.seeHereLabel is not None: - heightRight += self.seeHereLabel.minimumSizeHint().height() - height = max(heightLeft, heightRight, heightCenter) - screenHeight = self.screen().size().height() - screenGeom = self.screen().geometry() - screenLeft = screenGeom.left() - screenRight = screenGeom.right() - screenCenter = (screenLeft + screenRight) / 2 - width = self.sizeHint().width() - windowLeft = int(screenCenter - width/2) - self.move(windowLeft, 20) - - if height >= screenHeight - 150: - height = screenHeight - 150 - self.resize(width, height) - -class downloadModel: - def __init__(self, model_name, parent=None): - self.loop = None - self.model_name = model_name - self._parent = parent - - def download(self): - model_url = myutils._model_url(self.model_name) - if model_url is None: - return - - _, model_path = myutils.get_model_path( - self.model_name, create_temp_dir=False - ) - model_name = self.model_name - model_exists = myutils.check_model_exists(model_path, model_name) - if not model_exists: - self.warnDownloadModel(model_path, self.model_name) - try: - self._parent.logger.info( - f'Downloading {self.model_name} model(s) to "{model_path}"' - ) - except Exception as err: - pass - - success = myutils.download_model(self.model_name) - if not success: - self.criticalDowloadFailed() - - def warnDownloadModel(self, model_path, model_name): - txt = html_utils.paragraph( - 'Cell-ACDC needs to download the model ' - f'{model_name}.

' - 'The files will be dowloaded into the following folder:

' - f'{model_path}

' - 'Progress will be displayed in the terminal.
' - ) - msg = widgets.myMessageBox() - msg.information(self._parent, 'Download model', txt) - - def criticalDowloadFailed(self): - import cellacdc - model_name = self.model_name - m = model_name.lower() - weights_filenames = getattr(cellacdc, f'{m}_weights_filenames') - url, alternative_url = myutils._model_url( - model_name, return_alternative=True - ) - url_href = f'this link' - alternative_url_href = f'this link' - _, model_path = myutils.get_model_path(model_name, create_temp_dir=False) - txt = html_utils.paragraph(f""" - Automatic download of {model_name} failed.

- Please, manually download the model weights from {url_href} or - {alternative_url_href}.

- Next, unzip the content (or move the files if not a zip archive) - of the downloaded file into the following folder:

- {model_path}

- NOTE: if clicking on the link above does not work - copy one of the links below and paste it into the browser

- {url} -

- {alternative_url} - """) - weights_paths = [os.path.join(model_path, f) for f in weights_filenames] - weights = '\n\n'.join(weights_paths) - detailsText = ( - f'Files that {model_name} requires:\n\n' - f'{weights}' - ) - msg = widgets.myMessageBox() - msg.critical( - self._parent, f'Download of {model_name} failed', txt, - detailsText=detailsText - ) - self.close_() - - def close_(self): - return - # self.hide() - # self.close() - # if self.loop is not None: - # self.loop.exit() - -class combineMetricsEquationDialog(QBaseDialog): - sigOk = Signal(object) - - def __init__( - self, allChNames, isZstack, isSegm3D, parent=None, debug=False, - closeOnOk=True - ): - super().__init__(parent) - - self.setWindowTitle('Add combined measurement') - - self.initAttributes() - - self.allChNames = allChNames - - self.cancel = True - self.isOperatorMode = False - self.closeOnOk = closeOnOk - - mainLayout = QVBoxLayout() - equationLayout = QHBoxLayout() - - metricsTreeWidget = QTreeWidget() - metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) - self.metricsTreeWidget = metricsTreeWidget - - for chName in allChNames: - channelTreeItem = QTreeWidgetItem(metricsTreeWidget) - channelTreeItem.setText(0, f'{chName} measurements') - metricsTreeWidget.addTopLevelItem(channelTreeItem) - - metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( - isZstack, chName, isSegm3D=isSegm3D - ) - custom_metrics_desc = measurements.custom_metrics_desc( - isZstack, chName, isSegm3D=isSegm3D - ) - - foregrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - foregrMetricsTreeItem.setText(0, 'Cell signal measurements') - channelTreeItem.addChild(foregrMetricsTreeItem) - - bkgrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - bkgrMetricsTreeItem.setText(0, 'Background values') - channelTreeItem.addChild(bkgrMetricsTreeItem) - - if custom_metrics_desc: - customMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - customMetricsTreeItem.setText(0, 'Custom measurements') - channelTreeItem.addChild(customMetricsTreeItem) - - self.addTreeItems( - foregrMetricsTreeItem, metrics_desc.keys(), isCol=True - ) - self.addTreeItems( - bkgrMetricsTreeItem, bkgr_val_desc.keys(), isCol=True - ) - - if custom_metrics_desc: - self.addTreeItems( - customMetricsTreeItem, custom_metrics_desc.keys(), - isCol=True - ) - - self.addChannelLessItems(isZstack, isSegm3D=isSegm3D) - - sizeMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) - sizeMetricsTreeItem.setText(0, 'Size measurements') - metricsTreeWidget.addTopLevelItem(sizeMetricsTreeItem) - - size_metrics_desc = measurements.get_size_metrics_desc( - isSegm3D, True - ) - self.addTreeItems( - sizeMetricsTreeItem, size_metrics_desc.keys(), isCol=True - ) - - propMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) - propMetricsTreeItem.setText(0, 'Region properties') - metricsTreeWidget.addTopLevelItem(propMetricsTreeItem) - - props_names = measurements.get_props_names() - self.addTreeItems( - propMetricsTreeItem, props_names, isCol=True - ) - - operatorsLayout = QHBoxLayout() - operatorsLayout.addStretch(1) - - iconSize = 24 - - self.operatorButtons = [] - self.operators = [ - ('add', '+'), - ('subtract', '-'), - ('multiply', '*'), - ('divide', '/'), - ('open_bracket', '('), - ('close_bracket', ')'), - ('square', '**2'), - ('pow', '**'), - ('ln', 'log('), - ('log10', 'log10('), - ] - operatorFont = QFont() - operatorFont.setPixelSize(16) - for name, text in self.operators: - button = QPushButton() - button.setIcon(QIcon(f':{name}.svg')) - button.setIconSize(QSize(iconSize,iconSize)) - button.text = text - operatorsLayout.addWidget(button) - self.operatorButtons.append(button) - button.clicked.connect(self.addOperator) - # button.setFont(operatorFont) - - clearButton = QPushButton() - clearButton.setIcon(QIcon(':clear.svg')) - clearButton.setIconSize(QSize(iconSize,iconSize)) - clearButton.setFont(operatorFont) - - clearEntryButton = QPushButton() - clearEntryButton.setIcon(QIcon(':backspace.svg')) - clearEntryButton.setFont(operatorFont) - clearEntryButton.setIconSize(QSize(iconSize,iconSize)) - - operatorsLayout.addWidget(clearButton) - operatorsLayout.addWidget(clearEntryButton) - operatorsLayout.addStretch(1) - - newColNameLayout = QVBoxLayout() - newColNameLineEdit = widgets.alphaNumericLineEdit() - newColNameLineEdit.setAlignment(Qt.AlignCenter) - self.newColNameLineEdit = newColNameLineEdit - newColNameLayout.addStretch(1) - newColNameLayout.addWidget(QLabel('New measurement name:')) - newColNameLayout.addWidget(newColNameLineEdit) - newColNameLayout.addStretch(1) - - equationDisplayLayout = QVBoxLayout() - equationDisplayLayout.addWidget(QLabel('Equation:')) - equationDisplay = QPlainTextEdit() - # equationDisplay.setReadOnly(True) - self.equationDisplay = equationDisplay - equationDisplayLayout.addWidget(equationDisplay) - equationDisplayLayout.setStretch(0,0) - equationDisplayLayout.setStretch(1,1) - - equationLayout.addLayout(newColNameLayout) - equationLayout.addWidget(QLabel(' = ')) - equationLayout.addLayout(equationDisplayLayout) - equationLayout.setStretch(0,1) - equationLayout.setStretch(1,0) - equationLayout.setStretch(2,2) - - testOutputLayout = QVBoxLayout() - testOutputLayout.addWidget(QLabel('Result of test with random inputs:')) - testOutputDisplay = QTextEdit() - testOutputDisplay.setReadOnly(True) - self.testOutputDisplay = testOutputDisplay - testOutputLayout.addWidget(testOutputDisplay) - testOutputLayout.setStretch(0,0) - testOutputLayout.setStretch(1,1) - - instructions = html_utils.paragraph(""" - Double-click on any of the available measurements - to add it to the equation.

- NOTE: the result will be saved in the acdc_output.csv - file as a column with the same name
- you enter in "New measurement name" - field.

- """) - - buttonsLayout = QHBoxLayout() - - cancelButton = widgets.cancelPushButton('Cancel') - helpButton = widgets.infoPushButton(' Help...') - testButton = widgets.calcPushButton('Test output') - okButton = widgets.okPushButton(' Ok ') - okButton.setDisabled(True) - self.okButton = okButton - - buttonsLayout.addStretch(1) - - if debug: - debugButton = QPushButton('Debug') - debugButton.clicked.connect(self._debug) - buttonsLayout.addWidget(debugButton) - - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(helpButton) - buttonsLayout.addWidget(testButton) - buttonsLayout.addWidget(okButton) - - mainLayout.addWidget(QLabel(instructions)) - mainLayout.addWidget(QLabel('Available measurements:')) - mainLayout.addWidget(metricsTreeWidget) - mainLayout.addLayout(operatorsLayout) - mainLayout.addLayout(equationLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addLayout(testOutputLayout) - - clearButton.clicked.connect(self.clearEquation) - clearEntryButton.clicked.connect(self.clearEntryEquation) - metricsTreeWidget.itemDoubleClicked.connect(self.addColname) - - helpButton.clicked.connect(self.showHelp) - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - testButton.clicked.connect(self.test_cb) - - self.setLayout(mainLayout) - self.setFont(font) - - self.setStyleSheet(TREEWIDGET_STYLESHEET) - - def addChannelLessItems(self, isZstack, isSegm3D=False): - allChannelsTreeItem = QTreeWidgetItem(self.metricsTreeWidget) - allChannelsTreeItem.setText(0, f'All channels measurements') - metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( - isZstack, '', isSegm3D=isSegm3D - ) - custom_metrics_desc = measurements.custom_metrics_desc( - isZstack, '', isSegm3D=isSegm3D - ) - - foregrMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) - foregrMetricsTreeItem.setText(0, 'Cell signal measurements') - allChannelsTreeItem.addChild(foregrMetricsTreeItem) - - bkgrMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) - bkgrMetricsTreeItem.setText(0, 'Background values') - allChannelsTreeItem.addChild(bkgrMetricsTreeItem) - - if custom_metrics_desc: - customMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) - customMetricsTreeItem.setText(0, 'Custom measurements') - allChannelsTreeItem.addChild(customMetricsTreeItem) - - self.addTreeItems( - foregrMetricsTreeItem, metrics_desc.keys(), isCol=True, - isChannelLess=True - ) - self.addTreeItems( - bkgrMetricsTreeItem, bkgr_val_desc.keys(), isCol=True, - isChannelLess=True - ) - - if custom_metrics_desc: - self.addTreeItems( - customMetricsTreeItem, custom_metrics_desc.keys(), - isCol=True, isChannelLess=True - ) - - def addOperator(self): - button = self.sender() - text = f'{self.equationDisplay.toPlainText()}{button.text}' - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(button.text)) - - def clearEquation(self): - self.isOperatorMode = False - self.equationDisplay.setPlainText('') - self.initAttributes() - - def initAttributes(self): - self.clearLenghts = [] - self.equationColNames = [] - self.channelLessColnames = [] - - def clearEntryEquation(self): - if not self.clearLenghts: - return - - text = self.equationDisplay.toPlainText() - newText = text[:-self.clearLenghts[-1]] - clearedText = text[-self.clearLenghts[-1]:] - self.clearLenghts.pop(-1) - self.equationDisplay.setPlainText(newText) - if clearedText in self.equationColNames: - self.equationColNames.remove(clearedText) - if clearedText in self.channelLessColnames: - self.channelLessColnames.remove(clearedText) - - def addTreeItems( - self, parentItem, itemsText, isCol=False, isChannelLess=False - ): - for text in itemsText: - _item = QTreeWidgetItem(parentItem) - _item.setText(0, text) - parentItem.addChild(_item) - if isCol: - _item.isCol = True - _item.isChannelLess = isChannelLess - - - def addColname(self, item, column): - if not hasattr(item, 'isCol'): - return - - colName = item.text(0) - text = f'{self.equationDisplay.toPlainText()}{colName}' - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(colName)) - self.equationColNames.append(colName) - if item.isChannelLess: - self.channelLessColnames.append(colName) - - def _debug(self): - print(self.getEquationsDict()) - - def getEquationsDict(self): - equation = self.equationDisplay.toPlainText() - newColName = self.newColNameLineEdit.text() - if not self.channelLessColnames: - chNamesInTerms = set() - for term in self.equationColNames: - for chName in self.allChNames: - if chName in term: - chNamesInTerms.add(chName) - if len(chNamesInTerms) == 1: - # Equation uses metrics from a single channel --> append channel name - chName = chNamesInTerms.pop() - chColName = f'{chName}_{newColName}' - isMixedChannels = False - return {chColName:equation}, isMixedChannels - else: - # Equation doesn't use all channels metrics nor is single channel - isMixedChannels = True - return {newColName:equation}, isMixedChannels - - isMixedChannels = False - equations = {} - for chName in self.allChNames: - chEquation = equation - chEquationName = newColName - # Append each channel name to channelLess terms - for colName in self.channelLessColnames: - chColName = f'{chName}{colName}' - chEquation = chEquation.replace(colName, chColName) - chEquationName = f'{chName}_{newColName}' - equations[chEquationName] = chEquation - return equations, isMixedChannels - - def ok_cb(self): - if not self.newColNameLineEdit.text(): - self.warnEmptyEquationName() - return - - self.cancel = False - - # Save equation to "/acdc-metrics/combine_metrics.ini" file - config = measurements.read_saved_user_combine_config() - - equationsDict, isMixedChannels = self.getEquationsDict() - for newColName, equation in equationsDict.items(): - config = measurements.add_user_combine_metrics( - config, equation, newColName, isMixedChannels - ) - - isChannelLess = len(self.channelLessColnames) > 0 - if isChannelLess: - channelLess_equation = self.equationDisplay.toPlainText() - equation_name = self.newColNameLineEdit.text() - config = measurements.add_channelLess_combine_metrics( - config, channelLess_equation, equation_name, - self.channelLessColnames - ) - - measurements.save_common_combine_metrics(config) - - self.sigOk.emit(self) - - if self.closeOnOk: - self.close() - - def warnEmptyEquationName(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - "New measurement name" field cannot be empty! - """) - msg.critical( - self, 'Empty new measurement name', txt - ) - - def showHelp(self): - txt = measurements.get_combine_metrics_help_txt() - msg = widgets.myMessageBox( - showCentered=False, wrapText=False, - scrollableText=True, enlargeWidthFactor=1.7 - ) - path = measurements.acdc_metrics_path - msg.addShowInFileManagerButton(path, txt='Show saved file...') - msg.information(self, 'Combine measurements help', txt) - - def test_cb(self): - # Evaluate equation with random inputs - equation = self.equationDisplay.toPlainText() - random_data = np.random.rand(1, len(self.equationColNames))*5 - df = pd.DataFrame( - data=random_data, - columns=self.equationColNames - ).round(5) - newColName = self.newColNameLineEdit.text() - try: - df[newColName] = df.eval(equation) - except Exception as e: - traceback.print_exc() - self.testOutputDisplay.setHtml(html_utils.paragraph(e)) - self.testOutputDisplay.setStyleSheet("border: 2px solid red") - return - - self.testOutputDisplay.setStyleSheet("border: 2px solid green") - self.okButton.setDisabled(False) - - result = df.round(5).iloc[0][newColName] - - # Substitute numbers into equation - inputs = df.iloc[0] - equation_numbers = equation - for c, col in enumerate(self.equationColNames): - equation_numbers = equation_numbers.replace(col, str(inputs[c])) - - # Format output into html text - cols = self.equationColNames - inputs_txt = [f'{col} = {input}' for col, input in zip(cols, inputs)] - list_html = html_utils.to_list(inputs_txt) - text = html_utils.paragraph(f""" - By substituting the following random inputs: - {list_html} - we get the equation:

-   {newColName} = {equation_numbers}

- that equals to:

-   {newColName} = {result} - """) - self.testOutputDisplay.setHtml(text) - -class stopFrameDialog(QBaseDialog): - def __init__(self, posDatas, parent=None): - super().__init__(parent=parent) - - self.cancel = True - - self.setWindowTitle('Stop frame') - - mainLayout = QVBoxLayout() - - infoTxt = html_utils.paragraph( - 'Enter a stop frame number for each of the loaded Positions', - center=True - ) - exp_path = posDatas[0].exp_path - exp_path = os.path.normpath(exp_path).split(os.sep) - exp_path = f'...{f"{os.sep}".join(exp_path[-4:])}' - subInfoTxt = html_utils.paragraph( - f'Experiment folder: {exp_path}', font_size='12px', - center=True - ) - infoLabel = QLabel(f'{infoTxt}{subInfoTxt}') - infoLabel.setToolTip(posDatas[0].exp_path) - mainLayout.addWidget(infoLabel) - mainLayout.addSpacing(20) - - self.posDatas = posDatas - for posData in posDatas: - _layout = QHBoxLayout() - _layout.addStretch(1) - _label = QLabel(html_utils.paragraph(f'{posData.pos_foldername}')) - _layout.addWidget(_label) - - _spinBox = QSpinBox() - _spinBox.setMaximum(214748364) - _spinBox.setAlignment(Qt.AlignCenter) - _spinBox.setFont(font) - if posData.acdc_df is not None: - _val = posData.acdc_df.index.get_level_values(0).max()+1 - else: - _val = posData.readLastUsedStopFrameNumber() - if _val is None: - _val = posData.SizeT - _spinBox.setValue(_val) - - posData.stopFrameSpinbox = _spinBox - - _layout.addWidget(_spinBox) - - viewButton = widgets.viewPushButton('Visualize...') - viewButton.clicked.connect( - partial(self.viewChannelData, posData, _spinBox) - ) - _layout.addWidget(viewButton, alignment=Qt.AlignRight) - - _layout.addStretch(1) - - mainLayout.addLayout(_layout) - - buttonsLayout = QHBoxLayout() - - okButton = widgets.okPushButton(' Ok ') - cancelButton = widgets.cancelPushButton(' Cancel ') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - - self.setLayout(mainLayout) - - def viewChannelData(self, posData, spinBox): - self.sender().setText('Loading...') - QTimer.singleShot( - 200, partial(self._viewChannelData, posData, spinBox, self.sender()) - ) - - def _viewChannelData(self, posData, spinBox, senderButton): - chNames = posData.chNames - if len(chNames) > 1: - ch_name_selector = prompts.select_channel_name( - which_channel='segm', allow_abort=False - ) - ch_name_selector.QtPrompt( - self, chNames,'Select channel name to visualize: ' - ) - if ch_name_selector.was_aborted: - return - chName = ch_name_selector.channel_name - else: - chName = chNames[0] - - channel_file_path = load.get_filename_from_channel( - posData.images_path, chName - ) - posData.frame_i = 0 - posData.loadImgData(imgPath=channel_file_path) - self.slideshowWin = imageViewer( - posData=posData, spinBox=spinBox - ) - self.slideshowWin.update_img() - self.slideshowWin.show() - senderButton.setText('Visualize...') - - def ok_cb(self): - self.cancel = False - for posData in self.posDatas: - stopFrameNum = posData.stopFrameSpinbox.value() - posData.stopFrameNum = stopFrameNum - self.close() - -class pgTestWindow(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - - layout = QVBoxLayout() - - self.graphLayout = pg.GraphicsLayoutWidget() - self.ax1 = pg.PlotItem() - self.ax1.setAspectLocked(True) - self.graphLayout.addItem(self.ax1) - - layout.addWidget(self.graphLayout) - - self.setLayout(layout) - - -class CombineMetricsMultiDfsDialog(QBaseDialog): - sigOk = Signal(object, object) - sigClose = Signal(bool) - - def __init__(self, acdcDfs, allChNames, parent=None, debug=False): - super().__init__(parent) - - self.setWindowTitle('Add combined measurement') - - self.initAttributes() - - self.acdcDfs = acdcDfs - self.cancel = True - self.isOperatorMode = False - - mainLayout = QVBoxLayout() - equationLayout = QHBoxLayout() - - treesLayout = QHBoxLayout() - for i, (acdc_df_endname, acdc_df) in enumerate(acdcDfs.items()): - metricsTreeWidget = QTreeWidget() - metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) - - classified_metrics = measurements.classify_acdc_df_colnames( - acdc_df, allChNames - ) - - for chName in allChNames: - channelTreeItem = QTreeWidgetItem(metricsTreeWidget) - channelTreeItem.setText(0, f'{chName} measurements') - metricsTreeWidget.addTopLevelItem(channelTreeItem) - - standard_metrics = classified_metrics['foregr'][chName] - bkgr_metrics = classified_metrics['bkgr'][chName] - custom_metrics = classified_metrics['custom'][chName] - - if standard_metrics: - foregrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - foregrMetricsTreeItem.setText(0, 'Cell signal measurements') - channelTreeItem.addChild(foregrMetricsTreeItem) - self.addTreeItems( - foregrMetricsTreeItem, standard_metrics, - isCol=True, index=i - ) - - if bkgr_metrics: - bkgrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - bkgrMetricsTreeItem.setText(0, 'Background values') - channelTreeItem.addChild(bkgrMetricsTreeItem) - self.addTreeItems( - bkgrMetricsTreeItem, bkgr_metrics, - isCol=True, index=i - ) - - if custom_metrics: - customMetricsTreeItem = QTreeWidgetItem(channelTreeItem) - customMetricsTreeItem.setText(0, 'Custom measurements') - channelTreeItem.addChild(customMetricsTreeItem) - self.addTreeItems( - customMetricsTreeItem, custom_metrics, - isCol=True, index=i - ) - - if classified_metrics['size']: - sizeMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) - sizeMetricsTreeItem.setText(0, 'Size measurements') - metricsTreeWidget.addTopLevelItem(sizeMetricsTreeItem) - self.addTreeItems( - sizeMetricsTreeItem, classified_metrics['size'], - isCol=True, index=i - ) - - if classified_metrics['props']: - propMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) - propMetricsTreeItem.setText(0, 'Region properties') - metricsTreeWidget.addTopLevelItem(propMetricsTreeItem) - self.addTreeItems( - propMetricsTreeItem, classified_metrics['props'], - isCol=True, index=i - ) - - treeLayout = QVBoxLayout() - treeTitle = QLabel(html_utils.paragraph( - f'{i+1}. {acdc_df_endname} measurements ' - )) - treeLayout.addWidget(treeTitle) - treeLayout.addWidget(metricsTreeWidget) - treesLayout.addLayout(treeLayout) - - metricsTreeWidget.index = i - metricsTreeWidget.itemDoubleClicked.connect(self.addColname) - - operatorsLayout = QHBoxLayout() - operatorsLayout.addStretch(1) - - iconSize = 24 - - self.operatorButtons = [] - self.operators = [ - ('add', '+'), - ('subtract', '-'), - ('multiply', '*'), - ('divide', '/'), - ('open_bracket', '('), - ('close_bracket', ')'), - ('square', '**2'), - ('pow', '**'), - ('ln', 'log('), - ('log10', 'log10('), - ] - operatorFont = QFont() - operatorFont.setPixelSize(16) - for name, text in self.operators: - button = QPushButton() - button.setIcon(QIcon(f':{name}.svg')) - button.setIconSize(QSize(iconSize,iconSize)) - button.text = text - operatorsLayout.addWidget(button) - self.operatorButtons.append(button) - button.clicked.connect(self.addOperator) - # button.setFont(operatorFont) - - clearButton = QPushButton() - clearButton.setIcon(QIcon(':clear.svg')) - clearButton.setIconSize(QSize(iconSize,iconSize)) - clearButton.setFont(operatorFont) - - clearEntryButton = QPushButton() - clearEntryButton.setIcon(QIcon(':backspace.svg')) - clearEntryButton.setFont(operatorFont) - clearEntryButton.setIconSize(QSize(iconSize,iconSize)) - - operatorsLayout.addWidget(clearButton) - operatorsLayout.addWidget(clearEntryButton) - operatorsLayout.addStretch(1) - - newColNameLayout = QVBoxLayout() - newColNameLineEdit = widgets.alphaNumericLineEdit() - newColNameLineEdit.setAlignment(Qt.AlignCenter) - self.newColNameLineEdit = newColNameLineEdit - newColNameLayout.addStretch(1) - newColNameLayout.addWidget(QLabel('New measurement name:')) - newColNameLayout.addWidget(newColNameLineEdit) - newColNameLayout.addStretch(1) - - equationDisplayLayout = QVBoxLayout() - equationDisplayLayout.addWidget(QLabel('Equation:')) - equationDisplay = QPlainTextEdit() - # equationDisplay.setReadOnly(True) - self.equationDisplay = equationDisplay - equationDisplayLayout.addWidget(equationDisplay) - equationDisplayLayout.setStretch(0,0) - equationDisplayLayout.setStretch(1,1) - - equationLayout.addLayout(newColNameLayout) - equationLayout.addWidget(QLabel(' = ')) - equationLayout.addLayout(equationDisplayLayout) - equationLayout.setStretch(0,1) - equationLayout.setStretch(1,0) - equationLayout.setStretch(2,2) - - instructions = html_utils.paragraph(""" - Double-click on any of the available measurements - to add it to the equation.

- NOTE: the result will be saved in a new acdc_output - file as a column with the same name
- you enter in "New measurement name" - field.

- """) - - buttonsLayout = QHBoxLayout() - - cancelButton = widgets.cancelPushButton('Cancel') - testButton = widgets.calcPushButton('Test equation') - okButton = widgets.okPushButton(' Ok ') - okButton.setDisabled(True) - self.okButton = okButton - - if debug: - debugButton = QPushButton('Debug') - debugButton.clicked.connect(self._debug) - buttonsLayout.addWidget(debugButton) - - self.statusLabel = QLabel() - buttonsLayout.addWidget(self.statusLabel) - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(testButton) - buttonsLayout.addWidget(okButton) - - mainLayout.addWidget(QLabel(instructions)) - mainLayout.addLayout(treesLayout) - mainLayout.addLayout(operatorsLayout) - mainLayout.addLayout(equationLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - clearButton.clicked.connect(self.clearEquation) - clearEntryButton.clicked.connect(self.clearEntryEquation) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - testButton.clicked.connect(self.test_cb) - - self.equationDisplay.textChanged.connect(self.equationChanged) - # self.newColNameLineEdit.editingFinished.connect(self.equationChanged) - - self.setLayout(mainLayout) - self.setFont(font) - - self.setStyleSheet(TREEWIDGET_STYLESHEET) - - def setLogger(self, logger, logs_path, log_path): - self.logger = logger - self.logs_path = logs_path - self.log_path = log_path - - def closeEvent(self, event): - self.sigClose.emit(self.cancel) - return super().closeEvent(event) - - def getCombinedDf(self): - dfs = [] - for i, acdc_df in enumerate(self.acdcDfs.values()): - dfs.append(acdc_df.add_suffix(f'_table{i+1}')) - return pd.concat(dfs, axis=1) - - def _log(self, txt): - if hasattr(self, 'logger'): - self.logger.info(txt) - else: - print(f'[INFO]: {txt}') - - def equationChanged(self): - self.okButton.setDisabled(True) - self.statusLabel.setText('') - - @exception_handler - def test_cb(self): - combined_df = self.getCombinedDf() - new_df = pd.DataFrame(index=combined_df.index) - equation = self.equationDisplay.toPlainText() - newColName = self.newColNameLineEdit.text() - new_df[newColName] = combined_df.eval(equation) - self.okButton.setDisabled(False) - self._log('Equation test was successful.') - self.statusLabel.setText( - 'Equation test was successful. You can now click OK.' - ) - - def addOperator(self): - button = self.sender() - text = f'{self.equationDisplay.toPlainText()}{button.text}' - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(button.text)) - - def clearEquation(self): - self.isOperatorMode = False - self.equationDisplay.setPlainText('') - self.initAttributes() - - def initAttributes(self): - self.clearLenghts = [] - self.equationColNames = [] - self.channelLessColnames = [] - - def clearEntryEquation(self): - if not self.clearLenghts: - return - - text = self.equationDisplay.toPlainText() - newText = text[:-self.clearLenghts[-1]] - clearedText = text[-self.clearLenghts[-1]:] - self.clearLenghts.pop(-1) - self.equationDisplay.setPlainText(newText) - if clearedText in self.equationColNames: - self.equationColNames.remove(clearedText) - if clearedText in self.channelLessColnames: - self.channelLessColnames.remove(clearedText) - - def addTreeItems( - self, parentItem, itemsText, isCol=False, isChannelLess=False, - index=None - ): - for text in itemsText: - _item = QTreeWidgetItem(parentItem) - _item.setText(0, text) - parentItem.addChild(_item) - if isCol: - _item.isCol = True - if index is not None: - _item.index = index - _item.isChannelLess = isChannelLess - - def addColname(self, item, column): - if not hasattr(item, 'isCol'): - return - - colName = f'{item.text(0)}_table{item.index+1}' - text = f'{self.equationDisplay.toPlainText()}{colName}' - - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(colName)) - self.equationColNames.append(colName) - if item.isChannelLess: - self.channelLessColnames.append(colName) - - def _debug(self): - print(self.getEquationsDict()) - - def ok_cb(self): - if not self.newColNameLineEdit.text(): - self.warnEmptyEquationName() - return - if not self.equationDisplay.toPlainText(): - self.warnEmptyEquation() - return - - self.expression = self.equationDisplay.toPlainText() - self.newColname = self.newColNameLineEdit.text() - self.cancel = False - self.sigOk.emit(self.newColname, self.expression) - self.close() - - def warnEmptyEquation(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - "Equation" field cannot be empty! - """) - msg.critical( - self, 'Empty equation', txt - ) - - def warnEmptyEquationName(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - "New measurement name" field cannot be empty! - """) - msg.critical( - self, 'Empty new measurement name', txt - ) - -class CombineMetricsMultiDfsSummaryDialog(QBaseDialog): - sigLoadAdditionalAcdcDf = Signal() - - def __init__( - self, acdcDfs, allChNames, parent=None, debug=False - ): - super().__init__(parent) - - self.editedIndex = None - self.cancel = True - self.acdcDfs = acdcDfs - self.allChNames = allChNames - - self.setWindowTitle('Combine measurements summary') - - mainLayout = QVBoxLayout() - viewLayout = QGridLayout() - buttonsLayout = QHBoxLayout() - - row = 0 - txt = html_utils.paragraph('Selected acdc_output tables:') - viewLayout.addWidget(QLabel(txt), row, 0) - - row += 1 - items = [ - f'• Table {i+1}: {e}' - for i, e in enumerate(acdcDfs.keys()) - ] - selectedAcdcDfsList = widgets.readOnlyQList() - selectedAcdcDfsList.addItems(items) - self.selectedAcdcDfsList = selectedAcdcDfsList - - tablesButtonsLayout = QVBoxLayout() - loadAcdcDfButton = widgets.showInFileManagerButton('Load additional tables') - tablesButtonsLayout.addWidget(loadAcdcDfButton) - - loadEquationsButton = widgets.reloadPushButton( - 'Load previously used equations' - ) - tablesButtonsLayout.addWidget(loadEquationsButton) - - - tablesButtonsLayout.addStretch(1) - - viewLayout.addWidget(selectedAcdcDfsList, row, 0) - viewLayout.addLayout(tablesButtonsLayout, row, 1) - viewLayout.setRowStretch(row, 1) - - row += 1 - txt = html_utils.paragraph('Equations:') - viewLayout.addWidget(QLabel(txt), row, 0) - - row += 1 - self.equationsList = widgets.TreeWidget() - self.equationsList.setFont(font) - self.equationsList.setHeaderLabels(['Metric', 'Expression']) - self.equationsList.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection) - - equationsButtonsLayout = QVBoxLayout() - addEquationButton = widgets.addPushButton('Add metric') - removeEquationButton = widgets.subtractPushButton('Remove metric(s)') - editEquationButton = widgets.editPushButton('Edit metric') - removeEquationButton.setDisabled(True) - editEquationButton.setDisabled(True) - self.removeEquationButton = removeEquationButton - self.editEquationButton = editEquationButton - - equationsButtonsLayout.addWidget(addEquationButton) - equationsButtonsLayout.addWidget(removeEquationButton) - equationsButtonsLayout.addWidget(editEquationButton) - equationsButtonsLayout.addStretch(1) - - viewLayout.addWidget(self.equationsList, row, 0) - viewLayout.addLayout(equationsButtonsLayout, row, 1) - viewLayout.setRowStretch(row, 2) - - cancelButton = widgets.cancelPushButton('Cancel') - okButton = widgets.okPushButton('Ok') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(okButton) - - viewLayout.setVerticalSpacing(10) - mainLayout.addLayout(viewLayout) - mainLayout.addSpacing(10) - mainLayout.addLayout(buttonsLayout) - - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - addEquationButton.clicked.connect(self.addEquation_cb) - loadAcdcDfButton.clicked.connect(self.loadButtonClicked) - loadEquationsButton.clicked.connect(self.loadEquationsButtonClicked) - removeEquationButton.clicked.connect(self.removeButtonClicked) - editEquationButton.clicked.connect(self.editButtonClicked) - self.equationsList.itemSelectionChanged.connect( - self.onEquationItemSelectionChanged - ) - - self.setLayout(mainLayout) - - def setLogger(self, logger, logs_path, log_path): - self.logger = logger - self.logs_path = logs_path - self.log_path = log_path - - def loadEquationsButtonClicked(self): - MostRecentPath = myutils.getMostRecentPath() - file_path = QFileDialog.getOpenFileName( - self, 'Select equations file', MostRecentPath, "Config Files (*.ini)" - ";;All Files (*)")[0] - if file_path == '': - return - - cp = config.ConfigParser() - cp.read(file_path) - sectionToMatch = [ - f'table{i+1}:{end}' for i, end in enumerate(self.acdcDfs) - ] - sectionToMatch = ';'.join(sectionToMatch) - - lists = {} - nonMatchingLists = {} - groupsDescr = {} - - for section in cp.sections(): - # Tag acdc_output names with html and table(\d+) with html bold tag - listName = ';'.join([ - re.sub(r'table(\d+):(.*)', r'table\g<1>: \g<2>', s) - for s in section.split(';') - ]) - listName = listName.replace(';', ' ; ') - children = [f'{opt} = {cp[section][opt]}' for opt in cp[section]] - if section == sectionToMatch: - groupsDescr[listName] = ( - 'Equations that were calculated from the same ' - 'table names you loaded' - ) - lists[listName] = children - else: - groupsDescr[listName] = ( - 'Equations that were calculated from table names that ' - 'you did not load now' - ) - nonMatchingLists[listName] = children - # # Not implemented yet --> selecting from non matching table names - # # would require an additional widget where the user sets - # # what df1 and df2 are. - # trees[treeName] = children - - if not lists: - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph(""" - None of the equations in the selected file used the same - table names that you loaded.

- See below which table names and equations are present in the loaded file. - """) - with open(file_path) as iniFile: - detailedText = iniFile.read() - - msg.warning(self, 'Not the same tables', txt, showDialog=False) - msg.setDetailedText(detailedText, visible=True) - msg.addShowInFileManagerButton(os.path.dirname(file_path)) - msg.exec_() - return - - selectWindow = MultiListSelector( - lists, groupsDescr=groupsDescr, title='Select equations to load', - infoTxt='Select equations you want to load' - ) - selectWindow.exec_() - if selectWindow.cancel or not selectWindow.selectedItems: - return - - for listName, equations in selectWindow.selectedItems.items(): - for equation in equations: - metricName, expression = equation.split(' = ') - self.addEquation(metricName, expression) - - - def ok_cb(self): - self.cancel = False - self.equations = {} - for i in range(self.equationsList.topLevelItemCount()): - item = self.equationsList.topLevelItem(i) - self.equations[item.text(0)] = item.text(1) - - self.close() - - def loadButtonClicked(self): - self.sigLoadAdditionalAcdcDf.emit() - - def removeButtonClicked(self): - for item in self.equationsList.selectedItems(): - self.equationsList.invisibleRootItem().removeChild(item) - - def editButtonClicked(self): - self.editedItem = self.equationsList.selectedItems()[0] - self.editedIndex = self.equationsList.indexOfTopLevelItem(self.editedItem) - self.addEquation_cb() - - def onEquationItemSelectionChanged(self): - selectedItems = self.equationsList.selectedItems() - if len(selectedItems) == 1: - self.editEquationButton.setDisabled(False) - self.removeEquationButton.setDisabled(False) - elif len(selectedItems) > 1: - self.removeEquationButton.setDisabled(False) - self.editEquationButton.setDisabled(True) - else: - self.removeEquationButton.setDisabled(True) - self.editEquationButton.setDisabled(True) - - def addAcdcDfs(self, acdcDfsDict): - self.acdcDfs = {**self.acdcDfs, **acdcDfsDict} - items = [ - f'• Table {i+1}: {e}' - for i, e in enumerate(self.acdcDfs.keys()) - ] - self.selectedAcdcDfsList = widgets.readOnlyQList() - self.selectedAcdcDfsList.addItems(items) - - def addEquation(self, newColname, expression): - if self.editedIndex is not None: - self.equationsList.invisibleRootItem().removeChild(self.editedItem) - bkgrColor = QColor(*BACKGROUND_RGBA[:3], 200) - item = widgets.TreeWidgetItem( - self.equationsList, columnColors=[None, bkgrColor] - ) - item.setText(0, newColname) - item.setText(1, expression) - if self.editedIndex is not None: - self.equationsList.insertTopLevelItem(self.editedIndex, item) - else: - self.equationsList.addTopLevelItem(item) - self.equationsList.resizeColumnToContents(0) - self.equationsList.resizeColumnToContents(1) - self.editedIndex = None - - def addEquation_cb(self): - self.addEquationWin = CombineMetricsMultiDfsDialog( - self.acdcDfs, self.allChNames, parent=self - ) - if hasattr(self, 'logger'): - self.addEquationWin.setLogger( - self.logger, self.logs_path, self.log_path - ) - if self.editedIndex is not None: - editedMetricName = self.editedItem.text(0) - self.addEquationWin.newColNameLineEdit.setText(editedMetricName) - editedExpression = self.editedItem.text(1) - self.addEquationWin.equationDisplay.setPlainText(editedExpression) - self.addEquationWin.show() - self.addEquationWin.sigOk.connect(self.addEquation) - self.addEquationWin.sigClose.connect(self.addEquationClosed) - - def addEquationClosed(self, cancelled): - if cancelled: - self.editedIndex = None - - def showEvent(self, event) -> None: - self.resize(int(self.width()*2), self.height()) - -class ShortcutEditorDialog(QBaseDialog): - def __init__( - self, widgetsWithShortcut: dict, - delObjectKey='', - delObjectButton: Literal['Middle click', 'Left click']='Middle click', - zoomOutKeyValue: int=None, - parent=None - ): - self.cancel = True - super().__init__(parent) - - self.setWindowTitle('Customize keyboard shortcuts') - - mainLayout = QVBoxLayout() - - self.customShortcuts = {} - self.shortcutLineEdits = {} - - scrollArea = QScrollArea(self) - scrollArea.setWidgetResizable(True) - scrollAreaWidget = QWidget() - entriesLayout = QGridLayout() - - row = 0 - button = widgets.PushButton(self, flat=True) - button.setIcon(QIcon(":del_obj_click.svg")) - self.delObjShortcutLineEdit = widgets.ShortcutLineEdit( - allowModifiers=True, notAllowedModifier=Qt.AltModifier - ) - if delObjectKey is not None: - self.delObjShortcutLineEdit.setText(delObjectKey) - self.delObjButtonCombobox = QComboBox() - self.delObjButtonCombobox.addItems(['Middle click', 'Left click']) - self.delObjButtonCombobox.setCurrentText(delObjectButton) - entriesLayout.addWidget(button, row, 0) - entriesLayout.addWidget(QLabel('Delete object:'), row, 1) - entriesLayout.addWidget(self.delObjShortcutLineEdit, row, 2) - entriesLayout.addWidget( - self.delObjButtonCombobox, row, 3, alignment=Qt.AlignLeft - ) - - row += 1 - name = 'Zoom out' - button = widgets.PushButton(self, flat=True) - label = QLabel('Zoom out:') - self.zoomShortcutLineEdit = widgets.ShortcutLineEdit() - if zoomOutKeyValue is not None: - zoomOutKeySequence = widgets.KeySequenceFromText(zoomOutKeyValue) - self.zoomShortcutLineEdit.setText(zoomOutKeySequence.toString()) - self.zoomShortcutLineEdit.key = zoomOutKeyValue - self.zoomShortcutLineEdit.textChanged.connect( - self.checkDuplicateShortcuts - ) - entriesLayout.addWidget(button, row, 0) - entriesLayout.addWidget(label, row, 1) - entriesLayout.addWidget(self.zoomShortcutLineEdit, row, 2) - self.shortcutLineEdits[name] = self.zoomShortcutLineEdit - - row += 1 - for row, (name, widget) in enumerate(widgetsWithShortcut.items(), start=row): - button = widgets.PushButton(self, flat=True) - try: - button.setIcon(widget.icon()) - except: - pass - label = QLabel(f'{name}:') - shortcutLineEdit = widgets.ShortcutLineEdit() - if hasattr(widget, 'keyPressShortcut'): - shortcutLineEdit.key = widget.keyPressShortcut - shortcut = widgets.KeySequenceFromText(widget.keyPressShortcut) - isShortcutKeyPress = True - else: - shortcut = widget.shortcut() - isShortcutKeyPress = False - shortcutLineEdit.setText(shortcut.toString()) - shortcutLineEdit.textChanged.connect(self.checkDuplicateShortcuts) - shortcutLineEdit.isShortcutKeyPress = isShortcutKeyPress - entriesLayout.addWidget(button, row, 0) - entriesLayout.addWidget(label, row, 1) - entriesLayout.addWidget(shortcutLineEdit, row, 2) - self.shortcutLineEdits[name] = shortcutLineEdit - - entriesLayout.setColumnStretch(0, 0) - entriesLayout.setColumnStretch(1, 0) - entriesLayout.setColumnStretch(2, 1) - entriesLayout.setColumnStretch(3, 0) - - scrollAreaWidget.setLayout(entriesLayout) - scrollArea.setWidget(scrollAreaWidget) - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(scrollArea) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setFont(font) - self.setLayout(mainLayout) - - def checkDuplicateShortcuts(self, text): - for name, shortcutLineEdit in self.shortcutLineEdits.items(): - if shortcutLineEdit == self.sender(): - continue - if shortcutLineEdit.text() != text: - continue - shortcutLineEdit.setText('') - - def warnInvalidKeySequenceDelObjWithLeftClick(self): - txt = html_utils.paragraph( - 'The selected key sequence to delete objects with "Left click" ' - 'is invalid.

' - 'Only "Middle click" can be used without pressing keys.

' - 'Thank you for your patience!' - ) - msg = widgets.myMessageBox() - msg.warning(self, 'Invalid key sequence to delete objects', txt) - - def ok_cb(self): - delObjButtonText = self.delObjButtonCombobox.currentText() - delObjKeySequence = self.delObjShortcutLineEdit.keySequence - if delObjButtonText == 'Left click' and delObjKeySequence is None: - self.warnInvalidKeySequenceDelObjWithLeftClick() - return - - self.shortcutLineEdits.pop('Zoom out') - self.cancel = False - for name, shortcutLineEdit in self.shortcutLineEdits.items(): - text = shortcutLineEdit.text() - if shortcutLineEdit.isShortcutKeyPress: - self.customShortcuts[name] = (text, shortcutLineEdit.key) - else: - self.customShortcuts[name] = ( - text, shortcutLineEdit.keySequence - ) - - delObjQtButton = ( - Qt.MouseButton.LeftButton if delObjButtonText == 'Left click' - else Qt.MouseButton.MiddleButton - ) - self.delObjAction = delObjKeySequence, delObjQtButton - self.zoomOutKeyValue = self.zoomShortcutLineEdit.key - - self.close() - - def showEvent(self, event) -> None: - self.resize(int(self.width()*1.2), self.height()) - self.move(self.x(), 100) - -class SelectAcdcDfVersionToRestore(QBaseDialog): - def __init__(self, posData, parent=None): - super().__init__(parent=parent) - - self.cancel = True - - self.setWindowTitle('Select annotations table to restore') - - mainLayout = QVBoxLayout() - - acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) - instructionsLabel = html_utils.paragraph( - f'Select an older version of the {acdc_df_filename} ' - 'annotations table to load.

' - 'The datetime refers to the time you replaced the old version with ' - 'a newer one.

' - ) - mainLayout.addWidget(QLabel(instructionsLabel)) - - self.savedListBox = None - if os.path.exists(posData.acdc_output_backup_zip_path): - zip_path = posData.acdc_output_backup_zip_path - self.savedArchivefilepath = zip_path - with zipfile.ZipFile(zip_path, mode='r') as zip: - csv_names = natsorted(zip.namelist(), reverse=True) - - keys = [csv_name[:-4] for csv_name in csv_names] - - self.savedKeys = keys - f = load.ISO_TIMESTAMP_FORMAT - timestamps = [datetime.datetime.strptime(key, f) for key in keys] - items = [date.strftime(r'%d %b %Y, %H:%M:%S') for date in timestamps] - mainLayout.addWidget(QLabel('Saved annotations:')) - self.savedListBox = widgets.listWidget() - self.savedListBox.addItems(items) - mainLayout.addWidget(self.savedListBox) - self.savedListBox.itemSelectionChanged.connect( - self.onItemSelectionChanged - ) - - recovery_folderpath = posData.recoveryFolderpath() - unsaved_recovery_folderpath = os.path.join( - recovery_folderpath, 'never_saved' - ) - self.neverSavedFolderpath = unsaved_recovery_folderpath - files = myutils.listdir(unsaved_recovery_folderpath) - csv_files = [file for file in files if file.endswith('.csv')] - self.neverSavedListBox = None - if csv_files: - csv_names = natsorted(csv_files, reverse=True) - keys = [csv_name[:-4] for csv_name in csv_names] - self.neverSavedKeys = keys - f = load.ISO_TIMESTAMP_FORMAT - timestamps = [ - datetime.datetime.strptime(key, f) for key in keys - ] - items = [date.strftime(r'%d %b %Y, %H:%M:%S') for date in timestamps] - mainLayout.addWidget(QLabel('Never saved annotations:')) - self.neverSavedListBox = widgets.listWidget() - self.neverSavedListBox.addItems(items) - mainLayout.addWidget(self.neverSavedListBox) - self.neverSavedListBox.itemSelectionChanged.connect( - self.onItemSelectionChanged - ) - - cancelOkLayout = widgets.CancelOkButtonsLayout() - - cancelOkLayout.okButton.clicked.connect(self.ok_cb) - cancelOkLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(cancelOkLayout) - - self.setLayout(mainLayout) - - self.setFont(font) - - def ok_cb(self): - self.cancel = False - try: - for i in range(self.savedListBox.count()): - item = self.savedListBox.item(i) - if item.isSelected(): - self.selectedTimestamp = item.text() - self.selectedKey = self.savedKeys[i] - self.archiveFilePath = self.savedArchivefilepath - break - except Exception as e: - pass - - try: - for i in range(self.neverSavedListBox.count()): - item = self.neverSavedListBox.item(i) - if item.isSelected(): - self.selectedTimestamp = item.text() - self.selectedKey = self.neverSavedKeys[i] - self.archiveFilePath = self.neverSavedFolderpath - break - except Exception as e: - pass - self.close() - - def onItemSelectionChanged(self): - otherListBox = ( - self.savedListBox if self.sender() == self.neverSavedListBox - else self.neverSavedListBox - ) - if otherListBox is None: - return - for i in range(otherListBox.count()): - item = otherListBox.item(i) - item.setSelected(False) - -class ChangeUserProfileFolderPathDialog(QBaseDialog): - def __init__(self, posData, parent=None): - super().__init__(parent=parent) - - self.cancel = True - - self.setWindowTitle('Change user profile folder path') - - mainLayout = QVBoxLayout() - - acdc_folders = load.get_all_acdc_folders(user_profile_path) - acdc_folders_format = [f' - {folder}' for folder in acdc_folders] - acdc_folders_format = '
'.join(acdc_folders_format) - - txt = (f""" - Current user profile path:

- {user_profile_path}

- The user profile contains the following Cell-ACDC folders:

- {acdc_folders_format}

- After clicking "Ok" you will be asked to select the folder where - you want to migrate the user profile data. - """) - - txt = html_utils.paragraph(txt) - label = QLabel(txt) - - mainLayout.addWidget(label) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch() - - self.setLayout(mainLayout) - - def ok_cb(self): - self.cancel = False - self.close() - -class SelectFeaturesRange: - def __init__( - self, - posData, - force_postprocess_2D=False, - qparent=None, - sigValueChanged=None - ) -> None: - self.posData = posData - self.qparent = qparent - self.force_postprocess_2D = force_postprocess_2D - self.sigValueChanged = sigValueChanged - - self.lowRangeWidgets = widgets.CheckableSpinBoxWidgets() - self.highRangeWidgets = widgets.CheckableSpinBoxWidgets() - - self.selectButton = widgets.FeatureSelectorButton( - 'Click to select feature...' - ) - self.selectButton.setSizeLongestText( - 'Spotfit intens. metric, Foregr. integral gauss. peak' - ) - self.selectButton.clicked.connect(self.selectFeature) - self.selectButton.setCursor(Qt.PointingHandCursor) - - self.selectedFeatureGroups = {} - - self.widgets = [ - {'pos': (0, 0), 'widget': self.lowRangeWidgets.checkbox}, - {'pos': (1, 0), 'widget': self.lowRangeWidgets.spinbox}, - {'pos': (1, 1), 'widget': widgets.LessThanPushButton(flat=True)}, - {'pos': (1, 2), 'widget': self.selectButton}, - {'pos': (1, 3), 'widget': widgets.LessThanPushButton(flat=True)}, - {'pos': (0, 4), 'widget': self.highRangeWidgets.checkbox}, - {'pos': (1, 4), 'widget': self.highRangeWidgets.spinbox}, - {'pos': (2, 0), 'widget': widgets.VerticalSpacerEmptyWidget(height=10)} - ] - self.columnsStretches = {0: 0, 1: 0, 2: 1, 3: 0, 4: 0} - - def setText(self, text): - self.selectButton.setText(text) - - def selectFeature(self): - loadedChNames = [self.posData.user_ch_name] - notLoadedChNames = [] - isZstack = self.posData.SizeZ > 1 and not self.force_postprocess_2D - isSegm3D = self.posData.isSegm3D and not self.force_postprocess_2D - self.selectFeatureDialog = SetMeasurementsDialog( - loadedChNames, notLoadedChNames, isZstack, isSegm3D, - posData=self.posData, parent=self.qparent, - isSingleSelection=True, is_concat=True - ) - # self.selectFeatureDialog.resizeVertical() - self.selectFeatureDialog.sigClosed.connect(self.setFeatureText) - self.selectFeatureDialog.show() - - def setFeatureText(self): - if self.selectFeatureDialog.cancel: - return - self.selectButton.setFlat(True) - selectedMetricName, selectedMetricGroup = ( - self.selectFeatureDialog.selectedMetricNameAndGroup() - ) - self.selectButton.setText(selectedMetricName) - self.featureGroup = selectedMetricGroup - -class SelectFeaturesRangeDialog(QBaseDialog): - sigValueChanged = Signal(object) - - def __init__(self, posData=None, parent=None, force_postprocess_2D=False): - super().__init__(parent) - - self.force_postprocess_2D = force_postprocess_2D - - layout = QVBoxLayout() - self.setWindowTitle('Custom features for post-processing') - - self.groupbox = SelectFeaturesRangeGroupbox( - posData=posData, parent=parent, - force_postprocess_2D=force_postprocess_2D - ) - - buttonsLayout = QHBoxLayout() - okPushButton = widgets.okPushButton(' Ok ') - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(okPushButton) - - okPushButton.clicked.connect(self.ok_cb) - - layout.addWidget(self.groupbox) - layout.addSpacing(10) - layout.addLayout(buttonsLayout) - - self.setLayout(layout) - - def ok_cb(self): - if self.groupbox.selectedFeaturesRange(): - self.sigValueChanged.emit(None) - self.hide() - -class SelectFeaturesRangeGroupbox(QGroupBox): - def __init__( - self, posData=None, parent=None, force_postprocess_2D=False - ): - super().__init__(parent) - - self.setTitle('Features and thresholds for filtering segmented objects') - # self.setCheckable(True) - - self.posData = posData - self.force_postprocess_2D = force_postprocess_2D - - self._layout = QGridLayout() - self._layout.setVerticalSpacing(0) - - firstSelector = SelectFeaturesRange( - posData, force_postprocess_2D=force_postprocess_2D - ) - self.addButton = widgets.addPushButton(' Add feature ') - self.addButton.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding - ) - for col, widget in enumerate(firstSelector.widgets): - row, col = widget['pos'] - self._layout.addWidget(widget['widget'], row, col) - for col, stretch in firstSelector.columnsStretches.items(): - self._layout.setColumnStretch(col, stretch) - - lastCol = self._layout.columnCount() - self._layout.addWidget(self.addButton, 0, lastCol+1, 2, 1) - self.lastCol = lastCol+1 - self.selectors = [firstSelector] - - self.setLayout(self._layout) - - # self.setFont(font) - - self.addButton.clicked.connect(self.addFeatureField) - - def addFeatureField(self): - row = self._layout.rowCount() - selector = SelectFeaturesRange( - self.posData, force_postprocess_2D=self.force_postprocess_2D - ) - delButton = widgets.delPushButton('Remove feature') - delButton.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding - ) - delButton.selector = selector - selector.delButton = delButton - for col, widget in enumerate(selector.widgets): - relRow, col = widget['pos'] - self._layout.addWidget(widget['widget'], relRow+row, col) - self._layout.addWidget(delButton, row, self.lastCol, 2, 1) - self.selectors.append(selector) - delButton.clicked.connect(self.removeFeatureField) - - def resetFields(self): - while len(self.selectors) > 1: - selector = self.selectors[-1] - selector.delButton.click() - firstSelector = self.selectors[0] - firstSelector.selectButton.setText('Click to select feature...') - firstSelector.lowRangeWidgets.checkbox.setChecked(False) - firstSelector.highRangeWidgets.checkbox.setChecked(False) - - def removeFeatureField(self): - delButton = self.sender() - for widget in delButton.selector.widgets: - self._layout.removeWidget(widget['widget']) - self._layout.removeWidget(delButton) - self.selectors.remove(delButton.selector) - - def selectedFeaturesRange(self): - featuresRange = {} - for selector in self.selectors: - if selector.selectButton.text().find('Click') != -1: - continue - featuresRange[selector.selectButton.text()] = ( - selector.lowRangeWidgets.value(), - selector.highRangeWidgets.value() - ) - return featuresRange - - def selectedFeaturesGroup(self): - featuresGroup = {} - for selector in self.selectors: - if selector.selectButton.text().find('Click') != -1: - continue - group = selector.featureGroup - featuresGroup[selector.selectButton.text()] = group - return featuresGroup - - def groupedFeatures(self): - featuresGroup = self.selectedFeaturesGroup() - groupedFeatures = {} - for feature, group in featuresGroup.items(): - group = featuresGroup[feature] - if isinstance(group, str): - key = group - if key not in groupedFeatures: - groupedFeatures[key] = [] - groupedFeatures[key].append(feature) - else: - key, channel = list(group.items())[0] - if key not in groupedFeatures: - groupedFeatures[key] = {} - if channel not in groupedFeatures[key]: - groupedFeatures[key][channel] = [] - groupedFeatures[key][channel].append(feature) - return groupedFeatures - - def setValue(self, value): - pass - -def get_existing_directory(allow_images_path=True, **kwargs): - while True: - folder_path = qtpy.compat.getexistingdirectory(**kwargs) - if not folder_path: - return - - if allow_images_path: - return folder_path - - pos_folderpath = os.path.dirname(folder_path) - is_images_folder = ( - folder_path.endswith('Images') - and os.path.basename(pos_folderpath).startswith('Position_') - and os.path.isdir(folder_path) - ) - if not is_images_folder: - return folder_path - - txt = html_utils.paragraph( - 'You cannot save to the Images folder ' - 'because it is reserved to files that start with the same ' - 'basename.

Thank you for your patience!' - ) - msg = widgets.myMessageBox() - msg.warning(kwargs['parent'], 'Cannot save here', txt) - -class ScaleBarPropertiesDialog(QBaseDialog): - sigValueChanged = Signal(object) - - def __init__( - self, maxLength, maxThickness, PhysicalSizeX, parent=None, - **properties - ): - super().__init__(parent=parent) - - self.cancel = True - self.setWindowTitle('Scale bar properties') - - self.PhysicalSizeX = PhysicalSizeX - - mainLayout = QVBoxLayout() - - formLayout = widgets.FormLayout() - formLayout.setVerticalSpacing(10) - formLayout.setHorizontalSpacing(50) - - row = 0 - unitCombobox = QComboBox() - unitFormWidget = widgets.formWidget( - unitCombobox, labelTextLeft='Physical unit' - ) - unitCombobox.addItems( - ['nm', 'μm', 'mm', 'cm'] - ) - if properties.get('unit') is None: - unitCombobox.setCurrentIndex(1) - else: - unitCombobox.setCurrentText(properties.get('unit')) - formLayout.addFormWidget( - unitFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.unitCombobox = unitCombobox - - row += 1 - lengthDoubleSpinbox = widgets.DoubleSpinBox() - lengthDoubleSpinbox.setMaximum(maxLength) - lengthDoubleSpinbox.setMinimum(PhysicalSizeX) - lengthDoubleSpinbox.setDecimals(1) - if properties.get('length_unit') is not None: - lengthDoubleSpinbox.setValue(properties.get('length_unit')) - else: - deafultLength = np.ceil(PhysicalSizeX*15) - lengthDoubleSpinbox.setValue(round(deafultLength)) - lengthFormWidget = widgets.formWidget( - lengthDoubleSpinbox, labelTextLeft='Length (μm)' - ) - self.lengthFormWidget = lengthFormWidget - self.lengthDoubleSpinbox = lengthDoubleSpinbox - formLayout.addFormWidget( - lengthFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - thicknessSpinbox = widgets.DoubleSpinBox() - thicknessSpinbox.setMaximum(maxThickness) - thicknessSpinbox.setMinimum(1) - if properties.get('thickness') is not None: - thicknessSpinbox.setValue(properties.get('thickness')) - else: - thicknessSpinbox.setValue(round(4, 1)) - thicknessSpinbox.setDecimals(1) - thicknessFormWidget = widgets.formWidget( - thicknessSpinbox, labelTextLeft='Thickness (pixel)' - ) - formLayout.addFormWidget( - thicknessFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.thicknessSpinbox = thicknessSpinbox - - row += 1 - locCombobox = QComboBox() - locFormWidget = widgets.formWidget( - locCombobox, labelTextLeft='Location' - ) - locCombobox.addItems( - ['Bottom-right', 'Bottom-left', 'Top-left', 'Top-right', 'Custom'] - ) - loc = properties.get('loc') - if isinstance(loc, str): - locCombobox.setCurrentText(loc.capitalize()) - formLayout.addFormWidget( - locFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.locCombobox = locCombobox - - row += 1 - self.colorButton = widgets.myColorButton(color=(255, 255, 255)) - if properties.get('color') is not None: - self.colorButton.setColor(properties.get('color')) - colorFormWidget = widgets.formWidget( - self.colorButton, labelTextLeft='Color', - widgetAlignment=Qt.AlignCenter, stretchWidget=False - ) - formLayout.addFormWidget( - colorFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - displayTextToggle = widgets.Toggle() - if properties.get('is_text_visible') is not None: - displayTextToggle.setChecked(properties.get('is_text_visible')) - else: - displayTextToggle.setChecked(True) - displayTextFormWidget = widgets.formWidget( - displayTextToggle, labelTextLeft='Display text', - widgetAlignment=Qt.AlignCenter, stretchWidget=False - ) - formLayout.addFormWidget( - displayTextFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.displayTextToggle = displayTextToggle - - row += 1 - fontSizeSpinbox = widgets.SpinBox() - if properties.get('font_size') is not None: - fontSizeSpinbox.setValue(int(properties.get('font_size'))) - else: - fontSizeSpinbox.setValue(12) - fontSizeFormWidget = widgets.formWidget( - fontSizeSpinbox, labelTextLeft='Font size (px)' - ) - self.fontSizeSpinbox = fontSizeSpinbox - formLayout.addFormWidget( - fontSizeFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - decimalsSpinbox = widgets.SpinBox() - decimalsSpinbox.setMaximum(6) - decimalsSpinbox.setMinimum(0) - if properties.get('num_decimals') is not None: - decimalsSpinbox.setValue(properties.get('num_decimals')) - else: - decimalsSpinbox.setValue(0) - decimalsFormWidget = widgets.formWidget( - decimalsSpinbox, labelTextLeft='Number of decimals' - ) - formLayout.addFormWidget( - decimalsFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.decimalsSpinbox = decimalsSpinbox - - row += 1 - moveWithZoomToggle = widgets.Toggle() - moveWithZoomWidget = widgets.formWidget( - moveWithZoomToggle, labelTextLeft='Move scale bar with zoom', - widgetAlignment=Qt.AlignCenter, stretchWidget=False - ) - formLayout.addFormWidget( - moveWithZoomWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.moveWithZoomToggle = moveWithZoomToggle - - mainLayout.addLayout(formLayout) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch() - - self.setLayout(mainLayout) - self.setFont(font) - - self.unitCombobox.currentTextChanged.connect(self.updateLengthUnit) - self.colorButton.clicked.disconnect() - self.colorButton.clicked.connect(self.selectColor) - - self.colorButton.sigColorChanging.connect(self.onValueChanged) - self.lengthDoubleSpinbox.valueChanged.connect(self.onValueChanged) - self.thicknessSpinbox.valueChanged.connect(self.onValueChanged) - self.locCombobox.currentTextChanged.connect(self.onValueChanged) - self.displayTextToggle.toggled.connect(self.onValueChanged) - self.fontSizeSpinbox.valueChanged.connect(self.onValueChanged) - self.decimalsSpinbox.valueChanged.connect(self.onValueChanged) - self.moveWithZoomToggle.toggled.connect(self.onValueChanged) - - def onValueChanged(self, *args, **kwargs): - self.sigValueChanged.emit(self.kwargs()) - - def selectColor(self): - color = self.colorButton.color() - self.colorButton.origColor = color - self.colorButton.colorDialog.setCurrentColor(color) - self.colorButton.colorDialog.setWindowFlags( - Qt.Window | Qt.WindowStaysOnTopHint - ) - self.colorButton.colorDialog.setParent(self) - self.colorButton.colorDialog.open() - w = self.width() - left = self.pos().x() - colorDialogTop = self.colorButton.colorDialog.pos().y() - self.colorButton.colorDialog.move(w+left+10, colorDialogTop) - - def updateLengthUnit(self, unit): - newText = re.sub( - r'\(.*\)', f'({unit})', - self.lengthFormWidget.labelLeft.text() - ) - self.lengthFormWidget.labelLeft.setText(newText) - self.onValueChanged(self) - - def kwargs(self): - unit = self.unitCombobox.currentText() - length_unit = self.lengthDoubleSpinbox.value() - length_um = _core.convert_length(length_unit, unit, 'μm') - length_pixel = length_um/self.PhysicalSizeX - kwargs = { - 'thickness': self.thicknessSpinbox.value(), - 'length_pixel': length_pixel, - 'length_unit': length_unit, - 'is_text_visible': self.displayTextToggle.isChecked(), - 'color': self.colorButton.color(), - 'loc': self.locCombobox.currentText().lower(), - 'font_size': self.fontSizeSpinbox.value(), - 'unit': unit, - 'num_decimals': self.decimalsSpinbox.value(), - 'move_with_zoom': self.moveWithZoomToggle.isChecked() - } - return kwargs - - def ok_cb(self): - self.cancel = False - self.close() - -class SetColumnNamesDialog(QBaseDialog): - def __init__( - self, columnNames, categories, - optionalCategories=None, parent=None - ): - super().__init__(parent) - - if not optionalCategories: - optionalCategories = None - - self.cancel = True - - mainLayout = QVBoxLayout() - - mainLayout.addWidget(QLabel(html_utils.paragraph( - 'Assign a column to the following categories:
' - ))) - - self.categoriesWidgets = {} - formLayout = QFormLayout() - for row, category in enumerate(categories): - combobox = widgets.ComboBox() - combobox.addItems(columnNames) - if optionalCategories is not None: - text = f'* {category}' - else: - text = category - formLayout.addRow(text, combobox) - self.categoriesWidgets[category] = combobox - - if optionalCategories is not None: - optionalItems = ['None', *columnNames] - for row, category in enumerate(optionalCategories): - combobox = widgets.ComboBox() - combobox.addItems(optionalItems) - formLayout.addRow(category, combobox) - self.categoriesWidgets[category] = combobox - - mainLayout.addLayout(formLayout) - if optionalCategories is not None: - mainLayout.addSpacing(10) - mainLayout.addWidget(QLabel(html_utils.paragraph( - '* mandatory', font_size='11px' - ))) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - self.setFont(font) - - def _warnNonUniqueCategories(self, category_1, category_2): - txt = html_utils.paragraph(f""" - The following categories have the same column assigned to it.

- Columns assigned to categories must be unique.

- Categories with the same column: - {html_utils.to_list((category_1, category_2))} - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Non-unique columns', txt) - - def _checkUniqueNames(self): - self.textToCategoryMapper = {} - for category, combobox in self.categoriesWidgets.items(): - if combobox.text() == 'None': - continue - - if combobox.text() not in self.textToCategoryMapper: - self.textToCategoryMapper[combobox.text()] = category - continue - - sameCategory = self.textToCategoryMapper[combobox.text()] - self._warnNonUniqueCategories(category, sameCategory) - return False - - return True - - def ok_cb(self): - proceed = self._checkUniqueNames() - if not proceed: - return - - self.selectedColumns = { - category:combobox.text() - for category, combobox in self.categoriesWidgets.items() - } - self.cancel = False - self.close() - -class CombineFeaturesCalculator(QBaseDialog): - sigOk = Signal(object) - - def __init__( - self, features_groups: dict, - group_name_to_col_mapper: dict=None, - title='Combine features calculator', - parent=None - ): - super().__init__(parent) - - self.cancel = True - - self.setWindowTitle(title) - self.initAttributes() - - mainLayout = QVBoxLayout() - equationLayout = QHBoxLayout() - - metricsTreeWidget = QTreeWidget() - metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) - self.metricsTreeWidget = metricsTreeWidget - - for groupName, features in features_groups.items(): - topLevelTreeWidgetItem = QTreeWidgetItem(metricsTreeWidget) - topLevelTreeWidgetItem.setText(0, groupName) - metricsTreeWidget.addTopLevelItem(topLevelTreeWidgetItem) - self.addTreeItems( - topLevelTreeWidgetItem, features, isCol=True, - name_to_col_mapper=group_name_to_col_mapper.get(groupName) - ) - - operatorsLayout = self.createOperatorsLayout() - newFeatureNameLayout = self.createNewFeatureNameLayout() - equationDisplayLayout = self.createEquationDisplayLayout() - - equationLayout.addLayout(newFeatureNameLayout) - equationLayout.addWidget(QLabel(' = ')) - equationLayout.addLayout(equationDisplayLayout) - equationLayout.setStretch(0,1) - equationLayout.setStretch(1,0) - equationLayout.setStretch(2,2) - - testOutputLayout = self.createTestOutputLayout() - buttonsLayout = self.createButtonsOutputLayout() - - instructions = html_utils.paragraph(""" - Double-click on any of the available measurements - to add it to the equation.

- Before clicking the `Ok` button, check that the equation returns - the expected result by clicking the `Test output` button. - """) - - mainLayout.addWidget(QLabel(instructions)) - mainLayout.addWidget(QLabel('Available measurements:')) - mainLayout.addWidget(metricsTreeWidget) - mainLayout.addLayout(operatorsLayout) - mainLayout.addLayout(equationLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addLayout(testOutputLayout) - - - metricsTreeWidget.itemDoubleClicked.connect(self.addFeatureName) - self.setLayout(mainLayout) - self.setFont(font) - - self.setStyleSheet(TREEWIDGET_STYLESHEET) - - def setExpandedAll(self, expanded): - if expanded: - self.expandAll() - else: - for i in range(self.metricsTreeWidget.topLevelItemCount()): - topLevelItem = self.metricsTreeWidget.topLevelItem(i) - topLevelItem.setExpanded(False) - - def expandAll(self): - for i in range(self.metricsTreeWidget.topLevelItemCount()): - topLevelItem = self.metricsTreeWidget.topLevelItem(i) - topLevelItem.setExpanded(True) - - def addTreeItems( - self, parentItem, itemsText, isCol=False, name_to_col_mapper=None - ): - for text in itemsText: - _item = QTreeWidgetItem(parentItem) - _item.setText(0, text) - parentItem.addChild(_item) - if isCol: - _item.isCol = True - _item.variable_name = text - if name_to_col_mapper is None: - continue - - col_name = name_to_col_mapper.get(text, None) - if col_name is None: - continue - - _item.variable_name = col_name - - - def addFeatureName(self, item, column): - if not hasattr(item, 'isCol'): - return - - colName = item.variable_name - text = f'{self.equationDisplay.toPlainText()}{colName}' - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(colName)) - self.equationColNames.append(colName) - - def clearEquation(self): - self.isOperatorMode = False - self.equationDisplay.setPlainText('') - self.initAttributes() - - def createButtonsOutputLayout(self): - buttonsLayout = QHBoxLayout() - - cancelButton = widgets.cancelPushButton('Cancel') - helpButton = widgets.infoPushButton(' Help...') - testButton = widgets.calcPushButton('Test output') - okButton = widgets.okPushButton(' Ok ') - okButton.setDisabled(True) - self.okButton = okButton - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(cancelButton) - buttonsLayout.addSpacing(20) - buttonsLayout.addWidget(helpButton) - buttonsLayout.addWidget(testButton) - buttonsLayout.addWidget(okButton) - - helpButton.clicked.connect(self.showHelp) - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.close) - testButton.clicked.connect(self.test_cb) - - return buttonsLayout - - def ok_cb(self): - if not self.newFeatureNameLineEdit.text(): - self.warnEmptyEquationName() - return - - self.equation = self.equationDisplay.toPlainText() - self.newFeatureName = self.newFeatureNameLineEdit.text() - self.cancel = False - self.close() - self.sigOk.emit(self) - - def test_cb(self): - # Evaluate equation with random inputs - equation = self.equationDisplay.toPlainText() - random_data = np.random.rand(1, len(self.equationColNames))*5 - df = pd.DataFrame( - data=random_data, - columns=self.equationColNames - ).round(5) - newColName = self.newFeatureNameLineEdit.text() - try: - df[newColName] = df.eval(equation) - except Exception as e: - traceback.print_exc() - self.testOutputDisplay.setHtml(html_utils.paragraph(e)) - self.testOutputDisplay.setStyleSheet("border: 2px solid red") - return - - self.testOutputDisplay.setStyleSheet("border: 2px solid green") - self.okButton.setDisabled(False) - - result = df.round(5).iloc[0][newColName] - - # Substitute numbers into equation - inputs = df.iloc[0] - equation_numbers = equation - for c, col in enumerate(self.equationColNames): - equation_numbers = equation_numbers.replace(col, str(inputs[c])) - - # Format output into html text - cols = self.equationColNames - inputs_txt = [f'{col} = {input}' for col, input in zip(cols, inputs)] - list_html = html_utils.to_list(inputs_txt) - text = html_utils.paragraph(f""" - By substituting the following random inputs: - {list_html} - we get the equation:

-   {newColName} = {equation_numbers}

- that equals to:

-   {newColName} = {result} - """) - self.testOutputDisplay.setHtml(text) - - def warnEmptyEquationName(self): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - "New measurement name" field cannot be empty! - """) - msg.critical( - self, 'Empty new measurement name', txt - ) - - def showHelp(self): - pass - - def createTestOutputLayout(self): - testOutputLayout = QVBoxLayout() - testOutputLayout.addWidget(QLabel('Result of test with random inputs:')) - testOutputDisplay = QTextEdit() - testOutputDisplay.setReadOnly(True) - self.testOutputDisplay = testOutputDisplay - testOutputLayout.addWidget(testOutputDisplay) - testOutputLayout.setStretch(0,0) - testOutputLayout.setStretch(1,1) - - return testOutputLayout - - def createEquationDisplayLayout(self): - equationDisplayLayout = QVBoxLayout() - equationDisplayLayout.addWidget(QLabel('Equation:')) - equationDisplay = QPlainTextEdit() - # equationDisplay.setReadOnly(True) - self.equationDisplay = equationDisplay - equationDisplayLayout.addWidget(equationDisplay) - equationDisplayLayout.setStretch(0,0) - equationDisplayLayout.setStretch(1,1) - return equationDisplayLayout - - def createNewFeatureNameLayout(self): - newFeatureNameLayout = QVBoxLayout() - newFeatureNameLineEdit = widgets.alphaNumericLineEdit() - newFeatureNameLineEdit.setAlignment(Qt.AlignCenter) - self.newFeatureNameLineEdit = newFeatureNameLineEdit - newFeatureNameLayout.addStretch(1) - newFeatureNameLayout.addWidget(QLabel('New measurement name:')) - newFeatureNameLayout.addWidget(newFeatureNameLineEdit) - newFeatureNameLayout.addStretch(1) - return newFeatureNameLayout - - def createOperatorsLayout(self): - operatorsLayout = QHBoxLayout() - operatorsLayout.addStretch(1) - - iconSize = 24 - - self.operatorButtons = [] - self.operators = [ - ('add', '+'), - ('subtract', '-'), - ('multiply', '*'), - ('divide', '/'), - ('open_bracket', '('), - ('close_bracket', ')'), - ('square', '**2'), - ('pow', '**'), - ('ln', 'log('), - ('log10', 'log10('), - ] - operatorFont = QFont() - operatorFont.setPixelSize(16) - for name, text in self.operators: - button = QPushButton() - button.setIcon(QIcon(f':{name}.svg')) - button.setIconSize(QSize(iconSize,iconSize)) - button.text = text - operatorsLayout.addWidget(button) - self.operatorButtons.append(button) - button.clicked.connect(self.addOperator) - # button.setFont(operatorFont) - - clearButton = QPushButton() - clearButton.setIcon(QIcon(':clear.svg')) - clearButton.setIconSize(QSize(iconSize,iconSize)) - clearButton.setFont(operatorFont) - - clearEntryButton = QPushButton() - clearEntryButton.setIcon(QIcon(':backspace.svg')) - clearEntryButton.setFont(operatorFont) - clearEntryButton.setIconSize(QSize(iconSize,iconSize)) - - operatorsLayout.addWidget(clearButton) - operatorsLayout.addWidget(clearEntryButton) - operatorsLayout.addStretch(1) - - clearButton.clicked.connect(self.clearEquation) - clearEntryButton.clicked.connect(self.clearEntryEquation) - - return operatorsLayout - - def addOperator(self): - button = self.sender() - text = f'{self.equationDisplay.toPlainText()}{button.text}' - self.equationDisplay.setPlainText(text) - self.clearLenghts.append(len(button.text)) - - def clearEquation(self): - self.isOperatorMode = False - self.equationDisplay.setPlainText('') - self.initAttributes() - - def initAttributes(self): - self.clearLenghts = [] - self.equationColNames = [] - self.channelLessColnames = [] - - def clearEntryEquation(self): - if not self.clearLenghts: - return - - text = self.equationDisplay.toPlainText() - newText = text[:-self.clearLenghts[-1]] - clearedText = text[-self.clearLenghts[-1]:] - self.clearLenghts.pop(-1) - self.equationDisplay.setPlainText(newText) - if clearedText in self.equationColNames: - self.equationColNames.remove(clearedText) - if clearedText in self.channelLessColnames: - self.channelLessColnames.remove(clearedText) - -class QInput(QBaseDialog): - def __init__(self, parent=None, title='Input'): - self.cancel = True - self.allowEmpty = True - - super().__init__(parent) - - self.setWindowTitle(title) - - self.mainLayout = QVBoxLayout() - - self.infoLabel = QLabel() - self.mainLayout.addWidget(self.infoLabel) - - promptLayout = QHBoxLayout() - self.promptLabel = QLabel() - promptLayout.addWidget(self.promptLabel) - self.lineEdit = QLineEdit() - promptLayout.addWidget(self.lineEdit) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addLayout(promptLayout) - self.mainLayout.addSpacing(20) - self.mainLayout.addLayout(buttonsLayout) - - self.buttonsLayout = buttonsLayout - - self.setFont(font) - self.setLayout(self.mainLayout) - - def askText(self, prompt, infoText='', allowEmpty=False): - self.allowEmpty = allowEmpty - if infoText: - infoText = f'{infoText}
' - self.infoLabel.setText(html_utils.paragraph(infoText)) - self.promptLabel.setText(prompt) - self.exec_(resizeWidthFactor=1.5) - - def ok_cb(self): - self.answer = self.lineEdit.text() - if not self.allowEmpty and not self.answer: - msg = widgets.myMessageBox(showCentered=False) - msg.critical(self, 'Empty', 'Entry cannot be empty.') - return - self.cancel = False - self.close() - -class InstallPyTorchDialog(QBaseDialog): - def __init__(self, parent=None, caller_name='Cell-ACDC'): - super().__init__(parent=parent) - - self.cancel = True - - mainLayout = QVBoxLayout() - - innerLayout = QGridLayout() - - iconLabel = QLabel(self) - standardIcon = getattr(QStyle, 'SP_MessageBoxInformation') - icon = self.style().standardIcon(standardIcon) - pixmap = icon.pixmap(60, 60) - iconLabel.setPixmap(pixmap) - innerLayout.addWidget(iconLabel, 0, 0, alignment=Qt.AlignTop) - - href = html_utils.href_tag('How to install PyTorch', urls.install_pytorch) - important = html_utils.to_admonition(""" - Should you choose to install PyTorch yourself, make sure to - activate
- the correct acdc environment first
. - """, admonition_type='important') - - infoText = html_utils.paragraph(f""" - {caller_name} needs to install the package PyTorch.

- Select your preferences and click ok to install it now. - You will have to confirm the installation in the terminal.

- Alternatively, you can close {caller_name} and run the command - yourself.

- For more details see this guide: {href}
- {important} - """) - innerLayout.addWidget(QLabel(infoText), 0, 1) - innerLayout.addItem(QSpacerItem(10, 10), 1, 1) - - preferencesLayout = QGridLayout() - - row = 0 - self.osCombobox = QComboBox() - self.osCombobox.addItems(['Linux', 'Mac', 'Windows']) - preferencesLayout.addWidget(QLabel('Your OS'), row, 0) - preferencesLayout.addWidget(self.osCombobox, row, 1) - - if is_mac: - self.osCombobox.setCurrentText('Mac') - elif is_win: - self.osCombobox.setCurrentText('Windows') - - row += 1 - self.pkgManagerCombobox = QComboBox() - self.pkgManagerCombobox.addItems(['Pip']) - if not is_conda_env(): - self.pkgManagerCombobox.setCurrentText('Pip') - self.pkgManagerCombobox.setDisabled(True) - - preferencesLayout.addWidget(QLabel('Package manager'), row, 0) - preferencesLayout.addWidget(self.pkgManagerCombobox, row, 1) - - row += 1 - self.cmptPlatformCombobox = QComboBox() - self.cmptPlatformCombobox.addItems( - ['CPU', 'CUDA 11.8 (NVIDIA GPU)', 'CUDA 12.1 (NVIDIA GPU)'] - ) - - preferencesLayout.addWidget(QLabel('Compute Platform'), row, 0) - preferencesLayout.addWidget(self.cmptPlatformCombobox, row, 1) - - row += 1 - pip_prefix, conda_prefix = myutils.get_pip_conda_prefix() - self.commandWidget = widgets.CopiableCommandWidget( - command=f'{pip_prefix} torch' - ) - preferencesLayout.addWidget(QLabel('Run this command: '), row, 0) - preferencesLayout.addWidget(self.commandWidget, row, 1, 1, 2) - preferencesLayout.setColumnStretch(0, 0) - preferencesLayout.setColumnStretch(1, 0) - preferencesLayout.setColumnStretch(2, 1) - - innerLayout.addLayout(preferencesLayout, 2, 1) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(innerLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - self.osCombobox.currentTextChanged.connect(self.updateCommand) - self.pkgManagerCombobox.currentTextChanged.connect(self.updateCommand) - self.cmptPlatformCombobox.currentTextChanged.connect(self.updateCommand) - - self.updateCommand() - - def updateCommand(self, *args, **kwargs): - osText = self.osCombobox.currentText() - pkgManager = self.pkgManagerCombobox.currentText() - cmptPlatform = self.cmptPlatformCombobox.currentText() - command = myutils.get_pytorch_command()[osText][pkgManager][cmptPlatform] - self.commandWidget.setCommand(command) - - def ok_cb(self): - self.command = self.commandWidget.command() - self.cancel = False - self.close() - -class ExportToVideoParametersDialog(QBaseDialog): - sigOk = Signal(dict) - sigAddScaleBar = Signal(bool) - sigAddTimestamp = Signal(bool) - sigRescaleIntensLut = Signal(str, str) - sigChangeStartTime = Signal(str) - - def __init__( - self, channels, parent=None, startFolderpath='', startFilename='', - startFrameNum=1, SizeT=1, SizeZ=1, isTimelapseVideo=True, - isScaleBarPresent=False, isTimestampPresent=False, - rescaleIntensChannelHowMapper=None, - startTime=None - ): - self.cancel = True - - if rescaleIntensChannelHowMapper is None: - rescaleIntensChannelHowMapper = {} - - super().__init__(parent=parent) - - self.setWindowTitle('Preferences for output video') - - mainLayout = QVBoxLayout() - - gridLayout = QGridLayout() - - navVar = 'frame number' if isTimelapseVideo else 'z-slice' - maxNavVar = SizeT if isTimelapseVideo else SizeZ - - self.isTimelapseVideo = isTimelapseVideo - - row = 0 - gridLayout.addWidget(QLabel(f'Start {navVar}:'), row, 0) - self.startNavVarNumberEntry = widgets.SpinBox() - self.startNavVarNumberEntry.setMinimum(1) - self.startNavVarNumberEntry.setMaximum(maxNavVar-1) - self.startNavVarNumberEntry.setValue(startFrameNum) - gridLayout.addWidget(self.startNavVarNumberEntry, row, 1) - - row += 1 - gridLayout.addWidget(QLabel(f'Stop {navVar}:'), row, 0) - self.stopNavVarNumberEntry = widgets.SpinBox() - self.stopNavVarNumberEntry.setMinimum(2) - self.stopNavVarNumberEntry.setMaximum(maxNavVar) - self.stopNavVarNumberEntry.setValue(maxNavVar) - gridLayout.addWidget(self.stopNavVarNumberEntry, row, 1) - - row += 1 - gridLayout.addWidget(QLabel('File format:'), row, 0) - self.fileFormatCombobox = QComboBox() - self.fileFormatCombobox.addItems(['MP4', 'AVI']) - gridLayout.addWidget(self.fileFormatCombobox, row, 1) - - row += 1 - gridLayout.addWidget(QLabel('Frame rate (FPS):'), row, 0) - self.fpsWidget = widgets.FloatLineEdit(allowNegative=False) - self.fpsWidget.setValue(10.0) - gridLayout.addWidget(self.fpsWidget, row, 1) - - row += 1 - self.dpiWidget = widgets.IntLineEdit(allowNegative=False) - self.dpiWidget.setValue(300) - self.dpiWidget.label = QLabel('DPI') - gridLayout.addWidget(self.dpiWidget.label, row, 0) - gridLayout.addWidget(self.dpiWidget, row, 1) - - row += 1 - gridLayout.addWidget(QLabel('Folder path:'), row, 0) - self.folderPathLineEdit = widgets.ElidingLineEdit(minWidth=240) - self.folderPathLineEdit.setText(startFolderpath) - gridLayout.addWidget(self.folderPathLineEdit, row, 1) - self.browseButton = widgets.browseFileButton( - start_dir=startFolderpath, openFolder=True - ) - gridLayout.addWidget(self.browseButton, row, 2) - - row += 1 - gridLayout.addWidget(QLabel('Filename:'), row, 0) - self.filenameLineEdit = widgets.alphaNumericLineEdit() - self.filenameLineEdit.setAlignment(Qt.AlignCenter) - self.filenameLineEdit.setText(startFilename) - gridLayout.addWidget(self.filenameLineEdit, row, 1) - self.fileFormatLabel = QLabel('.mp4') - gridLayout.addWidget(self.fileFormatLabel, row, 2) - - row += 1 - gridLayout.addWidget(QLabel('Add Scale Bar:'), row, 0) - self.addScaleBarToggle = widgets.Toggle() - gridLayout.addWidget( - self.addScaleBarToggle, row, 1, alignment=Qt.AlignCenter - ) - self.addScaleBarToggle.setChecked(isScaleBarPresent) - - if isTimelapseVideo: - row += 1 - gridLayout.addWidget(QLabel('Add timestamp:'), row, 0) - self.addTimestampToggle = widgets.Toggle() - gridLayout.addWidget( - self.addTimestampToggle, row, 1, alignment=Qt.AlignCenter - ) - self.addTimestampToggle.setChecked(isTimestampPresent) - - for channel in channels: - row += 1 - labelText = f'Rescale intensities (LUT) {channel}:' - gridLayout.addWidget(QLabel(labelText), row, 0) - rescaleItems = ['Rescale each 2D image'] - if SizeZ > 1: - rescaleItems.append('Rescale across z-stack') - if isTimelapseVideo: - rescaleItems.append('Rescale across time frames') - rescaleItems.append('Choose custom levels...') - rescaleItems.append('Do no rescale, display raw image') - rescaleIntensCombobox = QComboBox() - rescaleIntensCombobox.addItems(rescaleItems) - rescaleIntensHow = rescaleIntensChannelHowMapper.get(channel) - if rescaleIntensHow is not None: - rescaleIntensCombobox.setCurrentText(rescaleIntensHow) - gridLayout.addWidget(rescaleIntensCombobox, row, 1) - rescaleIntensCombobox.textActivated.connect( - partial(self.emitRescaleIntens, channel=channel) - ) - - row += 1 - gridLayout.addWidget(QLabel('Save a PNG for each frame:'), row, 0) - self.saveFramesToggle = widgets.Toggle() - gridLayout.addWidget( - self.saveFramesToggle, row, 1, alignment=Qt.AlignCenter - ) - - gridLayout.setColumnStretch(0, 0) - gridLayout.setColumnStretch(1, 1) - gridLayout.setColumnStretch(2, 0) - - self.fileFormatCombobox.currentTextChanged.connect( - self.updateFileFormat - ) - self.browseButton.sigPathSelected.connect(self.updateFolderPath) - self.addScaleBarToggle.toggled.connect(self.addScaleBarToggled) - if isTimelapseVideo: - self.addTimestampToggle.toggled.connect(self.addTimestampToggled) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.setText('Export') - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def emitRescaleIntens(self, how, channel=''): - self.sigRescaleIntensLut.emit(how, channel) - - def addScaleBarToggled(self, checked): - self.sigAddScaleBar.emit(checked) - - def addTimestampToggled(self, checked): - self.sigAddTimestamp.emit(checked) - - def updateFolderPath(self, folderPath): - self.folderPathLineEdit.setText(folderPath) - self.browseButton.setStartPath(folderPath) - - def updateFileFormat(self, fileFormat): - self.fileFormatLabel.setText(f'.{fileFormat.lower()}') - - def validateFolderPath(self): - folderPath = self.folderPathLineEdit.text() - if os.path.exists(folderPath) and os.path.isdir(folderPath): - return True - - text = html_utils.paragraph( - 'The selected folder path is not a valid folder or does not exist' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Not a valid folder', text) - return False - - def validateFilename(self): - filename = self.filenameLineEdit.text() - if filename: - return True - - text = html_utils.paragraph( - 'The filename cannot be empty!' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Not a valid folder', text) - return False - - def validate(self): - proceed = self.validateFolderPath() - if not proceed: - return False - - proceed = self.validateFilename() - if not proceed: - return False - - return True - - def preferences(self, makedirs=True): - filename = f'{self.filenameLineEdit.text()}{self.fileFormatLabel.text()}' - avi_filename = f'{self.filenameLineEdit.text()}.avi' - avi_filepath = os.path.join(self.folderPathLineEdit.text(), avi_filename) - png_foldername = ( - f'{self.filenameLineEdit.text()}_frames_PNG' - ) - pngs_folderpath = os.path.join( - self.folderPathLineEdit.text(), png_foldername - ) - if makedirs: - os.makedirs(pngs_folderpath, exist_ok=True) - - preferences = { - 'start_nav_var_num': self.startNavVarNumberEntry.value(), - 'stop_nav_var_num': self.stopNavVarNumberEntry.value(), - 'filepath': os.path.join(self.folderPathLineEdit.text(), filename), - 'filename': self.filenameLineEdit.text(), - 'avi_filepath': avi_filepath, - 'pngs_folderpath': pngs_folderpath, - 'num_digits': len(str(self.stopNavVarNumberEntry.value())), - 'fps': self.fpsWidget.value(), - 'save_pngs': self.saveFramesToggle.isChecked(), - 'is_timelapse': self.isTimelapseVideo, - 'dpi': self.dpiWidget.value(), - } - return preferences - - def ok_cb(self): - proceed = self.validate() - if not proceed: - return - self.cancel = False - self.sigOk.emit(self.preferences()) - self.selected_preferences = self.preferences() - self.close() - -class TimestampPropertiesDialog(QBaseDialog): - sigValueChanged = Signal(object) - - def __init__(self, parent=None, **properties): - super().__init__(parent=parent) - - self.cancel = True - self.setWindowTitle('Timestamp preferences') - - mainLayout = QVBoxLayout() - - formLayout = widgets.FormLayout() - formLayout.setVerticalSpacing(10) - formLayout.setHorizontalSpacing(50) - - row = 0 - self.startTimeWidget = widgets.TimeWidget() - if properties.get('start_timedelta') is not None: - self.startTimeWidget.setValuesFromTimedelta( - properties.get('start_timedelta') - ) - startTimeFormWidget = widgets.formWidget( - self.startTimeWidget, labelTextLeft='Start time', - ) - formLayout.addFormWidget( - startTimeFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - self.colorButton = widgets.myColorButton(color=(255, 255, 255)) - if properties.get('color') is not None: - self.colorButton.setColor(properties.get('color')) - colorFormWidget = widgets.formWidget( - self.colorButton, labelTextLeft='Color', - widgetAlignment=Qt.AlignCenter, stretchWidget=False - ) - formLayout.addFormWidget( - colorFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - fontSizeWidget = widgets.FontSizeWidget() - if properties.get('font_size') is not None: - fontSizeWidget.setValue(properties.get('font_size')) - else: - fontSizeWidget.setValue(12) - fontSizeFormWidget = widgets.formWidget( - fontSizeWidget, labelTextLeft='Font size (px)' - ) - self.fontSizeWidget = fontSizeWidget - formLayout.addFormWidget( - fontSizeFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - - row += 1 - locCombobox = QComboBox() - locFormWidget = widgets.formWidget( - locCombobox, labelTextLeft='Location' - ) - locCombobox.addItems( - ['Top-left', 'Top-right', 'Bottom-left', 'Bottom-right', 'Custom'] - ) - loc = properties.get('loc') - if isinstance(loc, str): - locCombobox.setCurrentText(loc.capitalize()) - formLayout.addFormWidget( - locFormWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.locCombobox = locCombobox - - row += 1 - moveWithZoomToggle = widgets.Toggle() - moveWithZoomWidget = widgets.formWidget( - moveWithZoomToggle, labelTextLeft='Move timestamp with zoom', - widgetAlignment=Qt.AlignCenter, stretchWidget=False - ) - formLayout.addFormWidget( - moveWithZoomWidget, row=row, - leftLabelAlignment=Qt.AlignLeft - ) - self.moveWithZoomToggle = moveWithZoomToggle - - mainLayout.addLayout(formLayout) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch() - - self.setLayout(mainLayout) - self.setFont(font) - - self.colorButton.clicked.disconnect() - self.colorButton.clicked.connect(self.selectColor) - - self.startTimeWidget.sigValueChanged.connect(self.onValueChanged) - - self.locCombobox.currentTextChanged.connect(self.onValueChanged) - self.fontSizeWidget.sigTextChanged.connect(self.onValueChanged) - self.moveWithZoomToggle.toggled.connect(self.onValueChanged) - - def onValueChanged(self, *args, **kwargs): - self.sigValueChanged.emit(self.kwargs()) - - def selectColor(self): - color = self.colorButton.color() - self.colorButton.origColor = color - self.colorButton.colorDialog.setCurrentColor(color) - self.colorButton.colorDialog.setWindowFlags( - Qt.Window | Qt.WindowStaysOnTopHint - ) - self.colorButton.colorDialog.setParent(self) - self.colorButton.colorDialog.open() - w = self.width() - left = self.pos().x() - colorDialogTop = self.colorButton.colorDialog.pos().y() - self.colorButton.colorDialog.move(w+left+10, colorDialogTop) - - def kwargs(self): - kwargs = { - 'color': self.colorButton.color(), - 'start_timedelta': self.startTimeWidget.timedelta(), - 'loc': self.locCombobox.currentText().lower(), - 'font_size': self.fontSizeWidget.text(), - 'move_with_zoom': self.moveWithZoomToggle.isChecked() - } - return kwargs - - def ok_cb(self): - self.cancel = False - self.close() - -class ExportToImageParametersDialog(QBaseDialog): - sigOk = Signal(dict) - sigAddScaleBar = Signal(bool) - sigRangeChanged = Signal(object) - - def __init__( - self, parent=None, startFolderpath='', startFilename='', - startViewRange=None, isScaleBarPresent=False - ): - self.cancel = True - - super().__init__(parent=parent) - - self.setWindowTitle('Preferences for output image') - - mainLayout = QVBoxLayout() - - gridLayout = QGridLayout() - - row = 0 - gridLayout.addWidget(QLabel('View range X axis:'), row, 0) - self.xRangeSelector = widgets.RangeSelector(integers=True) - if startViewRange is not None: - xRange, yRange = startViewRange - self.xRangeSelector.setRange(*xRange) - gridLayout.addWidget(self.xRangeSelector, row, 1) - - row += 1 - gridLayout.addWidget(QLabel('View range Y axis:'), row, 0) - self.yRangeSelector = widgets.RangeSelector(integers=True) - if startViewRange is not None: - xRange, yRange = startViewRange - self.yRangeSelector.setRange(*yRange) - gridLayout.addWidget(self.yRangeSelector, row, 1) - - row += 1 - gridLayout.addWidget(QLabel('Width and Height:'), row, 0) - self.widthHeightSelector = widgets.RangeSelector( - integers=True, ordered=False - ) - if startViewRange is not None: - xRange, yRange = startViewRange - width = int(xRange[1] - xRange[0]) - height = int(yRange[1] - yRange[0]) - self.widthHeightSelector.setRange(width, height) - gridLayout.addWidget(self.widthHeightSelector, row, 1) - self.lockSizeButton = widgets.LockPushButton() - self.lockSizeButton.setCheckable(True) - self.lockSizeButton.setToolTip( - 'Lock width and height' - ) - gridLayout.addWidget(self.lockSizeButton, row, 2) - - row += 1 - gridLayout.addWidget(QLabel('File format:'), row, 0) - self.fileFormatCombobox = QComboBox() - self.fileFormatCombobox.addItems(['SVG', 'PNG', 'TIFF', 'JPEG']) - gridLayout.addWidget(self.fileFormatCombobox, row, 1) - - row += 1 - self.dpiWidget = widgets.IntLineEdit(allowNegative=False) - self.dpiWidget.setValue(300) - self.dpiWidget.label = QLabel('DPI') - gridLayout.addWidget(self.dpiWidget.label, row, 0) - gridLayout.addWidget(self.dpiWidget, row, 1) - self.dpiWidget.hide() - self.dpiWidget.label.hide() - - row += 1 - gridLayout.addWidget(QLabel('Folder path:'), row, 0) - self.folderPathLineEdit = widgets.ElidingLineEdit(minWidth=240) - self.folderPathLineEdit.setText(startFolderpath) - gridLayout.addWidget(self.folderPathLineEdit, row, 1) - self.browseButton = widgets.browseFileButton( - start_dir=startFolderpath, openFolder=True - ) - gridLayout.addWidget(self.browseButton, row, 2) - - row += 1 - gridLayout.addWidget(QLabel('Filename:'), row, 0) - self.filenameLineEdit = widgets.alphaNumericLineEdit() - self.filenameLineEdit.setAlignment(Qt.AlignCenter) - self.filenameLineEdit.setText(startFilename) - gridLayout.addWidget(self.filenameLineEdit, row, 1) - self.fileFormatLabel = QLabel( - f'.{self.fileFormatCombobox.currentText().lower()}' - ) - gridLayout.addWidget(self.fileFormatLabel, row, 2) - - row += 1 - gridLayout.addWidget(QLabel('Add Scale Bar:'), row, 0) - self.addScaleBarToggle = widgets.Toggle() - gridLayout.addWidget( - self.addScaleBarToggle, row, 1, alignment=Qt.AlignCenter - ) - self.addScaleBarToggle.setChecked(isScaleBarPresent) - - self.fileFormatCombobox.currentTextChanged.connect( - self.updateFileFormat - ) - self.browseButton.sigPathSelected.connect(self.updateFolderPath) - self.addScaleBarToggle.toggled.connect(self.addScaleBarToggled) - self.xRangeSelector.sigLowValueChanged.connect(self.x0Changed) - self.xRangeSelector.sigHighValueChanged.connect(self.x1Changed) - self.yRangeSelector.sigLowValueChanged.connect(self.y0Changed) - self.yRangeSelector.sigHighValueChanged.connect(self.y1Changed) - self.widthHeightSelector.sigLowValueChanged.connect(self.widthChanged) - self.widthHeightSelector.sigHighValueChanged.connect(self.heightChanged) - self.widthHeightSelector.sigRangeManuallyChanged.connect( - self.widthHeightManuallyChanged - ) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.setText('Export') - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - gridLayout.setColumnStretch(2, 0) - - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def widthHeightManuallyChanged(self, *args): - self.lockSizeButton.setChecked(True) - - def x0Changed(self, *args): - if self.lockSizeButton.isChecked(): - x0, _ = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - width, height = self.widthHeightSelector.range() - x1 = x0 + width - xRange = (x0, x1) - else: - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - _, height = self.widthHeightSelector.range() - width = int(xRange[1] - xRange[0]) - - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - self.widthHeightSelector.setRangeNoEmit(width, height) - self.rangeChanged() - - def x1Changed(self, *args): - if self.lockSizeButton.isChecked(): - _, x1 = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - width, height = self.widthHeightSelector.range() - x0 = x1 - width - xRange = (x0, x1) - else: - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - _, height = self.widthHeightSelector.range() - width = int(xRange[1] - xRange[0]) - - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - self.widthHeightSelector.setRangeNoEmit(width, height) - - self.rangeChanged() - - def y0Changed(self, *args): - if self.lockSizeButton.isChecked(): - xRange = self.xRangeSelector.range() - y0, _ = self.yRangeSelector.range() - width, height = self.widthHeightSelector.range() - y1 = y0 + height - yRange = (y0, y1) - else: - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - width, _ = self.widthHeightSelector.range() - height = int(yRange[1] - yRange[0]) - - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - self.widthHeightSelector.setRangeNoEmit(width, height) - - self.rangeChanged() - - def y1Changed(self, *args): - if self.lockSizeButton.isChecked(): - xRange = self.xRangeSelector.range() - _, y1 = self.yRangeSelector.range() - width, height = self.widthHeightSelector.range() - y0 = y1 - height - yRange = (y0, y1) - else: - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - width, _ = self.widthHeightSelector.range() - height = int(yRange[1] - yRange[0]) - - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - self.widthHeightSelector.setRangeNoEmit(width, height) - - self.rangeChanged() - - def widthChanged(self, *args): - self.widthHeightChanged() - self.rangeChanged() - - def heightChanged(self, *args): - self.widthHeightChanged() - self.rangeChanged() - - def updateViewRangeExportToImageDialog(self, viewBox, viewRange, changed): - xRange, yRange = viewRange - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - - def widthHeightChanged(self, *args): - x0, _ = self.xRangeSelector.range() - y0, _ = self.yRangeSelector.range() - width, height = self.widthHeightSelector.range() - x1 = x0 + width - y1 = y0 + height - self.xRangeSelector.setRangeNoEmit(x0, x1) - self.yRangeSelector.setRangeNoEmit(y0, y1) - self.rangeChanged() - - def rangeChanged(self, *args): - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - self.sigRangeChanged.emit((xRange, yRange)) - - def addScaleBarToggled(self, checked): - self.sigAddScaleBar.emit(checked) - - def updateFolderPath(self, folderPath): - self.folderPathLineEdit.setText(folderPath) - self.browseButton.setStartPath(folderPath) - - def updateFileFormat(self, fileFormat): - if fileFormat == 'SVG': - self.dpiWidget.hide() - self.dpiWidget.label.hide() - else: - self.dpiWidget.show() - self.dpiWidget.label.show() - - self.fileFormatLabel.setText(f'.{fileFormat.lower()}') - - def validateFolderPath(self): - folderPath = self.folderPathLineEdit.text() - if os.path.exists(folderPath) and os.path.isdir(folderPath): - return True - - text = html_utils.paragraph( - 'The selected folder path is not a valid folder or does not exist' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Not a valid folder', text) - return False - - def validateFilename(self): - filename = self.filenameLineEdit.text() - if filename: - return True - - text = html_utils.paragraph( - 'The filename cannot be empty!' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Not a valid folder', text) - return False - - def validate(self): - proceed = self.validateFolderPath() - if not proceed: - return False - - proceed = self.validateFilename() - if not proceed: - return False - - return True - - def setViewRange(self, xRange, yRange, emitSignal=True): - if self.lockSizeButton.isChecked(): - x0, _ = xRange - y0, _ = yRange - width, height = self.widthHeightSelector.range() - x1 = x0 + width - y1 = y0 + height - xRange = (x0, x1) - yRange = (y0, y1) - else: - width = int(xRange[1] - xRange[0]) - height = int(yRange[1] - yRange[0]) - - self.xRangeSelector.setRangeNoEmit(*xRange) - self.yRangeSelector.setRangeNoEmit(*yRange) - self.widthHeightSelector.setRangeNoEmit(width, height) - if not emitSignal: - return - - self.rangeChanged() - - def viewRange(self): - xRange = self.xRangeSelector.range() - yRange = self.yRangeSelector.range() - return (xRange, yRange) - - def preferences(self): - filename = f'{self.filenameLineEdit.text()}{self.fileFormatLabel.text()}' - preferences = { - 'view_range_x': self.xRangeSelector.range(), - 'view_range_y': self.yRangeSelector.range(), - 'filepath': os.path.join(self.folderPathLineEdit.text(), filename), - 'filename': self.filenameLineEdit.text(), - 'dpi': self.dpiWidget.value(), - } - return preferences - - def ok_cb(self): - proceed = self.validate() - if not proceed: - return - self.cancel = False - self.sigOk.emit(self.preferences()) - self.selected_preferences = self.preferences() - self.close() - -class DataPrepSubCropsPathsDialog(QBaseDialog): - def __init__(self, cropPaths=None, parent=None): - self.cancel = True - - super().__init__(parent=parent) - - mainLayout = QVBoxLayout() - - gridLayout = QGridLayout() - row = 0 - - if cropPaths is None: - cropPaths = {os.path.expanduser('~'): 1} - - if any([numCrops>1 for numCrops in cropPaths.values()]): - row += 1 - gridLayout.addWidget( - QLabel('Same folder for all crops:'), row, 0 - ) - self.sameFolderPathToggle = widgets.Toggle() - gridLayout.addWidget( - self.sameFolderPathToggle, row, 1, alignment=Qt.AlignCenter - ) - self.sameFolderPathToggle.setChecked(True) - self.sameFolderPathToggle.toggled.connect(self.setSameFolderPath) - - self.windowMinWidth = 0 - minWidth = int(self.screen().size().width()/3) - self.folderPathLineEdits = defaultdict(list) - for path, numCrops in cropPaths.items(): - row += 1 - gridLayout.addWidget(QLabel('Master Position:'), row, 0) - masterPathLabel = QLabel(f'{path}') - gridLayout.addWidget(masterPathLabel, row, 1) - - scrollArea = QScrollArea() - scrollArea.setWidgetResizable(True) - scrollAreaLayout = QGridLayout() - for i in range(numCrops): - label = QLabel(f'Crop {i+1} folder path:') - scrollAreaLayout.addWidget(label, i, 0) - folderPathLineEdit = widgets.ElidingLineEdit() - folderPathLineEdit.label = label - folderPathLineEdit.setText(path) - scrollAreaLayout.addWidget(folderPathLineEdit, i, 1) - browseButton = widgets.browseFileButton( - start_dir=path, openFolder=True - ) - scrollAreaLayout.addWidget(browseButton, i, 2) - browseButton.sigPathSelected.connect( - partial(self.updateFolderPath, lineEdit=folderPathLineEdit) - ) - self.folderPathLineEdits[path].append(folderPathLineEdit) - folderPathLineEdit.browseButton = browseButton - - scrollAreaLayout.setColumnStretch(0, 0) - scrollAreaLayout.setColumnStretch(1, 1) - scrollAreaLayout.setColumnStretch(2, 0) - container = QWidget() - container.setLayout(scrollAreaLayout) - scrollArea.setWidget(container) - - row += 1 - gridLayout.addWidget(scrollArea, row, 0, 1, 2) - noHorizontalScrollbarWidth = ( - container.sizeHint().width() - + scrollArea.verticalScrollBar().sizeHint().width() + 20 - ) - if noHorizontalScrollbarWidth > self.windowMinWidth: - self.windowMinWidth = noHorizontalScrollbarWidth - - row += 1 - gridLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) - - row += 1 - gridLayout.addItem(QSpacerItem(10, 10), row, 0, 1, 2) - - row += 1 - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def show(self, block=False): - self.resize(self.windowMinWidth, self.sizeHint().height()) - super().show(block=block) - - def setSameFolderPath(self, checked): - for masterPath, lineEdits in self.folderPathLineEdits.items(): - referencePath = lineEdits[0].text() - for lineEdit in lineEdits[1:]: - if checked: - lineEdit.setText(referencePath) - - lineEdit.setDisabled(checked) - lineEdit.browseButton.setDisabled(checked) - lineEdit.label.setDisabled(checked) - - def updateFolderPath(self, path, lineEdit=None): - lineEdit.setText(path) - lineEdit.browseButton.setStartPath(path) - - def warnFolderPathNotValid(self, cropNum, masterPath, folderPath): - text = html_utils.paragraph( - f'The following folder path for crop number {cropNum} ' - 'is not a valid folder or does not exist:' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Not a valid folder', text, commands=(folderPath,)) - - def askOverwritingPaths(self, overwritingPaths): - text = html_utils.paragraph( - 'Data in the following paths will be overwritten with ' - 'cropped data.

' - 'Are you sure you want to continue?' - ) - msg = widgets.myMessageBox(wrapText=False) - _, yesButton = msg.warning( - self, 'Not a valid folder', text, commands=overwritingPaths, - buttonsTexts=('No, let me edit paths', 'Yes, overwrite') - ) - return msg.clickedButton == yesButton - - def validatePaths(self): - for masterPath, lineEdits in self.folderPathLineEdits.items(): - for i, lineEdit in enumerate(lineEdits): - path = lineEdit.text() - if os.path.exists(path) and os.path.isdir(path): - continue - - self.warnFolderPathNotValid(i+1, masterPath, path) - return False - - overwritingPaths = [] - for masterPath, lineEdits in self.folderPathLineEdits.items(): - masterPath = masterPath.replace('\\', '/') - if not masterPath.endswith('Images'): - continue - - for i, lineEdit in enumerate(lineEdits): - path = lineEdit.text() - path = path.replace('\\', '/') - if path == masterPath: - overwritingPaths.append(masterPath) - - if not overwritingPaths: - return True - - return self.askOverwritingPaths(overwritingPaths) - - def paths(self): - selectedPaths = {} - for masterPath, lineEdits in self.folderPathLineEdits.items(): - selectedPaths[masterPath] = [le.text() for le in lineEdits] - return selectedPaths - - def ok_cb(self): - proceed = self.validatePaths() - if not proceed: - return - - self.folderPaths = self.paths() - self.cancel = False - self.close() - -class PreProcessParamsWidget(QWidget): - sigLoadRecipe = Signal() - sigLoadSavedRecipe = Signal() - sigValuesChanged = Signal(list) - - def __init__(self, df_metadata=None, addApplyButton=False, parent=None): - super().__init__(parent) - - mainLayout = QVBoxLayout() - - self.df_metadata = df_metadata - self.addApplyButton = addApplyButton - - groupbox = QGroupBox() - self.groupbox = groupbox - - groupbox.setTitle('Pre-processing') - groupbox.setCheckable(True) - - self.gridLayout = QGridLayout() - self.row = -1 - self.stepsWidgets = {} - - self.gridLayout.setColumnStretch(0, 0) - self.gridLayout.setColumnStretch(1, 1) - self.gridLayout.setColumnStretch(2, 0) - self.gridLayout.setColumnStretch(3, 0) - self.gridLayout.setColumnStretch(4, 0) - groupbox.setLayout(self.gridLayout) - - buttonsLayout = QGridLayout() - row = 0 - col = 0 - buttonsLayout.setColumnStretch(col, 1) - - loadRecipeButton = widgets.OpenFilePushButton('Load saved recipe...') - self.loadRecipeButton = loadRecipeButton - buttonsLayout.addWidget(loadRecipeButton, row, col+2) - - saveRecipeButton = widgets.savePushButton('Save current recipe...') - self.saveRecipeButton = saveRecipeButton - buttonsLayout.addWidget(saveRecipeButton, row+1, col+2) - - loadLastRecipeButton = widgets.reloadPushButton('Load last parameters') - self.loadLastRecipeButton = loadLastRecipeButton - buttonsLayout.addWidget(loadLastRecipeButton, row, col+1) - - self.buttonsLayout = buttonsLayout - - loadLastRecipeButton.clicked.connect(self.emitLoadRecipe) - saveRecipeButton.clicked.connect(self.saveRecipe) - loadRecipeButton.clicked.connect(self.selectAndLoadRecipe) - - mainLayout.addWidget(groupbox) - mainLayout.addSpacing(10) - mainLayout.addLayout(buttonsLayout) - - self.addStep(is_first=True) - - mainLayout.setContentsMargins(0, 0, 0, 0) - self.setLayout(mainLayout) - - def stepSizeHeightHint(self): - stepWidgets = self.stepsWidgets[1] - height = ( - stepWidgets['stepLabel'].minimumSizeHint().height() - + stepWidgets['selector'].minimumSizeHint().height() - ) - return height - - def setChecked(self, checked): - self.groupbox.setChecked(checked) - - def emitLoadRecipe(self): - self.sigLoadRecipe.emit() - - def loadRecipe(self, configPars: dict): - for stepWidgets in list(self.stepsWidgets.values()): - try: - stepWidgets['delButton'].click() - except Exception as err: - pass - - configPars = self.sortStepsConfigPars(configPars) - for s in range(1, len(configPars)): - self.stepsWidgets[1]['addButton'].click() - - for i, (section, section_items) in enumerate(configPars.items()): - step_n = i+1 - selector = self.stepsWidgets[step_n]['selector'] - kwarg_to_value_mapper = {} - for option, value in section_items.items(): - if option == 'method': - selector.setCurrentText(value) - method = value - else: - kwarg_to_value_mapper[option] = value - selector.setParams(method, kwarg_to_value_mapper) - - self.setChecked(True) - - def sortStepsConfigPars(self, configPars: dict): - sortedConfigPars = {} - sortedKeys = sorted( - configPars.keys(), - key=lambda key: int(re.findall(r'step(\d+)', key)[0]) - ) - for key in sortedKeys: - sortedConfigPars[key] = configPars[key] - return sortedConfigPars - - def saveRecipeUI(self, folder_path, ext, title, basename, hintText, - default_text):# -> tuple[Literal[False], Literal['']] | tuple[Literal[True], Any]: - win = filenameDialog( - title=title, - basename=basename, - ext=ext, - hintText=hintText, - allowEmpty=False, - defaultEntry=default_text, - parent=self, - ) - win.exec_() - if win.cancel: - return False, '' - - self.cancel = False - filepath = win.filename - os.makedirs(folder_path, exist_ok=True) - filepath = os.path.join(folder_path, filepath) - - if os.path.exists(filepath): - proceed = self.warnExistingRecipeFile(filepath) - if not proceed: - return False, '' - - return True, filepath - - def saveRecipe(self): - recipe = self.recipe() - if recipe is None: - return - - default_text = '' - for step in recipe[:2]: - method = step['method'] - func_name = config.PREPROCESS_MAPPER[method]['function_name'] - default_text = f'{default_text}-{func_name}' - default_text = default_text.lstrip('-') - - proceed, ini_filepath = self.saveRecipeUI(preproc_recipes_path, '.ini', - 'Filename for pre-processing recipe', - 'preprocessing_recipe', - 'Insert a filename for the pre-processing recipe:', - default_text - ) - if not proceed: - return - - cp = self.recipeConfigPars('acdc') - with open(ini_filepath, 'w') as configfile: - cp.write(configfile) - - self.communicateSavingRecipeFinished(ini_filepath) - - def warnExistingRecipeFile(self, ini_filename): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'A file with the following name

' - f'{ini_filename}

' - 'already exists.

' - 'Do you want to overwrite the existing file?' - ) - noButton, yesButton = msg.warning( - self, 'File name existing', txt, - buttonsTexts=( - 'No, stop saving process', - 'Yes, overwrite existing file' - ) - ) - return msg.clickedButton == yesButton - - def warnNoAvailableRecipesToLoad(self): - text = html_utils.paragraph( - 'There are no recipes saved. Sorry about that :(' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'No recipes saved', text) - - # def selectIniFileToLoadRecipe(self): - # import qtpy.compat - # ini_filepath = qtpy.compat.getopenfilename( - # parent=self, - # caption='Select INI file to load pre-processing recipe', - # filters='INI (*.ini);;All Files (*)' - # )[0] - # if not ini_filepath: - # return - - # cp = config.ConfigParser() - # cp.read(ini_filepath) - # preprocConfigPars = {} - # for section in cp.sections(): - # if not section.startswith('acdc.preprocess'): - # continue - - # preprocConfigPars[section] = cp[section] - - # if not preprocConfigPars: - # return - - # self.loadRecipe(preprocConfigPars) - - def selectRecipeFilepath(self, recipes_path, recipe_prefix, ext_label, ext): - availableRecipes = [] - if os.path.exists(recipes_path): - for file in myutils.listdir(recipes_path): - if not file.startswith(recipe_prefix): - continue - endname = file.split(f'{recipe_prefix}_')[1] - availableRecipes.append(endname) - - if not availableRecipes: - import qtpy.compat - filepath = qtpy.compat.getopenfilename( - parent=self, - caption=f'Select {ext_label} file to load recipe', - filters=f'{ext_label} (*.{ext});;All Files (*)' - )[0] - return filepath or None - - browseButton = widgets.browseFileButton( - f'Select {ext_label} file...', - title=f'Select {ext_label} file to load recipe', - openFolder=False, - start_dir=myutils.getMostRecentPath(), - ext={ext_label: f'.{ext}'} - ) - selectRecipeWin = widgets.QDialogListbox( - 'Select recipe', - 'Select recipe to load:\n', - availableRecipes, - multiSelection=False, - allowEmptySelection=False, - parent=self, - additionalButtons=(browseButton,) - ) - browseButton.sigPathSelected.connect( - partial( - self.recipeIniFileSelected, - selectRecipeWin=selectRecipeWin, - sender=browseButton - ) - ) - selectRecipeWin.exec_() - if selectRecipeWin.cancel: - return None - - if selectRecipeWin.clickedButton == browseButton: - return selectRecipeWin.selectedIniFilepath - - selected_endname = selectRecipeWin.selectedItemsText[0] - filename = f'{recipe_prefix}_{selected_endname}' - return os.path.join(recipes_path, filename) - - def selectAndLoadRecipe(self): - filepath = self.selectRecipeFilepath( - preproc_recipes_path, 'preprocessing_recipe', 'INI', 'ini' - ) - if filepath is None: - return - cp = config.ConfigParser() - cp.read(filepath) - preprocConfigPars = { - s: cp[s] for s in cp.sections() - if s.startswith('acdc.preprocess') - } - if not preprocConfigPars: - return - self.loadRecipe(preprocConfigPars) - - def recipeIniFileSelected( - self, ini_filepath, selectRecipeWin=None, sender=None - ): - selectRecipeWin.clickedButton = sender - selectRecipeWin.selectedIniFilepath = ini_filepath - selectRecipeWin.cancel = False - selectRecipeWin.close() - - def communicateSavingRecipeFinished(self, ini_filepath): - text = html_utils.paragraph( - 'Done!

' - 'Pre-processing recipe saved to:' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.information( - self, 'Pre-processing recipe saved!', text, - commands=(ini_filepath,), - path_to_browse=os.path.dirname(ini_filepath) - ) - - def addStep(self, is_first=False): - stepWidgets = {} - - self.row += 1 - - step_n = len(self.stepsWidgets)+1 - label = QLabel(f'Step {step_n}: ') - self.gridLayout.addWidget(label, self.row, 0) - stepWidgets['stepLabel'] = label - - selector = widgets.PreProcessingSelector() - self.gridLayout.addWidget(selector, self.row, 1) - stepWidgets['selector'] = selector - - setParamsButton = widgets.setPushButton() - setParamsButton.setToolTip( - 'Set step parameters' - ) - self.gridLayout.addWidget(setParamsButton, self.row, 2) - setParamsButton.clicked.connect( - partial(self.setParamsStep, selector=selector) - ) - stepWidgets['setParamsButton'] = setParamsButton - - infoButton = widgets.infoPushButton() - self.gridLayout.addWidget(infoButton, self.row, 3) - infoButton.clicked.connect(partial(self.showInfo, selector=selector)) - stepWidgets['infoButton'] = infoButton - - if is_first: - addButton = widgets.addPushButton() - self.gridLayout.addWidget(addButton, self.row, 4) - addButton.clicked.connect(self.addStep) - stepWidgets['addButton'] = addButton - else: - delButton = widgets.delPushButton() - self.gridLayout.addWidget(delButton, self.row, 4) - delButton.clicked.connect(self.removeStep) - delButton.step_n = step_n - stepWidgets['delButton'] = delButton - - self.row += 1 - selector.row = self.row - selector.step_n = step_n - - hline = widgets.QHLine() - self.gridLayout.addWidget(hline, self.row, 0, 1, 6) - stepWidgets['hline'] = hline - self.row += 1 - - self.stepsWidgets[step_n] = stepWidgets - - selector.sigValuesChanged.connect(self.emitValuesChanged) - selector.currentTextChanged.connect( - partial(self.clearInitKwargs, step_n=step_n) - ) - - self.resetStretch() - - def emitValuesChanged(self, functionKwargs, step_n): - self.stepsWidgets[step_n]['step_kwargs'] = functionKwargs - - recipe = self.recipe(warn=False) - if recipe is None: - return - - self.sigValuesChanged.emit(recipe) - - def clearInitKwargs(self, selected_method, step_n=0): - stepWidgets = self.stepsWidgets[step_n] - stepWidgets.pop('step_kwargs', None) - - def resetStretch(self): - for row in range(self.gridLayout.rowCount()): - self.gridLayout.setRowStretch(row, 0) - - self.gridLayout.setRowStretch(self.gridLayout.rowCount(), 1) - self.row = self.gridLayout.rowCount() - 1 - - def showInfo(self, checked=False, selector=None): - if selector is None: - return - - htmlText = selector.htmlInfo() - htmlText = html_utils.paragraph(htmlText) - - method = selector.currentText() - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, f'Info about `{method}`', htmlText) - - def setParamsStep( - self, checked=False, - selector: 'widgets.PreProcessingSelector'=None - ): - step_n = selector.step_n - stepFunctionKwargs = selector.askSetParams( - df_metadata=self.df_metadata, - addApplyButton=self.addApplyButton - ) - if stepFunctionKwargs is None: - return - - self.stepsWidgets[step_n]['step_kwargs'] = stepFunctionKwargs - - def removeStep(self, checked=False, step_n=None): - if step_n is None: - step_n = self.sender().step_n - - stepWidgets = self.stepsWidgets[step_n] - - stepWidgets['stepLabel'].hide() - self.gridLayout.removeWidget(stepWidgets['stepLabel']) - - stepWidgets['selector'].hide() - self.gridLayout.removeWidget(stepWidgets['selector']) - - stepWidgets['infoButton'].hide() - self.gridLayout.removeWidget(stepWidgets['infoButton']) - - # stepWidgets['addButton'].hide() - # self.gridLayout.removeWidget(stepWidgets['addButton']) - - stepWidgets['setParamsButton'].hide() - self.gridLayout.removeWidget(stepWidgets['setParamsButton']) - - stepWidgets['delButton'].hide() - self.gridLayout.removeWidget(stepWidgets['delButton']) - self.row -= 1 - - stepWidgets['hline'].hide() - self.gridLayout.removeWidget(stepWidgets['hline']) - self.row -= 1 - - self.stepsWidgets.pop(step_n) - - stepsWidgetsMapper = {1: self.stepsWidgets[1]} - for i, stepWidgets in enumerate(self.stepsWidgets.values()): - if i == 0: - continue - step_n = i + 1 - label = stepWidgets['stepLabel'] - label.setText(f'Step {step_n}: ') - stepWidgets['delButton'].step_n = step_n - stepWidgets['selector'].step_n = step_n - stepsWidgetsMapper[step_n] = stepWidgets - - self.stepsWidgets = stepsWidgetsMapper - - self.resetStretch() - - def isChecked(self): - return self.groupbox.isChecked() - - def warnStepNotInit(self, method): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - f'The parameters for the preprocessing step {method} ' - 'were not initialized.

' - 'Please, click on the corresponding Set step parameters ' - 'button to initialize this step (cog icon).

' - 'Thank you for your patience!' - ) - msg.warning(self, 'Params not initialized!', txt) - - def recipe(self, warn=True): - recipe = [] - if not self.groupbox.isChecked() and self.groupbox.isCheckable(): - return recipe - - for stepWidgets in self.stepsWidgets.values(): - method = stepWidgets['selector'].currentText() - step_kwargs = stepWidgets.get('step_kwargs') - if step_kwargs is None: - if warn: - self.warnStepNotInit(method) - return - - try: - init_func = config.PREPROCESS_INIT_MAPPER[method]['function'] - init_func(**step_kwargs) - except Exception as err: - pass - - recipe.append({ - 'method': method, 'kwargs': step_kwargs - }) - - return recipe - - def recipeConfigPars(self, model_name): - cp = config.ConfigParser() - if not self.groupbox.isChecked() and self.groupbox.isCheckable(): - return cp - - for s, step in enumerate(self.recipe()): - section = f'{model_name}.preprocess.step{s+1}' - cp[section] = {} - cp[section]['method'] = step['method'] - for option, value in step['kwargs'].items(): - cp[section][option] = str(value) - return cp - -# class QComboBoxChangeColor(QComboBox): -# def __init__(self, forbidden_items=None, parent=None): -# super().__init__(parent) -# self.forbiddenItems = forbidden_items or set() -# self._defaultStyleSheet = self.styleSheet() -# self.currentTextChanged.connect(self._updateColor) - -# def _updateColor(self, text=None): -# if not hasattr(self, '_defaultStyleSheet'): -# self._defaultStyleSheet = self.styleSheet() -# if self.currentText() in self.forbiddenItems: -# self.setStyleSheet( -# self._defaultStyleSheet + """ -# /* Closed state */ -# QComboBox { -# color: red; -# } - -# /* Open state (popup visible) */ -# QComboBox:on { -# color: white; -# } -# """ -# ) -# else: -# self.setStyleSheet(self._defaultStyleSheet) - - - -class CombineChannelsWidget(PreProcessParamsWidget): - sigValuesChangedCombineChannels = Signal() - - def __init__(self, channel_names:Iterable[str], parent=None): - self.channel_names = channel_names - - super().__init__(parent) - - self.parent = parent - qutils.delete_widget(self.loadLastRecipeButton) - qutils.delete_widget(self.saveRecipeButton) - qutils.delete_widget(self.loadRecipeButton) - - def addStep(self, is_first=False): - stepWidgets = {} - - self.row += 1 - if is_first: - self.row += 1 - - step_n = len(self.stepsWidgets)+1 - tooltip = ( - 'Use this text in the formula' - ) - if is_first: - label = QLabel('Formula var') - label.setToolTip( - tooltip - ) - self.gridLayout.addWidget(label, self.row-1, 1) - name_edit = QLineEdit(text=f'img{step_n}') - name_edit.setToolTip( - tooltip - ) - self.gridLayout.addWidget(name_edit, self.row, 1) - stepWidgets['name_edit'] = name_edit - name_edit.textChanged.connect(self.emitValuesChanged) - - tooltip = ( - 'Select a channel or a segmentation mask' - ) - if is_first: - label = QLabel('Channel') - label.setToolTip( - tooltip - ) - self.gridLayout.addWidget(label, self.row-1, 2) - ch_selector = QComboBox() - ch_selector.setToolTip( - tooltip - ) - ch_selector.addItems(self.channel_names) - self.gridLayout.addWidget(ch_selector, self.row, 2) - stepWidgets['selector'] = ch_selector - ch_selector.currentTextChanged.connect(self.setBinarizeCheckableAndNorm) - - # add binarisaion spinbox - tooltip = ( - 'If binarize is selected, the channel will be binarized first, before applying offset and multiplier.\n' - 'If inverse binarize is selected, the channel will be binerized and ' - 'then the logical NOT will be applied.' - ) - if is_first: - label = QLabel('Binarize') - label.setToolTip( - tooltip - ) - self.gridLayout.addWidget(label, self.row-1, 5) - options = ['No', 'binarize', 'inverse binarize'] - self.binarizeCombobox = QComboBox() - self.binarizeCombobox.addItems(options) - self.binarizeCombobox.setCurrentIndex(0) - self.binarizeCombobox.setEnabled(False) - self.binarizeCombobox.setToolTip( - tooltip - ) - self.binarizeCombobox.currentIndexChanged.connect(self.emitValuesChanged) - self.gridLayout.addWidget(self.binarizeCombobox, self.row, 5) - stepWidgets['binarize'] = self.binarizeCombobox - - tooltip = ( - 'Min value of the channel to be normalized to.' - ) - if is_first: - label = QLabel('Min val') - label.setToolTip( - tooltip - ) - self.gridLayout.addWidget(label, self.row-1, 6) - self.minValueSpinbox = QDoubleSpinBox() - self.minValueSpinbox.setRange(-np.inf, np.inf) - self.minValueSpinbox.setSingleStep(0.1) - self.minValueSpinbox.setValue(0) - self.minValueSpinbox.setToolTip( - tooltip - ) - - self.minValueSpinbox.valueChanged.connect(self.emitValuesChanged) - self.gridLayout.addWidget(self.minValueSpinbox, self.row, 6) - stepWidgets['minValueSpinbox'] = self.minValueSpinbox - - tooltip = ( - 'Max value of the channel to be normalized to.' - ) - if is_first: - label = QLabel('Max val') - label.setToolTip( - tooltip - ) - self.gridLayout.addWidget(label, self.row-1, 7) - self.maxValueSpinbox = QDoubleSpinBox() - self.maxValueSpinbox.setRange(-np.inf, np.inf) - self.maxValueSpinbox.setSingleStep(0.1) - self.maxValueSpinbox.setValue(1) - self.maxValueSpinbox.setToolTip( - tooltip - ) - - self.maxValueSpinbox.valueChanged.connect(self.emitValuesChanged) - self.gridLayout.addWidget(self.maxValueSpinbox, self.row, 7) - stepWidgets['maxValueSpinbox'] = self.maxValueSpinbox - - if is_first: - addButton = widgets.addPushButton() - self.gridLayout.addWidget(addButton, self.row, 8) - addButton.clicked.connect(self.addStep) - stepWidgets['addButton'] = addButton - - else: - delButton = widgets.delPushButton() - self.gridLayout.addWidget(delButton, self.row, 8) - delButton.clicked.connect(self.removeStep) - delButton.step_n = step_n - stepWidgets['delButton'] = delButton - - self.row += 1 - ch_selector.row = self.row - ch_selector.step_n = step_n - - hline = widgets.QHLine() - self.gridLayout.addWidget(hline, self.row, 0, 1, 8) - stepWidgets['hline'] = hline - self.row += 1 - - self.stepsWidgets[step_n] = stepWidgets - - self.resetStretch() - self.sigValuesChangedCombineChannels.emit() - self.setBinarizeCheckableAndNorm() - - def emitValuesChanged(self, *args): - self.sigValuesChangedCombineChannels.emit() - - def setBinarizeCheckableAndNorm(self): - for step_n, stepWidgets in self.stepsWidgets.items(): - binarizeSelector = stepWidgets['binarize'] - channel = stepWidgets['selector'].currentText() - if "segm" in channel: - binarizeSelector.setEnabled(True) - # set min and max to 0 and 1 and disable - stepWidgets['minValueSpinbox'].setValue(0) - stepWidgets['maxValueSpinbox'].setValue(1) - stepWidgets['minValueSpinbox'].setEnabled(False) - stepWidgets['maxValueSpinbox'].setEnabled(False) - else: - binarizeSelector.setEnabled(False) - binarizeSelector.setCurrentIndex(0) - # set min and max to 0 and 1 and enable - stepWidgets['minValueSpinbox'].setEnabled(True) - stepWidgets['maxValueSpinbox'].setEnabled(True) - - self.emitValuesChanged() - - def removeStep(self, checked=False, step_n=None): - if step_n is None: - step_n = self.sender().step_n - - stepWidgets = self.stepsWidgets[step_n] - - stepWidgets['name_edit'].hide() - self.gridLayout.removeWidget(stepWidgets['name_edit']) - - stepWidgets['selector'].hide() - self.gridLayout.removeWidget(stepWidgets['selector']) - - stepWidgets['binarize'].hide() - self.gridLayout.removeWidget(stepWidgets['binarize']) - - stepWidgets['minValueSpinbox'].hide() - self.gridLayout.removeWidget(stepWidgets['minValueSpinbox']) - - stepWidgets['maxValueSpinbox'].hide() - self.gridLayout.removeWidget(stepWidgets['maxValueSpinbox']) - - stepWidgets['delButton'].hide() - self.gridLayout.removeWidget(stepWidgets['delButton']) - - self.row -= 1 - - stepWidgets['hline'].hide() - self.gridLayout.removeWidget(stepWidgets['hline']) - self.row -= 1 - - self.stepsWidgets.pop(step_n) - - stepsWidgetsMapper = {1: self.stepsWidgets[1]} - for i, stepWidgets in enumerate(self.stepsWidgets.values()): - if i == 0: - continue - step_n = i + 1 - stepWidgets['delButton'].step_n = step_n - stepWidgets['selector'].step_n = step_n - stepsWidgetsMapper[step_n] = stepWidgets - - self.stepsWidgets = stepsWidgetsMapper - - self.resetStretch() - self.sigValuesChangedCombineChannels.emit() - - def steps(self): - steps = {} - if not self.groupbox.isChecked() and self.groupbox.isCheckable(): - return steps - - for step_number, stepWidgets in self.stepsWidgets.items(): - name = stepWidgets['name_edit'].text() - channel = stepWidgets['selector'].currentText() - binarize = stepWidgets['binarize'].currentText() - min_val = stepWidgets['minValueSpinbox'].value() - max_val = stepWidgets['maxValueSpinbox'].value() - steps[step_number] = { - 'name': name, - 'channel': channel, - 'binarize': binarize, - 'min_val': min_val, - 'max_val': max_val, - } - - steps = dict(sorted(steps.items())) - return steps - -class FormulaEditWidget(QWidget): - sigFormulaChanged = Signal(str, bool) # formula_str, is_valid - - def __init__(self, variable_names=None, parent=None): - super().__init__(parent) - self._variable_names = variable_names or [] - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(4) - - self._edit = QLineEdit() - self._edit.setPlaceholderText('e.g. img1 + img2 * 0.5') - layout.addWidget(self._edit) - - self._status_label = QLabel() - self._status_label.setWordWrap(True) - self._status_label.setStyleSheet('font-size: 11px;') - layout.addWidget(self._status_label) - - self._edit.textChanged.connect(self._onTextChanged) - self._clearStatus() - - self.parent = parent - - def setVariableNames(self, variable_names): - """Allows setting the variables. - - Parameters - ---------- - variable_names : list - list of variable names (strings) - """ - - self._variable_names = variable_names - self._onTextChanged(self._edit.text()) - - def text(self): - """Returns the current formula text.""" - return self._edit.text() - - def setText(self, text): - """Sets the formula text.""" - self._edit.setText(text) - - def _clearStatus(self): - self._status_label.setText('') - self._status_label.setStyleSheet('font-size: 11px;') - - def _onTextChanged(self, text): - if not text.strip(): - self._clearStatus() - - success, reconstructed_str = self.checkValidity(self._variable_names) - - if success: - self._status_label.setText(f'→ {reconstructed_str}') - self._status_label.setStyleSheet( - 'font-size: 11px; color: green;' - ) - else: - self._status_label.setText(reconstructed_str) - self._status_label.setStyleSheet( - 'font-size: 11px; color: red;' - ) - - self.sigFormulaChanged.emit(text, success) - - def checkValidity(self, variable_names=None): - if variable_names is None: - variable_names = self._variable_names - formula_str = self._edit.text() - arrays = {name: 1 for name in variable_names} - success = False - reconstructed_str = 'ERROR' - forb_ch = self.parent.forbiddenChannels - if forb_ch: - stepsWidgets = self.parent.combineChannelsWidget.stepsWidgets - channels = {stepsWidget['selector'].currentText() for stepsWidget in stepsWidgets.values()} - if forb_ch.intersection(channels): - reconstructed_str = ( - 'Channels that are forbidden are not allowed to be used!:\n' - f'{forb_ch}' - ) - return False, reconstructed_str - if formula_str == '': - reconstructed_str = 'First channel is returned/applied' - return True, reconstructed_str - try: - symbols = {name: sp.Symbol(name) for name in arrays} - expr = sp.sympify(formula_str, locals=symbols) - missing = {str(s) for s in expr.free_symbols} - arrays.keys() - if missing: - reconstructed_str = f'Missing variables: {missing}' - return False, reconstructed_str - - if formula_str == '': - reconstructed_str = '' - return True, reconstructed_str - - # filter out expressions that have no variables - if not any(s.is_Symbol for s in expr.free_symbols): - reconstructed_str = 'No variables used' - return False, reconstructed_str - - reconstructed_str = str(expr) - success = True - except Exception as e: - if 'syntax' in str(e): - reconstructed_str = f'Syntax error' - else: - reconstructed_str = str(e) - success = False - return success, reconstructed_str - -class InitFijiMacroDialog(QBaseDialog): - def __init__(self, parent=None): - self.cancel = True - - super().__init__(parent=parent) - - mainLayout = QVBoxLayout() - - infoLabel = QLabel(html_utils.paragraph( - """ - Place all the raw microscopy files in a folder without any other - file
- and provide the following information: - """ - )) - mainLayout.addWidget(infoLabel) - - gridLayout = QGridLayout() - - row = 0 - label = QLabel('Files internal structure: ') - gridLayout.addWidget(label, row, 0) - self.filesStructureCombobox = QComboBox() - self.filesStructureCombobox.addItems([ - 'Positions (aka "series") embedded in the file', - 'Positions (aka "series") separated, one for each file', - 'Positions (aka "series") and channels separated, one for each file' - ]) - gridLayout.addWidget(self.filesStructureCombobox, row, 1) - self.filesStructureCombobox.currentTextChanged.connect( - self.fileStructureChanged - ) - infoButton = widgets.infoPushButton() - gridLayout.addWidget(infoButton, row, 2) - infoButton.clicked.connect(self.showInfoFileStructure) - - row += 1 - label = QLabel('Folder with raw microscopy files: ') - gridLayout.addWidget(label, row, 0) - self.folderPathLineEdit = widgets.ElidingLineEdit() - gridLayout.addWidget(self.folderPathLineEdit, row, 1) - browseButton = widgets.browseFileButton(openFolder=True) - gridLayout.addWidget(browseButton, row, 2) - browseButton.sigPathSelected.connect( - partial(self.updateFolderPath, lineEdit=self.folderPathLineEdit) - ) - self.folderPathLineEdit.textChanged.connect(self.srcFolderPathChanged) - - row += 1 - label = QLabel('Destination folder: ') - gridLayout.addWidget(label, row, 0) - self.dstfolderPathLineEdit = widgets.ElidingLineEdit() - gridLayout.addWidget(self.dstfolderPathLineEdit, row, 1) - browseButton = widgets.browseFileButton(openFolder=True) - gridLayout.addWidget(browseButton, row, 2) - browseButton.sigPathSelected.connect(self.dstfolderPathLineEdit.setText) - - row += 1 - label = QLabel('Channel(s) name: ') - gridLayout.addWidget(label, row, 0) - self.channelNamesLineEdit = widgets.alphaNumericLineEdit( - additionalChars=' ,' - ) - gridLayout.addWidget(self.channelNamesLineEdit, row, 1) - checkButton = widgets.TestPushButton('Check') - gridLayout.addWidget(checkButton, row, 3) - checkButton.clicked.connect(self.checkChannelNames) - checkButton.setDisabled(True) - self.checkButton = checkButton - infoButton = widgets.infoPushButton() - gridLayout.addWidget(infoButton, row, 2) - infoButton.clicked.connect(self.showInfoChannelName) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - gridLayout.setColumnStretch(0, 0) - gridLayout.setColumnStretch(1, 1) - gridLayout.setColumnStretch(2, 0) - gridLayout.setColumnStretch(3, 0) - - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def fileStructureChanged(self, text): - self.checkButton.setDisabled(not 'channels separated' in text) - - def checkChannelNames(self, checked=False): - proceed = self.validate() - if not proceed: - return - - src_folderpath = self.folderPath() - channel_names = self.channelNames() - extension = os.listdir(src_folderpath)[0].split('.')[-1] - basenames = io.move_separate_channels_tiffs_to_pos_folders( - src_folderpath, channel_names, get_only_basenames=True, - extension=extension - ) - pos_folders_texts = [] - for p, basename in enumerate(basenames): - pos_folders_texts.append(f'Position_{p+1}: {basename}') - - pos_folders_html_list = html_utils.to_list( - pos_folders_texts, ordered=True - ) - text = html_utils.paragraph( - 'The following Position folders will be created based on the provided channel names:
' - f'{pos_folders_html_list}' - ) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Position folders', text) - - def srcFolderPathChanged(self, text): - if self.dstfolderPathLineEdit.text(): - return - - folderPath = self.folderPathLineEdit.text() - self.dstfolderPathLineEdit.setText(folderPath) - - def showInfoFileStructure(self): - txt = html_utils.paragraph(""" - Select whether the microscopy files contains multiple "series".

- This typically depends on how you acquired the images at the - microscope, i.e., you generated multiple microscopy files - (e.g., snapshots), or you setup automatic acquisition of multiple - positions. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Files structure info', txt) - - def showInfoChannelName(self): - txt = html_utils.paragraph(""" - Enter the channels name. Separate multiple channels with a comma.

- The channel names will be used to name the individual TIFF files - (one for each channel).

- If multiple channels are embedded in the microscopy file, make sure that you write the channels in the right order.
- If you are unsure, open the file in Fiji first - and check the order of channels.

- If the channels are already separated, make sure to write the - full channel name as it appears in the file, including capitalization and spaces.
- For example, if the files are named "pos1_ch1.tif", "pos1_ch2.tif", etc., the channels names should be "ch1, ch2".

- After providing the channel names, you can check that they are correct by clicking on the "Check" button next to the channel names field.
- The number of Positions that will be created will be displayed alongside the basename. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Files structure info', txt) - - def updateFolderPath(self, path, lineEdit=''): - for file in os.listdir(path): - if not is_alphanumeric_filename(file): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - f""" - The filename {file} contains invalid - characters.

- Valid characters are letters, numbers, spaces, underscores - and dashes.

- Please rename the file and try again.

- Thank you for your patience! - """ - ) - msg.critical( - self, 'Invalid filename', txt, path_to_browse=path - ) - lineEdit.setText('') - return - - lineEdit.setText(path) - - def warnPathEmpty(self, path_name): - txt = html_utils.paragraph(f""" - {path_name} cannot be empty. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Empty folder path', txt) - - def warnSelectedPathDoesNotExist(self, path): - txt = html_utils.paragraph(""" - The selected path does not exist.

- Selected path: - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Folder path does not exist', txt, commands=(path,)) - - def warnSelectedPathNotAFolder(self, path): - txt = html_utils.paragraph(""" - The selected path is not a folder.

- Selected path: - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Selected path not a folder', txt, commands=(path,)) - - def warnMultipleExtensionsPresent(self, path, extensions): - txt = html_utils.paragraph(f""" - The selected path contains files with different extensions. -

- Extensions present: {extensions}

- Please, make sure that all the files in the folder have the same - extension before proceeding.

- Selected path: - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning( - self, 'Multiple file extensions detected', txt, commands=(path,) - ) - - def warnChannelNamesEmpty(self): - txt = html_utils.paragraph(""" - Channel(s) name cannot be empty. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Empty channel name', txt) - - def validate(self): - path = self.folderPath() - dst_path = self.dstfolderPathLineEdit.text() - paths = { - 'Source folder': path, - 'Destination folder': dst_path, - } - for _path_name, _path in paths.items(): - if not _path: - self.warnPathEmpty(_path_name) - return False - - if not os.path.exists(_path): - self.warnSelectedPathDoesNotExist(_path) - return False - - if not os.path.isdir(_path): - self.warnSelectedPathNotAFolder(_path) - return False - - files = myutils.listdir(path) - extensions = set([os.path.splitext(file)[1] for file in files]) - if len(extensions) > 1: - self.warnMultipleExtensionsPresent(path, extensions) - return False - - if not self.channelNamesLineEdit.text(): - self.warnChannelNamesEmpty() - return False - - return True - - def folderPath(self): - return self.folderPathLineEdit.text() - - def channelNames(self): - channel_names = self.channelNamesLineEdit.text().split(',') - channel_names = [ch.strip() for ch in channel_names] - return channel_names - - def ok_cb(self): - proceed = self.validate() - if not proceed: - return - - self.selectedFolderPath = self.folderPath() - self.filesStructure = self.filesStructureCombobox.currentText() - is_multiple_files = self.filesStructure.find('separated') != -1 - is_separate_channels = 'channels separated' in self.filesStructure - dst_folderpath = self.dstfolderPathLineEdit.text() - self.init_macro_args = ( - self.folderPath(), - is_multiple_files, - is_separate_channels, - dst_folderpath, - self.channelNames(), - ) - self.cancel = False - self.close() - -class ImageJRoisToSegmManager(QBaseDialog): - def __init__( - self, rois_filepath, TZYX_shape, - addUseSamePropsForNextPosButton=False, parent=None - ): - import roifile - - self.cancel = True - super().__init__(parent) - - self.setWindowTitle('ROI Manager') - - mainLayout = QVBoxLayout() - - rois = roifile.roiread(rois_filepath) - self.rois = {roi.name: roi for roi in rois} - - roisNamesTreeWidget = widgets.TreeWidget() - roisNamesTreeWidget.setHeaderLabels(['ROI name', 'Cell_ID']) - roisNamesTreeWidget.header().setSectionResizeMode( - QHeaderView.ResizeToContents - ) - # roisNamesTreeWidget.header().setStretchLastSection(False) - for r, roi in enumerate(rois): - item = widgets.TreeWidgetItem() - item.setText(0, roi.name) - item.setText(1, str(r+1)) - roisNamesTreeWidget.addTopLevelItem(item) - roisNamesTreeWidget.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection - ) - roisNamesTreeWidget.selectAll() - mainLayout.addWidget(QLabel('Select ROIs to convert')) - mainLayout.addWidget(roisNamesTreeWidget) - self.roisNamesTreeWidget = roisNamesTreeWidget - mainLayout.addSpacing(10) - mainLayout.addWidget(widgets.QHLine()) - mainLayout.addSpacing(5) - - gridLayout = None - self.lowZspinbox = None - - SizeT, SizeZ, SizeY, SizeX = TZYX_shape - if SizeZ > 1: - gridLayout = QGridLayout() - self.lowZspinbox = widgets.SpinBox() - self.lowZspinbox.setMinimum(0) - self.lowZspinbox.setMaximum(SizeZ-1) - - self.highZspinbox = widgets.SpinBox() - self.highZspinbox.setMinimum(0) - self.highZspinbox.setMaximum(SizeZ-1) - self.highZspinbox.setValue(SizeZ-1) - - gridLayout.addWidget(QLabel('Repeat 2D ROIs over z-range: '), 1, 0) - - gridLayout.addWidget(QLabel('Start z-slice'), 0, 1) - gridLayout.addWidget(self.lowZspinbox, 1, 1) - - gridLayout.addWidget(QLabel('Stop z-slice'), 0, 2) - gridLayout.addWidget(self.highZspinbox, 1, 2) - - if gridLayout is not None: - mainLayout.addLayout(gridLayout) - mainLayout.addSpacing(5) - mainLayout.addWidget(widgets.QHLine()) - mainLayout.addSpacing(10) - - self.rescaleRoisGroupbox = widgets.RescaleImageJroisGroupbox(TZYX_shape) - self.rescaleRoisGroupbox.setChecked(False) - mainLayout.addWidget(self.rescaleRoisGroupbox) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - self.useSamePropsForNextPos = False - if addUseSamePropsForNextPosButton: - useSamePropsForNextPosButton = widgets.reloadPushButton( - 'Keep the same preferences for all next Positions' - ) - buttonsLayout.insertWidget(3, useSamePropsForNextPosButton) - useSamePropsForNextPosButton.clicked.connect( - self.useSamePropsForNextPosClicked - ) - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def useSamePropsForNextPosClicked(self): - self.useSamePropsForNextPos = True - self.ok_cb() - - def warnRoiSelectionEmpty(self): - txt = html_utils.paragraph(f""" - You did not select any ROI.

- ROIs selection cannot be empty. Thank you for your patience! - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'ROIs selection empty', txt) - - def ok_cb(self): - selectedRois = self.roisNamesTreeWidget.selectedItems() - if not selectedRois: - self.useSamePropsForNextPos = False - self.warnRoiSelectionEmpty() - return - - self.IDsToRoisMapper = {} - for item in selectedRois: - roiName = item.text(0) - ID = int(item.text(1)) - self.IDsToRoisMapper[ID] = self.rois[roiName] - - numRois = self.roisNamesTreeWidget.topLevelItemCount() - self.areAllRoisSelected = len(self.IDsToRoisMapper) == numRois - - self.rescaleSizes = self.rescaleRoisGroupbox.inputOutputSizes() - self.repeatRoisZslicesRange = None - if self.lowZspinbox is not None: - self.repeatRoisZslicesRange = ( - self.lowZspinbox.value(), self.highZspinbox.value()+1 - ) - - self.cancel = False - self.close() - -class ResizeUtilProps(QBaseDialog): - def __init__(self, input_path='', parent=None): - self.cancel = True - super().__init__(parent) - - self.setWindowTitle('Resize Data Properties') - - mainLayout = QVBoxLayout() - - paramsLayout = QGridLayout() - - self._input_path = input_path - - row = 0 - paramsLayout.addWidget(QLabel('Overwrite raw data: '), row, 0) - self.overwriteToggle = widgets.Toggle() - self.overwriteToggle.setChecked(True) - paramsLayout.addWidget( - self.overwriteToggle, row, 1, 1, 2, alignment=Qt.AlignCenter - ) - - row += 1 - paramsLayout.addWidget( - QLabel('Folder path for resized images: '), row, 0 - ) - self.folderPathOutControl = widgets.filePathControl( - browseFolder=True, - fileManagerTitle='Select folder where to save resized data', - elide=True, - startFolder=myutils.getMostRecentPath() - ) - self.folderPathOutControl.setDisabled(True) - paramsLayout.addWidget(self.folderPathOutControl, row, 1, 1, 2) - - row += 1 - paramsLayout.addWidget(QLabel('Text to append to files: '), row, 0) - self.textToAppendLineEdit = widgets.alphaNumericLineEdit() - self.textToAppendLineEdit.setAlignment(Qt.AlignCenter) - self.textToAppendLineEdit.setDisabled(True) - paramsLayout.addWidget(self.textToAppendLineEdit, row, 1, 1, 2) - - row += 1 - paramsLayout.addWidget(QLabel('Resize mode: '), row, 0) - self.downScaleRadioButton = QRadioButton('Downscale') - self.upScaleRadioButton = QRadioButton('Upscale') - self.downScaleRadioButton.setChecked(True) - paramsLayout.addWidget( - self.downScaleRadioButton, row, 1, alignment=Qt.AlignCenter - ) - paramsLayout.addWidget( - self.upScaleRadioButton, row, 2, alignment=Qt.AlignCenter - ) - - row += 1 - paramsLayout.addWidget(QLabel('Resize factor: '), row, 0) - self.factorSpinbox = widgets.FloatLineEdit(allowNegative=False) - self.factorSpinbox.setMinimum(1.0) - self.factorSpinbox.setValue(2.0) - paramsLayout.addWidget(self.factorSpinbox, row, 1, 1, 2) - - paramsLayout.setColumnStretch(0, 0) - paramsLayout.setVerticalSpacing(10) - - self.overwriteToggle.toggled.connect(self.overwriteToggled) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(paramsLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch(1) - - # self.textToAppendLineEdit.setText(self._getDefaultTextToAppend()) - - self.setLayout(mainLayout) - - def _getDefaultTextToAppend(self): - rescale_mode = 'up' if self.upScaleRadioButton.isChecked() else 'down' - factor = self.factorSpinbox.value() - text = f'{rescale_mode}scaled_factor_{factor}' - return text - - def overwriteToggled(self, checked): - self.folderPathOutControl.setDisabled(checked) - self.textToAppendLineEdit.setDisabled(checked) - if checked: - text = '' - else: - text = self._getDefaultTextToAppend() - self.textToAppendLineEdit.setText(text) - - def warnFolderPathEmpty(self): - txt = html_utils.paragraph(""" - To prevent overwriting raw data the Folder path for - resized images cannot be empty. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Empty folder path', txt) - - def warnTextToAppendEmpty(self): - txt = html_utils.paragraph(""" - To prevent overwriting raw data the text to append - cannot be empty. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Empty text to append', txt) - - def ok_cb(self): - self.expFolderpathOut = self.folderPathOutControl.path() - self.textToAppend = self.textToAppendLineEdit.text() - isAccidentalOverwrite = ( - not self.overwriteToggle.isChecked() - and self.expFolderpathOut == self._input_path - and not self.textToAppend - ) - if isAccidentalOverwrite: - self.warnTextToAppendEmpty() - return - - if self.textToAppend and not self.textToAppend.startswith('_'): - self.textToAppend = f'_{self.textToAppend}' - - if self.overwriteToggle.isChecked(): - self.expFolderpathOut = None - - factor = self.factorSpinbox.value() - self.resizeFactor = ( - factor if self.upScaleRadioButton.isChecked() else 1/factor - ) - - self.cancel = False - self.close() - -class LogoDialog(QDialog): - def __init__(self, logo_path, icon_path, parent=None): - super().__init__(parent) - - layout = QVBoxLayout() - - self.setWindowFlags(Qt.FramelessWindowHint) - # self.setWindowFlags(Qt.WindowStaysOnTopHint | Qt.FramelessWindowHint) - # self.setAttribute(Qt.WA_TranslucentBackground) - # self.setWindowIcon(QIcon(icon_path)) - - labelLogo = QLabel() - pixmapLogo = QPixmap(logo_path) - labelLogo.setPixmap(pixmapLogo) - - layout.addWidget(labelLogo) - - self.setLayout(layout) - -class SetCustomLevelsLut(QBaseDialog): - sigLevelsChanged = Signal(object) - - def __init__( - self, - init_min_value=None, - init_max_value=None, - minimum_min_value=0, - maximum_max_value=None, - parent=None - ): - super().__init__(parent=parent) - - self.cancel = True - - self.setWindowTitle('Custom LUT levels') - - layout = QVBoxLayout() - - self.minLevelSlider = widgets.sliderWithSpinBox( - title='Minimum', - title_loc='top', - ) - self.minLevelSlider.setMinimum(minimum_min_value) - - if init_min_value is not None: - self.minLevelSlider.setValue(init_min_value) - - layout.addWidget(self.minLevelSlider) - - self.maxLevelSlider = widgets.sliderWithSpinBox( - title='Maximum', - title_loc='top', - ) - self.maxLevelSlider.setMinimum(minimum_min_value) - if init_max_value is not None: - self.maxLevelSlider.setValue(init_max_value) - - if maximum_max_value is not None: - self.maxLevelSlider.setMaximum(maximum_max_value) - self.minLevelSlider.setMaximum(maximum_max_value) - - layout.addWidget(self.maxLevelSlider) - - self.minLevelSlider.sigValueChange.connect(self.emitLevelsChanged) - self.maxLevelSlider.sigValueChange.connect(self.emitLevelsChanged) - - buttonsLayout = widgets.CancelOkButtonsLayout() - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - layout.addSpacing(20) - layout.addLayout(buttonsLayout) - - self.setLayout(layout) - - def sizeHint(self): - heightHint = super().sizeHint().height() - widthHint = super().sizeHint().width()*2 - return QSize(widthHint, heightHint) - - - def levels(self): - levels = (self.minLevelSlider.value(), self.maxLevelSlider.value()) - return levels - - def emitLevelsChanged(self, value): - self.sigLevelsChanged.emit(self.levels()) - - def ok_cb(self): - self.cancel = False - self.selectedLevels = self.levels() - self.close() - -class FucciPreprocessDialog(FunctionParamsDialog): - def __init__( - self, channel_names, - df_metadata=None, - parent=None, - ): - - from cellacdc.preprocess import fucci_filter - params_argspecs = myutils.get_function_argspec(fucci_filter) - - super().__init__( - params_argspecs, - function_name='FUCCI pre-processing', - df_metadata=df_metadata, - parent=parent, - ) - - channelNamesLayout = QGridLayout() - - row = 0 - label = QLabel('First channel name: ') - channelNamesLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) - self.firstChNameWidget = QComboBox() - self.firstChNameWidget.addItems(channel_names) - channelNamesLayout.addWidget(self.firstChNameWidget, row, 1) - - row += 1 - label = QLabel('Second channel name: ') - channelNamesLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) - self.secondChNameWidget = QComboBox() - self.secondChNameWidget.addItems(channel_names) - self.secondChNameWidget.setCurrentText(list(channel_names)[1]) - channelNamesLayout.addWidget(self.secondChNameWidget, row, 1) - - channelNamesLayout.setColumnStretch(0, 0) - channelNamesLayout.setColumnStretch(1, 1) - - self.mainLayout.insertLayout(0, channelNamesLayout) - self.mainLayout.insertWidget(1, widgets.QHLine()) - - def ok_cb(self): - self.firstChannelName = self.firstChNameWidget.currentText() - self.secondChannelName = self.secondChNameWidget.currentText() - super().ok_cb() - -class ViewCcaTableWindow(pdDataFrameWidget): - sigUpdateCcaTable = Signal(object) - - def __init__(self, df, parent=None): - super().__init__(df, parent=parent) - - updateTableButton = widgets.reloadPushButton( - 'Update table with visible IDs...' - ) - buttonsLayout = QHBoxLayout() - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(updateTableButton) - - self._layout.insertLayout(0, buttonsLayout) - - updateTableButton.clicked.connect(self.emitUpdateCcaTable) - - def emitUpdateCcaTable(self): - self.sigUpdateCcaTable.emit(self) - - -class ObjectCountDialog(QBaseDialog): - sigShowEvent = Signal() - sigUpdateCounts = Signal() - - def __init__( - self, - categoryCountMapper: dict, - parent=None, - data: list['load.loadData'] | None=None - ): - super().__init__(parent=parent) - self.setWindowTitle('Object count') - - self.cancel = False - mainLayout = QVBoxLayout() - - cancelOkLayout = widgets.CancelOkButtonsLayout() - cancelOkLayout.okButton.clicked.connect(self.ok_cb) - cancelOkLayout.cancelButton.clicked.connect(self.close) - - self.data = data - if data is not None: - saveCountsButton = widgets.savePushButton( - 'Export counts to CSV table' - ) - saveCountsButton.clicked.connect(self.saveCounts) - cancelOkLayout.insertWidget(3, saveCountsButton) - - updateCountsButton = widgets.reloadPushButton('Update counts') - cancelOkLayout.insertWidget(3, updateCountsButton) - updateCountsButton.clicked.connect(self.emitUpdateCounts) - - mainLayout.addWidget( - QLabel(html_utils.paragraph('Object count
', font_size='18px')), - alignment=Qt.AlignLeft - ) - self.showHideButtons = [] - self.categoryLabelMapper = {} - for category, count in categoryCountMapper.items(): - categoryLayout = QHBoxLayout() - categoryLayout.addSpacing(10) - catText = html_utils.paragraph( - f'
{category}
', font_size='13px' - ) - catLabel = QLabel(catText) - categoryLayout.addWidget(catLabel) - categoryLayout.addStretch(1) - - countText = html_utils.paragraph( - f'
{count}
', font_size='13px' - ) - countLabel = QLabel(countText) - categoryLayout.addWidget(countLabel) - - self.categoryLabelMapper[category] = countLabel - - showHideButton = widgets.showDetailsButton(txt='') - showHideButton.setChecked(True) - showHideButton.sigToggled.connect( - partial(self.showHideCount, labels=(catLabel, countLabel)) - ) - showHideButton.setToolTip(f'Show/hide "{category}" count') - categoryLayout.addSpacing(10) - categoryLayout.addWidget(showHideButton) - showHideButton.category = category - - self.showHideButtons.append(showHideButton) - - categoryLayout.setStretch(0, 0) - categoryLayout.setStretch(1, 0) - categoryLayout.setStretch(3, 0) - - mainLayout.addLayout(categoryLayout) - mainLayout.addWidget(widgets.QHLine()) - - mainLayout.addSpacing(10) - - infoLayout = QHBoxLayout() - self.livePreviewCheckbox = QCheckBox('Live preview') - self.livePreviewCheckbox.setChecked(True) - infoLayout.addWidget(self.livePreviewCheckbox) - infoLayout.addStretch(1) - self.warnLabel = QLabel('') - infoLayout.addWidget(self.warnLabel) - self.livePreviewCheckbox.toggled.connect(self.updateWarnLabel) - mainLayout.addLayout(infoLayout) - - mainLayout.addSpacing(30) - mainLayout.addStretch(1) - mainLayout.addLayout(cancelOkLayout) - - self.setLayout(mainLayout) - - def saveCounts(self, checked=False): - categories = self.activeCategories() - for posData in self.data: - countMapper = posData.countObjectsInSegm(categories) - countMapper.pop('In current frame', None) - df_count_endname = posData.saveObjCounts(countMapper) - - txt = html_utils.paragraph(f""" - Done!

- Objects count table saved in every loaded Position folder
- as a CSV file ending with {df_count_endname} - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Objects count saved', txt) - - def updateWarnLabel(self, checked): - if not checked: - self.warnLabel.setText( - html_utils.paragraph( - 'WARNING: without live preview, counts are not updated', - font_color='red' - ) - ) - else: - self.warnLabel.setText('') - - def emitUpdateCounts(self): - self.sigUpdateCounts.emit() - - def activeCategories(self) -> List[str]: - activeCategories = [] - for showHideButton in self.showHideButtons: - if not showHideButton.isChecked(): - continue - activeCategories.append(showHideButton.category) - - return activeCategories - - def showHideCount(self, checked, labels): - for label in labels: - label.setVisible(checked) - - QTimer.singleShot(100, self.resizeToHeightHint) - - def updateCounts(self, categoryCountMapper): - for category, count in categoryCountMapper.items(): - countLabel = self.categoryLabelMapper[category] - countText = html_utils.paragraph( - f'
{count}
', font_size='13px' - ) - countLabel.setText(countText) - - - def resizeToHeightHint(self): - heightHint = self.sizeHint().height() - self.resize(self.width(), heightHint) - - def showEvent(self, event): - widthHint = self.sizeHint().width() - self.resize(int(widthHint*1.5), self.height()) - self.sigShowEvent.emit() - - def ok_cb(self): - self.cancel = False - self.close() - -class PreProcessRecipeDialog(QBaseDialog): - sigApplyImage = Signal(object) - sigApplyZstack = Signal(object) - sigApplyAllFrames = Signal(object) - sigApplyAllPos = Signal(object) - sigPreviewToggled = Signal(bool) - sigValuesChanged = Signal(list) - sigSavePreprocData = Signal(object) - sigClose = Signal(object) - - def __init__( - self, - isTimelapse=False, - isZstack=False, - isMultiPos=False, - df_metadata=None, - addApplyButton=False, - parent=None, - hideOnClosing=False, - ): - super().__init__(parent=parent) - - self.setWindowTitle('Pre-processing recipe') - - self.cancel = True - self.hideOnClosing = hideOnClosing - - mainLayout = QVBoxLayout() - - keepInputDataTypeLayout = QHBoxLayout() - self.keepInputDataTypeToggle = widgets.Toggle() - self.keepInputDataTypeToggle.setChecked(True) - self.keepInputDataTypeToggle.toggled.connect(self.emitValuesChanged) - - keepInputDataTypeLayout.addStretch(1) - keepInputDataTypeLayout.addWidget(QLabel('Keep input data type: ')) - keepInputDataTypeLayout.addWidget(self.keepInputDataTypeToggle) - keepInputDataTypeInfoButton = widgets.infoPushButton() - keepInputDataTypeLayout.addWidget(keepInputDataTypeInfoButton) - keepInputDataTypeInfoButton.clicked.connect( - self.showInfoKeepInputDataType - ) - self.keepInputDataTypeLayout = keepInputDataTypeLayout - - self.preProcessParamsWidget = PreProcessParamsWidget( - df_metadata=df_metadata, - addApplyButton=addApplyButton, - parent=self - ) - self.preProcessParamsWidget.groupbox.setCheckable(False) - - buttonsLayout = QGridLayout() # self.preProcessParamsWidget.buttonsLayout - self.buttonsLayout = buttonsLayout - self.previewCheckbox = QCheckBox('Preview') - buttonsLayout.addWidget(self.previewCheckbox, 0, 0) - - # Relocate buttons of PreProcessParamsWidget to this dialog - pPPWBL = self.preProcessParamsWidget.buttonsLayout - loadRecipeButtIdx = pPPWBL.indexOf( - self.preProcessParamsWidget.loadRecipeButton - ) - self.loadRecipeButton = pPPWBL.takeAt(loadRecipeButtIdx).widget() - buttonsLayout.addWidget(self.loadRecipeButton, 0, 1) - - saveRecipeButtIdx = pPPWBL.indexOf( - self.preProcessParamsWidget.saveRecipeButton - ) - self.saveRecipeButton = pPPWBL.takeAt(saveRecipeButtIdx).widget() - buttonsLayout.addWidget(self.saveRecipeButton, 1, 1) - - loadLastRecipeButtIdx = pPPWBL.indexOf( - self.preProcessParamsWidget.loadLastRecipeButton - ) - self.loadLastRecipeButton = pPPWBL.takeAt(loadLastRecipeButtIdx).widget() - buttonsLayout.addWidget(self.loadLastRecipeButton, 1, 0) - - self.loadLastRecipeButton.hide() - - # self.cancelButton = widgets.cancelPushButton('Cancel') - # buttonsLayout.insertWidget(2, self.cancelButton) - # buttonsLayout.insertSpacing(3, 20) - - self.allButtons = [ - self.previewCheckbox, - self.loadRecipeButton, - self.saveRecipeButton, - ] - col = 3 - row = 0 - self.applyCurrentFrameButton = widgets.okPushButton( - 'Apply to displayed image' - ) - buttonsLayout.addWidget(self.applyCurrentFrameButton, row, col) - self.applyCurrentFrameButton.clicked.connect( - partial(self.apply, signal=self.sigApplyImage) - ) - self.allButtons.append(self.applyCurrentFrameButton) - - infoLayout = QHBoxLayout() - buttonsHeight = self.applyCurrentFrameButton.sizeHint().height() - self.loadingCircle = widgets.LoadingCircleAnimation( - size=buttonsHeight - ) - sp = self.loadingCircle.sizePolicy() - sp.setRetainSizeWhenHidden(True) - self.loadingCircle.setSizePolicy(sp) - self.loadingCircle.setVisible(False) - infoLayout.addWidget(self.loadingCircle) - - self.infoLabel = QLabel( - "(Feel free to use Cell-ACDC while waiting)" - ) - sp = self.infoLabel.sizePolicy() - sp.setRetainSizeWhenHidden(True) - self.infoLabel.setSizePolicy(sp) - self.infoLabel.hide() - infoLayout.addWidget(self.infoLabel) - - buttonsLayout.addLayout( - infoLayout, row+1, 0, 3, 2, - alignment=Qt.AlignBottom | Qt.AlignLeft - ) - - if isZstack: - row += 1 - self.applyAllZslicesButton = widgets.threeDPushButton( - 'Apply to all z-slices of current image' - ) - buttonsLayout.addWidget(self.applyAllZslicesButton, row, col) - self.applyAllZslicesButton.clicked.connect(self.applyAllZslices) - self.allButtons.append(self.applyAllZslicesButton) - if isTimelapse: - row += 1 - self.applyAllFramesButton = widgets.futurePushButton( - 'Apply to all frames' - ) - buttonsLayout.addWidget(self.applyAllFramesButton, row, col) - self.applyAllFramesButton.clicked.connect(self.applyAllFrames) - self.allButtons.append(self.applyAllFramesButton) - if isMultiPos: - row += 1 - self.applyAllPosButton = widgets.futurePushButton( - 'Apply to all Positions' - ) - buttonsLayout.addWidget(self.applyAllPosButton, row, col) - self.applyAllPosButton.clicked.connect( - partial(self.apply, signal=self.sigApplyAllPos) - ) - self.allButtons.append(self.applyAllPosButton) - - row += 1 - self.savePreprocButton = widgets.savePushButton( - 'Save pre-processed data...' - ) - buttonsLayout.addWidget(self.savePreprocButton, row, col) - - self.allButtons.append(self.savePreprocButton) - self.savePreprocButton.clicked.connect(self.emitSignalSavePreprocData) - - self.previewCheckbox.toggled.connect(self.emitSigPreviewToggled) - self.preProcessParamsWidget.sigValuesChanged.connect( - self.emitValuesChanged - ) - - # self.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(keepInputDataTypeLayout) - mainLayout.addSpacing(20) - mainLayout.addWidget(self.preProcessParamsWidget) - mainLayout.addLayout(buttonsLayout) - self.mainLayout = mainLayout - - self.setLayout(mainLayout) - - def applyAllZslices(self, checked=False): - # Preview needs to be turned off because we are computing on every - # z-slice - self.previewCheckbox.setChecked(False) - self.apply(signal=self.sigApplyZstack) - - def applyAllFrames(self, checked=False): - # Preview needs to be turned off because we are computing on all frames - self.previewCheckbox.setChecked(False) - self.apply(signal=self.sigApplyAllFrames) - - def emitSigPreviewToggled(self): - self.sigPreviewToggled.emit(self.previewCheckbox.isChecked()) - - def showInfoKeepInputDataType(self): - txt = html_utils.paragraph(""" - If checked, the data type of the pre-processed data will be - the same as the input data type.

- This is useful to avoid saving the pre-processed data as - floating-point numbers (e.g., 32-bit float) which might - increase the file size.

- We recommend keeping this option checked. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Keep input data type', txt) - - def emitSignalSavePreprocData(self): - self.sigSavePreprocData.emit(self) - - def emitValuesChanged(self): - recipe = self.recipe(warn=False) - if recipe is None: - return - - self.sigValuesChanged.emit(recipe) - - def setDisabled(self, disabled: bool): - self.preProcessParamsWidget.setDisabled(disabled) - self.loadingCircle.setVisible(disabled) - self.infoLabel.setVisible(disabled) - for button in self.allButtons: - try: - button.setDisabled(disabled) - except RuntimeError as e: - printl(traceback.format_exc()) - printl(f'Error: {e}') - printl(f'Button: {button}') - - def apply(self, checked=False, signal: Signal=None): - recipe = self.recipe() - if recipe is None: - return - - if signal is not None: - signal.emit(recipe) - - if self.hideOnClosing: - self.setDisabled(True) - self.infoLabel.setText( - f"{self.sender().text().replace('Apply', 'Applying')}...
" - "(Feel free to use Cell-ACDC while waiting)" - ) - else: - self.ok_cb() - - def appliedFinished(self): - self.setDisabled(False) - - def recipe(self, warn=True): - recipe = self.preProcessParamsWidget.recipe(warn=warn) - if recipe is None: - return - - for step in recipe: - step['keep_input_data_type'] = ( - self.keepInputDataTypeToggle.isChecked() - ) - return recipe - - def recipeConfigPars(self): - return self.preProcessParamsWidget.recipeConfigPars('acdc') - - def ok_cb(self): - if self.hideOnClosing: - self.hide() - return - - self.cancel = False - self.close() - - def close(self): - super().close() - self.sigClose.emit(self) - -class PreProcessRecipeDialogUtil(PreProcessRecipeDialog): - def __init__( - self, - channel_names: Iterable[str], - df_metadata=None, - parent=None, - ): - self.cancel = True - - super().__init__( - isTimelapse=False, - isZstack=False, - isMultiPos=False, - addApplyButton=False, - df_metadata=df_metadata, - parent=parent, - hideOnClosing=False - ) - - self.listSelector = widgets.listWidget( - isMultipleSelection=True, minimizeHeight=True - ) - self.listSelector.addItems(channel_names) - self.listSelector.setCurrentRow(0) - - self.mainLayout.insertWidget(0, self.listSelector) - self.mainLayout.insertWidget( - 0, QLabel('Select channel(s) to pre-process:') - ) - self.mainLayout.insertSpacing(2, 10) - self.mainLayout.insertWidget(2, widgets.QHLine()) - - self.savePreprocButton.hide() - self.previewCheckbox.hide() - self.applyCurrentFrameButton.setText('Ok') - - buttonsLayout = self.preProcessParamsWidget.buttonsLayout - - saveRecipeButtonIndex = buttonsLayout.indexOf( - self.preProcessParamsWidget.saveRecipeButton - ) - - if saveRecipeButtonIndex == -1: - return - - saveRecipeButtonItem = buttonsLayout.takeAt(saveRecipeButtonIndex) - - buttonsLayout.addItem(saveRecipeButtonItem, 0, 2) - - def warnChannelSelectionEmpty(self): - txt = html_utils.paragraph(""" - You did not select any channel.

- Channel selection cannot be empty.

- Thank you for your patience! - """) - - def ok_cb(self): - selectedChannelItems = self.listSelector.selectedItems() - if not selectedChannelItems: - self.warnChannelSelectionEmpty() - - recipe = self.recipe() - if recipe is None: - return - - self.selectedRecipe = recipe - self.selectedChannels = [item.text() for item in selectedChannelItems] - - self.cancel = False - self.close() - - -# class ComboDelegate(QStyledItemDelegate): -# def __init__(self, bad_values, parent=None): -# super().__init__(parent) -# self.bad_values = bad_values - -# def paint(self, painter, option, index): -# text = index.data() -# if text in self.bad_values: -# option.palette.setColor(option.palette.Text, QColor("red")) -# super().paint(painter, option, index) - -class CombineChannelsSetupDialog(PreProcessRecipeDialog): - sigApplyImage = Signal(dict, bool, str) - sigApplyZstack = Signal(dict, bool, str) - sigApplyAllFrames = Signal(dict, bool, str) - sigApplyAllPos = Signal(dict, bool, str) - sigValuesChanged = Signal() - sigSaveAsSegmCheckboxToggled = Signal(bool) - - - # sigApplyAllZslices = Signal(dict, bool, str) - # sigApplyAllFramesZslices = Signal(dict, bool, str) - - def __init__( - self, - channel_names, - df_metadata=None, - parent=None, - hideOnClosing=False, - isTimelapse=False, - isZstack=False, - isMultiPos=False, - ): - - self.combineChannelsWidget = CombineChannelsWidget(channel_names, parent=self) - self.warnExistingRecipeFile = self.combineChannelsWidget.warnExistingRecipeFile - self.communicateSavingRecipeFinished = self.combineChannelsWidget.communicateSavingRecipeFinished - self.saveRecipeUI = self.combineChannelsWidget.saveRecipeUI - self.selectRecipeFilepath = self.combineChannelsWidget.selectRecipeFilepath - - super().__init__( - isTimelapse=isTimelapse, - isZstack=isZstack, - isMultiPos=isMultiPos, - df_metadata=df_metadata, - parent=parent, - hideOnClosing=hideOnClosing, - ) - - self.combineChannelsWidget.sigValuesChangedCombineChannels.connect( - self.emitValuesChangedSteps - ) - - - self.segm_blinked = False - self.validFormula = True # allow empty formula - self.forbiddenChannels = set() # channels that cannot be combined - - self.mainLayout.setSpacing(4) - - self.mainLayout.insertWidget(2, self.combineChannelsWidget) - self.combineChannelsWidget.groupbox.setCheckable(False) - self.combineChannelsWidget.groupbox.setTitle('Combine and manipulate channels and/or segmentation files') - - self.formulaEditWidget = FormulaEditWidget(parent=self) - self._updateFormulaVariableNames() - self.formulaEditWidget.sigFormulaChanged.connect(self.formulaChanged) - self.formulaEditWidget.setToolTip( - 'Enter a formula to combine the channels. For example ' - '"img1 + img2 * 0.5"' - ) - self.mainLayout.insertWidget(3, self.formulaEditWidget) - - buttonsLayoutSaveGroup = QGridLayout() - - row = 0 - col = 0 - loadRecipeButton = widgets.OpenFilePushButton('Load saved recipe') - self.loadRecipeButtonComb = loadRecipeButton - buttonsLayoutSaveGroup.addWidget(loadRecipeButton, row, col) - self.loadRecipeButtonComb.clicked.connect(self.selectAndLoadRecipe) - - col += 1 - saveRecipeButton = widgets.savePushButton('Save current recipe') - self.saveRecipeButtonComb = saveRecipeButton - buttonsLayoutSaveGroup.addWidget(saveRecipeButton, row, col) - saveRecipeButton.clicked.connect(self.saveRecipe) - saveRecipeButton.setToolTip( - 'Save the current recipe to a file\n' - f'Location: {combine_channels_recipes_path}' - ) - - col += 1 - loadLastRecipeButton = widgets.reloadPushButton('Load last recipe') - self.loadLastRecipeButtonComb = loadLastRecipeButton - buttonsLayoutSaveGroup.addWidget(loadLastRecipeButton, row, col) - self.mainLayout.addLayout(buttonsLayoutSaveGroup) - loadLastRecipeButton.clicked.connect(self.loadLastRecipe) - self.setLoadLastRecipe() - - loadLastRecipeButton.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) - loadLastRecipeButton.customContextMenuRequested.connect( - self._showLoadRecipeContextMenu - ) - - self.cancel = True - - self.setWindowTitle('Combine and manipulate channels and/or segmentation files') - self.preProcessParamsWidget.hide() - self.mainLayout.removeWidget(self.preProcessParamsWidget) - - self.savePreprocButton.setText('Save combined data...') - - - tooltip = ( - 'Save as a segmentation file, for example ' - 'when combining a binary mask with a segmentation mask.' - ) - label = QLabel('Save as segmentation:') - self.saveAsSegmlabel = label - label.setToolTip(tooltip) - self.saveAsSegmCheckbox = widgets.Toggle() - self.saveAsSegmCheckbox.setToolTip(tooltip) - self.saveAsSegmCheckbox.setChecked(False) - self.saveAsSegmCheckbox.setEnabled(False) - self.saveAsSegmCheckbox.toggled.connect(self.emitSaveAsSegmCheckboxToggled) - - self.keepInputDataTypeLayout.insertWidget(0, label) - self.keepInputDataTypeLayout.insertWidget(1, self.saveAsSegmCheckbox) - - def setLoadLastRecipe(self): - filepath = self._lastRecipePath() - if not os.path.exists(filepath): - self.loadLastRecipeButtonComb.setEnabled(False) - - def returLoadSecondLastRecipe(self): - filepath = self._secondLastRecipePath() - if not os.path.exists(filepath): - return False - return True - - def _showLoadRecipeContextMenu(self, pos): - menu = QMenu(self) - action = menu.addAction('Load recipe from before the last one') - action.triggered.connect(self.loadPreviousRecipe) - action.setEnabled(self.returLoadSecondLastRecipe()) - menu.exec(self.loadLastRecipeButtonComb.mapToGlobal(pos)) - - def loadPreviousRecipe(self): - filepath = self._secondLastRecipePath() - if not os.path.exists(filepath): - return - - self.loadRecipe(filepath) - - def loadLastRecipe(self): - filepath = self._lastRecipePath() - if not os.path.exists(filepath): - return - - self.loadRecipe(filepath) - - def saveLastRecipe(self): - os.makedirs(combine_channels_recipes_path, exist_ok=True) - filepath = self._lastRecipePath() - - same = False - if os.path.exists(filepath): - steps_curr = self._getSaveRecipyDict() - with open(filepath, 'r') as f: - steps_prev = json.load(f) - same = self._recipesMatch(steps_curr, steps_prev) - - if same: - return - - if os.path.exists(filepath): - new_filename = self._secondLastRecipePath() - if os.path.exists(new_filename): - os.remove(new_filename) - os.rename(filepath, new_filename) - self.saveRecipe(filepath=filepath) - - - def _recipesMatch(self, steps_curr, steps_prev): - # Normalize current dict to strings for comparison with JSON-loaded dict - def normalize(d): - return {str(k): str(v) for k, v in d.items()} - - for raw_key in steps_curr: - key = str(raw_key) - if key not in steps_prev: - return False - if key in ('formula', 'keep_input_data_type', 'save_as_segm'): - if str(steps_curr[raw_key]) != str(steps_prev[key]): - return False - else: - step_dict = normalize(steps_curr[raw_key]) - step_dict_prev = steps_prev[key] - for key2, val2 in step_dict.items(): - if key2 not in step_dict_prev: - return False - if val2 != str(step_dict_prev[key2]): - return False - return True - - def _lastRecipePath(self): - return os.path.join(combine_channels_recipes_path, '.last_combine_channels_recipe.json') - - def _secondLastRecipePath(self): - return os.path.join(combine_channels_recipes_path, '.previous_combine_channels_recipe.json') - - def _getSaveRecipyDict(self): - steps = self.combineChannelsWidget.steps() # already returns a copy - formula = self.formulaEditWidget.text() - steps['formula'] = formula - steps['keep_input_data_type'] = self.keepInputDataTypeToggle.isChecked() - steps['save_as_segm'] = self.saveAsSegmCheckbox.isChecked() - return steps - - def saveRecipe(self, dummy=None, filepath=None): - os.makedirs(combine_channels_recipes_path, exist_ok=True) - - filepath_provided = filepath is not None - if not filepath_provided: - folder_content = myutils.listdir(combine_channels_recipes_path) - num_recipes = len(folder_content) - default_text = f'{num_recipes + 1}' - proceed, filepath = self.saveRecipeUI( - combine_channels_recipes_path, '.json', - 'Save recipe', 'combine_channels_recipe', - 'Insert a filename for the recipe:', - default_text - ) - - if not proceed: - return - - steps = self._getSaveRecipyDict() - - with open(filepath, 'w') as f: - json.dump(steps, f, indent=2) - - if not filepath_provided: - self.communicateSavingRecipeFinished(filepath) - - def selectAndLoadRecipe(self): - filepath = self.selectRecipeFilepath( - combine_channels_recipes_path, - 'combine_channels_recipe', 'JSON', 'json' - ) - if filepath is None: - return - - self.loadRecipe(filepath) - - def loadRecipe(self, filepath): - with open(filepath, 'r') as f: - recipe = json.load(f) - - recipe = dict(sorted(recipe.items())) - keys_used = set() - for key, value in recipe.items(): - if key == 'formula': - formula = value - continue - if key == 'keep_input_data_type': - self.keepInputDataTypeToggle.setChecked(value) - continue - if key == 'save_as_segm': - self.saveAsSegmCheckbox.setChecked(value) - continue - - name = value['name'] - channel = value['channel'] - binarize = value['binarize'] - min_val = float(value['min_val']) - max_val = float(value['max_val']) - key = int(key) - stepWidgetsNum = len(self.combineChannelsWidget.stepsWidgets) - if key > stepWidgetsNum: - self.combineChannelsWidget.addStep() - - stepWidgets = self.combineChannelsWidget.stepsWidgets[key] - idx = stepWidgets['selector'].findText(channel) - if idx == -1: - stepWidgets['selector'].addItem(channel) - # stepWidgets['selector'].forbiddenItems.add(channel) - blinker = qutils.QControlBlink( - stepWidgets['selector'], - qparent=self - ) - blinker.start() - stepWidgets['selector'].blinker = blinker - self.forbiddenChannels.add(channel) - - stepWidgets['selector'].setCurrentText(channel) - stepWidgets['name_edit'].setText(name) - stepWidgets['binarize'].setCurrentText(binarize) - stepWidgets['minValueSpinbox'].setValue(min_val) - stepWidgets['maxValueSpinbox'].setValue(max_val) - - keys_used.add(key) - - # remove extra steps - keys_present = set(range(1, len(self.combineChannelsWidget.stepsWidgets)+1)) - extra_keys = keys_present - keys_used - extra_keys = list(extra_keys) - extra_keys.sort(reverse=True) - for key in extra_keys: - self.combineChannelsWidget.removeStep(step_n = key) - # updates key dynamically so I have to rely that missing indx are always last steps - - # update formula - self.formulaEditWidget.setText(formula) - - for stepWidgets in self.combineChannelsWidget.stepsWidgets.values(): - combo = stepWidgets['selector'] - # set forbidden channels red in all steps - for i in range(combo.count()): - item = combo.itemText(i) - if item in self.forbiddenChannels: - combo.setItemData(i, QColor('red'), Qt.ForegroundRole) - - def _updateFormulaVariableNames(self): - names = [ - stepWidgets['name_edit'].text() - for stepWidgets in self.combineChannelsWidget.stepsWidgets.values() - ] - self.formulaEditWidget.setVariableNames(names) - - def formulaChanged(self, formula_str, is_valid): - self.setButtonsEnabled(is_valid) - self.validFormula = is_valid - if is_valid: - self.sigValuesChanged.emit() - - def setButtonsEnabled(self, enabled): - for i in range(self.buttonsLayout.count()): - item = self.buttonsLayout.itemAt(i) - widget = item.widget() - if widget is None: - continue - if isinstance(widget, QPushButton): - label = widget.text().lower().rstrip().lstrip() - if 'apply' in label or 'save' in label or 'ok' in label: - if enabled: - try: - widget.setEnabled(True) - except: - pass - else: - try: - widget.setDisabled(True) - except: - pass - - - def saveAsSegm(self): - return self.saveAsSegmCheckbox.isChecked() - - def emitSaveAsSegmCheckboxToggled(self): - if self.validFormula: - self.sigSaveAsSegmCheckboxToggled.emit(self.saveAsSegm()) - - def autoCheckSaveAsSegmCheckbox(self): - any_not_seg = False - for step in self.combineChannelsWidget.steps().values(): - channel = step['channel'] - if 'segm' not in channel: - any_not_seg = True - break - - if any_not_seg: - self.saveAsSegmCheckbox.setChecked(False) - self.saveAsSegmCheckbox.setEnabled(False) - else: - if not self.segm_blinked: - self.saveAsSegmCheckbox.setEnabled(True) - self.blinker = qutils.QControlBlink( - self.saveAsSegmCheckbox, - qparent=self - ) - self.blinker.start() - self.segm_blinked = True - - def apply(self, checked=False, signal: Signal=None): - steps = self.combineChannelsWidget.steps() - formula = self.formulaEditWidget.text() - keep_input_dtype = self.keepInputDataTypeToggle.isChecked() - if not steps or not self.validFormula: - return - - if signal is not None: - try: - signal.emit(steps, formula) - except TypeError as err: - signal.emit(steps, keep_input_dtype, formula) - - - self.saveLastRecipe() - if self.hideOnClosing: - self.setDisabled(True) - self.infoLabel.setText( - f"{self.sender().text().replace('Apply', 'Applying')}...
" - "(Feel free to use Cell-ACDC while waiting)" - ) - else: - self.ok_cb(saveLastRecipe=False) - # Not needed anymore since now we funnel all changes to the formulaEditWidget, which then verifies the formula and - # emits a signal via formulaChangeda - # def emitValuesChanged(self): - # if not self.validFormula: - # return - # self.sigValuesChanged.emit() - - def emitValuesChangedSteps(self): - self.autoCheckSaveAsSegmCheckbox() - self._updateFormulaVariableNames() - - def ok_cb(self, dummy=None, saveLastRecipe=True): - if not self.validFormula: - return - - if saveLastRecipe: - self.saveLastRecipe() - - self.keepInputDataType = self.keepInputDataTypeToggle.isChecked() - self.selectedSteps = self.combineChannelsWidget.steps() - self.formula = self.formulaEditWidget.text() - self.cancel = False - self.close() - -class CombineChannelsSetupDialogUtil(CombineChannelsSetupDialog): - def __init__( - self, - channel_names, - df_metadata=None, - parent=None, - ): - - super().__init__( - channel_names, - parent=parent, - df_metadata=df_metadata - ) - - # add int input for number of workers - - - self.mainLayout.addSpacing(20) - - qutils.hide_and_delete_layout(self.buttonsLayout) - buttonsLayout = widgets.CancelOkButtonsLayout() - self.buttonsLayout = buttonsLayout - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - self.mainLayout.addLayout(buttonsLayout) - - self.nThreadsSpinBox = QSpinBox() - self.nThreadsSpinBox.setMinimum(1) - self.nThreadsSpinBox.setValue(4) - self.nThreadsSpinBox.setToolTip("Number of threads to use for processing") - self.mainLayout.addWidget(QLabel("Number of threads:")) - self.mainLayout.addWidget(self.nThreadsSpinBox) - -class CombineChannelsSetupDialogGUI(CombineChannelsSetupDialog): - def __init__( - self, - channel_names: Iterable[str], - df_metadata=None, - isTimelapse=False, - isZstack=False, - isMultiPos=False, - parent=None, - hideOnClosing=False - ): - super().__init__( - channel_names, - df_metadata=df_metadata, - isTimelapse=isTimelapse, - isZstack=isZstack, - isMultiPos=isMultiPos, - parent=parent, - hideOnClosing=hideOnClosing, - ) - - # remove the preprocess buttons, we use the comb version of them - qutils.delete_widget(self.loadLastRecipeButton) - qutils.delete_widget(self.saveRecipeButton) - qutils.delete_widget(self.loadRecipeButton) - - # self.allButtons.remove(self.loadLastRecipeButton) - self.allButtons.remove(self.saveRecipeButton) - self.allButtons.remove(self.loadRecipeButton) - - self.previewCheckbox.setChecked(True) - self.saveAsSegmlabel.setText('Save and view as segmentation') - - def steps(self, return_keepInputDataType=False): - steps = self.combineChannelsWidget.steps() - formula = self.formulaEditWidget.text() - # if not return_keepInputDataType: - # return steps, formula - - keep_input_dtype = self.keepInputDataTypeToggle.isChecked() - return steps, keep_input_dtype, formula - -class QCropTrangeTool(QBaseDialog): - sigClose = Signal() - sigTvalueChanged = Signal(int) - sigReset = Signal() - sigCrop = Signal(int, int) - - def __init__( - self, SizeT, - cropButtonText='Apply crop', - parent=None, - addDoNotShowAgain=False, - title='Select frames range' - ): - super().__init__(parent) - - self.cancel = True - - self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) - - self.SizeT = SizeT - self.numDigits = len(str(self.SizeT)) - - self.setWindowTitle(title) - - layout = QGridLayout() - buttonsLayout = QHBoxLayout() - - self.startFrameScrollbar = widgets.sliderWithSpinBox( - spinbox_loc='left', - maximum_on_label=SizeT - ) - self.startFrameScrollbar.setMaximum(SizeT, including_spinbox=True) - self.startFrameScrollbar.setMinimum(1, including_spinbox=True) - - self.endFrameScrollbar = widgets.sliderWithSpinBox( - spinbox_loc='left', - maximum_on_label=SizeT - ) - self.endFrameScrollbar.setMaximum(SizeT, including_spinbox=True) - self.endFrameScrollbar.setMinimum(1, including_spinbox=True) - self.endFrameScrollbar.setValue(SizeT) - - cancelButton = widgets.cancelPushButton('Cancel') - cropButton = widgets.okPushButton(cropButtonText) - buttonsLayout.addWidget(cropButton) - buttonsLayout.addWidget(cancelButton) - - row = 0 - layout.addWidget( - QLabel('Start frame n. '), row, 0, alignment=Qt.AlignRight - ) - layout.addWidget(self.startFrameScrollbar, row, 2) - - row += 1 - layout.setRowStretch(row, 5) - layout.addItem(QSpacerItem(10, 10), row, 0) - - row += 1 - layout.addWidget( - QLabel('Stop frame n. '), row, 0, alignment=Qt.AlignRight - ) - layout.addWidget(self.endFrameScrollbar, row, 2) - - row += 1 - if addDoNotShowAgain: - self.doNotShowAgainCheckbox = QCheckBox('Do not ask again') - layout.addWidget( - self.doNotShowAgainCheckbox, row, 2, alignment=Qt.AlignLeft - ) - row += 1 - - layout.addItem(QSpacerItem(10, 20), row, 0) - layout.addLayout(buttonsLayout, row+1, 2, alignment=Qt.AlignRight) - - layout.setColumnStretch(0, 0) - layout.setColumnStretch(1, 0) - layout.setColumnStretch(2, 10) - - self.setLayout(layout) - - # resetButton.clicked.connect(self.emitReset) - cropButton.clicked.connect(self.emitCrop) - cancelButton.clicked.connect(self.close) - self.startFrameScrollbar.sigValueChange.connect(self.TvalueChanged) - self.endFrameScrollbar.sigValueChange.connect(self.TvalueChanged) - - def emitReset(self): - self.sigReset.emit() - - def emitCrop(self): - self.cancel = False - low_z = self.startFrameScrollbar.value() - 1 - high_z = self.endFrameScrollbar.value() - 1 - self.sigCrop.emit(low_z, high_z) - self.close() - - def updateScrollbars(self, start_frame_i, lower_frame_i): - self.startFrameScrollbar.setValue(start_frame_i + 1) - self.endFrameScrollbar.setValue(lower_frame_i + 1) - - def TvalueChanged(self, value): - frame_i = value - 1 - self.sigTvalueChanged.emit(frame_i) - - def showEvent(self, event): - self.resize(int(self.width()*2.0), self.height()) - - def closeEvent(self, event): - super().closeEvent(event) - self.sigClose.emit() - -class QTreeDialog(QBaseDialog): - def __init__( - self, - items: List[Tuple[str]], - headerLabels: List[str]=None, - parent=None, - infoText='Select item', - title='Select item', - path_to_browse=None, - additional_buttons=None, - ): - self.cancel = True - super().__init__(parent) - - self.setWindowTitle(title) - - mainLayout = QVBoxLayout() - - infoLabel = QLabel(html_utils.paragraph(infoText)) - - self.treeWidget = widgets.TreeWidget() - if headerLabels is not None: - self.treeWidget.setHeaderLabels(headerLabels) - else: - self.treeWidget.setHeaderHidden(True) - - for row, texts in enumerate(items): - item = widgets.TreeWidgetItem(self.treeWidget) - for i, text in enumerate(texts): - item.setText(i, text) - self.treeWidget.addTopLevelItem(item) - - self.treeWidget.resizeColumnToContents(0) - self.treeWidget.resizeColumnToContents(1) - - # self.treeWidget.header().setStretchLastSection(False) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - if path_to_browse is not None: - browseButton = widgets.showInFileManagerButton( - setDefaultText=True - ) - browseButton.setPathToBrowse(path_to_browse) - buttonsLayout.insertWidget(3, browseButton) - - if additional_buttons is not None: - for btn in additional_buttons: - buttonsLayout.insertWidget(3, btn) - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(infoLabel) - mainLayout.addWidget(self.treeWidget) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def show(self, block=False): - w = self.sizeHint().width() - h = self.sizeHint().height() - self.resize(int(w*1.3), h) - super().show(block=block) - - def ok_cb(self): - self.clickedButton = self.sender() - self.cancel = False - self.selectedItem = self.treeWidget.currentItem() - self.selectedText = self.selectedItem.text(0) - self.close() - -class SelectFoldersToAnalyse(QBaseDialog): - def __init__( - self, parent=None, - preSelectedPaths=None, - onlyExpPaths=False, - scanFolderTree=True, - instructionsText='Select experiment folders to analyse', - askSelectPosFolders=False - ): - super().__init__(parent) - - self.cancel = True - self.onlyExpPaths = onlyExpPaths - self.setWindowTitle('Select experiments to analyse') - self.scanTree = scanFolderTree - self.askSelectPosFolders = askSelectPosFolders - - mainLayout = QVBoxLayout() - - instructionsText = html_utils.paragraph( - f'{instructionsText}

' - 'Drag and drop folders or click on Add folder button to ' - 'add as many folders ' - 'as needed.
', font_size='14px' - ) - instructionsLabel = QLabel(instructionsText) - instructionsLabel.setAlignment(Qt.AlignCenter) - - infoText = html_utils.paragraph( - 'A valid folder is either a Position folder, ' - 'or an experiment folder (containing Position_n folders),
' - 'or any folder that contains multiple experiment folders.

' - - 'In the last case, Cell-ACDC will automatically scan the entire tree of ' - 'sub-directories
' - 'and will add all experiments having the right folder structure.
', - font_size='12px' - ) - infoLabel = QLabel(infoText) - infoLabel.setAlignment(Qt.AlignCenter) - - self.listWidget = widgets.listWidget() - self.listWidget.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection - ) - if preSelectedPaths is not None: - self.listWidget.addItems(preSelectedPaths) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - delButton = widgets.delPushButton('Remove selected path(s)') - browseButton = widgets.browseFileButton( - 'Add folder...', openFolder=True, - start_dir=myutils.getMostRecentPath() - ) - - buttonsLayout.insertWidget(3, delButton) - buttonsLayout.insertWidget(4, browseButton) - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - browseButton.sigPathSelected.connect(self.addFolderPath) - delButton.clicked.connect(self.removePaths) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addWidget(instructionsLabel) - mainLayout.addWidget(infoLabel) - mainLayout.addWidget(self.listWidget) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - mainLayout.addStretch(1) - - self.setLayout(mainLayout) - - self.setAcceptDrops(True) - - self.setFont(font) - - def dragEnterEvent(self, event): - event.acceptProposedAction() - - def dropEvent(self, event): - event.setDropAction(Qt.CopyAction) - for url in event.mimeData().urls(): - dropped_path = url.toLocalFile() - if os.path.isfile(dropped_path): - dropped_path = os.path.dirname(dropped_path) - - QTimer.singleShot(50, partial(self.addFolderPath, dropped_path)) - - def pathsList(self): - return [ - self.listWidget.item(i).text().replace('\\', '/') - for i in range(self.listWidget.count()) - ] - - def expFolderToPosFoldernamesMapper(self): - expPathsPosFoldernamesMapper = defaultdict(set) - for selectedPath in self.pathsList(): - pos_foldernames = myutils.get_pos_foldernames( - selectedPath, check_if_is_sub_folder=True - ) - if not pos_foldernames: - images_path = myutils.get_images_folderpath(selectedPath) - expPathsPosFoldernamesMapper[selectedPath].add('') - else: - expPath = load.get_exp_path(selectedPath) - expPathsPosFoldernamesMapper[expPath].update(pos_foldernames) - - expPathsPosFoldernamesMapper = { - expPath: natsorted(pos_foldernames) - for expPath, pos_foldernames in expPathsPosFoldernamesMapper.items() - } - return expPathsPosFoldernamesMapper - - def ok_cb(self): - self.cancel = False - self.paths = self.pathsList() - self.selectedExpFolderToPosFoldernamesMapper = ( - self.expFolderToPosFoldernamesMapper() - ) - self.close() - - def warnNoValidPathsFound(self, selected_path): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - The selected path (see below) does not contain any valid folder.

- Please, make sure to select a Position folder, the Images folder - inside a Position folder, or any folder containing a Position folder - as a sub-directory.

- Thank you for your patience!

- Selected path: - """) - msg.warning( - self, 'Training workflow generated', txt, - commands=(f'{selected_path}',), - path_to_browse=selected_path - ) - - def warnNoValidExpPaths(self, selected_path): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - The selected folder does - not contain any valid experiment folders. - """) - command = selected_path.replace('\\', os.sep) - command = selected_path.replace('/', os.sep) - msg.warning( - self, 'No valid folders found', txt, - commands=(command,), - path_to_browse=selected_path - ) - - def parse_select_from_exp_paths( - self, exp_paths: dict[os.PathLike, Iterable[str]] - ): - if not self.askSelectPosFolders: - return list(exp_paths.keys()) - - paths = [] - for exp_path, pos_foldernames in exp_paths.items(): - if len(pos_foldernames) == 1: - paths.append(exp_path) - continue - - informativeText = html_utils.paragraph( - 'The following experiment folder

' - f'{exp_path}

' - 'contains multiple Position folders.

' - 'Please, select which Position folder(s) you want to analyse:
' - ) - select_folder = load.select_exp_folder() - values = select_folder.get_values_dataprep(exp_path) - select_folder.QtPrompt( - self, values, toggleMulti=True, - informativeText=informativeText, - selectedValues=values - ) - if select_folder.cancel: - return - - for pos in select_folder.selected_pos: - paths.append(os.path.join(exp_path, pos)) - - return paths - - def addFolderPath(self, selected_path): - myutils.addToRecentPaths(selected_path) - - folder_type = myutils.determine_folder_type(selected_path) - is_pos_folder, is_images_folder, folder_path = folder_type - if is_pos_folder: - paths = [selected_path] - elif is_images_folder: - paths = [os.path.dirname(selected_path)] - elif self.scanTree: - print(f'Scanning selected folder "{selected_path}"...') - exp_paths = path.get_posfolderpaths_walk(selected_path) - if not exp_paths: - self.warnNoValidExpPaths(selected_path) - return - - paths = self.parse_select_from_exp_paths(exp_paths) - if paths is None: - return - else: - paths = [selected_path] - - if not paths: - self.warnNoValidPathsFound(selected_path) - - for selectedPath in paths: - if self.onlyExpPaths: - selectedPath = load.get_exp_path(selectedPath) - - selectedPath = selectedPath.replace('\\', '/') - if selectedPath in self.pathsList(): - print( - f'[WARNING]: The following path was already selected: ' - f'"{selectedPath}"' - ) - return - - self.listWidget.addItem(selectedPath) - - def removePaths(self): - for item in self.listWidget.selectedItems(): - row = self.listWidget.row(item) - self.listWidget.takeItem(row) - -class OverlayLabelsAppearanceDialog(QBaseDialog): - sigValuesChanged = Signal(object) - - def __init__(self, scatterPlotItem: pg.ScatterPlotItem=None, parent=None): - super().__init__(parent) - - self.cancel = True - - self.setWindowTitle('Overlay contours appearance properties') - - mainLayout = QVBoxLayout() - - formLayout = widgets.FormLayout() - - row = -1 - - row += 1 - self.colorButton = widgets.myColorButton(color=(255, 0, 0)) - self.colorButton.clicked.disconnect() - self.colorButton.clicked.connect(self.selectColor) - self.colorButton.setCursor(Qt.PointingHandCursor) - self.colorWidget = widgets.formWidget( - self.colorButton, addInfoButton=False, stretchWidget=False, - labelTextLeft='Symbol color: ', parent=self, - widgetAlignment='left' - ) - if scatterPlotItem is not None: - pen = scatterPlotItem.opts['pen'] - color = pen.color() - self.colorButton.setColor(color) - formLayout.addFormWidget(self.colorWidget, row=row) - - row += 1 - self.penWidthSpinBox = widgets.SpinBox() - self.penWidthSpinBox.setMinimum(0) - self.penWidthSpinBox.setValue(2) - - self.penWidthWidget = widgets.formWidget( - self.penWidthSpinBox, addInfoButton=False, stretchWidget=False, - labelTextLeft='Symbol weight: ', parent=self, - widgetAlignment='left' - ) - if scatterPlotItem is not None: - pen = scatterPlotItem.opts['pen'] - width = pen.width() - self.penWidthSpinBox.setValue(width) - formLayout.addFormWidget(self.penWidthWidget, row=row) - - row += 1 - self.opacitySlider = widgets.sliderWithSpinBox( - isFloat=True, normalize=True - ) - self.opacitySlider.setMinimum(0) - self.opacitySlider.setMaximum(100) - self.opacitySlider.setValue(0.8) - - self.opacityWidget = widgets.formWidget( - self.opacitySlider, addInfoButton=False, stretchWidget=True, - labelTextLeft='Symbol opacity: ', parent=self - ) - if scatterPlotItem is not None: - brush = scatterPlotItem.opts['brush'] - alpha = brush.color().alpha() - opacity = alpha/255 - self.opacitySlider.setValue(opacity) - formLayout.addFormWidget(self.opacityWidget, row=row) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(formLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def selectColor(self): - color = self.colorButton.color() - self.colorButton.origColor = color - self.colorButton.colorDialog.setCurrentColor(color) - self.colorButton.colorDialog.setWindowFlags( - Qt.Window | Qt.WindowStaysOnTopHint - ) - self.colorButton.colorDialog.open() - w = self.width() - left = self.pos().x() - colorDialogTop = self.colorButton.colorDialog.pos().y() - self.colorButton.colorDialog.move(w+left+10, colorDialogTop) - - def getBrush(self): - r, g, b, _ = self.colorButton.color().getRgb() - alpha = round(self.opacitySlider.value()*255) - brushColor = (r, g, b, alpha) - brush = pg.mkBrush(brushColor) - return brush - - def getPen(self): - color = self.colorButton.color() - penWidth = self.penWidthSpinBox.value() - if penWidth == 0: - return - - pen = pg.mkPen(color, width=penWidth) - return pen - - def ok_cb(self): - self.cancel = False - self.properties = { - 'brush': self.getBrush(), - 'pen': self.getPen() - } - self.close() - -class AutoSaveIntervalDialog(QBaseDialog): - sigValueChanged = Signal(float, str) - - def __init__(self, parent=None): - super().__init__(parent) - - self.cancel = True - - self.setWindowTitle('Change autosave interval') - - mainLayout = QVBoxLayout() - - self.autoSaveIntervalWidget = ( - widgets.AutoSaveIntervalWidget(parent=self) - ) - - mainLayout.addWidget(QLabel('Autosave interval:')) - mainLayout.addWidget(self.autoSaveIntervalWidget) - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def setValues(self, autoSaveIntevalValue, autoSaveIntervalUnit): - self.autoSaveIntervalWidget.spinbox.setValue(autoSaveIntevalValue) - self.autoSaveIntervalWidget.unitCombobox.setCurrentText( - autoSaveIntervalUnit - ) - - def sizeHint(self): - defaultWidth = super().sizeHint().width() - defaultHeight = super().sizeHint().height() - return QSize(defaultWidth*2, defaultHeight) - - def ok_cb(self): - self.cancel = False - self.sigValueChanged.emit( - self.autoSaveIntervalWidget.spinbox.value(), - self.autoSaveIntervalWidget.unitCombobox.currentText() - ) - self.close() - -class TestSegmModelInitalDialog(QBaseDialog): - def __init__(self, parent=None): - super().__init__(parent) - - self.cancel = True - - mainLayout = QVBoxLayout() - entriesLayout = widgets.FormLayout() - - row = 0 - self.startFrameNumberSpinbox = widgets.SpinBox() - self.startFrameNumberSpinbox.setMinimum(1) - - self.startFrameNumberFormWidget = widgets.formWidget( - self.startFrameNumberSpinbox, - labelTextLeft='Start frame number', - addActivateCheckbox=True - ) - entriesLayout.addFormWidget(self.startFrameNumberFormWidget, row=row) - - row += 1 - self.stopFrameNumberSpinbox = widgets.SpinBox() - self.stopFrameNumberSpinbox.setMinimum(1) - - self.stopFrameNumberFormWidget = widgets.formWidget( - self.stopFrameNumberSpinbox, - labelTextLeft='Stop frame number', - addActivateCheckbox=True - ) - entriesLayout.addFormWidget(self.stopFrameNumberFormWidget, row=row) - - row += 1 - self.startZsliceNumberSpinbox = widgets.SpinBox() - self.startZsliceNumberSpinbox.setMinimum(1) - - self.startZsliceNumberFormWidget = widgets.formWidget( - self.startZsliceNumberSpinbox, - labelTextLeft='Start z-slice number', - addActivateCheckbox=True - ) - entriesLayout.addFormWidget(self.startZsliceNumberFormWidget, row=row) - - row += 1 - self.stopZsliceNumberSpinbox = widgets.SpinBox() - self.stopZsliceNumberSpinbox.setMinimum(1) - - self.stopZsliceNumberFormWidget = widgets.formWidget( - self.stopZsliceNumberSpinbox, - labelTextLeft='Stop z-slice number', - addActivateCheckbox=True - ) - entriesLayout.addFormWidget(self.stopZsliceNumberFormWidget, row=row) - - row += 1 - - self.isTimelapseToggleFormWidget = widgets.formWidget( - widgets.Toggle(), - labelTextLeft='Is timelapse?', - stretchWidget=False, - valueGetterName='isChecked' - ) - entriesLayout.addFormWidget(self.isTimelapseToggleFormWidget, row=row) - - - # self.stopFrameNumberSpinbox - # self.startZsliceNumberSpinbox - # self.stopZsliceNumberSpinbox - # self.isTimelapseToggle - - buttonsLayout = widgets.CancelOkButtonsLayout() - - buttonsLayout.okButton.clicked.connect(self.ok_cb) - buttonsLayout.cancelButton.clicked.connect(self.close) - - mainLayout.addLayout(entriesLayout) - mainLayout.addSpacing(20) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def ok_cb(self): - self.cancel = False - - self.start_frame_n = self.startFrameNumberFormWidget.value() - self.stop_frame_n = self.stopFrameNumberFormWidget.value() - self.start_z_slice_n = self.startZsliceNumberFormWidget.value() - self.stop_z_slice_n = self.stopZsliceNumberFormWidget.value() - self.is_timelapse = self.isTimelapseToggleFormWidget.value() - - self.close() \ No newline at end of file +from .dialogs import * # noqa: F403 diff --git a/cellacdc/autopilot.py b/cellacdc/autopilot.py index 2c44fd7e6..073724f10 100644 --- a/cellacdc/autopilot.py +++ b/cellacdc/autopilot.py @@ -1,68 +1,79 @@ import os -from qtpy.QtCore import ( - QTimer, QThread, Signal, QObject -) +from qtpy.QtCore import QTimer, QThread, Signal, QObject + +from . import load, printl, utils -from . import load, printl, myutils class AutoPilotProfile: def __init__(self): self.lastLoadingProfile = [] def storeSelectedChannel(self, user_channel): - self.lastLoadingProfile.append({ - 'windowTitle': 'Select channel name', - 'windowActions': ('ComboBox.setCurrentText', 'ok_cb'), - 'windowActionsArgs': ((user_channel,), tuple()) - }) + self.lastLoadingProfile.append( + { + "windowTitle": "Select channel name", + "windowActions": ("ComboBox.setCurrentText", "ok_cb"), + "windowActionsArgs": ((user_channel,), tuple()), + } + ) def storeSelectedSegmFile(self, selectedSegmEndName): - self.lastLoadingProfile.append({ - 'windowTitle': 'Multiple segm.npz files detected', - 'windowActions': ('listWidget.setSelectedItemFromText', 'ok_cb'), - 'windowActionsArgs': ((selectedSegmEndName,), tuple()) - }) - + self.lastLoadingProfile.append( + { + "windowTitle": "Multiple segm.npz files detected", + "windowActions": ("listWidget.setSelectedItemFromText", "ok_cb"), + "windowActionsArgs": ((selectedSegmEndName,), tuple()), + } + ) + def storeOkAskInputMetadata(self): - self.lastLoadingProfile.append({ - 'windowTitle': 'Image properties', - 'windowActions': ('ok_cb',), - 'windowActionsArgs': (tuple(),) - }) - + self.lastLoadingProfile.append( + { + "windowTitle": "Image properties", + "windowActions": ("ok_cb",), + "windowActionsArgs": (tuple(),), + } + ) + def storeLoadSavedData(self): - self.lastLoadingProfile.append({ - 'windowTitle': 'Recover unsaved data?', - 'windowActions': ('clickButtonFromText',), - 'windowActionsArgs': (('Load saved data',),) - }) - + self.lastLoadingProfile.append( + { + "windowTitle": "Recover unsaved data?", + "windowActions": ("clickButtonFromText",), + "windowActionsArgs": (("Load saved data",),), + } + ) + def storeClickMessageBox(self, windowTitle, buttonTextToClick): - self.lastLoadingProfile.append({ - 'windowTitle': windowTitle, - 'windowActions': ('clickButtonFromText',), - 'windowActionsArgs': ((buttonTextToClick,),) - }) - + self.lastLoadingProfile.append( + { + "windowTitle": windowTitle, + "windowActions": ("clickButtonFromText",), + "windowActionsArgs": ((buttonTextToClick,),), + } + ) + def storeLoadedFluoChannels(self, loadedChannels): - self.lastLoadingProfile.append({ - 'windowTitle': 'Select channel to load', - 'windowActions': ('setSelectedItems', 'ok_cb'), - 'windowActionsArgs': ((loadedChannels,), tuple()) - }) - + self.lastLoadingProfile.append( + { + "windowTitle": "Select channel to load", + "windowActions": ("setSelectedItems", "ok_cb"), + "windowActionsArgs": ((loadedChannels,), tuple()), + } + ) + def getCopy(self): return self.lastLoadingProfile.copy() -class AutoPilot: +class AutoPilot: def __init__(self, parentWin) -> None: self.parentWin = parentWin self.app = parentWin.app self.isFinished = True self.loadingProfile = parentWin.AutoPilotProfile.getCopy() - + def _askSelectPos(self): posData = self.parentWin.data[self.parentWin.pos_i] exp_path = posData.exp_path @@ -74,16 +85,16 @@ def _askSelectPos(self): select_folder.QtPrompt(self.parentWin, values, allowMultiSelection=False) if select_folder.cancel: return - + posPath = os.path.join(exp_path, select_folder.selected_pos[0]) return posPath def execLoadPos(self): posPath = self._askSelectPos() if posPath is None: - self.parentWin.logger.info('Loading Position cancelled.') + self.parentWin.logger.info("Loading Position cancelled.") return - + self.isFinished = False self.timer = QTimer() self.timer.timeout.connect(self.loadPosTimerCallback) @@ -96,8 +107,8 @@ def loadPosTimerCallback(self): if not self.loadingProfile: self.timer.stop() return - - windowTitle = self.loadingProfile[0]['windowTitle'] + + windowTitle = self.loadingProfile[0]["windowTitle"] for window in openWindows: if not window.windowTitle(): continue @@ -105,13 +116,12 @@ def loadPosTimerCallback(self): continue if not windowTitle == window.windowTitle(): continue - - windowActions = self.loadingProfile[0]['windowActions'] - windowActionsArgs = self.loadingProfile[0]['windowActionsArgs'] + + windowActions = self.loadingProfile[0]["windowActions"] + windowActionsArgs = self.loadingProfile[0]["windowActionsArgs"] for action, args in zip(windowActions, windowActionsArgs): - func = myutils.get_chained_attr(window, action) + func = utils.get_chained_attr(window, action) func(*args) - + self.loadingProfile.pop(0) break - \ No newline at end of file diff --git a/cellacdc/bioformats/__init__.py b/cellacdc/bioformats/__init__.py index 920de1d63..7944d6b6c 100755 --- a/cellacdc/bioformats/__init__.py +++ b/cellacdc/bioformats/__init__.py @@ -5,9 +5,7 @@ # Copyright (c) 2009-2014 Broad Institute # All rights reserved. -'''Bioformats package - wrapper for loci.bioformats java code - -''' +"""Bioformats package - wrapper for loci.bioformats java code""" from __future__ import absolute_import, unicode_literals @@ -22,50 +20,228 @@ from . import formatreader as _formatreader from . import formatwriter as _formatwriter -_jars_dir = os.path.join(os.path.dirname(__file__), 'jars') +_jars_dir = os.path.join(os.path.dirname(__file__), "jars") -JAR_VERSION = '6.5.1' +JAR_VERSION = "6.5.1" -JARS = javabridge.JARS + [os.path.realpath(os.path.join(_jars_dir, name + '.jar')) - for name in ['bioformats_package']] +JARS = javabridge.JARS + [ + os.path.realpath(os.path.join(_jars_dir, name + ".jar")) + for name in ["bioformats_package"] +] """List of directories, jar files, and zip files that should be added to the Java virtual machine's class path.""" # See http://www.loci.wisc.edu/software/bio-formats -READABLE_FORMATS = ('1sc', '2fl', 'acff', 'afi', 'afm', 'aim', 'al3d', 'ali', - 'am', 'amiramesh', 'apl', 'arf', 'avi', 'bif', 'bin', 'bip', - 'bmp', 'btf', 'c01', 'cfg', 'ch5', 'cif', 'cr2', 'crw', - 'cxd', 'czi', 'dat', 'dcm', 'dib', 'dicom', 'dm2', 'dm3', - 'dm4', 'dti', 'dv', 'eps', 'epsi', 'exp', 'fdf', 'fff', - 'ffr', 'fits', 'flex', 'fli', 'frm', 'gel', 'gif', 'grey', - 'h5', 'hdf', 'hdr', 'hed', 'his', 'htd', 'html', 'hx', 'i2i', - 'ics', 'ids', 'im3', 'img', 'ims', 'inr', 'ipl', 'ipm', 'ipw', - 'j2k', 'jp2', 'jpeg', 'jpf', 'jpg', 'jpk', 'jpx', 'klb', - 'l2d', 'labels', 'lei', 'lif', 'liff', 'lim', 'lms', 'lsm', - 'map', 'mdb', 'mea', 'mnc', 'mng', 'mod', 'mov', 'mrc', 'mrcs', - 'mrw', 'msr', 'mtb', 'mvd2', 'naf', 'nd', 'nd2', 'ndpi', 'ndpis', - 'nef', 'nhdr', 'nii', 'nii.gz', 'nrrd', 'obf', 'obsep', 'oib', - 'oif', 'oir', 'ome', 'ome.btf', 'ome.tf2', 'ome.tf8', 'ome.tif', - 'ome.tiff', 'ome.xml', 'par', 'pbm', 'pcoraw', 'pcx', 'pds', - 'pgm', 'pic', 'pict', 'png', 'pnl', 'ppm', 'pr3', 'ps', 'psd', - 'qptiff', 'r3d', 'raw', 'rcpnl', 'rec', 'res', 'scn', 'sdt', - 'seq', 'sif', 'sld', 'sm2', 'sm3', 'spc', 'spe', 'spi', 'st', - 'stk', 'stp', 'svs', 'sxm', 'tc.', 'tf2', 'tf8', 'tfr', 'tga', - 'tif', 'tiff', 'tnb', 'top', 'txt', 'v', 'vff', 'vms', 'vsi', - 'vws', 'wat', 'wlz', 'wpi', 'xdce', 'xml', 'xqd', 'xqf', 'xv', - 'xys', 'zfp', 'zfr', 'zvi') - -WRITABLE_FORMATS = ('avi', 'eps', 'epsi', 'ics', 'ids', 'jp2', 'jpeg', 'jpg', - 'mov', 'ome', 'ome.tiff', 'png', 'ps', 'tif', 'tiff') +READABLE_FORMATS = ( + "1sc", + "2fl", + "acff", + "afi", + "afm", + "aim", + "al3d", + "ali", + "am", + "amiramesh", + "apl", + "arf", + "avi", + "bif", + "bin", + "bip", + "bmp", + "btf", + "c01", + "cfg", + "ch5", + "cif", + "cr2", + "crw", + "cxd", + "czi", + "dat", + "dcm", + "dib", + "dicom", + "dm2", + "dm3", + "dm4", + "dti", + "dv", + "eps", + "epsi", + "exp", + "fdf", + "fff", + "ffr", + "fits", + "flex", + "fli", + "frm", + "gel", + "gif", + "grey", + "h5", + "hdf", + "hdr", + "hed", + "his", + "htd", + "html", + "hx", + "i2i", + "ics", + "ids", + "im3", + "img", + "ims", + "inr", + "ipl", + "ipm", + "ipw", + "j2k", + "jp2", + "jpeg", + "jpf", + "jpg", + "jpk", + "jpx", + "klb", + "l2d", + "labels", + "lei", + "lif", + "liff", + "lim", + "lms", + "lsm", + "map", + "mdb", + "mea", + "mnc", + "mng", + "mod", + "mov", + "mrc", + "mrcs", + "mrw", + "msr", + "mtb", + "mvd2", + "naf", + "nd", + "nd2", + "ndpi", + "ndpis", + "nef", + "nhdr", + "nii", + "nii.gz", + "nrrd", + "obf", + "obsep", + "oib", + "oif", + "oir", + "ome", + "ome.btf", + "ome.tf2", + "ome.tf8", + "ome.tif", + "ome.tiff", + "ome.xml", + "par", + "pbm", + "pcoraw", + "pcx", + "pds", + "pgm", + "pic", + "pict", + "png", + "pnl", + "ppm", + "pr3", + "ps", + "psd", + "qptiff", + "r3d", + "raw", + "rcpnl", + "rec", + "res", + "scn", + "sdt", + "seq", + "sif", + "sld", + "sm2", + "sm3", + "spc", + "spe", + "spi", + "st", + "stk", + "stp", + "svs", + "sxm", + "tc.", + "tf2", + "tf8", + "tfr", + "tga", + "tif", + "tiff", + "tnb", + "top", + "txt", + "v", + "vff", + "vms", + "vsi", + "vws", + "wat", + "wlz", + "wpi", + "xdce", + "xml", + "xqd", + "xqf", + "xv", + "xys", + "zfp", + "zfr", + "zvi", +) + +WRITABLE_FORMATS = ( + "avi", + "eps", + "epsi", + "ics", + "ids", + "jp2", + "jpeg", + "jpg", + "mov", + "ome", + "ome.tiff", + "png", + "ps", + "tif", + "tiff", +) OMETiffWriter = _formatwriter.make_ome_tiff_writer_class() ChannelSeparator = _formatreader.make_reader_wrapper_class( - "loci/formats/ChannelSeparator") + "loci/formats/ChannelSeparator" +) from .metadatatools import createOMEXMLMetadata as create_ome_xml_metadata from .metadatatools import wrap_imetadata_object from . import metadatatools as _metadatatools + PixelType = _metadatatools.make_pixel_type_class() get_metadata_options = _metadatatools.get_metadata_options @@ -84,6 +260,7 @@ # Metadata from .omexml import OMEXML + get_omexml_metadata = _formatreader.get_omexml_metadata # Writing images @@ -94,10 +271,20 @@ # Omero -from .formatreader import use_omero_credentials, set_omero_credentials, get_omero_credentials +from .formatreader import ( + use_omero_credentials, + set_omero_credentials, + get_omero_credentials, +) from .formatreader import set_omero_login_hook, omero_logout, has_omero_packages -from .formatreader import K_OMERO_SERVER, K_OMERO_PORT, K_OMERO_USER, K_OMERO_SESSION_ID,\ - K_OMERO_PASSWORD, K_OMERO_CONFIG_FILE +from .formatreader import ( + K_OMERO_SERVER, + K_OMERO_PORT, + K_OMERO_USER, + K_OMERO_SESSION_ID, + K_OMERO_PASSWORD, + K_OMERO_CONFIG_FILE, +) from . import omexml diff --git a/cellacdc/bioformats/formatreader.py b/cellacdc/bioformats/formatreader.py index a63d5b2c0..73ad66c10 100755 --- a/cellacdc/bioformats/formatreader.py +++ b/cellacdc/bioformats/formatreader.py @@ -5,7 +5,7 @@ # Copyright (c) 2009-2014 Broad Institute # All rights reserved. -'''formatreader.py - mechanism to wrap a bioformats ReaderWrapper and ImageReader +"""formatreader.py - mechanism to wrap a bioformats ReaderWrapper and ImageReader Example: import bioformats.formatreader as biordr @@ -20,13 +20,14 @@ my_red_image, my_green_image, my_blue_image = \ [cs.open_bytes(cs.getIndex(0,i,0)) for i in range(3)] -''' +""" from __future__ import absolute_import, unicode_literals __version__ = "$Revision$" import logging + logger = logging.getLogger(__name__) import errno import numpy as np @@ -40,6 +41,7 @@ else: from urllib import url2pathname from urllib2 import urlopen, urlparse, unquote + urlparse = urlparse.urlparse import shutil @@ -56,6 +58,7 @@ try: from omero_reader import OmeroReader, OMERO_IMPORTED from omero_reader.utils import omero_reader_enabled + OMERO_READER_IMPORTED = True except ImportError: pass @@ -65,162 +68,221 @@ K_OMERO_USER = "omero_user" K_OMERO_SESSION_ID = "omero_session_id" K_OMERO_CONFIG_FILE = "omero_config_file" -'''The cleartext password - only used if password is provided on command-line''' +"""The cleartext password - only used if password is provided on command-line""" K_OMERO_PASSWORD = "omero_password" + def make_format_tools_class(): - '''Get a wrapper for the loci/formats/FormatTools class + """Get a wrapper for the loci/formats/FormatTools class The FormatTools class has many of the constants needed by other classes as statics. - ''' + """ + class FormatTools(object): - '''A wrapper for loci.formats.FormatTools + """A wrapper for loci.formats.FormatTools See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/FormatTools.html - ''' + """ + env = jutil.get_env() - klass = env.find_class('loci/formats/FormatTools') - CAN_GROUP = jutil.get_static_field(klass, 'CAN_GROUP','I') - CANNOT_GROUP = jutil.get_static_field(klass, 'CANNOT_GROUP','I') - DOUBLE = jutil.get_static_field(klass, 'DOUBLE','I') - FLOAT = jutil.get_static_field(klass, 'FLOAT', 'I') - INT16 = jutil.get_static_field(klass, 'INT16', 'I') - INT32 = jutil.get_static_field(klass, 'INT32', 'I') - INT8 = jutil.get_static_field(klass, 'INT8', 'I') - MUST_GROUP = jutil.get_static_field(klass, 'MUST_GROUP', 'I') - UINT16 = jutil.get_static_field(klass, 'UINT16', 'I') - UINT32 = jutil.get_static_field(klass, 'UINT32', 'I') - UINT8 = jutil.get_static_field(klass, 'UINT8', 'I') + klass = env.find_class("loci/formats/FormatTools") + CAN_GROUP = jutil.get_static_field(klass, "CAN_GROUP", "I") + CANNOT_GROUP = jutil.get_static_field(klass, "CANNOT_GROUP", "I") + DOUBLE = jutil.get_static_field(klass, "DOUBLE", "I") + FLOAT = jutil.get_static_field(klass, "FLOAT", "I") + INT16 = jutil.get_static_field(klass, "INT16", "I") + INT32 = jutil.get_static_field(klass, "INT32", "I") + INT8 = jutil.get_static_field(klass, "INT8", "I") + MUST_GROUP = jutil.get_static_field(klass, "MUST_GROUP", "I") + UINT16 = jutil.get_static_field(klass, "UINT16", "I") + UINT32 = jutil.get_static_field(klass, "UINT32", "I") + UINT8 = jutil.get_static_field(klass, "UINT8", "I") @classmethod def getPixelTypeString(cls, pixel_type): - return jutil.static_call('loci/formats/FormatTools', 'getPixelTypeString', '(I)Ljava/lang/String;', pixel_type) + return jutil.static_call( + "loci/formats/FormatTools", + "getPixelTypeString", + "(I)Ljava/lang/String;", + pixel_type, + ) return FormatTools + def make_iformat_reader_class(): - '''Bind a Java class that implements IFormatReader to a Python class + """Bind a Java class that implements IFormatReader to a Python class Returns a class that implements IFormatReader through calls to the implemented class passed in. The returned class can be subclassed to provide additional bindings. - ''' + """ + class IFormatReader(object): - '''A wrapper for loci.formats.IFormatReader + """A wrapper for loci.formats.IFormatReader See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/ImageReader.html - ''' - close = jutil.make_method('close','()V', - 'Close the currently open file and free memory') - getDimensionOrder = jutil.make_method('getDimensionOrder', - '()Ljava/lang/String;', - 'Return the dimension order as a five-character string, e.g. "XYCZT"') - getGlobalMetadata = jutil.make_method('getGlobalMetadata', - '()Ljava/util/Hashtable;', - 'Obtains the hashtable containing the global metadata field/value pairs') + """ + + close = jutil.make_method( + "close", "()V", "Close the currently open file and free memory" + ) + getDimensionOrder = jutil.make_method( + "getDimensionOrder", + "()Ljava/lang/String;", + 'Return the dimension order as a five-character string, e.g. "XYCZT"', + ) + getGlobalMetadata = jutil.make_method( + "getGlobalMetadata", + "()Ljava/util/Hashtable;", + "Obtains the hashtable containing the global metadata field/value pairs", + ) getMetadata = getGlobalMetadata - getMetadataValue = jutil.make_method('getMetadataValue', - '(Ljava/lang/String;)' - 'Ljava/lang/Object;', - 'Look up a specific metadata value from the store') - getSeriesMetadata = jutil.make_method('getSeriesMetadata', - '()Ljava/util/Hashtable;', - 'Obtains the hashtable contaning the series metadata field/value pairs') - getSeriesCount = jutil.make_method('getSeriesCount', - '()I', - 'Return the # of image series in the file') - getSeries = jutil.make_method('getSeries', '()I', - 'Return the currently selected image series') - getImageCount = jutil.make_method('getImageCount', - '()I','Determines the number of images in the current file') - getIndex = jutil.make_method('getIndex', '(III)I', - 'Get the plane index given z, c, t') - getRGBChannelCount = jutil.make_method('getRGBChannelCount', - '()I','Gets the number of channels per RGB image (if not RGB, this returns 1') - getSizeC = jutil.make_method('getSizeC', '()I', - 'Get the number of color planes') - getSizeT = jutil.make_method('getSizeT', '()I', - 'Get the number of frames in the image') - getSizeX = jutil.make_method('getSizeX', '()I', - 'Get the image width') - getSizeY = jutil.make_method('getSizeY', '()I', - 'Get the image height') - getSizeZ = jutil.make_method('getSizeZ', '()I', - 'Get the image depth') - getPixelType = jutil.make_method('getPixelType', '()I', - 'Get the pixel type: see FormatTools for types') - isLittleEndian = jutil.make_method('isLittleEndian', - '()Z','Return True if the data is in little endian order') - isRGB = jutil.make_method('isRGB', '()Z', - 'Return True if images in the file are RGB') - isInterleaved = jutil.make_method('isInterleaved', '()Z', - 'Return True if image colors are interleaved within a plane') - isIndexed = jutil.make_method('isIndexed', '()Z', - 'Return True if the raw data is indexes in a lookup table') - openBytes = jutil.make_method('openBytes','(I)[B', - 'Get the specified image plane as a byte array') - openBytesXYWH = jutil.make_method('openBytes','(IIIII)[B', - '''Get the specified image plane as a byte array + getMetadataValue = jutil.make_method( + "getMetadataValue", + "(Ljava/lang/String;)Ljava/lang/Object;", + "Look up a specific metadata value from the store", + ) + getSeriesMetadata = jutil.make_method( + "getSeriesMetadata", + "()Ljava/util/Hashtable;", + "Obtains the hashtable contaning the series metadata field/value pairs", + ) + getSeriesCount = jutil.make_method( + "getSeriesCount", "()I", "Return the # of image series in the file" + ) + getSeries = jutil.make_method( + "getSeries", "()I", "Return the currently selected image series" + ) + getImageCount = jutil.make_method( + "getImageCount", + "()I", + "Determines the number of images in the current file", + ) + getIndex = jutil.make_method( + "getIndex", "(III)I", "Get the plane index given z, c, t" + ) + getRGBChannelCount = jutil.make_method( + "getRGBChannelCount", + "()I", + "Gets the number of channels per RGB image (if not RGB, this returns 1", + ) + getSizeC = jutil.make_method( + "getSizeC", "()I", "Get the number of color planes" + ) + getSizeT = jutil.make_method( + "getSizeT", "()I", "Get the number of frames in the image" + ) + getSizeX = jutil.make_method("getSizeX", "()I", "Get the image width") + getSizeY = jutil.make_method("getSizeY", "()I", "Get the image height") + getSizeZ = jutil.make_method("getSizeZ", "()I", "Get the image depth") + getPixelType = jutil.make_method( + "getPixelType", "()I", "Get the pixel type: see FormatTools for types" + ) + isLittleEndian = jutil.make_method( + "isLittleEndian", "()Z", "Return True if the data is in little endian order" + ) + isRGB = jutil.make_method( + "isRGB", "()Z", "Return True if images in the file are RGB" + ) + isInterleaved = jutil.make_method( + "isInterleaved", + "()Z", + "Return True if image colors are interleaved within a plane", + ) + isIndexed = jutil.make_method( + "isIndexed", + "()Z", + "Return True if the raw data is indexes in a lookup table", + ) + openBytes = jutil.make_method( + "openBytes", "(I)[B", "Get the specified image plane as a byte array" + ) + openBytesXYWH = jutil.make_method( + "openBytes", + "(IIIII)[B", + """Get the specified image plane as a byte array (corresponds to openBytes(int no, int x, int y, int w, int h)) no - image plane number x,y - offset into image - w,h - dimensions of image to return''') - setSeries = jutil.make_method('setSeries','(I)V','Set the currently selected image series') - setGroupFiles = jutil.make_method('setGroupFiles', '(Z)V', - 'Force reader to group or not to group files in a multi-file set') - setMetadataStore = jutil.make_method('setMetadataStore', - '(Lloci/formats/meta/MetadataStore;)V', - 'Sets the default metadata store for this reader.') - setMetadataOptions = jutil.make_method('setMetadataOptions', - '(Lloci/formats/in/MetadataOptions;)V', - 'Sets the metadata options used when reading metadata') + w,h - dimensions of image to return""", + ) + setSeries = jutil.make_method( + "setSeries", "(I)V", "Set the currently selected image series" + ) + setGroupFiles = jutil.make_method( + "setGroupFiles", + "(Z)V", + "Force reader to group or not to group files in a multi-file set", + ) + setMetadataStore = jutil.make_method( + "setMetadataStore", + "(Lloci/formats/meta/MetadataStore;)V", + "Sets the default metadata store for this reader.", + ) + setMetadataOptions = jutil.make_method( + "setMetadataOptions", + "(Lloci/formats/in/MetadataOptions;)V", + "Sets the metadata options used when reading metadata", + ) isThisTypeS = jutil.make_method( - 'isThisType', - '(Ljava/lang/String;)Z', - 'Return true if the filename might be handled by this reader') + "isThisType", + "(Ljava/lang/String;)Z", + "Return true if the filename might be handled by this reader", + ) isThisTypeSZ = jutil.make_method( - 'isThisType', - '(Ljava/lang/String;Z)Z', - '''Return true if the named file is handled by this reader. + "isThisType", + "(Ljava/lang/String;Z)Z", + """Return true if the named file is handled by this reader. filename - name of file allowOpen - True if the reader is allowed to open files when making its determination - ''') + """, + ) isThisTypeStream = jutil.make_method( - 'isThisType', - '(Lloci/common/RandomAccessInputStream;)Z', - '''Return true if the stream might be parseable by this reader. + "isThisType", + "(Lloci/common/RandomAccessInputStream;)Z", + """Return true if the stream might be parseable by this reader. stream - the RandomAccessInputStream to be used to read the file contents Note that both isThisTypeS and isThisTypeStream must return true - for the type to truly be handled.''') - def setId(self, path): - '''Set the name of the file''' - jutil.call(self.o, 'setId', - '(Ljava/lang/String;)V', - path) + for the type to truly be handled.""", + ) - getMetadataStore = jutil.make_method('getMetadataStore', '()Lloci/formats/meta/MetadataStore;', - 'Retrieves the current metadata store for this reader.') + def setId(self, path): + """Set the name of the file""" + jutil.call(self.o, "setId", "(Ljava/lang/String;)V", path) + + getMetadataStore = jutil.make_method( + "getMetadataStore", + "()Lloci/formats/meta/MetadataStore;", + "Retrieves the current metadata store for this reader.", + ) get8BitLookupTable = jutil.make_method( - 'get8BitLookupTable', - '()[[B', 'Get a lookup table for 8-bit indexed images') + "get8BitLookupTable", "()[[B", "Get a lookup table for 8-bit indexed images" + ) get16BitLookupTable = jutil.make_method( - 'get16BitLookupTable', - '()[[S', 'Get a lookup table for 16-bit indexed images') + "get16BitLookupTable", + "()[[S", + "Get a lookup table for 16-bit indexed images", + ) + def get_class_name(self): - return jutil.call(jutil.call(self.o, 'getClass', '()Ljava/lang/Class;'), - 'getName', '()Ljava/lang/String;') + return jutil.call( + jutil.call(self.o, "getClass", "()Ljava/lang/Class;"), + "getName", + "()Ljava/lang/String;", + ) @property def suffixNecessary(self): - if self.get_class_name() == 'loci.formats.in.JPKReader': - return True; + if self.get_class_name() == "loci.formats.in.JPKReader": + return True env = jutil.get_env() klass = env.get_object_class(self.o) field_id = env.get_field_id(klass, "suffixNecessary", "Z") @@ -230,8 +292,8 @@ def suffixNecessary(self): @property def suffixSufficient(self): - if self.get_class_name() == 'loci.formats.in.JPKReader': - return True; + if self.get_class_name() == "loci.formats.in.JPKReader": + return True env = jutil.get_env() klass = env.get_object_class(self.o) field_id = env.get_field_id(klass, "suffixSufficient", "Z") @@ -239,94 +301,109 @@ def suffixSufficient(self): return None return env.get_boolean_field(self.o, field_id) - return IFormatReader + def get_class_list(): - '''Return a wrapped instance of loci.formats.ClassList''' + """Return a wrapped instance of loci.formats.ClassList""" + # # This uses the reader.txt file from inside the loci_tools.jar # class ClassList(object): remove_class = jutil.make_method( - 'removeClass', '(Ljava/lang/Class;)V', - 'Remove the given class from the class list') + "removeClass", + "(Ljava/lang/Class;)V", + "Remove the given class from the class list", + ) add_class = jutil.make_method( - 'addClass', '(Ljava/lang/Class;)V', - 'Add the given class to the back of the class list') + "addClass", + "(Ljava/lang/Class;)V", + "Add the given class to the back of the class list", + ) get_classes = jutil.make_method( - 'getClasses', '()[Ljava/lang/Class;', - 'Get the classes in the list as an array') + "getClasses", + "()[Ljava/lang/Class;", + "Get the classes in the list as an array", + ) def __init__(self): env = jutil.get_env() - class_name = 'loci/formats/ImageReader' + class_name = "loci/formats/ImageReader" klass = env.find_class(class_name) - base_klass = env.find_class('loci/formats/IFormatReader') - self.o = jutil.make_instance("loci/formats/ClassList", - "(Ljava/lang/String;" - "Ljava/lang/Class;" # base - "Ljava/lang/Class;)V", # location in jar - "readers.txt", base_klass, klass) + base_klass = env.find_class("loci/formats/IFormatReader") + self.o = jutil.make_instance( + "loci/formats/ClassList", + "(Ljava/lang/String;" + "Ljava/lang/Class;" # base + "Ljava/lang/Class;)V", # location in jar + "readers.txt", + base_klass, + klass, + ) problem_classes = [ # BDReader will read all .tif files in an experiment if it's # called to load a .tif. # - 'loci.formats.in.BDReader', + "loci.formats.in.BDReader", # # MRCReader will read .stk files which should be read # by MetamorphReader # - 'loci.formats.in.MRCReader' - ] + "loci.formats.in.MRCReader", + ] for problem_class in problem_classes: # Move to back klass = jutil.class_for_name(problem_class) self.remove_class(klass) self.add_class(klass) + return ClassList() def make_image_reader_class(): - '''Return an image reader class for the given Java environment''' + """Return an image reader class for the given Java environment""" env = jutil.get_env() - class_name = 'loci/formats/ImageReader' + class_name = "loci/formats/ImageReader" klass = env.find_class(class_name) - base_klass = env.find_class('loci/formats/IFormatReader') + base_klass = env.find_class("loci/formats/IFormatReader") IFormatReader = make_iformat_reader_class() class_list = get_class_list() class ImageReader(IFormatReader): - new_fn = jutil.make_new(class_name, '(Lloci/formats/ClassList;)V') + new_fn = jutil.make_new(class_name, "(Lloci/formats/ClassList;)V") + def __init__(self): self.new_fn(class_list.o) - getFormat = jutil.make_method('getFormat', - '()Ljava/lang/String;', - 'Get a string describing the format of this file') - getReader = jutil.make_method('getReader', - '()Lloci/formats/IFormatReader;') + + getFormat = jutil.make_method( + "getFormat", + "()Ljava/lang/String;", + "Get a string describing the format of this file", + ) + getReader = jutil.make_method("getReader", "()Lloci/formats/IFormatReader;") + def allowOpenToCheckType(self, allow): - '''Allow the "isThisType" function to open files + """Allow the "isThisType" function to open files For the cluster, you want to tell potential file formats not to open the image file to test if it's their format. - ''' + """ if not hasattr(self, "allowOpenToCheckType_method"): self.allowOpenToCheckType_method = None class_wrapper = jutil.get_class_wrapper(self.o) methods = class_wrapper.getMethods() for method in jutil.get_env().get_object_array_elements(methods): m = jutil.get_method_wrapper(method) - if m.getName() in ('allowOpenToCheckType', 'setAllowOpenFiles'): + if m.getName() in ("allowOpenToCheckType", "setAllowOpenFiles"): self.allowOpenToCheckType_method = m if self.allowOpenToCheckType_method is not None: - object_class = env.find_class('java/lang/Object') + object_class = env.find_class("java/lang/Object") jexception = jutil.get_env().exception_occurred() if jexception is not None: raise jutil.JavaException(jexception) - boolean_value = jutil.make_instance('java/lang/Boolean', - '(Z)V', allow) + boolean_value = jutil.make_instance("java/lang/Boolean", "(Z)V", allow) args = jutil.get_env().make_object_array(1, object_class) jexception = jutil.get_env().exception_occurred() if jexception is not None: @@ -336,51 +413,68 @@ def allowOpenToCheckType(self, allow): if jexception is not None: raise jutil.JavaException(jexception) self.allowOpenToCheckType_method.invoke(self.o, args) + return ImageReader def make_reader_wrapper_class(class_name): - '''Make an ImageReader wrapper class + """Make an ImageReader wrapper class class_name - the name of the wrapper class, for instance, "loci/formats/ChannelSeparator" You can instantiate an instance of the wrapper class like this: rdr = ChannelSeparator(ImageReader()) - ''' + """ IFormatReader = make_iformat_reader_class() + class ReaderWrapper(IFormatReader): - __doc__ = '''A wrapper for %s + __doc__ = ( + """A wrapper for %s See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/ImageReader.html - '''%class_name - new_fn = jutil.make_new(class_name, '(Lloci/formats/IFormatReader;)V') + """ + % class_name + ) + new_fn = jutil.make_new(class_name, "(Lloci/formats/IFormatReader;)V") + def __init__(self, rdr): self.new_fn(rdr) - setId = jutil.make_method('setId', '(Ljava/lang/String;)V', - 'Set the name of the data file') + setId = jutil.make_method( + "setId", "(Ljava/lang/String;)V", "Set the name of the data file" + ) + return ReaderWrapper + __has_omero_jars = None + + def has_omero_packages(): - '''Return True if we can find the packages needed for OMERO + """Return True if we can find the packages needed for OMERO In order to run OMERO, you'll need the OMERO client and ICE on your class path (not supplied with python-bioformats and specific to your server's version) - ''' + """ global __has_omero_jars if __has_omero_jars is None: class_loader = jutil.static_call( - "java/lang/ClassLoader", "getSystemClassLoader", - "()Ljava/lang/ClassLoader;") - for klass in ("Glacier2.PermissionDeniedException", - "loci.ome.io.OmeroReader", "omero.client"): + "java/lang/ClassLoader", "getSystemClassLoader", "()Ljava/lang/ClassLoader;" + ) + for klass in ( + "Glacier2.PermissionDeniedException", + "loci.ome.io.OmeroReader", + "omero.client", + ): try: jutil.call( - class_loader, "loadClass", - "(Ljava/lang/String;)Ljava/lang/Class;", klass) + class_loader, + "loadClass", + "(Ljava/lang/String;)Ljava/lang/Class;", + klass, + ) except: __has_omero_jars = False break @@ -388,6 +482,7 @@ def has_omero_packages(): __has_omero_jars = True return __has_omero_jars + __omero_server = None __omero_username = None __omero_session_id = None @@ -398,8 +493,9 @@ def has_omero_packages(): # __omero_password = None + def set_omero_credentials(omero_server, omero_port, omero_username, omero_password): - '''Set the credentials to be used to connect to the Omero server + """Set the credentials to be used to connect to the Omero server :param omero_server: DNS name of the server @@ -412,7 +508,7 @@ def set_omero_credentials(omero_server, omero_port, omero_username, omero_passwo The session ID is valid after this function is called. An exception is thrown if the login fails. :func:`bioformats.omero_logout()` can be called to log out. - ''' + """ global __omero_server global __omero_username global __omero_session_id @@ -425,60 +521,74 @@ def set_omero_credentials(omero_server, omero_port, omero_username, omero_passwo var serverFactory = client.createSession(user, password); client.getSessionId(); """ - __omero_session_id = jutil.run_script(script, dict( - server = __omero_server, - port = __omero_port, - user = __omero_username, - password = omero_password)) + __omero_session_id = jutil.run_script( + script, + dict( + server=__omero_server, + port=__omero_port, + user=__omero_username, + password=omero_password, + ), + ) return __omero_session_id + def get_omero_credentials(): - '''Return a pickleable dictionary representing the Omero credentials. + """Return a pickleable dictionary representing the Omero credentials. Call :func:`bioformats.use_omero_credentials` in some other process to use this. - ''' + """ if __omero_session_id is None: omero_login() - return dict(omero_server = __omero_server, - omero_port = __omero_port, - omero_user = __omero_username, - omero_session_id = __omero_session_id) + return dict( + omero_server=__omero_server, + omero_port=__omero_port, + omero_user=__omero_username, + omero_session_id=__omero_session_id, + ) + def omero_login(): if __omero_config_file is not None and os.path.isfile(__omero_config_file): env = jutil.get_env() config = env.make_object_array(1, env.find_class("java/lang/String")) env.set_object_array_element( - config, 0, env.new_string("--Ice.Config=%s" % __omero_config_file)) + config, 0, env.new_string("--Ice.Config=%s" % __omero_config_file) + ) script = """ var client = Packages.omero.client(config); client.createSession(); client.getSessionId(); """ __omero_session_id = jutil.run_script(script, dict(config=config)) - elif all([x is not None for x in - (__omero_server, __omero_port, __omero_username, __omero_password)]): - set_omero_credentials(__omero_server, __omero_port, __omero_username, - __omero_password) + elif all( + [ + x is not None + for x in (__omero_server, __omero_port, __omero_username, __omero_password) + ] + ): + set_omero_credentials( + __omero_server, __omero_port, __omero_username, __omero_password + ) else: __omero_login_fn() return __omero_session_id -def omero_logout(): - '''Abandon any current Omero session. - ''' +def omero_logout(): + """Abandon any current Omero session.""" global __omero_session_id __omero_session_id = None + def use_omero_credentials(credentials): - '''Use the session ID from an existing login as credentials. + """Use the session ID from an existing login as credentials. :param credentials: credentials from get_omero_credentials. - ''' + """ global __omero_server global __omero_username global __omero_session_id @@ -492,18 +602,18 @@ def use_omero_credentials(credentials): __omero_config_file = credentials.get(K_OMERO_CONFIG_FILE, None) __omero_password = credentials.get(K_OMERO_PASSWORD, None) + __omero_login_fn = None -def set_omero_login_hook(fn): - '''Set the function to be called when a login to Omero is needed. - ''' + +def set_omero_login_hook(fn): + """Set the function to be called when a login to Omero is needed.""" global __omero_login_fn __omero_login_fn = fn -def get_omero_reader(): - '''Return an ``loci.ome.io.OMEROReader`` instance, wrapped as a FormatReader. - ''' +def get_omero_reader(): + """Return an ``loci.ome.io.OMEROReader`` instance, wrapped as a FormatReader.""" script = """ var rdr = new Packages.loci.ome.io.OmeroReader(); rdr.setServer(server); @@ -515,31 +625,41 @@ def get_omero_reader(): if __omero_session_id is None: omero_login() - jrdr = jutil.run_script(script, dict( - server = __omero_server, - port = __omero_port, - username = __omero_username, - sessionID = __omero_session_id)) + jrdr = jutil.run_script( + script, + dict( + server=__omero_server, + port=__omero_port, + username=__omero_username, + sessionID=__omero_session_id, + ), + ) rdr = make_iformat_reader_class()() rdr.o = jrdr return rdr -def load_using_bioformats_url(url, c=None, z=0, t=0, series=None, index=None, - rescale = True, - wants_max_intensity = False, - channel_names = None): - '''Load a file from Bio-formats via a URL - - ''' +def load_using_bioformats_url( + url, + c=None, + z=0, + t=0, + series=None, + index=None, + rescale=True, + wants_max_intensity=False, + channel_names=None, +): + """Load a file from Bio-formats via a URL""" with ImageReader(url=url) as rdr: - return rdr.read(c, z, t, series, index, rescale, wants_max_intensity, - channel_names) + return rdr.read( + c, z, t, series, index, rescale, wants_max_intensity, channel_names + ) class ImageReader(object): - '''Find the appropriate reader for a file. + """Find the appropriate reader for a file. This class is meant to be harnessed to a scope like this: @@ -549,7 +669,7 @@ class ImageReader(object): It uses `__enter__` and `__exit__` to manage the random access stream that can be used to cache the file contents in memory. - ''' + """ def __init__(self, path=None, url=None, perform_init=True): self.stream = None @@ -559,7 +679,7 @@ def __init__(self, path=None, url=None, perform_init=True): if url is not None: url = str(url) if url.lower().startswith(file_scheme): - url = url2pathname(url[len(file_scheme):]) + url = url2pathname(url[len(file_scheme) :]) path = url self.path = path @@ -578,12 +698,11 @@ def __init__(self, path=None, url=None, perform_init=True): return except jutil.JavaException as e: je = e.throwable + if jutil.is_instance_of(je, "loci/formats/FormatException"): + je = jutil.call(je, "getCause", "()Ljava/lang/Throwable;") if jutil.is_instance_of( - je, "loci/formats/FormatException"): - je = jutil.call(je, "getCause", - "()Ljava/lang/Throwable;") - if jutil.is_instance_of( - je, "Glacier2/PermissionDeniedException"): + je, "Glacier2/PermissionDeniedException" + ): omero_logout() omero_login() else: @@ -591,13 +710,18 @@ def __init__(self, path=None, url=None, perform_init=True): for line in traceback.format_exc().split("\n"): logger.warn(line) if jutil.is_instance_of( - je, "java/io/FileNotFoundException"): + je, "java/io/FileNotFoundException" + ): raise IOError( errno.ENOENT, - "The file, \"%s\", does not exist." % path, - path) + 'The file, "%s", does not exist.' % path, + path, + ) e2 = IOError( - errno.EINVAL, "Could not load the file as an image (see log for details)", path.encode('utf-8')) + errno.EINVAL, + "Could not load the file as an image (see log for details)", + path.encode("utf-8"), + ) raise e2 else: # @@ -610,14 +734,11 @@ def __init__(self, path=None, url=None, perform_init=True): filename = os.path.split(path)[1] if not os.path.isfile(self.path): - raise IOError( - errno.ENOENT, - "The file, \"%s\", does not exist." % path, - path) + raise IOError(errno.ENOENT, 'The file, "%s", does not exist.' % path, path) - self.stream = jutil.make_instance('loci/common/RandomAccessInputStream', - '(Ljava/lang/String;)V', - self.path) + self.stream = jutil.make_instance( + "loci/common/RandomAccessInputStream", "(Ljava/lang/String;)V", self.path + ) self.rdr = None class_list = get_class_list() @@ -657,9 +778,10 @@ def __init__(self, path=None, url=None, perform_init=True): rdr; """ IFormatReader = make_iformat_reader_class() - jrdr = jutil.run_script(find_rdr_script, dict(class_list = class_list, - filename = filename, - stream = self.stream)) + jrdr = jutil.run_script( + find_rdr_script, + dict(class_list=class_list, filename=filename, stream=self.stream), + ) if jrdr is None: raise ValueError("Could not find a Bio-Formats reader for %s", self.path) self.rdr = IFormatReader() @@ -669,24 +791,26 @@ def __init__(self, path=None, url=None, perform_init=True): def download(self, url): scheme = urlparse(url)[0] - ext = url[url.rfind("."):] + ext = url[url.rfind(".") :] urlpath = urlparse(url)[2] filename = unquote(urlpath.split("/")[-1]) self.using_temp_file = True - if scheme == 's3': - client = boto3.client('s3') - bucket_name, key = re.compile('s3://([\w\d\-\.]+)/(.*)').search(url).groups() + if scheme == "s3": + client = boto3.client("s3") + bucket_name, key = ( + re.compile("s3://([\w\d\-\.]+)/(.*)").search(url).groups() + ) url = client.generate_presigned_url( - 'get_object', - Params={'Bucket': bucket_name, 'Key': key.replace("+", " ")} + "get_object", + Params={"Bucket": bucket_name, "Key": key.replace("+", " ")}, ) cellacdc = urlopen(url) dest_fd, self.path = tempfile.mkstemp(suffix=ext) try: - with os.fdopen(dest_fd, 'wb') as dest: + with os.fdopen(dest_fd, "wb") as dest: shutil.copyfileobj(cellacdc, dest) except: os.remove(self.path) @@ -707,7 +831,7 @@ def close(self): del self.rdr.o del self.rdr if hasattr(self, "stream") and self.stream is not None: - jutil.call(self.stream, 'close', '()V') + jutil.call(self.stream, "close", "()V") del self.stream if self.using_temp_file: os.remove(self.path) @@ -715,7 +839,7 @@ def close(self): # # Run the Java garbage collector here. # - jutil.static_call("java/lang/System", "gc","()V") + jutil.static_call("java/lang/System", "gc", "()V") def init_reader(self): mdoptions = metadatatools.get_metadata_options(metadatatools.ALL) @@ -731,29 +855,38 @@ def init_reader(self): logger.warn(line) je = e.throwable if has_omero_packages() and jutil.is_instance_of( - je, "Glacier2/PermissionDeniedException"): + je, "Glacier2/PermissionDeniedException" + ): # Handle at a higher level raise - if jutil.is_instance_of( - je, "loci/formats/FormatException"): - je = jutil.call(je, "getCause", - "()Ljava/lang/Throwable;") - if jutil.is_instance_of( - je, "java/io/FileNotFoundException"): + if jutil.is_instance_of(je, "loci/formats/FormatException"): + je = jutil.call(je, "getCause", "()Ljava/lang/Throwable;") + if jutil.is_instance_of(je, "java/io/FileNotFoundException"): raise IOError( errno.ENOENT, - "The file, \"%s\", does not exist." % self.path, - self.path) + 'The file, "%s", does not exist.' % self.path, + self.path, + ) e2 = IOError( - errno.EINVAL, "Could not load the file as an image (see log for details)", - self.path.encode('utf-8')) + errno.EINVAL, + "Could not load the file as an image (see log for details)", + self.path.encode("utf-8"), + ) raise e2 - - def read(self, c = None, z = 0, t = 0, series = None, index = None, - rescale = True, wants_max_intensity = False, channel_names = None, - XYWH=None): - '''Read a single plane from the image reader file. + def read( + self, + c=None, + z=0, + t=0, + series=None, + index=None, + rescale=True, + wants_max_intensity=False, + channel_names=None, + XYWH=None, + ): + """Read a single plane from the image reader file. :param c: read from this channel. `None` = read color image if multichannel or interleaved RGB. :param z: z-stack index @@ -766,17 +899,18 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, return a tuple of image and max intensity :param channel_names: provide the channel names for the OME metadata :param XYWH: a (x, y, w, h) tuple - ''' + """ FormatTools = make_format_tools_class() - ChannelSeparator = make_reader_wrapper_class( - "loci/formats/ChannelSeparator") + ChannelSeparator = make_reader_wrapper_class("loci/formats/ChannelSeparator") env = jutil.get_env() if series is not None: self.rdr.setSeries(series) if XYWH is not None: assert isinstance(XYWH, tuple) and len(XYWH) == 4, "Invalid XYWH tuple" - openBytes_func = lambda x: self.rdr.openBytesXYWH(x, XYWH[0], XYWH[1], XYWH[2], XYWH[3]) + openBytes_func = lambda x: self.rdr.openBytesXYWH( + x, XYWH[0], XYWH[1], XYWH[2], XYWH[3] + ) width, height = XYWH[2], XYWH[3] else: openBytes_func = self.rdr.openBytes @@ -791,32 +925,34 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, dtype = np.uint8 scale = 255 elif pixel_type == FormatTools.UINT16: - dtype = 'u2' + dtype = "u2" scale = 65535 elif pixel_type == FormatTools.INT16: - dtype = 'i2' + dtype = "i2" scale = 65535 elif pixel_type == FormatTools.UINT32: - dtype = 'u4' + dtype = "u4" scale = 2147483647 elif pixel_type == FormatTools.INT32: - dtype = 'i4' - scale = 2147483647-1 + dtype = "i4" + scale = 2147483647 - 1 elif pixel_type == FormatTools.FLOAT: - dtype = 'f4' + dtype = "f4" scale = 1 elif pixel_type == FormatTools.DOUBLE: - dtype = 'f8' + dtype = "f8" scale = 1 - max_sample_value = self.rdr.getMetadataValue('MaxSampleValue') + max_sample_value = self.rdr.getMetadataValue("MaxSampleValue") if max_sample_value is not None: try: - scale = jutil.call(max_sample_value, 'intValue', '()I') + scale = jutil.call(max_sample_value, "intValue", "()I") except: - logger.warning("WARNING: failed to get MaxSampleValue for image. Intensities may be improperly scaled.") + logger.warning( + "WARNING: failed to get MaxSampleValue for image. Intensities may be improperly scaled." + ) if index is not None: image = np.frombuffer(openBytes_func(index), dtype) - if len(image) / height / width in (3,4): + if len(image) / height / width in (3, 4): n_channels = int(len(image) / height / width) if self.rdr.isInterleaved(): image.shape = (height, width, n_channels) @@ -826,13 +962,13 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, else: image.shape = (height, width) elif self.rdr.isRGB() and self.rdr.isInterleaved(): - index = self.rdr.getIndex(z,0,t) + index = self.rdr.getIndex(z, 0, t) image = np.frombuffer(openBytes_func(index), dtype) image.shape = (height, width, self.rdr.getSizeC()) if image.shape[2] > 3: image = image[:, :, :3] elif c is not None and self.rdr.getRGBChannelCount() == 1: - index = self.rdr.getIndex(z,c,t) + index = self.rdr.getIndex(z, c, t) image = np.frombuffer(openBytes_func(index), dtype) image.shape = (height, width) elif self.rdr.getRGBChannelCount() > 1: @@ -840,10 +976,17 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, rdr = ChannelSeparator(self.rdr) planes = [ np.frombuffer( - (rdr.openBytes(rdr.getIndex(z,i,t)) if XYWH is None else - rdr.openBytesXYWH(rdr.getIndex(z,i,t), XYWH[0], XYWH[1], XYWH[2], XYWH[3])), - dtype - ) for i in range(n_planes)] + ( + rdr.openBytes(rdr.getIndex(z, i, t)) + if XYWH is None + else rdr.openBytesXYWH( + rdr.getIndex(z, i, t), XYWH[0], XYWH[1], XYWH[2], XYWH[3] + ) + ), + dtype, + ) + for i in range(n_planes) + ] if len(planes) > 3: planes = planes[:3] @@ -852,12 +995,13 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, # see issue #775 planes.append(np.zeros(planes[0].shape, planes[0].dtype)) image = np.dstack(planes) - image.shape=(height, width, 3) + image.shape = (height, width, 3) del rdr elif self.rdr.getSizeC() > 1: images = [ - np.frombuffer(openBytes_func(self.rdr.getIndex(z,i,t)), dtype) - for i in range(self.rdr.getSizeC())] + np.frombuffer(openBytes_func(self.rdr.getIndex(z, i, t)), dtype) + for i in range(self.rdr.getSizeC()) + ] image = np.dstack(images) image.shape = (height, width, self.rdr.getSizeC()) if not channel_names is None: @@ -874,30 +1018,35 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, # But sometimes the table is the identity table and just generates # a monochrome RGB image # - index = self.rdr.getIndex(z,0,t) - image = np.frombuffer(openBytes_func(index),dtype) + index = self.rdr.getIndex(z, 0, t) + image = np.frombuffer(openBytes_func(index), dtype) if pixel_type in (FormatTools.INT16, FormatTools.UINT16): lut = self.rdr.get16BitLookupTable() if lut is not None: lut = np.array( - [env.get_short_array_elements(d) - for d in env.get_object_array_elements(lut)])\ - .transpose() + [ + env.get_short_array_elements(d) + for d in env.get_object_array_elements(lut) + ] + ).transpose() else: lut = self.rdr.get8BitLookupTable() if lut is not None: lut = np.array( - [env.get_byte_array_elements(d) - for d in env.get_object_array_elements(lut)])\ - .transpose() + [ + env.get_byte_array_elements(d) + for d in env.get_object_array_elements(lut) + ] + ).transpose() image.shape = (height, width) - if (lut is not None) \ - and not np.all(lut == np.arange(lut.shape[0])[:, np.newaxis]): + if (lut is not None) and not np.all( + lut == np.arange(lut.shape[0])[:, np.newaxis] + ): image = lut[image, :] else: - index = self.rdr.getIndex(z,0,t) - image = np.frombuffer(openBytes_func(index),dtype) - image.shape = (height,width) + index = self.rdr.getIndex(z, 0, t) + image = np.frombuffer(openBytes_func(index), dtype) + image.shape = (height, width) if rescale: image = image.astype(np.float32) / float(scale) @@ -905,6 +1054,7 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, return image, scale return image + ################### # # A cache mechanism for image readers @@ -924,14 +1074,15 @@ def read(self, c = None, z = 0, t = 0, series = None, index = None, # The image reader cache associates path/url with a reader __image_reader_cache = {} + def get_image_reader(key, path=None, url=None): - '''Make or find an image reader appropriate for the given path + """Make or find an image reader appropriate for the given path path - pathname to the reader on disk. key - use this key to keep only a single cache member associated with that key open at a time. - ''' + """ logger.debug("Getting image reader for: %s, %s, %s" % (key, path, url)) if key in __image_reader_key_cache: old_path, old_url = __image_reader_key_cache[key] @@ -946,23 +1097,26 @@ def get_image_reader(key, path=None, url=None): # is True OMERO python reader can be used to directly request # the image pixels from the server. # Following this route gives almost 10x speed up. - if OMERO_READER_IMPORTED and OMERO_IMPORTED and \ - omero_reader_enabled() and \ - url is not None and url.lower().startswith("omero:"): + if ( + OMERO_READER_IMPORTED + and OMERO_IMPORTED + and omero_reader_enabled() + and url is not None + and url.lower().startswith("omero:") + ): logger.debug("Initializing Python reader.") rdr = OmeroReader(__omero_server, __omero_session_id, url=url) else: logger.debug("Falling back to Java reader.") rdr = ImageReader(path=path, url=url) old_count = 0 - __image_reader_cache[path, url] = (old_count+1, rdr) + __image_reader_cache[path, url] = (old_count + 1, rdr) __image_reader_key_cache[key] = (path, url) return rdr -def release_image_reader(key): - '''Tell the cache that it should flush the reference for the given key - ''' +def release_image_reader(key): + """Tell the cache that it should flush the reference for the given key""" if key in __image_reader_key_cache: path, url = __image_reader_key_cache[key] del __image_reader_key_cache[key] @@ -971,21 +1125,30 @@ def release_image_reader(key): rdr.close() del __image_reader_cache[path, url] else: - __image_reader_cache[path, url] = (old_count-1, rdr) + __image_reader_cache[path, url] = (old_count - 1, rdr) + def clear_image_reader_cache(): - '''Get rid of any open image readers''' + """Get rid of any open image readers""" for use_count, rdr in __image_reader_cache.values(): logger.debug("Closing reader %s" % rdr) rdr.close() __image_reader_cache.clear() __image_reader_key_cache.clear() -def load_using_bioformats(path, c=None, z=0, t=0, series=None, index=None, - rescale = True, - wants_max_intensity = False, - channel_names = None): - '''Load the given image file using the Bioformats library. + +def load_using_bioformats( + path, + c=None, + z=0, + t=0, + series=None, + index=None, + rescale=True, + wants_max_intensity=False, + channel_names=None, +): + """Load the given image file using the Bioformats library. :param path: path to the file :param z: the frame index in the `z` (depth) dimension. @@ -994,14 +1157,16 @@ def load_using_bioformats(path, c=None, z=0, t=0, series=None, index=None, :returns: either a 2-d (grayscale) or 3-d (2-d + 3 RGB planes) image. - ''' + """ with ImageReader(path=path) as rdr: - return rdr.read(c, z, t, series, index, rescale, wants_max_intensity, - channel_names) + return rdr.read( + c, z, t, series, index, rescale, wants_max_intensity, channel_names + ) + def get_omexml_metadata(path=None, url=None): - '''Read the OME metadata from a file using Bio-formats + """Read the OME metadata from a file using Bio-formats :param path: path to the file @@ -1010,7 +1175,7 @@ def get_omexml_metadata(path=None, url=None): :returns: the metdata as XML. - ''' + """ with ImageReader(path=path, url=url, perform_init=False) as rdr: # # Below, "in" is a keyword and Rhino's parser is just a little wonky I fear. @@ -1033,5 +1198,5 @@ def get_omexml_metadata(path=None, url=None): var xml = service.getOMEXML(metadata); xml; """ - xml = jutil.run_script(script, dict(path=rdr.path, reader = rdr.rdr)) + xml = jutil.run_script(script, dict(path=rdr.path, reader=rdr.rdr)) return xml diff --git a/cellacdc/bioformats/formatwriter.py b/cellacdc/bioformats/formatwriter.py index c60c10ca0..834a002d6 100755 --- a/cellacdc/bioformats/formatwriter.py +++ b/cellacdc/bioformats/formatwriter.py @@ -5,7 +5,7 @@ # Copyright (c) 2009-2014 Broad Institute # All rights reserved. -'''formatwriter.py - mechanism to wrap a bioformats WriterWrapper and ImageWriter +"""formatwriter.py - mechanism to wrap a bioformats WriterWrapper and ImageWriter The following file formats can be written using Bio-Formats: @@ -28,7 +28,7 @@ and is especially useful for formats that do not support multiple images per file. -''' +""" from __future__ import absolute_import, print_function, unicode_literals @@ -43,10 +43,19 @@ import javabridge from ..bioformats import omexml as ome -def write_image(pathname, pixels, pixel_type, - c = 0, z = 0, t = 0, - size_c = 1, size_z = 1, size_t = 1, - channel_names = None): + +def write_image( + pathname, + pixels, + pixel_type, + c=0, + z=0, + t=0, + size_c=1, + size_z=1, + size_t=1, + channel_names=None, +): """Write the image using bioformats. :param filename: save to this filename @@ -86,7 +95,8 @@ def write_image(pathname, pixels, pixel_type, p.SizeC = pixels.shape[2] p.Channel(0).SamplesPerPixel = pixels.shape[2] omexml.structured_annotations.add_original_metadata( - ome.OM_SAMPLES_PER_PIXEL, str(pixels.shape[2])) + ome.OM_SAMPLES_PER_PIXEL, str(pixels.shape[2]) + ) elif size_c > 1: p.channel_count = size_c @@ -105,21 +115,20 @@ def write_image(pathname, pixels, pixel_type, writer.saveBytes(index, buffer); writer.close(); """ - jutil.run_script(script, - dict(path=pathname, - xml=xml, - index=index, - buffer=pixel_buffer)) + jutil.run_script( + script, dict(path=pathname, xml=xml, index=index, buffer=pixel_buffer) + ) + def convert_pixels_to_buffer(pixels, pixel_type): - '''Convert the pixels in the image into a buffer of the right pixel type + """Convert the pixels in the image into a buffer of the right pixel type pixels - a 2d monochrome or color image pixel_type - one of the OME pixel types returns a 1-d byte array - ''' + """ if pixel_type in (ome.PT_UINT8, ome.PT_INT8, ome.PT_BIT): as_dtype = np.uint8 elif pixel_type in (ome.PT_UINT16, ome.PT_INT16): @@ -136,269 +145,404 @@ def convert_pixels_to_buffer(pixels, pixel_type): env = jutil.get_env() return env.make_byte_array(buf) + def make_iformat_writer_class(class_name): - '''Bind a Java class that implements IFormatWriter to a Python class + """Bind a Java class that implements IFormatWriter to a Python class Returns a class that implements IFormatWriter through calls to the implemented class passed in. The returned class can be subclassed to provide additional bindings. - ''' + """ + class IFormatWriter(object): - '''A wrapper for loci.formats.IFormatWriter + """A wrapper for loci.formats.IFormatWriter See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/ImageWriter.html - ''' - canDoStacks = jutil.make_method('canDoStacks', '()Z', - 'Reports whether the writer can save multiple images to a single file.') - getColorModel = jutil.make_method('getColorModel', '()Ljava/awt/image/ColorModel;', - 'Gets the color model.') - getCompression = jutil.make_method('getCompression', '()Ljava/lang/String;', - 'Gets the current compression type.') - getCompressionTypes = jutil.make_method('getCompressionTypes', '()[Ljava/lang/String;', - 'Gets the available compression types.') - getFramesPerSecond = jutil.make_method('getFramesPerSecond', '()I', - 'Gets the frames per second to use when writing.') - getMetadataRetrieve = jutil.make_method('getMetadataRetrieve', '()Lloci/formats/meta/MetadataRetrieve;', - 'Retrieves the current metadata retrieval object for this writer.') - getPixelTypes = jutil.make_method('getPixelTypes', '()[I', - 'Gets the supported pixel types.') -# getPixelTypes = jutil.make_method('getPixelTypes', '(Ljava/lang/String;)[I', -# 'Gets the supported pixel types for the given codec.') - isInterleaved = jutil.make_method('isInterleaved', '()Z', - 'Gets whether or not the channels in an image are interleaved.') - isSupportedType = jutil.make_method('isSupportedType', '(I)Z', - 'Checks if the given pixel type is supported.') - saveBytes = jutil.make_method('saveBytes', '([BZ)V', - 'Saves the given byte array to the current file.') - saveBytesIB = jutil.make_method('saveBytes', '(I[B)V', - 'Saves bytes, first arg is image #') -# saveBytes = jutil.make_method('saveBytes', '([BIZZ)V', -# 'Saves the given byte array to the given series in the current file.') - savePlane = jutil.make_method('savePlane', '(Ljava/lang/Object;Z)V', - 'Saves the given image plane to the current file.') -# savePlane = jutil.make_method('savePlane', '(Ljava/lang/Object;IZZ)V', -# 'Saves the given image plane to the given series in the current file.') - setColorModel = jutil.make_method('setColorModel', '(Ljava/awt/image/ColorModel;)V', - 'Sets the color model.') - setCompression = jutil.make_method('setCompression', '(Ljava/lang/String;)V', - 'Sets the current compression type.') - setFramesPerSecond = jutil.make_method('setFramesPerSecond', '(I)V', - 'Sets the frames per second to use when writing.') - setInterleaved = jutil.make_method('setInterleaved', '(Z)V', - 'Sets whether or not the channels in an image are interleaved.') - setMetadataRetrieve = jutil.make_method('setMetadataRetrieve', '(Lloci/formats/meta/MetadataRetrieve;)V', - 'Sets the metadata retrieval object from which to retrieve standardized metadata.') + """ + + canDoStacks = jutil.make_method( + "canDoStacks", + "()Z", + "Reports whether the writer can save multiple images to a single file.", + ) + getColorModel = jutil.make_method( + "getColorModel", "()Ljava/awt/image/ColorModel;", "Gets the color model." + ) + getCompression = jutil.make_method( + "getCompression", + "()Ljava/lang/String;", + "Gets the current compression type.", + ) + getCompressionTypes = jutil.make_method( + "getCompressionTypes", + "()[Ljava/lang/String;", + "Gets the available compression types.", + ) + getFramesPerSecond = jutil.make_method( + "getFramesPerSecond", + "()I", + "Gets the frames per second to use when writing.", + ) + getMetadataRetrieve = jutil.make_method( + "getMetadataRetrieve", + "()Lloci/formats/meta/MetadataRetrieve;", + "Retrieves the current metadata retrieval object for this writer.", + ) + getPixelTypes = jutil.make_method( + "getPixelTypes", "()[I", "Gets the supported pixel types." + ) + # getPixelTypes = jutil.make_method('getPixelTypes', '(Ljava/lang/String;)[I', + # 'Gets the supported pixel types for the given codec.') + isInterleaved = jutil.make_method( + "isInterleaved", + "()Z", + "Gets whether or not the channels in an image are interleaved.", + ) + isSupportedType = jutil.make_method( + "isSupportedType", "(I)Z", "Checks if the given pixel type is supported." + ) + saveBytes = jutil.make_method( + "saveBytes", "([BZ)V", "Saves the given byte array to the current file." + ) + saveBytesIB = jutil.make_method( + "saveBytes", "(I[B)V", "Saves bytes, first arg is image #" + ) + # saveBytes = jutil.make_method('saveBytes', '([BIZZ)V', + # 'Saves the given byte array to the given series in the current file.') + savePlane = jutil.make_method( + "savePlane", + "(Ljava/lang/Object;Z)V", + "Saves the given image plane to the current file.", + ) + # savePlane = jutil.make_method('savePlane', '(Ljava/lang/Object;IZZ)V', + # 'Saves the given image plane to the given series in the current file.') + setColorModel = jutil.make_method( + "setColorModel", "(Ljava/awt/image/ColorModel;)V", "Sets the color model." + ) + setCompression = jutil.make_method( + "setCompression", + "(Ljava/lang/String;)V", + "Sets the current compression type.", + ) + setFramesPerSecond = jutil.make_method( + "setFramesPerSecond", + "(I)V", + "Sets the frames per second to use when writing.", + ) + setInterleaved = jutil.make_method( + "setInterleaved", + "(Z)V", + "Sets whether or not the channels in an image are interleaved.", + ) + setMetadataRetrieve = jutil.make_method( + "setMetadataRetrieve", + "(Lloci/formats/meta/MetadataRetrieve;)V", + "Sets the metadata retrieval object from which to retrieve standardized metadata.", + ) setValidBitsPerPixel = jutil.make_method( - 'setValidBitsPerPixel', '(I)V', - 'Sets the number of valid bits per pixel') + "setValidBitsPerPixel", "(I)V", "Sets the number of valid bits per pixel" + ) setSeries = jutil.make_method( - 'setSeries', '(I)V', - '''Set the series for the image file + "setSeries", + "(I)V", + """Set the series for the image file series - the zero-based index of the image stack in the file, - for instance in a multi-image tif.''') + for instance in a multi-image tif.""", + ) return IFormatWriter + def make_image_writer_class(): - '''Return an image writer class for the given Java environment''' + """Return an image writer class for the given Java environment""" env = jutil.get_env() - class_name = 'loci/formats/ImageWriter' + class_name = "loci/formats/ImageWriter" klass = env.find_class(class_name) - base_klass = env.find_class('loci/formats/IFormatWriter') + base_klass = env.find_class("loci/formats/IFormatWriter") IFormatWriter = make_iformat_writer_class(class_name) # # This uses the writers.txt file from inside the loci_tools.jar # - class_list = jutil.make_instance("loci/formats/ClassList", - "(Ljava/lang/String;" - "Ljava/lang/Class;" # base - "Ljava/lang/Class;)V", # location in jar - "writers.txt", base_klass, klass) + class_list = jutil.make_instance( + "loci/formats/ClassList", + "(Ljava/lang/String;" + "Ljava/lang/Class;" # base + "Ljava/lang/Class;)V", # location in jar + "writers.txt", + base_klass, + klass, + ) + class ImageWriter(IFormatWriter): - new_fn = jutil.make_new(class_name, '(Lloci/formats/ClassList;)V') + new_fn = jutil.make_new(class_name, "(Lloci/formats/ClassList;)V") + def __init__(self): self.new_fn(class_list) - setId = jutil.make_method('setId', '(Ljava/lang/String;)V', - 'Sets the current file name.') - addStatusListener = jutil.make_method('addStatusListener', '()Lloci/formats/StatusListener;', - 'Adds a listener for status update events.') - close = jutil.make_method('close','()V', - 'Closes currently open file(s) and frees allocated memory.') - getFormat = jutil.make_method('getFormat', '()Ljava/lang/String;', - 'Gets the name of this file format.') - getNativeDataType = jutil.make_method('getNativeDataType', '()Ljava/lang/Class;', - 'Returns the native data type of image planes for this reader, as returned by IFormatReader.openPlane(int, int, int, int, int) or IFormatWriter#saveData.') - getStatusListeners = jutil.make_method('getStatusListeners', '()[Lloci/formats/StatusListener;', - 'Gets a list of all registered status update listeners.') - getSuffixes = jutil.make_method('getSuffixes', '()Ljava/lang/String;', - 'Gets the default file suffixes for this file format.') - getWriter = jutil.make_method('getWriter', '()Lloci/formats/IFormatWriter;', - 'Gets the writer used to save the current file.') -# getWriter = jutil.make_method('getWriter', '(Ljava/lang/Class)Lloci/formats/IFormatWriter;', -# 'Gets the file format writer instance matching the given class.') -# getWriter = jutil.make_method('getWriter', '(Ljava/lang/String;)Lloci/formats/IFormatWriter;', -# 'Gets the writer used to save the given file.') - getWriters = jutil.make_method('getWriters', '()[Lloci/formats/IFormatWriter;', - 'Gets all constituent file format writers.') - isThisType = jutil.make_method('isThisType', '(Ljava/lang/String;)Z', - 'Checks if the given string is a valid filename for this file format.') - removeStatusListener = jutil.make_method('removeStatusListener', '(Lloci/formats/StatusListener;)V', - 'Saves the given byte array to the current file.') + setId = jutil.make_method( + "setId", "(Ljava/lang/String;)V", "Sets the current file name." + ) + addStatusListener = jutil.make_method( + "addStatusListener", + "()Lloci/formats/StatusListener;", + "Adds a listener for status update events.", + ) + close = jutil.make_method( + "close", "()V", "Closes currently open file(s) and frees allocated memory." + ) + getFormat = jutil.make_method( + "getFormat", "()Ljava/lang/String;", "Gets the name of this file format." + ) + getNativeDataType = jutil.make_method( + "getNativeDataType", + "()Ljava/lang/Class;", + "Returns the native data type of image planes for this reader, as returned by IFormatReader.openPlane(int, int, int, int, int) or IFormatWriter#saveData.", + ) + getStatusListeners = jutil.make_method( + "getStatusListeners", + "()[Lloci/formats/StatusListener;", + "Gets a list of all registered status update listeners.", + ) + getSuffixes = jutil.make_method( + "getSuffixes", + "()Ljava/lang/String;", + "Gets the default file suffixes for this file format.", + ) + getWriter = jutil.make_method( + "getWriter", + "()Lloci/formats/IFormatWriter;", + "Gets the writer used to save the current file.", + ) + # getWriter = jutil.make_method('getWriter', '(Ljava/lang/Class)Lloci/formats/IFormatWriter;', + # 'Gets the file format writer instance matching the given class.') + # getWriter = jutil.make_method('getWriter', '(Ljava/lang/String;)Lloci/formats/IFormatWriter;', + # 'Gets the writer used to save the given file.') + getWriters = jutil.make_method( + "getWriters", + "()[Lloci/formats/IFormatWriter;", + "Gets all constituent file format writers.", + ) + isThisType = jutil.make_method( + "isThisType", + "(Ljava/lang/String;)Z", + "Checks if the given string is a valid filename for this file format.", + ) + removeStatusListener = jutil.make_method( + "removeStatusListener", + "(Lloci/formats/StatusListener;)V", + "Saves the given byte array to the current file.", + ) + return ImageWriter + def make_ome_tiff_writer_class(): - '''Return a class that wraps loci.formats.out.OMETiffWriter''' - class_name = 'loci/formats/out/OMETiffWriter' + """Return a class that wraps loci.formats.out.OMETiffWriter""" + class_name = "loci/formats/out/OMETiffWriter" IFormatWriter = make_iformat_writer_class(class_name) class OMETiffWriter(IFormatWriter): - def __init__(self): - self.new_fn = jutil.make_new(self.class_name, '()V') - self.setId = jutil.make_method('setId', '(Ljava/lang/String;)V', - 'Sets the current file name.') + self.new_fn = jutil.make_new(self.class_name, "()V") + self.setId = jutil.make_method( + "setId", "(Ljava/lang/String;)V", "Sets the current file name." + ) self.close = jutil.make_method( - 'close','()V', - 'Closes currently open file(s) and frees allocated memory.') + "close", + "()V", + "Closes currently open file(s) and frees allocated memory.", + ) self.saveBytesIFD = jutil.make_method( - 'saveBytes', '(I[BLloci/formats/tiff/IFD;)V', - '''save a byte array to an image channel + "saveBytes", + "(I[BLloci/formats/tiff/IFD;)V", + """save a byte array to an image channel index - image index bytes - byte array to save ifd - a loci.formats.tiff.IFD instance that gives all of the - IFD values associated with the channel''') + IFD values associated with the channel""", + ) self.new_fn() return OMETiffWriter + def make_writer_wrapper_class(class_name): - '''Make an ImageWriter wrapper class + """Make an ImageWriter wrapper class class_name - the name of the wrapper class You can instantiate an instance of the wrapper class like this: writer = XXX(ImageWriter()) - ''' + """ IFormatWriter = make_iformat_writer_class(class_name) + class WriterWrapper(IFormatWriter): - __doc__ = '''A wrapper for %s + __doc__ = ( + """A wrapper for %s See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/ImageWriter.html - '''%class_name - new_fn = jutil.make_new(class_name, '(Lloci/formats/IFormatWriter;)V') + """ + % class_name + ) + new_fn = jutil.make_new(class_name, "(Lloci/formats/IFormatWriter;)V") + def __init__(self, writer): self.new_fn(writer) - setId = jutil.make_method('setId', '(Ljava/lang/String;)V', - 'Sets the current file name.') + setId = jutil.make_method( + "setId", "(Ljava/lang/String;)V", "Sets the current file name." + ) + return WriterWrapper def make_format_writer_class(class_name): - '''Make a FormatWriter wrapper class + """Make a FormatWriter wrapper class class_name - the name of a class that implements loci.formats.FormatWriter Known names in the loci.formats.out package: APNGWriter, AVIWriter, EPSWriter, ICSWriter, ImageIOWriter, JPEG2000Writer, JPEGWriter, LegacyQTWriter, OMETiffWriter, OMEXMLWriter, QTWriter, TiffWriter - ''' - new_fn = jutil.make_new(class_name, - '(Ljava/lang/String;Ljava/lang/String;)V') + """ + new_fn = jutil.make_new(class_name, "(Ljava/lang/String;Ljava/lang/String;)V") + class FormatWriter(object): - __doc__ = '''A wrapper for %s implementing loci.formats.FormatWriter - See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/FormatWriter'''%class_name + __doc__ = ( + """A wrapper for %s implementing loci.formats.FormatWriter + See http://hudson.openmicroscopy.org.uk/job/LOCI/javadoc/loci/formats/FormatWriter""" + % class_name + ) + def __init__(self): self.new_fn() - canDoStacks = jutil.make_method('canDoStacks','()Z', - 'Reports whether the writer can save multiple images to a single file') - getColorModel = jutil.make_method('getColorModel', - '()Ljava/awt/image/ColorModel;', - 'Gets the color model') - getCompression = jutil.make_method('getCompression', - '()Ljava/lang/String;', - 'Gets the current compression type') - getCompressionTypes = jutil.make_method('getCompressionTypes', - '()[Ljava/lang/String;', - 'Gets the available compression types') - getFramesPerSecond = jutil.make_method('getFramesPerSecond', - '()I', "Gets the frames per second to use when writing") - getMetadataRetrieve = jutil.make_method('getMetadataRetrieve', - '()Lloci/formats/meta/MetadataRetrieve;', - 'Retrieves the current metadata retrieval object for this writer.') - - getPixelTypes = jutil.make_method('getPixelTypes', - '()[I') - isInterleaved = jutil.make_method('isInterleaved','()Z', - 'Gets whether or not the channels in an image are interleaved') - isSupportedType = jutil.make_method('isSupportedType','(I)Z', - 'Checks if the given pixel type is supported') - saveBytes = jutil.make_method('saveBytes', '([BZ)V', - 'Saves the given byte array to the current file') - setColorModel = jutil.make_method('setColorModel', - '(Ljava/awt/image/ColorModel;)V', - 'Sets the color model') - setCompression = jutil.make_method('setCompression', - '(Ljava/lang/String;)V', - 'Sets the current compression type') - setFramesPerSecond = jutil.make_method('setFramesPerSecond', - '(I)V', - 'Sets the frames per second to use when writing') - setId = jutil.make_method('setId','(Ljava/lang/String;)V', - 'Sets the current file name') - setInterleaved = jutil.make_method('setInterleaved', '(Z)V', - 'Sets whether or not the channels in an image are interleaved') - setMetadataRetrieve = jutil.make_method('setMetadataRetrieve', - '(Lloci/formats/meta/MetadataRetrieve;)V', - 'Sets the metadata retrieval object from which to retrieve standardized metadata') + canDoStacks = jutil.make_method( + "canDoStacks", + "()Z", + "Reports whether the writer can save multiple images to a single file", + ) + getColorModel = jutil.make_method( + "getColorModel", "()Ljava/awt/image/ColorModel;", "Gets the color model" + ) + getCompression = jutil.make_method( + "getCompression", + "()Ljava/lang/String;", + "Gets the current compression type", + ) + getCompressionTypes = jutil.make_method( + "getCompressionTypes", + "()[Ljava/lang/String;", + "Gets the available compression types", + ) + getFramesPerSecond = jutil.make_method( + "getFramesPerSecond", + "()I", + "Gets the frames per second to use when writing", + ) + getMetadataRetrieve = jutil.make_method( + "getMetadataRetrieve", + "()Lloci/formats/meta/MetadataRetrieve;", + "Retrieves the current metadata retrieval object for this writer.", + ) + + getPixelTypes = jutil.make_method("getPixelTypes", "()[I") + isInterleaved = jutil.make_method( + "isInterleaved", + "()Z", + "Gets whether or not the channels in an image are interleaved", + ) + isSupportedType = jutil.make_method( + "isSupportedType", "(I)Z", "Checks if the given pixel type is supported" + ) + saveBytes = jutil.make_method( + "saveBytes", "([BZ)V", "Saves the given byte array to the current file" + ) + setColorModel = jutil.make_method( + "setColorModel", "(Ljava/awt/image/ColorModel;)V", "Sets the color model" + ) + setCompression = jutil.make_method( + "setCompression", + "(Ljava/lang/String;)V", + "Sets the current compression type", + ) + setFramesPerSecond = jutil.make_method( + "setFramesPerSecond", + "(I)V", + "Sets the frames per second to use when writing", + ) + setId = jutil.make_method( + "setId", "(Ljava/lang/String;)V", "Sets the current file name" + ) + setInterleaved = jutil.make_method( + "setInterleaved", + "(Z)V", + "Sets whether or not the channels in an image are interleaved", + ) + setMetadataRetrieve = jutil.make_method( + "setMetadataRetrieve", + "(Lloci/formats/meta/MetadataRetrieve;)V", + "Sets the metadata retrieval object from which to retrieve standardized metadata", + ) + return FormatWriter + def getRGBColorSpace(): - '''Get a Java object that represents an RGB color space + """Get a Java object that represents an RGB color space See java.awt.color.ColorSpace: this returns the linear RGB color space - ''' - cs_linear_rgb = jutil.get_static_field('java/awt/color/ColorSpace', - 'CS_LINEAR_RGB', 'I') - return jutil.static_call('java/awt/color/ColorSpace', 'getInstance', - '(I)Ljava/awt/color/ColorSpace;', - cs_linear_rgb) + """ + cs_linear_rgb = jutil.get_static_field( + "java/awt/color/ColorSpace", "CS_LINEAR_RGB", "I" + ) + return jutil.static_call( + "java/awt/color/ColorSpace", + "getInstance", + "(I)Ljava/awt/color/ColorSpace;", + cs_linear_rgb, + ) + def getGrayColorSpace(): - '''Get a Java object that represents an RGB color space + """Get a Java object that represents an RGB color space See java.awt.color.ColorSpace: this returns the linear RGB color space - ''' - cs_gray = jutil.get_static_field('java/awt/color/ColorSpace', - 'CS_GRAY', 'I') - return jutil.static_call('java/awt/color/ColorSpace', 'getInstance', - '(I)Ljava/awt/color/ColorSpace;', - cs_gray) - -'''Constant for color model transparency indicating bitmask transparency''' -BITMASK = 'BITMASK' -'''Constant for color model transparency indicting an opaque color model''' -OPAQUE = 'OPAQUE' -'''Constant for color model transparency indicating a transparent color model''' -TRANSPARENT = 'TRANSPARENT' -'''Constant for color model transfer type indicating byte per pixel''' -TYPE_BYTE = 'TYPE_BYTE' -'''Constant for color model transfer type indicating unsigned short per pixel''' -TYPE_USHORT = 'TYPE_USHORT' -'''Constant for color model transfer type indicating integer per pixel''' -TYPE_INT = 'TYPE_INT' - -def getColorModel(color_space, - has_alpha=False, - is_alpha_premultiplied = False, - transparency = OPAQUE, - transfer_type = TYPE_BYTE): - '''Return a java.awt.image.ColorModel color model + """ + cs_gray = jutil.get_static_field("java/awt/color/ColorSpace", "CS_GRAY", "I") + return jutil.static_call( + "java/awt/color/ColorSpace", + "getInstance", + "(I)Ljava/awt/color/ColorSpace;", + cs_gray, + ) + + +"""Constant for color model transparency indicating bitmask transparency""" +BITMASK = "BITMASK" +"""Constant for color model transparency indicting an opaque color model""" +OPAQUE = "OPAQUE" +"""Constant for color model transparency indicating a transparent color model""" +TRANSPARENT = "TRANSPARENT" +"""Constant for color model transfer type indicating byte per pixel""" +TYPE_BYTE = "TYPE_BYTE" +"""Constant for color model transfer type indicating unsigned short per pixel""" +TYPE_USHORT = "TYPE_USHORT" +"""Constant for color model transfer type indicating integer per pixel""" +TYPE_INT = "TYPE_INT" + + +def getColorModel( + color_space, + has_alpha=False, + is_alpha_premultiplied=False, + transparency=OPAQUE, + transfer_type=TYPE_BYTE, +): + """Return a java.awt.image.ColorModel color model color_space - a java.awt.color.ColorSpace such as returned by getGrayColorSpace or getRGBColorSpace @@ -412,16 +556,22 @@ def getColorModel(color_space, transparency - one of BITMASK, OPAQUE or TRANSPARENT. transfer_type - one of TYPE_BYTE, TYPE_USHORT, TYPE_INT - ''' - jtransparency = jutil.get_static_field('java/awt/Transparency', - transparency, - 'I') - jtransfer_type = jutil.get_static_field('java/awt/image/DataBuffer', - transfer_type, 'I') - return jutil.make_instance('java/awt/image/ComponentColorModel', - '(Ljava/awt/color/ColorSpace;ZZII)V', - color_space, has_alpha, is_alpha_premultiplied, - jtransparency, jtransfer_type) + """ + jtransparency = jutil.get_static_field("java/awt/Transparency", transparency, "I") + jtransfer_type = jutil.get_static_field( + "java/awt/image/DataBuffer", transfer_type, "I" + ) + return jutil.make_instance( + "java/awt/image/ComponentColorModel", + "(Ljava/awt/color/ColorSpace;ZZII)V", + color_space, + has_alpha, + is_alpha_premultiplied, + jtransparency, + jtransfer_type, + ) + + if __name__ == "__main__": import wx import matplotlib.backends.backend_wxagg as mmmm @@ -431,22 +581,24 @@ def getColorModel(color_space, app = wx.PySimpleApp() -# dlg = wx.FileDialog(None) -# if dlg.ShowModal()==wx.ID_OK: -# filename = dlg.Path -# else: -# app.Exit() -# sys.exit() + # dlg = wx.FileDialog(None) + # if dlg.ShowModal()==wx.ID_OK: + # filename = dlg.Path + # else: + # app.Exit() + # sys.exit() - filename = '/Users/afraser/Desktop/cpa_example/images/AS_09125_050116000001_A01f00d0.png' - filename = '/Users/afraser/Desktop/wedding/header.jpg' + filename = ( + "/Users/afraser/Desktop/cpa_example/images/AS_09125_050116000001_A01f00d0.png" + ) + filename = "/Users/afraser/Desktop/wedding/header.jpg" - out_file = '/Users/afraser/Desktop/test_output.avi' + out_file = "/Users/afraser/Desktop/test_output.avi" try: os.remove(out_file) - print('previous output file deleted') + print("previous output file deleted") except: - print('no output file to delete') + print("no output file to delete") env = jutil.attach() ImageReader = make_image_reader_class() @@ -464,13 +616,13 @@ def getColorModel(color_space, t = 4 images = [] for tt in range(t): - images += [(np.random.rand(w, h, c) * 255).astype('uint8')] + images += [(np.random.rand(w, h, c) * 255).astype("uint8")] imeta = createOMEXMLMetadata() meta = wrap_imetadata_object(imeta) meta.createRoot() meta.setPixelsBigEndian(True, 0, 0) - meta.setPixelsDimensionOrder('XYCZT', 0, 0) + meta.setPixelsDimensionOrder("XYCZT", 0, 0) meta.setPixelsPixelType(FormatTools.getPixelTypeString(FormatTools.UINT8), 0, 0) meta.setPixelsSizeX(w, 0, 0) meta.setPixelsSizeY(h, 0, 0) @@ -479,30 +631,34 @@ def getColorModel(color_space, meta.setPixelsSizeT(t, 0, 0) meta.setLogicalChannelSamplesPerPixel(c, 0, 0) - print('big endian:', meta.getPixelsBigEndian(0, 0)) - print('dim order:', meta.getPixelsDimensionOrder(0, 0)) - print('pixel type:', meta.getPixelsPixelType(0, 0)) - print('size x:', meta.getPixelsSizeX(0, 0)) - print('size y:', meta.getPixelsSizeY(0, 0)) - print('size c:', meta.getPixelsSizeC(0, 0)) - print('size z:', meta.getPixelsSizeZ(0, 0)) - print('size t:', meta.getPixelsSizeT(0, 0)) - print('samples per pixel:', meta.getLogicalChannelSamplesPerPixel(0, 0)) + print("big endian:", meta.getPixelsBigEndian(0, 0)) + print("dim order:", meta.getPixelsDimensionOrder(0, 0)) + print("pixel type:", meta.getPixelsPixelType(0, 0)) + print("size x:", meta.getPixelsSizeX(0, 0)) + print("size y:", meta.getPixelsSizeY(0, 0)) + print("size c:", meta.getPixelsSizeC(0, 0)) + print("size z:", meta.getPixelsSizeZ(0, 0)) + print("size t:", meta.getPixelsSizeT(0, 0)) + print("samples per pixel:", meta.getLogicalChannelSamplesPerPixel(0, 0)) writer.setMetadataRetrieve(meta) writer.setId(out_file) for image in images: - if len(image.shape)==3 and image.shape[2] == 3: - save_im = np.array([image[:,:,0], image[:,:,1], image[:,:,2]]).astype(np.uint8).flatten() + if len(image.shape) == 3 and image.shape[2] == 3: + save_im = ( + np.array([image[:, :, 0], image[:, :, 1], image[:, :, 2]]) + .astype(np.uint8) + .flatten() + ) else: save_im = image.astype(np.uint8).flatten() writer.saveBytes(env.make_byte_array(save_im), (image is images[-1])) writer.close() - print('Done writing image :)') -# import PIL.Image as Image -# im = Image.open(out_file, 'r') -# im.show() + print("Done writing image :)") + # import PIL.Image as Image + # im = Image.open(out_file, 'r') + # im.show() jutil.detach() app.MainLoop() diff --git a/cellacdc/bioformats/log4j.py b/cellacdc/bioformats/log4j.py index 90d838b03..657d2f76a 100755 --- a/cellacdc/bioformats/log4j.py +++ b/cellacdc/bioformats/log4j.py @@ -7,8 +7,9 @@ import javabridge + def basic_config(): - '''Configure logging for "ERROR" level''' + """Configure logging for "ERROR" level""" log4j = javabridge.JClassWrapper("loci.common.Log4jTools") log4j.enableLogging() log4j.setRootLevel("ERROR") diff --git a/cellacdc/bioformats/metadatatools.py b/cellacdc/bioformats/metadatatools.py index ed6f4d1bf..cd95dd467 100755 --- a/cellacdc/bioformats/metadatatools.py +++ b/cellacdc/bioformats/metadatatools.py @@ -5,9 +5,7 @@ # Copyright (c) 2009-2014 Broad Institute # All rights reserved. -''' metadatatools.py - mechanism to wrap some bioformats metadata classes - -''' +"""metadatatools.py - mechanism to wrap some bioformats metadata classes""" from __future__ import absolute_import, unicode_literals @@ -16,191 +14,327 @@ from javabridge import jutil from .. import bioformats + def createOMEXMLMetadata(): - '''Creates an OME-XML metadata object using reflection, to avoid direct + """Creates an OME-XML metadata object using reflection, to avoid direct dependencies on the optional loci.formats.ome package. - ''' - return jutil.static_call('loci/formats/MetadataTools', 'createOMEXMLMetadata', '()Lloci/formats/meta/IMetadata;') + """ + return jutil.static_call( + "loci/formats/MetadataTools", + "createOMEXMLMetadata", + "()Lloci/formats/meta/IMetadata;", + ) class MetadataStore(object): - ''' ''' + """ """ + def __init__(self, o): self.o = o - createRoot = jutil.make_method('createRoot', '()V', '') + createRoot = jutil.make_method("createRoot", "()V", "") + def setPixelsBigEndian(self, bigEndian, imageIndex, binDataIndex): - '''Set the endianness for a particular image + """Set the endianness for a particular image bigEndian - True for big-endian, False for little-endian imageIndex - index of the image in question from IFormatReader.get_index? binDataIndex - ??? - ''' + """ # Post loci_tools 4.2 try: - jutil.call(self.o, 'setPixelsBinDataBigEndian', - '(Ljava/lang/Boolean;II)V', - bigEndian, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsBinDataBigEndian", + "(Ljava/lang/Boolean;II)V", + bigEndian, + imageIndex, + binDataIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsBigEndian', '(Ljava/lang/Boolean;II)V', - bigEndian, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsBigEndian", + "(Ljava/lang/Boolean;II)V", + bigEndian, + imageIndex, + binDataIndex, + ) def setPixelsDimensionOrder(self, dimension_order, imageIndex, binDataIndex): - '''Set the dimension order for a series''' + """Set the dimension order for a series""" # Post loci_tools 4.2 - use ome.xml.model.DimensionOrder try: jdimension_order = jutil.static_call( - 'ome/xml/model/enums/DimensionOrder', 'fromString', - '(Ljava/lang/String;)Lome/xml/model/enums/DimensionOrder;', - dimension_order) - jutil.call(self.o, 'setPixelsDimensionOrder', - '(Lome/xml/model/enums/DimensionOrder;I)V', - jdimension_order, imageIndex) + "ome/xml/model/enums/DimensionOrder", + "fromString", + "(Ljava/lang/String;)Lome/xml/model/enums/DimensionOrder;", + dimension_order, + ) + jutil.call( + self.o, + "setPixelsDimensionOrder", + "(Lome/xml/model/enums/DimensionOrder;I)V", + jdimension_order, + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsDimensionOrder', - '(Ljava/lang/String;II)V', - dimension_order, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsDimensionOrder", + "(Ljava/lang/String;II)V", + dimension_order, + imageIndex, + binDataIndex, + ) setPixelsPixelType = jutil.make_method( - 'setPixelsPixelType', '(Ljava/lang/String;II)V', - '''Sets the pixel storage type + "setPixelsPixelType", + "(Ljava/lang/String;II)V", + """Sets the pixel storage type pixel_type - text representation of the type, e.g. "uint8" imageIndex - ? binDataIndex - ? WARNING: only available in BioFormats < 4.2 - ''') + """, + ) setPixelsType = jutil.make_method( - 'setPixelsType', '(Lome/xml/model/enums/PixelType;I)V', - '''Set the pixel storage type + "setPixelsType", + "(Lome/xml/model/enums/PixelType;I)V", + """Set the pixel storage type pixel_type - one of the enumerated values from PixelType. imageIndex - ? See the ome.xml.model.enums.PixelType and make_pixel_type_class's PixelType for possible values. - ''') + """, + ) def setPixelsSizeX(self, x, imageIndex, binDataIndex): try: - jutil.call(self.o, 'setPixelsSizeX', - '(Lome/xml/model/primitives/PositiveInteger;I)V', - PositiveInteger(x), imageIndex) + jutil.call( + self.o, + "setPixelsSizeX", + "(Lome/xml/model/primitives/PositiveInteger;I)V", + PositiveInteger(x), + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsSizeX', - '(Ljava/lang/Integer;II)V', x, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsSizeX", + "(Ljava/lang/Integer;II)V", + x, + imageIndex, + binDataIndex, + ) def setPixelsSizeY(self, y, imageIndex, binDataIndex): try: - jutil.call(self.o, 'setPixelsSizeY', - '(Lome/xml/model/primitives/PositiveInteger;I)V', - PositiveInteger(y), imageIndex) + jutil.call( + self.o, + "setPixelsSizeY", + "(Lome/xml/model/primitives/PositiveInteger;I)V", + PositiveInteger(y), + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsSizeY', - '(Ljava/lang/Integer;II)V', y, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsSizeY", + "(Ljava/lang/Integer;II)V", + y, + imageIndex, + binDataIndex, + ) def setPixelsSizeZ(self, z, imageIndex, binDataIndex): try: - jutil.call(self.o, 'setPixelsSizeZ', - '(Lome/xml/model/primitives/PositiveInteger;I)V', - PositiveInteger(z), imageIndex) + jutil.call( + self.o, + "setPixelsSizeZ", + "(Lome/xml/model/primitives/PositiveInteger;I)V", + PositiveInteger(z), + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsSizeZ', - '(Ljava/lang/Integer;II)V', z, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsSizeZ", + "(Ljava/lang/Integer;II)V", + z, + imageIndex, + binDataIndex, + ) def setPixelsSizeC(self, c, imageIndex, binDataIndex): try: - jutil.call(self.o, 'setPixelsSizeC', - '(Lome/xml/model/primitives/PositiveInteger;I)V', - PositiveInteger(c), imageIndex) + jutil.call( + self.o, + "setPixelsSizeC", + "(Lome/xml/model/primitives/PositiveInteger;I)V", + PositiveInteger(c), + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsSizeC', - '(Ljava/lang/Integer;II)V', c, imageIndex, binDataIndex) + jutil.call( + self.o, + "setPixelsSizeC", + "(Ljava/lang/Integer;II)V", + c, + imageIndex, + binDataIndex, + ) def setPixelsSizeT(self, t, imageIndex, binDataIndex): try: - jutil.call(self.o, 'setPixelsSizeT', - '(Lome/xml/model/primitives/PositiveInteger;I)V', - PositiveInteger(t), imageIndex) + jutil.call( + self.o, + "setPixelsSizeT", + "(Lome/xml/model/primitives/PositiveInteger;I)V", + PositiveInteger(t), + imageIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setPixelsSizeT', - '(Ljava/lang/Integer;II)V', t, imageIndex, binDataIndex) - - def setLogicalChannelSamplesPerPixel(self, samplesPerPixel, imageIndex, channelIndex): - 'For a particular LogicalChannel, sets number of channel components in the logical channel.' + jutil.call( + self.o, + "setPixelsSizeT", + "(Ljava/lang/Integer;II)V", + t, + imageIndex, + binDataIndex, + ) + + def setLogicalChannelSamplesPerPixel( + self, samplesPerPixel, imageIndex, channelIndex + ): + "For a particular LogicalChannel, sets number of channel components in the logical channel." try: - jutil.call(self.o, 'setChannelSamplesPerPixel', - '(Lome/xml/model/primitives/PositiveInteger;II)V', - PositiveInteger(samplesPerPixel), - imageIndex, channelIndex) + jutil.call( + self.o, + "setChannelSamplesPerPixel", + "(Lome/xml/model/primitives/PositiveInteger;II)V", + PositiveInteger(samplesPerPixel), + imageIndex, + channelIndex, + ) except jutil.JavaException: - jutil.call(self.o, 'setLogicalChannelSamplesPerPixel', - '(Ljava/lang/Integer;II)V', samplesPerPixel, - imageIndex, channelIndex) + jutil.call( + self.o, + "setLogicalChannelSamplesPerPixel", + "(Ljava/lang/Integer;II)V", + samplesPerPixel, + imageIndex, + channelIndex, + ) + setImageID = jutil.make_method( - 'setImageID', '(Ljava/lang/String;I)V', - '''Tag the indexed image with a name + "setImageID", + "(Ljava/lang/String;I)V", + """Tag the indexed image with a name id - the name, for instance Image:0 imageIndex - the index of the image (series???) - ''') + """, + ) setPixelsID = jutil.make_method( - 'setPixelsID', '(Ljava/lang/String;I)V', - '''Tag the pixels with a name (???) + "setPixelsID", + "(Ljava/lang/String;I)V", + """Tag the pixels with a name (???) id - the name, for instance Pixels:0 imageIndex - the index of the image (???) - ''') + """, + ) setChannelID = jutil.make_method( - 'setChannelID', '(Ljava/lang/String;II)V', - '''Give an ID name to the given channel + "setChannelID", + "(Ljava/lang/String;II)V", + """Give an ID name to the given channel id - the name of the channel imageIndex - (???) - channelIndex - index of the channel to be ID'ed''') + channelIndex - index of the channel to be ID'ed""", + ) + class MetadataRetrieve(object): - ''' ''' + """ """ + def __init__(self, o): self.o = o - getPixelsBigEndian = jutil.make_method('getPixelsBigEndian', '(II)Ljava/lang/Boolean;', - 'For a particular Pixels, gets endianness of the pixels set.') - getPixelsDimensionOrder = jutil.make_method('getPixelsDimensionOrder', '(II)Ljava/lang/String;', - 'For a particular Pixels, gets the dimension order of the pixels set.') - getPixelsPixelType = jutil.make_method('getPixelsPixelType', '(II)Ljava/lang/String;', - 'For a particular Pixels, gets the pixel type.') - getPixelsSizeX = jutil.make_method('getPixelsSizeX', '(II)Ljava/lang/Integer;', - 'For a particular Pixels, gets The size of an individual plane or section\'s X axis (width).') - getPixelsSizeY = jutil.make_method('getPixelsSizeY', '(II)Ljava/lang/Integer;', - 'For a particular Pixels, gets The size of an individual plane or section\'s Y axis (height).') - getPixelsSizeZ = jutil.make_method('getPixelsSizeZ', '(II)Ljava/lang/Integer;', - 'For a particular Pixels, gets number of optical sections per stack.') - getPixelsSizeC = jutil.make_method('getPixelsSizeC', '(II)Ljava/lang/Integer;', - 'For a particular Pixels, gets number of channels per timepoint.') - getPixelsSizeT = jutil.make_method('getPixelsSizeT', '(II)Ljava/lang/Integer;', - 'For a particular Pixels, gets number of timepoints.') - getLogicalChannelSamplesPerPixel = jutil.make_method('getLogicalChannelSamplesPerPixel', '(II)Ljava/lang/Integer;', - 'For a particular LogicalChannel, gets number of channel components in the logical channel.') - getChannelName = jutil.make_method('getChannelName', - '(II)Ljava/lang/String;', - '''Get the name for a particular channel. + getPixelsBigEndian = jutil.make_method( + "getPixelsBigEndian", + "(II)Ljava/lang/Boolean;", + "For a particular Pixels, gets endianness of the pixels set.", + ) + getPixelsDimensionOrder = jutil.make_method( + "getPixelsDimensionOrder", + "(II)Ljava/lang/String;", + "For a particular Pixels, gets the dimension order of the pixels set.", + ) + getPixelsPixelType = jutil.make_method( + "getPixelsPixelType", + "(II)Ljava/lang/String;", + "For a particular Pixels, gets the pixel type.", + ) + getPixelsSizeX = jutil.make_method( + "getPixelsSizeX", + "(II)Ljava/lang/Integer;", + "For a particular Pixels, gets The size of an individual plane or section's X axis (width).", + ) + getPixelsSizeY = jutil.make_method( + "getPixelsSizeY", + "(II)Ljava/lang/Integer;", + "For a particular Pixels, gets The size of an individual plane or section's Y axis (height).", + ) + getPixelsSizeZ = jutil.make_method( + "getPixelsSizeZ", + "(II)Ljava/lang/Integer;", + "For a particular Pixels, gets number of optical sections per stack.", + ) + getPixelsSizeC = jutil.make_method( + "getPixelsSizeC", + "(II)Ljava/lang/Integer;", + "For a particular Pixels, gets number of channels per timepoint.", + ) + getPixelsSizeT = jutil.make_method( + "getPixelsSizeT", + "(II)Ljava/lang/Integer;", + "For a particular Pixels, gets number of timepoints.", + ) + getLogicalChannelSamplesPerPixel = jutil.make_method( + "getLogicalChannelSamplesPerPixel", + "(II)Ljava/lang/Integer;", + "For a particular LogicalChannel, gets number of channel components in the logical channel.", + ) + getChannelName = jutil.make_method( + "getChannelName", + "(II)Ljava/lang/String;", + """Get the name for a particular channel. imageIndex - image # to query (use C = 0) - channelIndex - channel # to query''') - getChannelID = jutil.make_method('getChannelID', - '(II)Ljava/lang/String;', - '''Get the OME channel ID for a particular channel. + channelIndex - channel # to query""", + ) + getChannelID = jutil.make_method( + "getChannelID", + "(II)Ljava/lang/String;", + """Get the OME channel ID for a particular channel. imageIndex - image # to query (use C = 0) - channelIndex - channel # to query''') + channelIndex - channel # to query""", + ) def wrap_imetadata_object(o): - ''' Returns a python object wrapping the functionality of the given - IMetaData object (as returned by createOMEXMLMetadata) ''' + """Returns a python object wrapping the functionality of the given + IMetaData object (as returned by createOMEXMLMetadata)""" + class IMetadata(MetadataStore, MetadataRetrieve): - ''' ''' + """ """ + def __init__(self, o): MetadataStore.__init__(self, o) MetadataRetrieve.__init__(self, o) @@ -208,55 +342,91 @@ def __init__(self, o): return IMetadata(o) + __pixel_type_class = None + + def make_pixel_type_class(): - '''The class, ome.xml.model.enums.PixelType + """The class, ome.xml.model.enums.PixelType The Java class has enumerations for the various image data types such as UINT8 or DOUBLE - ''' + """ global __pixel_type_class if __pixel_type_class is None: + class PixelType(object): - '''Provide enums from ome.xml.model.enums.PixelType''' + """Provide enums from ome.xml.model.enums.PixelType""" + def __init__(self): - klass = jutil.get_env().find_class('ome/xml/model/enums/PixelType') - self.INT8 = jutil.get_static_field(klass, 'INT8', 'Lome/xml/model/enums/PixelType;') - self.INT16 = jutil.get_static_field(klass, 'INT16', 'Lome/xml/model/enums/PixelType;') - self.INT32 = jutil.get_static_field(klass, 'INT32', 'Lome/xml/model/enums/PixelType;') - self.UINT8 = jutil.get_static_field(klass, 'UINT8', 'Lome/xml/model/enums/PixelType;') - self.UINT16 = jutil.get_static_field(klass, 'UINT16', 'Lome/xml/model/enums/PixelType;') - self.UINT32 = jutil.get_static_field(klass, 'UINT32', 'Lome/xml/model/enums/PixelType;') - self.FLOAT = jutil.get_static_field(klass, 'FLOAT', 'Lome/xml/model/enums/PixelType;') - self.BIT = jutil.get_static_field(klass, 'BIT', 'Lome/xml/model/enums/PixelType;') - self.DOUBLE = jutil.get_static_field(klass, 'DOUBLE', 'Lome/xml/model/enums/PixelType;') - self.COMPLEX = jutil.get_static_field(klass, 'COMPLEX', 'Lome/xml/model/enums/PixelType;') - self.DOUBLECOMPLEX = jutil.get_static_field(klass, 'DOUBLECOMPLEX', 'Lome/xml/model/enums/PixelType;') + klass = jutil.get_env().find_class("ome/xml/model/enums/PixelType") + self.INT8 = jutil.get_static_field( + klass, "INT8", "Lome/xml/model/enums/PixelType;" + ) + self.INT16 = jutil.get_static_field( + klass, "INT16", "Lome/xml/model/enums/PixelType;" + ) + self.INT32 = jutil.get_static_field( + klass, "INT32", "Lome/xml/model/enums/PixelType;" + ) + self.UINT8 = jutil.get_static_field( + klass, "UINT8", "Lome/xml/model/enums/PixelType;" + ) + self.UINT16 = jutil.get_static_field( + klass, "UINT16", "Lome/xml/model/enums/PixelType;" + ) + self.UINT32 = jutil.get_static_field( + klass, "UINT32", "Lome/xml/model/enums/PixelType;" + ) + self.FLOAT = jutil.get_static_field( + klass, "FLOAT", "Lome/xml/model/enums/PixelType;" + ) + self.BIT = jutil.get_static_field( + klass, "BIT", "Lome/xml/model/enums/PixelType;" + ) + self.DOUBLE = jutil.get_static_field( + klass, "DOUBLE", "Lome/xml/model/enums/PixelType;" + ) + self.COMPLEX = jutil.get_static_field( + klass, "COMPLEX", "Lome/xml/model/enums/PixelType;" + ) + self.DOUBLECOMPLEX = jutil.get_static_field( + klass, "DOUBLECOMPLEX", "Lome/xml/model/enums/PixelType;" + ) + __pixel_type_class = PixelType return __pixel_type_class -MINIMUM = 'MINIMUM' -NO_OVERLAYS = 'NO_OVERLAYS' -ALL = 'ALL' + +MINIMUM = "MINIMUM" +NO_OVERLAYS = "NO_OVERLAYS" +ALL = "ALL" + def get_metadata_options(level): - '''Get an instance of the MetadataOptions interface + """Get an instance of the MetadataOptions interface level - MINIMUM, NO_OVERLAYS or ALL to set the metadata retrieval level The object returned can be used in setMetadataOptions in a format reader. - ''' - jlevel = jutil.get_static_field('loci/formats/in/MetadataLevel', level, - 'Lloci/formats/in/MetadataLevel;') - return jutil.make_instance('loci/formats/in/DefaultMetadataOptions', - '(Lloci/formats/in/MetadataLevel;)V', - jlevel) + """ + jlevel = jutil.get_static_field( + "loci/formats/in/MetadataLevel", level, "Lloci/formats/in/MetadataLevel;" + ) + return jutil.make_instance( + "loci/formats/in/DefaultMetadataOptions", + "(Lloci/formats/in/MetadataLevel;)V", + jlevel, + ) def PositiveInteger(some_number): - '''Return an instance of ome.xml.model.primitives.PositiveInteger + """Return an instance of ome.xml.model.primitives.PositiveInteger some_number - the number to be wrapped up in the class - ''' - return jutil.make_instance('ome/xml/model/primitives/PositiveInteger', - '(Ljava/lang/Integer;)V', some_number) + """ + return jutil.make_instance( + "ome/xml/model/primitives/PositiveInteger", + "(Ljava/lang/Integer;)V", + some_number, + ) diff --git a/cellacdc/bioformats/noseplugin.py b/cellacdc/bioformats/noseplugin.py index 0e6ca5efb..d95fbed7f 100755 --- a/cellacdc/bioformats/noseplugin.py +++ b/cellacdc/bioformats/noseplugin.py @@ -17,22 +17,24 @@ class Log4JPlugin(Plugin): - ''' + """ Plugin that initializes Log4J in order to avoid Bioformats error messages. - ''' + """ + enabled = False name = "log4j" - score = 90 # Less than the score of javabridge.nosetests.JavaBridgePlugin + score = 90 # Less than the score of javabridge.nosetests.JavaBridgePlugin def begin(self): - javabridge.static_call("org/apache/log4j/BasicConfigurator", - "configure", "()V") - log4j_logger = javabridge.static_call("org/apache/log4j/Logger", - "getRootLogger", - "()Lorg/apache/log4j/Logger;") - warn_level = javabridge.get_static_field("org/apache/log4j/Level","ERROR", - "Lorg/apache/log4j/Level;") - javabridge.call(log4j_logger, "setLevel", "(Lorg/apache/log4j/Level;)V", - warn_level) + javabridge.static_call("org/apache/log4j/BasicConfigurator", "configure", "()V") + log4j_logger = javabridge.static_call( + "org/apache/log4j/Logger", "getRootLogger", "()Lorg/apache/log4j/Logger;" + ) + warn_level = javabridge.get_static_field( + "org/apache/log4j/Level", "ERROR", "Lorg/apache/log4j/Level;" + ) + javabridge.call( + log4j_logger, "setLevel", "(Lorg/apache/log4j/Level;)V", warn_level + ) diff --git a/cellacdc/bioformats/omexml.py b/cellacdc/bioformats/omexml.py index 739054711..8f2381876 100755 --- a/cellacdc/bioformats/omexml.py +++ b/cellacdc/bioformats/omexml.py @@ -5,9 +5,7 @@ # Copyright (c) 2009-2014 Broad Institute # All rights reserved. -"""omexml.py read and write OME xml - -""" +"""omexml.py read and write OME xml""" from __future__ import absolute_import, unicode_literals @@ -15,24 +13,30 @@ from xml.etree import cElementTree as ElementTree import sys + if sys.version_info.major == 3: from io import StringIO - uenc = 'unicode' + + uenc = "unicode" else: from cStringIO import StringIO - uenc = 'utf-8' + + uenc = "utf-8" import datetime import logging from functools import reduce + logger = logging.getLogger(__file__) import re import uuid + def xsd_now(): - '''Return the current time in xsd:dateTime format''' + """Return the current time in xsd:dateTime format""" return datetime.datetime.now().isoformat() + DEFAULT_NOW = xsd_now() # # The namespaces @@ -69,7 +73,10 @@ def xsd_now(): -""".format(ns_ome_default=NS_DEFAULT.format(ns_key='ome'), ns_sa_default=NS_DEFAULT.format(ns_key='sa')) +""".format( + ns_ome_default=NS_DEFAULT.format(ns_key="ome"), + ns_sa_default=NS_DEFAULT.format(ns_key="sa"), +) # # These are the OME-XML pixel types - not all supported by subimager @@ -99,16 +106,16 @@ def xsd_now(): # The text for these can be found in # loci.formats.in.BaseTiffReader.initStandardMetadata # -'''IFD # 254''' +"""IFD # 254""" OM_NEW_SUBFILE_TYPE = "NewSubfileType" -'''IFD # 256''' +"""IFD # 256""" OM_IMAGE_WIDTH = "ImageWidth" -'''IFD # 257''' +"""IFD # 257""" OM_IMAGE_LENGTH = "ImageLength" -'''IFD # 258''' +"""IFD # 258""" OM_BITS_PER_SAMPLE = "BitsPerSample" -'''IFD # 262''' +"""IFD # 262""" OM_PHOTOMETRIC_INTERPRETATION = "PhotometricInterpretation" PI_WHITE_IS_ZERO = "WhiteIsZero" PI_BLACK_IS_ZERO = "BlackIsZero" @@ -120,89 +127,89 @@ def xsd_now(): PI_CIE_LAB = "CIELAB" PI_CFA_ARRAY = "Color Filter Array" -'''BioFormats infers the image type from the photometric interpretation''' +"""BioFormats infers the image type from the photometric interpretation""" OM_METADATA_PHOTOMETRIC_INTERPRETATION = "MetaDataPhotometricInterpretation" MPI_RGB = "RGB" MPI_MONOCHROME = "Monochrome" MPI_CMYK = "CMYK" -'''IFD # 263''' -OM_THRESHHOLDING = "Threshholding" # (sic) -'''IFD # 264 (but can be 265 if the orientation = 8)''' +"""IFD # 263""" +OM_THRESHHOLDING = "Threshholding" # (sic) +"""IFD # 264 (but can be 265 if the orientation = 8)""" OM_CELL_WIDTH = "CellWidth" -'''IFD # 265''' +"""IFD # 265""" OM_CELL_LENGTH = "CellLength" -'''IFD # 266''' +"""IFD # 266""" OM_FILL_ORDER = "FillOrder" -'''IFD # 279''' +"""IFD # 279""" OM_DOCUMENT_NAME = "Document Name" -'''IFD # 271''' +"""IFD # 271""" OM_MAKE = "Make" -'''IFD # 272''' +"""IFD # 272""" OM_MODEL = "Model" -'''IFD # 274''' +"""IFD # 274""" OM_ORIENTATION = "Orientation" -'''IFD # 277''' +"""IFD # 277""" OM_SAMPLES_PER_PIXEL = "SamplesPerPixel" -'''IFD # 280''' +"""IFD # 280""" OM_MIN_SAMPLE_VALUE = "MinSampleValue" -'''IFD # 281''' +"""IFD # 281""" OM_MAX_SAMPLE_VALUE = "MaxSampleValue" -'''IFD # 282''' +"""IFD # 282""" OM_X_RESOLUTION = "XResolution" -'''IFD # 283''' +"""IFD # 283""" OM_Y_RESOLUTION = "YResolution" -'''IFD # 284''' +"""IFD # 284""" OM_PLANAR_CONFIGURATION = "PlanarConfiguration" PC_CHUNKY = "Chunky" PC_PLANAR = "Planar" -'''IFD # 286''' +"""IFD # 286""" OM_X_POSITION = "XPosition" -'''IFD # 287''' +"""IFD # 287""" OM_Y_POSITION = "YPosition" -'''IFD # 288''' +"""IFD # 288""" OM_FREE_OFFSETS = "FreeOffsets" -'''IFD # 289''' +"""IFD # 289""" OM_FREE_BYTECOUNTS = "FreeByteCounts" -'''IFD # 290''' +"""IFD # 290""" OM_GRAY_RESPONSE_UNIT = "GrayResponseUnit" -'''IFD # 291''' +"""IFD # 291""" OM_GRAY_RESPONSE_CURVE = "GrayResponseCurve" -'''IFD # 292''' +"""IFD # 292""" OM_T4_OPTIONS = "T4Options" -'''IFD # 293''' +"""IFD # 293""" OM_T6_OPTIONS = "T6Options" -'''IFD # 296''' +"""IFD # 296""" OM_RESOLUTION_UNIT = "ResolutionUnit" -'''IFD # 297''' +"""IFD # 297""" OM_PAGE_NUMBER = "PageNumber" -'''IFD # 301''' +"""IFD # 301""" OM_TRANSFER_FUNCTION = "TransferFunction" -'''IFD # 305''' +"""IFD # 305""" OM_SOFTWARE = "Software" -'''IFD # 306''' +"""IFD # 306""" OM_DATE_TIME = "DateTime" -'''IFD # 315''' +"""IFD # 315""" OM_ARTIST = "Artist" -'''IFD # 316''' +"""IFD # 316""" OM_HOST_COMPUTER = "HostComputer" -'''IFD # 317''' +"""IFD # 317""" OM_PREDICTOR = "Predictor" -'''IFD # 318''' +"""IFD # 318""" OM_WHITE_POINT = "WhitePoint" -'''IFD # 322''' +"""IFD # 322""" OM_TILE_WIDTH = "TileWidth" -'''IFD # 323''' +"""IFD # 323""" OM_TILE_LENGTH = "TileLength" -'''IFD # 324''' +"""IFD # 324""" OM_TILE_OFFSETS = "TileOffsets" -'''IFD # 325''' +"""IFD # 325""" OM_TILE_BYTE_COUNT = "TileByteCount" -'''IFD # 332''' +"""IFD # 332""" OM_INK_SET = "InkSet" -'''IFD # 33432''' +"""IFD # 33432""" OM_COPYRIGHT = "Copyright" # # Well row/column naming conventions @@ -210,72 +217,82 @@ def xsd_now(): NC_LETTER = "letter" NC_NUMBER = "number" + def page_name_original_metadata(index): - '''Get the key name for the page name metadata data for the indexed tiff page + """Get the key name for the page name metadata data for the indexed tiff page These are TIFF IFD #'s 285+ index - zero-based index of the page - ''' + """ return "PageName #%d" % index + def get_text(node): - '''Get the contents of text nodes in a parent node''' + """Get the contents of text nodes in a parent node""" return node.text + def set_text(node, text): - '''Set the text of a parent''' + """Set the text of a parent""" node.text = text + def qn(namespace, tag_name): - '''Return the qualified name for a given namespace and tag name + """Return the qualified name for a given namespace and tag name This is the ElementTree representation of a qualified name - ''' + """ return "{%s}%s" % (namespace, tag_name) + def split_qn(qn): - '''Split a qualified tag name or return None if namespace not present''' - m = re.match('\{(.*)\}(.*)', qn) + """Split a qualified tag name or return None if namespace not present""" + m = re.match("\{(.*)\}(.*)", qn) return m.group(1), m.group(2) if m else None + def get_namespaces(node): - '''Get top-level XML namespaces from a node.''' - ns_lib = {'ome': None, 'sa': None, 'spw': None} + """Get top-level XML namespaces from a node.""" + ns_lib = {"ome": None, "sa": None, "spw": None} for child in node.iter(): ns = split_qn(child.tag)[0] match = re.match(NS_RE, ns) if match: - ns_key = match.group('ns_key').lower() + ns_key = match.group("ns_key").lower() ns_lib[ns_key] = ns return ns_lib + def get_float_attr(node, attribute): - '''Cast an element attribute to a float or return None if not present''' + """Cast an element attribute to a float or return None if not present""" attr = node.get(attribute) return None if attr is None else float(attr) + def get_int_attr(node, attribute): - '''Cast an element attribute to an int or return None if not present''' + """Cast an element attribute to an int or return None if not present""" attr = node.get(attribute) return None if attr is None else int(attr) + def make_text_node(parent, namespace, tag_name, text): - '''Either make a new node and add the given text or replace the text + """Either make a new node and add the given text or replace the text parent - the parent node to the node to be created or found namespace - the namespace of the node's qualified name tag_name - the tag name of the node's qualified name text - the text to be inserted - ''' + """ qname = qn(namespace, tag_name) node = parent.find(qname) if node is None: node = ElementTree.SubElement(parent, qname) set_text(node, text) + class OMEXML(object): - '''Reads and writes OME-XML with methods to get and set it. + """Reads and writes OME-XML with methods to get and set it. The OMEXML class has four main purposes: to parse OME-XML, to output OME-XML, to provide a structured mechanism for inspecting OME-XML and to @@ -316,7 +333,8 @@ class OMEXML(object): See the `OME-XML schema documentation `_. - ''' + """ + def __init__(self, xml=None): if xml is None: xml = default_xml @@ -329,7 +347,7 @@ def __init__(self, xml=None): self.dom = ElementTree.ElementTree(ElementTree.fromstring(xml)) # determine OME namespaces self.ns = get_namespaces(self.dom.getroot()) - if self.ns['ome'] is None: + if self.ns["ome"] is None: raise Exception("Error: String not in OME-XML format") def __str__(self): @@ -342,9 +360,9 @@ def __str__(self): ElementTree.register_namespace(ns_key, ns) ElementTree.register_namespace("om", NS_ORIGINAL_METADATA) result = StringIO() - ElementTree.ElementTree(self.root_node).write(result, - encoding=uenc, - method="xml") + ElementTree.ElementTree(self.root_node).write( + result, encoding=uenc, method="xml" + ) return result.getvalue() def to_xml(self, indent="\t", newline="\n", encoding=uenc): @@ -358,24 +376,27 @@ def root_node(self): return self.dom.getroot() def get_image_count(self): - '''The number of images (= series) specified by the XML''' - return len(self.root_node.findall(qn(self.ns['ome'], "Image"))) + """The number of images (= series) specified by the XML""" + return len(self.root_node.findall(qn(self.ns["ome"], "Image"))) def set_image_count(self, value): - '''Add or remove image nodes as needed''' + """Add or remove image nodes as needed""" assert value > 0 root = self.root_node if self.image_count > value: - image_nodes = root.find(qn(self.ns['ome'], "Image")) + image_nodes = root.find(qn(self.ns["ome"], "Image")) for image_node in image_nodes[value:]: root.remove(image_node) - while(self.image_count < value): - new_image = self.Image(ElementTree.SubElement(root, qn(self.ns['ome'], "Image"))) + while self.image_count < value: + new_image = self.Image( + ElementTree.SubElement(root, qn(self.ns["ome"], "Image")) + ) new_image.ID = str(uuid.uuid4()) new_image.Name = "default.png" new_image.AcquisitionDate = xsd_now() new_pixels = self.Pixels( - ElementTree.SubElement(new_image.node, qn(self.ns['ome'], "Pixels"))) + ElementTree.SubElement(new_image.node, qn(self.ns["ome"], "Pixels")) + ) new_pixels.ID = str(uuid.uuid4()) new_pixels.DimensionOrder = DO_XYCTZ new_pixels.PixelType = PT_UINT8 @@ -385,7 +406,8 @@ def set_image_count(self, value): new_pixels.SizeY = 512 new_pixels.SizeZ = 1 new_channel = self.Channel( - ElementTree.SubElement(new_pixels.node, qn(self.ns['ome'], "Channel"))) + ElementTree.SubElement(new_pixels.node, qn(self.ns["ome"], "Channel")) + ) new_channel.ID = "Channel%d:0" % self.image_count new_channel.Name = new_channel.ID new_channel.SamplesPerPixel = 1 @@ -398,21 +420,23 @@ def plates(self): @property def structured_annotations(self): - '''Return the structured annotations container + """Return the structured annotations container returns a wrapping of OME/StructuredAnnotations. It creates the element if it doesn't exist. - ''' - node = self.root_node.find(qn(self.ns['sa'], "StructuredAnnotations")) + """ + node = self.root_node.find(qn(self.ns["sa"], "StructuredAnnotations")) if node is None: node = ElementTree.SubElement( - self.root_node, qn(self.ns['sa'], "StructuredAnnotations")) + self.root_node, qn(self.ns["sa"], "StructuredAnnotations") + ) return self.StructuredAnnotations(node) class Image(object): - '''Representation of the OME/Image element''' + """Representation of the OME/Image element""" + def __init__(self, node): - '''Initialize with the DOM Image node''' + """Initialize with the DOM Image node""" self.node = node self.ns = get_namespaces(self.node) @@ -426,12 +450,14 @@ def set_ID(self, value): def get_Name(self): return self.node.get("Name") + def set_Name(self, value): self.node.set("Name", value) + Name = property(get_Name, set_Name) def get_AcquisitionDate(self): - '''The date in ISO-8601 format''' + """The date in ISO-8601 format""" acquired_date = self.node.find(qn(self.ns["ome"], "AcquisitionDate")) if acquired_date is None: return None @@ -441,14 +467,15 @@ def set_AcquisitionDate(self, date): acquired_date = self.node.find(qn(self.ns["ome"], "AcquisitionDate")) if acquired_date is None: acquired_date = ElementTree.SubElement( - self.node, qn(self.ns["ome"], "AcquisitionDate")) + self.node, qn(self.ns["ome"], "AcquisitionDate") + ) set_text(acquired_date, date) - AcquisitionDate = property(get_AcquisitionDate, set_AcquisitionDate) + AcquisitionDate = property(get_AcquisitionDate, set_AcquisitionDate) @property def Pixels(self): - '''The OME/Image/Pixels element. + """The OME/Image/Pixels element. Example: @@ -458,49 +485,57 @@ def Pixels(self): >>> stack_count = pixels.SizeZ >>> timepoint_count = pixels.SizeT - ''' - return OMEXML.Pixels(self.node.find(qn(self.ns['ome'], "Pixels"))) + """ + return OMEXML.Pixels(self.node.find(qn(self.ns["ome"], "Pixels"))) def roiref(self, index=0): - '''The OME/Image/ROIRef element''' - return OMEXML.ROIRef(self.node.findall(qn(self.ns['ome'], "ROIRef"))[index]) + """The OME/Image/ROIRef element""" + return OMEXML.ROIRef(self.node.findall(qn(self.ns["ome"], "ROIRef"))[index]) def get_roiref_count(self): - return len(self.node.findall(qn(self.ns['ome'], "ROIRef"))) + return len(self.node.findall(qn(self.ns["ome"], "ROIRef"))) + def set_roiref_count(self, value): - '''Add or remove roirefs as needed''' + """Add or remove roirefs as needed""" assert value > 0 if self.roiref_count > value: - roiref_nodes = self.node.find(qn(self.ns['ome'], "ROIRef")) + roiref_nodes = self.node.find(qn(self.ns["ome"], "ROIRef")) for roiref_node in roiref_nodes[value:]: self.node.remove(roiref_node) - while(self.roiref_count < value): + while self.roiref_count < value: iteration = self.roiref_count - 1 - new_roiref = OMEXML.ROIRef(ElementTree.SubElement(self.node, qn(self.ns['ome'], "ROIRef"))) + new_roiref = OMEXML.ROIRef( + ElementTree.SubElement(self.node, qn(self.ns["ome"], "ROIRef")) + ) new_roiref.set_ID("ROI:" + str(iteration)) roiref_count = property(get_roiref_count, set_roiref_count) def image(self, index=0): - '''Return an image node by index''' - return self.Image(self.root_node.findall(qn(self.ns['ome'], "Image"))[index]) + """Return an image node by index""" + return self.Image(self.root_node.findall(qn(self.ns["ome"], "Image"))[index]) class Channel(object): - '''The OME/Image/Pixels/Channel element''' + """The OME/Image/Pixels/Channel element""" + def __init__(self, node): self.node = node self.ns = get_namespaces(node) def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_Name(self): return self.node.get("Name") + def set_Name(self, value): self.node.set("Name", value) + Name = property(get_Name, set_Name) def get_SamplesPerPixel(self): @@ -508,9 +543,10 @@ def get_SamplesPerPixel(self): def set_SamplesPerPixel(self, value): self.node.set("SamplesPerPixel", str(value)) + SamplesPerPixel = property(get_SamplesPerPixel, set_SamplesPerPixel) - #--------------------- + # --------------------- # The following section is from the Allen Institute for Cell Science version of this file # which can be found at https://github.com/AllenCellModeling/aicsimageio/blob/master/aicsimageio/vendor/omexml.py class TiffData(object): @@ -526,7 +562,7 @@ def __init__(self, node): self.ns = get_namespaces(self.node) def get_FirstZ(self): - '''The Z index of the plane''' + """The Z index of the plane""" return get_int_attr(self.node, "FirstZ") def set_FirstZ(self, value): @@ -535,7 +571,7 @@ def set_FirstZ(self, value): FirstZ = property(get_FirstZ, set_FirstZ) def get_FirstC(self): - '''The channel index of the plane''' + """The channel index of the plane""" return get_int_attr(self.node, "FirstC") def set_FirstC(self, value): @@ -544,7 +580,7 @@ def set_FirstC(self, value): FirstC = property(get_FirstC, set_FirstC) def get_FirstT(self): - '''The T index of the plane''' + """The T index of the plane""" return get_int_attr(self.node, "FirstT") def set_FirstT(self, value): @@ -553,7 +589,7 @@ def set_FirstT(self, value): FirstT = property(get_FirstT, set_FirstT) def get_IFD(self): - '''plane index within tiff file''' + """plane index within tiff file""" return get_int_attr(self.node, "IFD") def set_IFD(self, value): @@ -562,7 +598,7 @@ def set_IFD(self, value): IFD = property(get_IFD, set_IFD) def get_plane_count(self): - '''How many planes in this TiffData. Should always be 1''' + """How many planes in this TiffData. Should always be 1""" return get_int_attr(self.node, "PlaneCount") def set_plane_count(self, value): @@ -571,18 +607,19 @@ def set_plane_count(self, value): plane_count = property(get_plane_count, set_plane_count) class Plane(object): - '''The OME/Image/Pixels/Plane element + """The OME/Image/Pixels/Plane element The Plane element represents one 2-dimensional image plane. It has the Z, C and T indices of the plane and optionally has the X, Y, Z, exposure time and a relative time delta. - ''' + """ + def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) def get_TheZ(self): - '''The Z index of the plane''' + """The Z index of the plane""" return get_int_attr(self.node, "TheZ") def set_TheZ(self, value): @@ -591,7 +628,7 @@ def set_TheZ(self, value): TheZ = property(get_TheZ, set_TheZ) def get_TheC(self): - '''The channel index of the plane''' + """The channel index of the plane""" return get_int_attr(self.node, "TheC") def set_TheC(self, value): @@ -600,7 +637,7 @@ def set_TheC(self, value): TheC = property(get_TheC, set_TheC) def get_TheT(self): - '''The T index of the plane''' + """The T index of the plane""" return get_int_attr(self.node, "TheT") def set_TheT(self, value): @@ -609,7 +646,7 @@ def set_TheT(self, value): TheT = property(get_TheT, set_TheT) def get_DeltaT(self): - '''# of seconds since the beginning of the experiment''' + """# of seconds since the beginning of the experiment""" return get_float_attr(self.node, "DeltaT") def set_DeltaT(self, value): @@ -624,13 +661,13 @@ def get_ExposureTime(self): return None def set_ExposureTime(self, value): - '''Units are seconds. Duration of acquisition????''' + """Units are seconds. Duration of acquisition????""" self.node.set("ExposureTime", str(value)) ExposureTime = property(get_ExposureTime, set_ExposureTime) def get_PositionX(self): - '''X position of stage''' + """X position of stage""" position_x = self.node.get("PositionX") if position_x is not None: return float(position_x) @@ -642,7 +679,7 @@ def set_PositionX(self, value): PositionX = property(get_PositionX, set_PositionX) def get_PositionY(self): - '''Y position of stage''' + """Y position of stage""" return get_float_attr(self.node, "PositionY") def set_PositionY(self, value): @@ -651,7 +688,7 @@ def set_PositionY(self, value): PositionY = property(get_PositionY, set_PositionY) def get_PositionZ(self): - '''Z position of stage''' + """Z position of stage""" return get_float_attr(self.node, "PositionZ") def set_PositionZ(self, value): @@ -684,129 +721,155 @@ def set_PositionZUnit(self, value): PositionZUnit = property(get_PositionZUnit, set_PositionZUnit) class Pixels(object): - '''The OME/Image/Pixels element + """The OME/Image/Pixels element The Pixels element represents the pixels in an OME image and, for an OME-XML encoded image, will actually contain the base-64 encoded pixel data. It has the X, Y, Z, C, and T extents of the image and it specifies the channel interleaving and channel depth. - ''' + """ + def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_DimensionOrder(self): - '''The ordering of image planes in the file + """The ordering of image planes in the file A 5-letter code indicating the ordering of pixels, from the most rapidly varying to least. Use the DO_* constants (for instance DO_XYZCT) to compare and set this. - ''' + """ return self.node.get("DimensionOrder") + def set_DimensionOrder(self, value): self.node.set("DimensionOrder", value) + DimensionOrder = property(get_DimensionOrder, set_DimensionOrder) def get_PixelType(self): - '''The pixel bit type, for instance PT_UINT8 + """The pixel bit type, for instance PT_UINT8 The pixel type specifies the datatype used to encode pixels in the image data. You can use the PT_* constants to compare and set the pixel type. - ''' + """ return self.node.get("Type") def get_PhysicalSizeXUnit(self): - '''The unit of length of a pixel in X direction.''' + """The unit of length of a pixel in X direction.""" return self.node.get("PhysicalSizeXUnit") + def set_PhysicalSizeXUnit(self, value): self.node.set("PhysicalSizeXUnit", str(value)) + PhysicalSizeXUnit = property(get_PhysicalSizeXUnit, set_PhysicalSizeXUnit) def get_PhysicalSizeYUnit(self): - '''The unit of length of a pixel in Y direction.''' + """The unit of length of a pixel in Y direction.""" return self.node.get("PhysicalSizeYUnit") + def set_PhysicalSizeYUnit(self, value): self.node.set("PhysicalSizeYUnit", str(value)) + PhysicalSizeYUnit = property(get_PhysicalSizeYUnit, set_PhysicalSizeYUnit) def get_PhysicalSizeZUnit(self): - '''The unit of length of a voxel in Z direction.''' + """The unit of length of a voxel in Z direction.""" return self.node.get("PhysicalSizeZUnit") + def set_PhysicalSizeZUnit(self, value): self.node.set("PhysicalSizeZUnit", str(value)) + PhysicalSizeZUnit = property(get_PhysicalSizeZUnit, set_PhysicalSizeZUnit) def get_PhysicalSizeX(self): - '''The length of a single pixel in X direction.''' + """The length of a single pixel in X direction.""" return get_float_attr(self.node, "PhysicalSizeX") + def set_PhysicalSizeX(self, value): self.node.set("PhysicalSizeX", str(value)) + PhysicalSizeX = property(get_PhysicalSizeX, set_PhysicalSizeX) def get_PhysicalSizeY(self): - '''The length of a single pixel in Y direction.''' + """The length of a single pixel in Y direction.""" return get_float_attr(self.node, "PhysicalSizeY") + def set_PhysicalSizeY(self, value): self.node.set("PhysicalSizeY", str(value)) + PhysicalSizeY = property(get_PhysicalSizeY, set_PhysicalSizeY) def get_PhysicalSizeZ(self): - '''The size of a voxel in Z direction or None for 2D images.''' + """The size of a voxel in Z direction or None for 2D images.""" return get_float_attr(self.node, "PhysicalSizeZ") + def set_PhysicalSizeZ(self, value): self.node.set("PhysicalSizeZ", str(value)) + PhysicalSizeZ = property(get_PhysicalSizeZ, set_PhysicalSizeZ) def set_PixelType(self, value): self.node.set("Type", value) + PixelType = property(get_PixelType, set_PixelType) def get_SizeX(self): - '''The dimensions of the image in the X direction in pixels''' + """The dimensions of the image in the X direction in pixels""" return get_int_attr(self.node, "SizeX") + def set_SizeX(self, value): self.node.set("SizeX", str(value)) + SizeX = property(get_SizeX, set_SizeX) def get_SizeY(self): - '''The dimensions of the image in the Y direction in pixels''' + """The dimensions of the image in the Y direction in pixels""" return get_int_attr(self.node, "SizeY") + def set_SizeY(self, value): self.node.set("SizeY", str(value)) + SizeY = property(get_SizeY, set_SizeY) def get_SizeZ(self): - '''The dimensions of the image in the Z direction in pixels''' + """The dimensions of the image in the Z direction in pixels""" return get_int_attr(self.node, "SizeZ") def set_SizeZ(self, value): self.node.set("SizeZ", str(value)) + SizeZ = property(get_SizeZ, set_SizeZ) def get_SizeT(self): - '''The dimensions of the image in the T direction in pixels''' + """The dimensions of the image in the T direction in pixels""" return get_int_attr(self.node, "SizeT") def set_SizeT(self, value): self.node.set("SizeT", str(value)) + SizeT = property(get_SizeT, set_SizeT) def get_SizeC(self): - '''The dimensions of the image in the C direction in pixels''' + """The dimensions of the image in the C direction in pixels""" return get_int_attr(self.node, "SizeC") + def set_SizeC(self, value): self.node.set("SizeC", str(value)) + SizeC = property(get_SizeC, set_SizeC) def get_channel_count(self): - '''The number of channels in the image + """The number of channels in the image You can change the number of channels in the image by setting the channel_count: @@ -814,20 +877,21 @@ def get_channel_count(self): pixels.channel_count = 3 pixels.Channel(0).Name = "Red" ... - ''' - return len(self.node.findall(qn(self.ns['ome'], "Channel"))) + """ + return len(self.node.findall(qn(self.ns["ome"], "Channel"))) def set_channel_count(self, value): assert value > 0 channel_count = self.channel_count if channel_count > value: - channels = self.node.findall(qn(self.ns['ome'], "Channel")) + channels = self.node.findall(qn(self.ns["ome"], "Channel")) for channel in channels[value:]: self.node.remove(channel) else: for _ in range(channel_count, value): new_channel = OMEXML.Channel( - ElementTree.SubElement(self.node, qn(self.ns['ome'], "Channel"))) + ElementTree.SubElement(self.node, qn(self.ns["ome"], "Channel")) + ) new_channel.ID = str(uuid.uuid4()) new_channel.Name = new_channel.ID new_channel.SamplesPerPixel = 1 @@ -835,13 +899,14 @@ def set_channel_count(self, value): channel_count = property(get_channel_count, set_channel_count) def Channel(self, index=0): - '''Get the indexed channel from the Pixels element''' - channel = self.node.findall(qn(self.ns['ome'], "Channel"))[index] + """Get the indexed channel from the Pixels element""" + channel = self.node.findall(qn(self.ns["ome"], "Channel"))[index] return OMEXML.Channel(channel) + channel = Channel def get_plane_count(self): - '''The number of planes in the image + """The number of planes in the image An image with only one plane or an interleaved color plane will often not have any planes. @@ -852,49 +917,53 @@ def get_plane_count(self): pixels.plane_count = 3 pixels.Plane(0).TheZ=pixels.Plane(0).TheC=pixels.Plane(0).TheT=0 ... - ''' - return len(self.node.findall(qn(self.ns['ome'], "Plane"))) + """ + return len(self.node.findall(qn(self.ns["ome"], "Plane"))) def set_plane_count(self, value): assert value >= 0 plane_count = self.plane_count if plane_count > value: - planes = self.node.findall(qn(self.ns['ome'], "Plane")) + planes = self.node.findall(qn(self.ns["ome"], "Plane")) for plane in planes[value:]: self.node.remove(plane) else: for _ in range(plane_count, value): new_plane = OMEXML.Plane( - ElementTree.SubElement(self.node, qn(self.ns['ome'], "Plane"))) + ElementTree.SubElement(self.node, qn(self.ns["ome"], "Plane")) + ) plane_count = property(get_plane_count, set_plane_count) def Plane(self, index=0): - '''Get the indexed plane from the Pixels element''' - plane = self.node.findall(qn(self.ns['ome'], "Plane"))[index] + """Get the indexed plane from the Pixels element""" + plane = self.node.findall(qn(self.ns["ome"], "Plane"))[index] return OMEXML.Plane(plane) + plane = Plane def get_tiffdata_count(self): - return len(self.node.findall(qn(self.ns['ome'], "TiffData"))) + return len(self.node.findall(qn(self.ns["ome"], "TiffData"))) def set_tiffdata_count(self, value): assert value >= 0 - tiffdatas = self.node.findall(qn(self.ns['ome'], "TiffData")) + tiffdatas = self.node.findall(qn(self.ns["ome"], "TiffData")) for td in tiffdatas: self.node.remove(td) for _ in range(0, value): new_tiffdata = OMEXML.TiffData( - ElementTree.SubElement(self.node, qn(self.ns['ome'], "TiffData"))) + ElementTree.SubElement(self.node, qn(self.ns["ome"], "TiffData")) + ) tiffdata_count = property(get_tiffdata_count, set_tiffdata_count) def tiffdata(self, index=0): - data = self.node.findall(qn(self.ns['ome'], "TiffData"))[index] + data = self.node.findall(qn(self.ns["ome"], "TiffData"))[index] return OMEXML.TiffData(data) class Instrument(object): - '''Representation of the OME/Instrument element''' + """Representation of the OME/Instrument element""" + def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) @@ -909,16 +978,16 @@ def set_ID(self, value): @property def Detector(self): - return OMEXML.Detector(self.node.find(qn(self.ns['ome'], "Detector"))) + return OMEXML.Detector(self.node.find(qn(self.ns["ome"], "Detector"))) @property def Objective(self): - return OMEXML.Objective(self.node.find(qn(self.ns['ome'], "Objective"))) - + return OMEXML.Objective(self.node.find(qn(self.ns["ome"], "Objective"))) def instrument(self, index=0): - return self.Instrument(self.root_node.findall(qn(self.ns['ome'], "Instrument"))[index]) - + return self.Instrument( + self.root_node.findall(qn(self.ns["ome"], "Instrument"))[index] + ) class Objective(object): def __init__(self, node): @@ -927,8 +996,10 @@ def __init__(self, node): def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_LensNA(self): @@ -936,18 +1007,25 @@ def get_LensNA(self): def set_LensNA(self, value): self.node.set("LensNA", value) + LensNA = property(get_LensNA, set_LensNA) def get_NominalMagnification(self): return self.node.get("NominalMagnification") + def set_NominalMagnification(self, value): self.node.set("NominalMagnification", value) - NominalMagnification = property(get_NominalMagnification, set_NominalMagnification) + + NominalMagnification = property( + get_NominalMagnification, set_NominalMagnification + ) def get_WorkingDistanceUnit(self): return get_int_attr(self.node, "WorkingDistanceUnit") + def set_WorkingDistanceUnit(self, value): self.node.set("WorkingDistanceUnit", str(value)) + WorkingDistanceUnit = property(get_WorkingDistanceUnit, set_WorkingDistanceUnit) class Detector(object): @@ -957,8 +1035,10 @@ def __init__(self, node): def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_Gain(self): @@ -966,24 +1046,27 @@ def get_Gain(self): def set_Gain(self, value): self.node.set("Gain", value) + Gain = property(get_Gain, set_Gain) def get_Model(self): return self.node.get("Model") + def set_Model(self, value): self.node.set("Model", value) + Model = property(get_Model, set_Model) def get_Type(self): return get_int_attr(self.node, "Type") + def set_Type(self, value): self.node.set("Type", str(value)) - Type = property(get_Type, set_Type) - + Type = property(get_Type, set_Type) class StructuredAnnotations(dict): - '''The OME/StructuredAnnotations element + """The OME/StructuredAnnotations element Structured annotations let OME-XML represent metadata from other file formats, for example the tag metadata in TIFF files. The @@ -1000,7 +1083,7 @@ class StructuredAnnotations(dict): callers will be using these to read tag data that's not represented in OME-XML such as the bits per sample and min and max sample values. - ''' + """ def __init__(self, node): self.node = node @@ -1016,8 +1099,9 @@ def __contains__(self, key): return self.has_key(key) def keys(self): - return filter(lambda x: x is not None, - [child.get("ID") for child in self.node]) + return filter( + lambda x: x is not None, [child.get("ID") for child in self.node] + ) def has_key(self, key): for child in self.node: @@ -1026,30 +1110,33 @@ def has_key(self, key): return False def add_original_metadata(self, key, value): - '''Create an original data key/value pair + """Create an original data key/value pair key - the original metadata's key name, for instance OM_PHOTOMETRIC_INTERPRETATION value - the value, for instance, "RGB" returns the ID for the structured annotation. - ''' + """ xml_annotation = ElementTree.SubElement( - self.node, qn(self.ns['sa'], "XMLAnnotation")) + self.node, qn(self.ns["sa"], "XMLAnnotation") + ) node_id = str(uuid.uuid4()) xml_annotation.set("ID", node_id) - xa_value = ElementTree.SubElement(xml_annotation, qn(self.ns['sa'], "Value")) + xa_value = ElementTree.SubElement( + xml_annotation, qn(self.ns["sa"], "Value") + ) ov = ElementTree.SubElement( - xa_value, qn(NS_ORIGINAL_METADATA, "OriginalMetadata")) + xa_value, qn(NS_ORIGINAL_METADATA, "OriginalMetadata") + ) ov_key = ElementTree.SubElement(ov, qn(NS_ORIGINAL_METADATA, "Key")) set_text(ov_key, key) - ov_value = ElementTree.SubElement( - ov, qn(NS_ORIGINAL_METADATA, "Value")) + ov_value = ElementTree.SubElement(ov, qn(NS_ORIGINAL_METADATA, "Value")) set_text(ov_value, value) return node_id def iter_original_metadata(self): - '''An iterator over the original metadata in structured annotations + """An iterator over the original metadata in structured annotations returns (, ()) @@ -1059,7 +1146,7 @@ def iter_original_metadata(self): is the original metadata key, typically one of the OM_* names of a TIFF tag is the value for the metadata - ''' + """ # # Here's the XML we're traversing: # @@ -1074,13 +1161,18 @@ def iter_original_metadata(self): # # # - for annotation_node in self.node.findall(qn(self.ns['sa'], "XMLAnnotation")): + for annotation_node in self.node.findall( + qn(self.ns["sa"], "XMLAnnotation") + ): # annotation_id = annotation_node.get("ID") - for xa_value_node in annotation_node.findall(qn(self.ns['sa'], "Value")): + for xa_value_node in annotation_node.findall( + qn(self.ns["sa"], "Value") + ): # for om_node in xa_value_node.findall( - qn(NS_ORIGINAL_METADATA, "OriginalMetadata")): + qn(NS_ORIGINAL_METADATA, "OriginalMetadata") + ): # key_node = om_node.find(qn(NS_ORIGINAL_METADATA, "Key")) value_node = om_node.find(qn(NS_ORIGINAL_METADATA, "Value")) @@ -1090,35 +1182,38 @@ def iter_original_metadata(self): if key_text is not None and value_text is not None: yield annotation_id, (key_text, value_text) else: - logger.warn("Original metadata was missing key or value:" + om_node.toxml()) + logger.warn( + "Original metadata was missing key or value:" + + om_node.toxml() + ) return def has_original_metadata(self, key): - '''True if there is an original metadata item with the given key''' - return any([k == key - for annotation_id, (k, v) - in self.iter_original_metadata()]) + """True if there is an original metadata item with the given key""" + return any( + [k == key for annotation_id, (k, v) in self.iter_original_metadata()] + ) def get_original_metadata_value(self, key, default=None): - '''Return the value for a particular original metadata key + """Return the value for a particular original metadata key key - key to search for default - default value to return if not found - ''' + """ for annotation_id, (k, v) in self.iter_original_metadata(): if k == key: return v return default def get_original_metadata_refs(self, ids): - '''For a given ID, get the matching original metadata references + """For a given ID, get the matching original metadata references ids - collection of IDs to match returns a dictionary of key to value - ''' + """ d = {} - for annotation_id, (k,v) in self.iter_original_metadata(): + for annotation_id, (k, v) in self.iter_original_metadata(): if annotation_id in ids: d[k] = v return d @@ -1128,13 +1223,14 @@ def OriginalMetadata(self): return OMEXML.OriginalMetadata(self) class OriginalMetadata(dict): - '''View original metadata as a dictionary + """View original metadata as a dictionary Original metadata holds "vendor-specific" metadata including TIFF tag values. - ''' + """ + def __init__(self, sa): - '''Initialized with the structured_annotations class instance''' + """Initialized with the structured_annotations class instance""" self.sa = sa def __getitem__(self, key): @@ -1154,9 +1250,9 @@ def __len__(self): return len(list(self.sa_iter_original_metadata())) def keys(self): - return [key - for annotation_id, (key, value) - in self.sa.iter_original_metadata()] + return [ + key for annotation_id, (key, value) in self.sa.iter_original_metadata() + ] def has_key(self, key): for annotation_id, (k, value) in self.sa.iter_original_metadata(): @@ -1169,38 +1265,41 @@ def iteritems(self): yield key, value class PlatesDucktype(object): - '''It looks like a list of plates''' + """It looks like a list of plates""" + def __init__(self, root): self.root = root self.ns = get_namespaces(self.root) def __getitem__(self, key): - plates = self.root.findall(qn(self.ns['spw'], "Plate")) + plates = self.root.findall(qn(self.ns["spw"], "Plate")) if isinstance(key, slice): return [OMEXML.Plate(plate) for plate in plates[key]] return OMEXML.Plate(plates[key]) def __len__(self): - return len(self.root.findall(qn(self.ns['spw'], "Plate"))) + return len(self.root.findall(qn(self.ns["spw"], "Plate"))) def __iter__(self): - for plate in self.root.iterfind(qn(self.ns['spw'], "Plate")): + for plate in self.root.iterfind(qn(self.ns["spw"], "Plate")): yield OMEXML.Plate(plate) - def newPlate(self, name, plate_id = str(uuid.uuid4())): + def newPlate(self, name, plate_id=str(uuid.uuid4())): new_plate_node = ElementTree.SubElement( - self.root, qn(self.ns['spw'], "Plate")) + self.root, qn(self.ns["spw"], "Plate") + ) new_plate = OMEXML.Plate(new_plate_node) new_plate.ID = plate_id new_plate.Name = name return new_plate class Plate(object): - '''The SPW:Plate element + """The SPW:Plate element This represents the plate element of the SPW schema: http://www.openmicroscopy.org/Schemas/SPW/2007-06/ - ''' + """ + def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) @@ -1244,8 +1343,10 @@ def get_ColumnNamingConvention(self): def set_ColumnNamingConvention(self, value): assert value in (NC_LETTER, NC_NUMBER) self.node.set("ColumnNamingConvention", value) - ColumnNamingConvention = property(get_ColumnNamingConvention, - set_ColumnNamingConvention) + + ColumnNamingConvention = property( + get_ColumnNamingConvention, set_ColumnNamingConvention + ) def get_RowNamingConvention(self): # Consider a default if not defined of NC_LETTER @@ -1254,14 +1355,15 @@ def get_RowNamingConvention(self): def set_RowNamingConvention(self, value): assert value in (NC_LETTER, NC_NUMBER) self.node.set("RowNamingConvention", value) - RowNamingConvention = property(get_RowNamingConvention, - set_RowNamingConvention) + + RowNamingConvention = property(get_RowNamingConvention, set_RowNamingConvention) def get_WellOriginX(self): return get_float_attr(self.node, "WellOriginX") def set_WellOriginX(self, value): self.node.set("WellOriginX", str(value)) + WellOriginX = property(get_WellOriginX, set_WellOriginX) def get_WellOriginY(self): @@ -1269,6 +1371,7 @@ def get_WellOriginY(self): def set_WellOriginY(self, value): self.node.set("WellOriginY", str(value)) + WellOriginY = property(get_WellOriginY, set_WellOriginY) def get_Rows(self): @@ -1288,32 +1391,39 @@ def set_Columns(self, value): Columns = property(get_Columns, set_Columns) def get_Description(self): - description = self.node.find(qn(self.ns['spw'], "Description")) + description = self.node.find(qn(self.ns["spw"], "Description")) if description is None: return None return get_text(description) def set_Description(self, text): - make_text_node(self.node, self.ns['spw'], "Description", text) + make_text_node(self.node, self.ns["spw"], "Description", text) + Description = property(get_Description, set_Description) def get_Well(self): - '''The well dictionary / list''' + """The well dictionary / list""" return OMEXML.WellsDucktype(self) + Well = property(get_Well) def get_well_name(self, well): - '''Get a well's name, using the row and column convention''' - result = "".join([ - "%02d" % (i+1) if convention == NC_NUMBER - else "ABCDEFGHIJKLMNOP"[i] - for i, convention - in ((well.Row, self.RowNamingConvention or NC_LETTER), - (well.Column, self.ColumnNamingConvention or NC_NUMBER))]) + """Get a well's name, using the row and column convention""" + result = "".join( + [ + "%02d" % (i + 1) + if convention == NC_NUMBER + else "ABCDEFGHIJKLMNOP"[i] + for i, convention in ( + (well.Row, self.RowNamingConvention or NC_LETTER), + (well.Column, self.ColumnNamingConvention or NC_NUMBER), + ) + ] + ) return result class WellsDucktype(dict): - '''The WellsDucktype lets you retrieve and create wells + """The WellsDucktype lets you retrieve and create wells The WellsDucktype looks like a dictionary but lets you reference the wells in a plate using indexing. Types of indexes: @@ -1326,17 +1436,18 @@ class WellsDucktype(dict): by ID - e.g. plate.Well["Well:0:0:0"] If the ducktype is unable to parse a well name, it assumes you're using an ID. - ''' + """ + def __init__(self, plate): self.plate_node = plate.node self.plate = plate self.ns = get_namespaces(self.plate_node) def __len__(self): - return len(self.plate_node.findall(qn(self.ns['spw'], "Well"))) + return len(self.plate_node.findall(qn(self.ns["spw"], "Well"))) def __getitem__(self, key): - all_wells = self.plate_node.findall(qn(self.ns['spw'], "Well")) + all_wells = self.plate_node.findall(qn(self.ns["spw"], "Well")) if isinstance(key, slice): return [OMEXML.Well(w) for w in all_wells[key]] if hasattr(key, "__len__") and len(key) == 2: @@ -1357,26 +1468,27 @@ def __getitem__(self, key): return None def __iter__(self): - '''Return the standard name for all wells on the plate + """Return the standard name for all wells on the plate for instance, 'B03' for a well with Row=1, Column=2 for a plate with the standard row and column naming convention - ''' - all_wells = self.plate_node.findall(qn(self.ns['spw'], "Well")) + """ + all_wells = self.plate_node.findall(qn(self.ns["spw"], "Well")) well = OMEXML.Well(None) for w in all_wells: well.node = w yield self.plate.get_well_name(well) - def new(self, row, column, well_id = str(uuid.uuid4())): - '''Create a new well at the given row and column + def new(self, row, column, well_id=str(uuid.uuid4())): + """Create a new well at the given row and column row - index of well's row column - index of well's column well_id - the ID attribute for the well - ''' + """ well_node = ElementTree.SubElement( - self.plate_node, qn(self.ns['spw'], "Well")) + self.plate_node, qn(self.ns["spw"], "Well") + ) well = OMEXML.Well(well_node) well.Row = row well.Column = column @@ -1389,24 +1501,31 @@ def __init__(self, node): def get_Column(self): return get_int_attr(self.node, "Column") + def set_Column(self, value): self.node.set("Column", str(value)) + Column = property(get_Column, set_Column) def get_Row(self): return get_int_attr(self.node, "Row") + def set_Row(self, value): self.node.set("Row", str(value)) + Row = property(get_Row, set_Row) def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_Sample(self): return OMEXML.WellSampleDucktype(self.node) + Sample = property(get_Sample) def get_ExternalDescription(self): @@ -1434,59 +1553,64 @@ def set_Color(self, value): Color = property(get_Color, set_Color) class WellSampleDucktype(list): - '''The WellSample elements in a well + """The WellSample elements in a well This is made to look like an indexable list so that you can do things like: wellsamples[0:2] - ''' + """ + def __init__(self, well_node): self.well_node = well_node self.ns = get_namespaces(self.well_node) def __len__(self): - return len(self.well_node.findall(qn(self.ns['spw'], "WellSample"))) + return len(self.well_node.findall(qn(self.ns["spw"], "WellSample"))) def __getitem__(self, key): - all_samples = self.well_node.findall(qn(self.ns['spw'], "WellSample")) + all_samples = self.well_node.findall(qn(self.ns["spw"], "WellSample")) if isinstance(key, slice): - return [OMEXML.WellSample(s) - for s in all_samples[key]] + return [OMEXML.WellSample(s) for s in all_samples[key]] return OMEXML.WellSample(all_samples[int(key)]) def __iter__(self): - '''Iterate through the well samples.''' - all_samples = self.well_node.findall(qn(self.ns['spw'], "WellSample")) + """Iterate through the well samples.""" + all_samples = self.well_node.findall(qn(self.ns["spw"], "WellSample")) for s in all_samples: yield OMEXML.WellSample(s) - def new(self, wellsample_id = str(uuid.uuid4()), index = None): - '''Create a new well sample - ''' + def new(self, wellsample_id=str(uuid.uuid4()), index=None): + """Create a new well sample""" if index is None: index = reduce(max, [s.Index for s in self], -1) + 1 new_node = ElementTree.SubElement( - self.well_node, qn(self.ns['spw'], "WellSample")) + self.well_node, qn(self.ns["spw"], "WellSample") + ) s = OMEXML.WellSample(new_node) s.ID = wellsample_id s.Index = index class WellSample(object): - '''The WellSample is a location within a well''' + """The WellSample is a location within a well""" + def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) def get_ID(self): return self.node.get("ID") + def set_ID(self, value): self.node.set("ID", value) + ID = property(get_ID, set_ID) def get_PositionX(self): return get_float_attr(self.node, "PositionX") + def set_PositionX(self, value): self.node.set("PositionX", str(value)) + PositionX = property(get_PositionX, set_PositionX) def get_PositionY(self): @@ -1494,6 +1618,7 @@ def get_PositionY(self): def set_PositionY(self, value): self.node.set("PositionY", str(value)) + PositionY = property(get_PositionY, set_PositionY) def get_Timepoint(self): @@ -1503,6 +1628,7 @@ def set_Timepoint(self, value): if isinstance(value, datetime.datetime): value = value.isoformat() self.node.set("Timepoint", value) + Timepoint = property(get_Timepoint, set_Timepoint) def get_Index(self): @@ -1514,22 +1640,22 @@ def set_Index(self, value): Index = property(get_Index, set_Index) def get_ImageRef(self): - '''Get the ID of the image of this site''' - ref = self.node.find(qn(self.ns['spw'], "ImageRef")) + """Get the ID of the image of this site""" + ref = self.node.find(qn(self.ns["spw"], "ImageRef")) if ref is None: return None return ref.get("ID") def set_ImageRef(self, value): - '''Add a reference to the image of this site''' - ref = self.node.find(qn(self.ns['spw'], "ImageRef")) + """Add a reference to the image of this site""" + ref = self.node.find(qn(self.ns["spw"], "ImageRef")) if ref is None: - ref = ElementTree.SubElement(self.node, qn(self.ns['spw'], "ImageRef")) + ref = ElementTree.SubElement(self.node, qn(self.ns["spw"], "ImageRef")) ref.set("ID", value) + ImageRef = property(get_ImageRef, set_ImageRef) class ROIRef(object): - def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) @@ -1538,36 +1664,38 @@ def get_ID(self): return self.node.get("ID") def set_ID(self, value): - ''' + """ ID will automatically be in the format "ROI:value" and must match the ROI ID (that uses the same formatting) - ''' + """ self.node.set("ID", "ROI:" + str(value)) ID = property(get_ID, set_ID) def get_roi_count(self): - return len(self.root_node.findall(qn(self.ns['ome'], "ROI"))) + return len(self.root_node.findall(qn(self.ns["ome"], "ROI"))) def set_roi_count(self, value): - '''Add or remove roi nodes as needed''' + """Add or remove roi nodes as needed""" assert value > 0 root = self.root_node if self.roi_count > value: - roi_nodes = root.find(qn(self.ns['ome'], "ROI")) + roi_nodes = root.find(qn(self.ns["ome"], "ROI")) for roi_node in roi_nodes[value:]: root.remove(roi_node) - while(self.roi_count < value): + while self.roi_count < value: iteration = self.roi_count - 1 - new_roi = self.ROI(ElementTree.SubElement(root, qn(self.ns['ome'], "ROI"))) + new_roi = self.ROI(ElementTree.SubElement(root, qn(self.ns["ome"], "ROI"))) new_roi.ID = str(iteration) new_roi.Name = "Marker " + str(iteration) new_Union = self.Union( - ElementTree.SubElement(new_roi.node, qn(self.ns['ome'], "Union"))) + ElementTree.SubElement(new_roi.node, qn(self.ns["ome"], "Union")) + ) new_Rectangle = self.Rectangle( - ElementTree.SubElement(new_Union.node, qn(self.ns['ome'], "Rectangle"))) + ElementTree.SubElement(new_Union.node, qn(self.ns["ome"], "Rectangle")) + ) new_Rectangle.set_ID("Shape:" + str(iteration) + ":0") new_Rectangle.set_TheZ(0) new_Rectangle.set_TheC(0) @@ -1583,11 +1711,10 @@ def set_roi_count(self, value): roi_count = property(get_roi_count, set_roi_count) def roi(self, index=0): - '''Return an ROI node by index''' - return self.ROI(self.root_node.findall(qn(self.ns['ome'], "ROI"))[index]) + """Return an ROI node by index""" + return self.ROI(self.root_node.findall(qn(self.ns["ome"], "ROI"))[index]) class ROI(object): - def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) @@ -1596,11 +1723,11 @@ def get_ID(self): return self.node.get("ID") def set_ID(self, value): - ''' + """ ID will automatically be in the format "ROI:value" and must match the ROIRef ID (that uses the same formatting) - ''' + """ self.node.set("ID", "ROI:" + str(value)) ID = property(get_ID, set_ID) @@ -1615,21 +1742,19 @@ def set_Name(self, value): @property def Union(self): - '''The OME/ROI/Union element.''' - return OMEXML.Union(self.node.find(qn(self.ns['ome'], "Union"))) + """The OME/ROI/Union element.""" + return OMEXML.Union(self.node.find(qn(self.ns["ome"], "Union"))) class Union(object): - def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) def Rectangle(self): - '''The OME/ROI/Union element. Currently only rectangle ROIs are available.''' - return OMEXML.Rectangle(self.node.find(qn(self.ns['ome'], "Rectangle"))) + """The OME/ROI/Union element. Currently only rectangle ROIs are available.""" + return OMEXML.Rectangle(self.node.find(qn(self.ns["ome"], "Rectangle"))) class Rectangle(object): - def __init__(self, node): self.node = node self.ns = get_namespaces(self.node) @@ -1654,12 +1779,12 @@ def get_StrokeWidth(self): return self.node.get("StrokeWidth") def set_StrokeWidth(self, value): - ''' + """ Colour is set using RGBA to integer conversion calculated using function from: https://docs.openmicroscopy.org/omero/5.5.1/developers/Python.html RGB colours: Red=-16776961, Green=16711935, Blue=65535 - ''' + """ self.node.set("StrokeWidth", str(value)) StrokeWidth = property(get_StrokeWidth, set_StrokeWidth) @@ -1705,7 +1830,7 @@ def set_Y(self, value): Y = property(get_Y, set_Y) def get_TheZ(self): - '''The Z index of the plane''' + """The Z index of the plane""" return get_int_attr(self.node, "TheZ") def set_TheZ(self, value): @@ -1714,7 +1839,7 @@ def set_TheZ(self, value): TheZ = property(get_TheZ, set_TheZ) def get_TheC(self): - '''The channel index of the plane''' + """The channel index of the plane""" return get_int_attr(self.node, "TheC") def set_TheC(self, value): @@ -1723,7 +1848,7 @@ def set_TheC(self, value): TheC = property(get_TheC, set_TheC) def get_TheT(self): - '''The T index of the plane''' + """The T index of the plane""" return get_int_attr(self.node, "TheT") def set_TheT(self, value): diff --git a/cellacdc/bioformats/tests/locate_jars.py b/cellacdc/bioformats/tests/locate_jars.py index 9c284289e..da3a162c1 100755 --- a/cellacdc/bioformats/tests/locate_jars.py +++ b/cellacdc/bioformats/tests/locate_jars.py @@ -11,9 +11,9 @@ jars = bioformats.JARS print(jars) jv.start_vm(class_path=jars) -paths = jv.JClassWrapper('java.lang.System').getProperty('java.class.path').split(";") +paths = jv.JClassWrapper("java.lang.System").getProperty("java.class.path").split(";") for path in paths: - print("%s: %s" %("exists" if os.path.isfile(path) else "missing", path)) + print("%s: %s" % ("exists" if os.path.isfile(path) else "missing", path)) jv.kill_vm() diff --git a/cellacdc/cca_functions.py b/cellacdc/cca_functions.py index cd87389b5..51b3c8a71 100755 --- a/cellacdc/cca_functions.py +++ b/cellacdc/cca_functions.py @@ -15,71 +15,74 @@ from typing import Iterable from . import GUI_INSTALLED + if GUI_INSTALLED: from qtpy.QtWidgets import QFileDialog from . import widgets from . import _run from . import load, cca_df_colnames -from . import myutils, prompts, html_utils, printl +from . import utils, prompts, html_utils, printl default_summable_columns = ( - 'cell_area_um2', - 'cell_vol_fl', - 'cell_vol_vox', - 'cell_area_pxl', - 'num_spots', - 'ref_ch_vol_um3', - 'ref_ch_num_fragments', - 'ref_ch_vol_vox' + "cell_area_um2", + "cell_vol_fl", + "cell_vol_vox", + "cell_area_pxl", + "num_spots", + "ref_ch_vol_um3", + "ref_ch_num_fragments", + "ref_ch_vol_vox", ) + def configuration_dialog(): app, _ = _run._setup_app(splashscreen=False) - + continue_selection = True data_dirs = [] positions = [] while continue_selection: - MostRecentPath = myutils.getMostRecentPath() + MostRecentPath = utils.getMostRecentPath() data_dir = QFileDialog.getExistingDirectory( - None, 'Select experiment folder containing Position_n folders ', - MostRecentPath + None, + "Select experiment folder containing Position_n folders ", + MostRecentPath, ) if not data_dir: continue_selection = False break - myutils.addToRecentPaths(data_dir) + utils.addToRecentPaths(data_dir) foldername = os.path.basename(data_dir) - if foldername == 'Images': + if foldername == "Images": pos_path = os.path.dirname(data_dir) data_dir = os.path.dirname(pos_path) pos = [os.path.basename(pos_path)] - elif foldername.find('Position_') != -1: + elif foldername.find("Position_") != -1: pos_path = data_dir data_dir = os.path.dirname(data_dir) pos = [os.path.basename(pos_path)] else: - available_pos = myutils.get_pos_foldernames(data_dir) + available_pos = utils.get_pos_foldernames(data_dir) if not available_pos: - print('******************************') - print('Selected folder does not contain any Position folders.') + print("******************************") + print("Selected folder does not contain any Position folders.") print(f'Selected folder: "{data_dir}"') - print('******************************') + print("******************************") raise FileNotFoundError win = widgets.QDialogListbox( - 'Position Selection', - 'Select which position(s) you want to analyse', - available_pos + "Position Selection", + "Select which position(s) you want to analyse", + available_pos, ) win.show() win.exec_() if win.cancel: - print('******************************') - print('Execution aborted by the user') - print('******************************') + print("******************************") + print("Execution aborted by the user") + print("******************************") raise InterruptedError pos = win.selectedItemsText @@ -87,140 +90,160 @@ def configuration_dialog(): positions.append(pos) msg = widgets.myMessageBox() txt = html_utils.paragraph( - 'Do you wish to select Positions from other experiments?' + "Do you wish to select Positions from other experiments?" ) yes, no = msg.question( - None, 'Continue selection?', txt, buttonsTexts=(' Yes ', ' No ') + None, "Continue selection?", txt, buttonsTexts=(" Yes ", " No ") ) continue_selection = msg.clickedButton == yes if len(data_dirs) == 0: - print('******************************') + print("******************************") print("No positions selected!") - print('******************************') + print("******************************") raise IndexError("No positions selected!") return data_dirs, positions, app + def find_available_channels(filenames, first_pos_dir): ch_name_selector = prompts.select_channel_name() - ch_names, warn = ch_name_selector.get_available_channels( - filenames, first_pos_dir - ) + ch_names, warn = ch_name_selector.get_available_channels(filenames, first_pos_dir) return ch_names, ch_name_selector.basename + def get_segm_endname(images_path, basename): segm_files = load.get_segm_files(images_path) - segm_endnames = load.get_endnames( - basename, segm_files - ) + segm_endnames = load.get_endnames(basename, segm_files) if not segm_endnames: msg = widgets.myMessageBox() txt = html_utils.paragraph(f""" The following position does not contain valid segmentation files.

{images_path}
""") - msg.critical(None, 'Segmentation file(s) not found', txt) + msg.critical(None, "Segmentation file(s) not found", txt) raise FileNotFoundError(f'Segmentation files not found in "{images_path}"') if len(segm_endnames) == 1: return segm_endnames[0] - + selectSegmWin = widgets.QDialogListbox( - 'Select segmentation file', - 'Select segmentation file to use as ROI:\n', - segm_endnames, multiSelection=False, parent=None + "Select segmentation file", + "Select segmentation file to use as ROI:\n", + segm_endnames, + multiSelection=False, + parent=None, ) selectSegmWin.exec_() if selectSegmWin.cancel: - raise FileNotFoundError(f'Segmentation file selection aborted by the user.') - + raise FileNotFoundError(f"Segmentation file selection aborted by the user.") + return selectSegmWin.selectedItemsText[0] - - + + def calculate_downstream_data( - file_names, - image_folders, - positions, - channels, - segm_endname, - force_recalculation=False, - calculate_fluo_metrics=True, - save_features_to_acdc_df=False, - ): + file_names, + image_folders, + positions, + channels, + segm_endname, + force_recalculation=False, + calculate_fluo_metrics=True, + save_features_to_acdc_df=False, +): no_of_channels = len(channels) overall_df = pd.DataFrame() for file_idx, file in enumerate(file_names): for pos_idx, pos_dir in enumerate(image_folders[file_idx]): - channel_data = ('placeholder')*no_of_channels - print(f'Load files for {file}, {positions[file_idx][pos_idx]}...') + channel_data = ("placeholder") * no_of_channels + print(f"Load files for {file}, {positions[file_idx][pos_idx]}...") acdc_df_path = None try: *channel_data, seg_mask, cc_data, metadata, cc_props, acdc_df_path = ( _load_files( - pos_dir, channels, segm_endname, - load_channels_data=calculate_fluo_metrics + pos_dir, + channels, + segm_endname, + load_channels_data=calculate_fluo_metrics, ) ) except TypeError: - print(f'File {file}, position {positions[file_idx][pos_idx]} skipped due to missing segmentation mask/CC annotations.') + print( + f"File {file}, position {positions[file_idx][pos_idx]} skipped due to missing segmentation mask/CC annotations." + ) continue - print(f'Number of cells in position: {len(cc_data.Cell_ID.unique())}') - print(f'Number of annotated frames in position: {cc_data.frame_i.max()+1}') + print(f"Number of cells in position: {len(cc_data.Cell_ID.unique())}") + print( + f"Number of annotated frames in position: {cc_data.frame_i.max() + 1}" + ) cc_data = _rename_columns(cc_data) is_timelapse_data, is_zstack_data = False, False - if int(metadata.loc['SizeT'])>1: - is_timelapse_data=True - if int(metadata.loc['SizeZ'])>1: - is_zstack_data=True + if int(metadata.loc["SizeT"]) > 1: + is_timelapse_data = True + if int(metadata.loc["SizeZ"]) > 1: + is_zstack_data = True if cc_props is not None and not force_recalculation: - print('Cell Cycle property data already existing, loaded from disk...') - overall_df = pd.concat([overall_df, cc_props], ignore_index=True).reset_index(drop=True) + print("Cell Cycle property data already existing, loaded from disk...") + overall_df = pd.concat( + [overall_df, cc_props], ignore_index=True + ).reset_index(drop=True) else: - print(f'Calculate regionprops on each frame based on Segmentation...') - rp_df = _calculate_rp_df(seg_mask, is_timelapse_data, is_zstack_data, metadata, max_frame=cc_data.frame_i.max()+1) - print(f'Calculate signal metrics for every channel and cell...') + print(f"Calculate regionprops on each frame based on Segmentation...") + rp_df = _calculate_rp_df( + seg_mask, + is_timelapse_data, + is_zstack_data, + metadata, + max_frame=cc_data.frame_i.max() + 1, + ) + print(f"Calculate signal metrics for every channel and cell...") flu_signal_df = _calculate_flu_signal( seg_mask, channel_data, channels, cc_data, is_timelapse_data, - is_zstack_data + is_zstack_data, ) temp_df = cc_data.merge( - rp_df, on=['frame_i', 'Cell_ID'], how='left', - suffixes=('_gui', '') + rp_df, on=["frame_i", "Cell_ID"], how="left", suffixes=("_gui", "") ) temp_df = temp_df.merge( - flu_signal_df, on=['frame_i', 'Cell_ID'], how='left', - suffixes=('_gui', '') + flu_signal_df, + on=["frame_i", "Cell_ID"], + how="left", + suffixes=("_gui", ""), ) # calculate amount of corrected signal by multiplying mean with area if is_timelapse_data: for channel in channels: - temp_df[f'{channel}_corrected_amount'] = ( - temp_df[f'{channel}_corrected_mean'] - * temp_df['area'] + temp_df[f"{channel}_corrected_amount"] = ( + temp_df[f"{channel}_corrected_mean"] * temp_df["area"] ) try: - temp_df[f'{channel}_corrected_concentration'] = ( - temp_df[f'{channel}_corrected_amount'] - / temp_df['cell_vol_fl'] + temp_df[f"{channel}_corrected_concentration"] = ( + temp_df[f"{channel}_corrected_amount"] + / temp_df["cell_vol_fl"] ) except KeyError: - print(f'Volume is missing in acdc output, NaNs inserted in concentration columns of channel {channel}') - temp_df[f'{channel}_corrected_concentration'] = None - temp_df['max_frame_pos'] = cc_data.frame_i.max() - temp_df['file'] = file - temp_df['selection_subset'] = file_idx - temp_df['position'] = positions[file_idx][pos_idx] - temp_df['directory'] = pos_dir - print('Saving calculated data for next time...') - files_in_curr_dir = myutils.listdir(pos_dir) + print( + f"Volume is missing in acdc output, NaNs inserted in concentration columns of channel {channel}" + ) + temp_df[f"{channel}_corrected_concentration"] = None + temp_df["max_frame_pos"] = cc_data.frame_i.max() + temp_df["file"] = file + temp_df["selection_subset"] = file_idx + temp_df["position"] = positions[file_idx][pos_idx] + temp_df["directory"] = pos_dir + print("Saving calculated data for next time...") + files_in_curr_dir = utils.listdir(pos_dir) common_prefix = _determine_common_prefix(files_in_curr_dir) - save_path = os.path.join(pos_dir, f'{common_prefix}cca_properties_downstream.csv') + save_path = os.path.join( + pos_dir, f"{common_prefix}cca_properties_downstream.csv" + ) temp_df.to_csv(save_path, index=False) - overall_df = pd.concat([overall_df, temp_df], ignore_index=True).reset_index(drop=True) + overall_df = pd.concat( + [overall_df, temp_df], ignore_index=True + ).reset_index(drop=True) # if save_features_to_acdc_df: # acdc_df = cc_data.set_index(['frame_i', 'Cell_ID']) @@ -228,7 +251,7 @@ def calculate_downstream_data( # acdc_df_path # import pdb; pdb.set_trace() - print('Done!') + print("Done!") return overall_df, is_timelapse_data, is_zstack_data @@ -237,108 +260,148 @@ def calculate_relatives_data(overall_df, channels): overall_df_rel = overall_df.copy() overall_df = overall_df.merge( overall_df_rel, - how='left', - left_on=['frame_i', 'relative_ID', 'max_frame_pos', 'file', 'selection_subset', 'position', 'directory'], - right_on=['frame_i', 'Cell_ID', 'max_frame_pos', 'file', 'selection_subset', 'position', 'directory'], - suffixes = ('', '_rel') + how="left", + left_on=[ + "frame_i", + "relative_ID", + "max_frame_pos", + "file", + "selection_subset", + "position", + "directory", + ], + right_on=[ + "frame_i", + "Cell_ID", + "max_frame_pos", + "file", + "selection_subset", + "position", + "directory", + ], + suffixes=("", "_rel"), ) # for every channel, calculate amount from mother and bud cells combined for ch in channels: try: - overall_df[f'{ch}_combined_amount_mother_bud'] = overall_df.apply( + overall_df[f"{ch}_combined_amount_mother_bud"] = overall_df.apply( lambda x: ( - x.loc[f'{ch}_corrected_amount'] - + x.loc[f'{ch}_corrected_amount_rel'] - if x.loc['cell_cycle_stage']=='S' - else x.loc[f'{ch}_corrected_amount'] - ), - axis=1 + x.loc[f"{ch}_corrected_amount"] + + x.loc[f"{ch}_corrected_amount_rel"] + if x.loc["cell_cycle_stage"] == "S" + else x.loc[f"{ch}_corrected_amount"] + ), + axis=1, ) - overall_df[f'{ch}_combined_raw_sum_mother_bud'] = overall_df.apply( + overall_df[f"{ch}_combined_raw_sum_mother_bud"] = overall_df.apply( lambda x: ( - x.loc[f'{ch}_raw_sum'] - + x.loc[f'{ch}_raw_sum_rel'] - if x.loc['cell_cycle_stage']=='S' - else x.loc[f'{ch}_raw_sum'] - ), - axis=1 + x.loc[f"{ch}_raw_sum"] + x.loc[f"{ch}_raw_sum_rel"] + if x.loc["cell_cycle_stage"] == "S" + else x.loc[f"{ch}_raw_sum"] + ), + axis=1, ) except KeyError: continue - overall_df['combined_mother_bud_volume'] = overall_df.apply( - lambda x: x.loc['cell_vol_fl']+x.loc['cell_vol_fl_rel'] if\ - x.loc['cell_cycle_stage']=='S' else\ - x.loc['cell_vol_fl'], - axis=1 + overall_df["combined_mother_bud_volume"] = overall_df.apply( + lambda x: ( + x.loc["cell_vol_fl"] + x.loc["cell_vol_fl_rel"] + if x.loc["cell_cycle_stage"] == "S" + else x.loc["cell_vol_fl"] + ), + axis=1, ) return overall_df def calculate_per_phase_quantities(overall_df, group_cols, channels): # group by group columns, aggregate some other columns - phase_grouped = overall_df.sort_values( - 'frame_i' - ).groupby(group_cols).agg( - # perform some calculations relating to the whole phase: - phase_area_growth=('cell_area_um2', lambda x: x.iloc[-1]-x.iloc[0]), - phase_volume_growth=('cell_vol_fl', lambda x: x.iloc[-1]-x.iloc[0]), - phase_area_at_beginning=('cell_area_um2', 'first'), - phase_volume_at_beginning=('cell_vol_fl', 'first'), - phase_volume_at_end=('cell_vol_fl', 'last'), - phase_daughter_area_growth=('cell_area_um2_rel', lambda x: x.iloc[-1]-x.iloc[0]), - phase_daughter_volume_growth=('cell_vol_fl_rel', lambda x: x.iloc[-1]-x.iloc[0]), - phase_length=('frame_i', lambda x: max(x)-min(x)), - phase_begin = ('frame_i', 'min'), - phase_end = ('frame_i', 'max'), - phase_combined_volume_at_end = ('combined_mother_bud_volume','last') - ).reset_index() + phase_grouped = ( + overall_df.sort_values("frame_i") + .groupby(group_cols) + .agg( + # perform some calculations relating to the whole phase: + phase_area_growth=("cell_area_um2", lambda x: x.iloc[-1] - x.iloc[0]), + phase_volume_growth=("cell_vol_fl", lambda x: x.iloc[-1] - x.iloc[0]), + phase_area_at_beginning=("cell_area_um2", "first"), + phase_volume_at_beginning=("cell_vol_fl", "first"), + phase_volume_at_end=("cell_vol_fl", "last"), + phase_daughter_area_growth=( + "cell_area_um2_rel", + lambda x: x.iloc[-1] - x.iloc[0], + ), + phase_daughter_volume_growth=( + "cell_vol_fl_rel", + lambda x: x.iloc[-1] - x.iloc[0], + ), + phase_length=("frame_i", lambda x: max(x) - min(x)), + phase_begin=("frame_i", "min"), + phase_end=("frame_i", "max"), + phase_combined_volume_at_end=("combined_mother_bud_volume", "last"), + ) + .reset_index() + ) # calculate some quantities in a for loop for all available channels and merge results. phase_grouped_flu = pd.DataFrame(columns=group_cols) for ch in channels: - if f'{ch}_corrected_mean' in overall_df.columns: - flu_temp = overall_df.sort_values( - 'frame_i' - ).groupby(group_cols).agg({ - # perform some calculations on flu data: - f'{ch}_corrected_amount': 'first', - f'{ch}_corrected_mean': 'first', - f'{ch}_corrected_concentration': ['first','last'], - f'{ch}_combined_amount_mother_bud': ['first','last'] - }).reset_index() + if f"{ch}_corrected_mean" in overall_df.columns: + flu_temp = ( + overall_df.sort_values("frame_i") + .groupby(group_cols) + .agg( + { + # perform some calculations on flu data: + f"{ch}_corrected_amount": "first", + f"{ch}_corrected_mean": "first", + f"{ch}_corrected_concentration": ["first", "last"], + f"{ch}_combined_amount_mother_bud": ["first", "last"], + } + ) + .reset_index() + ) # collapse multiindex into column name with aggregation as suffix - flu_temp.columns = ['_'.join(col) if col[1]!='' else col[0] for col in flu_temp.columns.values] + flu_temp.columns = [ + "_".join(col) if col[1] != "" else col[0] + for col in flu_temp.columns.values + ] # rename columns into meaningful names - flu_temp = flu_temp.rename({ - f'{ch}_corrected_amount_first': f'phase_{ch}_amount_at_beginning', - f'{ch}_corrected_mean_first': f'phase_{ch}_mean_at_beginning', - f'{ch}_corrected_concentration_first': f'phase_{ch}_concentration_at_beginning', - f'{ch}_corrected_concentration_last': f'phase_{ch}_concentration_at_end', - f'{ch}_combined_amount_mother_bud_first': f'phase_{ch}_combined_amount_at_beginning', - f'{ch}_combined_amount_mother_bud_last': f'phase_{ch}_combined_amount_at_end', - }, axis=1) - phase_grouped_flu = phase_grouped_flu.merge(flu_temp, how='right', on=group_cols, suffixes=('','')) + flu_temp = flu_temp.rename( + { + f"{ch}_corrected_amount_first": f"phase_{ch}_amount_at_beginning", + f"{ch}_corrected_mean_first": f"phase_{ch}_mean_at_beginning", + f"{ch}_corrected_concentration_first": f"phase_{ch}_concentration_at_beginning", + f"{ch}_corrected_concentration_last": f"phase_{ch}_concentration_at_end", + f"{ch}_combined_amount_mother_bud_first": f"phase_{ch}_combined_amount_at_beginning", + f"{ch}_combined_amount_mother_bud_last": f"phase_{ch}_combined_amount_at_end", + }, + axis=1, + ) + phase_grouped_flu = phase_grouped_flu.merge( + flu_temp, how="right", on=group_cols, suffixes=("", "") + ) # detect complete cell cycle phases and complete cell cycles temp = np.logical_and( phase_grouped.phase_begin > 0, - phase_grouped.phase_end < phase_grouped.max_frame_pos + phase_grouped.phase_end < phase_grouped.max_frame_pos, ) # this or is for disappearing cells - if 'max_t' in overall_df.columns: + if "max_t" in overall_df.columns: complete_phase_indices = np.logical_and( - temp, - phase_grouped.phase_end < phase_grouped.max_t + temp, phase_grouped.phase_end < phase_grouped.max_t ) else: complete_phase_indices = temp - phase_grouped['complete_phase'] = complete_phase_indices.astype(int) + phase_grouped["complete_phase"] = complete_phase_indices.astype(int) no_of_compl_phases_per_cycle = phase_grouped.groupby( - ['Cell_ID', 'generation_num', 'position', 'file'] - )['complete_phase'].transform('sum') + ["Cell_ID", "generation_num", "position", "file"] + )["complete_phase"].transform("sum") complete_cycle_indices = no_of_compl_phases_per_cycle == 2 - phase_grouped['complete_cycle'] = complete_cycle_indices.astype(int) + phase_grouped["complete_cycle"] = complete_cycle_indices.astype(int) # join phase-grouped data with - phase_grouped = phase_grouped.merge(phase_grouped_flu, how='left', on=group_cols, suffixes=('','')) + phase_grouped = phase_grouped.merge( + phase_grouped_flu, how="left", on=group_cols, suffixes=("", "") + ) return phase_grouped @@ -348,9 +411,8 @@ def _determine_common_prefix(filenames): # Determine the basename based on intersection of all .tif _, ext = os.path.splitext(file) sm = difflib.SequenceMatcher(None, file, basename) - i, j, k = sm.find_longest_match(0, len(file), - 0, len(basename)) - basename = file[i:i+k] + i, j, k = sm.find_longest_match(0, len(file), 0, len(basename)) + basename = file[i : i + k] return basename @@ -377,7 +439,7 @@ def _auto_rescale_intensity(img, perc=0.01, clip_min=False): scaled to [0,1] afterwards """ if perc > 0: - vmin, vmax = np.percentile(img, q=(perc, 100-perc)) + vmin, vmax = np.percentile(img, q=(perc, 100 - perc)) clip_min_indices = img < vmin clip_max_indices = img > vmax if clip_min: @@ -385,83 +447,90 @@ def _auto_rescale_intensity(img, perc=0.01, clip_min=False): img[clip_max_indices] = vmax else: vmin, vmax = img.min(), img.max() - scaled_img = (img-vmin)/(vmax-vmin) + scaled_img = (img - vmin) / (vmax - vmin) return scaled_img -def load_acdc_output_only( - file_names, - image_folders, - positions, - segm_endnames - ): + +def load_acdc_output_only(file_names, image_folders, positions, segm_endnames): """ Function to load only the acdc output. Use when fluorescent file is too big to load into RAM. #TODO: move to cca_functions """ - + overall_df = pd.DataFrame() for file_idx, file in enumerate(file_names): - acdc_output_endname = segm_endnames[file_idx].replace('segm', 'acdc_output') + acdc_output_endname = segm_endnames[file_idx].replace("segm", "acdc_output") for pos_idx, pos_dir in enumerate(image_folders[file_idx]): try: cc_stage_path = glob.glob( - os.path.join(f'{pos_dir}', f'*{acdc_output_endname}.csv') + os.path.join(f"{pos_dir}", f"*{acdc_output_endname}.csv") )[0] except IndexError: - cc_stage_path = glob.glob(os.path.join(f'{pos_dir}', '*cc_stage.csv'))[0] + cc_stage_path = glob.glob(os.path.join(f"{pos_dir}", "*cc_stage.csv"))[ + 0 + ] temp_df = pd.read_csv(cc_stage_path) - temp_df['max_frame_pos'] = temp_df.frame_i.max() - temp_df['file'] = file - temp_df['selection_subset'] = file_idx - temp_df['position'] = positions[file_idx][pos_idx] - temp_df['directory'] = pos_dir + temp_df["max_frame_pos"] = temp_df.frame_i.max() + temp_df["file"] = file + temp_df["selection_subset"] = file_idx + temp_df["position"] = positions[file_idx][pos_idx] + temp_df["directory"] = pos_dir overall_df = pd.concat([overall_df, temp_df]) return overall_df + def _load_channels_data(file_dir, channel_names, no_of_aligned_files): channel_files = [] if no_of_aligned_files > 0: for channel in channel_names: try: - ch_aligned_path = glob.glob(os.path.join(f'{file_dir}', f'*{channel}_aligned.npz'))[0] - channel_files.append(np.load(ch_aligned_path)['arr_0']) + ch_aligned_path = glob.glob( + os.path.join(f"{file_dir}", f"*{channel}_aligned.npz") + )[0] + channel_files.append(np.load(ch_aligned_path)["arr_0"]) except IndexError: try: - ch_aligned_path = glob.glob(os.path.join(f'{file_dir}', f'*{channel}_aligned.npy'))[0] + ch_aligned_path = glob.glob( + os.path.join(f"{file_dir}", f"*{channel}_aligned.npy") + )[0] channel_files.append(np.load(ch_aligned_path)) except IndexError: - print(f'Could not find an aligned file for channel {channel}') - print(f'Resulting data will not contain fluorescent data for this channel') + print(f"Could not find an aligned file for channel {channel}") + print( + f"Resulting data will not contain fluorescent data for this channel" + ) channel_files.append(None) else: for channel in channel_names: try: - ch_not_aligned_path = ( - glob.glob(os.path.join(f'{file_dir}', f'*{channel}.tif'))[0] - ) + ch_not_aligned_path = glob.glob( + os.path.join(f"{file_dir}", f"*{channel}.tif") + )[0] channel_files.append(imread(ch_not_aligned_path)) except IndexError: - print(f'Could not find any file for channel {channel}') - print(f'Resulting data will not contain fluorescent data for this channel') + print(f"Could not find any file for channel {channel}") + print( + f"Resulting data will not contain fluorescent data for this channel" + ) channel_files.append(None) return channel_files + def _load_files(file_dir, channels, segm_endname, load_channels_data=True): """ Function to load files of all given channels and the corresponding segmentation masks. Check first if aligned files are available and use them if so. """ - acdc_output_endname = segm_endname.replace('segm', 'acdc_output') - no_of_aligned_files = len( - glob.glob(os.path.join(f'{file_dir}', '*aligned.npz')) + acdc_output_endname = segm_endname.replace("segm", "acdc_output") + no_of_aligned_files = len(glob.glob(os.path.join(f"{file_dir}", "*aligned.npz"))) + seg_mask_available = ( + len(glob.glob(os.path.join(f"{file_dir}", f"*_{segm_endname}.npz"))) > 0 ) - seg_mask_available = len( - glob.glob(os.path.join(f'{file_dir}', f'*_{segm_endname}.npz')) - ) > 0 acdc_output_available = ( - len(glob.glob(os.path.join(f'{file_dir}', f'*{acdc_output_endname}.csv'))) - + len(glob.glob(os.path.join(f'{file_dir}', '*cc_stage*'))) > 0 + len(glob.glob(os.path.join(f"{file_dir}", f"*{acdc_output_endname}.csv"))) + + len(glob.glob(os.path.join(f"{file_dir}", "*cc_stage*"))) + > 0 ) if not (seg_mask_available and acdc_output_available): return None @@ -473,97 +542,145 @@ def _load_files(file_dir, channels, segm_endname, load_channels_data=True): # append segmentation file try: segm_file_path = glob.glob( - os.path.join(f'{file_dir}', f'*_{segm_endname}.npz') + os.path.join(f"{file_dir}", f"*_{segm_endname}.npz") )[0] - channel_files.append(np.load(segm_file_path)['arr_0']) + channel_files.append(np.load(segm_file_path)["arr_0"]) except IndexError: - segm_file_path = glob.glob(os.path.join(f'{file_dir}', '*_segm.npy'))[0] + segm_file_path = glob.glob(os.path.join(f"{file_dir}", "*_segm.npy"))[0] # assume segmentation mask to be .npy channel_files.append(np.load(segm_file_path)) # append cc-data try: cc_stage_path = glob.glob( - os.path.join(f'{file_dir}', f'*{acdc_output_endname}.csv') + os.path.join(f"{file_dir}", f"*{acdc_output_endname}.csv") )[0] except IndexError: - cc_stage_path = glob.glob(os.path.join(f'{file_dir}', '*cc_stage.csv'))[0] + cc_stage_path = glob.glob(os.path.join(f"{file_dir}", "*cc_stage.csv"))[0] # assume cell cycle output of ACDC to be .csv channel_files.append(pd.read_csv(cc_stage_path)) # append metadata if available, else append None - if len(glob.glob(os.path.join(f'{file_dir}', '*metadata*'))) > 0: - metadata_path = glob.glob(os.path.join(f'{file_dir}', '*metadata.csv'))[0] + if len(glob.glob(os.path.join(f"{file_dir}", "*metadata*"))) > 0: + metadata_path = glob.glob(os.path.join(f"{file_dir}", "*metadata.csv"))[0] # assume calculated metadata to be .csv - channel_files.append(pd.read_csv(metadata_path).set_index('Description')) + channel_files.append(pd.read_csv(metadata_path).set_index("Description")) else: channel_files.append(None) # append cc-properties if available, else append None - if len(glob.glob(os.path.join(f'{file_dir}', '*_downstream*'))) > 0: - cc_props_path = glob.glob(os.path.join(f'{file_dir}', '*_downstream*'))[0] + if len(glob.glob(os.path.join(f"{file_dir}", "*_downstream*"))) > 0: + cc_props_path = glob.glob(os.path.join(f"{file_dir}", "*_downstream*"))[0] # assume calculated cc properties to be .csv channel_files.append(pd.read_csv(cc_props_path)) else: channel_files.append(None) return (*channel_files, cc_stage_path) -def _calculate_rp_df(seg_mask, is_timelapse_data, is_zstack_data, metadata, max_frame=1, label_input=False): + +def _calculate_rp_df( + seg_mask, + is_timelapse_data, + is_zstack_data, + metadata, + max_frame=1, + label_input=False, +): """ function to calculate regionprops based on a 2D(!) segmentation mask. TODO: insert check if 3D segmentation mask is available and calculate more regionprops. """ if label_input: - #generate labeled video only when input is not labeled yet + # generate labeled video only when input is not labeled yet labeled_data = label(seg_mask) else: labeled_data = seg_mask.copy() # calculate rp's for rings t_df = pd.DataFrame() - props = ('label', 'area', 'convex_area', 'filled_area','major_axis_length', - 'minor_axis_length', 'orientation', 'perimeter', 'centroid', 'solidity') - rename_dict = {'label':'Cell_ID', 'centroid-0':'centroid_y', 'centroid-1':'centroid_x'} + props = ( + "label", + "area", + "convex_area", + "filled_area", + "major_axis_length", + "minor_axis_length", + "orientation", + "perimeter", + "centroid", + "solidity", + ) + rename_dict = { + "label": "Cell_ID", + "centroid-0": "centroid_y", + "centroid-1": "centroid_x", + } if is_timelapse_data: for t, img in enumerate(tqdm(labeled_data)): # build time-dependent dataframes for further use (later for cca) if img.max() > 0: - t_rp_df = pd.DataFrame(regionprops_table(img.astype(int), properties=props)).rename(columns=rename_dict) - t_rp_df['frame_i'] = t + t_rp_df = pd.DataFrame( + regionprops_table(img.astype(int), properties=props) + ).rename(columns=rename_dict) + t_rp_df["frame_i"] = t # calculate volumes based on regionprops if metadata is None: warnings.warn("No metadata available. Volumes are not calculated") - t_rp_df['cell_vol_vox_downstream'] = 0 - t_rp_df['cell_vol_fl_downstream'] = 0 + t_rp_df["cell_vol_vox_downstream"] = 0 + t_rp_df["cell_vol_fl_downstream"] = 0 else: t_rp = regionprops(img.astype(int)) - vol_vox = [_calc_rot_vol(obj, metadata.loc["PhysicalSizeY"], metadata.loc["PhysicalSizeX"])[0] for obj in t_rp] - vol_fl = [_calc_rot_vol(obj, metadata.loc["PhysicalSizeY"], metadata.loc["PhysicalSizeX"])[1] for obj in t_rp] + vol_vox = [ + _calc_rot_vol( + obj, + metadata.loc["PhysicalSizeY"], + metadata.loc["PhysicalSizeX"], + )[0] + for obj in t_rp + ] + vol_fl = [ + _calc_rot_vol( + obj, + metadata.loc["PhysicalSizeY"], + metadata.loc["PhysicalSizeX"], + )[1] + for obj in t_rp + ] assert len(t_rp_df) == len(vol_vox) - t_rp_df['cell_vol_vox_downstream'] = vol_vox - t_rp_df['cell_vol_fl_downstream'] = vol_fl + t_rp_df["cell_vol_vox_downstream"] = vol_vox + t_rp_df["cell_vol_fl_downstream"] = vol_fl # determine id's which are falsely merged by 3D-labeling for r_id in t_rp_df.Cell_ID.unique(): - bin_label = label((img==r_id).astype(int)) - t_rp_df.loc[t_rp_df['Cell_ID']==r_id, '2d_label_count'] = bin_label.max() + bin_label = label((img == r_id).astype(int)) + t_rp_df.loc[t_rp_df["Cell_ID"] == r_id, "2d_label_count"] = ( + bin_label.max() + ) t_df = pd.concat([t_df, t_rp_df], ignore_index=True) # calculate global features by grouping - grouped_df = t_df.groupby('Cell_ID').agg( - min_t=('frame_i', min), - max_t=('frame_i', max), - lifespan=('frame_i', lambda x: max(x)-min(x)+1) - ).reset_index() - merged_df = t_df.merge(grouped_df, how='left', on='Cell_ID') + grouped_df = ( + t_df.groupby("Cell_ID") + .agg( + min_t=("frame_i", min), + max_t=("frame_i", max), + lifespan=("frame_i", lambda x: max(x) - min(x) + 1), + ) + .reset_index() + ) + merged_df = t_df.merge(grouped_df, how="left", on="Cell_ID") # calculate further indicators based on merged data - merged_df['age'] = merged_df['frame_i'] - merged_df['min_t'] + 1 - merged_df['frames_till_gone'] = merged_df['max_t'] - merged_df['frame_i'] - merged_df['elongation'] = merged_df['major_axis_length']/merged_df['minor_axis_length'] + merged_df["age"] = merged_df["frame_i"] - merged_df["min_t"] + 1 + merged_df["frames_till_gone"] = merged_df["max_t"] - merged_df["frame_i"] + merged_df["elongation"] = ( + merged_df["major_axis_length"] / merged_df["minor_axis_length"] + ) return merged_df else: - rp_df = pd.DataFrame(regionprops_table(labeled_data.astype(int), properties=props)).rename(columns=rename_dict) + rp_df = pd.DataFrame( + regionprops_table(labeled_data.astype(int), properties=props) + ).rename(columns=rename_dict) for r_id in rp_df.Cell_ID.unique(): - bin_label = label((labeled_data==r_id).astype(int)) - rp_df.loc[rp_df['Cell_ID']==r_id, '2d_label_count'] = bin_label.max() - rp_df['elongation'] = rp_df['major_axis_length']/rp_df['minor_axis_length'] - rp_df['frame_i'] = 0 + bin_label = label((labeled_data == r_id).astype(int)) + rp_df.loc[rp_df["Cell_ID"] == r_id, "2d_label_count"] = bin_label.max() + rp_df["elongation"] = rp_df["major_axis_length"] / rp_df["minor_axis_length"] + rp_df["frame_i"] = 0 return rp_df @@ -605,18 +722,20 @@ def _calc_rot_vol(obj, PhysicalSizeY=1, PhysicalSizeX=1, logger=None): try: if is3Dobj: # For 3D objects we use a max projection for the rotation - obj_lab = obj.image.max(axis=0).astype(np.uint32)*obj.label + obj_lab = obj.image.max(axis=0).astype(np.uint32) * obj.label obj = regionprops(obj_lab)[0] - vox_to_fl = float(PhysicalSizeY)*pow(float(PhysicalSizeX), 2) + vox_to_fl = float(PhysicalSizeY) * pow(float(PhysicalSizeX), 2) rotate_ID_img = skimage.transform.rotate( - obj.image.astype(np.single), -(obj.orientation*180/np.pi), - resize=True, order=3 + obj.image.astype(np.single), + -(obj.orientation * 180 / np.pi), + resize=True, + order=3, ) - radii = np.sum(rotate_ID_img, axis=1)/2 - vol_vox = np.sum(np.pi*(radii**2)) + radii = np.sum(rotate_ID_img, axis=1) / 2 + vol_vox = np.sum(np.pi * (radii**2)) if vox_to_fl is not None: - return vol_vox, float(vol_vox*vox_to_fl) + return vol_vox, float(vol_vox * vox_to_fl) else: return vol_vox, vol_vox except Exception as e: @@ -627,79 +746,87 @@ def _calc_rot_vol(obj, PhysicalSizeY=1, PhysicalSizeX=1, logger=None): return np.nan, np.nan -def _calculate_flu_signal(seg_mask, channel_data, channels, cc_data, is_timelapse_data, is_zstack_data): +def _calculate_flu_signal( + seg_mask, channel_data, channels, cc_data, is_timelapse_data, is_zstack_data +): """ function to calculate sum and scaled sum of fluorescence signal per frame and cell. channel_data is a list-like of TYX arrays, one for each channel. channels are the name of the channels in the tuple. cc_data the output of acdc. - """ + """ max_frame = cc_data.frame_i.max() - df = pd.DataFrame(columns=['frame_i', 'Cell_ID']) + df = pd.DataFrame(columns=["frame_i", "Cell_ID"]) bg_medians = [] - + if seg_mask.ndim == 4: raise TypeError( - '4D segmentation masks not supported. ' - 'Feel free to request the new feature on our GitHub page ' - 'https://github.com/SchmollerLab/Cell_ACDC/issues' + "4D segmentation masks not supported. " + "Feel free to request the new feature on our GitHub page " + "https://github.com/SchmollerLab/Cell_ACDC/issues" ) - + for i, ch_img in enumerate(channel_data): if ch_img.ndim == 3: continue - + # Use sum projections for 4D data channel_data[i] = ch_img.sum(axis=1) - + for ch_idx, ch_array in enumerate(channel_data): if ch_array is None: bg_medians.append(None) else: bg_index = np.logical_and( - seg_mask[:max_frame+1]==0, ch_array[:max_frame+1]!=0 + seg_mask[: max_frame + 1] == 0, ch_array[: max_frame + 1] != 0 ) - ch_medians = [np.median(ch_array[t][bg_index[t]]) for t in range(max_frame+1)] + ch_medians = [ + np.median(ch_array[t][bg_index[t]]) for t in range(max_frame + 1) + ] bg_medians.append(ch_medians) if is_timelapse_data: for cell_id in tqdm(cc_data.Cell_ID.unique()): - temp_df = pd.DataFrame(columns=['frame_i', 'Cell_ID']) - times = range(max_frame+1) - temp_df['frame_i'] = times; temp_df['Cell_ID'] = cell_id - index_array = (seg_mask[:max_frame+1] == cell_id) - channel_data_cut = [c_arr[:max_frame+1] if c_arr is not None else None for c_arr in channel_data] + temp_df = pd.DataFrame(columns=["frame_i", "Cell_ID"]) + times = range(max_frame + 1) + temp_df["frame_i"] = times + temp_df["Cell_ID"] = cell_id + index_array = seg_mask[: max_frame + 1] == cell_id + channel_data_cut = [ + c_arr[: max_frame + 1] if c_arr is not None else None + for c_arr in channel_data + ] for c_idx, c_array in enumerate(channel_data_cut): if c_array is not None: - cell_signal = c_array*index_array + cell_signal = c_array * index_array # cell_signal = c_array[index_array] - summed = np.sum(cell_signal, axis=(1,2)) + summed = np.sum(cell_signal, axis=(1, 2)) # count = np.sum(cell_signal!=0, axis=(1,2)) - count = np.sum(index_array, axis=(1,2)) - mean_signal = np.divide(summed, count, where=count!=0) + count = np.sum(index_array, axis=(1, 2)) + mean_signal = np.divide(summed, count, where=count != 0) # mean_signal = np.mean(cell_signal, axis=(1,2)) corrected_signal = mean_signal - np.array(bg_medians[c_idx]) - temp_df[f'{channels[c_idx]}_corrected_mean'] = corrected_signal - temp_df[f'{channels[c_idx]}_raw_sum'] = summed + temp_df[f"{channels[c_idx]}_corrected_mean"] = corrected_signal + temp_df[f"{channels[c_idx]}_raw_sum"] = summed else: - temp_df[f'{channels[c_idx]}_corrected_mean'] = 0 - temp_df[f'{channels[c_idx]}_raw_sum'] = 0 + temp_df[f"{channels[c_idx]}_corrected_mean"] = 0 + temp_df[f"{channels[c_idx]}_raw_sum"] = 0 df = pd.concat([df, temp_df], ignore_index=True) - signal_indices = np.array(['_corrected_mean' in col for col in df.columns]) - keep_rows = df.loc[:,signal_indices].sum(axis=1) > 0 + signal_indices = np.array(["_corrected_mean" in col for col in df.columns]) + keep_rows = df.loc[:, signal_indices].sum(axis=1) > 0 df = df[keep_rows] - df = df.sort_values(['frame_i', 'Cell_ID']).reset_index(drop=True) + df = df.sort_values(["frame_i", "Cell_ID"]).reset_index(drop=True) return df def _rename_columns(cc_data): rename_dict = { - 'Cell cycle stage': 'cell_cycle_stage', - '# of cycles': 'generation_num', - "Relative's ID": 'relative_ID', - 'Relationship': 'relationship', - 'Emerg_frame_i': 'emerg_frame_i', - 'Division_frame_i': 'division_frame_i', - 'Discard': 'is_cell_excluded' + "Cell cycle stage": "cell_cycle_stage", + "# of cycles": "generation_num", + "Relative's ID": "relative_ID", + "Relationship": "relationship", + "Emerg_frame_i": "emerg_frame_i", + "Division_frame_i": "division_frame_i", + "Discard": "is_cell_excluded", } cc_data.columns = [rename_dict.get(col, col) for col in cc_data.columns] return cc_data @@ -710,133 +837,131 @@ def binned_mean_stats(x, values, nbins, bins_min_count): function to calculate binned means and corresponding standard errors for evenly spaced bins in the data ("x" gets distributed in bins, stats are calculated on "values") """ - bin_counts, _, _ = binned_statistic(x, values, statistic='count', bins=nbins) + bin_counts, _, _ = binned_statistic(x, values, statistic="count", bins=nbins) bin_means, bin_edges, _ = binned_statistic(x, values, bins=nbins) - bin_std, _, _ = binned_statistic(x, values, statistic='std', bins=nbins) - bin_standard_errors = bin_std/np.sqrt(bin_counts) - bin_width = (bin_edges[1] - bin_edges[0]) - bin_centers = bin_edges[1:] - bin_width/2 - x_errorbar = bin_centers[bin_counts>bins_min_count] - y_errorbar = bin_means[bin_counts>bins_min_count] - err_errorbar = 1.96 * bin_standard_errors[bin_counts>bins_min_count] + bin_std, _, _ = binned_statistic(x, values, statistic="std", bins=nbins) + bin_standard_errors = bin_std / np.sqrt(bin_counts) + bin_width = bin_edges[1] - bin_edges[0] + bin_centers = bin_edges[1:] - bin_width / 2 + x_errorbar = bin_centers[bin_counts > bins_min_count] + y_errorbar = bin_means[bin_counts > bins_min_count] + err_errorbar = 1.96 * bin_standard_errors[bin_counts > bins_min_count] return x_errorbar, y_errorbar, err_errorbar -def calculate_effect_size_cohen(data, group1, group2, cat_column='size_category', val_column='Pp38_concentration'): +def calculate_effect_size_cohen( + data, group1, group2, cat_column="size_category", val_column="Pp38_concentration" +): assert cat_column in data.columns and val_column in data.columns - data_gr1 = data[data[cat_column]==group1] - data_gr2 = data[data[cat_column]==group2] + data_gr1 = data[data[cat_column] == group1] + data_gr2 = data[data[cat_column] == group2] n1 = len(data_gr1) n2 = len(data_gr2) s1 = np.var(data_gr1[val_column]) s2 = np.var(data_gr2[val_column]) - cohen_s = np.sqrt( - ((n1-1)*s1+(n2-1)*s2) / (n1+n2-2) - ) - effect_size = (np.mean(data_gr1[val_column])- np.mean(data_gr2[val_column])) / cohen_s + cohen_s = np.sqrt(((n1 - 1) * s1 + (n2 - 1) * s2) / (n1 + n2 - 2)) + effect_size = ( + np.mean(data_gr1[val_column]) - np.mean(data_gr2[val_column]) + ) / cohen_s return effect_size -def calculate_effect_size_glass(data, group1, group2, cat_column='size_category', val_column='Pp38_concentration'): + +def calculate_effect_size_glass( + data, group1, group2, cat_column="size_category", val_column="Pp38_concentration" +): assert cat_column in data.columns and val_column in data.columns - data_gr1 = data[data[cat_column]==group1] - data_gr2 = data[data[cat_column]==group2] + data_gr1 = data[data[cat_column] == group1] + data_gr2 = data[data[cat_column] == group2] glass_s = np.std(data_gr2[val_column]) - effect_size = (np.mean(data_gr1[val_column])- np.mean(data_gr2[val_column])) / glass_s + effect_size = ( + np.mean(data_gr1[val_column]) - np.mean(data_gr2[val_column]) + ) / glass_s return effect_size + def _add_end_of_frame_i_column(acdc_df): cca_df_idx = acdc_df.cell_cycle_stage.dropna().index cca_df = acdc_df.loc[cca_df_idx][cca_df_colnames] - acdc_df['end_of_cell_cycle_frame_i'] = np.nan - + acdc_df["end_of_cell_cycle_frame_i"] = np.nan + will_divice_cca_df_S = cca_df[ - (cca_df.cell_cycle_stage == 'S') & (cca_df.will_divide > 0) + (cca_df.cell_cycle_stage == "S") & (cca_df.will_divide > 0) ].reset_index() - - cca_df['end_of_cell_cycle_frame_i'] = -1 - grouped_ID_gen_num = will_divice_cca_df_S.groupby( - ['Cell_ID', 'generation_num'] - ) - + + cca_df["end_of_cell_cycle_frame_i"] = -1 + grouped_ID_gen_num = will_divice_cca_df_S.groupby(["Cell_ID", "generation_num"]) + end_cc_frame_i_per_cycle = grouped_ID_gen_num.agg( - end_of_cell_cycle_frame_i=('frame_i', 'max') + end_of_cell_cycle_frame_i=("frame_i", "max") ) - + cca_df_with_gen_num_idx = ( - cca_df.reset_index() - .set_index(['Cell_ID', 'generation_num']) - .sort_index() + cca_df.reset_index().set_index(["Cell_ID", "generation_num"]).sort_index() ) - + for row in end_cc_frame_i_per_cycle.itertuples(): ID, gen_num = row.Index end_cc_frame_i = row.end_of_cell_cycle_frame_i idx = (ID, gen_num) - cca_df_with_gen_num_idx.loc[idx, 'end_of_cell_cycle_frame_i'] = ( - end_cc_frame_i - ) - + cca_df_with_gen_num_idx.loc[idx, "end_of_cell_cycle_frame_i"] = end_cc_frame_i + cca_df = ( cca_df_with_gen_num_idx.reset_index() - .set_index(['frame_i', 'Cell_ID']) + .set_index(["frame_i", "Cell_ID"]) .sort_index() ) - - acdc_df.loc[cca_df_idx, 'end_of_cell_cycle_frame_i'] = ( - cca_df['end_of_cell_cycle_frame_i'] - ) + + acdc_df.loc[cca_df_idx, "end_of_cell_cycle_frame_i"] = cca_df[ + "end_of_cell_cycle_frame_i" + ] return acdc_df + def _extend_will_divide_to_G1(acdc_df): - acdc_df = acdc_df.drop(columns=['level_0', 'index'], errors='ignore') + acdc_df = acdc_df.drop(columns=["level_0", "index"], errors="ignore") acdc_df = acdc_df.reset_index() - acdc_df_will_divide_true = acdc_df[acdc_df['will_divide'] > 0] - grouped = acdc_df_will_divide_true.groupby(['Cell_ID', 'generation_num']) - for (ID, gen_num) in grouped.groups.keys(): - mask = ( - (acdc_df['Cell_ID'] == ID) - & (acdc_df['generation_num'] == gen_num) - ) - acdc_df.loc[mask, 'will_divide'] = 1.0 - acdc_df = ( - acdc_df.reset_index() - .set_index(['frame_i', 'Cell_ID']) - .sort_index() - ) + acdc_df_will_divide_true = acdc_df[acdc_df["will_divide"] > 0] + grouped = acdc_df_will_divide_true.groupby(["Cell_ID", "generation_num"]) + for ID, gen_num in grouped.groups.keys(): + mask = (acdc_df["Cell_ID"] == ID) & (acdc_df["generation_num"] == gen_num) + acdc_df.loc[mask, "will_divide"] = 1.0 + acdc_df = acdc_df.reset_index().set_index(["frame_i", "Cell_ID"]).sort_index() return acdc_df - + + def add_derived_cell_cycle_columns(acdc_df: pd.DataFrame): - if 'cell_cycle_stage' not in acdc_df.columns: + if "cell_cycle_stage" not in acdc_df.columns: return acdc_df - + acdc_df = _extend_will_divide_to_G1(acdc_df) acdc_df = _add_end_of_frame_i_column(acdc_df) - + return acdc_df - + + def add_generation_num_of_relative_ID( - acdc_df, prefix_index: Iterable[str]=None, reset_index=True - ): - relID_index_col = ['frame_i', 'relative_ID', 'Cell_ID'] - ID_index_col = ['frame_i', 'Cell_ID', 'relative_ID'] - + acdc_df, prefix_index: Iterable[str] = None, reset_index=True +): + relID_index_col = ["frame_i", "relative_ID", "Cell_ID"] + ID_index_col = ["frame_i", "Cell_ID", "relative_ID"] + if prefix_index is not None: relID_index_col = [*prefix_index, *relID_index_col] ID_index_col = [*prefix_index, *ID_index_col] - + if reset_index: acdc_df = acdc_df.reset_index() - + acdc_df_by_rel_ID = acdc_df.set_index(relID_index_col) acdc_df_by_rel_ID.index relative_ID_idx = acdc_df_by_rel_ID.index acdc_df_by_frame_i = acdc_df.set_index(ID_index_col) relative_ID_idx = relative_ID_idx.intersection(acdc_df_by_frame_i.index) - acdc_df_by_frame_i['generation_num_relID'] = -1 + acdc_df_by_frame_i["generation_num_relID"] = -1 - acdc_df_by_frame_i.loc[relative_ID_idx, 'generation_num_relID'] = ( - acdc_df_by_rel_ID.loc[relative_ID_idx, 'generation_num'] + acdc_df_by_frame_i.loc[relative_ID_idx, "generation_num_relID"] = ( + acdc_df_by_rel_ID.loc[relative_ID_idx, "generation_num"] ) # Fix where generation_num_relID is still -1 @@ -844,29 +969,29 @@ def add_generation_num_of_relative_ID( acdc_df_to_fix = ( acdc_df_by_frame_i[to_fix_mask] .reset_index() - .set_index([*prefix_index, 'frame_i', 'relative_ID']) + .set_index([*prefix_index, "frame_i", "relative_ID"]) ) - acdc_df_by_cellID = ( - acdc_df_by_rel_ID.reset_index() - .set_index([*prefix_index, 'frame_i', 'Cell_ID']) + acdc_df_by_cellID = acdc_df_by_rel_ID.reset_index().set_index( + [*prefix_index, "frame_i", "Cell_ID"] ) # Intersection takes care of disappearing relative_IDs fixing_idx = acdc_df_to_fix.index.intersection(acdc_df_by_cellID.index) - acdc_df_to_fix.loc[fixing_idx, 'generation_num_relID'] = ( - acdc_df_by_cellID.loc[fixing_idx, 'generation_num'].values - ) + acdc_df_to_fix.loc[fixing_idx, "generation_num_relID"] = acdc_df_by_cellID.loc[ + fixing_idx, "generation_num" + ].values index_to_fix = acdc_df_by_frame_i[to_fix_mask].index - acdc_df_by_frame_i.loc[index_to_fix, 'generation_num_relID'] = ( - acdc_df_to_fix['generation_num_relID'].values - ) - + acdc_df_by_frame_i.loc[index_to_fix, "generation_num_relID"] = acdc_df_to_fix[ + "generation_num_relID" + ].values + acdc_df_with_col = acdc_df_by_frame_i.reset_index() return acdc_df_with_col - + + def get_IDs_gen_num_will_divide_wrong(global_cca_df): - """Get a list of (ID, gen_num) of cells whose `will_divide`>0 but the + """Get a list of (ID, gen_num) of cells whose `will_divide`>0 but the next generation does not exist (i.e., `will_divide` is wrong) Parameters @@ -877,86 +1002,89 @@ def get_IDs_gen_num_will_divide_wrong(global_cca_df): Returns ------- list of tuples - List of (ID, gen_num) of cells whose `will_divide`>0 but the + List of (ID, gen_num) of cells whose `will_divide`>0 but the next generation does not exist (i.e., `will_divide` is wrong) - + Notes ----- - To get the (ID, gen_num) where `will_divide` is wrong we first get an - index of (ID, gen_num) where `will_divide`>0. - - Then we get the same index but with (ID, gen_num+1) which is the next - generation. - - Finally we check if (ID, gen_num+1) actually exists in the annotations. - If not, those are wrongly annotated with `will_divide`>0. To check for - the existence we get the difference between the next gen index and the - whole DataFrame (i.e., get the (ID, gen_num+1) that do not exist in + To get the (ID, gen_num) where `will_divide` is wrong we first get an + index of (ID, gen_num) where `will_divide`>0. + + Then we get the same index but with (ID, gen_num+1) which is the next + generation. + + Finally we check if (ID, gen_num+1) actually exists in the annotations. + If not, those are wrongly annotated with `will_divide`>0. To check for + the existence we get the difference between the next gen index and the + whole DataFrame (i.e., get the (ID, gen_num+1) that do not exist in annotations). - """ + """ global_cca_will_divide = ( - global_cca_df[(global_cca_df['will_divide'] > 0)] + global_cca_df[(global_cca_df["will_divide"] > 0)] ).reset_index() - + ID_gen_num_index = ( - global_cca_df.reset_index() - .set_index(['Cell_ID', 'generation_num']) - .index + global_cca_df.reset_index().set_index(["Cell_ID", "generation_num"]).index ) - + # Next generation index - next_gen_will_divide_cca_df = ( - global_cca_will_divide[['Cell_ID', 'generation_num']].copy() - ) - next_gen_will_divide_cca_df['generation_num'] += 1 + next_gen_will_divide_cca_df = global_cca_will_divide[ + ["Cell_ID", "generation_num"] + ].copy() + next_gen_will_divide_cca_df["generation_num"] += 1 next_gen_will_divide_index = ( next_gen_will_divide_cca_df.reset_index() - .set_index(['Cell_ID', 'generation_num']) + .set_index(["Cell_ID", "generation_num"]) .index ) - - # (ID, gen_num) list of cells with will_divide>0 but whose next + + # (ID, gen_num) list of cells with will_divide>0 but whose next # generation number actually does not exist IDs_will_divide_next_gen_does_not_exist = ( next_gen_will_divide_index.difference(ID_gen_num_index) - .to_frame().to_numpy() # .to_list() + .to_frame() + .to_numpy() # .to_list() ) IDs_will_divide_next_gen_does_not_exist[:, -1] -= 1 - - IDs_will_divide_wrong = list(zip( - IDs_will_divide_next_gen_does_not_exist[:,0], - IDs_will_divide_next_gen_does_not_exist[:, 1] - )) + + IDs_will_divide_wrong = list( + zip( + IDs_will_divide_next_gen_does_not_exist[:, 0], + IDs_will_divide_next_gen_does_not_exist[:, 1], + ) + ) return IDs_will_divide_wrong - + + def generate_mother_bud_total_df( - df, - column_operation_mapper: dict[str, str], - do_copy_all_nonselected_columns=True, - grouping_columns=None, - entity_colname='entity' - ): + df, + column_operation_mapper: dict[str, str], + do_copy_all_nonselected_columns=True, + grouping_columns=None, + entity_colname="entity", +): if grouping_columns is None: grouping_columns = [] - - df_G1 = df[df['cell_cycle_stage'] == 'G1'] - df_S = df[(df['cell_cycle_stage'] == 'S')] - df_S_bud = df_S[df_S['relationship'] == 'bud'] - df_S_moth = df_S[df_S['relationship'] == 'mother'] - + + df_G1 = df[df["cell_cycle_stage"] == "G1"] + df_S = df[(df["cell_cycle_stage"] == "S")] + df_S_bud = df_S[df_S["relationship"] == "bud"] + df_S_moth = df_S[df_S["relationship"] == "mother"] + df_S_bud_relID = df_S_bud.reset_index().set_index( - [*grouping_columns, 'frame_i', 'relative_ID'] + [*grouping_columns, "frame_i", "relative_ID"] ) - df_S_bud_relID.index.names = [*grouping_columns, 'frame_i', 'Cell_ID'] + df_S_bud_relID.index.names = [*grouping_columns, "frame_i", "Cell_ID"] df_S_moth = df_S_moth.reset_index().set_index( - [*grouping_columns, 'frame_i', 'Cell_ID'] + [*grouping_columns, "frame_i", "Cell_ID"] ) - + columns_to_add = [ - col for col, operation in column_operation_mapper.items() - if 'sum' in operation.lower() + col + for col, operation in column_operation_mapper.items() + if "sum" in operation.lower() ] - + if do_copy_all_nonselected_columns: df_S_tot = df_S_moth.copy() else: @@ -965,22 +1093,18 @@ def generate_mother_bud_total_df( df_G1 = df_G1[columns_to_keep].copy() df_S_bud = df_S_bud[columns_to_keep].copy() df_S_moth = df_S_moth[columns_to_keep].copy() - - df_S_tot[columns_to_add] = ( - df_S_tot[columns_to_add] + df_S_bud_relID[columns_to_add] - ) - - df_S_tot = df_S_tot.drop(columns='level_0', errors='ignore').reset_index() - df_S_moth = df_S_moth.drop(columns='level_0', errors='ignore').reset_index() + + df_S_tot[columns_to_add] = df_S_tot[columns_to_add] + df_S_bud_relID[columns_to_add] + + df_S_tot = df_S_tot.drop(columns="level_0", errors="ignore").reset_index() + df_S_moth = df_S_moth.drop(columns="level_0", errors="ignore").reset_index() df_S_bud = df_S_bud.reset_index() - + final_df = pd.concat( - [df_G1, df_S_moth, df_S_bud, df_S_tot], - keys=['G1', 'Mother', 'Bud', 'Total'], + [df_G1, df_S_moth, df_S_bud, df_S_tot], + keys=["G1", "Mother", "Bud", "Total"], names=[entity_colname], - ignore_index=True + ignore_index=True, ) - + return final_df - - \ No newline at end of file diff --git a/cellacdc/cli.py b/cellacdc/cli.py index 04b4e578f..b1b75f442 100644 --- a/cellacdc/cli.py +++ b/cellacdc/cli.py @@ -15,7 +15,7 @@ from . import load from . import error_up_str from . import issues_url -from . import myutils +from . import utils from . import config from . import core from . import features @@ -24,20 +24,23 @@ from . import favourite_func_metrics_csv_path from . import cca_functions + class HeadlessSignal: def __init__(self, *args): pass - + def emit(self, *args, **kwargs): pass + class ProgressCliSignal: def __init__(self, logger_func): self.logger_func = logger_func - + def emit(self, text): self.logger_func(text) + class KernelCliSignals: def __init__(self, logger_func): self.finished = HeadlessSignal(float) @@ -51,36 +54,36 @@ def __init__(self, logger_func): self.debug = HeadlessSignal(object) self.critical = HeadlessSignal(object) + class _WorkflowKernel: def __init__(self, logger, log_path, is_cli=False): self.logger = logger self.log_path = log_path self.is_cli = is_cli - + @exception_handler_cli def parse_paths(self, workflow_params): - paths_to_segm = workflow_params['paths_info']['paths'] - if 'initialization' in workflow_params: - ch_name = workflow_params['initialization']['user_ch_name'] - elif 'measurements' in workflow_params: - channels = workflow_params['measurements']['channels'] - channel_names_to_skip = ( - workflow_params['measurements']['channel_names_to_skip'] - ) + paths_to_segm = workflow_params["paths_info"]["paths"] + if "initialization" in workflow_params: + ch_name = workflow_params["initialization"]["user_ch_name"] + elif "measurements" in workflow_params: + channels = workflow_params["measurements"]["channels"] + channel_names_to_skip = workflow_params["measurements"][ + "channel_names_to_skip" + ] channels = [ch for ch in channels if ch not in channel_names_to_skip] ch_name = channels[0] else: printl(workflow_params, pretty=True) raise KeyError( - 'Cannot find channel name in workflow parameters. ' - 'See above.' + "Cannot find channel name in workflow parameters. See above." ) parsed_paths = [] for path in paths_to_segm: if os.path.isfile(path): parsed_paths.append(path) continue - + images_paths = load.get_images_paths(path) ch_filepaths = load.get_user_ch_paths(images_paths, ch_name) parsed_paths.extend(ch_filepaths) @@ -88,39 +91,38 @@ def parse_paths(self, workflow_params): @exception_handler_cli def parse_stop_frame_numbers(self, workflow_params): - stop_frames_param = ( - workflow_params['paths_info']['stop_frame_numbers'] - ) + stop_frames_param = workflow_params["paths_info"]["stop_frame_numbers"] return [int(n) for n in stop_frames_param] - + def quit(self, error=None): if not self.is_cli and error is not None: raise error - - self.logger.info('='*50) + + self.logger.info("=" * 50) if error is not None: self.logger.exception(traceback.format_exc()) - print('-'*60) - self.logger.info(f'[ERROR]: {error}{error_up_str}') + print("-" * 60) + self.logger.info(f"[ERROR]: {error}{error_up_str}") err_msg = ( - 'Cell-ACDC aborted due to **error**. ' - 'More details above or in the following log file:\n\n' - f'{self.log_path}\n\n' - 'If you cannot solve it, you can report this error by opening ' - 'an issue on our ' - 'GitHub page at the following link:\n\n' - f'{issues_url}\n\n' - 'Please **send the log file** when reporting a bug, thanks!' + "Cell-ACDC aborted due to **error**. " + "More details above or in the following log file:\n\n" + f"{self.log_path}\n\n" + "If you cannot solve it, you can report this error by opening " + "an issue on our " + "GitHub page at the following link:\n\n" + f"{issues_url}\n\n" + "Please **send the log file** when reporting a bug, thanks!" ) self.logger.info(err_msg) else: self.logger.info( - 'Cell-ACDC command-line interface closed. ' - f'{myutils.get_salute_string()}' + "Cell-ACDC command-line interface closed. " + f"{utils.get_salute_string()}" ) - self.logger.info('='*50) + self.logger.info("=" * 50) exit() + class SegmKernel(_WorkflowKernel): def __init__(self, logger, log_path, is_cli): super().__init__(logger, log_path, is_cli=is_cli) @@ -129,100 +131,92 @@ def __init__(self, logger, log_path, is_cli): def parse_custom_postproc_features_grouped(self, workflow_params): custom_postproc_grouped_features = {} for section, options in workflow_params.items(): - if not section.startswith('postprocess_features.'): + if not section.startswith("postprocess_features."): continue - category = section.split('.')[-1] + category = section.split(".")[-1] for option, value in options.items(): - if option == 'names': - values = value.strip('\n').strip().split('\n') + if option == "names": + values = value.strip("\n").strip().split("\n") custom_postproc_grouped_features[category] = values continue channel = option if category not in custom_postproc_grouped_features: - custom_postproc_grouped_features[category] = { - channel: [value] - } + custom_postproc_grouped_features[category] = {channel: [value]} elif channel not in custom_postproc_grouped_features[category]: - custom_postproc_grouped_features[category][channel] = ( - [value] - ) + custom_postproc_grouped_features[category][channel] = [value] else: custom_postproc_grouped_features[category][channel].append(value) return custom_postproc_grouped_features - - @exception_handler_cli + + @exception_handler_cli def init_args_from_params(self, workflow_params, logger_func): - args = workflow_params['initialization'].copy() - - initialization_section = workflow_params['initialization'] - args['use3DdataFor2Dsegm'] = initialization_section.get( - 'use3DdataFor2Dsegm', False + args = workflow_params["initialization"].copy() + + initialization_section = workflow_params["initialization"] + args["use3DdataFor2Dsegm"] = initialization_section.get( + "use3DdataFor2Dsegm", False ) - args['model_kwargs'] = workflow_params['segmentation_model_params'] - args['track_params'] = workflow_params.get('tracker_params', {}) - args['standard_postrocess_kwargs'] = ( - workflow_params.get('standard_postprocess_features', {}) + args["model_kwargs"] = workflow_params["segmentation_model_params"] + args["track_params"] = workflow_params.get("tracker_params", {}) + args["standard_postrocess_kwargs"] = workflow_params.get( + "standard_postprocess_features", {} ) - args['custom_postproc_features'] = ( - workflow_params.get('custom_postprocess_features', {}) + args["custom_postproc_features"] = workflow_params.get( + "custom_postprocess_features", {} ) - args['custom_postproc_grouped_features'] = ( + args["custom_postproc_grouped_features"] = ( self.parse_custom_postproc_features_grouped(workflow_params) ) - - args['SizeT'] = workflow_params['metadata']['SizeT'] - args['SizeZ'] = workflow_params['metadata']['SizeZ'] - args['logger_func'] = logger_func - args['init_model_kwargs'] = ( - workflow_params.get('init_segmentation_model_params', {}) - ) - args['init_tracker_kwargs'] = ( - workflow_params.get('init_tracker_params', {}) - ) - - args['preproc_recipe'] = config.preprocess_ini_items_to_recipe( - workflow_params + + args["SizeT"] = workflow_params["metadata"]["SizeT"] + args["SizeZ"] = workflow_params["metadata"]["SizeZ"] + args["logger_func"] = logger_func + args["init_model_kwargs"] = workflow_params.get( + "init_segmentation_model_params", {} ) - args['reduce_memory_usage'] = initialization_section.get( - 'reduce_memory_usage', False + args["init_tracker_kwargs"] = workflow_params.get("init_tracker_params", {}) + + args["preproc_recipe"] = config.preprocess_ini_items_to_recipe(workflow_params) + args["reduce_memory_usage"] = initialization_section.get( + "reduce_memory_usage", False ) - + self.init_args(**args) - + @exception_handler_cli def init_args( - self, - user_ch_name, - segm_endname, - model_name, - do_tracking, - do_postprocess, - do_save, - image_channel_tracker, - standard_postrocess_kwargs, - custom_postproc_grouped_features, - custom_postproc_features, - isSegm3D, - use_ROI, - second_channel_name, - use3DdataFor2Dsegm, - model_kwargs, - track_params, - SizeT, - SizeZ, - tracker_name='', - model=None, - preproc_recipe=None, - init_model_kwargs=None, - init_tracker_kwargs=None, - tracker=None, - signals=None, - logger_func=print, - innerPbar_available=False, - is_segment3DT_available=False, - reduce_memory_usage=False, - use_freehand_ROI=True - ): + self, + user_ch_name, + segm_endname, + model_name, + do_tracking, + do_postprocess, + do_save, + image_channel_tracker, + standard_postrocess_kwargs, + custom_postproc_grouped_features, + custom_postproc_features, + isSegm3D, + use_ROI, + second_channel_name, + use3DdataFor2Dsegm, + model_kwargs, + track_params, + SizeT, + SizeZ, + tracker_name="", + model=None, + preproc_recipe=None, + init_model_kwargs=None, + init_tracker_kwargs=None, + tracker=None, + signals=None, + logger_func=print, + innerPbar_available=False, + is_segment3DT_available=False, + reduce_memory_usage=False, + use_freehand_ROI=True, + ): self.user_ch_name = user_ch_name self.segm_endname = segm_endname self.model_name = model_name @@ -256,473 +250,136 @@ def init_args( self.model_kwargs = model_kwargs self.tracker_name = tracker_name self.init_tracker( - self.do_tracking, track_params, tracker_name=tracker_name, - tracker=tracker + self.do_tracking, track_params, tracker_name=tracker_name, tracker=tracker ) - + from cellacdc.workflow.adapters import workflow_context_from_segm_kernel + from cellacdc.workflow.pipelines.segm import build_position_segm_graph + + self._workflow_ctx = workflow_context_from_segm_kernel(self) + self._position_segm_graph = build_position_segm_graph(self._workflow_ctx).compile() + @exception_handler_cli def init_segm_model(self, posData): self.signals.progress.emit( - f'\nInitializing {self.model_name} segmentation model...' + f"\nInitializing {self.model_name} segmentation model..." ) - acdcSegment = myutils.import_segment_module(self.model_name) - init_argspecs, segment_argspecs = myutils.getModelArgSpec(acdcSegment) - self.init_model_kwargs = myutils.parse_model_params( + acdcSegment = utils.import_segment_module(self.model_name) + init_argspecs, segment_argspecs = utils.getModelArgSpec(acdcSegment) + self.init_model_kwargs = utils.parse_model_params( init_argspecs, self.init_model_kwargs ) - self.model_kwargs = myutils.parse_model_params( + self.model_kwargs = utils.parse_model_params( segment_argspecs, self.model_kwargs ) if self.second_channel_name is not None: - self.init_model_kwargs['is_rgb'] = True + self.init_model_kwargs["is_rgb"] = True - self.model = myutils.init_segm_model( + self.model = utils.init_segm_model( acdcSegment, posData, self.init_model_kwargs ) if self.model is None: # The model was not initialized correctly return self.is_segment3DT_available = any( - [name=='segment3DT' for name in dir(self.model)] + [name == "segment3DT" for name in dir(self.model)] ) - + if hasattr(self, "_workflow_ctx"): + self._workflow_ctx.model = self.model + self._workflow_ctx.is_segment3dt_available = self.is_segment3DT_available + self._workflow_ctx.init_model_kwargs = dict(self.init_model_kwargs or {}) + self._workflow_ctx.model_kwargs = dict(self.model_kwargs or {}) + @exception_handler_cli - def init_tracker( - self, do_tracking, track_params, tracker_name='', tracker=None - ): + def init_tracker(self, do_tracking, track_params, tracker_name="", tracker=None): if not do_tracking: self.tracker = None return - + if tracker is None: - self.signals.progress.emit(f'Initializing {tracker_name} tracker...') - tracker_module = myutils.import_tracker_module(tracker_name) - init_argspecs, track_argspecs = myutils.getTrackerArgSpec( + self.signals.progress.emit(f"Initializing {tracker_name} tracker...") + tracker_module = utils.import_tracker_module(tracker_name) + init_argspecs, track_argspecs = utils.getTrackerArgSpec( tracker_module, realTime=False ) - self.init_tracker_kwargs = myutils.parse_model_params( + self.init_tracker_kwargs = utils.parse_model_params( init_argspecs, self.init_tracker_kwargs ) - self.init_tracker_kwargs = myutils.parse_model_params( + self.init_tracker_kwargs = utils.parse_model_params( init_argspecs, self.init_tracker_kwargs ) - track_params = myutils.parse_model_params( - track_argspecs, track_params - ) + track_params = utils.parse_model_params(track_argspecs, track_params) tracker = tracker_module.tracker(**self.init_tracker_kwargs) - + self.track_params = track_params self.tracker = tracker - + def _tracker_track(self, lab, tracker_input_img=None): tracked_lab = core.tracker_track( - lab, self.tracker, self.track_params, - intensity_img=tracker_input_img, - logger_func=self.logger_func + lab, + self.tracker, + self.track_params, + intensity_img=tracker_input_img, + logger_func=self.logger_func, ) return tracked_lab - - @exception_handler_cli - def run( - self, - img_path, - stop_frame_n - ): - posData = load.loadData(img_path, self.user_ch_name) - - self.logger_func(f'Loading {posData.relPath}...') - posData.getBasenameAndChNames() - posData.buildPaths() - posData.loadImgData() - posData.loadOtherFiles( - load_segm_data=False, - load_acdc_df=False, - load_shifts=True, - loadSegmInfo=True, - load_delROIsInfo=False, - load_dataPrep_ROIcoords=True, - load_bkgr_data=True, - load_last_tracked_i=False, - load_metadata=True, - load_dataprep_free_roi=True, - end_filename_segm=self.segm_endname - ) - # Get only name from the string 'segm_.npz' - endName = ( - self.segm_endname.replace('segm', '', 1) - .replace('_', '', 1) - .split('.')[0] + @exception_handler_cli + def run(self, img_path, stop_frame_n): + from cellacdc.workflow.adapters import ( + runnable_config_from_segm_kernel, + sync_segm_kernel_from_context, + update_workflow_context_from_segm_kernel, ) - if endName: - # Create a new file that is not the default 'segm.npz' - posData.setFilePaths(endName) - - segmFilename = os.path.basename(posData.segm_npz_path) - if self.do_save: - self.logger_func(f'\nSegmentation file {segmFilename}...') - - posData.SizeT = self.SizeT - if self.SizeZ > 1: - SizeZ = posData.img_data.shape[-3] - posData.SizeZ = SizeZ - else: - posData.SizeZ = 1 - - posData.isSegm3D = self.isSegm3D - posData.saveMetadata() - - isROIactive = False - if posData.dataPrep_ROIcoords is not None and self.use_ROI: - df_roi = posData.dataPrep_ROIcoords.loc[0] - isROIactive = df_roi.at['cropped', 'value'] == 0 - x0, x1, y0, y1 = df_roi['value'].astype(int)[:4] - Y, X = posData.img_data.shape[-2:] - x0 = x0 if x0>0 else 0 - y0 = y0 if y0>0 else 0 - x1 = x1 if x1 1: - self.t0 = 0 - if posData.SizeZ > 1 and not self.isSegm3D and not self.use3DdataFor2Dsegm: - # 2D segmentation on 3D data over time - img_data = posData.img_data - - if self.second_channel_name is not None: - second_ch_data_slice = secondChImgData[self.t0:stop_i] - if isROIactive: - Y, X = img_data.shape[-2:] - img_data = img_data[:, :, y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data_slice = second_ch_data_slice[:, :, y0:y1, x0:x1] - pad_info = ((0, 0), (y0, Y-y1), (x0, X-x1)) - - img_data_slice = img_data[self.t0:stop_i] - postprocess_img = img_data - - Y, X = img_data.shape[-2:] - newShape = (stop_i, Y, X) - img_data = np.zeros(newShape, img_data.dtype) - - if self.second_channel_name is not None: - second_ch_data = np.zeros(newShape, secondChImgData.dtype) - df = posData.segmInfo_df.loc[posData.filename] - for z_info in df[:stop_i].itertuples(): - i = z_info.Index - z = z_info.z_slice_used_dataPrep - zProjHow = z_info.which_z_proj - img = img_data_slice[i] - if self.second_channel_name is not None: - second_ch_img = second_ch_data_slice[i] - if zProjHow == 'single z-slice': - img_data[i] = img[z] - if self.second_channel_name is not None: - second_ch_data[i] = second_ch_img[z] - elif zProjHow == 'max z-projection': - img_data[i] = img.max(axis=0) - if self.second_channel_name is not None: - second_ch_data[i] = second_ch_img.max(axis=0) - elif zProjHow == 'mean z-projection': - img_data[i] = img.mean(axis=0) - if self.second_channel_name is not None: - second_ch_data[i] = second_ch_img.mean(axis=0) - elif zProjHow == 'median z-proj.': - img_data[i] = np.median(img, axis=0) - if self.second_channel_name is not None: - second_ch_data[i] = np.median(second_ch_img, axis=0) - elif posData.SizeZ > 1 and (self.isSegm3D or self.use3DdataFor2Dsegm): - # 3D segmentation on 3D data over time - img_data = posData.img_data[self.t0:stop_i] - postprocess_img = img_data - if self.second_channel_name is not None: - second_ch_data = secondChImgData[self.t0:stop_i] - if isROIactive: - Y, X = img_data.shape[-2:] - img_data = img_data[:, :, y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] - pad_info = ((0, 0), (0, 0), (y0, Y-y1), (x0, X-x1)) - else: - # 2D data over time - img_data = posData.img_data[self.t0:stop_i] - postprocess_img = img_data - if self.second_channel_name is not None: - second_ch_data = secondChImgData[self.t0:stop_i] - if isROIactive: - Y, X = img_data.shape[-2:] - img_data = img_data[:, y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] - pad_info = ((0, 0), (y0, Y-y1), (x0, X-x1)) - else: - if posData.SizeZ > 1 and not self.isSegm3D and not self.use3DdataFor2Dsegm: - img_data = posData.img_data - if self.second_channel_name is not None: - second_ch_data = secondChImgData - if isROIactive: - Y, X = img_data.shape[-2:] - pad_info = ((y0, Y-y1), (x0, X-x1)) - img_data = img_data[:, y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] - - postprocess_img = img_data - # 2D segmentation on single 3D image - z_info = posData.segmInfo_df.loc[posData.filename].iloc[0] - z = z_info.z_slice_used_dataPrep - zProjHow = z_info.which_z_proj - if zProjHow == 'single z-slice': - img_data = img_data[z] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[z] - elif zProjHow == 'max z-projection': - img_data = img_data.max(axis=0) - if self.second_channel_name is not None: - second_ch_data = second_ch_data.max(axis=0) - elif zProjHow == 'mean z-projection': - img_data = img_data.mean(axis=0) - if self.second_channel_name is not None: - second_ch_data = second_ch_data.mean(axis=0) - elif zProjHow == 'median z-proj.': - img_data = np.median(img_data, axis=0) - if self.second_channel_name is not None: - second_ch_data[i] = np.median(second_ch_data, axis=0) - elif posData.SizeZ > 1 and (self.isSegm3D or self.use3DdataFor2Dsegm): - # 3D segmentation on 3D z-stack - img_data = posData.img_data - if self.second_channel_name is not None: - second_ch_data = secondChImgData - if isROIactive: - Y, X = img_data.shape[-2:] - pad_info = ((0, 0), (y0, Y-y1), (x0, X-x1)) - img_data = img_data[:, y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[:, y0:y1, x0:x1] - postprocess_img = img_data - else: - # Single 2D image - img_data = posData.img_data - if self.second_channel_name is not None: - second_ch_data = secondChImgData - if isROIactive: - Y, X = img_data.shape[-2:] - pad_info = ((y0, Y-y1), (x0, X-x1)) - img_data = img_data[y0:y1, x0:x1] - if self.second_channel_name is not None: - second_ch_data = second_ch_data[y0:y1, x0:x1] - postprocess_img = img_data - - self.logger_func(f'\nImage shape = {img_data.shape}') - - if self.model is None: - self.init_segm_model(posData) - - if self.model is None: - self.logger_func( - f'\nSegmentation model {self.model_name} was not initialized!' - ) - return - - """Segmentation routine""" - self.logger_func(f'\nSegmenting with {self.model_name}...') - t0 = time.perf_counter() - if posData.SizeT > 1: - if self.innerPbar_available and self.signals is not None: - self.signals.resetInnerPbar.emit(len(img_data)) - - if self.is_segment3DT_available and img_data.ndim == 3: - self.model_kwargs['signals'] = ( - self.signals, self.innerPbar_available - ) - if self.second_channel_name is not None: - img_data = self.model.second_ch_img_to_stack( - img_data, second_ch_data - ) - lab_stack = core.segm_model_segment( - self.model, img_data, self.model_kwargs, - is_timelapse_model_and_data=True, - preproc_recipe=self.preproc_recipe, - posData=posData - ) - if self.innerPbar_available: - # emit one pos done - self.signals.progressBar.emit(1) - else: - lab_stack = [] - pbar = tqdm(total=len(img_data), ncols=100) - for t, img in enumerate(img_data): - if self.second_channel_name is not None: - img = self.model.second_ch_img_to_stack( - img, second_ch_data[t] - ) - - lab = core.segm_model_segment( - self.model, img, self.model_kwargs, frame_i=t, - preproc_recipe=self.preproc_recipe, - posData=posData - ) - lab_stack.append(lab) - if self.innerPbar_available: - self.signals.innerProgressBar.emit(1) - else: - self.signals.progressBar.emit(1) - pbar.update() - pbar.close() - lab_stack = np.array(lab_stack, dtype=np.uint32) - if self.innerPbar_available: - # emit one pos done - self.signals.progressBar.emit(1) - else: - if self.second_channel_name is not None: - img_data = self.model.second_ch_img_to_stack( - img_data, second_ch_data - ) - - lab_stack = core.segm_model_segment( - self.model, img_data, self.model_kwargs, frame_i=0, - preproc_recipe=self.preproc_recipe, - posData=posData - ) - self.signals.progressBar.emit(1) - # lab_stack = smooth_contours(lab_stack, radius=2) - - posData.saveSamEmbeddings(logger_func=self.logger_func) - - if len(posData.dataPrepFreeRoiPoints) > 0 and self.use_freehand_ROI: - self.logger_func( - 'Removing objects outside the dataprep free-hand ROI...' - ) - lab_stack = posData.clearSegmObjsDataPrepFreeRoi( - lab_stack, is_timelapse=posData.SizeT > 1 - ) - - if self.do_postprocess: - if posData.SizeT > 1: - pbar = tqdm(total=len(lab_stack), ncols=100) - for t, lab in enumerate(lab_stack): - lab_cleaned = core.post_process_segm( - lab, **self.standard_postrocess_kwargs - ) - lab_stack[t] = lab_cleaned - if self.custom_postproc_features: - lab_filtered = features.custom_post_process_segm( - posData, self.custom_postproc_grouped_features, - lab_cleaned, postprocess_img, t, posData.filename, - posData.user_ch_name, self.custom_postproc_features - ) - lab_stack[t] = lab_filtered - pbar.update() - pbar.close() - else: - lab_stack = core.post_process_segm( - lab_stack, **self.standard_postrocess_kwargs - ) - if self.custom_postproc_features: - lab_stack = features.custom_post_process_segm( - posData, self.custom_postproc_grouped_features, - lab_stack, postprocess_img, 0, posData.filename, - posData.user_ch_name, self.custom_postproc_features - ) + from cellacdc.workflow.state import PositionState - if posData.SizeT > 1 and self.do_tracking: - self.logger_func(f'\nTracking with {self.tracker_name} tracker...') - if self.do_save: - # Since tracker could raise errors we save the not-tracked - # version which will eventually be overwritten - self.logger_func(f'Saving NON-tracked masks of {posData.relPath}...') - io.savez_compressed(posData.segm_npz_path, lab_stack) - - self.signals.innerPbar_available = self.innerPbar_available - self.track_params['signals'] = self.signals - if self.image_channel_tracker is not None: - # Check if loading the image for the tracker is required - if 'image' in self.track_params: - trackerInputImage = self.track_params.pop('image') - else: - self.logger_func( - 'Loading image data of channel ' - f'"{self.image_channel_tracker}"') - trackerInputImage = posData.loadChannelData( - self.image_channel_tracker) - tracked_stack = self._tracker_track( - lab_stack, tracker_input_img=trackerInputImage - ) - else: - tracked_stack = self._tracker_track(lab_stack) - posData.fromTrackerToAcdcDf(self.tracker, tracked_stack, save=True) - else: - tracked_stack = lab_stack - try: - if self.innerPbar_available: - self.signals.innerProgressBar.emit(stop_frame_n) - else: - self.signals.progressBar.emit(stop_frame_n) - except AttributeError: - if self.innerPbar_available: - self.signals.innerProgressBar.emit(1) - else: - self.signals.progressBar.emit(1) - - if isROIactive: - self.logger_func(f'Padding with zeros {pad_info}...') - tracked_stack = np.pad(tracked_stack, pad_info, mode='constant') - - if self.do_save: - self.logger_func(f'Saving {posData.relPath}...') - io.savez_compressed(posData.segm_npz_path, tracked_stack) - - t_end = time.perf_counter() + update_workflow_context_from_segm_kernel(self._workflow_ctx, self) + state = self._position_segm_graph.invoke( + PositionState(img_path=img_path, stop_frame_n=stop_frame_n), + runnable_config_from_segm_kernel(self), + ) + sync_segm_kernel_from_context(self, self._workflow_ctx) + return state - self.logger_func(f'\n{posData.relPath} done.') class ComputeMeasurementsKernel(_WorkflowKernel): def __init__(self, logger, log_path, is_cli): super().__init__(logger, log_path, is_cli=is_cli) self.setup_done = False - + def init_args(self, channel_names, end_filename_segm): self.ch_names = channel_names self.end_filename_segm = end_filename_segm self.notLoadedChNames = [] self.save_object_counts_table = False - - def log(self, message, level='INFO'): + + def log(self, message, level="INFO"): try: self.logger.log(message, level=level) return except Exception as err: pass - + try: self.logger.log(message) return except Exception as err: pass - + try: log_func = getattr(self.logger, level.lower()) log_func(message) return except Exception as err: pass - + def _set_metrics_func_from_posData(self, posData): - (metrics_func, all_metrics_names, custom_func_dict, total_metrics, - ch_indipend_custom_func_dict) = measurements.getMetricsFunc(posData) + ( + metrics_func, + all_metrics_names, + custom_func_dict, + total_metrics, + ch_indipend_custom_func_dict, + ) = measurements.getMetricsFunc(posData) self.metrics_func = metrics_func self.all_metrics_names = all_metrics_names self.total_metrics = total_metrics @@ -731,106 +388,95 @@ def _set_metrics_func_from_posData(self, posData): self.mixed_channel_combine_metrics = [] self.channel_names = posData.chNames self.not_loaded_channel_names = [] - + def to_workflow_config_params(self): params = { - 'channels': '\n'.join(self.ch_names), - 'end_filename_segm': self.end_filename_segm + "channels": "\n".join(self.ch_names), + "end_filename_segm": self.end_filename_segm, } - params['channel_names_to_skip'] = '\n'.join(self.chNamesToSkip) - params['channel_names_to_process'] = '\n'.join(self.chNamesToProcess) + params["channel_names_to_skip"] = "\n".join(self.chNamesToSkip) + params["channel_names_to_process"] = "\n".join(self.chNamesToProcess) calc_for_each_zslice = [ - f'{channel},{value}' + f"{channel},{value}" for channel, value in self.calc_for_each_zslice_mapper.items() ] - params['calc_for_each_zslice_channels'] = '\n'.join(calc_for_each_zslice) - + params["calc_for_each_zslice_channels"] = "\n".join(calc_for_each_zslice) + for channel, colnames in self.metricsToSkip.items(): - params[f'metrics_to_skip_{channel}'] = '\n'.join(colnames) - + params[f"metrics_to_skip_{channel}"] = "\n".join(colnames) + for channel, colnames in self.metricsToSave.items(): - params[f'metrics_to_save_{channel}'] = '\n'.join(colnames) - - params['calc_for_each_zslice_size'] = str( - self.calc_size_for_each_zslice - ) - - params['size_metrics_to_save'] = '\n'.join(self.sizeMetricsToSave) - params['regionprops_to_save'] = '\n'.join(self.regionPropsToSave) - if hasattr(self, 'chIndipendCustomMetricsToSave'): - params['channel_indipendent_custom_metrics_to_save'] = ( - '\n'.join(self.chIndipendCustomMetricsToSave) + params[f"metrics_to_save_{channel}"] = "\n".join(colnames) + + params["calc_for_each_zslice_size"] = str(self.calc_size_for_each_zslice) + + params["size_metrics_to_save"] = "\n".join(self.sizeMetricsToSave) + params["regionprops_to_save"] = "\n".join(self.regionPropsToSave) + if hasattr(self, "chIndipendCustomMetricsToSave"): + params["channel_indipendent_custom_metrics_to_save"] = "\n".join( + self.chIndipendCustomMetricsToSave ) - if hasattr(self, 'mixedChCombineMetricsToSkip'): - params['mixed_combine_metrics_to_skip'] = ( - '\n'.join(self.mixedChCombineMetricsToSkip) + if hasattr(self, "mixedChCombineMetricsToSkip"): + params["mixed_combine_metrics_to_skip"] = "\n".join( + self.mixedChCombineMetricsToSkip ) - - params['save_object_counts_table'] = self.save_object_counts_table - + + params["save_object_counts_table"] = self.save_object_counts_table + return params - + def set_metrics_from_workflow_config_params(self, config_params): - self.init_args( - config_params['channels'], - config_params['end_filename_segm'] - ) - - self.chNamesToSkip = config_params['channel_names_to_skip'] + self.init_args(config_params["channels"], config_params["end_filename_segm"]) + + self.chNamesToSkip = config_params["channel_names_to_skip"] self.chNamesToProcess = config_params.get( - 'channel_names_to_process', config_params['channels'] + "channel_names_to_process", config_params["channels"] ) - self.metricsToSkip = {chName:[] for chName in self.ch_names} - self.metricsToSave = {chName:[] for chName in self.ch_names} + self.metricsToSkip = {chName: [] for chName in self.ch_names} + self.metricsToSave = {chName: [] for chName in self.ch_names} self.mixedChCombineMetricsToSkip = [] self.calc_for_each_zslice_mapper = {} - self.calc_size_for_each_zslice = ( - config_params['calc_for_each_zslice_size'] - ) - self.sizeMetricsToSave = config_params['size_metrics_to_save'] - self.regionPropsToSave = config_params['regionprops_to_save'] + self.calc_size_for_each_zslice = config_params["calc_for_each_zslice_size"] + self.sizeMetricsToSave = config_params["size_metrics_to_save"] + self.regionPropsToSave = config_params["regionprops_to_save"] self.save_object_counts_table = config_params.get( - 'save_object_counts_table', False + "save_object_counts_table", False ) - if 'channel_indipendent_custom_metrics_to_save' in config_params: - self.chIndipendCustomMetricsToSave = ( - config_params['channel_indipendent_custom_metrics_to_save'] - ) - - if 'mixed_combine_metrics_to_skip' in config_params: - self.mixedChCombineMetricsToSkip = ( - config_params['mixed_combine_metrics_to_skip'] - ) - - for channel_value in config_params['calc_for_each_zslice_channels']: - channel, value = channel_value.split(',') - value = value.lower() == 'true' + if "channel_indipendent_custom_metrics_to_save" in config_params: + self.chIndipendCustomMetricsToSave = config_params[ + "channel_indipendent_custom_metrics_to_save" + ] + + if "mixed_combine_metrics_to_skip" in config_params: + self.mixedChCombineMetricsToSkip = config_params[ + "mixed_combine_metrics_to_skip" + ] + + for channel_value in config_params["calc_for_each_zslice_channels"]: + channel, value = channel_value.split(",") + value = value.lower() == "true" self.calc_for_each_zslice_mapper[channel] = value - + for channel in self.ch_names: - metrics_to_skip = config_params.get( - f'metrics_to_skip_{channel}', '' - ) + metrics_to_skip = config_params.get(f"metrics_to_skip_{channel}", "") if metrics_to_skip: self.metricsToSkip[channel] = metrics_to_skip - - metrics_to_save = config_params.get( - f'metrics_to_save_{channel}', '' - ) + + metrics_to_save = config_params.get(f"metrics_to_save_{channel}", "") if metrics_to_save: self.metricsToSave[channel] = metrics_to_save - + def set_save_objects_count_table(self, yes: bool): self.save_object_counts_table = yes - + def set_metrics_from_set_measurements_dialog(self, setMeasurementsDialog): self.chNamesToSkip = [] self.chNamesToProcess = [] - self.metricsToSkip = {chName:[] for chName in self.ch_names} - self.metricsToSave = {chName:[] for chName in self.ch_names} + self.metricsToSkip = {chName: [] for chName in self.ch_names} + self.metricsToSave = {chName: [] for chName in self.ch_names} self.calc_for_each_zslice_mapper = {} self.calc_size_for_each_zslice = False - + # Remove unchecked metrics and load checked not loaded channels for chNameGroupbox in setMeasurementsDialog.chNameGroupboxes: chName = chNameGroupbox.chName @@ -838,7 +484,7 @@ def set_metrics_from_set_measurements_dialog(self, setMeasurementsDialog): # Skip entire channel self.chNamesToSkip.append(chName) continue - + self.chNamesToProcess.append(chName) self.calc_for_each_zslice_mapper[chName] = ( chNameGroupbox.calcForEachZsliceRequested @@ -849,7 +495,7 @@ def set_metrics_from_set_measurements_dialog(self, setMeasurementsDialog): self.metricsToSkip[chName].append(colname) else: self.metricsToSave[chName].append(colname) - func_name = colname[len(chName):] + func_name = colname[len(chName) :] self.calc_size_for_each_zslice = ( setMeasurementsDialog.sizeMetricsQGBox.calcForEachZsliceRequested @@ -871,7 +517,7 @@ def set_metrics_from_set_measurements_dialog(self, setMeasurementsDialog): for checkBox in setMeasurementsDialog.regionPropsQGBox.checkBoxes: if checkBox.isChecked(): self.regionPropsToSave.append(checkBox.text()) - + self.regionPropsToSave = tuple(self.regionPropsToSave) if setMeasurementsDialog.chIndipendCustomeMetricsQGBox is not None: @@ -886,14 +532,12 @@ def set_metrics_from_set_measurements_dialog(self, setMeasurementsDialog): for checkBox in checkBoxes: if skipAll: continue - + if checkBox.isChecked(): - chIndipendCustomMetricsToSave.append(checkBox.text()) + chIndipendCustomMetricsToSave.append(checkBox.text()) + + self.chIndipendCustomMetricsToSave = tuple(chIndipendCustomMetricsToSave) - self.chIndipendCustomMetricsToSave = tuple( - chIndipendCustomMetricsToSave - ) - self.mixedChCombineMetricsToSkip = [] if setMeasurementsDialog.mixedChannelsCombineMetricsQGBox is not None: skipAll = ( @@ -917,67 +561,72 @@ def _init_metrics_to_save(self, posData): self.isSegm3D = posData.getIsSegm3D() if self.metricsToSave is None: - # self.metricsToSave means that the user did not set + # self.metricsToSave means that the user did not set # through setMeasurements dialog --> save all measurements - self.metricsToSave = {chName:[] for chName in posData.loadedChNames} + self.metricsToSave = {chName: [] for chName in posData.loadedChNames} isManualBackgrPresent = posData.manualBackgroundLab is not None for chName in posData.loadedChNames: metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( - posData.SizeZ>1, chName, isSegm3D=self.isSegm3D, - isManualBackgrPresent=isManualBackgrPresent + posData.SizeZ > 1, + chName, + isSegm3D=self.isSegm3D, + isManualBackgrPresent=isManualBackgrPresent, ) self.metricsToSave[chName].extend(metrics_desc.keys()) self.metricsToSave[chName].extend(bkgr_val_desc.keys()) custom_metrics_desc = measurements.custom_metrics_desc( - posData.SizeZ>1, chName, posData=posData, - isSegm3D=self.isSegm3D, return_combine=False - ) - self.metricsToSave[chName].extend( - custom_metrics_desc.keys() + posData.SizeZ > 1, + chName, + posData=posData, + isSegm3D=self.isSegm3D, + return_combine=False, ) - + self.metricsToSave[chName].extend(custom_metrics_desc.keys()) + # Get metrics parameters --> function name, how etc self.metrics_func, _ = measurements.standard_metrics_func() self.custom_func_dict = measurements.get_custom_metrics_func() params = measurements.get_metrics_params( self.metricsToSave, self.metrics_func, self.custom_func_dict ) - (bkgr_metrics_params, foregr_metrics_params, - concentration_metrics_params, custom_metrics_params) = params + ( + bkgr_metrics_params, + foregr_metrics_params, + concentration_metrics_params, + custom_metrics_params, + ) = params self.bkgr_metrics_params = bkgr_metrics_params self.foregr_metrics_params = foregr_metrics_params self.concentration_metrics_params = concentration_metrics_params self.custom_metrics_params = custom_metrics_params - + self.ch_indipend_custom_func_dict = ( measurements.get_channel_indipendent_custom_metrics_func() ) - if not hasattr(self, 'chIndipendCustomMetricsToSave'): + if not hasattr(self, "chIndipendCustomMetricsToSave"): self.chIndipendCustomMetricsToSave = list( measurements.ch_indipend_custom_metrics_desc( - posData.SizeZ>1, isSegm3D=self.isSegm3D, + posData.SizeZ > 1, + isSegm3D=self.isSegm3D, ).keys() ) - + self.ch_indipend_custom_func_params = ( measurements.get_channel_indipend_custom_metrics_params( - self.ch_indipend_custom_func_dict, - self.chIndipendCustomMetricsToSave + self.ch_indipend_custom_func_dict, self.chIndipendCustomMetricsToSave ) ) - + def _load_posData(self, img_path, end_filename_segm): images_path = os.path.dirname(img_path) - exp_foldername = os.path.basename( - os.path.dirname(os.path.dirname(images_path)) - ) - basename, channel_names = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') + exp_foldername = os.path.basename(os.path.dirname(os.path.dirname(images_path))) + basename, channel_names = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") ) posData = load.loadData(img_path, channel_names[0]) - - posData.getBasenameAndChNames(useExt=('.tif', '.h5')) + + posData.getBasenameAndChNames(useExt=(".tif", ".h5")) posData.buildPaths() posData.loadImgData() @@ -994,139 +643,179 @@ def _load_posData(self, img_path, end_filename_segm): load_customAnnot=True, load_customCombineMetrics=True, end_filename_segm=end_filename_segm, - load_dataPrep_ROIcoords=True + load_dataPrep_ROIcoords=True, ) posData.labelSegmData() - + self.isSegm3D = posData.getIsSegm3D() - + # Allow single 2D/3D image if posData.SizeT == 1: posData.img_data = posData.img_data[np.newaxis] - + if posData.segm_data is not None: posData.segm_data = posData.segm_data[np.newaxis] - + return posData - + def _load_image_data(self, posData, channel_names): if posData.fluo_data_dict: - return - + return + # Load fluorescence channels data since not loaded in GUI posData.loadedChNames = [] for c, channel in enumerate(channel_names): if channel in self.chNamesToSkip: - continue - + continue + if channel == posData.user_ch_name: img_data = posData.img_data filename = posData.filename bkgrData = posData.bkgrData else: - filepath = load.get_filename_from_channel( - posData.images_path, channel - ) + filepath = load.get_filename_from_channel(posData.images_path, channel) img_data, bkgrData = self._load_channel_data(filepath) if posData.SizeT == 1: img_data = img_data[np.newaxis] - + filename_ext = os.path.basename(filepath) filename, _ = os.path.splitext(filename_ext) - + posData.loadedChNames.append(channel) posData.loadedFluoChannels.add(channel) posData.fluo_data_dict[filename] = img_data posData.fluo_bkgrData_dict[filename] = bkgrData - + def init_signals(self, computeMetricsWorker, saveDataWorker): self.customMetricsCritical = HeadlessSignal() self.regionPropsCritical = HeadlessSignal() - + if saveDataWorker is not None: self.customMetricsCritical = saveDataWorker.customMetricsCritical self.regionPropsCritical = saveDataWorker.regionPropsCritical - + elif computeMetricsWorker is not None: saveDataWorker = computeMetricsWorker.mainWin.gui.saveDataWorker self.customMetricsCritical = saveDataWorker.customMetricsCritical self.regionPropsCritical = saveDataWorker.regionPropsCritical - - @exception_handler_cli - def run( - self, - img_path: os.PathLike='', - stop_frame_n: int=1, - end_filename_segm: str='', - computeMetricsWorker=None, - saveDataWorker=None, - posData=None, - save_metrics=True, - do_init_metrics=True, - last_cca_frame_i=None - ): - if posData is None: - posData = self._load_posData(img_path, end_filename_segm) - + + def _run_metrics_cli( + self, + posData, + stop_frame_n: int, + save_metrics: bool = True, + last_cca_frame_i=None, + ): channel_names = posData.chNames - images_path = posData.images_path exp_foldername = os.path.basename(posData.exp_path) - - self._set_metrics_func_from_posData(posData) - if computeMetricsWorker is not None and do_init_metrics: - computeMetricsWorker.emitSigInitMetricsDialog(posData) - if computeMetricsWorker.abort: - computeMetricsWorker.signals.finished.emit(computeMetricsWorker) - return - - if self.setup_done: - computeMetricsWorker.signals.finished.emit(computeMetricsWorker) - return - - computeMetricsWorker.emitSigAskRunNow() - if computeMetricsWorker.abort or computeMetricsWorker.savedToWorkflow: - computeMetricsWorker.signals.finished.emit(computeMetricsWorker) - return - - if not posData.segmFound: - rel_path = ( - f'...{os.sep}{exp_foldername}' - f'{os.sep}{posData.pos_foldername}' - ) - self.log( - f'Skipping "{rel_path}" ' - f'because segm. file was not found.' - ) - return - - self.init_signals(computeMetricsWorker, saveDataWorker) - + self._set_metrics_func_from_posData(posData) + self.init_signals(None, None) self.log( - 'Loading the following files:\n' - f'Segmentation file name: {os.path.basename(posData.segm_npz_path)}\n' - f'ACDC output file name: {os.path.basename(posData.acdc_output_csv_path)}' + "Loading the following files:\n" + f"Segmentation file name: {os.path.basename(posData.segm_npz_path)}\n" + f"ACDC output file name: {os.path.basename(posData.acdc_output_csv_path)}" ) - + posData.init_segmInfo_df() - - if computeMetricsWorker is not None: - computeMetricsWorker.emitSigComputeVolume(posData, stop_frame_n) - self._init_metrics_to_save(posData) - - if computeMetricsWorker is not None: - computeMetricsWorker.signals.initProgressBar.emit(stop_frame_n) - + channels_to_load = [ - ch for ch in channel_names if not ch in self.chNamesToSkip - and ch in self.chNamesToProcess + ch + for ch in channel_names + if ch not in self.chNamesToSkip and ch in self.chNamesToProcess ] - - self.log(f'Loading channels {channels_to_load}...') - + self.log(f"Loading channels {channels_to_load}...") self._load_image_data(posData, channels_to_load) - + + acdc_df_li = [] + keys = [] + for frame_i in range(stop_frame_n): + lab = posData.segm_data[frame_i] + if not np.any(lab): + continue + + if frame_i == 0: + self.log("\nComputing cell volume...") + rp = skimage.measure.regionprops(lab) + rp = self._calc_volume_metrics(rp, posData) + + posData.lab = lab + posData.rp = rp + + if posData.acdc_df is None: + acdc_df = utils.getBaseAcdcDf(rp) + else: + try: + acdc_df = posData.acdc_df.loc[frame_i].copy() + except Exception: + acdc_df = utils.getBaseAcdcDf(rp) + + key = (frame_i, posData.TimeIncrement * frame_i) + acdc_df = load.pd_bool_and_float_to_int_to_str( + acdc_df, inplace=False, colsToCastInt=[] + ) + + if not save_metrics: + acdc_df_li.append(acdc_df) + keys.append(key) + continue + + try: + acdc_df = self._add_volume_metrics(acdc_df, rp, posData) + acdc_df, calc_metrics_addtional_args = self._init_calc_metrics( + acdc_df, rp, frame_i, lab, posData, saveDataWorker=None + ) + acdc_df = self._calc_metrics_iter_channels( + acdc_df, rp, frame_i, lab, posData, *calc_metrics_addtional_args + ) + except Exception as error: + self.log(f"\n{traceback.format_exc()}") + + if frame_i == 0: + acdc_df_li.append(acdc_df) + keys.append(key) + continue + + try: + prev_lab = posData.segm_data[frame_i - 1] + acdc_df = self._add_velocity_measurement( + acdc_df, prev_lab, lab, posData + ) + except Exception as error: + self.log(f"\n{traceback.format_exc()}") + + acdc_df_li.append(acdc_df) + keys.append(key) + + if not acdc_df_li: + print("-" * 30) + self.log( + "All selected positions in the experiment folder " + f"{exp_foldername} have EMPTY segmentation mask. " + "Metrics will not be saved." + ) + print("-" * 30) + return + + self._concat_and_save_acdc_df( + acdc_df_li, + keys, + posData, + save_metrics, + computeMetricsWorker=None, + saveDataWorker=None, + last_cca_frame_i=last_cca_frame_i, + ) + + def _compute_metrics_gui_frames( + self, + posData, + stop_frame_n, + save_metrics=True, + computeMetricsWorker=None, + saveDataWorker=None, + ): acdc_df_li = [] keys = [] for frame_i in range(stop_frame_n): @@ -1134,72 +823,67 @@ def run( stop = saveDataWorker.checkAbort() if stop: break - - lab = posData.segm_data[frame_i] + + lab = posData.segm_data[frame_i] if not np.any(lab): - # Empty segmentation mask --> skip continue - + acdc_df = None if computeMetricsWorker is not None: - rp = posData.allData_li[frame_i]['regionprops'] + rp = posData.allData_li[frame_i]["regionprops"] elif saveDataWorker is not None: - rp = posData.allData_li[frame_i]['regionprops'] - acdc_df = posData.allData_li[frame_i]['acdc_df'] + rp = posData.allData_li[frame_i]["regionprops"] + acdc_df = posData.allData_li[frame_i]["acdc_df"] if acdc_df is None: continue else: if frame_i == 0: - self.log('\nComputing cell volume...') + self.log("\nComputing cell volume...") rp = skimage.measure.regionprops(lab) rp = self._calc_volume_metrics(rp, posData) - + posData.lab = lab posData.rp = rp - + if acdc_df is None: if posData.acdc_df is None: - acdc_df = myutils.getBaseAcdcDf(rp) + acdc_df = utils.getBaseAcdcDf(rp) else: try: acdc_df = posData.acdc_df.loc[frame_i].copy() - except: - acdc_df = myutils.getBaseAcdcDf(rp) - - key = (frame_i, posData.TimeIncrement*frame_i) + except Exception: + acdc_df = utils.getBaseAcdcDf(rp) + + key = (frame_i, posData.TimeIncrement * frame_i) acdc_df = load.pd_bool_and_float_to_int_to_str( acdc_df, inplace=False, colsToCastInt=[] ) - + if not save_metrics: if saveDataWorker is not None: saveDataWorker.emitUpdateProgressBar() acdc_df_li.append(acdc_df) keys.append(key) continue - + try: acdc_df = self._add_volume_metrics(acdc_df, rp, posData) acdc_df, calc_metrics_addtional_args = self._init_calc_metrics( - acdc_df, rp, frame_i, lab, posData, - saveDataWorker=saveDataWorker + acdc_df, rp, frame_i, lab, posData, saveDataWorker=saveDataWorker ) acdc_df = self._calc_metrics_iter_channels( - acdc_df, rp, frame_i, lab, posData, - *calc_metrics_addtional_args + acdc_df, rp, frame_i, lab, posData, *calc_metrics_addtional_args ) except Exception as error: - traceback_format = traceback.format_exc() - self.log(f'\n{traceback_format}') + traceback_format = traceback.format_exc() + self.log(f"\n{traceback_format}") if computeMetricsWorker is not None: computeMetricsWorker.standardMetricsErrors[str(error)] = ( traceback_format ) if saveDataWorker is not None: - saveDataWorker.addMetricsCritical.emit( - traceback_format, str(error) - ) - + saveDataWorker.addMetricsCritical.emit(traceback_format, str(error)) + if frame_i == 0: if saveDataWorker is not None: saveDataWorker.emitUpdateProgressBar() @@ -1208,142 +892,236 @@ def run( continue try: - prev_lab = posData.segm_data[frame_i-1] + prev_lab = posData.segm_data[frame_i - 1] acdc_df = self._add_velocity_measurement( acdc_df, prev_lab, lab, posData ) except Exception as error: traceback_format = traceback.format_exc() - self.log(f'\n{traceback_format}') + self.log(f"\n{traceback_format}") if computeMetricsWorker is not None: - e = str(error) - computeMetricsWorker.standardMetricsErrors[e] = ( + computeMetricsWorker.standardMetricsErrors[str(error)] = ( traceback_format ) - + acdc_df_li.append(acdc_df) keys.append(key) if computeMetricsWorker is not None: computeMetricsWorker.signals.progressBar.emit(1) - + if saveDataWorker is not None: saveDataWorker.emitUpdateProgressBar() - - if not acdc_df_li: - print('-'*30) - self.log( - 'All selected positions in the experiment folder ' - f'{exp_foldername} have EMPTY segmentation mask. ' - 'Metrics will not be saved.' + + return acdc_df_li, keys + + def _run_metrics_gui_via_graph( + self, + img_path="", + stop_frame_n=1, + end_filename_segm="", + computeMetricsWorker=None, + saveDataWorker=None, + posData=None, + save_metrics=True, + do_init_metrics=True, + last_cca_frame_i=None, + ): + from cellacdc.workflow.pipelines.measurements_gui import ( + build_gui_measurements_graph, + ) + from cellacdc.workflow.runnable import RunnableConfig + from cellacdc.workflow.state import MeasurementsGuiContext, MeasurementsGuiState + + ctx = MeasurementsGuiContext( + kernel=self, + compute_metrics_worker=computeMetricsWorker, + save_data_worker=saveDataWorker, + save_metrics=save_metrics, + do_init_metrics=do_init_metrics, + last_cca_frame_i=last_cca_frame_i, + end_filename_segm=end_filename_segm or self.end_filename_segm, + ) + graph = build_gui_measurements_graph( + ctx, + pos_data_loaded=posData is not None, + ).compile() + return graph.invoke( + MeasurementsGuiState( + img_path=img_path, + stop_frame_n=stop_frame_n, + pos_data=posData, + ), + RunnableConfig(logger_func=self.log), + ) + + def _run_metrics_via_graph( + self, + img_path, + stop_frame_n, + end_filename_segm, + save_metrics=True, + last_cca_frame_i=None, + ): + from cellacdc.workflow.pipelines.measurements import ( + build_measurements_position_graph, + ) + from cellacdc.workflow.runnable import RunnableConfig + from cellacdc.workflow.state import MeasurementsContext, MeasurementsState + + ctx = MeasurementsContext( + end_filename_segm=end_filename_segm or self.end_filename_segm, + kernel=self, + save_metrics=save_metrics, + ) + ctx.last_cca_frame_i = last_cca_frame_i + graph = build_measurements_position_graph(ctx).compile() + return graph.invoke( + MeasurementsState(img_path=img_path, stop_frame_n=stop_frame_n), + RunnableConfig(logger_func=self.log), + ) + + @exception_handler_cli + def run( + self, + img_path: os.PathLike = "", + stop_frame_n: int = 1, + end_filename_segm: str = "", + computeMetricsWorker=None, + saveDataWorker=None, + posData=None, + save_metrics=True, + do_init_metrics=True, + last_cca_frame_i=None, + ): + if ( + computeMetricsWorker is None + and saveDataWorker is None + and posData is None + ): + return self._run_metrics_via_graph( + img_path, + stop_frame_n, + end_filename_segm or self.end_filename_segm, + save_metrics=save_metrics, + last_cca_frame_i=last_cca_frame_i, ) - print('-'*30) - return - - self._concat_and_save_acdc_df( - acdc_df_li, keys, posData, save_metrics, - computeMetricsWorker=computeMetricsWorker, + + return self._run_metrics_gui_via_graph( + img_path=img_path, + stop_frame_n=stop_frame_n, + end_filename_segm=end_filename_segm or self.end_filename_segm, + computeMetricsWorker=computeMetricsWorker, saveDataWorker=saveDataWorker, - last_cca_frame_i=last_cca_frame_i + posData=posData, + save_metrics=save_metrics, + do_init_metrics=do_init_metrics, + last_cca_frame_i=last_cca_frame_i, ) - + def _concat_and_save_acdc_df( - self, acdc_df_li, keys, posData, save_metrics, - computeMetricsWorker=None, saveDataWorker=None, - last_cca_frame_i=None - ): - + self, + acdc_df_li, + keys, + posData, + save_metrics, + computeMetricsWorker=None, + saveDataWorker=None, + last_cca_frame_i=None, + ): + all_frames_acdc_df = pd.concat( - acdc_df_li, keys=keys, names=['frame_i', 'time_seconds', 'Cell_ID'] + acdc_df_li, keys=keys, names=["frame_i", "time_seconds", "Cell_ID"] ) - + if save_metrics: self._add_combined_metrics( posData, all_frames_acdc_df, saveDataWorker=saveDataWorker ) - + all_frames_acdc_df = self._add_additional_metadata( posData, all_frames_acdc_df, posData.segm_data ) - all_frames_acdc_df = self._remove_deprecated_rows( - all_frames_acdc_df - ) - all_frames_acdc_df = self._add_derived_cell_cycle_columns( - all_frames_acdc_df - ) + all_frames_acdc_df = self._remove_deprecated_rows(all_frames_acdc_df) + all_frames_acdc_df = self._add_derived_cell_cycle_columns(all_frames_acdc_df) all_frames_acdc_df = load._fix_will_divide(all_frames_acdc_df) custom_annot_columns = posData.getCustomAnnotColumnNames() - self.log( - f'Saving acdc_output to: "{posData.acdc_output_csv_path}"' - ) - + self.log(f'Saving acdc_output to: "{posData.acdc_output_csv_path}"') + self._save_acdc_df( - all_frames_acdc_df, posData, custom_annot_columns, - computeMetricsWorker=computeMetricsWorker, + all_frames_acdc_df, + posData, + custom_annot_columns, + computeMetricsWorker=computeMetricsWorker, saveDataWorker=saveDataWorker, - last_cca_frame_i=last_cca_frame_i + last_cca_frame_i=last_cca_frame_i, ) - + if not self.save_object_counts_table: return - + countMapper = posData.countObjectsInSegm() - countMapper.pop('In current frame', None) + countMapper.pop("In current frame", None) df_count_endname = posData.saveObjCounts(countMapper) - - self.log( - 'Saved object counts table to file ending with: ' - f'"{df_count_endname}"' - ) - + + self.log(f'Saved object counts table to file ending with: "{df_count_endname}"') + def _remove_deprecated_rows(self, df): v1_2_4_rc25_deprecated_cols = [ - 'editIDclicked_x', 'editIDclicked_y', - 'editIDnewID', 'editIDnewIDs' + "editIDclicked_x", + "editIDclicked_y", + "editIDnewID", + "editIDnewIDs", ] - df = df.drop(columns=v1_2_4_rc25_deprecated_cols, errors='ignore') + df = df.drop(columns=v1_2_4_rc25_deprecated_cols, errors="ignore") # Remove old gui_ columns from version < v1.2.4.rc-7 - gui_columns = df.filter(regex='gui_*').columns - df = df.drop(columns=gui_columns, errors='ignore') - cell_id_cols = df.filter(regex='Cell_ID.*').columns - df = df.drop(columns=cell_id_cols, errors='ignore') - time_seconds_cols = df.filter(regex='time_seconds.*').columns - df = df.drop(columns=time_seconds_cols, errors='ignore') - df = df.drop(columns='relative_ID_tree', errors='ignore') - df = df.drop(columns=['level_0', 'index'], errors='ignore') + gui_columns = df.filter(regex="gui_*").columns + df = df.drop(columns=gui_columns, errors="ignore") + cell_id_cols = df.filter(regex="Cell_ID.*").columns + df = df.drop(columns=cell_id_cols, errors="ignore") + time_seconds_cols = df.filter(regex="time_seconds.*").columns + df = df.drop(columns=time_seconds_cols, errors="ignore") + df = df.drop(columns="relative_ID_tree", errors="ignore") + df = df.drop(columns=["level_0", "index"], errors="ignore") return df - + def _save_acdc_df( - self, all_frames_acdc_df, posData, custom_annot_columns, - computeMetricsWorker=None, saveDataWorker=None, - last_cca_frame_i=None - ): + self, + all_frames_acdc_df, + posData, + custom_annot_columns, + computeMetricsWorker=None, + saveDataWorker=None, + last_cca_frame_i=None, + ): try: if saveDataWorker is not None: load.store_copy_acdc_df( - posData, posData.acdc_output_csv_path, - log_func=saveDataWorker.progress.emit + posData, + posData.acdc_output_csv_path, + log_func=saveDataWorker.progress.emit, ) load.save_acdc_df_file( - all_frames_acdc_df, posData.acdc_output_csv_path, + all_frames_acdc_df, + posData.acdc_output_csv_path, custom_annot_columns=custom_annot_columns, - last_cca_frame_i=last_cca_frame_i + last_cca_frame_i=last_cca_frame_i, ) posData.acdc_df = all_frames_acdc_df except PermissionError as error: traceback_str = traceback.format_exc() if computeMetricsWorker is not None: computeMetricsWorker.emitSigPermissionErrorAndSave( - posData, traceback_str, all_frames_acdc_df, - custom_annot_columns + posData, traceback_str, all_frames_acdc_df, custom_annot_columns ) - + if saveDataWorker is not None: saveDataWorker.emitSigPermissionErrorAndSave( - all_frames_acdc_df, posData.acdc_output_csv_path, - custom_annot_columns + all_frames_acdc_df, + posData.acdc_output_csv_path, + custom_annot_columns, ) except Exception as error: if saveDataWorker is not None: @@ -1351,7 +1129,7 @@ def _save_acdc_df( saveDataWorker.critical.emit(error) saveDataWorker.waitCond.wait(saveDataWorker.mutex) saveDataWorker.mutex.unlock() - + def _load_channel_data(self, channel_path): self.log(f'Loading fluorescence image data from "{channel_path}"...') images_path = os.path.dirname(channel_path) @@ -1359,28 +1137,28 @@ def _load_channel_data(self, channel_path): # Load overlay frames and align if needed filename = os.path.basename(channel_path) filename_noEXT, ext = os.path.splitext(filename) - if ext == '.npy' or ext == '.npz': + if ext == ".npy" or ext == ".npz": img_data = np.load(channel_path) try: - img_data = np.squeeze(img_data['arr_0']) + img_data = np.squeeze(img_data["arr_0"]) except Exception as e: img_data = np.squeeze(img_data) # Load background data bkgrData_path = os.path.join( - images_path, f'{filename_noEXT}_bkgrRoiData.npz' + images_path, f"{filename_noEXT}_bkgrRoiData.npz" ) if os.path.exists(bkgrData_path): bkgrData = np.load(bkgrData_path) - elif ext == '.tif' or ext == '.tiff': - aligned_filename = f'{filename_noEXT}_aligned.npz' + elif ext == ".tif" or ext == ".tiff": + aligned_filename = f"{filename_noEXT}_aligned.npz" aligned_path = os.path.join(images_path, aligned_filename) if os.path.exists(aligned_path): - img_data = np.load(aligned_path)['arr_0'] + img_data = np.load(aligned_path)["arr_0"] # Load background data bkgrData_path = os.path.join( - images_path, f'{aligned_filename}_bkgrRoiData.npz' + images_path, f"{aligned_filename}_bkgrRoiData.npz" ) if os.path.exists(bkgrData_path): bkgrData = np.load(bkgrData_path) @@ -1389,7 +1167,7 @@ def _load_channel_data(self, channel_path): # Load background data bkgrData_path = os.path.join( - images_path, f'{filename_noEXT}_bkgrRoiData.npz' + images_path, f"{filename_noEXT}_bkgrRoiData.npz" ) if os.path.exists(bkgrData_path): bkgrData = np.load(bkgrData_path) @@ -1397,7 +1175,7 @@ def _load_channel_data(self, channel_path): return None, None return img_data, bkgrData - + def _calc_volume_metrics(self, rp, posData): PhysicalSizeY = posData.PhysicalSizeY PhysicalSizeX = posData.PhysicalSizeX @@ -1409,14 +1187,14 @@ def _calc_volume_metrics(self, rp, posData): obj.vol_vox = vol_vox obj.vol_fl = vol_fl return rp - + def _add_volume_metrics(self, df, rp, posData): PhysicalSizeY = posData.PhysicalSizeY PhysicalSizeX = posData.PhysicalSizeX - yx_pxl_to_um2 = PhysicalSizeY*PhysicalSizeX - vox_to_fl_3D = PhysicalSizeY*PhysicalSizeX*posData.PhysicalSizeZ - - init_list = [-2]*len(rp) + yx_pxl_to_um2 = PhysicalSizeY * PhysicalSizeX + vox_to_fl_3D = PhysicalSizeY * PhysicalSizeX * posData.PhysicalSizeZ + + init_list = [-2] * len(rp) IDs = init_list.copy() IDs_vol_vox = init_list.copy() IDs_area_pxl = init_list.copy() @@ -1435,54 +1213,52 @@ def _add_volume_metrics(self, df, rp, posData): IDs_vol_vox[i] = np.nan IDs_vol_fl[i] = np.nan IDs_area_pxl[i] = obj.area - IDs_area_um2[i] = obj.area*yx_pxl_to_um2 + IDs_area_um2[i] = obj.area * yx_pxl_to_um2 if self.isSegm3D: IDs_vol_vox_3D[i] = obj.area - IDs_vol_fl_3D[i] = obj.area*vox_to_fl_3D - - df['cell_area_pxl'] = pd.Series(data=IDs_area_pxl, index=IDs, dtype=float) - df['cell_vol_vox'] = pd.Series(data=IDs_vol_vox, index=IDs, dtype=float) - df['cell_area_um2'] = pd.Series(data=IDs_area_um2, index=IDs, dtype=float) - df['cell_vol_fl'] = pd.Series(data=IDs_vol_fl, index=IDs, dtype=float) + IDs_vol_fl_3D[i] = obj.area * vox_to_fl_3D + + df["cell_area_pxl"] = pd.Series(data=IDs_area_pxl, index=IDs, dtype=float) + df["cell_vol_vox"] = pd.Series(data=IDs_vol_vox, index=IDs, dtype=float) + df["cell_area_um2"] = pd.Series(data=IDs_area_um2, index=IDs, dtype=float) + df["cell_vol_fl"] = pd.Series(data=IDs_vol_fl, index=IDs, dtype=float) if self.isSegm3D: - df['cell_vol_vox_3D'] = pd.Series( + df["cell_vol_vox_3D"] = pd.Series( data=IDs_vol_vox_3D, index=IDs, dtype=float ) - df['cell_vol_fl_3D'] = pd.Series( - data=IDs_vol_fl_3D, index=IDs, dtype=float - ) + df["cell_vol_fl_3D"] = pd.Series(data=IDs_vol_fl_3D, index=IDs, dtype=float) return df - + def _check_zSlice(self, posData, frame_i, saveDataWorker=None): if posData.SizeZ == 1: return True - + # Iteare fluo channels and get 2D data from 3D if needed filenames = posData.fluo_data_dict.keys() for chName, filename in zip(posData.loadedChNames, filenames): if chName in self.chNamesToSkip: - continue - + continue + idx = (filename, frame_i) try: - if posData.segmInfo_df.at[idx, 'resegmented_in_gui']: - col = 'z_slice_used_gui' + if posData.segmInfo_df.at[idx, "resegmented_in_gui"]: + col = "z_slice_used_gui" else: - col = 'z_slice_used_dataPrep' + col = "z_slice_used_dataPrep" z_slice = posData.segmInfo_df.at[idx, col] except KeyError: try: # Try to see if the user already selected z-slice in prev pos segmInfo_df = pd.read_csv(posData.segmInfo_df_csv_path) - index_col = ['filename', 'frame_i'] + index_col = ["filename", "frame_i"] posData.segmInfo_df = segmInfo_df.set_index(index_col) - col = 'z_slice_used_dataPrep' + col = "z_slice_used_dataPrep" z_slice = posData.segmInfo_df.at[idx, col] except KeyError as e: if saveDataWorker is not None: saveDataWorker.progress.emit( f'z-slice for channel "{chName}" absent. ' - 'Follow instructions on pop-up dialogs.' + "Follow instructions on pop-up dialogs." ) saveDataWorker.mutex.lock() saveDataWorker.askZsliceAbsent.emit(filename, posData) @@ -1491,41 +1267,40 @@ def _check_zSlice(self, posData, frame_i, saveDataWorker=None): if saveDataWorker.abort: return False saveDataWorker.progress.emit( - f'Saving (check terminal for additional progress info)...' + f"Saving (check terminal for additional progress info)..." ) segmInfo_df = pd.read_csv(posData.segmInfo_df_csv_path) - index_col = ['filename', 'frame_i'] + index_col = ["filename", "frame_i"] posData.segmInfo_df = segmInfo_df.set_index(index_col) - col = 'z_slice_used_dataPrep' + col = "z_slice_used_dataPrep" z_slice = posData.segmInfo_df.at[idx, col] else: print( - f'[WARNING]: z-slice for channel {chName} absent. ' - 'Using middle z-slice for calculating metrics.' + f"[WARNING]: z-slice for channel {chName} absent. " + "Using middle z-slice for calculating metrics." ) middle_z = round(np.median(np.arange(posData.SizeZ))) - new_row = pd.DataFrame({ - 'z_slice_used_dataPrep': [middle_z], - 'resegmented_in_gui': [0], - 'which_z_proj': 'single z-slice', - 'is_from_dataPrep': [0], - 'z_slice_used_gui': [-1], - 'which_z_proj_gui': 'single z-slice', + new_row = pd.DataFrame( + { + "z_slice_used_dataPrep": [middle_z], + "resegmented_in_gui": [0], + "which_z_proj": "single z-slice", + "is_from_dataPrep": [0], + "z_slice_used_gui": [-1], + "which_z_proj_gui": "single z-slice", }, - index=[idx] - ) - posData.segmInfo_df = pd.concat( - [posData.segmInfo_df, new_row] + index=[idx], ) + posData.segmInfo_df = pd.concat([posData.segmInfo_df, new_row]) posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) return True - + def _init_calc_metrics( - self, acdc_df, rp, frame_i, lab, posData, saveDataWorker=None - ): - yx_pxl_to_um2 = posData.PhysicalSizeY*posData.PhysicalSizeX + self, acdc_df, rp, frame_i, lab, posData, saveDataWorker=None + ): + yx_pxl_to_um2 = posData.PhysicalSizeY * posData.PhysicalSizeX vox_to_fl_3D = ( - posData.PhysicalSizeY*posData.PhysicalSizeX*posData.PhysicalSizeZ + posData.PhysicalSizeY * posData.PhysicalSizeX * posData.PhysicalSizeZ ) manualBackgrLab = posData.manualBackgroundLab @@ -1538,7 +1313,7 @@ def _init_calc_metrics( size_metrics_to_save = self.sizeMetricsToSave regionprops_to_save = self.regionPropsToSave custom_func_dict = self.custom_func_dict - + calc_size_for_each_zslice = self.calc_size_for_each_zslice # Pre-populate columns with zeros @@ -1554,35 +1329,36 @@ def _init_calc_metrics( df = df.combine_first(acdc_df) # Check if z-slice is present for 3D z-stack data - proceed = self._check_zSlice( - posData, frame_i, saveDataWorker=saveDataWorker - ) + proceed = self._check_zSlice(posData, frame_i, saveDataWorker=saveDataWorker) if not proceed: return df, [] - + df = measurements.add_size_metrics( - df, rp, size_metrics_to_save, isSegm3D, yx_pxl_to_um2, - vox_to_fl_3D, calc_size_for_each_zslice=calc_size_for_each_zslice + df, + rp, + size_metrics_to_save, + isSegm3D, + yx_pxl_to_um2, + vox_to_fl_3D, + calc_size_for_each_zslice=calc_size_for_each_zslice, ) - + # Get background masks - autoBkgr_masks = measurements.get_autoBkgr_mask( - lab, isSegm3D, posData, frame_i - ) + autoBkgr_masks = measurements.get_autoBkgr_mask(lab, isSegm3D, posData, frame_i) # self._emitSigDebug((lab, frame_i, autoBkgr_masks)) - + autoBkgr_mask, autoBkgr_mask_proj = autoBkgr_masks dataPrepBkgrROI_mask = measurements.get_bkgrROI_mask(posData, isSegm3D) - + calc_metrics_addtional_args = ( - autoBkgr_mask, - autoBkgr_mask_proj, + autoBkgr_mask, + autoBkgr_mask_proj, dataPrepBkgrROI_mask, - manualBackgrRp + manualBackgrRp, ) - + return df, calc_metrics_addtional_args - + def _init_metrics(self, posData, isSegm3D): self.chNamesToSkip = [] loadedChannels = posData.setLoadedChannelNames(returnList=True) @@ -1598,26 +1374,25 @@ def _init_metrics(self, posData, isSegm3D): if isSegm3D: self.regionPropsToSave = measurements.get_props_names_3D() else: - self.regionPropsToSave = measurements.get_props_names() + self.regionPropsToSave = measurements.get_props_names() self.mixedChCombineMetricsToSkip = [] self.chIndipendCustomMetricsToSave = list( measurements.ch_indipend_custom_metrics_desc( - posData.SizeZ>1, isSegm3D=isSegm3D, + posData.SizeZ > 1, + isSegm3D=isSegm3D, ).keys() ) self.sizeMetricsToSave = list( - measurements.get_size_metrics_desc( - isSegm3D, posData.SizeT>1 - ).keys() + measurements.get_size_metrics_desc(isSegm3D, posData.SizeT > 1).keys() ) - + exp_path = posData.exp_path - posFoldernames = myutils.get_pos_foldernames(exp_path) + posFoldernames = utils.get_pos_foldernames(exp_path) for pos in posFoldernames: - images_path = os.path.join(exp_path, pos, 'Images') - for file in myutils.listdir(images_path): - if not file.endswith('custom_combine_metrics.ini'): + images_path = os.path.join(exp_path, pos, "Images") + for file in utils.listdir(images_path): + if not file.endswith("custom_combine_metrics.ini"): continue filePath = os.path.join(images_path, file) configPars = load.read_config_metrics(filePath) @@ -1625,75 +1400,98 @@ def _init_metrics(self, posData, isSegm3D): posData.combineMetricsConfig = load.add_configPars_metrics( configPars, posData.combineMetricsConfig ) - + def _add_custom_metrics( - self, posData, frame_i, isSegm3D, df, rp, custom_metrics_params, - lab, calc_for_each_zslice_mapper - ): - iter_channels = zip( - posData.loadedChNames, - posData.fluo_data_dict.items() - ) + self, + posData, + frame_i, + isSegm3D, + df, + rp, + custom_metrics_params, + lab, + calc_for_each_zslice_mapper, + ): + iter_channels = zip(posData.loadedChNames, posData.fluo_data_dict.items()) # Add custom measurements for channel, (filename, channel_data) in iter_channels: if channel in self.chNamesToSkip: - continue - + continue + foregr_img = channel_data[frame_i] - + iter_other_channels = zip( - posData.loadedChNames, - posData.fluo_data_dict.items() + posData.loadedChNames, posData.fluo_data_dict.items() ) other_channels_foregr_imgs = { - ch:ch_data[frame_i] for ch, (_, ch_data) in iter_other_channels + ch: ch_data[frame_i] + for ch, (_, ch_data) in iter_other_channels if ch != channel } - + # Get the z-slice if we have z-stacks z = posData.zSliceSegmentation(filename, frame_i) - + foregr_data = measurements.get_foregr_data(foregr_img, isSegm3D, z) - + df = measurements.add_custom_metrics( - df, rp, channel, foregr_data, - custom_metrics_params[channel], - isSegm3D, lab, foregr_img, + df, + rp, + channel, + foregr_data, + custom_metrics_params[channel], + isSegm3D, + lab, + foregr_img, other_channels_foregr_imgs, z_slice=z, customMetricsCritical=self.customMetricsCritical, ) - + if not calc_for_each_zslice_mapper.get(channel, False): continue - + # Repeat measureemnts for each z-slice pbar_z = tqdm( - total=posData.SizeZ, desc='Computing for z-slices: ', - ncols=100, leave=False, unit='z-slice' + total=posData.SizeZ, + desc="Computing for z-slices: ", + ncols=100, + leave=False, + unit="z-slice", ) for z in range(posData.SizeZ): - foregr_data = measurements.get_foregr_data( - foregr_img, isSegm3D, z - ) - foregr_data = {'zSlice': foregr_data['zSlice']} - + foregr_data = measurements.get_foregr_data(foregr_img, isSegm3D, z) + foregr_data = {"zSlice": foregr_data["zSlice"]} + df = measurements.add_custom_metrics( - df, rp, channel, foregr_data, - custom_metrics_params[channel], - isSegm3D, lab, foregr_img, + df, + rp, + channel, + foregr_data, + custom_metrics_params[channel], + isSegm3D, + lab, + foregr_img, other_channels_foregr_imgs, z_slice=z, text_to_append_to_col=str(z), - customMetricsCritical=self.customMetricsCritical, + customMetricsCritical=self.customMetricsCritical, ) - + return df - + def _calc_metrics_iter_channels( - self, acdc_df, rp, frame_i, lab, posData, autoBkgr_mask, - autoBkgr_mask_proj, dataPrepBkgrROI_mask, manualBackgrRp - ): + self, + acdc_df, + rp, + frame_i, + lab, + posData, + autoBkgr_mask, + autoBkgr_mask_proj, + dataPrepBkgrROI_mask, + manualBackgrRp, + ): all_channels_foregr_data = {} all_channels_foregr_imgs = {} all_channels_z_slices = {} @@ -1705,94 +1503,123 @@ def _calc_metrics_iter_channels( concentration_metrics_params = self.concentration_metrics_params regionprops_to_save = self.regionPropsToSave custom_metrics_params = self.custom_metrics_params - ch_indipend_custom_func_params = ( - self.ch_indipend_custom_func_params - ) + ch_indipend_custom_func_params = self.ch_indipend_custom_func_params images_path = posData.images_path # Iterate channels - iter_channels = zip( - posData.loadedChNames, - posData.fluo_data_dict.items() - ) + iter_channels = zip(posData.loadedChNames, posData.fluo_data_dict.items()) for channel, (filename, channel_data) in iter_channels: if channel in self.chNamesToSkip: - continue + continue foregr_img = channel_data[frame_i] # Get the z-slice if we have z-stacks z = posData.zSliceSegmentation(filename, frame_i) - + # Get the background data bkgr_data = measurements.get_bkgr_data( - foregr_img, posData, filename, frame_i, autoBkgr_mask, z, - autoBkgr_mask_proj, dataPrepBkgrROI_mask, isSegm3D, lab + foregr_img, + posData, + filename, + frame_i, + autoBkgr_mask, + z, + autoBkgr_mask_proj, + dataPrepBkgrROI_mask, + isSegm3D, + lab, ) - + foregr_data = measurements.get_foregr_data(foregr_img, isSegm3D, z) - + all_channels_foregr_data[channel] = foregr_data all_channels_foregr_imgs[channel] = foregr_img all_channels_z_slices[channel] = z # Compute background values acdc_df = measurements.add_bkgr_values( - acdc_df, bkgr_data, bkgr_metrics_params[channel], metrics_func, - manualBackgrRp=manualBackgrRp, foregr_data=foregr_data + acdc_df, + bkgr_data, + bkgr_metrics_params[channel], + metrics_func, + manualBackgrRp=manualBackgrRp, + foregr_data=foregr_data, ) # Iterate objects and compute foreground metrics acdc_df = measurements.add_foregr_standard_metrics( - acdc_df, rp, channel, foregr_data, - foregr_metrics_params[channel], - metrics_func, isSegm3D, - lab, foregr_img, + acdc_df, + rp, + channel, + foregr_data, + foregr_metrics_params[channel], + metrics_func, + isSegm3D, + lab, + foregr_img, manualBackgrRp=manualBackgrRp, - z_slice=z + z_slice=z, ) if not calc_for_each_zslice_mapper.get(channel, False): continue - + # Repeat measureemnts for each z-slice pbar_z = tqdm( - total=posData.SizeZ, desc='Computing for z-slices: ', - ncols=100, leave=False, unit='z-slice' + total=posData.SizeZ, + desc="Computing for z-slices: ", + ncols=100, + leave=False, + unit="z-slice", ) for z in range(posData.SizeZ): # Get the background data bkgr_data = measurements.get_bkgr_data( - foregr_img, posData, filename, frame_i, autoBkgr_mask, z, - autoBkgr_mask_proj, dataPrepBkgrROI_mask, isSegm3D, lab + foregr_img, + posData, + filename, + frame_i, + autoBkgr_mask, + z, + autoBkgr_mask_proj, + dataPrepBkgrROI_mask, + isSegm3D, + lab, ) bkgr_data = { - 'autoBkgr': {'zSlice': bkgr_data['autoBkgr']['zSlice']}, - 'dataPrepBkgr': {'zSlice': bkgr_data['dataPrepBkgr']['zSlice']} + "autoBkgr": {"zSlice": bkgr_data["autoBkgr"]["zSlice"]}, + "dataPrepBkgr": {"zSlice": bkgr_data["dataPrepBkgr"]["zSlice"]}, } - - foregr_data = measurements.get_foregr_data( - foregr_img, isSegm3D, z - ) - foregr_data = {'zSlice': foregr_data['zSlice']} + + foregr_data = measurements.get_foregr_data(foregr_img, isSegm3D, z) + foregr_data = {"zSlice": foregr_data["zSlice"]} # Compute background values acdc_df = measurements.add_bkgr_values( - acdc_df, bkgr_data, bkgr_metrics_params[channel], + acdc_df, + bkgr_data, + bkgr_metrics_params[channel], metrics_func, - manualBackgrRp=manualBackgrRp, + manualBackgrRp=manualBackgrRp, foregr_data=foregr_data, - text_to_append_to_col=str(z) + text_to_append_to_col=str(z), ) # Iterate objects and compute foreground metrics acdc_df = measurements.add_foregr_standard_metrics( - acdc_df, rp, channel, foregr_data, - foregr_metrics_params[channel], - metrics_func, isSegm3D, - lab, foregr_img, + acdc_df, + rp, + channel, + foregr_data, + foregr_metrics_params[channel], + metrics_func, + isSegm3D, + lab, + foregr_img, manualBackgrRp=manualBackgrRp, - z_slice=z, text_to_append_to_col=str(z) + z_slice=z, + text_to_append_to_col=str(z), ) pbar_z.update() pbar_z.close() @@ -1800,77 +1627,81 @@ def _calc_metrics_iter_channels( acdc_df = measurements.add_concentration_metrics( acdc_df, concentration_metrics_params ) - + # Add region properties try: acdc_df, rp_errors = measurements.add_regionprops_metrics( - acdc_df, lab, regionprops_to_save, - logger_func=self.logger.exception + acdc_df, lab, regionprops_to_save, logger_func=self.logger.exception ) if rp_errors: - print('\n') + print("\n") err_message = ( - 'WARNING: Some objects had the following errors:\n' - f'{rp_errors}\n' - 'Region properties with errors were saved as `Not A Number`.' + "WARNING: Some objects had the following errors:\n" + f"{rp_errors}\n" + "Region properties with errors were saved as `Not A Number`." ) self.logger.exception(err_message) - err_txt = 'Morphological properties error' + err_txt = "Morphological properties error" self.regionPropsCritical.emit(err_message, err_txt) except Exception as error: traceback_format = traceback.format_exc() self.regionPropsCritical.emit(traceback_format, str(error)) acdc_df = self._add_custom_metrics( - posData, frame_i, isSegm3D, acdc_df, rp, custom_metrics_params, - lab, calc_for_each_zslice_mapper + posData, + frame_i, + isSegm3D, + acdc_df, + rp, + custom_metrics_params, + lab, + calc_for_each_zslice_mapper, ) - + acdc_df = measurements.add_ch_indipend_custom_metrics( - acdc_df, rp, all_channels_foregr_data, - ch_indipend_custom_func_params, - isSegm3D, lab, all_channels_foregr_imgs, + acdc_df, + rp, + all_channels_foregr_data, + ch_indipend_custom_func_params, + isSegm3D, + lab, + all_channels_foregr_imgs, all_channels_z_slices=all_channels_z_slices, - customMetricsCritical=self.customMetricsCritical, + customMetricsCritical=self.customMetricsCritical, ) - + # Remove 0s columns acdc_df = acdc_df.loc[:, (acdc_df != -2).any(axis=0)] return acdc_df - + def _add_velocity_measurement(self, acdc_df, prev_lab, lab, posData): - if 'velocity_pixel' not in self.sizeMetricsToSave: + if "velocity_pixel" not in self.sizeMetricsToSave: return acdc_df - - if 'velocity_um' not in self.sizeMetricsToSave: - spacing = None + + if "velocity_um" not in self.sizeMetricsToSave: + spacing = None elif self.isSegm3D: - spacing = np.array([ - posData.PhysicalSizeZ, - posData.PhysicalSizeY, - posData.PhysicalSizeX - ]) + spacing = np.array( + [posData.PhysicalSizeZ, posData.PhysicalSizeY, posData.PhysicalSizeX] + ) else: - spacing = np.array([ - posData.PhysicalSizeY, - posData.PhysicalSizeX - ]) + spacing = np.array([posData.PhysicalSizeY, posData.PhysicalSizeX]) velocities_pxl, velocities_um = core.compute_twoframes_velocity( prev_lab, lab, spacing=spacing ) - acdc_df['velocity_pixel'] = velocities_pxl - acdc_df['velocity_um'] = velocities_um + acdc_df["velocity_pixel"] = velocities_pxl + acdc_df["velocity_um"] = velocities_um return acdc_df - + def _add_combined_metrics(self, posData, df, saveDataWorker=None): - # Add channel specifc combined metrics (from equations and + # Add channel specifc combined metrics (from equations and # from user_path_equations sections) config = posData.combineMetricsConfig for chName in posData.loadedChNames: metricsToSkipChannel = self.metricsToSkip.get(chName, []) - posDataEquations = config['equations'] - userPathChEquations = config['user_path_equations'] + posDataEquations = config["equations"] + userPathChEquations = config["user_path_equations"] for newColName, equation in posDataEquations.items(): if not newColName.startswith(chName): continue @@ -1889,16 +1720,16 @@ def _add_combined_metrics(self, posData, df, saveDataWorker=None): ) # Add mixed channels combined metrics - mixedChannelsEquations = config['mixed_channels_equations'] + mixedChannelsEquations = config["mixed_channels_equations"] for newColName, equation in mixedChannelsEquations.items(): if newColName in self.mixedChCombineMetricsToSkip: continue - cols = re.findall(r'[A-Za-z0-9]+_[A-Za-z0-9_]+', equation) + cols = re.findall(r"[A-Za-z0-9]+_[A-Za-z0-9_]+", equation) if all([col in df.columns for col in cols]): self._df_eval_equation( df, newColName, equation, saveDataWorker=saveDataWorker ) - + def _df_eval_equation(self, df, newColName, expr, saveDataWorker=None): try: df[newColName] = df.eval(expr) @@ -1907,7 +1738,7 @@ def _df_eval_equation(self, df, newColName, expr, saveDataWorker=None): saveDataWorker.sigCombinedMetricsMissingColumn.emit( str(error), newColName ) - + try: df[newColName] = df.eval(expr) except Exception as error: @@ -1915,77 +1746,71 @@ def _df_eval_equation(self, df, newColName, expr, saveDataWorker=None): saveDataWorker.customMetricsCritical.emit( traceback.format_exc(), newColName ) - + def _add_additional_metadata( - self, posData: load.loadData, df: pd.DataFrame, saved_segm_data - ): + self, posData: load.loadData, df: pd.DataFrame, saved_segm_data + ): for col, val in posData.additionalMetadataValues().items(): if col in df.columns: df.pop(col) df.insert(0, col, val) - + try: - df.pop('time_minutes') + df.pop("time_minutes") except Exception as e: pass try: - df.pop('time_hours') + df.pop("time_hours") except Exception as e: pass try: - time_seconds = df.index.get_level_values('time_seconds') - df.insert(0, 'time_minutes', time_seconds/60) - df.insert(1, 'time_hours', time_seconds/3600) + time_seconds = df.index.get_level_values("time_seconds") + df.insert(0, "time_minutes", time_seconds / 60) + df.insert(1, "time_hours", time_seconds / 3600) except Exception as e: pass - + df = self._add_disappears_before_end(df, saved_segm_data) return df - - def _add_disappears_before_end( - self, acdc_df: pd.DataFrame, saved_segm_data - ): - acdc_df = acdc_df.drop('time_seconds', axis=1, errors='ignore') - acdc_df = ( - acdc_df.reset_index() - .set_index(['frame_i', 'Cell_ID']) - .sort_index() - ) - acdc_df['disappears_before_end'] = 0 + + def _add_disappears_before_end(self, acdc_df: pd.DataFrame, saved_segm_data): + acdc_df = acdc_df.drop("time_seconds", axis=1, errors="ignore") + acdc_df = acdc_df.reset_index().set_index(["frame_i", "Cell_ID"]).sort_index() + acdc_df["disappears_before_end"] = 0 for frame_i, lab in enumerate(saved_segm_data): if frame_i == 0: continue - + try: df_frame = acdc_df.loc[frame_i] except KeyError: break - - prev_lab = saved_segm_data[frame_i-1] + + prev_lab = saved_segm_data[frame_i - 1] prev_rp = skimage.measure.regionprops(prev_lab) - + curr_rp = skimage.measure.regionprops(lab) curr_rp_mapper = {obj.label: obj for obj in curr_rp} lost_IDs = [] for prev_obj in prev_rp: if curr_rp_mapper.get(prev_obj.label) is None: lost_IDs.append(prev_obj.label) - - if 'parent_ID_tree' in df_frame.columns: - parent_IDs = set(df_frame['parent_ID_tree'].values) + + if "parent_ID_tree" in df_frame.columns: + parent_IDs = set(df_frame["parent_ID_tree"].values) lost_IDs = [ID for ID in lost_IDs if ID not in parent_IDs] - + if not lost_IDs: continue - - idx = pd.IndexSlice[frame_i-1, lost_IDs] + + idx = pd.IndexSlice[frame_i - 1, lost_IDs] try: - acdc_df.loc[idx, 'disappears_before_end'] = 1 + acdc_df.loc[idx, "disappears_before_end"] = 1 except Exception as err: printl(frame_i, lost_IDs) - + return acdc_df - + def _add_derived_cell_cycle_columns(self, all_frames_acdc_df): try: all_frames_acdc_df = cca_functions.add_derived_cell_cycle_columns( @@ -1993,5 +1818,5 @@ def _add_derived_cell_cycle_columns(self, all_frames_acdc_df): ) except Exception as err: self.sigLog.emit(traceback.format_exc()) - + return all_frames_acdc_df diff --git a/cellacdc/colors.py b/cellacdc/colors.py index 0c71ef0e7..69aabf2ad 100644 --- a/cellacdc/colors.py +++ b/cellacdc/colors.py @@ -19,27 +19,30 @@ try: import networkx as nx + NETWORKX_INSTALLED = True except: NETWORKX_INSTALLED = False -__all__ = ['ColorMap'] +__all__ = ["ColorMap"] FLUO_CHANNELS_COLORS = { - 'mCardinal': (255, 0, 255), - 'mNeonGreen': (0, 255, 0), - 'NeonGreen': (0, 255, 0), - 'mNG': (0, 255, 0), - 'mScarlet': (255, 0, 255), - 'mScarlet-I3': (255, 0, 255), - 'mKate': (255, 0, 255), - 'mKate2': (255, 0, 255), - 'GFP': (0, 255, 0), - 'EGFP': (0, 255, 0), - 'mCitrine': (255, 255, 0) + "mCardinal": (255, 0, 255), + "mNeonGreen": (0, 255, 0), + "NeonGreen": (0, 255, 0), + "mNG": (0, 255, 0), + "mScarlet": (255, 0, 255), + "mScarlet-I3": (255, 0, 255), + "mKate": (255, 0, 255), + "mKate2": (255, 0, 255), + "GFP": (0, 255, 0), + "EGFP": (0, 255, 0), + "mCitrine": (255, 255, 0), } _mapCache = {} + + def getFromMatplotlib(name): """ Added to pyqtgraph 0.12 copied/pasted here to allow pyqtgraph <0.12. Link: @@ -55,45 +58,49 @@ def getFromMatplotlib(name): return None cm = None col_map = plt.get_cmap(name) - if hasattr(col_map, '_segmentdata'): # handle LinearSegmentedColormap + if hasattr(col_map, "_segmentdata"): # handle LinearSegmentedColormap data = col_map._segmentdata - if ('red' in data) and isinstance(data['red'], (Sequence, np.ndarray)): - positions = set() # super-set of handle positions in individual channels - for key in ['red','green','blue']: + if ("red" in data) and isinstance(data["red"], (Sequence, np.ndarray)): + positions = set() # super-set of handle positions in individual channels + for key in ["red", "green", "blue"]: for tup in data[key]: positions.add(tup[0]) - col_data = np.zeros((len(positions),4 )) - col_data[:,-1] = sorted(positions) - for idx, key in enumerate(['red','green','blue']): - positions = np.zeros( len(data[key] ) ) - comp_vals = np.zeros( len(data[key] ) ) - for idx2, tup in enumerate( data[key] ): + col_data = np.zeros((len(positions), 4)) + col_data[:, -1] = sorted(positions) + for idx, key in enumerate(["red", "green", "blue"]): + positions = np.zeros(len(data[key])) + comp_vals = np.zeros(len(data[key])) + for idx2, tup in enumerate(data[key]): positions[idx2] = tup[0] - comp_vals[idx2] = tup[1] # these are sorted in the raw data - col_data[:,idx] = np.interp(col_data[:,3], positions, comp_vals) - cm = ColorMap(pos=col_data[:,-1], color=255*col_data[:,:3]+0.5) + comp_vals[idx2] = tup[1] # these are sorted in the raw data + col_data[:, idx] = np.interp(col_data[:, 3], positions, comp_vals) + cm = ColorMap(pos=col_data[:, -1], color=255 * col_data[:, :3] + 0.5) # some color maps (gnuplot in particular) are defined by RGB component functions: - elif ('red' in data) and isinstance(data['red'], Callable): + elif ("red" in data) and isinstance(data["red"], Callable): col_data = np.zeros((64, 4)) - col_data[:,-1] = np.linspace(0., 1., 64) - for idx, key in enumerate(['red','green','blue']): - col_data[:,idx] = np.clip( data[key](col_data[:,-1]), 0, 1) - cm = ColorMap(pos=col_data[:,-1], color=255*col_data[:,:3]+0.5) - elif hasattr(col_map, 'colors'): # handle ListedColormap + col_data[:, -1] = np.linspace(0.0, 1.0, 64) + for idx, key in enumerate(["red", "green", "blue"]): + col_data[:, idx] = np.clip(data[key](col_data[:, -1]), 0, 1) + cm = ColorMap(pos=col_data[:, -1], color=255 * col_data[:, :3] + 0.5) + elif hasattr(col_map, "colors"): # handle ListedColormap col_data = np.array(col_map.colors) - cm = ColorMap(pos=np.linspace(0.0, 1.0, col_data.shape[0]), - color=255*col_data[:,:3]+0.5 ) + cm = ColorMap( + pos=np.linspace(0.0, 1.0, col_data.shape[0]), + color=255 * col_data[:, :3] + 0.5, + ) if cm is not None: cm.name = name _mapCache[name] = cm return cm + def get_pg_gradient(colors): - ticks_pos = np.linspace(0,1,len(colors)) + ticks_pos = np.linspace(0, 1, len(colors)) ticks = [(tick_pos, color) for tick_pos, color in zip(ticks_pos, colors)] - gradient = {'ticks': ticks, 'mode': 'rgb'} + gradient = {"ticks": ticks, "mode": "rgb"} return gradient + def lighten_color(color, amount=0.3, hex=True): """ Lightens the given color by multiplying (1-luminosity) by the given amount. @@ -109,31 +116,33 @@ def lighten_color(color, amount=0.3, hex=True): c = matplotlib.colors.cnames[color] except: c = color - + c = colorsys.rgb_to_hls(*matplotlib.colors.to_rgb(c)) lightened_c = colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) if hex: - lightened_c = tuple([int(round(v*255)) for v in lightened_c]) - lightened_c = '#%02x%02x%02x' % lightened_c + lightened_c = tuple([int(round(v * 255)) for v in lightened_c]) + lightened_c = "#%02x%02x%02x" % lightened_c return lightened_c -def rgb_str_to_values(rgbString, errorRgb=(255,255,255)): + +def rgb_str_to_values(rgbString, errorRgb=(255, 255, 255)): try: - r, g, b = re.findall(r'(\d+), (\d+), (\d+)', rgbString)[0] + r, g, b = re.findall(r"(\d+), (\d+), (\d+)", rgbString)[0] r, g, b = int(r), int(g), int(b) except TypeError: try: r, g, b = rgbString except Exception as e: - print('======================') + print("======================") traceback.print_exc() - print('======================') + print("======================") r, g, b = errorRgb return r, g, b -def rgba_str_to_values(rgbaString, errorRgb=(255,255,255,255)): + +def rgba_str_to_values(rgbaString, errorRgb=(255, 255, 255, 255)): try: - m = re.findall(r'(\d+), *(\d+), *(\d+),* *(\d+)*', rgbaString) + m = re.findall(r"(\d+), *(\d+), *(\d+),* *(\d+)*", rgbaString) r, g, b, a = m[0] if a: r, g, b, a = int(r), int(g), int(b), int(a) @@ -147,13 +156,15 @@ def rgba_str_to_values(rgbaString, errorRgb=(255,255,255,255)): r, g, b, a = errorRgb return r, g, b, a -def get_lut_from_colors(colors, name='mycmap', N=256, to_uint8=False): + +def get_lut_from_colors(colors, name="mycmap", N=256, to_uint8=False): cmap = LinearSegmentedColormap.from_list(name, colors, N=256) - lut = np.array([cmap(i)[:3] for i in np.linspace(0,1,256)]) + lut = np.array([cmap(i)[:3] for i in np.linspace(0, 1, 256)]) if to_uint8: - lut = (lut*255).astype(np.uint8) + lut = (lut * 255).astype(np.uint8) return lut + def plt_colormap_to_pg_lut(name: str, ncolors=256): cmap = plt.get_cmap(name) colors = [cmap(i) for i in np.linspace(0, 1, ncolors)] @@ -161,6 +172,7 @@ def plt_colormap_to_pg_lut(name: str, ncolors=256): lut = np.round(lut_float * 255).astype(np.uint8) return lut + def invertRGB(rgb_img, max_val=1.0): # see https://forum.image.sc/t/invert-rgb-image-without-changing-colors/33571 R = rgb_img[:, :, 0] @@ -169,42 +181,43 @@ def invertRGB(rgb_img, max_val=1.0): GB_mean = np.mean([G, B], axis=0) RB_mean = np.mean([R, B], axis=0) RG_mean = np.mean([R, G], axis=0) - rgb_img[:, :, 0] = max_val-GB_mean - rgb_img[:, :, 1] = max_val-RB_mean - rgb_img[:, :, 2] = max_val-RG_mean + rgb_img[:, :, 0] = max_val - GB_mean + rgb_img[:, :, 1] = max_val - RB_mean + rgb_img[:, :, 2] = max_val - RG_mean return rgb_img + def rescale_RGB(rgb_img, saturation_val=1.0): - rescaled_rgb = rgb_img-rgb_img.min() + rescaled_rgb = rgb_img - rgb_img.min() max_val = rescaled_rgb.max() - brightness = saturation_val/max_val - return rescaled_rgb*brightness - + brightness = saturation_val / max_val + return rescaled_rgb * brightness + -def get_greedy_lut(lab, lut, ids=None): +def get_greedy_lut(lab, lut, ids=None): if ids is None: ids = [obj.label for obj in skimage.measure.regionprops(lab)] - + if len(ids) == 0: return lut - + if len(ids) == 1: greedy_lut = np.copy(lut) greedy_lut[:] = greedy_lut[-1] - greedy_lut[0] = [0]*lut.shape[-1] + greedy_lut[0] = [0] * lut.shape[-1] return greedy_lut - + max_ID = max(ids, default=0) if max_ID + 1 > len(lut): # Repeat lut entries if not enough colors - lut = np.concatenate([lut]*((max_ID // len(lut))+1), axis=0) - + lut = np.concatenate([lut] * ((max_ID // len(lut)) + 1), axis=0) + if lab.ndim == 3: lab = lab.max(axis=0) - + expanded = skimage.segmentation.expand_labels(lab, distance=7) - adj_M = np.zeros([expanded.max() + 1]*2, dtype=bool) - + adj_M = np.zeros([expanded.max() + 1] * 2, dtype=bool) + # Taken from https://stackoverflow.com/questions/26486898/matrix-of-labels-to-adjacency-matrix adj_M[expanded[:, :-1], expanded[:, 1:]] = 1 adj_M[expanded[:, 1:], expanded[:, :-1]] = 1 @@ -216,15 +229,14 @@ def get_greedy_lut(lab, lut, ids=None): # adj_M = adj_M[1:, 1:] graph = nx.from_numpy_array(adj_M) - color_ids = nx.coloring.greedy_color( - graph, strategy='connected_sequential' - ) - - n_foregr_colors = len(lut)-1 + color_ids = nx.coloring.greedy_color(graph, strategy="connected_sequential") + + n_foregr_colors = len(lut) - 1 n_colors_greedy = max([color_id for color_id in color_ids.values()]) color_idxs = { - id:abs(int(n_foregr_colors * c/n_colors_greedy)-n_foregr_colors) - for id, c in color_ids.items() if id!=0 + id: abs(int(n_foregr_colors * c / n_colors_greedy) - n_foregr_colors) + for id, c in color_ids.items() + if id != 0 } greedy_lut = np.copy(lut) @@ -232,115 +244,129 @@ def get_greedy_lut(lab, lut, ids=None): return greedy_lut + def rgb_uint_to_html_hex(rgb): r, g, b = rgb - hex_color = f'#{r:02x}{g:02x}{b:02x}' + hex_color = f"#{r:02x}{g:02x}{b:02x}" return hex_color + def hex_to_rgb(hex): - if hex.startswith('#'): + if hex.startswith("#"): hex = hex[1:] - - return tuple(int(hex[i:i+2], 16) for i in (0, 2, 4)) + + return tuple(int(hex[i : i + 2], 16) for i in (0, 2, 4)) + def hierarchical_weights(alphas): alphas = np.array([1.0, *alphas]) if len(alphas) == 0: return alphas - + weights = [] - for i, a_ref in enumerate(alphas): - weight = np.prod(1-alphas[i+1:]) * a_ref + for i, a_ref in enumerate(alphas): + weight = np.prod(1 - alphas[i + 1 :]) * a_ref weights.append(weight) - + return weights[::-1] + def hierarchical_blend(images, weights): if len(images) == 1: return images[0] - + # Stack all images and do weighted sum stacked = np.stack(images, axis=0) # shape: (N, H, W) return np.tensordot(weights, stacked, axes=(0, 0)) + def merge_two_grayscale_imgs( - img1, img2, rgb1, rgb2, alpha=0.5, - brightness1=1.0, brightness2=1.0, dtype=np.uint8, - inverted=False - ): + img1, + img2, + rgb1, + rgb2, + alpha=0.5, + brightness1=1.0, + brightness2=1.0, + dtype=np.uint8, + inverted=False, +): if img1.max() > 1.0: img1 = skimage.exposure.rescale_intensity(img1, out_range=(0, 1.0)) - + if img2.max() > 1.0: img2 = skimage.exposure.rescale_intensity(img2, out_range=(0, 1.0)) - - img1_bright = np.clip(img1*brightness1, 0, 1.0) - img2_bright = np.clip(img2*brightness1, 0, 1.0) - - img1_rgb = (skimage.color.gray2rgb(img1_bright)*rgb1).astype(dtype) - img2_rgb = (skimage.color.gray2rgb(img2_bright)*rgb2).astype(dtype) - - merge = (alpha*img1_rgb + (1-alpha)*img2_rgb).astype(dtype) - + + img1_bright = np.clip(img1 * brightness1, 0, 1.0) + img2_bright = np.clip(img2 * brightness1, 0, 1.0) + + img1_rgb = (skimage.color.gray2rgb(img1_bright) * rgb1).astype(dtype) + img2_rgb = (skimage.color.gray2rgb(img2_bright) * rgb2).astype(dtype) + + merge = (alpha * img1_rgb + (1 - alpha) * img2_rgb).astype(dtype) + if inverted: merge_inverted = merge.copy() - merge_inverted[..., 0] = 255-((merge[..., 1]+merge[..., 2])/2) - merge_inverted[..., 1] = 255-((merge[..., 0]+merge[..., 2])/2) - merge_inverted[..., 2] = 255-((merge[..., 1]+merge[..., 0])/2) + merge_inverted[..., 0] = 255 - ((merge[..., 1] + merge[..., 2]) / 2) + merge_inverted[..., 1] = 255 - ((merge[..., 0] + merge[..., 2]) / 2) + merge_inverted[..., 2] = 255 - ((merge[..., 1] + merge[..., 0]) / 2) merge = merge_inverted return merge + def pg_ticks_to_colormap(ticks): positions = [] colors = [] for pos, color in ticks: positions.append(pos) colors.append(color) - + cmap = ColorMap(positions, colors) return cmap -def color_palette(name='Okabe_lto', **sns_color_palette_kwargs): + +def color_palette(name="Okabe_lto", **sns_color_palette_kwargs): """Create seaborn color palette (default or custom ones). Parameters ---------- name : str, optional Name of the color palette. Default 'Okabe_lto'. - + References ---------- https://thenode.biologists.com/data-visualization-with-flying-colors/research/ https://www.nature.com/articles/nmeth.1618 """ - if name == 'Okabe_lto': + if name == "Okabe_lto": colors = ( - '#0072B2', - '#F0E442', - '#009E73', - '#56B4E9', - '#E69F00', - '#000000', - '#CC79A7', - '#D55E00', + "#0072B2", + "#F0E442", + "#009E73", + "#56B4E9", + "#E69F00", + "#000000", + "#CC79A7", + "#D55E00", ) return sns.color_palette(colors) - elif name == 'Wong': + elif name == "Wong": colors = ( - (0, 0, 0), - (230, 159, 0), - (86, 180, 233), - (0, 158, 115), - (240, 228, 66), - (0, 114, 178), - (213, 94, 0), - (204, 121, 167), + (0, 0, 0), + (230, 159, 0), + (86, 180, 233), + (0, 158, 115), + (240, 228, 66), + (0, 114, 178), + (213, 94, 0), + (204, 121, 167), ) return sns.color_palette(colors) - + return sns.color_palette(**sns_color_palette_kwargs) + def grayscale_apply_lut(image, lut): """ Map a grayscale image to RGBA using a lookup table. @@ -359,11 +385,12 @@ def grayscale_apply_lut(image, lut): """ # Normalize image to [0, N-1] N = lut.shape[0] - img = np.clip(image, 0, 1) if image.dtype.kind == 'f' else image / 255.0 + img = np.clip(image, 0, 1) if image.dtype.kind == "f" else image / 255.0 indices = np.clip((img * (N - 1)).astype(int), 0, N - 1) rgba = lut[indices] return rgba + def get_complementary_color(rgba_str: str) -> str: r, g, b, a = rgba_str_to_values(rgba_str) - return f'rgba({255 - r}, {255 - g}, {255 - b}, {a})' \ No newline at end of file + return f"rgba({255 - r}, {255 - g}, {255 - b}, {a})" diff --git a/cellacdc/components/__init__.py b/cellacdc/components/__init__.py new file mode 100644 index 000000000..ef991cf0b --- /dev/null +++ b/cellacdc/components/__init__.py @@ -0,0 +1 @@ +"""Reusable GUI components extracted from widgets.py and apps.py.""" diff --git a/cellacdc/components/base.py b/cellacdc/components/base.py new file mode 100644 index 000000000..06e13c336 --- /dev/null +++ b/cellacdc/components/base.py @@ -0,0 +1,65 @@ +from qtpy.QtCore import QEventLoop, Qt +from qtpy.QtWidgets import QDialog, QMainWindow + +from .. import printl + + +class QBaseDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + + def exec_(self, resizeWidthFactor=None): + if resizeWidthFactor is not None: + self.show() + self.resize(int(self.width() * resizeWidthFactor), self.height()) + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + + try: + self.setEnabled(True) + except Exception as err: + pass + + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + def keyPressEvent(self, event) -> None: + if event.key() == Qt.Key_Escape: + event.ignore() + return + + super().keyPressEvent(event) + + +class QBaseWindow(QMainWindow): + def __init__(self, parent=None): + super().__init__(parent) + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + def keyPressEvent(self, event) -> None: + if event.key() == Qt.Key_Escape: + event.ignore() + return + + super().keyPressEvent(event) diff --git a/cellacdc/components/buttons.py b/cellacdc/components/buttons.py new file mode 100644 index 000000000..bda76b40d --- /dev/null +++ b/cellacdc/components/buttons.py @@ -0,0 +1,691 @@ +import os +from functools import partial + +from qtpy.QtCore import ( + QEvent, + QTimer, + Qt, + QUrl, + QSize, +) +from qtpy.QtGui import ( + QBrush, + QIcon, + QLinearGradient, + QPainter, + QPixmap, +) +from qtpy.QtWidgets import ( + QApplication, + QFileDialog, + QGridLayout, + QHBoxLayout, + QLabel, + QPushButton, + QWidget, + QWidgetAction, +) + +from .. import utils + +class PushButton(QPushButton): + def __init__( + self, *args, icon=None, alignIconLeft=False, flat=False, hoverable=False + ): + super().__init__(*args) + if icon is not None: + self.setIcon(icon) + self.alignIconLeft = alignIconLeft + self._text = None + if flat: + self.setFlat(True) + if hoverable: + self.installEventFilter(self) + + def setRetainSizeWhenHidden(self, retainSize): + sp = self.sizePolicy() + sp.setRetainSizeWhenHidden(retainSize) + self.setSizePolicy(sp) + + def eventFilter(self, object, event): + if event.type() == QEvent.Type.HoverEnter: + self.setFlat(False) + elif event.type() == QEvent.Type.HoverLeave: + self.setFlat(True) + return False + + def show(self): + text = self.text() + if not self.alignIconLeft: + super().show() + return + + self._text = text + self.setStyleSheet("text-align:left;") + self.setLayout(QGridLayout()) + textLabel = QLabel(self._text) + textLabel.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + textLabel.setAttribute(Qt.WA_TransparentForMouseEvents, True) + self._layout().addWidget(textLabel) + super().show() + + def confirmAction(self): + self.baseIcon = self.icon() + self.setIcon(QIcon(":greenTick.svg")) + QTimer.singleShot(2000, self.resetButton) + + def resetButton(self): + self.setIcon(self.baseIcon) + + def setText(self, text): + if self._text is None: + super().setText(text) + else: + super().setText(self._text) + + +class LoadPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":fork_lift.svg")) + + +class mergePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":merge-IDs.svg")) + + +class okPushButton(PushButton): + def __init__(self, *args, isDefault=True, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":yesGray.svg")) + if isDefault: + self.setDefault(True) + # QShortcut(Qt.Key_Return, self, self.click) + # QShortcut(Qt.Key_Enter, self, self.click) + + +class MagnifyingGlassPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":magnGlass.svg")) + + +class MagnifyingGlassAllPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":magnGlass_all.svg")) + + +class AssignNewIDButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":assign_new_id.svg")) + + +class LockPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":lock.svg")) + self.toggled.connect(self.onToggled) + + def onToggled(self, checked): + if not self.isCheckable(): + return + + if checked: + self.setIcon(QIcon(":lock_closed.svg")) + else: + self.setIcon(QIcon(":lock_open.svg")) + + def setCheckable(self, checkable: bool): + super().setCheckable(checkable) + if checkable: + self.setIcon(QIcon(":lock_open.svg")) + else: + self.setIcon(QIcon(":lock.svg")) + + +class SkipPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":skip_arrow.svg")) + + +class BedPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":bed.svg")) + + +class BedPlusLabelPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":bed_plus_label.svg")) + iconH = self.iconSize().height() + iconW = int(iconH * 2.5) + self.setIconSize(QSize(iconW, iconH)) + + +class NoBedPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":no_bed.svg")) + + +class NavigatePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":navigate.svg")) + + +class SwitchPlaneButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":switch_2d_plane.svg")) + self._planes = ("xy", "zy", "zx") + self._idx = 0 + + def switchPlane(self): + self._idx += 1 + + def setPlane(self, plane): + self._idx = self._planes.index(plane) + + def plane(self): + return self._planes[self._idx % 3] + + def depthAxes(self): + plane = self.plane() + for axes in "xyz": + if axes not in plane: + return axes + + +class zoomPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":zoom_out.svg")) + + def setIconZoomOut(self): + self.setIcon(QIcon(":zoom_out.svg")) + + def setIconZoomIn(self): + self.setIcon(QIcon(":zoom_in.svg")) + + +class WarningButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":warning.svg")) + + +class reloadPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":reload.svg")) + + +class savePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":file-save.svg")) + + +class autoPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":cog_play.svg")) + + +class newFilePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":file-new.svg")) + + +class helpPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":help.svg")) + + +class viewPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":eye.svg")) + + +class infoPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":info.svg")) + + +class threeDPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":3d.svg")) + + +class twoDPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":2d.svg")) + + +class addPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":add.svg")) + + +class futurePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":arrow_future.svg")) + + +class FutureAllPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":arrow_future_all.svg")) + + +class currentPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":arrow_current.svg")) + + +class arrowUpPushButton(PushButton): + def __init__(self, *args, **kwargs): + alignIconLeft = kwargs.get("alignIconLeft", False) + super().__init__( + *args, icon=QIcon(":arrow-up.svg"), alignIconLeft=alignIconLeft + ) + + +class arrowDownPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":arrow-down.svg")) + + +class selectAllPushButton(PushButton): + sigClicked = Signal(object, bool) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._status = "deselect" + self.setIcon(QIcon(":deselect_all.svg")) + self.setText("Deselect all") + self.clicked.connect(self.onClicked) + self.setMinimumWidth(self.sizeHint().width()) + + def setChecked(self, checked): + if checked: + self._status == "deselect" + else: + self._status == "select" + self.click() + + def onClicked(self): + if self._status == "select": + icon_fn = ":deselect_all.svg" + self._status = "deselect" + checked = True + text = "Deselect all" + else: + icon_fn = ":select_all.svg" + text = "Select all" + self._status = "select" + checked = False + self.setIcon(QIcon(icon_fn)) + self.setText(text) + self.sigClicked.emit(self, checked) + + +class subtractPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":subtract.svg")) + + +class continuePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":continue.svg")) + + +class calcPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":calc.svg")) + + +class playPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":play.svg")) + + +class stopPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":stop.svg")) + + +class copyPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":edit-copy.svg")) + self.clicked.connect(self.onClicked) + self._text_to_copy = None + + def setTextToCopy(self, text): + self._text_to_copy = text + + def onClicked(self): + self._original_text = self.text() + if self._text_to_copy is not None: + cb = QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + cb.setText(self._text_to_copy, mode=cb.Clipboard) + + super().setText("Copied!") + self.setIcon(QIcon(":greenTick.svg")) + QTimer.singleShot(2000, self.resetButton) + + def resetButton(self): + self.setText(self._original_text) + self.setIcon(QIcon(":edit-copy.svg")) + + +class OpenFilePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":folder-open.svg")) + + +class movePushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":folder-move.svg")) + + +class DownloadPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":download.svg")) + + +class showInFileManagerButton(PushButton): + def __init__(self, *args, setDefaultText=False, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":drawer.svg")) + self._path_to_browse = None + if setDefaultText: + self.setDefaultText() + + def setDefaultText(self): + self._text = utils.get_show_in_file_manager_text() + self.setText(self._text) + + def setPathToBrowse(self, path: os.PathLike): + self._path_to_browse = path + self.clicked.connect(partial(utils.showInExplorer, path)) + + +class OpenUrlButton(PushButton): + def __init__(self, url, *args, **kwargs): + self._url = url + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":browser.svg")) + self.clicked.connect(self.openUrl) + + def openUrl(self): + QDesktopServices.openUrl(QUrl(self._url)) + + +class LessThanPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":less_than.svg")) + flat = kwargs.get("flat") + if flat is not None: + self.setFlat(True) + + +class showDetailsButton(PushButton): + sigToggled = Signal(bool) + + def __init__(self, *args, txt="Show details...", parent=None): + super().__init__(txt, parent) + # self.setText(txt) + self.txt = txt + self.checkedIcon = QIcon(":hideUp.svg") + self.uncheckedIcon = QIcon(":showDown.svg") + self.setIcon(self.uncheckedIcon) + self.toggled.connect(self.onClicked) + self.setCheckable(True) + w = self.sizeHint().width() + 10 + self.setFixedWidth(w) + + def onClicked(self, checked): + if checked: + self.setText(self.txt.replace("Show", "Hide")) + self.setIcon(self.checkedIcon) + else: + self.setText(self.txt) + self.setIcon(self.uncheckedIcon) + + self.sigToggled.emit(checked) + + +class cancelPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":cancelButton.svg")) + + +class setPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":cog.svg")) + + +class TrainPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":train.svg")) + + +class noPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":no.svg")) + + +class editPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":edit-id.svg")) + + +class delPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":bin.svg")) + + +class eraserPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":eraser.svg")) + + +class CrossCursorPointButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":cross_cursor.svg")) + + +class TestPushButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":test.svg")) + + +class browseFileButton(PushButton): + sigPathSelected = Signal(str) + + def __init__( + self, + *args, + ext=None, + title="Select file", + start_dir="", + openFolder=False, + **kwargs, + ): + """PushButton with sigPathSelected Signal to select file or folder + + Parameters + ---------- + ext : dict or None, optional + If not None, this is a dictionary of + {'FILE NAME': ['.ext1', '.ext2', ...]}. + For example, to allow only selection of CSV files, + pass {'CSV': ['.csv']}. + + Note that the 'FILE NAME' is arbitrary. Default is None + title : str, optional + Title of the File Manager window. Default is 'Select file' + start_dir : str, optional + Directory where the File Manager window will initially be open. + Default is '' + openFolder : bool, optional + If True, allows for selection of folders instead of files. + Default is False + """ + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":folder-open.svg")) + self.clicked.connect(self.browse) + + self._title = title + self._start_dir = start_dir + self._openFolder = openFolder + self._file_types = "All Files (*)" + if ext is not None: + s_li = [] + for name, extensions in ext.items(): + _s = "" + if isinstance(extensions, str): + extensions = [extensions] + for ext in extensions: + _s = f"{_s}*{ext} " + s_li.append(f"{name} {_s.strip()}") + + self._file_types = ";;".join(s_li) + self._file_types = f"{self._file_types};;All Files (*)" + + def setStartPath(self, start_path): + self._start_dir = start_path + + def browse(self): + if self._openFolder: + fileDialog = QFileDialog.getExistingDirectory + args = (self, self._title, self._start_dir) + else: + fileDialog = QFileDialog.getOpenFileName + args = (self, self._title, self._start_dir, self._file_types) + file_path = fileDialog(*args) + if not isinstance(file_path, str): + file_path = file_path[0] + if file_path: + self.sigPathSelected.emit(file_path) + + +def getPushButton(buttonText, qparent=None): + isCancelButton = ( + buttonText.lower().find("cancel") != -1 + or buttonText.lower().find("abort") != -1 + ) + isYesButton = ( + buttonText.lower().find("yes") != -1 + or buttonText.lower().find("ok") != -1 + or buttonText.lower().find("continue") != -1 + or buttonText.lower().find("recommended") != -1 + ) + isSettingsButton = buttonText.lower().find("set") != -1 + isNoButton = ( + buttonText.replace(" ", "").lower() == "no" + or buttonText.lower().find("Do not ") != -1 + or buttonText.lower().find("no, ") != -1 + ) + isDelButton = buttonText.lower().find("delete") != -1 + isAddButton = buttonText.lower().find("add ") != -1 + is3Dbutton = buttonText.find(" 3D ") != -1 + is2Dbutton = buttonText.find(" 2D ") != -1 + isSaveButton = buttonText.lower().find("overwrite") != -1 + isNewFileButton = buttonText.lower().find("rename") != -1 + isTryAgainButton = buttonText.lower().find("try again") != -1 + + if isCancelButton: + button = cancelPushButton(buttonText, qparent) + if qparent is not None: + qparent.addCancelButton(button=button) + elif isYesButton: + button = okPushButton(buttonText, qparent) + if qparent is not None: + qparent.okButton = button + elif isSettingsButton: + button = setPushButton(buttonText, qparent) + elif isNoButton: + button = noPushButton(buttonText, qparent) + elif isDelButton: + button = delPushButton(buttonText, qparent) + elif isAddButton: + button = addPushButton(buttonText, qparent) + elif is3Dbutton: + button = threeDPushButton(buttonText, qparent) + elif is2Dbutton: + button = twoDPushButton(buttonText, qparent) + elif isSaveButton: + button = savePushButton(buttonText, qparent) + elif isNewFileButton: + button = newFilePushButton(buttonText, qparent) + elif isTryAgainButton: + button = reloadPushButton(buttonText, qparent) + else: + button = QPushButton(buttonText, qparent) + + return button, isCancelButton + + +def CustomGradientMenuAction(gradient: QLinearGradient, name: str, parent): + pixmap = QPixmap(100, 15) + painter = QPainter(pixmap) + brush = QBrush(gradient) + painter.fillRect(QRect(0, 0, 100, 15), brush) + painter.end() + label = QLabel() + label.setPixmap(pixmap) + label.setContentsMargins(1, 1, 1, 1) + labelName = QLabel(name) + hbox = QHBoxLayout() + delButton = delPushButton() + hbox.addWidget(labelName) + hbox.addStretch(1) + hbox.addWidget(label) + hbox.addWidget(delButton) + widget = QWidget() + widget.setLayout(hbox) + action = QWidgetAction(parent) + action.name = name + action.setDefaultWidget(widget) + action.delButton = delButton + delButton.action = action + return action diff --git a/cellacdc/components/inputs_basic.py b/cellacdc/components/inputs_basic.py new file mode 100644 index 000000000..253ef2d40 --- /dev/null +++ b/cellacdc/components/inputs_basic.py @@ -0,0 +1,167 @@ +import re + +from qtpy.QtCore import ( + QEvent, + Qt, + Signal, +) +from qtpy.QtGui import ( + QFontMetrics, + QKeyEvent, + QRegularExpressionValidator, +) +from qtpy.QtWidgets import ( + QLineEdit, + QScrollBar, +) + +from .palette import LINEEDIT_INVALID_ENTRY_STYLESHEET + +class ElidingLineEdit(QLineEdit): + def __init__(self, parent=None, minWidth=None): + super().__init__(parent) + self._text = "" + self._minWidth = minWidth + if minWidth is not None: + self.setMinimumWidth(minWidth) + + self.textEdited.connect(self.setText) + self.installEventFilter(self) + self._elide = True + + def setText(self, text: str, width=None, elide=True) -> None: + if width is None: + width = self._minWidth + + if width is None: + try: + textToPrevRatio = len(text) / len(self.text()) + width = round(self.width() * textToPrevRatio) + except ZeroDivisionError: + width = self.width() + + if width > self.width(): + width = self.width() + + self._text = text + if not elide or not self._elide: + super().setText(text) + return + + fm = QFontMetrics(self.font()) + elidedText = fm.elidedText(text, Qt.ElideLeft, width) + + super().setText(elidedText) + self.setToolTip(text) + + def text(self): + return self._text + + def resizeEvent(self, event): + newWidth = event.size().width() + self.setText(self._text, width=newWidth) + event.accept() + + def eventFilter(self, a0: "QObject", a1: "QEvent") -> bool: + isFocusIn = a1.type() == QEvent.Type.FocusIn + if isFocusIn and (self.isReadOnly() or not self.isEnabled()): + self.clearFocus() + return True + return super().eventFilter(a0, a1) + + def focusInEvent(self, event): + super().focusInEvent(event) + self._elide = False + self.setText(self._text, elide=False) + self.setCursorPosition(len(self.text())) + + def focusOutEvent(self, event): + self._elide = True + super().focusOutEvent(event) + self.setText(self._text) + + +class ValidLineEdit(QLineEdit): + def __init__(self, parent=None): + super().__init__(parent) + + def setInvalidStyleSheet(self): + self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) + + def setValidStyleSheet(self): + self.setStyleSheet("") + + +class KeepIDsLineEdit(ValidLineEdit): + sigIDsChanged = Signal(list) + sigSort = Signal() + sigEnterPressed = Signal() + + def __init__(self, instructionsLabel, parent=None): + super().__init__(parent) + + self.validPattern = "^[0-9-, ]+$" + regExpr = QRegularExpression(self.validPattern) + self.setValidator(QRegularExpressionValidator(regExpr)) + + self.textChanged.connect(self.onTextChanged) + self.editingFinished.connect(self.onEditingFinished) + + self.instructionsText = instructionsLabel.text() + self._label = instructionsLabel + + def keyPressEvent(self, event) -> None: + super().keyPressEvent(event) + if event.text() == ",": + self.sigSort.emit() + elif event.key() == Qt.Key.Key_Return or event.key() == Qt.Key.Key_Enter: + self.sigEnterPressed.emit() + + def onTextChanged(self, text): + IDs = [] + rangesMatch = re.findall(r"(\d+-\d+)", text) + if rangesMatch: + for rangeText in rangesMatch: + start, stop = rangeText.split("-") + start, stop = int(start), int(stop) + IDs.extend(range(start, stop + 1)) + text = re.sub(r"(\d+)-(\d+)", "", text) + IDsMatch = re.findall(r"(\d+)", text) + if IDsMatch: + for ID in IDsMatch: + IDs.append(int(ID)) + self.IDs = sorted(list(set(IDs))) + self.sigIDsChanged.emit(self.IDs) + + def onEditingFinished(self): + self.sigSort.emit() + + def warnNotExistingID(self): + self.setInvalidStyleSheet() + self._label.setText( + " Some of the IDs are not existing --> they will be IGNORED" + ) + self._label.setStyleSheet("color: red") + + def setInstructionsText(self): + self.setValidStyleSheet() + self._label.setText(self.instructionsText) + self._label.setStyleSheet("") + + +class ScrollBar(QScrollBar): + def __init__(self, *args): + super().__init__(*args) + self.installEventFilter(self) + self.setContextMenuPolicy(Qt.NoContextMenu) + + def eventFilter(self, object, event) -> bool: + if event.type() == QEvent.Type.Wheel: + return True + elif event.type() == QEvent.Type.MouseButtonPress: + # Filter right-click to prevent context menu + return event.button() == Qt.MouseButton.RightButton + elif event.type() == QEvent.Type.MouseButtonRelease: + # Filter right-click to prevent context menu + return event.button() == Qt.MouseButton.RightButton + return False diff --git a/cellacdc/components/layout.py b/cellacdc/components/layout.py new file mode 100644 index 000000000..341d53e58 --- /dev/null +++ b/cellacdc/components/layout.py @@ -0,0 +1,256 @@ +from qtpy.QtCore import QEvent, Qt, Signal +from qtpy.QtGui import QColor, QPalette +from qtpy.QtWidgets import ( + QCheckBox, + QFrame, + QGridLayout, + QGroupBox, + QHBoxLayout, + QScrollArea, + QSizePolicy, + QVBoxLayout, + QWidget, +) + +import pyqtgraph as pg + +from .buttons import cancelPushButton, okPushButton +from .palette import BASE_COLOR + +class VerticalSpacerEmptyWidget(QWidget): + + def __init__(self, parent=None, height=5) -> None: + super().__init__(parent) + self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + self.setFixedHeight(height) +class QHWidgetSpacer(QWidget): + def __init__(self, width=10, parent=None) -> None: + super().__init__(parent) + self.setFixedWidth(width) + + +class QVWidgetSpacer(QWidget): + def __init__(self, height=10, parent=None) -> None: + super().__init__(parent) + self.setFixedHeight(height) + + +class QHLine(QFrame): + def __init__(self, shadow="Sunken", parent=None, color=None): + super().__init__(parent) + self.setFrameShape(QFrame.Shape.HLine) + self.setFrameShadow(getattr(QFrame, shadow)) + if color is not None: + self.setColor(color) + + def setColor(self, color): + qcolor = pg.mkColor(color) + pal = self.palette() + pal.setColor(QPalette.ColorRole.WindowText, qcolor) + self.setPalette(pal) + + +class QVLine(QFrame): + def __init__(self, shadow="Plain", parent=None, color=None): + super().__init__(parent) + self.setFrameShape(QFrame.Shape.VLine) + self.setFrameShadow(getattr(QFrame.Shadow, shadow)) + if color is not None: + self.setColor(color) + + def setColor(self, color): + qcolor = pg.mkColor(color) + pal = self.palette() + pal.setColor(QPalette.ColorRole.WindowText, qcolor) + self.setPalette(pal) + + +class VerticalResizeHline(QFrame): + dragged = Signal(object) + clicked = Signal(object) + released = Signal(object) + + def __init__(self): + super().__init__() + self.setCursor(Qt.SplitVCursor) + self.setFrameShape(QFrame.Shape.HLine) + self.setFrameShadow(QFrame.Shadow.Sunken) + self.installEventFilter(self) + self.isMousePressed = False + self._height = 4 + self.setMinimumHeight(self._height) + + def mousePressEvent(self, event) -> None: + self.isMousePressed = True + self.clicked.emit(event) + return super().mousePressEvent(event) + + def mouseMoveEvent(self, event) -> None: + self.dragged.emit(event) + return super().mouseMoveEvent(event) + + def mouseReleaseEvent(self, event) -> None: + self.isMousePressed = False + self.released.emit(event) + return super().mouseReleaseEvent(event) + + def eventFilter(self, object, event): + if event.type() == QEvent.Type.Enter: + self.setLineWidth(0) + self.setMidLineWidth(self._height) + pal = self.palette() + pal.setColor(QPalette.ColorRole.WindowText, QColor(BASE_COLOR)) + self.setPalette(pal) + # self.setStyleSheet('background-color: #4d4d4d') + elif event.type() == QEvent.Type.Leave: + self.setMidLineWidth(0) + self.setLineWidth(1) + return False + + +class GroupBox(QGroupBox): + def __init__(self, *args, keyPressCallback=None): + super().__init__(*args) + self.keyPressCallback = None + self.setFocusPolicy(Qt.NoFocus) + + def keyPressEvent(self, event) -> None: + event.ignore() + if self.keyPressCallback is None: + return + + self.keyPressCallback() + + +class CheckBox(QCheckBox): + def __init__(self, *args, keyPressCallback=None): + super().__init__(*args) + self.keyPressCallback = None + self.setFocusPolicy(Qt.NoFocus) + + def keyPressEvent(self, event) -> None: + event.ignore() + if self.keyPressCallback is None: + return + + self.keyPressCallback() + + +class CancelOkButtonsLayout(QHBoxLayout): + def __init__(self, *args, additionalButtons=None): + super().__init__(*args) + + self.cancelButton = cancelPushButton("Cancel") + self.okButton = okPushButton(" Ok ") + + self.addStretch(1) + self.addWidget(self.cancelButton) + self.addSpacing(20) + + if additionalButtons is not None: + for button in additionalButtons: + self.addWidget(button) + + self.addWidget(self.okButton) + +class FormLayout(QGridLayout): + def __init__(self): + QGridLayout.__init__(self) + + def addFormWidget( + self, formWidget, leftLabelAlignment=Qt.AlignRight, align=None, row=0 + ): + for col, item in enumerate(formWidget.items): + if col == 0: + alignment = leftLabelAlignment + elif col == 2: + alignment = Qt.AlignLeft + else: + alignment = align + try: + if alignment is None: + self.addWidget(item, row, col) + else: + self.addWidget(item, row, col, alignment=alignment) + except TypeError: + self.addLayout(item, row, col) + + +class ScrollArea(QScrollArea): + sigLeaveEvent = Signal() + + def __init__( + self, parent=None, resizeVerticalOnShow=False, dropArrowKeyEvents=False + ) -> None: + super().__init__(parent) + self.setWidgetResizable(True) + self.setFrameStyle(QFrame.Shape.NoFrame) + self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.containerWidget = None + self.resizeVerticalOnShow = resizeVerticalOnShow + self.isOnlyVertical = False + self.dropArrowKeyEvents = dropArrowKeyEvents + + def setVerticalLayout(self, layout, widget=None): + if widget is None: + self.containerWidget = QWidget() + else: + self.containerWidget = widget + self.containerWidget.setLayout(layout) + self.containerWidget.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) + self.setWidget(self.containerWidget) + self.containerWidget.installEventFilter(self) + self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.isOnlyVertical = True + + def setWidget(self, widget): + self.containerWidget = widget + super().setWidget(widget) + + def _resizeHorizontal(self): + self.setMinimumWidth( + self.containerWidget.minimumSizeHint().width() + + self.verticalScrollBar().width() + ) + + def minimumWidthNoScrollbar(self) -> int: + width = ( + self.containerWidget.minimumSizeHint().width() + + self.verticalScrollBar().width() + ) + return width + + def minimumHeightNoScrollbar(self) -> int: + height = ( + self.containerWidget.minimumSizeHint().height() + + self.horizontalScrollBar().height() + ) + return height + + def _resizeVertical(self): + height = ( + self.containerWidget.minimumSizeHint().height() + + self.horizontalScrollBar().height() + ) + self.containerWidget.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) + + self.setFixedHeight(height) + + def eventFilter(self, object, event: QEvent): + if event.type() == QEvent.Type.Leave: + self.sigLeaveEvent.emit() + + if object != self.containerWidget: + return False + + isResize = event.type() == QEvent.Type.Resize + isShow = event.type() == QEvent.Type.Show + if isResize and self.isOnlyVertical: + self._resizeHorizontal() + elif isShow and self.resizeVerticalOnShow: + self._resizeVertical() + return False diff --git a/cellacdc/components/lists.py b/cellacdc/components/lists.py new file mode 100644 index 000000000..0264d43b5 --- /dev/null +++ b/cellacdc/components/lists.py @@ -0,0 +1,542 @@ +from qtpy.QtCore import ( + QAbstractItemModel, + QAbstractListModel, + QDataStream, + QIODevice, + QItemSelection, + QItemSelectionModel, + QModelIndex, + Qt, + Signal, + QSize, + QByteArray, + QObject, + QMimeData, +) +from qtpy.QtGui import QBrush +from qtpy.QtWidgets import ( + QAbstractItemView, + QComboBox, + QHBoxLayout, + QLayout, + QLabel, + QListView, + QListWidget, + QListWidgetItem, + QTreeWidget, + QTreeWidgetItem, + QTreeWidgetItemIterator, + QTextEdit, + QWidget, +) + +from .. import html_utils +from .palette import LISTWIDGET_STYLESHEET, TREEWIDGET_STYLESHEET, font + +class _ReorderableListModel(QAbstractListModel): + """ + ReorderableListModel is a list model which implements reordering of its + items via drag-n-drop + """ + + dragDropFinished = Signal() + + def __init__(self, items, parent=None): + QAbstractItemModel.__init__(self, parent) + self.nodes = items + self.lastDroppedItems = [] + self.pendingRemoveRowsAfterDrop = False + + def rowForItem(self, text): + """ + rowForItem method returns the row corresponding to the passed in item + or None if no such item exists in the model + """ + try: + row = self.nodes.index(text) + except ValueError: + return None + return row + + def index(self, row, column, parent): + if row < 0 or row >= len(self.nodes): + return QModelIndex() + return self.createIndex(row, column) + + def parent(self, index): + return QModelIndex() + + def rowCount(self, index): + if index.isValid(): + return 0 + return len(self.nodes) + + def data(self, index, role): + if not index.isValid(): + return None + if role == Qt.DisplayRole: + row = index.row() + if row < 0 or row >= len(self.nodes): + return None + return self.nodes[row] + elif role == Qt.SizeHintRole: + return QSize(48, 32) + else: + return None + + def supportedDropActions(self): + return Qt.MoveAction + + def flags(self, index): + if not index.isValid(): + return Qt.ItemIsEnabled + return ( + Qt.ItemIsEnabled + | Qt.ItemIsSelectable + | Qt.ItemIsDragEnabled + | Qt.ItemIsDropEnabled + ) + + def insertRows(self, row, count, index): + if index.isValid(): + return False + if count <= 0: + return False + # inserting 'count' empty rows starting at 'row' + self.beginInsertRows(QModelIndex(), row, row + count - 1) + for i in range(0, count): + self.nodes.insert(row + i, "") + self.endInsertRows() + return True + + def removeRows(self, row, count, index): + if index.isValid(): + return False + if count <= 0: + return False + num_rows = self.rowCount(QModelIndex()) + self.beginRemoveRows(QModelIndex(), row, row + count - 1) + for i in range(count, 0, -1): + self.nodes.pop(row - i + 1) + self.endRemoveRows() + + if self.pendingRemoveRowsAfterDrop: + """ + If we got here, it means this call to removeRows is the automatic + 'cleanup' action after drag-n-drop performed by Qt + """ + self.pendingRemoveRowsAfterDrop = False + self.dragDropFinished.emit() + + return True + + def setData(self, index, value, role): + if not index.isValid(): + return False + if index.row() < 0 or index.row() > len(self.nodes): + return False + self.nodes[index.row()] = str(value) + self.dataChanged.emit(index, index) + return True + + def mimeTypes(self): + return ["application/vnd.treeviewdragdrop.list"] + + def mimeData(self, indexes): + mimedata = QMimeData() + encoded_data = QByteArray() + stream = QDataStream(encoded_data, QIODevice.WriteOnly) + for index in indexes: + if index.isValid(): + text = self.data(index, 0) + stream << QByteArray(text.encode("utf-8")) + mimedata.setData("application/vnd.treeviewdragdrop.list", encoded_data) + return mimedata + + def dropMimeData(self, data, action, row, column, parent): + if action == Qt.IgnoreAction: + return True + if not data.hasFormat("application/vnd.treeviewdragdrop.list"): + return False + if column > 0: + return False + + num_rows = self.rowCount(QModelIndex()) + if num_rows <= 0: + return False + + if row < 0: + if parent.isValid(): + row = parent.row() + else: + return False + + encoded_data = data.data("application/vnd.treeviewdragdrop.list") + stream = QDataStream(encoded_data, QIODevice.ReadOnly) + + new_items = [] + rows = 0 + while not stream.atEnd(): + text = QByteArray() + stream >> text + text = bytes(text).decode("utf-8") + index = self.nodes.index(text) + new_items.append((text, index)) + rows += 1 + + self.lastDroppedItems = [] + for text, index in new_items: + target_row = row + if index < row: + target_row += 1 + self.beginInsertRows(QModelIndex(), target_row, target_row) + self.nodes.insert(target_row, self.nodes[index]) + self.endInsertRows() + self.lastDroppedItems.append(text) + row += 1 + + self.pendingRemoveRowsAfterDrop = True + return True + + +class _SelectionModel(QItemSelectionModel): + def __init__(self, parent=None, isSingleSelection=False): + QItemSelectionModel.__init__(self, parent) + self.isSingleSelection = isSingleSelection + + def onModelItemsReordered(self): + new_selection = QItemSelection() + new_index = QModelIndex() + for item in self.model().lastDroppedItems: + row = self.model().rowForItem(item) + if row is None: + continue + new_index = self.model().index(row, 0, QModelIndex()) + new_selection.select(new_index, new_index) + + self.clearSelection() + flags = ( + QItemSelectionModel.SelectionFlag.ClearAndSelect + | QItemSelectionModel.SelectionFlag.Rows + | QItemSelectionModel.SelectionFlag.Current + ) + self.select(new_selection, flags) + self.setCurrentIndex(new_index, flags) + if not self.isSingleSelection: + self.reset() + + +class ReorderableListView(QListView): + def __init__(self, items=None, parent=None, isSingleSelection=False) -> None: + super().__init__(parent) + if items is None: + items = [] + + self.isSingleSelection = isSingleSelection + self._model = _ReorderableListModel(items) + self._selectionModel = _SelectionModel(self._model) + self._model.dragDropFinished.connect(self._selectionModel.onModelItemsReordered) + self.setModel(self._model) + self.setSelectionModel(self._selectionModel) + self.setDragDropMode(QAbstractItemView.DragDropMode.InternalMove) + self.setDragDropOverwriteMode(False) + styleSheet = f""" + QListView {{ + selection-background-color: rgba(200, 200, 200, 0.30); + selection-color: black; + show-decoration-selected: 1; + }} + QListView::item {{ + border-bottom: 1px solid rgba(180, 180, 180, 0.5); + }} + QListView::item:hover {{ + background-color: rgba(200, 200, 200, 0.30); + }} + """ + self.setStyleSheet(styleSheet) + + def setItems(self, items): + self._model.nodes = items + + def items(self): + return self._model.nodes + + # def mouseReleaseEvent(self, e: QMouseEvent) -> None: + # super().mouseReleaseEvent(e) + # self._selectionModel.reset() + + +class listWidget(QListWidget): + def __init__( + self, *args, isMultipleSelection=False, minimizeHeight=False, **kwargs + ): + super().__init__(*args, **kwargs) + self.itemHeight = None + self.setStyleSheet(LISTWIDGET_STYLESHEET) + self.setFont(font) + if isMultipleSelection: + self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + + self.minimizeHeight = minimizeHeight + + def setSelectedAll(self, selected): + for i in range(self.count()): + self.item(i).setSelected(selected) + + def setSelectedItems(self, itemsText): + for i in range(self.count()): + item = self.item(i) + item.setSelected(item.text() in itemsText) + + def addItems(self, labels) -> None: + super().addItems(labels) + if self.itemHeight is not None: + self.setItemHeight() + + if self.minimizeHeight: + itemHeight = self.sizeHintForRow(0) + self.setMaximumHeight(itemHeight * self.count() + itemHeight * 2) + + def addItem(self, text): + super().addItem(text) + if self.itemHeight is None: + return + self.setItemHeight() + + def setItemHeight(self, height=40): + self.itemHeight = height + for i in range(self.count()): + item = self.item(i) + item.setSizeHint(QSize(0, height)) + + def selectedItemsText(self): + return [item.text() for item in self.selectedItems()] + + +class OrderableListWidget(QWidget): + sigEnterEvent = Signal(object) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._labels = [] + + def setParentItem(self, item): + self._item = item + + def setLabelsColor(self, selected): + if selected: + stylesheet = "color : black" + else: + stylesheet = "" + + for label in self._labels: + label.setStyleSheet(stylesheet) + + def enterEvent(self, event): + super().enterEvent(event) + self.setLabelsColor(True) + self.sigEnterEvent.emit(self._item) + + # def leaveEvent(self, event): + # super().leaveEvent(event) + # self.setLabelsColor(self._item.isSelected()) + # printl('leave', self._item.isSelected()) + + def addLabel(self, label): + self._labels.append(label) + + +class OrderableList(listWidget): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.setMouseTracking(True) + self.itemEntered.connect(self.onItemEntered) + + def onItemEntered(self, enteredItem): + enteredRow = self.row(enteredItem) + for i in range(self.count()): + item = self.item(i) + item._container.setLabelsColor(i == enteredRow or item.isSelected()) + + def leaveEvent(self, event): + super().leaveEvent(event) + for i in range(self.count()): + item = self.item(i) + item._container.setLabelsColor(item.isSelected()) + + def addItems(self, items): + self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + nr_items = len(items) + nn = [str(n) for n in range(1, nr_items + 1)] + for i, item in enumerate(items): + itemW = QListWidgetItem() + itemContainer = OrderableListWidget() + itemContainer.setParentItem(itemW) + itemText = QLabel(item) + tableNrLabel = QLabel("| Table nr.") + itemContainer.addLabel(tableNrLabel) + itemContainer.addLabel(itemText) + itemLayout = QHBoxLayout() + itemNumberWidget = QComboBox() + itemNumberWidget.addItems(nn) + itemLayout.addWidget(itemText) + itemLayout.addWidget(tableNrLabel) + itemLayout.addWidget(itemNumberWidget) + itemContainer.setLayout(itemLayout) + itemLayout.setSizeConstraint(QLayout.SizeConstraint.SetFixedSize) + itemW.setSizeHint(itemContainer.sizeHint()) + self.addItem(itemW) + self.setItemWidget(itemW, itemContainer) + itemW._text = item + itemW._nrWidget = itemNumberWidget + itemW._container = itemContainer + itemNumberWidget.setDisabled(True) + itemNumberWidget.textActivated.connect(self.onTextActivated) + itemNumberWidget._currentNr = 1 + itemNumberWidget.row = i + itemContainer.sigEnterEvent.connect(self.onItemEntered) + + self.itemSelectionChanged.connect(self.onItemSelectionChanged) + + def keyPressEvent(self, event) -> None: + if event.key() == Qt.Key_Escape: + self.clearSelection() + event.ignore() + return + super().keyPressEvent(event) + + def updateNr(self): + for i in range(self.count()): + item = self.item(i) + item._currentNr = int(item._nrWidget.currentText()) + + def onItemSelectionChanged(self): + for i in range(self.count()): + item = self.item(i) + item._container.setLabelsColor(item.isSelected()) + item._nrWidget.setDisabled(not item.isSelected()) + if item._nrWidget.currentText() != "1": + item._nrWidget.setCurrentText("1") + item._currentNr = 1 + + for i, item in enumerate(self.selectedItems()): + item._nrWidget.setCurrentText(f"{i + 1}") + item._currentNr = i + 1 + + def onTextActivated(self, text): + changedNr = self.sender()._currentNr + for item in self.selectedItems(): + row = self.row(item) + if self.sender().row == row: + changedNr = item._currentNr + continue + + for item in self.selectedItems(): + row = self.row(item) + if self.sender().row == row: + continue + nr = int(item._nrWidget.currentText()) + if nr == int(text): + item._nrWidget.setCurrentText(str(changedNr)) + break + + self.updateNr() + + +class TreeWidget(QTreeWidget): + def __init__(self, *args, multiSelection=False): + super().__init__(*args) + self.setStyleSheet(TREEWIDGET_STYLESHEET) + self.setFont(font) + if multiSelection: + self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + self.itemClicked.connect(self.selectAllChildren) + + self.isCtrlDown = False + self.isShiftDown = False + + def keyPressEvent(self, ev): + if ev.key() == Qt.Key_Escape: + self.clearSelection() + elif ev.key() == Qt.Key_Control: + self.isCtrlDown = True + elif ev.key() == Qt.Key_Shift: + self.isShiftDown = True + + def keyReleaseEvent(self, ev): + if ev.key() == Qt.Key_Control: + self.isCtrlDown = False + elif ev.key() == Qt.Key_Shift: + self.isShiftDown = False + + def onFocusChanged(self): + self.isCtrlDown = False + self.isShiftDown = False + + def selectAllChildren(self, item_or_label): + label = None + if isinstance(item_or_label, QLabel): + label = item_or_label + else: + item = item_or_label + if item.childCount() == 0: + return + + if label is not None: + if not self.isCtrlDown and not self.isShiftDown: + self.clearSelection() + label.item.setSelected(True) + if self.isShiftDown: + selectionStarted = False + it = QTreeWidgetItemIterator(self) + while it: + item = it.value() + if item is None: + break + if item.isSelected(): + selectionStarted = not selectionStarted + if selectionStarted: + item.setSelected(True) + it += 1 + + for item in self.selectedItems(): + if item.parent() is None: + for i in range(item.childCount()): + item.child(i).setSelected(True) + + + +class TreeWidgetItem(QTreeWidgetItem): + def __init__(self, *args, columnColors=None): + super().__init__(*args) + + if columnColors is not None: + for c, color in enumerate(columnColors): + if color is None: + continue + self.setBackground(c, QBrush(color)) + + +class FilterObject(QObject): + sigFilteredEvent = Signal(object, object) + + def __init__(self) -> None: + super().__init__() + + def eventFilter(self, object, event): + self.sigFilteredEvent.emit(object, event) + return super().eventFilter(object, event) + + +class readOnlyQList(QTextEdit): + def __init__(self, parent=None): + super().__init__(parent) + self.setReadOnly(True) + self.items = [] + + def addItems(self, items): + self.items.extend(items) + items = [str(item) for item in self.items] + columnList = html_utils.paragraph("
".join(items)) + self.setText(columnList) + diff --git a/cellacdc/components/palette.py b/cellacdc/components/palette.py new file mode 100644 index 000000000..fe5d088a0 --- /dev/null +++ b/cellacdc/components/palette.py @@ -0,0 +1,131 @@ +import os +import operator + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + +import pyqtgraph as pg +from qtpy.QtGui import QFont + +from .. import config, settings_folderpath +from .. import _palettes + +LINEEDIT_WARNING_STYLESHEET = _palettes.lineedit_warning_stylesheet() +LINEEDIT_INVALID_ENTRY_STYLESHEET = _palettes.lineedit_invalid_entry_stylesheet() +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BASE_COLOR = _palettes.base_color() +PROGRESSBAR_QCOLOR = _palettes.QProgressBarColor() +PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR = _palettes.QProgressBarHighlightedTextColor() +TEXT_COLOR = _palettes.text_float_rgba() + +font = QFont() +font.setPixelSize(12) + +custom_cmaps_filepath = os.path.join(settings_folderpath, "custom_colormaps.ini") + +str_to_operator_mapper = {"+": operator.add, "-": operator.sub} + +sign_int_mapper = {"+": 1, "-": -1} + + +def removeHSVcmaps(): + hsv_cmaps = [] + for g, grad in pg.graphicsItems.GradientEditorItem.Gradients.items(): + if grad["mode"] == "hsv": + hsv_cmaps.append(g) + for g in hsv_cmaps: + del pg.graphicsItems.GradientEditorItem.Gradients[g] + + +def renamePgCmaps(): + Gradients = pg.graphicsItems.GradientEditorItem.Gradients + try: + Gradients["hot"] = Gradients.pop("thermal") + except KeyError: + pass + try: + Gradients.pop("greyclip") + except KeyError: + pass + + +def _tab20gradient(): + cmap = plt.get_cmap("tab20") + ticks = [(t, tuple([int(v * 255) for v in cmap(t)])) for t in np.linspace(0, 1, 20)] + gradient = {"ticks": ticks, "mode": "rgb"} + return gradient + + +def _tab10gradient(): + cmap = plt.get_cmap("tab10") + ticks = [(t, tuple([int(v * 255) for v in cmap(t)])) for t in np.linspace(0, 1, 20)] + gradient = {"ticks": ticks, "mode": "rgb"} + return gradient + + +def getCustomGradients(name="image"): + CustomGradients = {} + if not os.path.exists(custom_cmaps_filepath): + return CustomGradients + + cp = config.ConfigParser() + cp.read(custom_cmaps_filepath) + for section in cp.sections(): + if not section.startswith(f"{name}"): + continue + + cmap_name = section[len(f"{name}.") :] + CustomGradients[cmap_name] = {"ticks": [], "mode": "rgb"} + for option in cp.options(section): + value = cp[section][option] + pos, *rgb = value.split(",") + rgb = tuple([int(c) for c in rgb]) + pos = float(pos) + CustomGradients[cmap_name]["ticks"].append((pos, rgb)) + return CustomGradients + + +def addGradients(): + Gradients = pg.graphicsItems.GradientEditorItem.Gradients + Gradients["cividis"] = { + "ticks": [ + (0.0, (0, 34, 78, 255)), + (0.25, (66, 78, 108, 255)), + (0.5, (124, 123, 120, 255)), + (0.75, (187, 173, 108, 255)), + (1.0, (254, 232, 56, 255)), + ], + "mode": "rgb", + } + Gradients["cool"] = { + "ticks": [(0.0, (0, 255, 255, 255)), (1.0, (255, 0, 255, 255))], + "mode": "rgb", + } + Gradients["sunset"] = { + "ticks": [ + (0.0, (71, 118, 148, 255)), + (0.4, (222, 213, 141, 255)), + (0.8, (229, 184, 155, 255)), + (1.0, (240, 127, 97, 255)), + ], + "mode": "rgb", + } + Gradients["tab20"] = _tab20gradient() + Gradients["tab10"] = _tab10gradient() + cmaps = {} + for name, gradient in Gradients.items(): + ticks = gradient["ticks"] + colors = [tuple([v / 255 for v in tick[1]]) for tick in ticks] + cmaps[name] = LinearSegmentedColormap.from_list(name, colors, N=256) + return cmaps, Gradients + + +nonInvertibleCmaps = ["cool", "sunset", "bipolar"] + +renamePgCmaps() +removeHSVcmaps() +cmaps, Gradients = addGradients() +GradientsLabels = Gradients.copy() +GradientsImage = Gradients.copy() diff --git a/cellacdc/components/path_controls.py b/cellacdc/components/path_controls.py new file mode 100644 index 000000000..12537f3c1 --- /dev/null +++ b/cellacdc/components/path_controls.py @@ -0,0 +1,74 @@ +from qtpy.QtCore import Signal +from qtpy.QtGui import QShowEvent +from qtpy.QtWidgets import QFrame, QHBoxLayout, QLineEdit + +from .buttons import browseFileButton +from .inputs_basic import ElidingLineEdit + +class filePathControl(QFrame): + sigValueChanged = Signal(str) + + def __init__( + self, + parent=None, + browseFolder=False, + fileManagerTitle="Select file", + validExtensions=None, + startFolder="", + elide=False, + ): + super().__init__(parent) + + layout = QHBoxLayout() + if elide: + self.le = ElidingLineEdit() + else: + self.le = QLineEdit() + + self.browseButton = browseFileButton( + openFolder=browseFolder, + title=fileManagerTitle, + ext=validExtensions, + start_dir=startFolder, + ) + + layout.addWidget(self.le) + layout.addWidget(self.browseButton) + self.setLayout(layout) + + self.le.editingFinished.connect(self.setTextTooltip) + self.browseButton.sigPathSelected.connect(self.setText) + + self.setFrameStyle(QFrame.Shape.StyledPanel) + + def setText(self, text): + self.le.setText(text) + self.le.setToolTip(text) + self.sigValueChanged.emit(self.le.text()) + + def setTextTooltip(self): + self.le.setToolTip(self.le.text()) + self.sigValueChanged.emit(self.le.text()) + + def path(self): + return self.le.text() + + def showEvent(self, a0: QShowEvent) -> None: + self.le.setFixedHeight(self.browseButton.height()) + return super().showEvent(a0) + + +class FolderPathControl(filePathControl): + def __init__(self, **kwargs): + super().__init__(browseFolder=True, fileManagerTitle="Select folder", **kwargs) + + +class CsvFilePathControl(filePathControl): + def __init__(self, **kwargs): + super().__init__( + browseFolder=False, + fileManagerTitle="Select a CSV file", + validExtensions={"CSV files": [".csv", ".CSV"]}, + **kwargs, + ) + diff --git a/cellacdc/components/progress.py b/cellacdc/components/progress.py new file mode 100644 index 000000000..bcb61cc96 --- /dev/null +++ b/cellacdc/components/progress.py @@ -0,0 +1,245 @@ +import logging +import math +import sys +import time + +import numpy as np +import pyqtgraph as pg +from qtpy.QtCore import Property, QPropertyAnimation, QObject, QPointF, Qt, Signal +from qtpy.QtGui import QFont, QPalette, QPainter, QColor, QPen +from qtpy.QtWidgets import ( + QGraphicsBlurEffect, + QLabel, + QPlainTextEdit, + QProgressBar, + QTextEdit, +) + +from .. import _palettes, utils +from .palette import PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, PROGRESSBAR_QCOLOR + + +class XStream(QObject): + _stdout = None + _stderr = None + messageWritten = Signal(str) + + def flush(self): + pass + + def fileno(self): + return -1 + + def write(self, msg): + if not self.signalsBlocked(): + self.messageWritten.emit(msg) + + @staticmethod + def stdout(): + if not XStream._stdout: + XStream._stdout = XStream() + sys.stdout = XStream._stdout + return XStream._stdout + + @staticmethod + def stderr(): + if not XStream._stderr: + XStream._stderr = XStream() + sys.stderr = XStream._stderr + return XStream._stderr + + +class QtHandler(logging.Handler): + def __init__(self): + super().__init__() + + def emit(self, record): + record = self.format(record) + if record: + XStream.stdout().write("%s\n" % record) + + +class QLog(QPlainTextEdit): + sigClose = Signal() + + def __init__(self, *args, logger=None): + super().__init__(*args) + self.logger = logger + self.setReadOnly(True) + + def connect(self): + XStream.stdout().messageWritten.connect(self.writeStdOutput) + + def writeStdOutput(self, text: str) -> None: + super().insertPlainText(text) + self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) + + def writeStdErr(self, text: str) -> None: + super().insertPlainText(text) + self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) + if self.logger is not None: + self.logger.exception(text) + + def insertPlainText(self, text: str) -> None: + super().insertPlainText(f"{text}\n") + self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) + + def closeEvent(self, event) -> None: + super().closeEvent(event) + self.sigClose.emit() + + +class QLogConsole(QTextEdit): + def __init__(self, parent=None): + super().__init__(parent) + self.setReadOnly(True) + font = QFont() + font.setPixelSize(13) + self.setFont(font) + + def write(self, message): + message = message.replace("\r ", "") + if message: + self.apppendText(message) + + def append(self, text: str) -> None: + super().append(text) + self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) + + def insertPlainText(self, text: str) -> None: + super().append(text) + self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) + + +class ProgressBar(QProgressBar): + def __init__(self, parent=None): + super().__init__(parent) + palette = self.palette() + palette.setColor(QPalette.ColorRole.Highlight, PROGRESSBAR_QCOLOR) + palette.setColor( + QPalette.ColorRole.HighlightedText, PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR + ) + self.setPalette(palette) + + +class ProgressBarWithETA(ProgressBar): + def __init__(self, parent=None): + self.parent = parent + super().__init__(parent=parent) + self.ETA_label = QLabel("NDh:NDm:NDs") + + def update(self, step: int): + self.setValue(self.value() + step) + t = time.perf_counter() + if not hasattr(self, "last_time_update"): + self.last_time_update = t + self.mean_value_duration = None + return + seconds_per_value = (t - self.last_time_update) / step + value_left = self.maximum() - self.value() + if self.mean_value_duration is None: + self.mean_value_duration = seconds_per_value + else: + self.mean_value_duration = ( + self.mean_value_duration * (self.value() - 1) + seconds_per_value + ) / self.value() + + seconds_left = self.mean_value_duration * value_left + ETA = utils.seconds_to_ETA(seconds_left) + self.ETA_label.setText(ETA) + self.last_time_update = t + return ETA + + def show(self): + QProgressBar.show(self) + self.ETA_label.show() + + def hide(self): + QProgressBar.hide(self) + self.ETA_label.hide() + + +class NoneWidget: + def __init__(self): + pass + + def value(self): + return None + + def setValue(self, value): + return + + +class LoadingCircleAnimation(QLabel): + def __init__(self, size=32, motionBlur=False, parent=None): + super().__init__(parent) + self.setAlignment(Qt.AlignCenter) + self._size = size + size % 2 + self._radius = int(self._size / 2) + self.setFixedSize(self._size, self._size) + self._dotDiameter = int(self._size * 0.15) + self._dotDiameter = self._dotDiameter + self._dotDiameter % 2 + self._dotRadius = int(self._dotDiameter / 2) + + self._rgb = _palettes.getPainterColor()[:3] + self._index = 0 + + self.setBrushesAndAngles() + + if motionBlur: + blurEffect = QGraphicsBlurEffect() + blurRadius = self._size * 0.02 + if blurRadius < 1: + blurRadius = 1 + blurEffect.setBlurRadius(blurRadius) + self.setGraphicsEffect(blurEffect) + + self.animation = QPropertyAnimation(self, b"index", self) + self.animation.setStartValue(0) + self.animation.setEndValue(11) + self.animation.setLoopCount(-1) + self.animation.setDuration(1200) + self.animation.start() + + self.update() + + def setVisible(self, visible): + if visible: + self.animation.start() + else: + self.animation.stop() + super().setVisible(visible) + + def setBrushesAndAngles(self): + self._brushes = [] + self._pens = [] + alphas = np.round(np.linspace(0, 255, 12)).astype(int) + self._angles = np.arange(0, 360, 30) + for alpha in alphas: + color = QColor(*self._rgb, alpha) + self._brushes.append(pg.mkBrush(color)) + self._pens.append(pg.mkPen(color)) + + @Property(int) + def index(self): + return self._index + + @index.setter + def index(self, value): + self._index = value + self.update() + + def paintEvent(self, event): + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + painter.translate(self._radius, self._radius) + for i in range(12): + idx = i - self._index + angle = self._angles[i] + painter.setBrush(self._brushes[idx]) + painter.setPen(self._pens[idx]) + x = (self._radius - self._dotRadius) * math.cos(angle * math.pi / 180) + y = (self._radius - self._dotRadius) * math.sin(angle * math.pi / 180) + painter.drawEllipse(QPointF(x, y), self._dotRadius, self._dotRadius) + + painter.end() diff --git a/cellacdc/config.py b/cellacdc/config.py index da319ee89..cb09d747e 100755 --- a/cellacdc/config.py +++ b/cellacdc/config.py @@ -4,23 +4,25 @@ import os import json -from typing import get_type_hints +from typing import get_type_hints import re from . import printl, debug_true_filepath + class ConfigParser(configparser.ConfigParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.optionxform = str - + def __repr__(self) -> str: string = pprint.pformat( {section: dict(self[section]) for section in self.sections()} ) return string + from . import GUI_INSTALLED if GUI_INSTALLED: @@ -30,7 +32,7 @@ class QtWarningHandler(QObject): sigGeometryWarning = Signal(object) def _resizeWarningHandler(self, msg_type, msg_log_context, msg_string): - if msg_string.find('Unable to set geometry') != -1: + if msg_string.find("Unable to set geometry") != -1: try: self.sigGeometryWarning.emit(msg_string) except Exception as e: @@ -42,113 +44,114 @@ def _resizeWarningHandler(self, msg_type, msg_log_context, msg_string): qInstallMessageHandler(warningHandler._resizeWarningHandler) help_text = ( - 'Welcome to Cell-ACDC!\n\n' - 'You can run Cell-ACDC both as a GUI or in the command line.\n' - 'To run the GUI type `acdc`. To run the command line type `acdc -p `.\n' - 'The `` must be a workflow INI file.\n' - 'If you do not have one, use the GUI to set up the parameters.\n\n' - 'Enjoy!' + "Welcome to Cell-ACDC!\n\n" + "You can run Cell-ACDC both as a GUI or in the command line.\n" + "To run the GUI type `acdc`. To run the command line type `acdc -p `.\n" + "The `` must be a workflow INI file.\n" + "If you do not have one, use the GUI to set up the parameters.\n\n" + "Enjoy!" ) try: ap = argparse.ArgumentParser( - prog='acdc', description=help_text, - formatter_class=argparse.RawTextHelpFormatter + prog="acdc", + description=help_text, + formatter_class=argparse.RawTextHelpFormatter, ) - + ap.add_argument( - '-p', '--params', - default='', + "-p", + "--params", + default="", type=str, - metavar='PATH_TO_PARAMS', - help=('Path of the ".ini" workflow file') - ) - - ap.add_argument( - '-v', '--version', action='store_true', - help=( - 'Get information about Cell-ACDC version and environment' - ) + metavar="PATH_TO_PARAMS", + help=('Path of the ".ini" workflow file'), ) - + ap.add_argument( - '--reset', action='store_true', - help=( - 'Reset Cell-ACDC settings' - ) + "-v", + "--version", + action="store_true", + help=("Get information about Cell-ACDC version and environment"), ) - + + ap.add_argument("--reset", action="store_true", help=("Reset Cell-ACDC settings")) + ap.add_argument( - '-info', '--info', action='store_true', - help=( - 'Get information about Cell-ACDC version and environment' - ) + "-info", + "--info", + action="store_true", + help=("Get information about Cell-ACDC version and environment"), ) - + ap.add_argument( - '-y', '--yes', action='store_true', + "-y", + "--yes", + action="store_true", help=( 'Sets confirmation values to "yes" automatically. Users will ' - 'not be prompted for confirmation when installing Cell-ACDC for the first time.' - ) + "not be prompted for confirmation when installing Cell-ACDC for the first time." + ), ) - + ap.add_argument( - '-d', '--debug', action='store_true', + "-d", + "--debug", + action="store_true", help=( - 'Used for debugging. Test code with' + "Used for debugging. Test code with" '"from cellacdc.config import parser_args, debug = parser_args["debug"]", ' - 'if debug: ' - ) + "if debug: " + ), ) ap.add_argument( - '--install_details', - default='', + "--install_details", + default="", type=str, - metavar='PATH_TO_INSTALL_DETAILS', - help=('Path of the "install_details.json" file') + metavar="PATH_TO_INSTALL_DETAILS", + help=('Path of the "install_details.json" file'), ) - + ap.add_argument( - '--cpModelsDownload', - action='store_true', - help=('Whether to download cellpose models'), + "--cpModelsDownload", + action="store_true", + help=("Whether to download cellpose models"), # metavar='CP_MODELS_DOWNLOAD_FLAG' ) ap.add_argument( - '--YeaZModelsDownload', - action='store_true', - help=('Whether to download YeaZ models'), + "--YeaZModelsDownload", + action="store_true", + help=("Whether to download YeaZ models"), # metavar='YEAZ_MODELS_DOWNLOAD_FLAG' ) ap.add_argument( - '--DeepSeaModelsDownload', - action='store_true', - help=('Whether to download DeepSea models'), + "--DeepSeaModelsDownload", + action="store_true", + help=("Whether to download DeepSea models"), # metavar='DEEPSEA_MODELS_DOWNLOAD_FLAG' ) ap.add_argument( - '--StarDistModelsDownload', - action='store_true', - help=('Whether to download StarDist models'), + "--StarDistModelsDownload", + action="store_true", + help=("Whether to download StarDist models"), # metavar='STARDIST_MODELS_DOWNLOAD_FLAG' ) ap.add_argument( - '--TrackastraModelsDownload', - action='store_true', - help=('Whether to download Trackastra models'), + "--TrackastraModelsDownload", + action="store_true", + help=("Whether to download Trackastra models"), # metavar='TRACKASTRA_MODELS_DOWNLOAD_FLAG' ) - + ap.add_argument( - '--AllModelsDownload', - action='store_true', + "--AllModelsDownload", + action="store_true", help=( - 'Whether to download models for Cellpose, YeaZ, DeepSea, StarDist, Trackastra.' + "Whether to download models for Cellpose, YeaZ, DeepSea, StarDist, Trackastra." ), ) @@ -158,106 +161,121 @@ def _resizeWarningHandler(self, msg_type, msg_log_context, msg_string): parser_args, unknown = ap.parse_known_args() parser_args = vars(parser_args) if os.path.exists(debug_true_filepath): - parser_args['debug'] = True - - install_details = parser_args.get('install_details') - if install_details and install_details != '': + parser_args["debug"] = True + + install_details = parser_args.get("install_details") + if install_details and install_details != "": try: - with open(parser_args['install_details'], 'r') as f: + with open(parser_args["install_details"], "r") as f: install_details = json.load(f) - for pathlike in ['conda_path', 'clone_path', 'venv_path', 'target_dir',]: + for pathlike in [ + "conda_path", + "clone_path", + "venv_path", + "target_dir", + ]: if pathlike in install_details: - install_details[pathlike] = f'"{os.path.abspath(install_details[pathlike])}"' - parser_args['install_details'] = install_details + install_details[pathlike] = ( + f'"{os.path.abspath(install_details[pathlike])}"' + ) + parser_args["install_details"] = install_details except Exception as e: printl( - 'Error reading install details from file: ' - f'{parser_args["install_details"]}. Error: {e}' + "Error reading install details from file: " + f"{parser_args['install_details']}. Error: {e}" ) - parser_args['install_details'] = {} - - + parser_args["install_details"] = {} + + except Exception as err: - import pdb; pdb.set_trace() - print('Importing from notebook, ignoring Cell-ACDC argument parser...') + import pdb + + pdb.set_trace() + print("Importing from notebook, ignoring Cell-ACDC argument parser...") parser_args = {} - parser_args['debug'] = False + parser_args["debug"] = False + def preprocessing_mapper(): from cellacdc import preprocess, cellacdc_path, acdc_regex from inspect import getmembers, isfunction + functions = getmembers(preprocess, isfunction) - preprocess_py_path = os.path.join(cellacdc_path, 'preprocess.py') - with open(preprocess_py_path, 'r') as py_file: + preprocess_py_path = os.path.join(cellacdc_path, "preprocess.py") + with open(preprocess_py_path, "r") as py_file: text = py_file.read() valid_functions_names = acdc_regex.get_function_names(text) mapper = {} for func_name, func in functions: - if func_name.startswith('_'): + if func_name.startswith("_"): continue - - if func_name == 'dummy_filter' and not parser_args['debug']: + + if func_name == "dummy_filter" and not parser_args["debug"]: continue - + if func_name not in valid_functions_names: continue - - method = func_name.title().replace('_', ' ') + + method = func_name.title().replace("_", " ") mapper[method] = { - 'function': func, - 'docstring': func.__doc__, - 'function_name': func_name - } + "function": func, + "docstring": func.__doc__, + "function_name": func_name, + } return mapper + def preprocessing_init_func_mapper(): from cellacdc import preprocess, cellacdc_path, acdc_regex from inspect import getmembers, isfunction + functions = getmembers(preprocess, isfunction) - preprocess_py_path = os.path.join(cellacdc_path, 'preprocess.py') - with open(preprocess_py_path, 'r') as py_file: + preprocess_py_path = os.path.join(cellacdc_path, "preprocess.py") + with open(preprocess_py_path, "r") as py_file: text = py_file.read() valid_functions_names = acdc_regex.get_function_names(text) mapper = {} for func_name, func in functions: - if not func_name.startswith('_init_'): + if not func_name.startswith("_init_"): continue - - method = func_name.lstrip('_init_').title().replace('_', ' ') + + method = func_name.lstrip("_init_").title().replace("_", " ") mapper[method] = { - 'function': func, - 'docstring': func.__doc__, - 'function_name': func_name - } + "function": func, + "docstring": func.__doc__, + "function_name": func_name, + } return mapper + def preprocess_recipe_to_ini_items(preproc_recipe): if preproc_recipe is None: return {} - + ini_items = {} for s, step in enumerate(preproc_recipe): - section = f'preprocess.step{s+1}' + section = f"preprocess.step{s + 1}" ini_items[section] = {} - ini_items[section]['method'] = step['method'] - for option, value in step['kwargs'].items(): + ini_items[section]["method"] = step["method"] + for option, value in step["kwargs"].items(): ini_items[section][option] = str(value) return ini_items + def preprocess_ini_items_to_recipe(ini_items): recipe = {} - + for section, section_items in ini_items.items(): - if not section.startswith('preprocess.step'): + if not section.startswith("preprocess.step"): continue - - step_n = int(re.findall(r'step(\d+)', section)[0]) - recipe[step_n] = {'method': section_items['method']} + + step_n = int(re.findall(r"step(\d+)", section)[0]) + recipe[step_n] = {"method": section_items["method"]} kwargs = {} for option, value_str in section_items.items(): - if option == 'method': + if option == "method": continue - + value = value_str if isinstance(value_str, str): for _type in (int, float, str): @@ -266,17 +284,18 @@ def preprocess_ini_items_to_recipe(ini_items): break except Exception as e: continue - + kwargs[option] = value - - recipe[step_n]['kwargs'] = kwargs - + + recipe[step_n]["kwargs"] = kwargs + recipe = [value for key, value in sorted(recipe.items())] - + if not recipe: return - + return recipe + PREPROCESS_MAPPER = preprocessing_mapper() -PREPROCESS_INIT_MAPPER = preprocessing_init_func_mapper() \ No newline at end of file +PREPROCESS_INIT_MAPPER = preprocessing_init_func_mapper() diff --git a/cellacdc/core.py b/cellacdc/core.py index 80ae1df0e..b92e22c84 100755 --- a/cellacdc/core.py +++ b/cellacdc/core.py @@ -3,9 +3,7 @@ from typing import List, Dict, Any, Iterable, Tuple, Callable, Union, Literal import os import time -from concurrent.futures import ( - ThreadPoolExecutor, ProcessPoolExecutor, as_completed -) +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from functools import partial from importlib import import_module import numpy as np @@ -35,7 +33,7 @@ from tqdm import tqdm -from . import load, myutils +from . import load, utils from . import cca_df_colnames, printl, base_cca_dict, base_cca_tree_dict from . import features from . import error_up_str @@ -51,40 +49,41 @@ from . import favourite_func_metrics_csv_path from . import default_index_cols -from ._types import ( - ChannelsDict -) +from ._types import ChannelsDict + def get_indices_dash_pattern(arr, line_length, gap): n = len(arr) - sampling_rate = (line_length+gap) + sampling_rate = line_length + gap n_lines = n // sampling_rate - tot_len = n_lines*sampling_rate - indices_2D = np.arange(tot_len).reshape((n_lines,sampling_rate)) + tot_len = n_lines * sampling_rate + indices_2D = np.arange(tot_len).reshape((n_lines, sampling_rate)) indices = (indices_2D[:, :line_length]).flatten() return indices + def get_line(r0, c0, r1, c1, dashed=True): x1, x2 = sorted((c0, c1)) - Dc = (c0-c1) - Dr = (r0-r1) - dist = np.ceil(np.sqrt(np.square(Dr)+np.square(Dc))) - + Dc = c0 - c1 + Dr = r0 - r1 + dist = np.ceil(np.sqrt(np.square(Dr) + np.square(Dc))) + if Dc == 0: - xx = np.array([c0]*int(dist)) + xx = np.array([c0] * int(dist)) y1, y2 = sorted((r0, r1)) yy = np.linspace(y1, y2, len(xx)) else: xx = np.linspace(x1, x2, int(dist)) - m = Dr/Dc - q = (c0*r1 - c1*r0)/Dc - yy = xx*m+q + m = Dr / Dc + q = (c0 * r1 - c1 * r0) / Dc + yy = xx * m + q if dashed: indices = get_indices_dash_pattern(xx, 4, 3) xx = xx[indices] yy = yy[indices] return xx, yy + def np_replace_values(arr, old_values, new_values): # See method_jdehesa https://stackoverflow.com/questions/45735230/how-to-replace-a-list-of-values-in-a-numpy-array old_values = np.asarray(old_values) @@ -97,6 +96,7 @@ def np_replace_values(arr, old_values, new_values): arr = replacer[arr - n_min] return arr + def nearest_nonzero_2D(a, y, x, max_dist=None, return_coords=False): value = a[round(y), round(x)] if value > 0: @@ -105,7 +105,7 @@ def nearest_nonzero_2D(a, y, x, max_dist=None, return_coords=False): else: return value r, c = np.nonzero(a) - dist = ((r - y)**2 + (c - x)**2) + dist = (r - y) ** 2 + (c - x) ** 2 if dist.size == 0: if return_coords: return 0, 0, 0 @@ -124,23 +124,25 @@ def nearest_nonzero_2D(a, y, x, max_dist=None, return_coords=False): else: return a[y_nearest, x_nearest] + def nearest_nonzero_1D(arr, x, return_index=False): if arr[x] > 0: if return_index: return arr[x], x else: return arr[x] - nonzero_idxs, = np.nonzero(arr) - dist = (nonzero_idxs - x)**2 + (nonzero_idxs,) = np.nonzero(arr) + dist = (nonzero_idxs - x) ** 2 min_idx = dist.argmin() nearest_nonzero_idx = nonzero_idxs[min_idx] val = arr[nearest_nonzero_idx] - + if return_index: return val, nearest_nonzero_idx else: return val + def nearest_nonzero_z_idx_from_z_centroid(obj, current_z=-1): zc = obj.local_centroid[0] z_obj_local = int(zc) @@ -148,20 +150,21 @@ def nearest_nonzero_z_idx_from_z_centroid(obj, current_z=-1): z_obj_global = z_obj_local + obj.bbox[0] if current_z == z_obj_global and is_obj_slice_not_empty: return current_z - - zslices_not_empty_arr = np.any(obj.image, axis=(1,2)).astype(np.uint8) + + zslices_not_empty_arr = np.any(obj.image, axis=(1, 2)).astype(np.uint8) _, nearest_nonzero_z_local = nearest_nonzero_1D( zslices_not_empty_arr, z_obj_local, return_index=True ) nearest_nonzero_z_global = nearest_nonzero_z_local + obj.bbox[0] return nearest_nonzero_z_global + def compute_twoframes_velocity(prev_lab, lab, spacing=None): prev_rp = skimage.measure.regionprops(prev_lab) rp = skimage.measure.regionprops(lab) prev_IDs = [obj.label for obj in prev_rp] - velocities_pxl = [0]*len(rp) - velocities_um = [0]*len(rp) + velocities_pxl = [0] * len(rp) + velocities_um = [0] * len(rp) for i, obj in enumerate(rp): if obj.label not in prev_IDs: continue @@ -171,17 +174,18 @@ def compute_twoframes_velocity(prev_lab, lab, spacing=None): v_pixel = np.linalg.norm(diff) velocities_pxl[i] = v_pixel if spacing is not None: - v_um = np.linalg.norm(diff*spacing) + v_um = np.linalg.norm(diff * spacing) velocities_um[i] = v_um return velocities_pxl, velocities_um + def nearest_points_objects(objs_arr: np.ndarray, other_obj: np.ndarray): """Find the nearest points between all objects in objs_arr and other_obj Parameters ---------- - objs_arr : (N, P, 2) np.ndarray of floats - Array with N pages (one for each object), P rows (number of points) + objs_arr : (N, P, 2) np.ndarray of floats + Array with N pages (one for each object), P rows (number of points) and 2 columns for y, x coordinates other_obj : (P1, 2) np.ndarray Array with P1 rows (number of points) and 2 columns for y, x coordinates @@ -189,20 +193,21 @@ def nearest_points_objects(objs_arr: np.ndarray, other_obj: np.ndarray): Returns ------- (N,) np.ndarray - Array with N elements where the ith element is the minimum distance + Array with N elements where the ith element is the minimum distance between object objs_arr[i] and other_obj - """ + """ # diff[l, k, i] = objs_arr[l][k] - other_obj[i] diff = objs_arr[:, :, np.newaxis] - other_obj - + # dist[l, i, j] = math.dist(objs_arr[l][i], other_obj[j]) dist = np.linalg.norm(diff, axis=3) - + # min_dist[l] = min_dist(objs_arr[l], other_obj) min_dist = np.nanmin(dist, axis=(1, 2)) - + return min_dist + def nearest_point_2Dyx(points, all_others): """ Given 2D array of [y, x] coordinates points and all_others return the @@ -223,6 +228,7 @@ def nearest_point_2Dyx(points, all_others): min_dist = np.min(dist) return min_dist, nearest_point + def lab_replace_values(lab, rp, oldIDs, newIDs, in_place=True): if not in_place: lab = lab.copy() @@ -240,15 +246,16 @@ def lab_replace_values(lab, rp, oldIDs, newIDs, in_place=True): lab[obj.slice][obj.image] = newIDs[idx] except OverflowError: # it should be uint32 already but sometimes it was not - lab = lab.astype(np.uint32) + lab = lab.astype(np.uint32) lab[obj.slice][obj.image] = newIDs[idx] return lab + def post_process_segm(labels, return_delIDs=False, **kwargs): - min_solidity = kwargs.get('min_solidity') - min_area = kwargs.get('min_area') - max_elongation = kwargs.get('max_elongation') - min_obj_no_zslices = kwargs.get('min_obj_no_zslices') + min_solidity = kwargs.get("min_solidity") + min_area = kwargs.get("min_area") + max_elongation = kwargs.get("max_elongation") + min_obj_no_zslices = kwargs.get("min_obj_no_zslices") if labels.ndim == 3: delIDs = set() if min_obj_no_zslices is not None: @@ -259,11 +266,10 @@ def post_process_segm(labels, return_delIDs=False, **kwargs): if obj_no_zslices < min_obj_no_zslices: labels[obj.slice][obj.image] = 0 delIDs.add(obj.label) - + for z, lab in enumerate(labels): _result = post_process_segm_lab2D( - lab, min_solidity, min_area, max_elongation, - return_delIDs=return_delIDs + lab, min_solidity, min_area, max_elongation, return_delIDs=return_delIDs ) if return_delIDs: lab, _delIDs = _result @@ -277,8 +283,7 @@ def post_process_segm(labels, return_delIDs=False, **kwargs): result = labels else: result = post_process_segm_lab2D( - labels, min_solidity, min_area, max_elongation, - return_delIDs=return_delIDs + labels, min_solidity, min_area, max_elongation, return_delIDs=return_delIDs ) if return_delIDs: @@ -288,10 +293,10 @@ def post_process_segm(labels, return_delIDs=False, **kwargs): labels = result return labels + def post_process_segm_lab2D( - lab, min_solidity=None, min_area=None, max_elongation=None, - return_delIDs=False - ): + lab, min_solidity=None, min_area=None, max_elongation=None, return_delIDs=False +): """ function to remove cells with areamax_elongation @@ -317,7 +322,7 @@ def post_process_segm_lab2D( if max_elongation is not None: # NOTE: single pixel horizontal or vertical lines minor_axis_length=0 minor_axis_length = max(1, obj.minor_axis_length) - elongation = obj.major_axis_length/minor_axis_length + elongation = obj.major_axis_length / minor_axis_length if elongation > max_elongation: lab[obj.slice][obj.image] = 0 if return_delIDs: @@ -328,6 +333,7 @@ def post_process_segm_lab2D( else: return lab + def connect_3Dlab_zboundaries(lab): connected_lab = np.zeros_like(lab) rp = skimage.measure.regionprops(lab) @@ -335,59 +341,65 @@ def connect_3Dlab_zboundaries(lab): if len(obj.image) == 1: lab[obj.slice][obj.image] = obj.label continue - + # Take the center non-zero z-area as reference object z_areas = [np.count_nonzero(z_img) for z_img in obj.image] nonzero_z_areas = [z_area for z_area in z_areas if z_area > 0] - nonzero_center_idx = int(len(nonzero_z_areas)/2) + nonzero_center_idx = int(len(nonzero_z_areas) / 2) nonzero_center_z_area = nonzero_z_areas[nonzero_center_idx] center_idx = z_areas.index(nonzero_center_z_area) max_obj_image = obj.image[center_idx] num_zslices = len(obj.image) - + for z in range(num_zslices): connected_lab[obj.slice][z][max_obj_image] = obj.label - + return connected_lab + def stack_2Dlab_to_3D(lab, SizeZ): return np.tile(lab, (SizeZ, 1, 1)) + def track_sub_cell_objects_third_segm_acdc_df( - track_parent_objs_segm_data, parent_objs_acdc_df - ): + track_parent_objs_segm_data, parent_objs_acdc_df +): if parent_objs_acdc_df is None: return - + keys = [] dfs = [] for frame_i, lab in enumerate(track_parent_objs_segm_data): rp = skimage.measure.regionprops(lab) IDs = [obj.label for obj in rp] if frame_i not in parent_objs_acdc_df.index.get_level_values(0): - acdc_df_frame_i = myutils.getBaseAcdcDf(rp) + acdc_df_frame_i = utils.getBaseAcdcDf(rp) else: acdc_df_frame_i = parent_objs_acdc_df.loc[frame_i] cols = acdc_df_frame_i.columns.intersection(all_non_metrics_cols) acdc_df_frame_i = acdc_df_frame_i[cols] - + dfs.append(acdc_df_frame_i) keys.append(frame_i) - third_segm_acdc_df = pd.concat( - dfs, keys=keys, names=['frame_i', 'Cell_ID'] - ) + third_segm_acdc_df = pd.concat(dfs, keys=keys, names=["frame_i", "Cell_ID"]) return third_segm_acdc_df - + + def track_sub_cell_objects_acdc_df( - tracked_subobj_segm_data, subobj_acdc_df, all_old_sub_ids, - all_num_objects_per_cells, SizeT=None, sigProgress=None, - tracked_cells_segm_data=None, cells_acdc_df=None - ): + tracked_subobj_segm_data, + subobj_acdc_df, + all_old_sub_ids, + all_num_objects_per_cells, + SizeT=None, + sigProgress=None, + tracked_cells_segm_data=None, + cells_acdc_df=None, +): if SizeT == 1: tracked_subobj_segm_data = tracked_subobj_segm_data[np.newaxis] if tracked_cells_segm_data is not None: tracked_cells_segm_data = tracked_cells_segm_data[np.newaxis] - + if tracked_cells_segm_data is not None: acdc_df_list = [] sub_acdc_df_list = [] @@ -398,22 +410,20 @@ def track_sub_cell_objects_acdc_df( sub_ids = [sub_obj.label for sub_obj in rp_sub] old_sub_ids = all_old_sub_ids[frame_i] if subobj_acdc_df is None: - sub_acdc_df_frame_i = myutils.getBaseAcdcDf(rp_sub) + sub_acdc_df_frame_i = utils.getBaseAcdcDf(rp_sub) elif frame_i not in subobj_acdc_df.index.get_level_values(0): - sub_acdc_df_frame_i = myutils.getBaseAcdcDf(rp_sub) + sub_acdc_df_frame_i = utils.getBaseAcdcDf(rp_sub) else: - sub_acdc_df_frame_i = ( - subobj_acdc_df.loc[frame_i].rename(index=old_sub_ids) - ) - if 'relative_ID' in sub_acdc_df_frame_i.columns: - sub_acdc_df_frame_i['relative_ID'] = ( - sub_acdc_df_frame_i['relative_ID'].replace(old_sub_ids) - ) - + sub_acdc_df_frame_i = subobj_acdc_df.loc[frame_i].rename(index=old_sub_ids) + if "relative_ID" in sub_acdc_df_frame_i.columns: + sub_acdc_df_frame_i["relative_ID"] = sub_acdc_df_frame_i[ + "relative_ID" + ].replace(old_sub_ids) + cols = sub_acdc_df_frame_i.columns.intersection(all_non_metrics_cols) sub_acdc_df_list.append(sub_acdc_df_frame_i.loc[sub_ids, cols]) keys_sub.append(frame_i) - + if tracked_cells_segm_data is not None: num_objects_per_cells = all_num_objects_per_cells[frame_i] lab = tracked_cells_segm_data[frame_i] @@ -422,100 +432,101 @@ def track_sub_cell_objects_acdc_df( # --> check with `IDs_with_sub_obj = ... if id in lab` IDs_with_sub_obj = [id for id in sub_ids if id in lab] if cells_acdc_df is None: - acdc_df_frame_i = myutils.getBaseAcdcDf(rp) + acdc_df_frame_i = utils.getBaseAcdcDf(rp) else: acdc_df_frame_i = cells_acdc_df.loc[[frame_i]].copy() - + cols = acdc_df_frame_i.columns.intersection(all_non_metrics_cols) acdc_df_frame_i = acdc_df_frame_i[cols] - - acdc_df_frame_i['num_sub_cell_objs_per_cell'] = 0 - acdc_df_frame_i.loc[IDs_with_sub_obj, 'num_sub_cell_objs_per_cell'] = ([ + + acdc_df_frame_i["num_sub_cell_objs_per_cell"] = 0 + acdc_df_frame_i.loc[IDs_with_sub_obj, "num_sub_cell_objs_per_cell"] = [ num_objects_per_cells[id] for id in IDs_with_sub_obj - ]) + ] acdc_df_list.append(acdc_df_frame_i) keys_cells.append(frame_i) if sigProgress is not None: sigProgress.emit(1) - + sub_tracked_acdc_df = pd.concat( - sub_acdc_df_list, keys=keys_sub, names=['frame_i', 'Cell_ID'] + sub_acdc_df_list, keys=keys_sub, names=["frame_i", "Cell_ID"] ) - + tracked_acdc_df = None if tracked_cells_segm_data is not None: tracked_acdc_df = pd.concat( - acdc_df_list, keys=keys_cells, names=['frame_i', 'Cell_ID'] + acdc_df_list, keys=keys_cells, names=["frame_i", "Cell_ID"] ) - + return sub_tracked_acdc_df, tracked_acdc_df - + + def track_sub_cell_objects( - cells_segm_data, - subobj_segm_data, - IoAthresh, - how: Literal[ - 'delete_sub', 'delete_cells', 'delete_both', 'only_track' - ]='delete_sub', - SizeT: int | None =None, - sigProgress=None, - relabel_sub_obj_lab=False - ): - """Function used to track sub-cellular objects and assign the same ID of - the cell they belong to. - - For each sub-cellular object calculate the interesection over area with cells - --> get max IoA in case it is touching more than one cell + cells_segm_data, + subobj_segm_data, + IoAthresh, + how: Literal[ + "delete_sub", "delete_cells", "delete_both", "only_track" + ] = "delete_sub", + SizeT: int | None = None, + sigProgress=None, + relabel_sub_obj_lab=False, +): + """Function used to track sub-cellular objects and assign the same ID of + the cell they belong to. + + For each sub-cellular object calculate the interesection over area with cells + --> get max IoA in case it is touching more than one cell --> assign that cell if IoA >= IoA thresh Args: - cells_segm_data (ndarray): 2D, 3D or 4D array of `int` type cotaining + cells_segm_data (ndarray): 2D, 3D or 4D array of `int` type cotaining the cells segmentation masks. - subobj_segm_data (ndarray): 2D, 3D or 4D array of `int` type cotaining + subobj_segm_data (ndarray): 2D, 3D or 4D array of `int` type cotaining the sub-cellular segmentation masks (e.g., nuclei). - IoAthresh (float): Minimum percentage (0-1) of the sub-cellular object's + IoAthresh (float): Minimum percentage (0-1) of the sub-cellular object's area to assign it to a cell - how (str, optional): Strategy to take with untracked objects. - Options are 'delete_sub' to delete untracked sub-cellular objects, - 'delete_cells' to delete cells that do not have any sub-cellular - object assigned to it, 'delete_both', and 'only_track' to keep - untracked objects. Note that 'delete_sub' is actually not used - because we add tracked sub-objects to an array initialized with + how (str, optional): Strategy to take with untracked objects. + Options are 'delete_sub' to delete untracked sub-cellular objects, + 'delete_cells' to delete cells that do not have any sub-cellular + object assigned to it, 'delete_both', and 'only_track' to keep + untracked objects. Note that 'delete_sub' is actually not used + because we add tracked sub-objects to an array initialized with zeros. Defaults to 'delete_sub'. SizeT (int, optional): Number of frames. Pass `SizeT=1` for non-timelapse data. Defaults to None --> assume first dimension of segm data is SizeT. - sigProgress (qtpy.QtCore.Signal, optional): If provided it will emit - 1 for each complete frame. Used to update GUI progress bars. + sigProgress (qtpy.QtCore.Signal, optional): If provided it will emit + 1 for each complete frame. Used to update GUI progress bars. Defaults to None --> do not emit signal. - - relabel_sub_obj_lab (bool, optional): Re-label sub-cellular objects + + relabel_sub_obj_lab (bool, optional): Re-label sub-cellular objects segmentation labels before tracking them. - + Returns: - tuple: A tuple `(tracked_subobj_segm_data, tracked_cells_segm_data, - all_num_objects_per_cells, old_sub_ids)` where `tracked_subobj_segm_data` is the - segmentation mask of the sub-cellular objects with the same IDs of - the cells they belong to, `tracked_cells_segm_data` is the segmentation - masks of the cells that do have at least on sub-cellular object - (`None` if `how != 'delete_sub'`), `all_num_objects_per_cells` is - a list of dictionary (one per frame) where the dictionaries have + tuple: A tuple `(tracked_subobj_segm_data, tracked_cells_segm_data, + all_num_objects_per_cells, old_sub_ids)` where `tracked_subobj_segm_data` is the + segmentation mask of the sub-cellular objects with the same IDs of + the cells they belong to, `tracked_cells_segm_data` is the segmentation + masks of the cells that do have at least on sub-cellular object + (`None` if `how != 'delete_sub'`), `all_num_objects_per_cells` is + a list of dictionary (one per frame) where the dictionaries have cell IDs as keys and the number of sub-cellular objects per cell as values, and `all_old_sub_ids` is a list of dictionaries (one per frame) - where each dictionary has the new sub-cellular objects' ids as keys and + where each dictionary has the new sub-cellular objects' ids as keys and the old (replaced) ids. - """ + """ if SizeT == 1: cells_segm_data = cells_segm_data[np.newaxis] subobj_segm_data = subobj_segm_data[np.newaxis] tracked_cells_segm_data = None - tracked_subobj_segm_data = np.zeros_like(subobj_segm_data) + tracked_subobj_segm_data = np.zeros_like(subobj_segm_data) segm_data_zip = zip(cells_segm_data, subobj_segm_data) old_tracked_sub_obj_IDs = set() @@ -525,7 +536,7 @@ def track_sub_cell_objects( all_old_sub_ids = [{} for _ in range(len(cells_segm_data))] for frame_i, (lab, lab_sub) in enumerate(segm_data_zip): rp = skimage.measure.regionprops(lab) - num_objects_per_cells = {obj.label:0 for obj in rp} + num_objects_per_cells = {obj.label: 0 for obj in rp} if relabel_sub_obj_lab: lab_sub = skimage.measure.label(lab_sub) rp_sub = skimage.measure.regionprops(lab_sub) @@ -535,41 +546,39 @@ def track_sub_cell_objects( untracked_sub_objs_frame_i = set() for sub_obj in rp_sub: intersect_mask = lab[sub_obj.slice][sub_obj.image] - intersect_IDs, intersections = np.unique( - intersect_mask, return_counts=True - ) + intersect_IDs, intersections = np.unique(intersect_mask, return_counts=True) if intersect_IDs[0] == 0: intersect_IDs = intersect_IDs[1:] intersections = intersections[1:] - + if len(intersect_IDs) == 0: untracked_sub_objs_frame_i.add(sub_obj.label) continue - + argmax = intersections.argmax() intersect_ID = intersect_IDs[argmax] intersection = intersections[argmax] - - IoA = intersection/sub_obj.area + + IoA = intersection / sub_obj.area if IoA < IoAthresh: # Do not add untracked sub-obj untracked_sub_objs_frame_i.add(sub_obj.label) continue - + all_old_sub_ids[frame_i][sub_obj.label] = intersect_ID tracked_lab_sub[sub_obj.slice][sub_obj.image] = intersect_ID num_objects_per_cells[intersect_ID] += 1 old_tracked_sub_obj_IDs.add(sub_obj.label) cells_IDs_with_sub_obj.append(intersect_ID) tracked_sub_obj_original_IDs.append(sub_obj.label) - + all_num_objects_per_cells.append(num_objects_per_cells) all_cells_IDs_with_sub_obj.append(cells_IDs_with_sub_obj) - + if sigProgress is not None: sigProgress.emit(1) - - if how == 'delete_both' or how == 'delete_cells': + + if how == "delete_both" or how == "delete_cells": # Delete cells that do not have a sub-cellular object tracked_cells_segm_data = cells_segm_data.copy() for frame_i, lab in enumerate(tracked_cells_segm_data): @@ -580,11 +589,11 @@ def track_sub_cell_objects( if obj.label in cells_IDs_with_sub_obj: # Cell has sub-object do not delete continue - + tracked_lab[obj.slice][obj.image] = 0 - - if how == 'only_track' or how == 'delete_cells': - # Assign unique IDs to untracked sub-cellular objects and add them + + if how == "only_track" or how == "delete_cells": + # Assign unique IDs to untracked sub-cellular objects and add them # to all_old_sub_ids maxSubObjID = tracked_subobj_segm_data.max() + 1 for sub_obj_ID in np.unique(subobj_segm_data): @@ -594,9 +603,9 @@ def track_sub_cell_objects( if sub_obj_ID in old_tracked_sub_obj_IDs: # sub_obj_ID has already ben tracked continue - + tracked_subobj_segm_data[subobj_segm_data == sub_obj_ID] = maxSubObjID - + for frame_i, lab_sub in enumerate(subobj_segm_data): if sub_obj_ID not in lab_sub: continue @@ -605,59 +614,62 @@ def track_sub_cell_objects( if SizeT == 1: tracked_subobj_segm_data = tracked_subobj_segm_data[0] - if how == 'delete_both': + if how == "delete_both": tracked_cells_segm_data = tracked_cells_segm_data[0] - + return ( - tracked_subobj_segm_data, tracked_cells_segm_data, - all_num_objects_per_cells, all_old_sub_ids + tracked_subobj_segm_data, + tracked_cells_segm_data, + all_num_objects_per_cells, + all_old_sub_ids, ) + def _calc_airy_radius(wavelen, NA): - airy_radius_nm = (1.22 * wavelen)/(2*NA) - airy_radius_um = airy_radius_nm*1E-3 #convert nm to µm + airy_radius_nm = (1.22 * wavelen) / (2 * NA) + airy_radius_um = airy_radius_nm * 1e-3 # convert nm to µm return airy_radius_nm, airy_radius_um + def calc_resolution_limited_vol( - wavelen, NA, yx_resolution_multi, zyx_vox_dim, z_resolution_limit - ): + wavelen, NA, yx_resolution_multi, zyx_vox_dim, z_resolution_limit +): airy_radius_nm, airy_radius_um = _calc_airy_radius(wavelen, NA) - yx_resolution = airy_radius_um*yx_resolution_multi - zyx_resolution = np.asarray( - [z_resolution_limit, yx_resolution, yx_resolution] - ) - zyx_resolution_pxl = zyx_resolution/np.asarray(zyx_vox_dim) + yx_resolution = airy_radius_um * yx_resolution_multi + zyx_resolution = np.asarray([z_resolution_limit, yx_resolution, yx_resolution]) + zyx_resolution_pxl = zyx_resolution / np.asarray(zyx_vox_dim) return zyx_resolution, zyx_resolution_pxl, airy_radius_nm + def align_frames_3D(data, slices=None, user_shifts=None, sigPyqt=None): - registered_shifts = np.zeros((len(data),2), int) + registered_shifts = np.zeros((len(data), 2), int) data_aligned = np.copy(data) for frame_i, frame_V in enumerate(data): if frame_i == 0: # skip first frame - continue + continue if user_shifts is None: slice = slices[frame_i] curr_frame_img = frame_V[slice] - prev_frame_img = data_aligned[frame_i-1, slice] + prev_frame_img = data_aligned[frame_i - 1, slice] shifts = skimage.registration.phase_cross_correlation( prev_frame_img, curr_frame_img - )[0] + )[0] else: shifts = user_shifts[frame_i] - + shifts = shifts.astype(int) aligned_frame_V = np.copy(frame_V) - aligned_frame_V = np.roll(aligned_frame_V, tuple(shifts), axis=(1,2)) + aligned_frame_V = np.roll(aligned_frame_V, tuple(shifts), axis=(1, 2)) # Pad rolled sides with 0s y, x = shifts - if y>0: + if y > 0: aligned_frame_V[:, :y] = 0 - elif y<0: + elif y < 0: aligned_frame_V[:, y:] = 0 - if x>0: + if x > 0: aligned_frame_V[:, :, :x] = 0 - elif x<0: + elif x < 0: aligned_frame_V[:, :, x:] = 0 data_aligned[frame_i] = aligned_frame_V registered_shifts[frame_i] = shifts @@ -669,6 +681,7 @@ def align_frames_3D(data, slices=None, user_shifts=None, sigPyqt=None): # plt.show() return data_aligned, registered_shifts + def revert_alignment(saved_shifts, img_data, sigPyqt=None): shifts = -saved_shifts reverted_data = np.zeros_like(img_data) @@ -683,36 +696,36 @@ def revert_alignment(saved_shifts, img_data, sigPyqt=None): sigPyqt.emit(1) return reverted_data + def align_frames_2D( - data, slices=None, register=True, user_shifts=None, pbar=False, - sigPyqt=None - ): - registered_shifts = np.zeros((len(data),2), int) + data, slices=None, register=True, user_shifts=None, pbar=False, sigPyqt=None +): + registered_shifts = np.zeros((len(data), 2), int) data_aligned = np.copy(data) for frame_i, frame_V in enumerate(tqdm(data, ncols=100)): if frame_i == 0: # skip first frame - continue - + continue + curr_frame_img = frame_V - prev_frame_img = data_aligned[frame_i-1] #previously aligned frame, slice + prev_frame_img = data_aligned[frame_i - 1] # previously aligned frame, slice if user_shifts is None: shifts = skimage.registration.phase_cross_correlation( prev_frame_img, curr_frame_img - )[0] + )[0] else: shifts = user_shifts[frame_i] shifts = shifts.astype(int) aligned_frame_V = np.copy(frame_V) - aligned_frame_V = np.roll(aligned_frame_V, tuple(shifts), axis=(0,1)) + aligned_frame_V = np.roll(aligned_frame_V, tuple(shifts), axis=(0, 1)) y, x = shifts - if y>0: + if y > 0: aligned_frame_V[:y] = 0 - elif y<0: + elif y < 0: aligned_frame_V[y:] = 0 - if x>0: + if x > 0: aligned_frame_V[:, :x] = 0 - elif x<0: + elif x < 0: aligned_frame_V[:, x:] = 0 data_aligned[frame_i] = aligned_frame_V registered_shifts[frame_i] = shifts @@ -724,6 +737,7 @@ def align_frames_2D( # plt.show() return data_aligned, registered_shifts + def label_3d_segm(labels): """Label objects in 3D array that is the result of applying 2D segmentation model on each z-slice. @@ -747,38 +761,37 @@ def label_3d_segm(labels): return labels + def get_obj_contours( - obj=None, - obj_image=None, - obj_bbox=None, - all_external=False, - all=False, - only_longest_contour=True, - local=False, - ): + obj=None, + obj_image=None, + obj_bbox=None, + all_external=False, + all=False, + only_longest_contour=True, + local=False, +): if all: retrieveMode = cv2.RETR_CCOMP else: retrieveMode = cv2.RETR_EXTERNAL - + if obj_image is None: obj_image = obj.image - + obj_image = obj_image.astype(np.uint8) - + if obj_bbox is None and not local: obj_bbox = obj.bbox - - contours, _ = cv2.findContours( - obj_image, retrieveMode, cv2.CHAIN_APPROX_NONE - ) + + contours, _ = cv2.findContours(obj_image, retrieveMode, cv2.CHAIN_APPROX_NONE) if all or all_external: if local: return [np.squeeze(cont, axis=1) for cont in contours] else: min_y, min_x, _, _ = obj_bbox - return [np.squeeze(cont, axis=1)+[min_x, min_y] for cont in contours] - + return [np.squeeze(cont, axis=1) + [min_x, min_y] for cont in contours] + if len(contours) > 1 and only_longest_contour: contours_len = [len(c) for c in contours] max_len_idx = contours_len.index(max(contours_len)) @@ -792,25 +805,29 @@ def get_obj_contours( contour += [min_x, min_y] return contour + def smooth_contours(lab, radius=2): - sigma = 2*radius + 1 + sigma = 2 * radius + 1 smooth_lab = np.zeros_like(lab) for obj in skimage.measure.regionprops(lab): cont = get_obj_contours(obj) - x = cont[:,0] - y = cont[:,1] + x = cont[:, 0] + y = cont[:, 1] x = np.append(x, x[0:sigma]) y = np.append(y, y[0:sigma]) - x = np.round(skimage.filters.gaussian(x, sigma=sigma, - preserve_range=True)).astype(int) - y = np.round(skimage.filters.gaussian(y, sigma=sigma, - preserve_range=True)).astype(int) + x = np.round( + skimage.filters.gaussian(x, sigma=sigma, preserve_range=True) + ).astype(int) + y = np.round( + skimage.filters.gaussian(y, sigma=sigma, preserve_range=True) + ).astype(int) temp_mask = np.zeros(lab.shape, bool) temp_mask[y, x] = True temp_mask = scipy.ndimage.morphology.binary_fill_holes(temp_mask) smooth_lab[temp_mask] = obj.label return smooth_lab + def get_labels_to_IDs_mapper(tracked_labels): labels_to_IDs_mapper = {} uniqueID = 1 @@ -819,40 +836,38 @@ def get_labels_to_IDs_mapper(tracked_labels): if tracked_label in labels_to_IDs_mapper: # Cell existed in the past, ID already stored continue - - parent_label, _, sister_label = tracked_label.rpartition('_') + + parent_label, _, sister_label = tracked_label.rpartition("_") if not parent_label: # Single-cell that was not mapped yet ID = uniqueID uniqueID += 1 - elif sister_label == '0': + elif sister_label == "0": # Sister label == 0 --> keep mother ID - ID = labels_to_IDs_mapper[parent_label].split('_')[0] + ID = labels_to_IDs_mapper[parent_label].split("_")[0] elif ( - sister_label == '1' - and f'{parent_label}_0' not in tracked_frame_labels - ): + sister_label == "1" and f"{parent_label}_0" not in tracked_frame_labels + ): # Daughter cell without a sister --> keep mother ID - ID = labels_to_IDs_mapper[parent_label].split('_')[0] + ID = labels_to_IDs_mapper[parent_label].split("_")[0] else: # Sister label == 1 --> assign new ID ID = uniqueID uniqueID += 1 - labels_to_IDs_mapper[tracked_label] = f'{ID}_{frame_i}' + labels_to_IDs_mapper[tracked_label] = f"{ID}_{frame_i}" return labels_to_IDs_mapper + def annotate_lineage_tree_from_labels(tracked_labels, labels_to_IDs_mapper): - IDs_to_labels_mapper = { - ID:label for label, ID in labels_to_IDs_mapper.items() - } + IDs_to_labels_mapper = {ID: label for label, ID in labels_to_IDs_mapper.items()} cca_dfs = [] keys = [] pbar = tqdm(total=len(tracked_labels), ncols=100) for frame_i, tracked_frame_labels in enumerate(tracked_labels): keys.append(frame_i) IDs = [ - int(labels_to_IDs_mapper[label].split('_')[0]) + int(labels_to_IDs_mapper[label].split("_")[0]) for label in tracked_frame_labels ] if frame_i == 0: @@ -860,9 +875,9 @@ def annotate_lineage_tree_from_labels(tracked_labels, labels_to_IDs_mapper): cca_dfs.append(cca_df) pbar.update() continue - + # Get cca_df from previous frame for existing cells - cca_df = cca_dfs[frame_i-1] + cca_df = cca_dfs[frame_i - 1] is_in_index = cca_df.index.isin(IDs) cca_df = cca_df[is_in_index] new_cells_cca_dfs = [] @@ -870,86 +885,95 @@ def annotate_lineage_tree_from_labels(tracked_labels, labels_to_IDs_mapper): for ID in IDs: if ID in cca_df.index: continue - + newID = ID # New cell --> store cca info - label = IDs_to_labels_mapper[f'{newID}_{frame_i}'] - parent_label, _, sister_label = label.rpartition('_') + label = IDs_to_labels_mapper[f"{newID}_{frame_i}"] + parent_label, _, sister_label = label.rpartition("_") if not parent_label: # New single-cell --> check if it existed in past frames - for i in range(frame_i-2, -1, -1): - past_cca_df = cca_dfs[frame_i-1] + for i in range(frame_i - 2, -1, -1): + past_cca_df = cca_dfs[frame_i - 1] if newID in past_cca_df.index: cca_df_single_ID = past_cca_df.loc[[newID]] break else: cca_df_single_ID = getBaseCca_df([newID]) - cca_df_single_ID.loc[newID, 'emerg_frame_i'] = frame_i + cca_df_single_ID.loc[newID, "emerg_frame_i"] = frame_i else: # New cell resulting from division --> store division - mothID = int(labels_to_IDs_mapper[parent_label].split('_')[0]) + mothID = int(labels_to_IDs_mapper[parent_label].split("_")[0]) cca_df_single_ID = getBaseCca_df([newID]) try: - cca_df.at[mothID, 'generation_num'] += 1 + cca_df.at[mothID, "generation_num"] += 1 except Exception as e: - import pdb; pdb.set_trace() - cca_df.at[mothID, 'division_frame_i'] = frame_i - cca_df.at[mothID, 'relative_ID'] = newID - cca_df_single_ID.at[newID, 'emerg_frame_i'] = frame_i - cca_df_single_ID.at[newID, 'division_frame_i'] = frame_i - cca_df_single_ID.at[newID, 'generation_num'] = 1 - cca_df_single_ID.at[newID, 'relative_ID'] = mothID + import pdb + + pdb.set_trace() + cca_df.at[mothID, "division_frame_i"] = frame_i + cca_df.at[mothID, "relative_ID"] = newID + cca_df_single_ID.at[newID, "emerg_frame_i"] = frame_i + cca_df_single_ID.at[newID, "division_frame_i"] = frame_i + cca_df_single_ID.at[newID, "generation_num"] = 1 + cca_df_single_ID.at[newID, "relative_ID"] = mothID new_cells_cca_dfs.append(cca_df_single_ID) - + cca_df = pd.concat([cca_df, *new_cells_cca_dfs]).sort_index() cca_dfs.append(cca_df) pbar.update() pbar.close() return cca_dfs + def getBaseCca_df(IDs, with_tree_cols=False): row_data = base_cca_dict if with_tree_cols: row_data = {**base_cca_dict, **base_cca_tree_dict} - data = [row_data]*len(IDs) - cca_df = pd.DataFrame(data, index=IDs) + data = [row_data] * len(IDs) + cca_df = pd.DataFrame(data, index=IDs) if with_tree_cols: - cca_df['Cell_ID_tree'] = IDs - cca_df.index.name = 'Cell_ID' + cca_df["Cell_ID_tree"] = IDs + cca_df.index.name = "Cell_ID" return cca_df -def apply_tracking_from_table( - segmData, trackColsInfo, src_df, signal=None, logger=print, - pbarMax=None, debug=False - ): - frameIndexCol = trackColsInfo['frameIndexCol'] - if trackColsInfo['isFirstFrameOne']: +def apply_tracking_from_table( + segmData, + trackColsInfo, + src_df, + signal=None, + logger=print, + pbarMax=None, + debug=False, +): + frameIndexCol = trackColsInfo["frameIndexCol"] + + if trackColsInfo["isFirstFrameOne"]: # Zeroize frames since first frame starts at 1 src_df[frameIndexCol] = src_df[frameIndexCol] - 1 - logger('Applying tracking info...') + logger("Applying tracking info...") grouped = src_df.groupby(frameIndexCol) iterable = grouped if signal is not None else tqdm(grouped, ncols=100) - trackIDsCol = trackColsInfo['trackIDsCol'] - maskIDsCol = trackColsInfo['maskIDsCol'] - xCentroidCol = trackColsInfo['xCentroidCol'] - yCentroidCol = trackColsInfo['yCentroidCol'] - deleteUntrackedIDs = trackColsInfo['deleteUntrackedIDs'] + trackIDsCol = trackColsInfo["trackIDsCol"] + maskIDsCol = trackColsInfo["maskIDsCol"] + xCentroidCol = trackColsInfo["xCentroidCol"] + yCentroidCol = trackColsInfo["yCentroidCol"] + deleteUntrackedIDs = trackColsInfo["deleteUntrackedIDs"] trackedIDsMapper = {} deleteIDsMapper = {} for frame_i, df_frame in iterable: if frame_i == len(segmData): - print('') + print("") logger( - '[WARNING]: segmentation data has less frames than the ' + "[WARNING]: segmentation data has less frames than the " f'frames in the "{frameIndexCol}" column.' ) if signal is not None and pbarMax is not None: - signal.emit(pbarMax-frame_i) + signal.emit(pbarMax - frame_i) break lab = segmData[frame_i] @@ -968,7 +992,7 @@ def apply_tracking_from_table( deleteIDs = [] if deleteUntrackedIDs: - if xCentroidCol == 'None': + if xCentroidCol == "None": maskIDsTracked = df_frame[maskIDsCol].dropna().apply(round).values else: xx = df_frame[xCentroidCol].dropna().apply(round).values @@ -987,16 +1011,16 @@ def apply_tracking_from_table( # First iterate IDs and make sure there are no overlapping IDs for row in df_frame.itertuples(): trackedID = getattr(row, trackIDsCol) - if xCentroidCol == 'None': + if xCentroidCol == "None": maskID = getattr(row, maskIDsCol) else: xc = getattr(row, xCentroidCol) yc = getattr(row, yCentroidCol) maskID = lab[round(yc), round(xc)] - + if not maskID > 0: continue - + if maskID == trackedID: continue @@ -1011,56 +1035,55 @@ def apply_tracking_from_table( if uniqueID in trackIDs: uniqueID = maxTrackID + 1 maxTrackID += 1 - - lab[lab==trackedID] = uniqueID + + lab[lab == trackedID] = uniqueID firstPassMapper_i[int(trackedID)] = int(uniqueID) - if xCentroidCol == 'None': - mask = df_frame[maskIDsCol]==trackedID + if xCentroidCol == "None": + mask = df_frame[maskIDsCol] == trackedID df_frame.loc[mask, maskIDsCol] = int(uniqueID) - + # print(f'First = {int(trackedID)} --> {int(uniqueID)}') if firstPassMapper_i: - trackedIDsMapper[str(frame_i)] = {'first_pass': firstPassMapper_i} + trackedIDsMapper[str(frame_i)] = {"first_pass": firstPassMapper_i} secondPassMapper_i = {} for row in df_frame.itertuples(): trackedID = getattr(row, trackIDsCol) - if xCentroidCol == 'None': + if xCentroidCol == "None": maskID = getattr(row, maskIDsCol) else: xc = getattr(row, xCentroidCol) yc = getattr(row, yCentroidCol) maskID = lab[round(yc), round(xc)] - + if not maskID > 0: continue - + if maskID == trackedID: continue - lab[lab==maskID] = trackedID - secondPassMapper_i[int(maskID)] = int(trackedID) + lab[lab == maskID] = trackedID + secondPassMapper_i[int(maskID)] = int(trackedID) + + # print(f'Second = {int(maskID)} --> {int(trackedID)}') - # print(f'Second = {int(maskID)} --> {int(trackedID)}') - if secondPassMapper_i: if firstPassMapper_i: - trackedIDsMapper[str(frame_i)]['second_pass'] = secondPassMapper_i + trackedIDsMapper[str(frame_i)]["second_pass"] = secondPassMapper_i else: - trackedIDsMapper[str(frame_i)] = {'second_pass': secondPassMapper_i} + trackedIDsMapper[str(frame_i)] = {"second_pass": secondPassMapper_i} if signal is not None: signal.emit(1) # print('*'*40) # import pdb; pdb.set_trace() - + return segmData, trackedIDsMapper, deleteIDsMapper -def apply_trackedIDs_mapper_to_acdc_df( - tracked_IDs_mapper, deleted_IDs_mapper, acdc_df - ): + +def apply_trackedIDs_mapper_to_acdc_df(tracked_IDs_mapper, deleted_IDs_mapper, acdc_df): acdc_dfs_renamed = [] for frame_i, acdc_df_i in acdc_df.groupby(level=0): df_renamed = acdc_df_i @@ -1073,81 +1096,100 @@ def apply_trackedIDs_mapper_to_acdc_df( if mapper_i is None: acdc_dfs_renamed.append(df_renamed) continue - - first_pass = mapper_i.get('first_pass') + + first_pass = mapper_i.get("first_pass") if first_pass is not None: - first_pass = {int(k):int(v) for k,v in first_pass.items()} + first_pass = {int(k): int(v) for k, v in first_pass.items()} # Substitute mask IDs with tracked IDs df_renamed = df_renamed.rename(index=first_pass, level=1) - if 'relative_ID' in df_renamed.columns: - relIDs = df_renamed['relative_ID'] - df_renamed['relative_ID'] = relIDs.replace(tracked_IDs_mapper) - - second_pass = mapper_i.get('second_pass') + if "relative_ID" in df_renamed.columns: + relIDs = df_renamed["relative_ID"] + df_renamed["relative_ID"] = relIDs.replace(tracked_IDs_mapper) + + second_pass = mapper_i.get("second_pass") if second_pass is not None: - second_pass = {int(k):int(v) for k,v in second_pass.items()} + second_pass = {int(k): int(v) for k, v in second_pass.items()} # Substitute mask IDs with tracked IDs df_renamed = df_renamed.rename(index=second_pass, level=1) - if 'relative_ID' in df_renamed.columns: - relIDs = df_renamed['relative_ID'] - df_renamed['relative_ID'] = relIDs.replace(tracked_IDs_mapper) - + if "relative_ID" in df_renamed.columns: + relIDs = df_renamed["relative_ID"] + df_renamed["relative_ID"] = relIDs.replace(tracked_IDs_mapper) + acdc_dfs_renamed.append(df_renamed) - + acdc_df = pd.concat(acdc_dfs_renamed).sort_index() return acdc_df + def _get_cca_info_warn_text( - newID, parentID, frame_i, maskID_colname, x_colname, y_colname, - df_frame, src_df, frame_idx_colname, trackID_colname - ): + newID, + parentID, + frame_i, + maskID_colname, + x_colname, + y_colname, + df_frame, + src_df, + frame_idx_colname, + trackID_colname, +): txt = ( - f'\n[WARNING]: The parent ID of {newID} at frame index ' - f'{frame_i} is {parentID}, but this parent {parentID} ' - f'does not exist at previous frame {frame_i-1} -->\n' - f' --> Setting ID {newID} as a new cell without a parent.\n\n' - 'More details:\n' + f"\n[WARNING]: The parent ID of {newID} at frame index " + f"{frame_i} is {parentID}, but this parent {parentID} " + f"does not exist at previous frame {frame_i - 1} -->\n" + f" --> Setting ID {newID} as a new cell without a parent.\n\n" + "More details:\n" ) try: - df_prev_frame = src_df[src_df[frame_idx_colname] == frame_i-1] + df_prev_frame = src_df[src_df[frame_idx_colname] == frame_i - 1] df_prev_frame = df_prev_frame.set_index(trackID_colname) - if maskID_colname != 'None': + if maskID_colname != "None": maskID_of_newID = df_frame.at[newID, maskID_colname] maskID_of_parentID = df_prev_frame.at[parentID, maskID_colname] details_txt = ( f' - "{maskID_colname}" of ID {newID} = {maskID_of_newID}\n' f' - "{maskID_colname}" of ID {parentID} = {maskID_of_parentID}\n' ) - txt = f'{txt}{details_txt}' - if x_colname != 'None': + txt = f"{txt}{details_txt}" + if x_colname != "None": xc_of_newID = df_frame.at[newID, x_colname] xc_of_parentID = df_prev_frame.at[parentID, x_colname] yc_of_newID = df_frame.at[newID, y_colname] yc_of_parentID = df_prev_frame.at[parentID, y_colname] details_txt = ( - f' - (x,y) coordinates of ID {newID} = {(xc_of_newID, yc_of_newID)}\n' - f' - (x,y) coordinates of ID {parentID} = {(xc_of_parentID, yc_of_parentID)}\n' + f" - (x,y) coordinates of ID {newID} = {(xc_of_newID, yc_of_newID)}\n" + f" - (x,y) coordinates of ID {parentID} = {(xc_of_parentID, yc_of_parentID)}\n" ) - txt = f'{txt}{details_txt}' + txt = f"{txt}{details_txt}" except Exception as e: # import pdb; pdb.set_trace() pass return txt + def add_cca_info_from_parentID_col( - src_df, acdc_df, frame_idx_colname, IDs_colname, parentID_colname, - SizeT, signal=None, trackedData=None, logger=print, - maskID_colname='None', x_colname='None', y_colname='None' - ): + src_df, + acdc_df, + frame_idx_colname, + IDs_colname, + parentID_colname, + SizeT, + signal=None, + trackedData=None, + logger=print, + maskID_colname="None", + x_colname="None", + y_colname="None", +): grouped = src_df.groupby(frame_idx_colname) acdc_dfs = [] keys = [] - iterable = grouped if signal is not None else tqdm(grouped, ncols=100) + iterable = grouped if signal is not None else tqdm(grouped, ncols=100) for frame_i, df_frame in iterable: if frame_i == SizeT: break - + if trackedData is not None: lab = trackedData[frame_i] @@ -1160,13 +1202,13 @@ def add_cca_info_from_parentID_col( oldIDs = [] newIDs = IDs else: - prevIDs = acdc_df.loc[frame_i-1].index.values + prevIDs = acdc_df.loc[frame_i - 1].index.values newIDs = [ID for ID in IDs if ID not in prevIDs] oldIDs = [ID for ID in IDs if ID in prevIDs] - + if oldIDs: # For the oldIDs copy from previous cca_df - prev_acdc_df = acdc_dfs[frame_i-1].filter(oldIDs, axis=0) + prev_acdc_df = acdc_dfs[frame_i - 1].filter(oldIDs, axis=0) cca_df.loc[prev_acdc_df.index] = prev_acdc_df for newID in newIDs: @@ -1174,80 +1216,80 @@ def add_cca_info_from_parentID_col( parentID = int(df_frame.at[newID, parentID_colname]) except Exception as e: parentID = -1 - + parentGenNum = None if parentID > 1: - prev_acdc_df = acdc_dfs[frame_i-1] + prev_acdc_df = acdc_dfs[frame_i - 1] try: - parentGenNum = prev_acdc_df.at[parentID, 'generation_num'] + parentGenNum = prev_acdc_df.at[parentID, "generation_num"] except Exception as e: parentGenNum = None - logger('*'*40) + logger("*" * 40) warn_txt = _get_cca_info_warn_text( - newID, parentID, frame_i, maskID_colname, x_colname, - y_colname, df_frame, src_df, frame_idx_colname, - IDs_colname + newID, + parentID, + frame_i, + maskID_colname, + x_colname, + y_colname, + df_frame, + src_df, + frame_idx_colname, + IDs_colname, ) logger(warn_txt) - logger('*'*40) - + logger("*" * 40) + if parentGenNum is not None: - prentGenNumTree = ( - prev_acdc_df.at[parentID, 'generation_num_tree'] - ) - newGenNumTree = prentGenNumTree+1 - parentRootID = ( - prev_acdc_df.at[parentID, 'root_ID_tree'] - ) - cca_df.at[newID, 'is_history_known'] = True - cca_df.at[newID, 'cell_cycle_stage'] = 'G1' - cca_df.at[newID, 'generation_num'] = parentGenNum+1 - cca_df.at[newID, 'emerg_frame_i'] = frame_i - cca_df.at[newID, 'division_frame_i'] = frame_i - cca_df.at[newID, 'relationship'] = 'mother' - cca_df.at[newID, 'generation_num_tree'] = newGenNumTree - cca_df.at[newID, 'Cell_ID_tree'] = newID - cca_df.at[newID, 'root_ID_tree'] = parentRootID - cca_df.at[newID, 'parent_ID_tree'] = parentID + prentGenNumTree = prev_acdc_df.at[parentID, "generation_num_tree"] + newGenNumTree = prentGenNumTree + 1 + parentRootID = prev_acdc_df.at[parentID, "root_ID_tree"] + cca_df.at[newID, "is_history_known"] = True + cca_df.at[newID, "cell_cycle_stage"] = "G1" + cca_df.at[newID, "generation_num"] = parentGenNum + 1 + cca_df.at[newID, "emerg_frame_i"] = frame_i + cca_df.at[newID, "division_frame_i"] = frame_i + cca_df.at[newID, "relationship"] = "mother" + cca_df.at[newID, "generation_num_tree"] = newGenNumTree + cca_df.at[newID, "Cell_ID_tree"] = newID + cca_df.at[newID, "root_ID_tree"] = parentRootID + cca_df.at[newID, "parent_ID_tree"] = parentID # sister ID is the other cell with the same parent ID - sisterIDmask = ( - (df_frame[parentID_colname] == parentID) - & (df_frame.index != newID) + sisterIDmask = (df_frame[parentID_colname] == parentID) & ( + df_frame.index != newID ) sisterID_df = df_frame[sisterIDmask] if len(sisterID_df) == 1: sisterID = sisterID_df.index[0] else: sisterID = -1 - cca_df.at[newID, 'sister_ID_tree'] = sisterID + cca_df.at[newID, "sister_ID_tree"] = sisterID else: # Set new ID without a parent as history unknown - cca_df.at[newID, 'is_history_known'] = False - cca_df.at[newID, 'cell_cycle_stage'] = 'G1' - cca_df.at[newID, 'generation_num'] = 2 - cca_df.at[newID, 'emerg_frame_i'] = frame_i - cca_df.at[newID, 'division_frame_i'] = -1 - cca_df.at[newID, 'relationship'] = 'mother' - cca_df.at[newID, 'generation_num_tree'] = 1 - cca_df.at[newID, 'Cell_ID_tree'] = newID - cca_df.at[newID, 'root_ID_tree'] = newID - cca_df.at[newID, 'parent_ID_tree'] = -1 - cca_df.at[newID, 'sister_ID_tree'] = -1 - + cca_df.at[newID, "is_history_known"] = False + cca_df.at[newID, "cell_cycle_stage"] = "G1" + cca_df.at[newID, "generation_num"] = 2 + cca_df.at[newID, "emerg_frame_i"] = frame_i + cca_df.at[newID, "division_frame_i"] = -1 + cca_df.at[newID, "relationship"] = "mother" + cca_df.at[newID, "generation_num_tree"] = 1 + cca_df.at[newID, "Cell_ID_tree"] = newID + cca_df.at[newID, "root_ID_tree"] = newID + cca_df.at[newID, "parent_ID_tree"] = -1 + cca_df.at[newID, "sister_ID_tree"] = -1 + acdc_df_i[cca_df.columns] = cca_df acdc_dfs.append(acdc_df_i) keys.append(frame_i) if signal is not None: signal.emit(1) - + if acdc_dfs: - acdc_df_with_cca = pd.concat( - acdc_dfs, keys=keys, names=['frame_i', 'Cell_ID'] - ) + acdc_df_with_cca = pd.concat(acdc_dfs, keys=keys, names=["frame_i", "Cell_ID"]) if len(acdc_df_with_cca) == len(acdc_df): # All frames from existing acdc_df were cca annotated in src_table acdc_df_with_cca = pd.concat( - acdc_dfs, keys=keys, names=['frame_i', 'Cell_ID'] + acdc_dfs, keys=keys, names=["frame_i", "Cell_ID"] ) return acdc_df_with_cca else: @@ -1258,10 +1300,10 @@ def add_cca_info_from_parentID_col( else: # No annotations present in src_table return acdc_df - - + return acdc_df + def cca_df_to_acdc_df(cca_df, rp, acdc_df=None): if acdc_df is None: IDs = [] @@ -1275,138 +1317,136 @@ def cca_df_to_acdc_df(cca_df, rp, acdc_df=None): is_cell_excluded_li.append(0) xx_centroid.append(int(obj.centroid[1])) yy_centroid.append(int(obj.centroid[0])) - acdc_df = pd.DataFrame({ - 'Cell_ID': IDs, - 'is_cell_dead': is_cell_dead_li, - 'is_cell_excluded': is_cell_excluded_li, - 'x_centroid': xx_centroid, - 'y_centroid': yy_centroid, - 'was_manually_edited': is_cell_excluded_li.copy() - }).set_index('Cell_ID') - - acdc_df = acdc_df.join(cca_df, how='left') + acdc_df = pd.DataFrame( + { + "Cell_ID": IDs, + "is_cell_dead": is_cell_dead_li, + "is_cell_excluded": is_cell_excluded_li, + "x_centroid": xx_centroid, + "y_centroid": yy_centroid, + "was_manually_edited": is_cell_excluded_li.copy(), + } + ).set_index("Cell_ID") + + acdc_df = acdc_df.join(cca_df, how="left") return acdc_df + class LineageTree: def __init__(self, acdc_df, logging_func=print, debug=False) -> None: - acdc_df = load.pd_bool_and_float_to_int_to_str(acdc_df, colsToCastInt=[]).reset_index() + acdc_df = load.pd_bool_and_float_to_int_to_str( + acdc_df, colsToCastInt=[] + ).reset_index() acdc_df = self._normalize_gen_num(acdc_df).reset_index() - acdc_df = acdc_df.drop(columns=['index', 'level_0'], errors='ignore') - self.acdc_df = acdc_df.set_index(['frame_i', 'Cell_ID']) + acdc_df = acdc_df.drop(columns=["index", "level_0"], errors="ignore") + self.acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]) self.df = acdc_df.copy() self.cca_df_colnames = cca_df_colnames self.log = logging_func self.debug = debug - + def build(self): - self.log('Building lineage tree...') + self.log("Building lineage tree...") try: - df_G1 = self.acdc_df[self.acdc_df['cell_cycle_stage'] == 'G1'] + df_G1 = self.acdc_df[self.acdc_df["cell_cycle_stage"] == "G1"] self.df_G1 = df_G1[self.cca_df_colnames].copy() - self.new_col_loc = df_G1.columns.get_loc('division_frame_i') + 1 + self.new_col_loc = df_G1.columns.get_loc("division_frame_i") + 1 except Exception as error: return error - + self.df = self.add_lineage_tree_table_to_acdc_df() - self.log('Lineage tree built successfully!') - + self.log("Lineage tree built successfully!") + def _normalize_gen_num(self, acdc_df): - ''' + """ Since the user is allowed to start the generation_num of unknown mother cells with any number we need to normalise this to 2 --> Create a new 'normalized_gen_num' column where we make sure that mother cells with unknown history have 'normalized_gen_num' starting from 2 (required by the logic of _build_tree) - ''' - acdc_df = acdc_df.drop(columns=['level_0', 'index'], errors='ignore') - acdc_df = ( - acdc_df.reset_index() - .drop(columns='index', errors='ignore') - ) + """ + acdc_df = acdc_df.drop(columns=["level_0", "index"], errors="ignore") + acdc_df = acdc_df.reset_index().drop(columns="index", errors="ignore") # Get the starting generation number of the unknown mother cells - df_emerg = acdc_df.groupby('Cell_ID').agg('first') - history_unknown_mask = df_emerg['is_history_known'] == 0 - moth_mask = df_emerg['relationship'] == 'mother' + df_emerg = acdc_df.groupby("Cell_ID").agg("first") + history_unknown_mask = df_emerg["is_history_known"] == 0 + moth_mask = df_emerg["relationship"] == "mother" df_emerg_moth_uknown = df_emerg[(history_unknown_mask) & (moth_mask)] # Get the difference from 2 - df_diff = 2 - df_emerg_moth_uknown['generation_num'] + df_diff = 2 - df_emerg_moth_uknown["generation_num"] # Build a normalizing df with the number to be added for each cell - normalizing_df = pd.DataFrame( - data=acdc_df[['frame_i', 'Cell_ID']] - ).set_index('Cell_ID') - normalizing_df['gen_num_diff'] = 0 - normalizing_df.loc[df_emerg_moth_uknown.index, 'gen_num_diff'] = ( - df_diff + normalizing_df = pd.DataFrame(data=acdc_df[["frame_i", "Cell_ID"]]).set_index( + "Cell_ID" ) + normalizing_df["gen_num_diff"] = 0 + normalizing_df.loc[df_emerg_moth_uknown.index, "gen_num_diff"] = df_diff # Add the normalising_df to create the new normalized_gen_num col - normalizing_df = normalizing_df.reset_index().set_index( - ['frame_i', 'Cell_ID'] - ) - acdc_df = acdc_df.set_index(['frame_i', 'Cell_ID']) - acdc_df['normalized_gen_num'] = ( - acdc_df['generation_num'] + normalizing_df['gen_num_diff'] + normalizing_df = normalizing_df.reset_index().set_index(["frame_i", "Cell_ID"]) + acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]) + acdc_df["normalized_gen_num"] = ( + acdc_df["generation_num"] + normalizing_df["gen_num_diff"] ) return acdc_df - + def _build_tree(self, gen_df, ID): current_ID = gen_df.index.get_level_values(1)[0] if current_ID != ID: return gen_df - ''' + """ Add generation number tree: --> At the start of a branch we set the generation number as either 0 (if also start of tree) or relative ID generation number tree --> This value called gen_num_relID_tree is added to the current generation_num - ''' + """ ID_slice = pd.IndexSlice[:, ID] - relID = gen_df.loc[ID_slice, 'relative_ID'].iloc[0] + relID = gen_df.loc[ID_slice, "relative_ID"].iloc[0] relID_slice = pd.IndexSlice[:, relID] - gen_nums_tree = gen_df['generation_num_tree'].values + gen_nums_tree = gen_df["generation_num_tree"].values start_frame_i = gen_df.index.get_level_values(0)[0] if self.is_new_tree: - try: - gen_num_relID_tree = self.df_G1.at[ - (start_frame_i, relID), 'generation_num_tree' - ] - 1 + try: + gen_num_relID_tree = ( + self.df_G1.at[(start_frame_i, relID), "generation_num_tree"] - 1 + ) except Exception as e: gen_num_relID_tree = 0 self.branch_start_gen_num[ID] = gen_num_relID_tree else: gen_num_relID_tree = self.branch_start_gen_num[ID] - + updated_gen_nums_tree = gen_nums_tree + gen_num_relID_tree - gen_df['generation_num_tree'] = updated_gen_nums_tree - - '''Assign unique ID every consecutive division''' + gen_df["generation_num_tree"] = updated_gen_nums_tree + + """Assign unique ID every consecutive division""" if self.is_new_tree: # Keep start ID for cell at the top of the branch Cell_ID_tree = ID - gen_df['Cell_ID_tree'] = [ID]*len(gen_df) + gen_df["Cell_ID_tree"] = [ID] * len(gen_df) else: Cell_ID_tree = self.uniqueID self.uniqueID += 1 - - gen_df['Cell_ID_tree'] = [Cell_ID_tree]*len(gen_df) - ''' + gen_df["Cell_ID_tree"] = [Cell_ID_tree] * len(gen_df) + + """ Assign parent ID --> existing ID between relID and ID in prev gen_num_tree - ''' - gen_num_tree = gen_df.loc[ID_slice, 'generation_num_tree'].iloc[0] - + """ + gen_num_tree = gen_df.loc[ID_slice, "generation_num_tree"].iloc[0] + prev_gen_G1_existing = True - if gen_num_tree > 1: + if gen_num_tree > 1: prev_gen_num_tree = gen_num_tree - 1 try: # Parent ID is the Cell_ID_tree that current ID had in prev gen prev_gen_df = self.gen_dfs[(ID, prev_gen_num_tree)] except Exception as e: - # Parent ID is the Cell_ID_tree that the relative of the + # Parent ID is the Cell_ID_tree that the relative of the # current ID had in prev gen try: prev_gen_df = self.gen_dfs[(relID, prev_gen_num_tree)] @@ -1415,21 +1455,21 @@ def _build_tree(self, gen_df, ID): # starts at 2 (cell appeared in S and then started G1) prev_gen_G1_existing = False pass - + if prev_gen_G1_existing: try: - parent_ID = prev_gen_df.loc[relID_slice, 'Cell_ID_tree'].iloc[0] + parent_ID = prev_gen_df.loc[relID_slice, "Cell_ID_tree"].iloc[0] except Exception as e: - parent_ID = prev_gen_df.loc[ID_slice, 'Cell_ID_tree'].iloc[0] - gen_df['parent_ID_tree'] = parent_ID + parent_ID = prev_gen_df.loc[ID_slice, "Cell_ID_tree"].iloc[0] + gen_df["parent_ID_tree"] = parent_ID else: # Cell appeared in S in previous frame - idx = (start_frame_i-1, ID) - was_bud = self.acdc_df.loc[idx, 'relationship'] == 'bud' + idx = (start_frame_i - 1, ID) + was_bud = self.acdc_df.loc[idx, "relationship"] == "bud" if was_bud: - # This is a bud of the first frame where the algorithm + # This is a bud of the first frame where the algorithm # thinks is a new tree --> correct - parent_ID = self.acdc_df.loc[idx, 'relative_ID'] + parent_ID = self.acdc_df.loc[idx, "relative_ID"] try: self.branch_start_gen_num[ID] = ( self.branch_start_gen_num[parent_ID] + 2 @@ -1439,20 +1479,20 @@ def _build_tree(self, gen_df, ID): self.branch_start_gen_num[ID] = gen_num_parentID_tree else: parent_ID = ID - + Cell_ID_tree = self.uniqueID self.uniqueID += 1 - gen_df['Cell_ID_tree'] = [Cell_ID_tree]*len(gen_df) + gen_df["Cell_ID_tree"] = [Cell_ID_tree] * len(gen_df) else: parent_ID = -1 - - ''' + + """ Assign root ID --> at start of branch (self.is_new_tree is True) the root_ID is ID if gen_num_tree == 1 otherwise we go back until the parent_ID == -1 --> store this and use when traversing branch - ''' + """ if self.is_new_tree: if gen_num_tree == 2 and prev_gen_G1_existing: root_ID_tree = parent_ID @@ -1460,14 +1500,14 @@ def _build_tree(self, gen_df, ID): prev_gen_num_tree = gen_num_tree - 1 prev_gen_idx = parent_ID parent_ID_df = self.gen_dfs_by_ID_tree[prev_gen_idx] - root_ID_tree = parent_ID_df['parent_ID_tree'].iloc[0] + root_ID_tree = parent_ID_df["parent_ID_tree"].iloc[0] while prev_gen_num_tree > 2: prev_gen_num_tree -= 1 prev_gen_idx = root_ID_tree parent_ID_df = self.gen_dfs_by_ID_tree[prev_gen_idx] - root_ID_tree = parent_ID_df['parent_ID_tree'].iloc[0] + root_ID_tree = parent_ID_df["parent_ID_tree"].iloc[0] if root_ID_tree == -1: - root_ID_tree = parent_ID_df['root_ID_tree'].iloc[0] + root_ID_tree = parent_ID_df["root_ID_tree"].iloc[0] elif parent_ID > 0: # We started a new tree of a bud that appeared already in S # --> the root ID is the parent_ID (mother cell) @@ -1477,49 +1517,51 @@ def _build_tree(self, gen_df, ID): self.root_IDs_trees[ID] = root_ID_tree else: root_ID_tree = self.root_IDs_trees[ID] - - gen_df['root_ID_tree'] = root_ID_tree + + gen_df["root_ID_tree"] = root_ID_tree if self.debug: printl( - f'Traversing ID: {ID}\n' - f'Parent ID: {parent_ID}\n' - f'Is new tree: {self.is_new_tree}\n' - f'Relative ID: {relID}\n' - f'Relative ID generation num tree: {gen_num_relID_tree}\n' - f'Generation number tree: {gen_num_tree}\n' - f'New cell ID tree: {Cell_ID_tree}\n' - f'Start branch gen number: {self.branch_start_gen_num[ID]}\n' - f'Start of tree frame n.: {start_frame_i+1}\n' - f'root_ID_tree: {root_ID_tree}' + f"Traversing ID: {ID}\n" + f"Parent ID: {parent_ID}\n" + f"Is new tree: {self.is_new_tree}\n" + f"Relative ID: {relID}\n" + f"Relative ID generation num tree: {gen_num_relID_tree}\n" + f"Generation number tree: {gen_num_tree}\n" + f"New cell ID tree: {Cell_ID_tree}\n" + f"Start branch gen number: {self.branch_start_gen_num[ID]}\n" + f"Start of tree frame n.: {start_frame_i + 1}\n" + f"root_ID_tree: {root_ID_tree}" ) - import pdb; pdb.set_trace() - + import pdb + + pdb.set_trace() + self.gen_dfs[(ID, gen_num_tree)] = gen_df self.gen_dfs_by_ID_tree[Cell_ID_tree] = gen_df - + self.is_new_tree = False - + return gen_df - + def add_lineage_tree_table_to_acdc_df(self): Cell_ID_tree_vals = self.df_G1.index.get_level_values(1) - self.df_G1['Cell_ID_tree'] = Cell_ID_tree_vals - self.df_G1['parent_ID_tree'] = -1 - self.df_G1['root_ID_tree'] = -1 - self.df_G1['sister_ID_tree'] = -1 - - self.df_G1['generation_num_tree'] = self.df_G1['generation_num'] - + self.df_G1["Cell_ID_tree"] = Cell_ID_tree_vals + self.df_G1["parent_ID_tree"] = -1 + self.df_G1["root_ID_tree"] = -1 + self.df_G1["sister_ID_tree"] = -1 + + self.df_G1["generation_num_tree"] = self.df_G1["generation_num"] + # For cells that starts at ccs = 2 subtract 1 - history_unknown_mask = self.df_G1['is_history_known'] == 0 - ccs_greater_one_mask = self.df_G1['generation_num'] > 1 + history_unknown_mask = self.df_G1["is_history_known"] == 0 + ccs_greater_one_mask = self.df_G1["generation_num"] > 1 subtract_gen_num_mask = (history_unknown_mask) & (ccs_greater_one_mask) - self.df_G1.loc[subtract_gen_num_mask, 'generation_num_tree'] = ( - self.df_G1.loc[subtract_gen_num_mask, 'generation_num'] - 1 + self.df_G1.loc[subtract_gen_num_mask, "generation_num_tree"] = ( + self.df_G1.loc[subtract_gen_num_mask, "generation_num"] - 1 ) - - cols_tree = [col for col in self.df_G1.columns if col.endswith('_tree')] + + cols_tree = [col for col in self.df_G1.columns if col.endswith("_tree")] frames_idx = self.df_G1.dropna().index.get_level_values(0).unique() not_annotated_IDs = self.df_G1.index.get_level_values(1).unique().to_list() @@ -1534,24 +1576,22 @@ def add_lineage_tree_table_to_acdc_df(self): if not not_annotated_IDs: # Built tree for every ID --> exit break - + df_i = self.df_G1.loc[frame_i] IDs = np.sort(df_i.index.array) for ID in IDs: if ID not in not_annotated_IDs: # Tree already built in previous frame iteration --> skip continue - + self.is_new_tree = True # Iterate the branch till the end - df_tree_iter = ( - self.df_G1 - .groupby(['Cell_ID', 'generation_num'], group_keys=False) - .apply(self._build_tree, ID) - ) + df_tree_iter = self.df_G1.groupby( + ["Cell_ID", "generation_num"], group_keys=False + ).apply(self._build_tree, ID) self.df_G1 = df_tree_iter not_annotated_IDs.remove(ID) - + self._add_sister_ID() for c, col_tree in enumerate(cols_tree): @@ -1563,107 +1603,110 @@ def add_lineage_tree_table_to_acdc_df(self): self._build_tree_S(cols_tree) return self.acdc_df - + def _err_msg_add_sister_ID(self, relative_ID, frame_i, df): ID = df.index.get_level_values(1)[0] txt = ( - f'There is a problem with Cell ID {relative_ID} ' - f'at frame n. {frame_i+1}. ' - 'Make sure that annotations are correct before trying again.\n\n' - 'More info: error happened when trying to set the `sister_ID` of ' - f'cell ID {ID} to {relative_ID}. It might be that ID {relative_ID} ' - f'is not in G1 at frame n. {frame_i+1}' + f"There is a problem with Cell ID {relative_ID} " + f"at frame n. {frame_i + 1}. " + "Make sure that annotations are correct before trying again.\n\n" + "More info: error happened when trying to set the `sister_ID` of " + f"cell ID {ID} to {relative_ID}. It might be that ID {relative_ID} " + f"is not in G1 at frame n. {frame_i + 1}" ) return txt - + def _add_sister_ID(self): - grouped_ID_tree = self.df_G1.groupby('Cell_ID_tree') + grouped_ID_tree = self.df_G1.groupby("Cell_ID_tree") for Cell_ID_tree, df in grouped_ID_tree: - relative_ID = df['relative_ID'].iloc[0] + relative_ID = df["relative_ID"].iloc[0] if relative_ID == -1: continue start_frame_i = df.index.get_level_values(0)[0] try: sister_ID_tree = self.df_G1.at[ - (start_frame_i, relative_ID), 'Cell_ID_tree' + (start_frame_i, relative_ID), "Cell_ID_tree" ] except KeyError as error: raise KeyError( self._err_msg_add_sister_ID(relative_ID, start_frame_i, df) ) from error - - self.df_G1.loc[df.index, 'sister_ID_tree'] = sister_ID_tree - + + self.df_G1.loc[df.index, "sister_ID_tree"] = sister_ID_tree + def _build_tree_S(self, cols_tree): - '''In S we consider the bud still the same as the mother in the tree - --> either copy the tree information from the G1 phase or, in case + """In S we consider the bud still the same as the mother in the tree + --> either copy the tree information from the G1 phase or, in case the cell doesn't have a G1 (before S) because it appeared already in S, copy from the current S phase (e.g., Cell_ID_tree = Cell_ID) - ''' - S_mask = self.acdc_df['cell_cycle_stage'] == 'S' + """ + S_mask = self.acdc_df["cell_cycle_stage"] == "S" df_S = self.acdc_df[S_mask].copy() - gen_acdc_df = self.acdc_df.reset_index().set_index( - ['Cell_ID', 'generation_num', 'cell_cycle_stage'] - ).sort_index() + gen_acdc_df = ( + self.acdc_df.reset_index() + .set_index(["Cell_ID", "generation_num", "cell_cycle_stage"]) + .sort_index() + ) for row_S in df_S.itertuples(): relationship = row_S.relationship - if relationship == 'mother': + if relationship == "mother": idx_ID = row_S.Index[1] idx_gen_num = row_S.generation_num else: idx_ID = row_S.relative_ID frame_i = row_S.Index[0] - idx_gen_num = self.acdc_df.at[(frame_i, idx_ID), 'generation_num'] + idx_gen_num = self.acdc_df.at[(frame_i, idx_ID), "generation_num"] cc_df = gen_acdc_df.loc[(idx_ID, idx_gen_num)] - if 'G1' in cc_df.index: - row_G1 = cc_df.loc[['G1']].iloc[0] + if "G1" in cc_df.index: + row_G1 = cc_df.loc[["G1"]].iloc[0] for col_tree in cols_tree: self.acdc_df.loc[row_S.Index, col_tree] = row_G1[col_tree] else: # Cell that was already in S at appearance --> There is not G1 to copy from - sister_ID = cc_df.iloc[0]['relative_ID'] - self.acdc_df.loc[row_S.Index, 'Cell_ID_tree'] = idx_ID - self.acdc_df.loc[row_S.Index, 'parent_ID_tree'] = -1 - self.acdc_df.loc[row_S.Index, 'root_ID_tree'] = idx_ID - self.acdc_df.loc[row_S.Index, 'generation_num_tree'] = 1 - self.acdc_df.loc[row_S.Index, 'sister_ID_tree'] = sister_ID - + sister_ID = cc_df.iloc[0]["relative_ID"] + self.acdc_df.loc[row_S.Index, "Cell_ID_tree"] = idx_ID + self.acdc_df.loc[row_S.Index, "parent_ID_tree"] = -1 + self.acdc_df.loc[row_S.Index, "root_ID_tree"] = idx_ID + self.acdc_df.loc[row_S.Index, "generation_num_tree"] = 1 + self.acdc_df.loc[row_S.Index, "sister_ID_tree"] = sister_ID + def newick(self): - if 'Cell_ID_tree' not in self.acdc_df.columns: + if "Cell_ID_tree" not in self.acdc_df.columns: self.build() - + df = self.df.reset_index() - + def plot(self): - if 'Cell_ID_tree' not in self.acdc_df.columns: + if "Cell_ID_tree" not in self.acdc_df.columns: self.build() - + df = self.df.reset_index() - + def to_arboretum(self, rebuild=False): # See https://github.com/lowe-lab-ucl/arboretum/blob/main/examples/show_sample_data.py - if 'Cell_ID_tree' not in self.acdc_df.columns or rebuild: + if "Cell_ID_tree" not in self.acdc_df.columns or rebuild: self.build() df = self.df.reset_index() - tracks_cols = ['Cell_ID_tree', 'frame_i', 'y_centroid', 'x_centroid'] + tracks_cols = ["Cell_ID_tree", "frame_i", "y_centroid", "x_centroid"] tracks_data = df[tracks_cols].to_numpy() - graph_df = df.groupby('Cell_ID_tree').agg('first').reset_index() + graph_df = df.groupby("Cell_ID_tree").agg("first").reset_index() graph_df = graph_df[graph_df.parent_ID_tree > 0] graph = { - child_ID:[parent_ID] for child_ID, parent_ID - in zip(graph_df.Cell_ID_tree, graph_df.parent_ID_tree) + child_ID: [parent_ID] + for child_ID, parent_ID in zip( + graph_df.Cell_ID_tree, graph_df.parent_ID_tree + ) } - properties = pd.DataFrame({ - 't': df.frame_i, - 'root': df.root_ID_tree, - 'parent': df.parent_ID_tree - }) + properties = pd.DataFrame( + {"t": df.frame_i, "root": df.root_ID_tree, "parent": df.parent_ID_tree} + ) return tracks_data, graph, properties + def brownian(x0, n, dt, delta, out=None): """ Generate an instance of Brownian motion (i.e. the Wiener process): @@ -1675,7 +1718,7 @@ def brownian(x0, n, dt, delta, out=None): independence of N on different time intervals; that is, if [t0, t1) and [t2, t3) are disjoint intervals, then N(a, b; t0, t1) and N(a, b; t2, t3) are independent. - + Written as an iteration scheme, X(t + dt) = X(t) + N(0, delta**2 * dt; t, t+dt) @@ -1705,7 +1748,7 @@ def brownian(x0, n, dt, delta, out=None): Returns ------- A numpy array of floats with shape `x0.shape + (n,)`. - + Note that the initial value `x0` is not included in the returned array. """ @@ -1713,14 +1756,14 @@ def brownian(x0, n, dt, delta, out=None): # For each element of x0, generate a sample of n numbers from a # normal distribution. - r = norm.rvs(size=x0.shape + (n,), scale=delta*sqrt(dt)) + r = norm.rvs(size=x0.shape + (n,), scale=delta * sqrt(dt)) # If `out` was not given, create an output array. if out is None: out = np.empty(r.shape) # This computes the Brownian motion by forming the cumulative sum of - # the random samples. + # the random samples. np.cumsum(r, axis=-1, out=out) # Add the initial condition. @@ -1728,229 +1771,219 @@ def brownian(x0, n, dt, delta, out=None): return out + def preprocess_multi_pos_from_recipe( - image_data: Iterable[np.ndarray], - recipe: List[Dict[str, Any]] - ): - pbar = tqdm(total=len(image_data), unit='Position', ncols=100) + image_data: Iterable[np.ndarray], recipe: List[Dict[str, Any]] +): + pbar = tqdm(total=len(image_data), unit="Position", ncols=100) preprocessed_data = [] for pos_i, image in enumerate(image_data): - preprocessed_image = preprocess_zstack_from_recipe( - image, recipe, pbar_pos=1 - ) + preprocessed_image = preprocess_zstack_from_recipe(image, recipe, pbar_pos=1) preprocessed_data.append(preprocessed_image) pbar.update() pbar.close() return preprocessed_data -def preprocess_video_from_recipe( - image, recipe: List[Dict[str, Any]], pbar_pos=0 - ): + +def preprocess_video_from_recipe(image, recipe: List[Dict[str, Any]], pbar_pos=0): if image.ndim < 3: raise TypeError( - 'Only 3D or 4D videos allowed. ' - f'Input image has {image.ndim} dimensions!' + f"Only 3D or 4D videos allowed. Input image has {image.ndim} dimensions!" ) preprocessed_image = image for step in recipe: - method = step['method'] - func = PREPROCESS_MAPPER[method]['function'] - kwargs = step['kwargs'] + method = step["method"] + func = PREPROCESS_MAPPER[method]["function"] + kwargs = step["kwargs"] argspecs = inspect.getfullargspec(func) is_func_time_capable = False is_func_zstack_capable = False for arg in argspecs.args: - if arg == 'apply_to_all_frames': + if arg == "apply_to_all_frames": is_func_time_capable = True - elif arg == 'apply_to_all_zslices': + elif arg == "apply_to_all_zslices": is_func_zstack_capable = True - + if is_func_time_capable and is_func_zstack_capable: kwargs["apply_to_all_zslices"] = True kwargs["apply_to_all_frames"] = True - preprocessed_image = func( - preprocessed_image, - **kwargs - ) + preprocessed_image = func(preprocessed_image, **kwargs) else: pbar = tqdm( - total=len(preprocessed_image), unit='frame', ncols=100, - position=pbar_pos + total=len(preprocessed_image), + unit="frame", + ncols=100, + position=pbar_pos, ) for frame_i, frame_img in enumerate(preprocessed_image): if frame_img.ndim == 3: preprocessed_img = preprocess_zstack_from_recipe( - frame_img, (step,), pbar_pos=pbar_pos+1 + frame_img, (step,), pbar_pos=pbar_pos + 1 ) if preprocessed_img.dtype != preprocessed_image.dtype: - preprocessed_image = ( - preprocessed_image.astype(preprocessed_img.dtype) + preprocessed_image = preprocessed_image.astype( + preprocessed_img.dtype ) preprocessed_image[frame_i] = preprocessed_img else: - preprocessed_img = preprocess_image_from_recipe( - frame_img, (step,) - ) + preprocessed_img = preprocess_image_from_recipe(frame_img, (step,)) if preprocessed_img.dtype != preprocessed_image.dtype: - preprocessed_image = ( - preprocessed_image.astype(preprocessed_img.dtype) + preprocessed_image = preprocessed_image.astype( + preprocessed_img.dtype ) preprocessed_image[frame_i] = preprocessed_img pbar.update() pbar.close() - + return preprocessed_image - -def preprocess_zstack_from_recipe( - image, recipe: List[Dict[str, Any]], pbar_pos=0 - ): + + +def preprocess_zstack_from_recipe(image, recipe: List[Dict[str, Any]], pbar_pos=0): if image.ndim != 3: raise TypeError( - 'Only 3D z-stack images allowed. ' - f'Input image has {image.ndim} dimensions!' + f"Only 3D z-stack images allowed. Input image has {image.ndim} dimensions!" ) - + preprocessed_image = image for step in recipe: - method = step['method'] - func = PREPROCESS_MAPPER[method]['function'] - kwargs = step['kwargs'] + method = step["method"] + func = PREPROCESS_MAPPER[method]["function"] + kwargs = step["kwargs"] argspecs = inspect.getfullargspec(func) is_func_zstack_capable = False for arg in argspecs.args: - if arg == 'apply_to_all_zslices': + if arg == "apply_to_all_zslices": is_func_zstack_capable = True break - + if is_func_zstack_capable: - kwargs['apply_to_all_zslices'] = True - preprocessed_image = func( - preprocessed_image, **kwargs - ) + kwargs["apply_to_all_zslices"] = True + preprocessed_image = func(preprocessed_image, **kwargs) else: pbar = tqdm( - total=len(preprocessed_image), unit='z-slice', ncols=100, - position=pbar_pos + total=len(preprocessed_image), + unit="z-slice", + ncols=100, + position=pbar_pos, ) for z_slice, img in enumerate(preprocessed_image): preprocessed_img = func(img, **kwargs) if preprocessed_img.dtype != preprocessed_image.dtype: - preprocessed_image = ( - preprocessed_image.astype(preprocessed_img.dtype) + preprocessed_image = preprocessed_image.astype( + preprocessed_img.dtype ) preprocessed_image[z_slice] = preprocessed_img pbar.update() pbar.close() - + return preprocessed_image + all_kwargs_to_pop = ( - ('apply_to_all_zslices',), - ('apply_to_all_frames',), - ('apply_to_all_frames', 'apply_to_all_zslices'), + ("apply_to_all_zslices",), + ("apply_to_all_frames",), + ("apply_to_all_frames", "apply_to_all_zslices"), ) + + def preprocess_image_from_recipe(image, recipe: List[Dict[str, Any]]): preprocessed_image = image for step in recipe: - method = step['method'] - func = PREPROCESS_MAPPER[method]['function'] - kwargs = step['kwargs'] + method = step["method"] + func = PREPROCESS_MAPPER[method]["function"] + kwargs = step["kwargs"] for kwargs_to_pop in all_kwargs_to_pop: test_kwargs = kwargs.copy() try: preprocessed_image = func(preprocessed_image, **test_kwargs) break except TypeError as err: - if not 'unexpected keyword argument' in str(err): + if not "unexpected keyword argument" in str(err): raise err - + for kwarg_to_pop in kwargs_to_pop: test_kwargs.pop(kwarg_to_pop, None) - + return preprocessed_image + def pop_signals_kwarg_if_not_needed(func, kwargs): args = inspect.getfullargspec(func).args - if 'signals' in args: + if "signals" in args: return kwargs - - kwargs.pop('signals', None) + + kwargs.pop("signals", None) return kwargs + def segm_model_segment( - model, image, model_kwargs, frame_i=None, preproc_recipe=None, - is_timelapse_model_and_data=False, posData=None, start_z_slice=0, - ): + model, + image, + model_kwargs, + frame_i=None, + preproc_recipe=None, + is_timelapse_model_and_data=False, + posData=None, + start_z_slice=0, +): if preproc_recipe is not None: if is_timelapse_model_and_data: filtered_image = np.zeros(image.shape) for i, img in enumerate(image): img = preprocess_image_from_recipe(img, preproc_recipe) filtered_image[i] = img - image = filtered_image # .astype(image.dtype) + image = filtered_image # .astype(image.dtype) else: image = preprocess_image_from_recipe(image, preproc_recipe) - + if is_timelapse_model_and_data: - model_kwargs = pop_signals_kwarg_if_not_needed( - model.segment3DT, model_kwargs - ) + model_kwargs = pop_signals_kwarg_if_not_needed(model.segment3DT, model_kwargs) segm_data = model.segment3DT(image, **model_kwargs) - return segm_data - - model_kwargs = pop_signals_kwarg_if_not_needed( - model.segment, model_kwargs - ) + return segm_data + + model_kwargs = pop_signals_kwarg_if_not_needed(model.segment, model_kwargs) # Some models have `start_z_slice` kwarg try: lab = model.segment( - image, - frame_i=frame_i, - posData=posData, - start_z_slice=start_z_slice, - **model_kwargs + image, + frame_i=frame_i, + posData=posData, + start_z_slice=start_z_slice, + **model_kwargs, ) return lab except TypeError as err: - if str(err).find('unexpected keyword argument') == -1: + if str(err).find("unexpected keyword argument") == -1: # Raise error since it's not about the missing posData kwarg raise err - + # Some models have posData as kwarg and frame_i as second arg try: - lab = model.segment( - image, - frame_i=frame_i, - posData=posData, - **model_kwargs - ) + lab = model.segment(image, frame_i=frame_i, posData=posData, **model_kwargs) return lab except TypeError as err: - if str(err).find('unexpected keyword argument') == -1: + if str(err).find("unexpected keyword argument") == -1: # Raise error since it's not about the missing posData kwarg raise err - + # Some models have frame_i as second arg try: - lab = model.segment( - image, - frame_i=frame_i, - **model_kwargs - ) + lab = model.segment(image, frame_i=frame_i, **model_kwargs) return lab except TypeError as err: pass - + lab = model.segment(image, **model_kwargs) return lab + def filter_segm_objs_from_table_coords(lab, df): cols = [] if lab.ndim == 3: - cols = ['z'] - cols.extend(('y', 'x')) + cols = ["z"] + cols.extend(("y", "x")) coords = df[cols].values.T IDs_to_keep = lab[tuple(coords)] mask_to_keep = np.isin(lab, IDs_to_keep) @@ -1958,16 +1991,16 @@ def filter_segm_objs_from_table_coords(lab, df): filtered_lab[~mask_to_keep] = 0 return filtered_lab + def tracker_track( - segm_data, tracker, track_params, intensity_img=None, - logger_func=print - ): + segm_data, tracker, track_params, intensity_img=None, logger_func=print +): if intensity_img is not None: args_to_try = (tuple(), (intensity_img,)) else: args_to_try = (tuple(),) - kwargs_to_remove = ('', 'signals') + kwargs_to_remove = ("", "signals") for args, kwarg_to_remove in product(args_to_try, kwargs_to_remove): kwargs = track_params.copy() kwargs.pop(kwarg_to_remove, None) @@ -1975,17 +2008,18 @@ def tracker_track( tracked_video = tracker.track(segm_data, *args, **kwargs) return tracked_video except Exception as err: - is_unexpected_kwarg = (str(err).find( - "got an unexpected keyword argument 'signals'" - ) != -1) - is_missing_arg = (str(err).find( - "missing 1 required positional argument:" - ) != -1) + is_unexpected_kwarg = ( + str(err).find("got an unexpected keyword argument 'signals'") != -1 + ) + is_missing_arg = ( + str(err).find("missing 1 required positional argument:") != -1 + ) if is_unexpected_kwarg or is_missing_arg: continue else: raise err + def _relabel_sequential(segm_data): relabelled, fw, inv = skimage.segmentation.relabel_sequential(segm_data) newIDs = list(inv.in_values) @@ -1994,6 +2028,7 @@ def _relabel_sequential(segm_data): oldIDs.append(-1) return relabelled, oldIDs, newIDs + def _relabel_sequential_timelapse(segm_data): """Relabel IDs sequentially frame-by-frame @@ -2005,14 +2040,14 @@ def _relabel_sequential_timelapse(segm_data): Returns ------- 3-tuple of (numpy.ndarray, list, list) - First element is the relabelled segmentation data. + First element is the relabelled segmentation data. Second element is the list of the old IDs. Third element is the list of the new IDs. - """ + """ mapper_old_to_new_IDs = {-1: -1} relabelled = np.zeros_like(segm_data) lastID = 0 - pbar = tqdm(total=len(segm_data), ncols=100, unit=' frame') + pbar = tqdm(total=len(segm_data), ncols=100, unit=" frame") for frame_i, lab in enumerate(segm_data): if frame_i == 0: relab, oldIDs_i, newIDs_i = _relabel_sequential(lab) @@ -2020,7 +2055,7 @@ def _relabel_sequential_timelapse(segm_data): lastID = max(newIDs_i) relabelled[frame_i] = relab continue - + rp = skimage.measure.regionprops(lab) for obj in rp: newID = mapper_old_to_new_IDs.get(obj.label) @@ -2038,6 +2073,7 @@ def _relabel_sequential_timelapse(segm_data): newIDs = list(mapper_old_to_new_IDs.values()) return relabelled, oldIDs, newIDs + def relabel_sequential(segm_data, is_timelapse=False): if is_timelapse: relabelled, oldIDs, newIDs = _relabel_sequential_timelapse(segm_data) @@ -2045,48 +2081,42 @@ def relabel_sequential(segm_data, is_timelapse=False): relabelled, oldIDs, newIDs = _relabel_sequential(segm_data) return relabelled, oldIDs, newIDs + class CcaIntegrityChecker: def __init__(self, cca_df, lab, lab_IDs): self.lab = lab self.lab_IDs = lab_IDs self.cca_df = cca_df - self.cca_df_S = cca_df[cca_df['cell_cycle_stage'] == 'S'] - self.cca_df_G1 = cca_df[cca_df['cell_cycle_stage'] == 'G1'] + self.cca_df_S = cca_df[cca_df["cell_cycle_stage"] == "S"] + self.cca_df_G1 = cca_df[cca_df["cell_cycle_stage"] == "G1"] def get_num_mothers_and_buds_in_S(self): cca_df_S = self.cca_df_S - cca_df_S_buds = cca_df_S[cca_df_S['relationship'] == 'bud'] - cca_df_S_mothers = cca_df_S[cca_df_S['relationship'] == 'mother'] + cca_df_S_buds = cca_df_S[cca_df_S["relationship"] == "bud"] + cca_df_S_mothers = cca_df_S[cca_df_S["relationship"] == "mother"] num_buds = len(cca_df_S_buds) num_mothers = len(cca_df_S_mothers) return num_mothers, num_buds - + def get_mother_IDs_with_multiple_buds(self): cca_df_S = self.cca_df_S - cca_df_S_buds = cca_df_S[cca_df_S['relationship'] == 'bud'] - mothers_of_buds = cca_df_S_buds['relative_ID'] - mother_IDs_with_multiple_buds = ( - mothers_of_buds[mothers_of_buds.duplicated()] - ) + cca_df_S_buds = cca_df_S[cca_df_S["relationship"] == "bud"] + mothers_of_buds = cca_df_S_buds["relative_ID"] + mother_IDs_with_multiple_buds = mothers_of_buds[mothers_of_buds.duplicated()] return mother_IDs_with_multiple_buds.values - + def get_IDs_cycles_without_G1(self, global_cca_df): - global_cca_df_moths_hist_known = ( - global_cca_df[ - (global_cca_df['relationship'] == 'mother') - & (global_cca_df['is_history_known'] > 0) - ] - ) + global_cca_df_moths_hist_known = global_cca_df[ + (global_cca_df["relationship"] == "mother") + & (global_cca_df["is_history_known"] > 0) + ] grouped_cycles = global_cca_df_moths_hist_known.reset_index().groupby( - ['Cell_ID', 'generation_num'] + ["Cell_ID", "generation_num"] ) - G1_not_present_mask = ( - grouped_cycles['cell_cycle_stage'] - .agg(lambda x: ~x.eq('G1').any()) - ) - IDs_cycles_without_G1 = ( - G1_not_present_mask[G1_not_present_mask].index.to_list() + G1_not_present_mask = grouped_cycles["cell_cycle_stage"].agg( + lambda x: ~x.eq("G1").any() ) + IDs_cycles_without_G1 = G1_not_present_mask[G1_not_present_mask].index.to_list() return IDs_cycles_without_G1 def get_IDs_gen_num_will_divide_wrong(self, global_cca_df): @@ -2094,150 +2124,146 @@ def get_IDs_gen_num_will_divide_wrong(self, global_cca_df): global_cca_df ) return IDs_will_divide_wrong - + def get_bud_IDs_gen_num_nonzero(self): cca_df_S = self.cca_df_S - cca_df_S_buds = cca_df_S[cca_df_S['relationship'] == 'bud'] - bud_IDs_gen_num_nonzero = ( - cca_df_S_buds[cca_df_S_buds['generation_num'] != 0] - .index.to_list() - ) + cca_df_S_buds = cca_df_S[cca_df_S["relationship"] == "bud"] + bud_IDs_gen_num_nonzero = cca_df_S_buds[ + cca_df_S_buds["generation_num"] != 0 + ].index.to_list() return bud_IDs_gen_num_nonzero - + def get_moth_IDs_gen_num_non_greater_one(self): cca_df_S = self.cca_df_S - cca_df_S_moths = cca_df_S[cca_df_S['relationship'] == 'mother'] - moth_IDs_gen_num_non_greater_one = ( - cca_df_S_moths[cca_df_S_moths['generation_num'] < 1] - .index.to_list() - ) + cca_df_S_moths = cca_df_S[cca_df_S["relationship"] == "mother"] + moth_IDs_gen_num_non_greater_one = cca_df_S_moths[ + cca_df_S_moths["generation_num"] < 1 + ].index.to_list() return moth_IDs_gen_num_non_greater_one - + def get_buds_G1(self): cca_df_S = self.cca_df_S - cca_df_S_buds = cca_df_S[cca_df_S['relationship'] == 'bud'] - buds_G1 = ( - cca_df_S_buds[cca_df_S_buds['cell_cycle_stage'] == 'G1'] - .index.to_list() - ) + cca_df_S_buds = cca_df_S[cca_df_S["relationship"] == "bud"] + buds_G1 = cca_df_S_buds[ + cca_df_S_buds["cell_cycle_stage"] == "G1" + ].index.to_list() return buds_G1 - + def get_cell_S_rel_ID_zero(self): cca_df_S = self.cca_df_S - cell_S_rel_ID_zero = ( - cca_df_S[cca_df_S['relative_ID'] < 1] - .index.to_list() - ) + cell_S_rel_ID_zero = cca_df_S[cca_df_S["relative_ID"] < 1].index.to_list() return cell_S_rel_ID_zero - + def get_ID_rel_ID_mismatches(self): ID_rel_ID_mismatches = [] for row in self.cca_df_S.itertuples(): ID = row.Index relID = row.relative_ID - relID_of_relID = self.cca_df.at[relID, 'relative_ID'] - + relID_of_relID = self.cca_df.at[relID, "relative_ID"] + if relID_of_relID != ID: ID_rel_ID_mismatches.append((ID, relID, relID_of_relID)) - + return ID_rel_ID_mismatches def get_lonely_cells_in_S(self): lonely_cells_in_S = [] for row in self.cca_df_S.itertuples(): - ID = row.Index + ID = row.Index if row.relative_ID in self.lab_IDs: continue - + if ID not in self.lab_IDs: # Mother-bud pair gone entirely continue - + # ID is in S but its relative_ID does not exist in lab lonely_cells_in_S.append(ID) - + return lonely_cells_in_S + def cellpose_v3_run_denoise( - image, - run_params, - denoise_model=None, - init_params=None, - timelapse=False, - isZstack=False - ): + image, + run_params, + denoise_model=None, + init_params=None, + timelapse=False, + isZstack=False, +): if denoise_model is None: - from cellacdc.models.cellpose_v3 import _denoise + from cellacdc.segmenters.cellpose_v3 import _denoise + denoise_model = _denoise.CellposeDenoiseModel(**init_params) - - denoised_img = denoise_model.run(image, timelapse=timelapse,isZstack=isZstack, **run_params)# may have to give rgb stuff too! + + denoised_img = denoise_model.run( + image, timelapse=timelapse, isZstack=isZstack, **run_params + ) # may have to give rgb stuff too! return denoised_img -def closest_n_divisible_by_m(n, m) : + +def closest_n_divisible_by_m(n, m): # Find the quotient q = int(n / m) - + # 1st possible closest number n1 = m * q - + # 2nd possible closest number - if((n * m) > 0) : - n2 = (m * (q + 1)) - else : - n2 = (m * (q - 1)) - + if (n * m) > 0: + n2 = m * (q + 1) + else: + n2 = m * (q - 1) + # if true, then n1 is the required closest number - if (abs(n - n1) < abs(n - n2)) : + if abs(n - n1) < abs(n - n2): return n1 - - # else n2 is the required closest number + + # else n2 is the required closest number return n2 + def fucci_pipeline_executor_map(input, **filter_kwargs): frame_i, (ch1_img, ch2_img) = input - - ch1_img = skimage.exposure.rescale_intensity( - ch1_img, out_range=(0, 0.5) - ) - ch2_img = skimage.exposure.rescale_intensity( - ch2_img, out_range=(0, 0.5) - ) - + + ch1_img = skimage.exposure.rescale_intensity(ch1_img, out_range=(0, 0.5)) + ch2_img = skimage.exposure.rescale_intensity(ch2_img, out_range=(0, 0.5)) + sum_img = ch1_img + ch2_img - + processed_img = preprocess.fucci_filter(sum_img, **filter_kwargs) - + return frame_i, processed_img + def preprocess_exceutor_map( - input: Tuple[int, np.ndarray], - recipe: List[Dict[str, Any]]=None, - ): + input: Tuple[int, np.ndarray], + recipe: List[Dict[str, Any]] = None, +): if recipe is None: return input - + frame_i, image = input if image.ndim == 3: preprocessed_image = preprocess_zstack_from_recipe(image, recipe) else: preprocessed_image = preprocess_image_from_recipe(image, recipe) - + return frame_i, preprocessed_image + def preprocess_image_from_recipe_multithread( - image: np.ndarray, - recipe: List[Dict[str, Any]], - n_threads: int=None - ): + image: np.ndarray, recipe: List[Dict[str, Any]], n_threads: int = None +): preprocessed_image = image for step in recipe: - method = step['method'] - func = PREPROCESS_MAPPER[method]['function'] - kwargs = step['kwargs'] + method = step["method"] + func = PREPROCESS_MAPPER[method]["function"] + kwargs = step["kwargs"] argspecs = inspect.getfullargspec(func) is_func_time_capable = False for arg in argspecs.args: - if arg == 'apply_to_all_frames': + if arg == "apply_to_all_frames": is_func_time_capable = True break @@ -2250,10 +2276,7 @@ def preprocess_image_from_recipe_multithread( pbar = tqdm(total=num_frames, ncols=100) with ThreadPoolExecutor(max_workers=n_threads) as executor: iterable = enumerate(preprocessed_image) - func = partial( - preprocess_exceutor_map, - recipe=(step,) - ) + func = partial(preprocess_exceutor_map, recipe=(step,)) futures = {executor.submit(func, arg) for arg in iterable} for future in as_completed(futures): try: @@ -2264,28 +2287,27 @@ def preprocess_image_from_recipe_multithread( printl(e) raise e pbar.close() - + return preprocessed_image + def combine_channels_multithread( - steps: Dict[str, Dict[str, Any]], - formula: str, - images_paths: List[str], - keep_input_data_type: bool, - save_filepaths: List[str]=None, - n_threads: int=None, - signals=None, - logger_func: Callable=None, - output_as_segm: bool = False, - ): + steps: Dict[str, Dict[str, Any]], + formula: str, + images_paths: List[str], + keep_input_data_type: bool, + save_filepaths: List[str] = None, + n_threads: int = None, + signals=None, + logger_func: Callable = None, + output_as_segm: bool = False, +): with ThreadPoolExecutor(max_workers=n_threads) as executor: if signals: signals.initProgressBar.emit(len(images_paths)) else: - pbar = tqdm( - total=len(images_paths), ncols=100, desc='Combining channels' - ) - + pbar = tqdm(total=len(images_paths), ncols=100, desc="Combining channels") + func = partial( combine_channels_executor_map, keep_input_data_type=keep_input_data_type, @@ -2303,21 +2325,21 @@ def combine_channels_multithread( else: pbar.update() -def combine_channels_multithread_return_imgs( - steps: Dict[str, Dict[str, Any]], - data: list['load.loadData'], - keep_input_data_type: bool, - keys: List[Tuple[Union[int, None], Union[int, None], Union[int, None]]], - n_threads: int=None, - signals=None, - logger_func: Callable=None, - output_as_segm: bool = False, - formula: str = None, - ): +def combine_channels_multithread_return_imgs( + steps: Dict[str, Dict[str, Any]], + data: list["load.loadData"], + keep_input_data_type: bool, + keys: List[Tuple[Union[int, None], Union[int, None], Union[int, None]]], + n_threads: int = None, + signals=None, + logger_func: Callable = None, + output_as_segm: bool = False, + formula: str = None, +): total = len(keys) - + output_imgs = [None] * total keys_out = [0] * total res_i = 0 @@ -2327,7 +2349,7 @@ def combine_channels_multithread_return_imgs( if signals: signals.initProgressBar.emit(total) else: - pbar = tqdm(total=len(total), ncols=100, desc='Combining channels') + pbar = tqdm(total=len(total), ncols=100, desc="Combining channels") func = partial( combine_channels_executor_map_return_img, data=data, @@ -2364,41 +2386,44 @@ def combine_channels_multithread_return_imgs( return output_imgs, keys_out + def combine_channels_executor_map(args, **kwargs): images_path, save_filepath = args - kwargs['save_filepath'] = save_filepath - kwargs['images_path'] = images_path + kwargs["save_filepath"] = save_filepath + kwargs["images_path"] = images_path return combine_channels_func(**kwargs) + def combine_channels_executor_map_return_img(args, **kwargs): key = args - kwargs['key'] = key + kwargs["key"] = key return combine_channels_func(**kwargs) + def _combine_channels_multiplier_apply(binarize, input_img): - if binarize == 'binarize': - input_img = (input_img > 0) - elif binarize == 'inverse binarize': + if binarize == "binarize": + input_img = input_img > 0 + elif binarize == "inverse binarize": input_img = ~(input_img > 0) return input_img + def _get_img_from_data_key(data, key, num_dim, seg=False): - n_dim_data = data.ndim - 1 # - 1 dim for x y + n_dim_data = data.ndim - 1 # - 1 dim for x y n_dim_key = len(key) if seg and n_dim_key == n_dim_data + 1: # here a 2D segmentation is used for 3D image return data[key[1]] - if num_dim == 3: # t x y + if num_dim == 3: # t x y return data[key[1]] - elif num_dim == 4: # t z x y + elif num_dim == 4: # t z x y return data[key[1]][key[2]] - elif num_dim == 2: # z x y, but t is always there + elif num_dim == 2: # z x y, but t is always there return data[0] else: - raise ValueError( - f'Invalid number of dimensions in img_data. {num_dim}' - ) - + raise ValueError(f"Invalid number of dimensions in img_data. {num_dim}") + + def _log_printl_fallback(txt, logger_func): if logger_func is not None: try: @@ -2408,7 +2433,8 @@ def _log_printl_fallback(txt, logger_func): pass else: printl(txt) - + + def _add_missing_dims(segm, target_shape, use_broadcast=False): """ Expand segmentation by replicating existing data along missing dims. @@ -2426,7 +2452,7 @@ def _add_missing_dims(segm, target_shape, use_broadcast=False): segm_expanded : np.ndarray text : str """ - text = '' + text = "" if segm.shape == target_shape: return segm, text @@ -2434,8 +2460,8 @@ def _add_missing_dims(segm, target_shape, use_broadcast=False): # 2D -> 3D (Y,X -> Z,Y,X) if segm.ndim == 2 and len(target_shape) == 3: text = ( - 'The segmentation mask is 2D but the image data is 3D. ' - 'Replicating mask across Z.' + "The segmentation mask is 2D but the image data is 3D. " + "Replicating mask across Z." ) y, x = segm.shape z = target_shape[0] @@ -2452,8 +2478,8 @@ def _add_missing_dims(segm, target_shape, use_broadcast=False): # 3D -> 4D (T,Y,X -> T,Z,Y,X) if segm.ndim == 3 and len(target_shape) == 4: text = ( - 'The segmentation mask is 2Dt but the image data is 3Dt. ' - 'Replicating mask across Z.' + "The segmentation mask is 2Dt but the image data is 3Dt. " + "Replicating mask across Z." ) t, y, x = segm.shape z = target_shape[1] @@ -2467,26 +2493,27 @@ def _add_missing_dims(segm, target_shape, use_broadcast=False): return segm_expanded, text - raise ValueError( - f'Invalid shape. segm: {segm.shape}, target: {target_shape}' - ) - + raise ValueError(f"Invalid shape. segm: {segm.shape}, target: {target_shape}") + + def _verify_shape_ndim(img_data, target_dims, target_shape, is_segm=False): def _shape_mismatch_error(indices, img_data, target_shape): mismatches = [ - f' axis {i}: got {img_data.shape[idx]}, expected {target_shape[i]}' + f" axis {i}: got {img_data.shape[idx]}, expected {target_shape[i]}" for idx, i in indices if img_data.shape[idx] != target_shape[i] ] if mismatches: raise ValueError( - f'Shape mismatch:\n' + '\n'.join(mismatches) + '\n' - f'img shape={img_data.shape}, target shape={target_shape}' + f"Shape mismatch:\n" + "\n".join(mismatches) + "\n" + f"img shape={img_data.shape}, target shape={target_shape}" ) if img_data.ndim == target_dims: # Check all axes directly - _shape_mismatch_error([(i, i) for i in range(target_dims)], img_data, target_shape) + _shape_mismatch_error( + [(i, i) for i in range(target_dims)], img_data, target_shape + ) elif is_segm and img_data.ndim + 1 == target_dims: if target_dims == 3: @@ -2497,15 +2524,16 @@ def _shape_mismatch_error(indices, img_data, target_shape): _shape_mismatch_error([(0, 0), (1, 2), (2, 3)], img_data, target_shape) else: raise ValueError( - f'Invalid segmentation mask dimensions: ' - f'got {img_data.ndim}D mask for {target_dims}D target.' + f"Invalid segmentation mask dimensions: " + f"got {img_data.ndim}D mask for {target_dims}D target." ) else: raise ValueError( - f'Invalid image dimensions: ' - f'got {img_data.ndim}D data, expected {target_dims}D.' + f"Invalid image dimensions: " + f"got {img_data.ndim}D data, expected {target_dims}D." ) + def _update_target_shape_target_dims(target_dims, target_shape, img_data): if target_dims < img_data.ndim: target_dims_old = target_dims @@ -2516,7 +2544,7 @@ def _update_target_shape_target_dims(target_dims, target_shape, img_data): img_shape = [0, 0] + img_shape elif img_data.ndim == 3: img_shape = img_shape[:1] + [0] + img_shape[1:] - + try: for i in range(len(target_shape)): if target_shape[i] < img_shape[i]: @@ -2526,64 +2554,66 @@ def _update_target_shape_target_dims(target_dims, target_shape, img_data): raise err return target_dims, target_shape - + def combine_channels_func( - steps: Dict[str, Dict[str, Any]], - keep_input_data_type: bool, - save_filepath: str=None, - return_img: bool=False, - logger_func: Callable=None, - images_path: str = None, - key: str = None, - data = None, - output_as_segm: bool = False, - formula: str = None, - ):# -> tuple[Any | ndarray, str, str | Any] | None: + steps: Dict[str, Dict[str, Any]], + keep_input_data_type: bool, + save_filepath: str = None, + return_img: bool = False, + logger_func: Callable = None, + images_path: str = None, + key: str = None, + data=None, + output_as_segm: bool = False, + formula: str = None, +): # -> tuple[Any | ndarray, str, str | Any] | None: if not save_filepath and not return_img: - raise ValueError('Either save_filepath must be provided or return_img must be true') - + raise ValueError( + "Either save_filepath must be provided or return_img must be true" + ) + provided = sum(x is not None for x in (data, key)) provided += return_img if provided not in (0, 3): - raise ValueError('return_img, data, and key must all be provided together or not at all') - + raise ValueError( + "return_img, data, and key must all be provided together or not at all" + ) + fluo_ch_data_list = dict() segm_ch_data_list = dict() - - channel_names = [step['channel'] for step in steps.values()] + + channel_names = [step["channel"] for step in steps.values()] channel_keys = steps.keys() - segm_channels, fluo_channel_names, current_segm = myutils.separate_fluo_segment_channels(channel_names) + segm_channels, fluo_channel_names, current_segm = ( + utils.separate_fluo_segment_channels(channel_names) + ) original_dtype = None - + target_dims = 0 target_shape = [0, 0, 0, 0] if data is None: for channel in fluo_channel_names: - ch_filepath = load.get_filepath_from_endname( - images_path, channel - ) + ch_filepath = load.get_filepath_from_endname(images_path, channel) ch_image_data = load.load_image_file(ch_filepath) if original_dtype is None: original_dtype = ch_image_data.dtype - - ch_image_data = myutils.img_to_float(ch_image_data) + + ch_image_data = utils.img_to_float(ch_image_data) target_dims, target_shape = _update_target_shape_target_dims( - target_dims, target_shape, ch_image_data + target_dims, target_shape, ch_image_data ) fluo_ch_data_list[channel] = ch_image_data for channel in segm_channels: - ch_filepath = load.get_filepath_from_endname( - images_path, channel - ) + ch_filepath = load.get_filepath_from_endname(images_path, channel) ch_image_data = load.load_image_file(ch_filepath) if original_dtype is None: original_dtype = ch_image_data.dtype ch_image_data = ch_image_data.astype(np.uint32) target_dims, target_shape = _update_target_shape_target_dims( - target_dims, target_shape, ch_image_data + target_dims, target_shape, ch_image_data ) segm_ch_data_list[channel] = ch_image_data else: @@ -2593,7 +2623,9 @@ def combine_channels_func( n_dim -= 1 # if posData.SizeT == 1: # actually t is always there, we only need to subtract for curr. segm # n_dim -= 1 - is_2D_segm_on_3D = posData.SizeZ != 1 and posData.allData_li[0]['labels'].ndim == 2 + is_2D_segm_on_3D = ( + posData.SizeZ != 1 and posData.allData_li[0]["labels"].ndim == 2 + ) fluo_data_dict = posData.fluo_data_dict segm_data_dict = posData.ol_labels_data imgs_path = posData.images_path @@ -2604,17 +2636,21 @@ def combine_channels_func( ) channel_full_name = pathlib.Path(channel_path).stem # remove the file extension - - channel_img_data = _get_img_from_data_key(fluo_data_dict[channel_full_name], key, n_dim) + + channel_img_data = _get_img_from_data_key( + fluo_data_dict[channel_full_name], key, n_dim + ) if original_dtype is None: original_dtype = channel_img_data.dtype - channel_img_data_float = myutils.img_to_float(channel_img_data) + channel_img_data_float = utils.img_to_float(channel_img_data) target_dims, target_shape = _update_target_shape_target_dims( target_dims, target_shape, channel_img_data_float ) fluo_ch_data_list[channel] = channel_img_data_float for channel in segm_channels: - channel_img_data = _get_img_from_data_key(segm_data_dict[channel], key, n_dim, seg=True) + channel_img_data = _get_img_from_data_key( + segm_data_dict[channel], key, n_dim, seg=True + ) if original_dtype is None: original_dtype = channel_img_data.dtype channel_img_data_int = channel_img_data.astype(np.uint32) @@ -2622,22 +2658,24 @@ def combine_channels_func( target_dims, target_shape, channel_img_data_int ) segm_ch_data_list[channel] = channel_img_data_int - if current_segm: # here we dont need to get/appply target dim, as we already ignore z slice key if segm is 2D and image 3D (time is always treated differently!) - if posData.SizeT == 1: # actually t is always there, we only need to subtract for curr. segm + if current_segm: # here we dont need to get/appply target dim, as we already ignore z slice key if segm is 2D and image 3D (time is always treated differently!) + if ( + posData.SizeT == 1 + ): # actually t is always there, we only need to subtract for curr. segm n_dim -= 1 if posData.frame_i != key[1]: if n_dim == 4 and not is_2D_segm_on_3D: - channel_img_data = posData.allData_li[key[1]]['labels'][key[2]] + channel_img_data = posData.allData_li[key[1]]["labels"][key[2]] elif n_dim == 4 and is_2D_segm_on_3D: - channel_img_data = posData.allData_li[key[1]]['labels'] + channel_img_data = posData.allData_li[key[1]]["labels"] elif n_dim == 3 and posData.SizeZ == 1: - channel_img_data = posData.allData_li[key[1]]['labels'] + channel_img_data = posData.allData_li[key[1]]["labels"] elif n_dim == 3 and posData.SizeT == 1 and not is_2D_segm_on_3D: - channel_img_data = posData.allData_li[0]['labels'][key[2]] + channel_img_data = posData.allData_li[0]["labels"][key[2]] elif n_dim == 3 and posData.SizeT == 1 and is_2D_segm_on_3D: - channel_img_data = posData.allData_li[0]['labels'] + channel_img_data = posData.allData_li[0]["labels"] else: - channel_img_data = posData.allData_li[0]['labels'] + channel_img_data = posData.allData_li[0]["labels"] else: if n_dim == 4 and not is_2D_segm_on_3D: channel_img_data = posData.lab[key[2]] @@ -2653,14 +2691,14 @@ def combine_channels_func( target_dims, target_shape = _update_target_shape_target_dims( target_dims, target_shape, channel_img_data_int ) - segm_ch_data_list['current segm.'] = channel_img_data_int - + segm_ch_data_list["current segm."] = channel_img_data_int + target_shape_new = [] for dim in target_shape: if dim == 0: continue target_shape_new.append(dim) - + target_shape = tuple(target_shape_new) # _log_printl_fallback(f'target shape: {target_shape}', logger_func) for i, ch in zip(channel_keys, channel_names): @@ -2673,31 +2711,32 @@ def combine_channels_func( ch_image_data, text = _add_missing_dims(ch_image_data, target_shape) if text: _log_printl_fallback(text, logger_func) - _verify_shape_ndim(ch_image_data, target_dims, target_shape, is_segm=False) # false since we already expanded + _verify_shape_ndim( + ch_image_data, target_dims, target_shape, is_segm=False + ) # false since we already expanded segm_ch_data_list[ch] = ch_image_data else: raise ValueError(f'Channel "{ch}" not found.') - if steps[i]['channel'] != ch: - raise ValueError(f'Channel "{ch}" not found.') - steps[i]['channel_data'] = ch_image_data - + if steps[i]["channel"] != ch: + raise ValueError(f'Channel "{ch}" not found.') + steps[i]["channel_data"] = ch_image_data + for i, step_info in steps.items(): - binarize = step_info['binarize'] - steps[i]['channel_data'] = _combine_channels_multiplier_apply( - binarize, step_info['channel_data'] + binarize = step_info["binarize"] + steps[i]["channel_data"] = _combine_channels_multiplier_apply( + binarize, step_info["channel_data"] ) - norm_min, norm_max = step_info['min_val'], step_info['max_val'] + norm_min, norm_max = step_info["min_val"], step_info["max_val"] # use rescale_intensity to normalize if norm_min == 0 and norm_max == 1: - continue # cases where either the fields where disabled/reset or default, where we already normalized - steps[i]['channel_data'] = skimage.exposure.rescale_intensity( - steps[i]['channel_data'], - out_range=(norm_min, norm_max) + continue # cases where either the fields where disabled/reset or default, where we already normalized + steps[i]["channel_data"] = skimage.exposure.rescale_intensity( + steps[i]["channel_data"], out_range=(norm_min, norm_max) ) - - if formula != '': - input_img_data = {step['name']: step['channel_data'] for step in steps.values()} - + + if formula != "": + input_img_data = {step["name"]: step["channel_data"] for step in steps.values()} + symbols = {name: sp.Symbol(name) for name in input_img_data} expr = sp.sympify(formula, locals=symbols) @@ -2705,173 +2744,166 @@ def combine_channels_func( func = sp.lambdify( [symbols[v] for v in used_vars], # fixed order! expr, - modules="numpy" + modules="numpy", ) args = [input_img_data[v] for v in used_vars] output_img = func(*args) else: key0 = list(steps.keys())[0] - output_img = steps[key0]['channel_data'] + output_img = steps[key0]["channel_data"] if not output_as_segm: - output_img = skimage.exposure.rescale_intensity( - output_img, out_range=(0, 1) - ) - - txt = '' + output_img = skimage.exposure.rescale_intensity(output_img, out_range=(0, 1)) + + txt = "" if keep_input_data_type and not output_as_segm: try: - output_img = myutils.convert_to_dtype( - output_img, original_dtype - ) - method = 'cellacdc.myutils.convert_to_dtype' - warning = 'safe' - prefix = '' + output_img = utils.convert_to_dtype(output_img, original_dtype) + method = "cellacdc.utils.convert_to_dtype" + warning = "safe" + prefix = "" except Exception as err: dtype_info = np.iinfo(original_dtype) dtype_max = dtype_info.max dtype_min = dtype_info.min is_in_bounds = ( - output_img.max() <= dtype_max - and output_img.min() >= dtype_min + output_img.max() <= dtype_max and output_img.min() >= dtype_min ) if is_in_bounds: output_img = output_img.astype(original_dtype) - method = 'output_img.astype(original_dtype)' - warning = 'safe' # if weights were set correctly' - prefix = '[WARNING]: ' + method = "output_img.astype(original_dtype)" + warning = "safe" # if weights were set correctly' + prefix = "[WARNING]: " else: output_img = skimage.exposure.rescale_intensity( output_img, out_range=(dtype_min, dtype_max) ) output_img = output_img.astype(original_dtype) method = ( - 'skimage.exposure.rescale_intensity ' - '-> output_img.astype (original_dtype)' + "skimage.exposure.rescale_intensity " + "-> output_img.astype (original_dtype)" ) - warning = '!RESCALING! the image data' - prefix = '[WARNING]: ' + warning = "!RESCALING! the image data" + prefix = "[WARNING]: " txt = ( - f'{prefix}Converted output image to {original_dtype} ' - f'using {method}, which is {warning}' + f"{prefix}Converted output image to {original_dtype} " + f"using {method}, which is {warning}" ) if not return_img: _log_printl_fallback(txt, logger_func) elif output_as_segm: - output_img[output_img<0] = 0 + output_img[output_img < 0] = 0 output_img = output_img.astype(np.uint32) - + if return_img: return output_img, key, txt - - txt = f'Saving combined {"segmentation" if output_as_segm else "image"} to {save_filepath}' + + txt = f"Saving combined {'segmentation' if output_as_segm else 'image'} to {save_filepath}" _log_printl_fallback(txt, logger_func) - - io.save_image_data( # handles saving img and segm + io.save_image_data( # handles saving img and segm save_filepath, output_img ) return None + def get_selected_channels(steps): selected_channel = set() for step in steps.values(): - ch = step['channel'] - if ch == 'current segm.': + ch = step["channel"] + if ch == "current segm.": continue selected_channel.add(ch) - + return selected_channel + def split_segm_masks_mother_bud_line( - cells_segm_data, segm_data_to_split, acdc_df, - debug=False - ): - acdc_df = acdc_df.set_index(['frame_i', 'Cell_ID']) + cells_segm_data, segm_data_to_split, acdc_df, debug=False +): + acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]) split_segm_away = np.zeros_like(segm_data_to_split) split_segm_close = np.zeros_like(segm_data_to_split) - + pbar = tqdm(total=len(cells_segm_data), ncols=100, position=1, leave=False) for frame_i, lab in enumerate(cells_segm_data): rp = skimage.measure.regionprops(lab) - rp_mapper = {obj.label:obj for obj in rp} + rp_mapper = {obj.label: obj for obj in rp} for obj in rp: try: - ccs = acdc_df.at[(frame_i, obj.label), 'cell_cycle_stage'] + ccs = acdc_df.at[(frame_i, obj.label), "cell_cycle_stage"] except Exception as err: continue - - if ccs != 'S': + + if ccs != "S": continue - + try: - relationship = acdc_df.at[(frame_i, obj.label), 'relationship'] + relationship = acdc_df.at[(frame_i, obj.label), "relationship"] except Exception as err: continue - - if relationship == 'bud': + + if relationship == "bud": continue - - bud_ID = int(acdc_df.at[(frame_i, obj.label), 'relative_ID']) + + bud_ID = int(acdc_df.at[(frame_i, obj.label), "relative_ID"]) obj_bud = rp_mapper[bud_ID] - + moth_ID = obj.label yc_m, xc_m = obj.centroid yc_b, xc_b = obj_bud.centroid - - slope_mb = (yc_b - yc_m)/(xc_b - yc_b) + + slope_mb = (yc_b - yc_m) / (xc_b - yc_b) if slope_mb != 0: - slope_perp = -1/slope_mb - interc_perp = yc_m - xc_m*slope_perp + slope_perp = -1 / slope_mb + interc_perp = yc_m - xc_m * slope_perp else: slope_perp = np.inf interc_perp = np.nan - + ref_p1, ref_p2 = get_split_line_ref_points_img( lab, slope_perp, interc_perp, xc_m, yc_m ) - + if debug: from cellacdc import _debug + _debug.split_segm_masks_mother_bud_line( lab, obj, obj_bud, ref_p1, ref_p2 ) - + for z, lab_split in enumerate(segm_data_to_split[frame_i]): - lab_split_yy, lab_split_xx = np.nonzero(lab_split==obj.label) + lab_split_yy, lab_split_xx = np.nonzero(lab_split == obj.label) if len(lab_split_yy) == 0: continue - + query_points = np.column_stack((lab_split_xx, lab_split_yy)) close_to_bud_mask = classify_points_plane_split_by_line( ref_p1, ref_p2, query_points, (xc_b, yc_b) ) - + split_close_yy = lab_split_yy[close_to_bud_mask] split_close_xx = lab_split_xx[close_to_bud_mask] - - split_segm_close[frame_i, z, split_close_yy, split_close_xx] = ( - obj.label - ) - + + split_segm_close[frame_i, z, split_close_yy, split_close_xx] = obj.label + split_away_yy = lab_split_yy[~close_to_bud_mask] split_away_xx = lab_split_xx[~close_to_bud_mask] - - split_segm_away[frame_i, z, split_away_yy, split_away_xx] = ( - obj.label - ) - + + split_segm_away[frame_i, z, split_away_yy, split_away_xx] = obj.label + pbar.update() - pbar.close() - + pbar.close() + return split_segm_close, split_segm_away + def classify_points_plane_split_by_line( - p1, p2, query_points: np.ndarray, relative_to_p - ): + p1, p2, query_points: np.ndarray, relative_to_p +): """Classify points on plane crossed by a line connecting p1 and p2 relative to `relative_to_p` point @@ -2880,24 +2912,24 @@ def classify_points_plane_split_by_line( p1 : (x, y) of floats First point of the line p2 : (x, y) of floats - Second point + Second point query_points : (N, 2) np.ndarray (x, y) coordinates of the points to classify - + References ---------- https://stackoverflow.com/questions/45766534/finding-cross-product-to-find-points-above-below-a-line-in-matplotlib - """ + """ relative_p_arr = np.array([relative_to_p]) a = np.array(p1) b = np.array(p2) - - class_relative_p = (np.cross(relative_p_arr-a, b-a) <= 0).astype(int)[0] - class_query_points = (np.cross(query_points-a, b-a) <= 0).astype(int) + + class_relative_p = (np.cross(relative_p_arr - a, b - a) <= 0).astype(int)[0] + class_query_points = (np.cross(query_points - a, b - a) <= 0).astype(int) query_points_mask = class_query_points == class_relative_p - - return query_points_mask - + + return query_points_mask + def get_split_line_ref_points_img(img, slope, interc, xc, yc): Y, X = img.shape @@ -2913,46 +2945,48 @@ def get_split_line_ref_points_img(img, slope, interc, xc, yc): y_ref1 = yc else: y0 = 0 - x0 = y0 - interc/slope - + x0 = y0 - interc / slope + x1 = X - y1 = slope*x1 + interc - + y1 = slope * x1 + interc + x2 = 0 y2 = interc - + y3 = Y - x3 = (y3 - interc)/slope - + x3 = (y3 - interc) / slope + if x0 < X: x_ref_0 = x0 y_ref_0 = y0 else: x_ref_0 = x1 y_ref_0 = y1 - + if x3 > 0: x_ref1 = x3 y_ref1 = y3 else: x_ref1 = x2 y_ref1 = y2 - + return (x_ref_0, y_ref_0), (x_ref1, y_ref1) + def _compute_obj_to_all_objs_contour_dist_pairs( - input, all_objs_contours_arr=None, all_contours=None, pbar=None - ): + input, all_objs_contours_arr=None, all_contours=None, pbar=None +): j, other_obj = input - other_obj_contours = all_contours[(other_obj.label, 'None', False, False)] + other_obj_contours = all_contours[(other_obj.label, "None", False, False)] min_distances_to_other = nearest_points_objects( all_objs_contours_arr, other_obj_contours - ) + ) return other_obj.label, min_distances_to_other + def _compute_all_obj_to_obj_contour_dist_pairs( - all_contours: dict, rp, prev_rp=None, restrict_search=True - ): + all_contours: dict, rp, prev_rp=None, restrict_search=True +): if prev_rp is not None: prev_IDs = set([obj.label for obj in prev_rp]) new_IDs = set([obj.label for obj in rp if obj.label not in prev_IDs]) @@ -2963,36 +2997,32 @@ def _compute_all_obj_to_obj_contour_dist_pairs( current_rp = rp other_rp = rp num_cols = len(current_rp) - + max_distance = np.inf if restrict_search: - max_distance = 3*np.max([obj.major_axis_length for obj in rp]) - + max_distance = 3 * np.max([obj.major_axis_length for obj in rp]) + calculated_pairs = {} num_rows = len(current_rp) num_objs = len(rp) IDs = [obj.label for obj in rp] dist_matrix_df = pd.DataFrame( - index=IDs, - columns=IDs, - data=np.full((num_objs, num_objs), np.inf) - ) - len_longest_contour = np.max( - [len(contours) for contours in all_contours.values()] + index=IDs, columns=IDs, data=np.full((num_objs, num_objs), np.inf) ) + len_longest_contour = np.max([len(contours) for contours in all_contours.values()]) all_objs_contours_arr = np.full((num_rows, len_longest_contour, 2), np.nan) current_rp_mapper = {} for o, obj in enumerate(current_rp): - obj_contours = all_contours[(obj.label, 'None', False, False)] - all_objs_contours_arr[o, :len(obj_contours)] = obj_contours + obj_contours = all_contours[(obj.label, "None", False, False)] + all_objs_contours_arr[o, : len(obj_contours)] = obj_contours current_rp_mapper[o] = obj - - pbar = tqdm(total=num_rows*num_cols, ncols=100, leave=False) + + pbar = tqdm(total=num_rows * num_cols, ncols=100, leave=False) with ThreadPoolExecutor() as executor: iterable = enumerate(other_rp) - + func = partial( - _compute_obj_to_all_objs_contour_dist_pairs, + _compute_obj_to_all_objs_contour_dist_pairs, all_objs_contours_arr=all_objs_contours_arr, all_contours=all_contours, pbar=pbar, @@ -3005,39 +3035,46 @@ def _compute_all_obj_to_obj_contour_dist_pairs( return dist_matrix_df + def convexity_defects(img, eps_percent): img = img.astype(np.uint8) - contours, _ = cv2.findContours(img,2,1) + contours, _ = cv2.findContours(img, 2, 1) cnt = contours[0] - cnt = cv2.approxPolyDP(cnt,eps_percent*cv2.arcLength(cnt,True),True) # see https://www.programcreek.com/python/example/89457/cv22.convexityDefects - hull = cv2.convexHull(cnt,returnPoints = False) # see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html - defects = cv2.convexityDefects(cnt,hull) # see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html + cnt = cv2.approxPolyDP( + cnt, eps_percent * cv2.arcLength(cnt, True), True + ) # see https://www.programcreek.com/python/example/89457/cv22.convexityDefects + hull = cv2.convexHull( + cnt, returnPoints=False + ) # see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html + defects = cv2.convexityDefects( + cnt, hull + ) # see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html return cnt, defects -def split_connected_components(lab, rp=None, max_ID=None): + +def split_connected_components(lab, rp=None, max_ID=None): if rp is None: lab = skimage.measure.regionprops(lab) - + if max_ID is None: max_ID = max([obj.label for obj in rp], default=1) - + split_occured = False for obj in rp: lab_obj = skimage.measure.label(obj.image) rp_lab_obj = skimage.measure.regionprops(lab_obj) - if len(rp_lab_obj)<=1: + if len(rp_lab_obj) <= 1: continue lab_obj += max_ID - _slice = obj.slice # self.getObjSlice(obj.slice) - _objMask = obj.image # self.getObjImage(obj.image) + _slice = obj.slice # self.getObjSlice(obj.slice) + _objMask = obj.image # self.getObjImage(obj.image) lab[_slice][_objMask] = lab_obj[_objMask] split_occured = True max_ID += 1 return split_occured -def split_along_convexity_defects( - ID, lab, max_ID, max_i=1, eps_percent=0.01 - ): + +def split_along_convexity_defects(ID, lab, max_ID, max_i=1, eps_percent=0.01): lab_ID_bool = lab == ID # First try separating by labelling lab_ID = lab_ID_bool.astype(int) @@ -3058,20 +3095,18 @@ def split_along_convexity_defects( if len(defects) != 2: return lab, success, [] - defects_points = [0]*len(defects) + defects_points = [0] * len(defects) for i, defect in enumerate(defects): - s,e,f,d = defect[0] - x,y = tuple(cnt[f][0]) - defects_points[i] = (y,x) + s, e, f, d = defect[0] + x, y = tuple(cnt[f][0]) + defects_points[i] = (y, x) (r0, c0), (r1, c1) = defects_points rr, cc, _ = skimage.draw.line_aa(r0, c0, r1, c1) sep_bud_img = np.copy(lab_ID_bool) sep_bud_img[rr, cc] = False - - sep_bud_label = skimage.measure.label( - sep_bud_img, connectivity=2 - ) - + + sep_bud_label = skimage.measure.label(sep_bud_img, connectivity=2) + rp_sep = skimage.measure.regionprops(sep_bud_label) IDs_sep = [obj.label for obj in rp_sep] areas = [obj.area for obj in rp_sep] @@ -3080,16 +3115,16 @@ def split_along_convexity_defects( orig_sblab = np.copy(sep_bud_label) # sep_bud_label = np.zeros_like(sep_bud_label) ID1 = ID - ID2 = max_ID+max_i - sep_bud_label[orig_sblab==curr_ID_moth] = ID1 - sep_bud_label[orig_sblab==curr_ID_bud] = ID2 + ID2 = max_ID + max_i + sep_bud_label[orig_sblab == curr_ID_moth] = ID1 + sep_bud_label[orig_sblab == curr_ID_bud] = ID2 splittedIDs = [ID1, ID2] # sep_bud_label *= (max_ID+max_i) temp_sep_bud_lab = sep_bud_label.copy() for r, c in zip(rr, cc): if lab_ID_bool[r, c]: nearest_ID = nearest_nonzero_2D(sep_bud_label, r, c) - temp_sep_bud_lab[r,c] = nearest_ID + temp_sep_bud_lab[r, c] = nearest_ID sep_bud_label = temp_sep_bud_lab sep_bud_label_mask = sep_bud_label != 0 # plt.imshow_tk(sep_bud_label, dots_coords=np.asarray(defects_points)) @@ -3098,25 +3133,25 @@ def split_along_convexity_defects( success = True return lab, success, splittedIDs + def validate_multidimensional_recipe( - recipe: List[Dict[str, Any]], - apply_to_all_zslices=False, - apply_to_all_frames=False - ): + recipe: List[Dict[str, Any]], apply_to_all_zslices=False, apply_to_all_frames=False +): for step in recipe: - method = step['method'] - func = PREPROCESS_MAPPER[method]['function'] - kwargs = step['kwargs'] - + method = step["method"] + func = PREPROCESS_MAPPER[method]["function"] + kwargs = step["kwargs"] + argspecs = inspect.getfullargspec(func) for arg in argspecs.args: - if arg == 'apply_to_all_frames': - kwargs['apply_to_all_frames'] = apply_to_all_frames - if arg == 'apply_to_all_zslices': - kwargs['apply_to_all_zslices'] = apply_to_all_zslices - + if arg == "apply_to_all_frames": + kwargs["apply_to_all_frames"] = apply_to_all_frames + if arg == "apply_to_all_zslices": + kwargs["apply_to_all_zslices"] = apply_to_all_zslices + return recipe + def insert_missing_object(lab_dst, obj, all_dst_IDs, assignments_mapper): added_ID = assignments_mapper.get(obj.label) if obj.label not in all_dst_IDs: @@ -3136,24 +3171,23 @@ def insert_missing_object(lab_dst, obj, all_dst_IDs, assignments_mapper): # --> need to assign the same ID as before lab_dst[obj.slice][obj.image] = added_ID all_dst_IDs.add(added_ID) - + return lab_dst, assignments_mapper, all_dst_IDs -def insert_missing_objects( - segm_dst, segm_src, is_timelapse=True, display_pbar=True - ): + +def insert_missing_objects(segm_dst, segm_src, is_timelapse=True, display_pbar=True): if not is_timelapse: segm_dst = segm_dst[np.newaxis] segm_src = segm_src[np.newaxis] - + all_dst_IDs = set() for lab_dst in segm_dst: rp = skimage.measure.regionprops(lab_dst) all_dst_IDs.update([obj.label for obj in rp]) - + if display_pbar: pbar = tqdm(total=len(segm_src), ncols=100, leave=False) - + assignments_mapper = {} for frame_i, (lab_src, lab_dst) in enumerate(zip(segm_src, segm_dst)): rp = skimage.measure.regionprops(lab_src) @@ -3170,50 +3204,52 @@ def insert_missing_objects( lab_dst, assignments_mapper, all_dst_IDs = out segm_dst[frame_i] = lab_dst continue - + # Check if merged --> the masks do not coincide obj_dst = rp_dst_mapper[obj_dst_ID] is_merged = not ( len(obj_dst.coords) == len(obj.coords) and np.all(obj_dst.coords == obj.coords) ) - + if not is_merged: continue - + lab_dst, assignments_mapper, all_dst_IDs = insert_missing_object( lab_dst, obj, all_dst_IDs, assignments_mapper ) segm_dst[frame_i] = lab_dst - + if display_pbar: pbar.update() - + if display_pbar: pbar.close() - + return segm_dst - + + def process_lab(task): i, lab = task # Assuming this function processes each lab independently data_dict = {} rp = skimage.measure.regionprops(lab) IDs = [obj.label for obj in rp] - data_dict['IDs'] = IDs - data_dict['regionprops'] = rp - data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)} - + data_dict["IDs"] = IDs + data_dict["regionprops"] = rp + data_dict["IDs_idxs"] = {ID: idx for idx, ID in enumerate(IDs)} + return i, data_dict, IDs # Return index, data_dict, and IDs + def parallel_count_objects(posData, logger_func): benchmark = True - #futile attempt to use multiprocessing to speed things up - logger_func('Counting total number of segmented objects...') - + # futile attempt to use multiprocessing to speed things up + logger_func("Counting total number of segmented objects...") + allIDs = set() seg_data = posData.segm_data - + # Initialize empty data dictionary to avoid recalculating each time tasks = [(i, lab) for i, lab in enumerate(seg_data)] @@ -3222,22 +3258,25 @@ def parallel_count_objects(posData, logger_func): # Process in batches to optimize memory usage and control parallelism with ThreadPoolExecutor() as executor: futures = [executor.submit(process_lab, task) for task in tasks] - + # Process results as they are completed for future in tqdm(as_completed(futures), total=len(futures), ncols=100): i, data_dict, IDs = future.result() - posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable - posData.allData_li[i]['IDs'] = data_dict['IDs'] - posData.allData_li[i]['regionprops'] = data_dict['regionprops'] - posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs'] + posData.allData_li[i] = ( + utils.get_empty_stored_data_dict() + ) # or directly assign if it's mutable + posData.allData_li[i]["IDs"] = data_dict["IDs"] + posData.allData_li[i]["regionprops"] = data_dict["regionprops"] + posData.allData_li[i]["IDs_idxs"] = data_dict["IDs_idxs"] allIDs.update(IDs) - + if benchmark: t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') + logger_func(f"Counting objects took {(t1 - t0) * 1000:.2f} ms") return allIDs, posData + def count_objects(posData, logger_func): benchmark = False @@ -3247,18 +3286,18 @@ def count_objects(posData, logger_func): if not np.any(segm_data): allIDs = [] return allIDs, posData - - logger_func('Counting total number of segmented objects...') + + logger_func("Counting total number of segmented objects...") pbar = tqdm(total=len(segm_data), ncols=100) if benchmark: t0 = time.perf_counter() for i, lab in enumerate(segm_data): - posData.allData_li[i] = myutils.get_empty_stored_data_dict() + posData.allData_li[i] = utils.get_empty_stored_data_dict() rp = skimage.measure.regionprops(lab) IDs = [obj.label for obj in rp] - posData.allData_li[i]['IDs'] = IDs - posData.allData_li[i]['regionprops'] = rp - posData.allData_li[i]['IDs_idxs'] = { # IDs_idxs[obj.label] = idx + posData.allData_li[i]["IDs"] = IDs + posData.allData_li[i]["regionprops"] = rp + posData.allData_li[i]["IDs_idxs"] = { # IDs_idxs[obj.label] = idx ID: idx for idx, ID in enumerate(IDs) } allIDs.update(IDs) @@ -3266,9 +3305,10 @@ def count_objects(posData, logger_func): pbar.close() if benchmark: t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') + logger_func(f"Counting objects took {(t1 - t0) * 1000:.2f} ms") return allIDs, posData + def fix_sparse_directML(verbose=True): """DirectML does not support sparse tensors, so we need to fallback to CPU. This function replaces `torch.sparse_coo_tensor`, `torch._C._sparse_coo_tensor_unsafe`, @@ -3284,9 +3324,9 @@ def fix_sparse_directML(verbose=True): import warnings def fallback_to_cpu_on_sparse_error(func, verbose=True): - @functools.wraps(func) # wrapper shinanigans (thanks chatgpt) + @functools.wraps(func) # wrapper shinanigans (thanks chatgpt) def wrapper(*args, **kwargs): - device_arg = kwargs.get('device', None) # get desired device from kwargs + device_arg = kwargs.get("device", None) # get desired device from kwargs # Ensure indices are int64 if args[0] looks like indices, # I got random errors from it not being int64 @@ -3294,65 +3334,87 @@ def wrapper(*args, **kwargs): if args[0].dtype != torch.int64: args = (args[0].to(dtype=torch.int64),) + args[1:] - try: # try to perform the operation and move to dml if possible - result = func(*args, **kwargs) # run function with current args and kwargs + try: # try to perform the operation and move to dml if possible + result = func( + *args, **kwargs + ) # run function with current args and kwargs if device_arg is not None and str(device_arg).lower() == "dml": - try: # try to move result to dml + try: # try to move result to dml result.to("dml") - except RuntimeError as e: # moving failed, falling back to cpu + except RuntimeError as e: # moving failed, falling back to cpu if verbose: - warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}") - kwargs['device'] = torch.device("cpu") - return func(*args, **kwargs) # try again, after setting device to cpu - return result # just return result if all worked well - - except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu + warnings.warn( + f"Sparse op failed on DirectML, falling back to CPU: {e}" + ) + kwargs["device"] = torch.device("cpu") + return func( + *args, **kwargs + ) # try again, after setting device to cpu + return result # just return result if all worked well + + except ( + RuntimeError + ) as e: # try and run on dlm, if it fails, fallback to cpu if "sparse" in str(e).lower() or "not implemented" in str(e).lower(): if verbose: - warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}") - kwargs['device'] = torch.device("cpu") # if rutime warning caused by sparse tensor, set device to cpu + warnings.warn( + f"Sparse op failed on DirectML, falling back to CPU: {e}" + ) + kwargs["device"] = torch.device( + "cpu" + ) # if rutime warning caused by sparse tensor, set device to cpu # Re-apply indices dtype correction before retrying on CPU. Just in case (maybe first one not needed?) if len(args) >= 1 and isinstance(args[0], torch.Tensor): if args[0].dtype != torch.int64: args = (args[0].to(dtype=torch.int64),) + args[1:] - return func(*args, **kwargs) # run function again with cpu device + return func(*args, **kwargs) # run function again with cpu device else: - raise e # catch and other runtime errors + raise e # catch and other runtime errors return wrapper # --- Patch Sparse Tensor Constructors --- # High-level API - torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose) + torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error( + torch.sparse_coo_tensor, verbose=verbose + ) # Low-level API if hasattr(torch._C, "_sparse_coo_tensor_unsafe"): - torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose) + torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error( + torch._C._sparse_coo_tensor_unsafe, verbose=verbose + ) if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"): - torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error( - torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose + torch._C._sparse_coo_tensor_with_dims_and_tensors = ( + fallback_to_cpu_on_sparse_error( + torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose + ) + ) + + if hasattr(torch.sparse, "SparseTensor"): + torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error( + torch.sparse.SparseTensor, verbose=verbose ) - if hasattr(torch.sparse, 'SparseTensor'): - torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose) - # suppress warnings if not verbose: import warnings + warnings.filterwarnings("once", message="Sparse op failed on DirectML*") -def connected_components_in_undirected_graph(undirected_graph:dict): + +def connected_components_in_undirected_graph(undirected_graph: dict): # Build undirected graph graph = defaultdict(set) for key, val in undirected_graph.items(): for other in val: graph[key].add(other) graph[other].add(key) # Make it bidirectional - + visited = set() groups = [] @@ -3361,31 +3423,34 @@ def dfs(node, group): group.append(node) for neighbor in graph[node]: if neighbor not in visited: - dfs(neighbor, group) # recursive call to visit neighbors + dfs(neighbor, group) # recursive call to visit neighbors for key in graph: if key not in visited: group = [] dfs(key, group) groups.append(group) - + return groups -def apply_func_to_imgs(image:np.ndarray, - func: Callable, - *args, - workers: int = 10, - iter_axis:List[int]|int= None, - target_shape:List[int] = None, - target_type: type = None, - target_axis_iter: List[int]|int = None, - parallel: bool = True, - benchmark: bool = False, - processpool: bool = False, - **kwargs): + +def apply_func_to_imgs( + image: np.ndarray, + func: Callable, + *args, + workers: int = 10, + iter_axis: List[int] | int = None, + target_shape: List[int] = None, + target_type: type = None, + target_axis_iter: List[int] | int = None, + parallel: bool = True, + benchmark: bool = False, + processpool: bool = False, + **kwargs, +): """Apply a function to each image. This is done along the iter_axis (can also be a single int). Then the processed image is put in the target_axis_iter (can also be a single int). - (If target_axis_iter, target_shape or target_type are None, + (If target_axis_iter, target_shape or target_type are None, they are taken from the input image). Example of iter_axis: [0, 1] and target_axis_iter: [1, 0] means that the function is applied to each [0, 1, ...] slice of the input and the processed image is put in the [1, 0, ...] slice of the output image. @@ -3397,11 +3462,11 @@ def apply_func_to_imgs(image:np.ndarray, ---------- image : np.ndarray Image to be processed - + func : Callable Function to be applied to each image. First argument should be the image itself, one kwarg should be `frame_index_out`, - should `return processed_image, frame_index_out`. `frame_index_out` just needs to + should `return processed_image, frame_index_out`. `frame_index_out` just needs to be passed along, no need to slice the image in `func`. *args : tuple Additional arguments to be passed to the function @@ -3444,30 +3509,29 @@ def apply_func_to_imgs(image:np.ndarray, out = func(image, *args, **kwargs, frame_index_out=None)[1] if benchmark: t1 = time.perf_counter() - printl(f"Processing time: {(t1 - t0)*1000:.2f} ms, no parallel since iter_axis is None") - return out - + printl( + f"Processing time: {(t1 - t0) * 1000:.2f} ms, no parallel since iter_axis is None" + ) + return out + if isinstance(iter_axis, int): iter_axis = [iter_axis] - + if isinstance(target_axis_iter, int): iter_axis = [target_axis_iter] - + if target_axis_iter is None: target_axis_iter = iter_axis - + if target_shape is None: target_shape = image_shape - + if target_type is None: target_type = type(image.flat[0]) - - image_out = np.empty( - target_shape, dtype=target_type - ) + image_out = np.empty(target_shape, dtype=target_type) - input_output_mapper = myutils.get_input_output_mapper( + input_output_mapper = utils.get_input_output_mapper( image_shape, iter_axis, target_shape, target_axis_iter ) @@ -3478,25 +3542,25 @@ def apply_func_to_imgs(image:np.ndarray, executor_func = ThreadPoolExecutor with executor_func() as executor: futures = { - executor.submit(func, image[i_in], *args, frame_index_out=i_out, **kwargs) + executor.submit( + func, image[i_in], *args, frame_index_out=i_out, **kwargs + ) for i_in, i_out in input_output_mapper } for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Processing frames" - ): + as_completed(futures), total=len(futures), desc="Processing frames" + ): i, processed = future.result() image_out[i] = processed else: for i_in, i_out in tqdm(input_output_mapper, desc="Processing frames"): processed = func(image[i_in], *args, frame_index_out=i_out, **kwargs)[1] image_out[i_out] = processed - + if benchmark: t1 = time.perf_counter() - printl(f"Processing time: {(t1 - t0)*1000:.2f} ms") + printl(f"Processing time: {(t1 - t0) * 1000:.2f} ms") return image_out @@ -3506,7 +3570,7 @@ def fill_holes_in_segmentation(labels): for obj in skimage.measure.regionprops(labels): label_id = obj.label mask_filled = scipy.ndimage.binary_fill_holes(obj.image) - + region = filled[obj.slice] # Only fill where mask_filled is True and region is still background fill_mask = mask_filled & (region == 0) @@ -3515,25 +3579,24 @@ def fill_holes_in_segmentation(labels): return filled -def natsort_acdc_columns( - columns: Iterable[str], - prepend_default_index_cols=True - ): + +def natsort_acdc_columns(columns: Iterable[str], prepend_default_index_cols=True): sorted_cols = natsorted(columns, key=str.casefold) if not prepend_default_index_cols: return sorted_cols - + cols_to_prepend = [] for col in default_index_cols: if col not in sorted_cols: continue - + sorted_cols.remove(col) cols_to_prepend.append(col) - + sorted_cols = [*cols_to_prepend, *sorted_cols] return sorted_cols + def linear_fit_3d(xx, yy, zz): points = np.column_stack((xx, yy, zz)) centroid = points.mean(axis=0) @@ -3543,42 +3606,44 @@ def linear_fit_3d(xx, yy, zz): return centroid, d + def binary_fill_holes(mask, slice_by_slice=True): if not slice_by_slice: mask = scipy.ndimage.binary_fill_holes(mask) return mask - + if mask.ndim == 2: mask = scipy.ndimage.binary_fill_holes(mask) return mask - + for z, mask_z in enumerate(mask): if not np.any(mask_z): continue - + mask[z] = scipy.ndimage.binary_fill_holes(mask_z) - + return mask + def convex_hull_mask(mask: np.ndarray, slice_by_slice=True): if not slice_by_slice: mask = skimage.morphology.convex_hull_image(mask) return mask - + if mask.ndim == 2: mask = skimage.morphology.convex_hull_image(mask) return mask - + mask_rp = skimage.measure.regionprops(mask.astype(np.uint8)) if len(mask_rp) == 0: return mask - + mask_obj = mask_rp[0] for z, mask_obj_img_z in enumerate(mask_obj.image): if not np.any(mask_obj_img_z): continue - + mask_obj_hull_z = skimage.morphology.convex_hull_image(mask_obj_img_z) mask[mask_obj.slice][z] = mask_obj_hull_z - - return mask \ No newline at end of file + + return mask diff --git a/cellacdc/data.py b/cellacdc/data.py index da20e2398..19ab6f14f 100644 --- a/cellacdc/data.py +++ b/cellacdc/data.py @@ -6,33 +6,33 @@ from . import data_path, load, base_cca_dict, cca_df_colnames + class _Data: def __init__( - self, images_path, intensity_image_path, acdc_df_path, segm_path, - basename - ): + self, images_path, intensity_image_path, acdc_df_path, segm_path, basename + ): self.images_path = images_path self.intensity_image_path = intensity_image_path self.acdc_df_path = acdc_df_path self.segm_path = segm_path self.basename = basename - + def filename(self): return os.path.basename(self.intensity_image_path) - + def channel_name(self): filename, ext = os.path.splitext(self.filename()) - return filename[len(self.basename):] - + return filename[len(self.basename) :] + def acdc_df(self): return load._load_acdc_df_file(self.acdc_df_path) - + def image_data(self): return load.load_image_file(self.intensity_image_path) - + def segm_data(self): - return np.load(self.segm_path)['arr_0'] - + return np.load(self.segm_path)["arr_0"] + def cca_df(self): acdc_df = load._load_acdc_df_file(self.acdc_df_path).dropna() cca_df = acdc_df[cca_df_colnames] @@ -40,103 +40,100 @@ def cca_df(self): cca_df = cca_df.astype(dtypes) return cca_df + class FissionYeastAnnotated(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_symm_div_acdc_tracker', 'Images', - ) - intensity_image_path = os.path.join( - images_path, 'bknapp_Movie_S1.tif' - ) - acdc_df_path = os.path.join( - images_path, 'bknapp_Movie_S1_acdc_output.csv' - ) - segm_path = os.path.join( - images_path, 'bknapp_Movie_S1_segm.npz' - ) - basename = 'bknapp_Movie_S1_' + data_path, + "test_symm_div_acdc_tracker", + "Images", + ) + intensity_image_path = os.path.join(images_path, "bknapp_Movie_S1.tif") + acdc_df_path = os.path.join(images_path, "bknapp_Movie_S1_acdc_output.csv") + segm_path = os.path.join(images_path, "bknapp_Movie_S1_segm.npz") + basename = "bknapp_Movie_S1_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, '') - + + return load.loadData(self.intensity_image_path, "") + + class DeepSeaAnnotation(_Data): def __init__(self): images_path = os.path.join( - data_path, 'deep_sea', 'Images', - ) - intensity_image_path = os.path.join( - images_path, 'set_22_MESC.tif' - ) - acdc_df_path = os.path.join( - images_path, 'set_22_MESC_acdc_output.csv' - ) - segm_path = os.path.join( - images_path, 'set_22_MESC_segm.tif' - ) - basename = 'set_22_MESC_' + data_path, + "deep_sea", + "Images", + ) + intensity_image_path = os.path.join(images_path, "set_22_MESC.tif") + acdc_df_path = os.path.join(images_path, "set_22_MESC_acdc_output.csv") + segm_path = os.path.join(images_path, "set_22_MESC_segm.tif") + basename = "set_22_MESC_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, '') + + return load.loadData(self.intensity_image_path, "") + class YeastTimeLapseAnnotated(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_timelapse', 'Yagya_Kurt_presentation', - 'Position_6', 'Images' + data_path, + "test_timelapse", + "Yagya_Kurt_presentation", + "Position_6", + "Images", ) intensity_image_path = os.path.join( - images_path, 'SCGE_5strains_23092021_Dia_Ph3.tif' + images_path, "SCGE_5strains_23092021_Dia_Ph3.tif" ) acdc_df_path = os.path.join( - images_path, 'SCGE_5strains_23092021_acdc_output.csv' - ) - segm_path = os.path.join( - images_path, 'SCGE_5strains_23092021_segm.npz' + images_path, "SCGE_5strains_23092021_acdc_output.csv" ) - basename = 'SCGE_5strains_23092021_' + segm_path = os.path.join(images_path, "SCGE_5strains_23092021_segm.npz") + basename = "SCGE_5strains_23092021_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'Dia_Ph3') + + return load.loadData(self.intensity_image_path, "Dia_Ph3") + class pomBseenDualChannelData(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_pomBseen', 'dual_channel', - 'Position_3', 'Images' + data_path, "test_pomBseen", "dual_channel", "Position_3", "Images" ) intensity_image_path = os.path.join( - images_path, 'Demo_two_channel_input_image_BF.tif' + images_path, "Demo_two_channel_input_image_BF.tif" ) acdc_df_path = os.path.join( - images_path, 'Demo_two_channel_input_image_acdc_output.csv' + images_path, "Demo_two_channel_input_image_acdc_output.csv" ) segm_path = os.path.join( - images_path, 'Demo_two_channel_input_image_segm_bf.npz' + images_path, "Demo_two_channel_input_image_segm_bf.npz" ) - basename = 'Demo_two_channel_input_image_' + basename = "Demo_two_channel_input_image_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'BF') + + return load.loadData(self.intensity_image_path, "BF") + class _YeastTimeLapseAnnotatedJordan(_Data): def __init__(self, custom_data_path=None): @@ -145,162 +142,160 @@ def __init__(self, custom_data_path=None): else: _data_path = data_path images_path = os.path.join( - _data_path, 'gh_issue_394_Jordan', 'Position_1', 'Images' + _data_path, "gh_issue_394_Jordan", "Position_1", "Images" ) intensity_image_path = os.path.join( - images_path, '220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_Phase.tif' + images_path, "220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_Phase.tif" ) acdc_df_path = os.path.join( - images_path, '220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_acdc_output.csv' + images_path, "220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_acdc_output.csv" ) segm_path = os.path.join( - images_path, '220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_segm.npz' + images_path, "220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_segm.npz" ) - basename = '220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_' + basename = "220630_JX_MS380_2ng-uL-aTc_pos04_g_s1_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'Dia_Ph3') + + return load.loadData(self.intensity_image_path, "Dia_Ph3") + class Cdc42TimeLapseData(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_timelapse', 'Kurt_ring', 'Cdc42', - 'Position_1', 'Images' + data_path, "test_timelapse", "Kurt_ring", "Cdc42", "Position_1", "Images" ) intensity_image_path = os.path.join( - images_path, 'SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_Dia_Ph3.tif' + images_path, "SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_Dia_Ph3.tif" ) acdc_df_path = os.path.join( - images_path, 'SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_acdc_output.csv' + images_path, "SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_acdc_output.csv" ) segm_path = os.path.join( - images_path, 'SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_segm.npz' + images_path, "SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_segm.npz" ) - basename = 'SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_' + basename = "SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) self.intensity_image_path = intensity_image_path - + def cdc42_data(self): - return load.imread(os.path.join( - self.images_path, - 'SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_tdTomato_Ph3__YEAST.tif' - )) - + return load.imread( + os.path.join( + self.images_path, + "SCGE_DLY16570_1-15_DLY16571_16-30_corr_s01_tdTomato_Ph3__YEAST.tif", + ) + ) + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'Ph3__YEAST') + + return load.loadData(self.intensity_image_path, "Ph3__YEAST") + class YeastMitoTimelapse(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_4D', 'Lisa_mito', 'Position_5', 'Images' + data_path, "test_4D", "Lisa_mito", "Position_5", "Images" ) intensity_image_path = os.path.join( - images_path, 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_Ph_3.tif' + images_path, "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_Ph_3.tif" ) acdc_df_path = os.path.join( - images_path, 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_acdc_output.csv' + images_path, + "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_acdc_output.csv", ) segm_path = os.path.join( - images_path, 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_segm.npz' + images_path, "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_segm.npz" ) - basename = 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_' + basename = "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def mito_segm(self): - return np.load(os.path.join( - self.images_path, - 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_GFP_segm_mask_otsu.npz' - ))['arr_0'] - + return np.load( + os.path.join( + self.images_path, + "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_GFP_segm_mask_otsu.npz", + ) + )["arr_0"] + def cells_3D_segm(self): - return np.load(os.path.join( - self.images_path, - 'Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_segm_7slices.npz' - ))['arr_0'] + return np.load( + os.path.join( + self.images_path, + "Point0019_ChannelGFP,mCardinal,Ph-3_Seq0019_s5_segm_7slices.npz", + ) + )["arr_0"] + class BABYtestData(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_BABY', 'evolve_testG_Brightfield', 'Position_1', - 'Images' - ) - intensity_image_path = os.path.join( - images_path, 'evolve_testG_Brightfield.tif' - ) - acdc_df_path = os.path.join( - images_path, 'evolve_testG_acdc_output.csv' + data_path, "test_BABY", "evolve_testG_Brightfield", "Position_1", "Images" ) - segm_path = os.path.join( - images_path, 'evolve_testG_segm.npz' - ) - basename = 'evolve_testG_' + intensity_image_path = os.path.join(images_path, "evolve_testG_Brightfield.tif") + acdc_df_path = os.path.join(images_path, "evolve_testG_acdc_output.csv") + segm_path = os.path.join(images_path, "evolve_testG_segm.npz") + basename = "evolve_testG_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'Brightfield') + + return load.loadData(self.intensity_image_path, "Brightfield") + class YeastMitoSnapshotData(_Data): def __init__(self): images_path = os.path.join( - data_path, 'test_snapshots', 'mtDNA_Anika', 'Position_10', - 'Images' - ) - intensity_image_path = os.path.join( - images_path, 'ASY15-1_0nM-10_s10_mNeon.tif' - ) - acdc_df_path = os.path.join( - images_path, 'ASY15-1_0nM-10_s10_acdc_output.csv' + data_path, "test_snapshots", "mtDNA_Anika", "Position_10", "Images" ) - segm_path = os.path.join( - images_path, 'ASY15-1_0nM-10_s10_segm.npz' - ) - basename = 'ASY15-1_0nM-10_s10_' + intensity_image_path = os.path.join(images_path, "ASY15-1_0nM-10_s10_mNeon.tif") + acdc_df_path = os.path.join(images_path, "ASY15-1_0nM-10_s10_acdc_output.csv") + segm_path = os.path.join(images_path, "ASY15-1_0nM-10_s10_segm.npz") + basename = "ASY15-1_0nM-10_s10_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) - + def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'mNeon') + + return load.loadData(self.intensity_image_path, "mNeon") + class MIA_KC_htb1_mCitrine(_Data): def __init__(self): images_path = os.path.join( - data_path, 'budding_yeast', 'TimeLapse_2D', - 'MIA_KC_htb1_mCitrine_labeled', 'Position_2', 'Images' + data_path, + "budding_yeast", + "TimeLapse_2D", + "MIA_KC_htb1_mCitrine_labeled", + "Position_2", + "Images", ) intensity_image_path = os.path.join( - images_path, '19-03-2021_KCY050_SCGE_s02_phase_contr.tif' + images_path, "19-03-2021_KCY050_SCGE_s02_phase_contr.tif" ) acdc_df_path = os.path.join( - images_path, '19-03-2021_KCY050_SCGE_s02_acdc_output.csv' - ) - segm_path = os.path.join( - images_path, '19-03-2021_KCY050_SCGE_s02_segm.npz' + images_path, "19-03-2021_KCY050_SCGE_s02_acdc_output.csv" ) - basename = '19-03-2021_KCY050_SCGE_s02_' + segm_path = os.path.join(images_path, "19-03-2021_KCY050_SCGE_s02_segm.npz") + basename = "19-03-2021_KCY050_SCGE_s02_" super().__init__( - images_path, intensity_image_path, acdc_df_path, segm_path, - basename + images_path, intensity_image_path, acdc_df_path, segm_path, basename ) def posData(self): from . import load - return load.loadData(self.intensity_image_path, 'phase_contr') \ No newline at end of file + + return load.loadData(self.intensity_image_path, "phase_contr") diff --git a/cellacdc/dataPrep.py b/cellacdc/dataPrep.py index 10b558445..9b01102b7 100755 --- a/cellacdc/dataPrep.py +++ b/cellacdc/dataPrep.py @@ -18,42 +18,69 @@ from tifffile.tifffile import TiffWriter, TiffFile from qtpy.QtCore import ( - Qt, QFile, QEventLoop, QSize, QRect, QRectF, - QObject, QThread, Signal, QSettings, QMutex, QWaitCondition + Qt, + QFile, + QEventLoop, + QSize, + QRect, + QRectF, + QObject, + QThread, + Signal, + QSettings, + QMutex, + QWaitCondition, ) from qtpy.QtGui import ( - QIcon, QKeySequence, QCursor, QTextBlockFormat, - QTextCursor, QFont + QIcon, + QKeySequence, + QCursor, + QTextBlockFormat, + QTextCursor, + QFont, ) from qtpy.QtWidgets import ( - QAction, QLabel, QWidget, QMainWindow, QMenu, QToolBar, QGridLayout, - QScrollBar, QComboBox, QFileDialog, QAbstractSlider, QMessageBox + QAction, + QLabel, + QWidget, + QMainWindow, + QMenu, + QToolBar, + QGridLayout, + QScrollBar, + QComboBox, + QFileDialog, + QAbstractSlider, + QMessageBox, ) from qtpy.compat import getexistingdirectory import pyqtgraph as pg -pg.setConfigOption('imageAxisOrder', 'row-major') + +pg.setConfigOption("imageAxisOrder", "row-major") # Custom modules from . import exception_handler -from . import load, prompts, apps, core, myutils +from . import load, prompts, apps, core, utils from . import widgets -from . import html_utils, myutils, darkBkgrColor, printl +from . import html_utils, utils, darkBkgrColor, printl from . import autopilot, workers from . import recentPaths_path from . import urls from . import io from .help import about -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except Exception as e: pass + class toCsvWorker(QObject): finished = Signal() progress = Signal(int) @@ -66,29 +93,28 @@ def run(self): posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) self.finished.emit() + class dataPrepWin(QMainWindow): sigClose = Signal(object) - def __init__( - self, parent=None, buttonToRestore=None, mainWin=None, - version=None - ): + def __init__(self, parent=None, buttonToRestore=None, mainWin=None, version=None): from .config import parser_args - self.debug = parser_args['debug'] + + self.debug = parser_args["debug"] super().__init__(parent) self._version = version - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='dataPrep' + logger, logs_path, log_path, log_filename = utils.setupLogger( + module="dataPrep" ) self.logger = logger if self._version is not None: - logger.info(f'Initializing Data Prep module v{self._version}...') + logger.info(f"Initializing Data Prep module v{self._version}...") else: - logger.info(f'Initializing Data Prep module...') + logger.info(f"Initializing Data Prep module...") self.log_path = log_path self.log_filename = log_filename @@ -99,7 +125,7 @@ def __init__( if mainWin is not None: self.app = mainWin.app - self._acdc_version = myutils.read_version() + self._acdc_version = utils.read_version() self.setWindowTitle(f"Cell-ACDC v{self._acdc_version} - data prep") self.setGeometry(100, 50, 850, 800) self.setWindowIcon(QIcon(":icon.ico")) @@ -149,17 +175,15 @@ def keyPressEvent(self, event): printl(self.freeRoiItem.bbox()) printl(self.freeRoiItem.mask().shape) from cellacdc.plot import imshow - - imshow( - posData.dataPrepFreeRoiLocalMask, self.freeRoiItem.mask() - ) + + imshow(posData.dataPrepFreeRoiLocalMask, self.freeRoiItem.mask()) cropROI = posData.cropROIs[0] x0, y0 = [int(round(c)) for c in cropROI.pos()] w, h = [int(round(c)) for c in cropROI.size()] - x1, y1 = x0+w, y0+h - + x1, y1 = x0 + w, y0 + h + printl(x0, y0) - + # printl(posData.all_npz_paths) # printl(posData.tif_paths) # for r, roi in enumerate(posData.bkgrROIs): @@ -191,15 +215,12 @@ def keyPressEvent(self, event): def gui_createActions(self): # File actions self.aboutAction = QAction("About Cell-ACDC", self) - self.infoAction = ( - QAction(QIcon(":info.svg"), "&How to prep the data...", self) - ) - self.openFolderAction = QAction( - QIcon(":folder-open.svg"), "&Open...", self - ) + self.infoAction = QAction(QIcon(":info.svg"), "&How to prep the data...", self) + self.openFolderAction = QAction(QIcon(":folder-open.svg"), "&Open...", self) self.exitAction = QAction("&Exit", self) - self.showInExplorerAction = QAction(QIcon(":drawer.svg"), - "&Show in Explorer/Finder", self) + self.showInExplorerAction = QAction( + QIcon(":drawer.svg"), "&Show in Explorer/Finder", self + ) self.showInExplorerAction.setDisabled(True) # Toolbar actions @@ -214,7 +235,7 @@ def gui_createActions(self): self.loadPosAction = QAction("Load different Position...", self) self.loadPosAction.setShortcut("Shift+P") - + toolTip = ( "Add crop ROI for multiple crops\n\n" "Multiple crops will be saved as Position_1, Position_2 " @@ -227,87 +248,85 @@ def gui_createActions(self): self.addBkrgRoiActon = QAction(QIcon(":bkgrRoi.svg"), toolTip, self) self.addBkrgRoiActon.setDisabled(True) - self.ZbackAction = QAction(QIcon(":zback.svg"), - "Use same z-slice from first frame to here", - self) + self.ZbackAction = QAction( + QIcon(":zback.svg"), "Use same z-slice from first frame to here", self + ) self.ZbackAction.setEnabled(False) - self.ZforwAction = QAction(QIcon(":zforw.svg"), - "Use same z-slice from here to last frame", - self) + self.ZforwAction = QAction( + QIcon(":zforw.svg"), "Use same z-slice from here to last frame", self + ) self.ZforwAction.setEnabled(False) - self.interpAction = QAction(QIcon(":interp.svg"), - "Interpolate z-slice from first slice to here", - self) + self.interpAction = QAction( + QIcon(":interp.svg"), "Interpolate z-slice from first slice to here", self + ) self.interpAction.setEnabled(False) - self.cropAction = QAction(QIcon(":crop.svg"), "Crop XY", self) self.cropAction.setToolTip( - 'Crop XY.\n\n' - 'If the button is disabled you need to click on the Start button ' - 'first.\n\n' - 'You can add as many crop ROIs as needed. If you use more than ' - 'one, the cropped data will be saved into sub-folders of each ' - 'cropped Position\n\n' - 'After adjusting the crop ROIs, click this button to apply the ' - 'crop and activate the save button.\n\n' - 'To save the cropped data click the Save button.' + "Crop XY.\n\n" + "If the button is disabled you need to click on the Start button " + "first.\n\n" + "You can add as many crop ROIs as needed. If you use more than " + "one, the cropped data will be saved into sub-folders of each " + "cropped Position\n\n" + "After adjusting the crop ROIs, click this button to apply the " + "crop and activate the save button.\n\n" + "To save the cropped data click the Save button." ) self.cropZaction = QAction(QIcon(":cropZ.svg"), "Crop z-slices", self) self.cropZaction.setToolTip( - 'Crop upper and bottom Z-slices.\n\n' - 'If the button is disabled you need to click on the Start button ' - 'first.\n\n' - 'USAGE: Click this button, adjust the lower and upper z-slices ' + "Crop upper and bottom Z-slices.\n\n" + "If the button is disabled you need to click on the Start button " + "first.\n\n" + "USAGE: Click this button, adjust the lower and upper z-slices " 'and click on "Apply crop" to activate the save button.\n\n' - 'To save the cropped data click the Save button.' + "To save the cropped data click the Save button." ) self.cropZaction.setEnabled(False) self.cropZaction.setCheckable(True) - + self.cropTaction = QAction( QIcon(":cropT.svg"), "Crop frames (time points)", self ) self.cropTaction.setToolTip( - 'Crop a specified time range.\n\n' - 'If the button is disabled you need to click on the Start button ' - 'first.\n\n' - 'USAGE: Click this button, adjust the start and end frame numbers ' + "Crop a specified time range.\n\n" + "If the button is disabled you need to click on the Start button " + "first.\n\n" + "USAGE: Click this button, adjust the start and end frame numbers " 'and click on "Apply crop" to activate the save button.\n\n' - 'To save the cropped data click the Save button.' + "To save the cropped data click the Save button." ) self.cropTaction.setEnabled(False) self.cropTaction.setCheckable(True) - + self.freeRoiAction = QAction( - QIcon(':drawFreeRoi.svg'), "Draw a freehand ROI", self + QIcon(":drawFreeRoi.svg"), "Draw a freehand ROI", self ) self.freeRoiAction.setToolTip( - 'Draw a freehand ROI.\n\n' - 'To remove a previously drawn ROI, activate the tool, ' + "Draw a freehand ROI.\n\n" + "To remove a previously drawn ROI, activate the tool, " 'right-click on the ROI, and select "Remove free-hand ROI".\n\n' - 'When running segmentation later in the segmentation module, ' - 'the objects outside of this ROI will be automatically removed ' - 'from the segmentation masks.\n\n' + "When running segmentation later in the segmentation module, " + "the objects outside of this ROI will be automatically removed " + "from the segmentation masks.\n\n" ) self.freeRoiAction.setEnabled(False) self.freeRoiAction.setCheckable(True) - - self.saveAction = QAction( - QIcon(":file-save.svg"), "Crop and save", self) + + self.saveAction = QAction(QIcon(":file-save.svg"), "Crop and save", self) self.saveAction.setEnabled(False) self.saveAction.setToolTip( - 'Save the image data.\n\n' - 'Saving is needed only to save the aligned (timelapse) and/or ' - 'cropped image data.\n\n' - 'If you did not align and you do not need cropping, there is ' - 'no need to save. The information about the z-slice to use for ' - 'segmentation, the background ROIs, and the ROI has already ' - 'been saved automatically.\n\n' - 'If the button is disabled you need to click on the Start button ' - 'first.' + "Save the image data.\n\n" + "Saving is needed only to save the aligned (timelapse) and/or " + "cropped image data.\n\n" + "If you did not align and you do not need cropping, there is " + "no need to save. The information about the z-slice to use for " + "segmentation, the background ROIs, and the ROI has already " + "been saved automatically.\n\n" + "If the button is disabled you need to click on the Start button " + "first." ) self.startAction = QAction(QIcon(":start.svg"), "Start process!", self) @@ -318,7 +337,7 @@ def gui_createActions(self): def gui_createMenuBar(self): menuBar = self.menuBar() menuBar.setNativeMenuBar(False) - + # File menu fileMenu = QMenu("&File", self) menuBar.addMenu(fileMenu) @@ -329,7 +348,7 @@ def gui_createMenuBar(self): fileMenu.addAction(self.loadPosAction) fileMenu.addSeparator() fileMenu.addAction(self.exitAction) - + # Help menu helpMenu = menuBar.addMenu("&Help") helpMenu.addAction(self.infoAction) @@ -348,17 +367,17 @@ def gui_createToolBars(self): fileToolBar.addAction(self.openFolderAction) fileToolBar.addAction(self.showInExplorerAction) fileToolBar.addAction(self.saveAction) - + editToolbar = self.addToolBar("Edit") # fileToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) editToolbar.setMovable(False) - + editToolbar.addAction(self.startAction) editToolbar.addAction(self.cropAction) editToolbar.addAction(self.cropZaction) editToolbar.addAction(self.cropTaction) editToolbar.addAction(self.freeRoiAction) - + navigateToolbar = QToolBar("Navigate", self) # navigateToolbar.setIconSize(QSize(toolbarSize, toolbarSize)) self.addToolBar(navigateToolbar) @@ -374,16 +393,18 @@ def gui_createToolBars(self): self.ROIshapeComboBox = QComboBox() self.ROIshapeComboBox.setFont(apps.font) - self.ROIshapeComboBox.SizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents) - self.ROIshapeComboBox.addItems([' 256x256 ']) - ROIshapeLabel = QLabel(html_utils.paragraph( - '   ROI standard shape: ') + self.ROIshapeComboBox.SizeAdjustPolicy( + QComboBox.SizeAdjustPolicy.AdjustToContents + ) + self.ROIshapeComboBox.addItems([" 256x256 "]) + ROIshapeLabel = QLabel( + html_utils.paragraph("   ROI standard shape: ") ) ROIshapeLabel.setBuddy(self.ROIshapeComboBox) navigateToolbar.addWidget(ROIshapeLabel) navigateToolbar.addWidget(self.ROIshapeComboBox) - self.ROIshapeLabel = QLabel(' Current ROI shape: 256 x 256') + self.ROIshapeLabel = QLabel(" Current ROI shape: 256 x 256") navigateToolbar.addWidget(self.ROIshapeLabel) def gui_connectActions(self): @@ -424,18 +445,17 @@ def gui_addGraphicsItems(self): self.ax1 = pg.PlotItem() self.ax1.invertY(True) self.ax1.setAspectLocked(True) - self.ax1.hideAxis('bottom') - self.ax1.hideAxis('left') + self.ax1.hideAxis("bottom") + self.ax1.hideAxis("left") self.graphLayout.addItem(self.ax1, row=1, col=1) - #Image histogram + # Image histogram self.hist = widgets.myHistogramLUTitem() self.graphLayout.addItem(self.hist, row=1, col=0) # Title - self.titleLabel = pg.LabelItem(justify='center', color='w', size='14pt') - self.titleLabel.setText( - 'File --> Open or Open recent to start the process') + self.titleLabel = pg.LabelItem(justify="center", color="w", size="14pt") + self.titleLabel.setText("File --> Open or Open recent to start the process") self.graphLayout.addItem(self.titleLabel, row=0, col=1) # Current frame text @@ -459,14 +479,14 @@ def removeFreeRoi(self): self.freeRoiMask = None posData = self.data[self.pos_i] posData.removeDataPrepFreeRoi(logger_func=self.logger.info) - + def showRemoveFreeRoiContextMenu(self, event): self.removeFreeRoiMenu = QMenu(self) - action = QAction('Remove free-hand ROI') + action = QAction("Remove free-hand ROI") action.triggered.connect(self.removeFreeRoi) self.removeFreeRoiMenu.addAction(action) self.removeFreeRoiMenu.exec_(event.screenPos()) - + def gui_connectGraphicsEvents(self): self.img.hoverEvent = self.gui_hoverEventImg self.img.mousePressEvent = self.gui_mousePressEventImg @@ -482,23 +502,26 @@ def gui_createImgWidgets(self): self.navigateScrollbar = QScrollBar(Qt.Horizontal) self.navigateScrollbar.setFixedHeight(20) self.navigateScrollbar.setDisabled(True) - navSB_label = QLabel('') + navSB_label = QLabel("") navSB_label.setFont(_font) self.navigateSB_label = navSB_label - self.zSliceScrollBar = QScrollBar(Qt.Horizontal) self.zSliceScrollBar.setFixedHeight(20) self.zSliceScrollBar.setDisabled(True) - _z_label = QLabel('z-slice ') + _z_label = QLabel("z-slice ") _z_label.setFont(_font) self.z_label = _z_label self.zProjComboBox = QComboBox() - self.zProjComboBox.addItems(['single z-slice', - 'max z-projection', - 'mean z-projection', - 'median z-proj.']) + self.zProjComboBox.addItems( + [ + "single z-slice", + "max z-projection", + "mean z-projection", + "median z-proj.", + ] + ) self.zProjComboBox.setDisabled(True) self.img_Widglayout.addWidget(navSB_label, 0, 0, alignment=Qt.AlignCenter) @@ -520,45 +543,43 @@ def gui_hoverEventImg(self, event): Y, X = _img.shape if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: val = _img[ydata, xdata] - self.wcLabel.setText( - f'(x={xdata:.2f}, y={ydata:.2f}, value={val:.2f})' - ) + self.wcLabel.setText(f"(x={xdata:.2f}, y={ydata:.2f}, value={val:.2f})") else: - self.wcLabel.setText(f'') + self.wcLabel.setText(f"") except Exception as e: - self.wcLabel.setText(f'') + self.wcLabel.setText(f"") def showInExplorer(self): try: posData = self.data[self.pos_i] systems = { - 'nt': os.startfile, - 'posix': lambda foldername: os.system('xdg-open "%s"' % foldername), - 'os2': lambda foldername: os.system('open "%s"' % foldername) - } + "nt": os.startfile, + "posix": lambda foldername: os.system('xdg-open "%s"' % foldername), + "os2": lambda foldername: os.system('open "%s"' % foldername), + } systems.get(os.name, os.startfile)(posData.images_path) except AttributeError: pass - + def loadPosTriggered(self): if not self.isDataLoaded: return - + self.startAutomaticLoadingPos() - + def startAutomaticLoadingPos(self): self.AutoPilot = autopilot.AutoPilot(self) self.AutoPilot.execLoadPos() - + def stopAutomaticLoadingPos(self): if self.AutoPilot is None: return - + if self.AutoPilot.timer.isActive(): self.AutoPilot.timer.stop() self.AutoPilot = None - + def updatePos(self): self.updateCropZtool() self.setImageNameText() @@ -567,13 +588,13 @@ def updatePos(self): self.updateFreeRoiItem() self.updateBkgrROIs() self.saveBkgrROIs(self.data[self.pos_i]) - + def clearCurrentPos(self): self.removeBkgrROIs() self.removeCropROIs() def skip10ahead_frames(self): - if self.frame_i < self.num_frames-10: + if self.frame_i < self.num_frames - 10: self.frame_i += 10 else: self.frame_i = 0 @@ -583,7 +604,7 @@ def skip10back_frames(self): if self.frame_i > 9: self.frame_i -= 10 else: - self.frame_i = self.num_frames-1 + self.frame_i = self.num_frames - 1 self.update_img() def updateNavigateItems(self): @@ -592,24 +613,22 @@ def updateNavigateItems(self): # self.frameLabel.setText( # f'Current position = {self.pos_i+1}/{self.num_pos} ' # f'({posData.pos_foldername})') - self.navigateSB_label.setText(f'Pos n. {self.pos_i+1}') + self.navigateSB_label.setText(f"Pos n. {self.pos_i + 1}") try: self.navigateScrollbar.valueChanged.disconnect() except TypeError: pass - self.navigateScrollbar.setValue(self.pos_i+1) + self.navigateScrollbar.setValue(self.pos_i + 1) else: # self.frameLabel.setText( # f'Current frame = {self.frame_i+1}/{self.num_frames}') - self.navigateSB_label.setText(f'frame n. {self.frame_i+1}') + self.navigateSB_label.setText(f"frame n. {self.frame_i + 1}") try: self.navigateScrollbar.valueChanged.disconnect() except TypeError: pass - self.navigateScrollbar.setValue(self.frame_i+1) - self.navigateScrollbar.valueChanged.connect( - self.navigateScrollbarValueChanged - ) + self.navigateScrollbar.setValue(self.frame_i + 1) + self.navigateScrollbar.valueChanged.connect(self.navigateScrollbarValueChanged) def getImage(self, posData, img_data, frame_i, force_z=None): if posData.SizeT > 1: @@ -618,19 +637,19 @@ def getImage(self, posData, img_data, frame_i, force_z=None): img = img_data.copy() if posData.SizeZ > 1: if force_z is not None: - self.z_label.setText(f'z-slice {force_z+1}/{posData.SizeZ}') + self.z_label.setText(f"z-slice {force_z + 1}/{posData.SizeZ}") img = img[force_z] return img - df = posData.segmInfo_df + df = posData.segmInfo_df idx = (posData.filename, frame_i) try: - z = df.at[idx, 'z_slice_used_dataPrep'] + z = df.at[idx, "z_slice_used_dataPrep"] except Exception as e: duplicated_idx = df.index.duplicated() posData.segmInfo_df = df[~duplicated_idx] - z = posData.segmInfo_df.at[idx, 'z_slice_used_dataPrep'] - - zProjHow = posData.segmInfo_df.at[idx, 'which_z_proj'] + z = posData.segmInfo_df.at[idx, "z_slice_used_dataPrep"] + + zProjHow = posData.segmInfo_df.at[idx, "which_z_proj"] try: self.zProjComboBox.currentTextChanged.disconnect() except TypeError: @@ -638,17 +657,17 @@ def getImage(self, posData, img_data, frame_i, force_z=None): self.zProjComboBox.setCurrentText(zProjHow) self.zProjComboBox.currentTextChanged.connect(self.updateZproj) - if zProjHow == 'single z-slice': + if zProjHow == "single z-slice": self.zSliceScrollBar.valueChanged.disconnect() self.zSliceScrollBar.setSliderPosition(z) self.zSliceScrollBar.valueChanged.connect(self.update_z_slice) - self.z_label.setText(f'z-slice {z+1}/{posData.SizeZ}') + self.z_label.setText(f"z-slice {z + 1}/{posData.SizeZ}") img = img[z] - elif zProjHow == 'max z-projection': + elif zProjHow == "max z-projection": img = img.max(axis=0) - elif zProjHow == 'mean z-projection': + elif zProjHow == "mean z-projection": img = img.mean(axis=0) - elif zProjHow == 'median z-proj.': + elif zProjHow == "median z-proj.": img = np.median(img, axis=0) return img @@ -657,16 +676,16 @@ def update_img(self): self.updateNavigateItems() posData = self.data[self.pos_i] img = self.getImage(posData, posData.img_data, self.frame_i) - if self.zProjComboBox.currentText() == 'single z-slice': + if self.zProjComboBox.currentText() == "single z-slice": zslice = self.zSliceScrollBar.sliderPosition() else: zslice = None - + self.img.setCurrentZsliceIndex(zslice) self.img.setCurrentPosIndex(self.pos_i) self.img.setCurrentFrameIndex(self.frame_i) self.img.setImage(img) - self.zSliceScrollBar.setMaximum(posData.SizeZ-1) + self.zSliceScrollBar.setMaximum(posData.SizeZ - 1) def addAndConnectROI(self, roi): if roi not in self.ax1.items: @@ -675,15 +694,15 @@ def addAndConnectROI(self, roi): roi.sigRegionChanged.connect(self.updateCurrentRoiShape) roi.sigRegionChangeFinished.connect(self.ROImovingFinished) - + def addAndConnectCropROIs(self): if self.startAction.isEnabled() or self.onlySelectingZslice: return posData = self.data[self.pos_i] - if not hasattr(posData, 'cropROIs'): + if not hasattr(posData, "cropROIs"): return - + for cropROI in posData.cropROIs: self.addAndConnectROI(cropROI) @@ -692,9 +711,9 @@ def removeCropROIs(self): return posData = self.data[self.pos_i] - if not hasattr(posData, 'cropROIs'): - return - + if not hasattr(posData, "cropROIs"): + return + if posData.cropROIs is None: return @@ -707,9 +726,9 @@ def removeCropROIs(self): cropROI.sigRegionChangeFinished.disconnect() except TypeError: pass - + for c, cropROI in enumerate(posData.cropROIs): - cropROI.label.setText(f'ROI n. {c+1}') + cropROI.label.setText(f"ROI n. {c + 1}") def updateBkgrROIs(self): if self.startAction.isEnabled() or self.onlySelectingZslice: @@ -749,9 +768,7 @@ def init_attr(self): else: self.navigateScrollbar.setDisabled(True) self.navigateScrollbar.setValue(1) - self.navigateScrollbar.valueChanged.connect( - self.navigateScrollbarValueChanged - ) + self.navigateScrollbar.valueChanged.connect(self.navigateScrollbarValueChanged) self.startFrameIdxCrop = 0 self.endFrameIdxCrop = None self.isFreeRoiDrag = False @@ -760,10 +777,10 @@ def navigateScrollbarValueChanged(self, value): if self.num_pos > 1: self.removeBkgrROIs() self.removeCropROIs() - self.pos_i = value-1 + self.pos_i = value - 1 self.updatePos() else: - self.frame_i = value-1 + self.frame_i = value - 1 self.update_img() @exception_handler @@ -772,36 +789,34 @@ def crop(self, data, posData, cropROI): x0, y0 = [int(round(c)) for c in cropROI.pos()] w, h = [int(round(c)) for c in cropROI.size()] if data.ndim == 4: - croppedData = croppedData[:, :, y0:y0+h, x0:x0+w] + croppedData = croppedData[:, :, y0 : y0 + h, x0 : x0 + w] elif data.ndim == 3: - croppedData = croppedData[:, y0:y0+h, x0:x0+w] + croppedData = croppedData[:, y0 : y0 + h, x0 : x0 + w] elif data.ndim == 2: - croppedData = croppedData[y0:y0+h, x0:x0+w] - + croppedData = croppedData[y0 : y0 + h, x0 : x0 + w] + SizeZ = posData.SizeZ if posData.SizeZ > 1: idx = (posData.filename, 0) try: - lower_z = int(posData.segmInfo_df['crop_lower_z_slice'].iloc[0]) + lower_z = int(posData.segmInfo_df["crop_lower_z_slice"].iloc[0]) except KeyError: lower_z = 0 try: - upper_z = int(posData.segmInfo_df['crop_upper_z_slice'].iloc[0]) + upper_z = int(posData.segmInfo_df["crop_upper_z_slice"].iloc[0]) except KeyError: - upper_z = posData.SizeZ-1 + upper_z = posData.SizeZ - 1 if croppedData.ndim == 4: - croppedData = croppedData[:, lower_z:upper_z+1] + croppedData = croppedData[:, lower_z : upper_z + 1] elif croppedData.ndim == 3: - croppedData = croppedData[lower_z:upper_z+1] - SizeZ = (upper_z-lower_z)+1 - + croppedData = croppedData[lower_z : upper_z + 1] + SizeZ = (upper_z - lower_z) + 1 + if posData.SizeT > 1: - croppedData = croppedData[ - self.startFrameIdxCrop:self.endFrameIdxCrop - ] - + croppedData = croppedData[self.startFrameIdxCrop : self.endFrameIdxCrop] + return croppedData, SizeZ def saveBkgrROIs(self, posData): @@ -809,7 +824,7 @@ def saveBkgrROIs(self, posData): return ROIstates = [roi.saveState() for roi in posData.bkgrROIs] - with open(posData.dataPrepBkgrROis_path, 'w') as json_fp: + with open(posData.dataPrepBkgrROis_path, "w") as json_fp: json.dump(ROIstates, json_fp) def saveBkgrData(self, posData): @@ -822,21 +837,21 @@ def saveBkgrData(self, posData): for chName in posData.chNames: alignedFound = False tifFound = False - for file in myutils.listdir(posData.images_path): + for file in utils.listdir(posData.images_path): filePath = os.path.join(posData.images_path, file) filenameNOext, _ = os.path.splitext(file) - if file.endswith(f'{chName}_aligned.npz'): + if file.endswith(f"{chName}_aligned.npz"): aligned_filename = filenameNOext aligned_filePath = filePath alignedFound = True - elif file.find(f'{chName}.tif') != -1: + elif file.find(f"{chName}.tif") != -1: tif_filename = filenameNOext tif_path = filePath tifFound = True if alignedFound: filename = aligned_filename - chData = np.load(aligned_filePath)['arr_0'] + chData = np.load(aligned_filePath)["arr_0"] elif tifFound: filename = tif_filename chData = load.imread(tif_path) @@ -845,7 +860,7 @@ def saveBkgrData(self, posData): for r, roi in enumerate(posData.bkgrROIs): xl, yt = [int(round(c)) for c in roi.pos()] w, h = [int(round(c)) for c in roi.size()] - if not yt+h>yt or not xl+w>xl: + if not yt + h > yt or not xl + w > xl: # Prevent 0 height or 0 width roi continue is4D = posData.SizeT > 1 and posData.SizeZ > 1 @@ -853,21 +868,21 @@ def saveBkgrData(self, posData): is3Dt = posData.SizeT > 1 and posData.SizeZ == 1 is2D = posData.SizeT == 1 and posData.SizeZ == 1 if is4D: - bkgr_data = chData[:, :, yt:yt+h, xl:xl+w] + bkgr_data = chData[:, :, yt : yt + h, xl : xl + w] elif is3Dz or is3Dt: - bkgr_data = chData[:, yt:yt+h, xl:xl+w] + bkgr_data = chData[:, yt : yt + h, xl : xl + w] elif is2D: - bkgr_data = chData[yt:yt+h, xl:xl+w] - bkgrROI_data[f'roi{r}_data'] = bkgr_data + bkgr_data = chData[yt : yt + h, xl : xl + w] + bkgrROI_data[f"roi{r}_data"] = bkgr_data if bkgrROI_data: - bkgr_data_fn = f'{filename}_bkgrRoiData.npz' + bkgr_data_fn = f"{filename}_bkgrRoiData.npz" bkgr_data_path = os.path.join(posData.images_path, bkgr_data_fn) - print('---------------------------------') - self.logger.info('Saving background data to:') + print("---------------------------------") + self.logger.info("Saving background data to:") self.logger.info(bkgr_data_path) - print('*********************************') - print('') + print("*********************************") + print("") io.savez_compressed(bkgr_data_path, **bkgrROI_data) def removeAllROIs(self, event): @@ -889,8 +904,8 @@ def removeROI(self, event): except Exception as e: posData.cropROIs.remove(self.roi_to_del) for c, cropROI in enumerate(posData.cropROIs): - cropROI.label.setText(f'ROI n. {c+1}') - + cropROI.label.setText(f"ROI n. {c + 1}") + self.ax1.removeItem(self.roi_to_del.label) self.ax1.removeItem(self.roi_to_del) if not posData.bkgrROIs: @@ -900,7 +915,7 @@ def removeROI(self, event): pass else: self.saveBkgrROIs(posData) - + def gui_raiseContextMenuRoi(self, roi, event, is_bkgr_ROI=True): self.roi_to_del = roi self.roiContextMenu = QMenu(self) @@ -908,17 +923,17 @@ def gui_raiseContextMenuRoi(self, roi, event, is_bkgr_ROI=True): separator.setSeparator(True) self.roiContextMenu.addAction(separator) if is_bkgr_ROI: - action1 = QAction('Remove background ROI') + action1 = QAction("Remove background ROI") else: - action1 = QAction('Remove crop ROI') + action1 = QAction("Remove crop ROI") action1.triggered.connect(self.removeROI) self.roiContextMenu.addAction(action1) if is_bkgr_ROI: - action2 = QAction('Remove ALL background ROIs') + action2 = QAction("Remove ALL background ROIs") action2.triggered.connect(self.removeAllROIs) self.roiContextMenu.addAction(action2) self.roiContextMenu.exec_(event.screenPos()) - + def gui_mousePressEventImg(self, event): posData = self.data[self.pos_i] right_click = event.button() == Qt.MouseButton.RightButton @@ -926,7 +941,7 @@ def gui_mousePressEventImg(self, event): freeRoiActive = self.freeRoiAction.isChecked() dragImg = left_click and not freeRoiActive - + if dragImg: pg.ImageItem.mousePressEvent(self.img, event) event.ignore() @@ -934,21 +949,23 @@ def gui_mousePressEventImg(self, event): x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) - + if freeRoiActive and self.freeRoiMask is not None and right_click: if self.isClickOnFreeRoi(xdata, ydata): self.showRemoveFreeRoiContextMenu(event) return - + handleSize = 7 # Check if right click on ROI for r, roi in enumerate(posData.bkgrROIs): x0, y0 = [int(c) for c in roi.pos()] w, h = [int(c) for c in roi.size()] - x1, y1 = x0+w, y0+h + x1, y1 = x0 + w, y0 + h clickedOnROI = ( - x>=x0-handleSize and x<=x1+handleSize - and y>=y0-handleSize and y<=y1+handleSize + x >= x0 - handleSize + and x <= x1 + handleSize + and y >= y0 - handleSize + and y <= y1 + handleSize ) raiseContextMenuRoi = right_click and clickedOnROI dragRoi = left_click and clickedOnROI @@ -958,45 +975,47 @@ def gui_mousePressEventImg(self, event): elif dragRoi and not freeRoiActive: event.ignore() return - + if left_click and freeRoiActive: self.isFreeRoiDrag = True self.freeRoiItem.clear() return - - if not hasattr(posData, 'cropROIs'): + + if not hasattr(posData, "cropROIs"): return - + if posData.cropROIs is None: return - + for c, cropROI in enumerate(posData.cropROIs): x0, y0 = [int(c) for c in cropROI.pos()] w, h = [int(c) for c in cropROI.size()] - x1, y1 = x0+w, y0+h + x1, y1 = x0 + w, y0 + h clickedOnROI = ( - x>=x0-handleSize and x<=x1+handleSize - and y>=y0-handleSize and y<=y1+handleSize + x >= x0 - handleSize + and x <= x1 + handleSize + and y >= y0 - handleSize + and y <= y1 + handleSize ) dragRoi = left_click and clickedOnROI if dragRoi: event.ignore() return - raiseContextMenuRoi = right_click and clickedOnROI and c>0 + raiseContextMenuRoi = right_click and clickedOnROI and c > 0 if raiseContextMenuRoi: self.gui_raiseContextMenuRoi(cropROI, event, is_bkgr_ROI=False) - + def gui_mouseDragEventImg(self, event): posData = self.data[self.pos_i] x, y = event.pos().x(), event.pos().y() Y, X = posData.img_data.shape[-2:] xdata, ydata = int(x), int(y) - if not myutils.is_in_bounds(xdata, ydata, X, Y): + if not utils.is_in_bounds(xdata, ydata, X, Y): return - + if self.isFreeRoiDrag: self.freeRoiItem.addPoint(xdata, ydata) - + def saveFreeRoi(self): posData = self.data[self.pos_i] xx, yy = self.freeRoiItem.getData() @@ -1006,25 +1025,25 @@ def saveFreeRoi(self): self.freeRoiItem, logger_func=self.logger.info ) self.dataPrepFreeRoiSaved() - + def gui_mouseReleaseEventImg(self, event): posData = self.data[self.pos_i] if self.isFreeRoiDrag: self.freeRoiItem.closeCurve() self.saveFreeRoi() - + def dataPrepFreeRoiSaved(self): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'Free-hand ROI saved.

' - 'When you segment the data, the objects outside of this ROI will be ' - 'automatically removed from the segmentation masks.

' - 'To remove the free-hand ROI, right-click on it and select ' + "Free-hand ROI saved.

" + "When you segment the data, the objects outside of this ROI will be " + "automatically removed from the segmentation masks.

" + "To remove the free-hand ROI, right-click on it and select " '"Remove free-hand ROI".

' - 'See you at the next one!' + "See you at the next one!" ) - msg.information(self, 'Free-hand ROI saved', txt) - + msg.information(self, "Free-hand ROI saved", txt) + def isClickOnFreeRoi(self, xdata, ydata): y0, x0 = self.freeRoiYXorigin local_x = xdata - x0 @@ -1033,61 +1052,60 @@ def isClickOnFreeRoi(self, xdata, ydata): return False if local_y < 0 or local_y >= self.freeRoiMask.shape[0]: return False - + return self.freeRoiMask[local_y, local_x] - + def getAllChannelsPaths(self, posData): _zip = zip(posData.tif_paths, posData.all_npz_paths) for tif_path, npz_path in _zip: if self.align: - uncropped_data = np.load(npz_path)['arr_0'] + uncropped_data = np.load(npz_path)["arr_0"] else: uncropped_data = load.imread(tif_path) - + yield uncropped_data, npz_path, tif_path - + def saveCroppedChannel(self, cropped_data, npz_path, tif_path, posData): - self.logger.info( - f'Saving cropped data with shape {cropped_data.shape}' - ) + self.logger.info(f"Saving cropped data with shape {cropped_data.shape}") if self.align or os.path.exists(npz_path): - self.logger.info(f'Saving: {npz_path}') + self.logger.info(f"Saving: {npz_path}") temp_npz = self.getTempfilePath(npz_path) io.savez_compressed(temp_npz, cropped_data) self.moveTempFile(temp_npz, npz_path) - self.logger.info(f'Saving: {tif_path}') + self.logger.info(f"Saving: {tif_path}") temp_tif = self.getTempfilePath(tif_path) - myutils.to_tiff( - temp_tif, cropped_data, - SizeT=getattr(posData, 'SizeT', None), - SizeZ=getattr(posData, 'SizeZ', None), - TimeIncrement=getattr(posData, 'TimeIncrement', None), - PhysicalSizeZ=getattr(posData, 'PhysicalSizeZ', None), - PhysicalSizeY=getattr(posData, 'PhysicalSizeY', None), - PhysicalSizeX=getattr(posData, 'PhysicalSizeX', None), + utils.to_tiff( + temp_tif, + cropped_data, + SizeT=getattr(posData, "SizeT", None), + SizeZ=getattr(posData, "SizeZ", None), + TimeIncrement=getattr(posData, "TimeIncrement", None), + PhysicalSizeZ=getattr(posData, "PhysicalSizeZ", None), + PhysicalSizeY=getattr(posData, "PhysicalSizeY", None), + PhysicalSizeX=getattr(posData, "PhysicalSizeX", None), ) self.moveTempFile(temp_tif, tif_path) - + def saveCroppedSegmData(self, posData, segm_npz_path, cropROI): if not posData.segmFound: return - self.logger.info(f'Saving: {segm_npz_path}') + self.logger.info(f"Saving: {segm_npz_path}") croppedSegm, _ = self.crop(posData.segm_data, posData, cropROI) temp_npz = self.getTempfilePath(segm_npz_path) io.savez_compressed(temp_npz, croppedSegm) self.moveTempFile(temp_npz, segm_npz_path) - + def correctAcdcDfCrop(self, posData, acdc_output_csv_path, cropROI): try: # Correct acdc_df if present and save if posData.acdc_df is not None: x0, y0 = [int(round(c)) for c in cropROI.pos()] - self.logger.info(f'Saving: {acdc_output_csv_path}') + self.logger.info(f"Saving: {acdc_output_csv_path}") df = posData.acdc_df - df['x_centroid'] -= x0 - df['y_centroid'] -= y0 + df["x_centroid"] -= x0 + df["y_centroid"] -= y0 try: df.to_csv(acdc_output_csv_path) except PermissionError: @@ -1095,194 +1113,173 @@ def correctAcdcDfCrop(self, posData, acdc_output_csv_path, cropROI): df.to_csv(acdc_output_csv_path) except Exception as e: pass - + def copyAdditionalFilesToCropFolder( - self, posData, subImagesPath, cropBasename, cropIdx=0 - ): - subImagesPath = subImagesPath.replace('\\', '/') - parentImagesPath = posData.images_path.replace('\\', '/') + self, posData, subImagesPath, cropBasename, cropIdx=0 + ): + subImagesPath = subImagesPath.replace("\\", "/") + parentImagesPath = posData.images_path.replace("\\", "/") if parentImagesPath == subImagesPath: return - + basename = posData.basename try: df_roi = posData.dataPrep_ROIcoords.loc[[cropIdx]] - df_roi_filename = os.path.basename( - posData.dataPrepROI_coords_path - ) - df_roi_endname = df_roi_filename[len(basename):] - crop_df_roi_filename = f'{cropBasename}{df_roi_endname}' - df_roi_filepath = os.path.join( - subImagesPath, crop_df_roi_filename - ) + df_roi_filename = os.path.basename(posData.dataPrepROI_coords_path) + df_roi_endname = df_roi_filename[len(basename) :] + crop_df_roi_filename = f"{cropBasename}{df_roi_endname}" + df_roi_filepath = os.path.join(subImagesPath, crop_df_roi_filename) df_roi.to_csv(df_roi_filepath) except IndexError: pass - - for file in myutils.listdir(posData.images_path): + + for file in utils.listdir(posData.images_path): copy_file = ( - file.endswith('bkgrRoiData.npz') - or file.endswith('dataPrep_bkgrROIs.json') - or file.endswith('segmInfo.csv') - or file.endswith('dataPrepFreeRoi.npz') + file.endswith("bkgrRoiData.npz") + or file.endswith("dataPrep_bkgrROIs.json") + or file.endswith("segmInfo.csv") + or file.endswith("dataPrepFreeRoi.npz") ) - is_metadata_file = file.endswith('metadata.csv') + is_metadata_file = file.endswith("metadata.csv") if not copy_file and not is_metadata_file: continue - + src_filepath = os.path.join(posData.images_path, file) - endname = file[len(basename):] - crop_filename = f'{cropBasename}{endname}' + endname = file[len(basename) :] + crop_filename = f"{cropBasename}{endname}" sub_filepath = os.path.join(subImagesPath, crop_filename) if os.path.exists(sub_filepath): continue - + if copy_file: shutil.copyfile(src_filepath, sub_filepath) elif is_metadata_file: - df_metadata = pd.read_csv( - src_filepath, index_col='Description' - ) - df_metadata.at['basename', 'values'] = cropBasename + df_metadata = pd.read_csv(src_filepath, index_col="Description") + df_metadata.at["basename", "values"] = cropBasename df_metadata.to_csv(sub_filepath) - + def saveSingleCrop(self, posData, cropROI, dstPath): if dstPath != posData.images_path: - currentSubPosFolders = myutils.get_pos_foldernames(dstPath) + currentSubPosFolders = utils.get_pos_foldernames(dstPath) currentSubPosNumbers = [ - int(pos.split('_')[-1]) for pos in currentSubPosFolders + int(pos.split("_")[-1]) for pos in currentSubPosFolders ] startPosNumber = max(currentSubPosNumbers, default=0) + 1 cropNum = startPosNumber - subPosFolder = f'Position_{cropNum}' + subPosFolder = f"Position_{cropNum}" subPosFolderPath = os.path.join(dstPath, subPosFolder) - subImagesPath = os.path.join(subPosFolderPath, 'Images') + subImagesPath = os.path.join(subPosFolderPath, "Images") os.makedirs(subImagesPath) - cropBasename = f'{posData.basename}crop{cropNum}_' + cropBasename = f"{posData.basename}crop{cropNum}_" else: subImagesPath = dstPath cropBasename = posData.basename - + self._saveCroppedData(posData, subImagesPath, cropROI, cropBasename) - + def _saveCroppedData( - self, posData, subImagesPath, cropROI, cropBasename, cropIdx=0 - ): + self, posData, subImagesPath, cropROI, cropBasename, cropIdx=0 + ): basename = posData.basename _iter = self.getAllChannelsPaths(posData) for uncropped_data, npz_path, tif_path in _iter: cropped_data, _ = self.crop(uncropped_data, posData, cropROI) npz_filename = os.path.basename(npz_path) tif_filename = os.path.basename(tif_path) - npz_endname = npz_filename[len(basename):] - tif_endname = tif_filename[len(basename):] - crop_npz_filename = f'{cropBasename}{npz_endname}' - crop_tif_filename = f'{cropBasename}{tif_endname}' + npz_endname = npz_filename[len(basename) :] + tif_endname = tif_filename[len(basename) :] + crop_npz_filename = f"{cropBasename}{npz_endname}" + crop_tif_filename = f"{cropBasename}{tif_endname}" sub_npz_filepath = os.path.join(subImagesPath, crop_npz_filename) sub_tif_filepath = os.path.join(subImagesPath, crop_tif_filename) self.saveCroppedChannel( - cropped_data, sub_npz_filepath, sub_tif_filepath, - posData + cropped_data, sub_npz_filepath, sub_tif_filepath, posData ) - + segm_filename = os.path.basename(posData.segm_npz_path) - segm_endname = segm_filename[len(basename):] - crop_segm_filename = f'{cropBasename}{segm_endname}' + segm_endname = segm_filename[len(basename) :] + crop_segm_filename = f"{cropBasename}{segm_endname}" sub_segm_filepath = os.path.join(subImagesPath, crop_segm_filename) self.saveCroppedSegmData(posData, sub_segm_filepath, cropROI) - + acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) - acdc_df_endname = acdc_df_filename[len(basename):] - crop_acdc_df_filename = f'{cropBasename}{acdc_df_endname}' + acdc_df_endname = acdc_df_filename[len(basename) :] + crop_acdc_df_filename = f"{cropBasename}{acdc_df_endname}" acdc_df_filepath = os.path.join(subImagesPath, crop_acdc_df_filename) self.correctAcdcDfCrop(posData, acdc_df_filepath, cropROI) - - self.saveMasterFolderPathTxt( - posData, subImagesPath, basename=cropBasename - ) - + + self.saveMasterFolderPathTxt(posData, subImagesPath, basename=cropBasename) + self.copyAdditionalFilesToCropFolder( posData, subImagesPath, cropBasename, cropIdx=cropIdx ) - + def saveMasterFolderPathTxt(self, posData, subImagesPath, basename=None): - subImagesPath = subImagesPath.replace('\\', '/') - parentImagesPath = posData.images_path.replace('\\', '/') + subImagesPath = subImagesPath.replace("\\", "/") + parentImagesPath = posData.images_path.replace("\\", "/") if parentImagesPath == subImagesPath: return - + if basename is None: basename = posData.basename - - filename = f'{basename}master_position.txt' + + filename = f"{basename}master_position.txt" filepath = os.path.join(subImagesPath, filename) - masterPos = posData.pos_path.replace('\\', os.sep).replace('/', os.sep) - with open(filepath, 'w') as txt: + masterPos = posData.pos_path.replace("\\", os.sep).replace("/", os.sep) + with open(filepath, "w") as txt: txt.write(masterPos) - + def startCropWorker(self, posData, dstPath): # Disable clicks on image during alignment self.img.mousePressEvent = None - + if posData.SizeT > 1: self.progressWin = apps.QDialogWorkerProgress( - title='Saving cropped data', + title="Saving cropped data", parent=self, - pbarDesc=f'Saving cropped data...' + pbarDesc=f"Saving cropped data...", ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(0) - + self._thread = QThread() - + self.cropWorker = workers.DataPrepCropWorker(posData, self, dstPath) self.cropWorker.moveToThread(self._thread) - + self.cropWorker.moveToThread(self._thread) self.cropWorker.signals.finished.connect(self._thread.quit) - self.cropWorker.signals.finished.connect( - self.cropWorker.deleteLater - ) + self.cropWorker.signals.finished.connect(self.cropWorker.deleteLater) self._thread.finished.connect(self._thread.deleteLater) - self.cropWorker.signals.finished.connect( - self.cropWorkerFinished - ) + self.cropWorker.signals.finished.connect(self.cropWorkerFinished) self.cropWorker.signals.progress.connect(self.workerProgress) - self.cropWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.cropWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.cropWorker.signals.critical.connect( - self.workerCritical - ) - + self.cropWorker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.cropWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.cropWorker.signals.critical.connect(self.workerCritical) + self._thread.started.connect(self.cropWorker.run) self._thread.start() return self.cropWorker - + def startSaveBkgrDataWorker(self, posData): # Disable clicks on image during alignment self.img.mousePressEvent = None - + if posData.SizeT > 1: self.progressWin = apps.QDialogWorkerProgress( - title='Saving background data', + title="Saving background data", parent=self, - pbarDesc=f'Saving background data...' + pbarDesc=f"Saving background data...", ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(0) - + self._thread = QThread() - - self.saveBkgrDataWorker = workers.DataPrepSaveBkgrDataWorker( - posData, self - ) + + self.saveBkgrDataWorker = workers.DataPrepSaveBkgrDataWorker(posData, self) self.saveBkgrDataWorker.moveToThread(self._thread) - + self.saveBkgrDataWorker.moveToThread(self._thread) self.saveBkgrDataWorker.signals.finished.connect(self._thread.quit) self.saveBkgrDataWorker.signals.finished.connect( @@ -1300,14 +1297,12 @@ def startSaveBkgrDataWorker(self, posData): self.saveBkgrDataWorker.signals.progressBar.connect( self.workerUpdateProgressbar ) - self.saveBkgrDataWorker.signals.critical.connect( - self.workerCritical - ) - + self.saveBkgrDataWorker.signals.critical.connect(self.workerCritical) + self._thread.started.connect(self.saveBkgrDataWorker.run) self._thread.start() return self.saveBkgrDataWorker - + def saveCroppedData(self, posData, cropDstPaths): if len(posData.cropROIs) == 1: worker = self.startCropWorker(posData, cropDstPaths[0]) @@ -1315,31 +1310,31 @@ def saveCroppedData(self, posData, cropDstPaths): else: self.saveMultiCrops(posData, cropDstPaths) - self.logger.info(f'{posData.pos_foldername} saved!') - print(f'--------------------------------') - print('') - - def saveMultiCrops(self, posData, cropDstPaths): + self.logger.info(f"{posData.pos_foldername} saved!") + print(f"--------------------------------") + print("") + + def saveMultiCrops(self, posData, cropDstPaths): basename = posData.basename for p, cropROI in enumerate(posData.cropROIs): parentSubPosPath = cropDstPaths[p] - currentSubPosFolders = myutils.get_pos_foldernames(parentSubPosPath) + currentSubPosFolders = utils.get_pos_foldernames(parentSubPosPath) currentSubPosNumbers = [ - int(pos.split('_')[-1]) for pos in currentSubPosFolders + int(pos.split("_")[-1]) for pos in currentSubPosFolders ] startPosNumber = max(currentSubPosNumbers, default=0) + 1 cropNum = startPosNumber - subPosFolder = f'Position_{cropNum}' + subPosFolder = f"Position_{cropNum}" subPosFolderPath = os.path.join(parentSubPosPath, subPosFolder) - subImagesPath = os.path.join(subPosFolderPath, 'Images') + subImagesPath = os.path.join(subPosFolderPath, "Images") os.makedirs(subImagesPath) - - cropBasename = f'{basename}crop{cropNum}_' - + + cropBasename = f"{basename}crop{cropNum}_" + self._saveCroppedData( posData, subImagesPath, cropROI, cropBasename, cropIdx=p ) - + def saveROIcoords(self, doCrop, posData): dfs = [] keys = [] @@ -1348,43 +1343,39 @@ def saveROIcoords(self, doCrop, posData): w, h = [int(round(c)) for c in cropROI.size()] Y, X = self.img.image.shape - x1, y1 = x0+w, y0+h + x1, y1 = x0 + w, y0 + h - x0 = x0 if x0>0 else 0 - y0 = y0 if y0>0 else 0 - x1 = x1 if x1 0 else 0 + y0 = y0 if y0 > 0 else 0 + x1 = x1 if x1 < X else X + y1 = y1 if y1 < Y else Y - if x0<=0 and y0<=0 and x1>=X and y1>=Y: + if x0 <= 0 and y0 <= 0 and x1 >= X and y1 >= Y: # ROI coordinates are the exact image shape. No need to save them continue - + keys.append(c) - description = ['x_left', 'x_right', 'y_top', 'y_bottom', 'cropped'] + description = ["x_left", "x_right", "y_top", "y_bottom", "cropped"] values = [x0, x1, y0, y1, int(doCrop)] - df_roi = ( - pd.DataFrame({'description': description, 'value': values}) - .set_index('description') - ) - + df_roi = pd.DataFrame( + {"description": description, "value": values} + ).set_index("description") + dfs.append(df_roi) - + if not dfs: return - - df = pd.concat(dfs, keys=keys, names=['roi_id']) - self.logger.info( - f'Saving ROI coords ' - f'to "{posData.dataPrepROI_coords_path}"' - ) + df = pd.concat(dfs, keys=keys, names=["roi_id"]) + + self.logger.info(f'Saving ROI coords to "{posData.dataPrepROI_coords_path}"') try: df.to_csv(posData.dataPrepROI_coords_path) except PermissionError: self.permissionErrorCritical(posData.dataPrepROI_coords_path) df.to_csv(posData.dataPrepROI_coords_path) - + posData.dataPrep_ROIcoords = df def openCropZtool(self, checked): @@ -1401,7 +1392,7 @@ def openCropZtool(self, checked): # Restore original z-slice df = posData.segmInfo_df idx = (posData.filename, self.frame_i) - z = posData.segmInfo_df.at[idx, 'z_slice_used_dataPrep'] + z = posData.segmInfo_df.at[idx, "z_slice_used_dataPrep"] self.zSliceScrollBar.setValue(z) def openCropTtool(self, checked): @@ -1415,7 +1406,7 @@ def openCropTtool(self, checked): else: self.cropZtool.close() self.cropZtool = None - + def cropZtoolvalueChanged(self, whichZ, z): self.zSliceScrollBar.valueChanged.disconnect() self.zSliceScrollBar.setValue(z) @@ -1438,12 +1429,12 @@ def updateCropZtool(self): return try: - lower_z = int(posData.segmInfo_df['crop_lower_z_slice'].iloc[0]) + lower_z = int(posData.segmInfo_df["crop_lower_z_slice"].iloc[0]) except KeyError: lower_z = 0 try: - upper_z = int(posData.segmInfo_df['crop_upper_z_slice'].iloc[0]) + upper_z = int(posData.segmInfo_df["crop_upper_z_slice"].iloc[0]) except KeyError: upper_z = posData.SizeZ @@ -1455,59 +1446,57 @@ def cropZtoolClosed(self): self.cropZtool = None posData = self.data[self.pos_i] idx = (posData.filename, self.frame_i) - z = posData.segmInfo_df.at[idx, 'z_slice_used_dataPrep'] + z = posData.segmInfo_df.at[idx, "z_slice_used_dataPrep"] self.zSliceScrollBar.setSliderPosition(z) self.cropZaction.toggled.disconnect() self.cropZaction.setChecked(False) self.cropZaction.toggled.connect(self.openCropZtool) - + def cropTtoolClosed(self): self.cropTtool = None self.cropTaction.toggled.disconnect() self.cropTaction.setChecked(False) self.cropTaction.toggled.connect(self.openCropTtool) - + def cropTtoolvalueChanged(self, frame_i): - self.navigateScrollbar.setValue(frame_i+1) - + self.navigateScrollbar.setValue(frame_i + 1) + def applyCropTrange(self, start_frame_i, end_frame_i): self.startFrameIdxCrop = start_frame_i self.endFrameIdxCrop = end_frame_i + 1 self.logger.info( - f'Previewing cropped frames ({start_frame_i+1},{end_frame_i+1})...' + f"Previewing cropped frames ({start_frame_i + 1},{end_frame_i + 1})..." ) for posData in self.data: posData.img_data[:start_frame_i] = 0 - posData.img_data[end_frame_i+1:] = 0 - + posData.img_data[end_frame_i + 1 :] = 0 + self.update_img() note_text = ( - f'Done. Frames outside of the range ({start_frame_i+1},{end_frame_i+1}) ' + f"Done. Frames outside of the range ({start_frame_i + 1},{end_frame_i + 1}) " 'will appear black now. To save cropped data, click on the "Save" ' - 'button on the top toolbar.' + "button on the top toolbar." ) self.logger.info(note_text) - + txt = html_utils.paragraph(f""" Cropping frames applied.

Note that this is just a preview where the frames outside of the - range ({start_frame_i+1},{end_frame_i+1}) will look black.

+ range ({start_frame_i + 1},{end_frame_i + 1}) will look black.

To save cropped data, click on the Save cropped data button on the top toolbar. """) msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Preview cropped frames', txt) + msg.information(self, "Preview cropped frames", txt) def addFreeRoiItem(self): if self.freeRoiItem is not None: return - - self.freeRoiItem = widgets.PlotCurveItem( - pen=pg.mkPen(color='r', width=2) - ) + + self.freeRoiItem = widgets.PlotCurveItem(pen=pg.mkPen(color="r", width=2)) self.ax1.addItem(self.freeRoiItem) self.updateFreeRoiItem() - + def updateFreeRoiItem(self): posData = self.data[self.pos_i] for point in posData.dataPrepFreeRoiPoints: @@ -1516,25 +1505,25 @@ def updateFreeRoiItem(self): if len(posData.dataPrepFreeRoiPoints) == 0: self.freeRoiMask = None return - + xx, yy = self.freeRoiItem.getData() self.freeRoiYXorigin = (yy.min(), xx.min()) self.freeRoiMask = self.freeRoiItem.mask() - + def freeRoiActionToggled(self, checked): if checked: self.hideROIs() self.addFreeRoiItem() else: self.reAddROIs() - + def getCroppedData(self, askCropping=True, doValidateFreeRoi=False): for p, posData in enumerate(self.data): self.saveBkgrROIs(posData) # Get crop shape and print it data = posData.img_data - + allCropsData = [] for cropROI in posData.cropROIs: croppedData, SizeZ = self.crop(data, posData, cropROI) @@ -1542,74 +1531,70 @@ def getCroppedData(self, askCropping=True, doValidateFreeRoi=False): croppedShapes = [cropped.shape for cropped in allCropsData] isCropped = any([shape != data.shape for shape in croppedShapes]) - + proceed = True if isCropped: if p == 0 and askCropping: - proceed = self.askCropping( - data.shape, croppedShapes - ) + proceed = self.askCropping(data.shape, croppedShapes) doCrop = proceed else: doCrop = True else: doCrop = False - + if not proceed: self.setEnabledCropActions(True) - txt = ('Cropping cancelled.') - self.titleLabel.setText(txt, color='r') + txt = "Cropping cancelled." + self.titleLabel.setText(txt, color="r") self.logger.info(txt) yield None elif not isCropped: self.setEnabledCropActions(True) txt = ( - 'Crop ROI has same shape of the image --> no need to crop. ' - 'Process stopped.' + "Crop ROI has same shape of the image --> no need to crop. " + "Process stopped." ) - self.titleLabel.setText(txt, color='r') + self.titleLabel.setText(txt, color="r") self.logger.info(txt) - yield 'continue' + yield "continue" elif not doValidateFreeRoi: yield croppedShapes, posData, SizeZ, doCrop else: - proceed = self.validateFreeRoi(posData, warn=p==0) + proceed = self.validateFreeRoi(posData, warn=p == 0) if not proceed: self.setEnabledCropActions(True) - txt = ('Cropping cancelled because overlaps with free roi.') - self.titleLabel.setText(txt, color='r') + txt = "Cropping cancelled because overlaps with free roi." + self.titleLabel.setText(txt, color="r") self.logger.info(txt) yield None else: yield croppedShapes, posData, SizeZ, doCrop - + def validateFreeRoi(self, posData, warn=True): posData.loadDataPrepFreeRoi(logger_func=self.logger.info) if len(posData.dataPrepFreeRoiPoints) == 0: return True - + if len(posData.cropROIs) > 1: if warn: proceed = self.warnMultiCropsWithFreeRoi() else: proceed = True - + if proceed: posData.removeDataPrepFreeRoi() self.freeRoiItem.clear() self.freeRoiMask = None - return proceed - + return proceed + cropROI = posData.cropROIs[0] x0, y0 = [int(round(c)) for c in cropROI.pos()] w, h = [int(round(c)) for c in cropROI.size()] - x1, y1 = x0+w, y0+h - + x1, y1 = x0 + w, y0 + h + y0f, x0f, y1f, x1f = posData.dataPrepFreeRoiBbox - - is_free_roi_in_crop_bounds = ( - x0f >= x0 and x1f <= x1 and y0f >= y0 and y1f <= y1 - ) + + is_free_roi_in_crop_bounds = x0f >= x0 and x1f <= x1 and y0f >= y0 and y1f <= y1 if not is_free_roi_in_crop_bounds and warn: proceed = self.warnFreeRoiOverlapsWithCropRoi() else: @@ -1617,7 +1602,7 @@ def validateFreeRoi(self, posData, warn=True): if not proceed: return False - + # Adjust free-hand ROI according to crop ROI local_mask = posData.dataPrepFreeRoiLocalMask crop_x0, crop_y0, crop_x1, crop_y1 = None, None, None, None @@ -1626,36 +1611,36 @@ def validateFreeRoi(self, posData, warn=True): x0f = 0 else: x0f = x0f - x0 - + if y0f < y0: crop_y0 = y0 - y0f y0f = 0 else: y0f = y0f - y0 - + if x1f > x1: crop_x1 = x1 - x1f x1f = w else: x1f = x1f - x0 - + if y1f > y1: crop_y1 = y1 - y1f y1f = h else: y1f = y1f - y0 - local_mask = posData.dataPrepFreeRoiLocalMask[ - crop_y0:crop_y1, crop_x0:crop_x1 - ] + local_mask = posData.dataPrepFreeRoiLocalMask[crop_y0:crop_y1, crop_x0:crop_x1] bbox = (y0f, x0f, y1f, x1f) posData.saveDataPrepFreeRoi( - self.freeRoiItem, logger_func=self.logger.info, - bbox=bbox, local_mask=local_mask + self.freeRoiItem, + logger_func=self.logger.info, + bbox=bbox, + local_mask=local_mask, ) - + return proceed - + def warnFreeRoiOverlapsWithCropRoi(self): txt = html_utils.paragraph(f""" The crop ROI is smaller than the free-hand ROI.

@@ -1663,14 +1648,13 @@ def warnFreeRoiOverlapsWithCropRoi(self): """) msg = widgets.myMessageBox(wrapText=False) noButton, yesButton = msg.warning( - self, 'Crop ROI is smaller than free-hand ROI', txt, - buttonsTexts=( - 'No, stop cropping process', - 'Yes, continue with cropping' - ), + self, + "Crop ROI is smaller than free-hand ROI", + txt, + buttonsTexts=("No, stop cropping process", "Yes, continue with cropping"), ) return msg.clickedButton == yesButton - + def warnMultiCropsWithFreeRoi(self): txt = html_utils.paragraph(f""" You are about to create multiple crops and you also have a @@ -1682,94 +1666,91 @@ def warnMultiCropsWithFreeRoi(self): """) msg = widgets.myMessageBox(wrapText=False) noButton, yesButton = msg.warning( - self, 'Multiple crops with free-hand ROI', txt, - buttonsTexts=( - 'No, stop cropping process', - 'Yes, continue with cropping' - ), + self, + "Multiple crops with free-hand ROI", + txt, + buttonsTexts=("No, stop cropping process", "Yes, continue with cropping"), ) return msg.clickedButton == yesButton - + def applyCropZslices(self, low_z, high_z): self.logger.info( - f'Previewing cropped z-slices in the range ({low_z+1},{high_z+1})...' + f"Previewing cropped z-slices in the range ({low_z + 1},{high_z + 1})..." ) for posData in self.data: - posData.segmInfo_df['crop_lower_z_slice'] = low_z - posData.segmInfo_df['crop_upper_z_slice'] = high_z + posData.segmInfo_df["crop_lower_z_slice"] = low_z + posData.segmInfo_df["crop_upper_z_slice"] = high_z if posData.SizeT > 1: posData.img_data[:, :low_z] = 0 - posData.img_data[:, high_z+1:] = 0 + posData.img_data[:, high_z + 1 :] = 0 else: posData.img_data[:low_z] = 0 - posData.img_data[high_z+1:] = 0 - + posData.img_data[high_z + 1 :] = 0 + self.update_img() note_text = ( - f'Done. Z-slices outside of the range ({low_z+1},{high_z+1}) ' + f"Done. Z-slices outside of the range ({low_z + 1},{high_z + 1}) " 'will appear black now. To save cropped data, click on the "Save" ' - 'button on the top toolbar.' + "button on the top toolbar." ) self.logger.info(note_text) - + txt = html_utils.paragraph(f""" Cropping z-slice applied.

Note that this is just a preview where the z-slices outside of the - range ({low_z+1},{high_z+1}) will look black.

+ range ({low_z + 1},{high_z + 1}) will look black.

To save cropped data, click on the Save cropped data button on the top toolbar. """) msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'Preview cropped z-slices', txt) - + msg.information(self, "Preview cropped z-slices", txt) + def applyCropYX(self): for posData in self.data: for cropROI in posData.cropROIs: x0, y0 = [int(round(c)) for c in cropROI.pos()] w, h = [int(round(c)) for c in cropROI.size()] cropMask = np.zeros(posData.img_data.shape, dtype=bool) - cropMask[..., y0:y0+h, x0:x0+w] = True + cropMask[..., y0 : y0 + h, x0 : x0 + w] = True posData.img_data[~cropMask] = 0 - self.update_img() - + self.update_img() + def saveActionTriggered(self): if self.tempFilesToMove: cancel = self.warnSaveAlignedNotReversible() if not cancel: self.startMoveTempFilesWorker() self.waitMoveTempFilesWorker() - + self.cropAndSave() - + @exception_handler def cropAndSave(self): cropPaths = {} - for cropInfo in self.getCroppedData( - askCropping=True, doValidateFreeRoi=True - ): + for cropInfo in self.getCroppedData(askCropping=True, doValidateFreeRoi=True): if cropInfo is None: # Process cancelled by the user return - if cropInfo == 'continue': + if cropInfo == "continue": continue - + croppedShapes, posData, SizeZ, doCrop = cropInfo if len(croppedShapes) == 1: masterPath = posData.images_path else: masterPath = posData.pos_path - + cropPaths[masterPath] = len(croppedShapes) - + if not cropPaths: return - + win = apps.DataPrepSubCropsPathsDialog(cropPaths=cropPaths) win.exec_() if win.cancel: - txt = 'Cropping cancelled.' - self.titleLabel.setText(txt, color='r') + txt = "Cropping cancelled." + self.titleLabel.setText(txt, color="r") return dstPaths = win.folderPaths @@ -1778,92 +1759,90 @@ def cropAndSave(self): if cropInfo is None: # Process cancelled by the user return - - if cropInfo == 'continue': + + if cropInfo == "continue": continue - + croppedShapes, posData, SizeZ, doCrop = cropInfo posData.SizeZ = SizeZ # Update metadata with cropped SizeZ - posData.metadata_df.at['SizeZ', 'values'] = SizeZ + posData.metadata_df.at["SizeZ", "values"] = SizeZ posData.metadata_df.to_csv(posData.metadata_csv_path) - self.logger.info(f'Cropping {posData.relPath}...') + self.logger.info(f"Cropping {posData.relPath}...") self.titleLabel.setText( - 'Cropping... (check progress in the terminal)', - color='w') + "Cropping... (check progress in the terminal)", color="w" + ) - croppedShapesFormat = [f' --> {shape}' for shape in croppedShapes] - croppedShapesFormat = '\n'.join(croppedShapesFormat) - self.logger.info(f'Cropped data shape:\n{croppedShapesFormat}') + croppedShapesFormat = [f" --> {shape}" for shape in croppedShapes] + croppedShapesFormat = "\n".join(croppedShapesFormat) + self.logger.info(f"Cropped data shape:\n{croppedShapesFormat}") self.saveROIcoords(doCrop, posData) - self.logger.info('Saving background data...') - + self.logger.info("Saving background data...") + worker = self.startSaveBkgrDataWorker(posData) self.waitWorker(worker) - + if len(croppedShapes) == 1: masterPath = posData.images_path else: masterPath = posData.pos_path - - self.logger.info('Cropping...') + + self.logger.info("Cropping...") self.saveCroppedData(posData, dstPaths[masterPath]) - + for posData in self.data: self.disconnectROIs(posData) if posData.SizeZ > 1: # Save segmInfo try: - low_z = posData.segmInfo_df['crop_lower_z_slice'] - posData.segmInfo_df['z_slice_used_dataPrep'] -= low_z + low_z = posData.segmInfo_df["crop_lower_z_slice"] + posData.segmInfo_df["z_slice_used_dataPrep"] -= low_z except Exception as err: - pass - + pass + posData.segmInfo_df = posData.segmInfo_df.drop( - columns=['crop_lower_z_slice', 'crop_upper_z_slice'], - errors='ignore' + columns=["crop_lower_z_slice", "crop_upper_z_slice"], + errors="ignore", ) posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - txt = ( - 'Saved! You can close the program or load another position.' - ) - self.titleLabel.setText(txt, color='g') + txt = "Saved! You can close the program or load another position." + self.titleLabel.setText(txt, color="g") msg = widgets.myMessageBox() - txt = html_utils.paragraph(txt.replace('! ', '!

')) - msg.information(self, 'Data prep done', txt) - - self.saveAction.setEnabled(False) - + txt = html_utils.paragraph(txt.replace("! ", "!

")) + msg.information(self, "Data prep done", txt) + + self.saveAction.setEnabled(False) + def setEnabledCropActions(self, enabled): self.cropAction.setEnabled(enabled) self.cropZaction.setEnabled(enabled) self.saveAction.setEnabled(enabled) self.cropTaction.setEnabled(enabled) self.freeRoiAction.setEnabled(enabled) - - if not hasattr(self, 'data'): + + if not hasattr(self, "data"): return - + posData = self.data[self.pos_i] if posData.SizeZ == 1: self.cropZaction.setEnabled(False) - + if posData.SizeT == 1: self.cropTaction.setEnabled(False) - + def removeAllHandles(self, roi): for handle in roi.handles: - item = handle['item'] + item = handle["item"] item.disconnectROI(roi) if len(item.rois) == 0 and roi.scene() is not None: roi.scene().removeItem(item) roi.handles = [] roi.stateChanged() - + def disconnectROIs(self, posData): for cropROI in posData.cropROIs: try: @@ -1884,41 +1863,41 @@ def disconnectROIs(self, posData): roi.removable = False self.removeAllHandles(roi) - + self.addCropRoiActon.setDisabled(True) self.addBkrgRoiActon.setDisabled(True) self.cropTaction.setDisabled(True) self.freeRoiAction.setDisabled(True) - - self.logger.info('ROIs disconnected.') + + self.logger.info("ROIs disconnected.") def permissionErrorCritical(self, path): msg = QMessageBox() msg.critical( - self, 'Permission denied', - f'The below file is open in another app (Excel maybe?).\n\n' - f'{path}\n\n' + self, + "Permission denied", + f"The below file is open in another app (Excel maybe?).\n\n" + f"{path}\n\n" 'Close file and then press "Ok".', - msg.Ok + msg.Ok, ) def askCropping(self, dataShape, croppedShapes): - header_text = (f""" + header_text = f""" Data-prep information saved.

- """) + """ if len(self.data) > 1: - info_text = (""" + info_text = """ Do you also want to save cropped data?

- """) + """ else: - info_text = (f""" + info_text = f""" Do you also want to save cropped data from shape {dataShape} to the following shapes: {html_utils.to_list(croppedShapes, ordered=True)} - """) + """ important = html_utils.to_admonition( - 'Saving cropped data cannot be undone.', - admonition_type='Important' + "Saving cropped data cannot be undone.", admonition_type="Important" ) msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(f""" @@ -1931,28 +1910,28 @@ def askCropping(self, dataShape, croppedShapes): Do you want to continue with saving cropped data? """) noButton, yesButton = msg.warning( - self, 'Crop?', txt, - buttonsTexts=('No, do not crop.', 'Yes, crop please.') + self, "Crop?", txt, buttonsTexts=("No, do not crop.", "Yes, crop please.") ) return msg.clickedButton == yesButton def getDefaultROI(self, shrinkFactor=1): Y, X = self.img.image.shape - w, h = int(X*shrinkFactor), int(Y*shrinkFactor) + w, h = int(X * shrinkFactor), int(Y * shrinkFactor) - xc, yc = int(round(X/2)), int(round(Y/2)) + xc, yc = int(round(X / 2)), int(round(Y / 2)) # yt, xl = int(round(xc-w/2)), int(round(yc-h/2)) yt, xl = 0, 0 # Add ROI Rectangle cropROI = pg.ROI( - [xl, yt], [w, h], + [xl, yt], + [w, h], rotatable=False, removable=False, - pen=pg.mkPen(color='r'), - maxBounds=QRectF(QRect(0,0,X,Y)), + pen=pg.mkPen(color="r"), + maxBounds=QRectF(QRect(0, 0, X, Y)), scaleSnap=True, - translateSnap=True + translateSnap=True, ) return cropROI @@ -1960,7 +1939,7 @@ def setROIprops(self, roi, roiNumber=1): xl, yt = [int(round(c)) for c in roi.pos()] roi.handleSize = 7 - roi.label = pg.TextItem(f'ROI n. {roiNumber}', color='r') + roi.label = pg.TextItem(f"ROI n. {roiNumber}", color="r") roi.label.setFont(self.roiLabelFont) # hLabel = roi.label.rect().bottom() roi.label.setPos(xl, yt) @@ -1998,7 +1977,7 @@ def init_data(self, user_ch_file_paths, user_ch_name): load_last_tracked_i=False, load_metadata=True, load_dataprep_free_roi=True, - getTifPath=True + getTifPath=True, ) # If data was cropped then dataPrep_ROIcoords are useless @@ -2006,22 +1985,23 @@ def init_data(self, user_ch_file_paths, user_ch_name): posData.dataPrep_ROIcoords = None posData.loadAllImgPaths() - if f==0 and not self.metadataAlreadyAsked: + if f == 0 and not self.metadataAlreadyAsked: proceed = posData.askInputMetadata( self.num_pos, - ask_SizeT=self.num_pos==1, + ask_SizeT=self.num_pos == 1, ask_TimeIncrement=False, ask_PhysicalSizes=False, save=True, - askSegm3D=False + askSegm3D=False, ) self.isSegm3D = posData.isSegm3D self.SizeT = posData.SizeT self.SizeZ = posData.SizeZ if not proceed: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", + color="w", + ) return False self.AutoPilotProfile.storeOkAskInputMetadata() else: @@ -2042,20 +2022,20 @@ def init_data(self, user_ch_file_paths, user_ch_name): posData.SizeT = 1 posData.saveMetadata() except AttributeError: - print('') - print('====================================') + print("") + print("====================================") traceback.print_exc() - print('====================================') - print('') + print("====================================") + print("") self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) return False if posData is None: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) return False img_shape = posData.img_data.shape @@ -2063,31 +2043,32 @@ def init_data(self, user_ch_file_paths, user_ch_name): self.user_ch_name = user_ch_name SizeT = posData.SizeT SizeZ = posData.SizeZ - if f==0: - print('') - self.logger.info(f'Data shape = {img_shape}') - self.logger.info(f'Number of frames = {SizeT}') - self.logger.info(f'Number of z-slices per frame = {SizeZ}') + if f == 0: + print("") + self.logger.info(f"Data shape = {img_shape}") + self.logger.info(f"Number of frames = {SizeT}") + self.logger.info(f"Number of z-slices per frame = {SizeZ}") data.append(posData) - if SizeT>1 and self.num_pos>1: + if SizeT > 1 and self.num_pos > 1: path = os.path.normpath(file_path) path_li = path.split(os.sep) - rel_path = f'.../{"/".join(path_li[-3:])}' + rel_path = f".../{'/'.join(path_li[-3:])}" msg = QMessageBox() msg.critical( - self, 'Multiple Pos loading not allowed.', - f'The file {rel_path} has multiple frames over time.\n\n' - 'Loading multiple positions that contain frames over time ' - 'is not allowed.\n\n' - 'To analyse frames over time load one position at the time', - msg.Ok + self, + "Multiple Pos loading not allowed.", + f"The file {rel_path} has multiple frames over time.\n\n" + "Loading multiple positions that contain frames over time " + "is not allowed.\n\n" + "To analyse frames over time load one position at the time", + msg.Ok, ) self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) return False - + self.data = data self.init_segmInfo_df() self.init_attr() @@ -2104,7 +2085,7 @@ def init_segmInfo_df(self): ) if NO_segmInfo and posData.SizeZ > 1: filename = posData.filename - df = myutils.getDefault_SegmInfo_df(posData, filename) + df = utils.getDefault_SegmInfo_df(posData, filename) if posData.segmInfo_df is None: posData.segmInfo_df = df else: @@ -2121,7 +2102,7 @@ def init_segmInfo_df(self): self.z_label.setDisabled(False) self.zSliceScrollBar.setDisabled(False) self.zProjComboBox.setDisabled(False) - self.zSliceScrollBar.setMaximum(posData.SizeZ-1) + self.zSliceScrollBar.setMaximum(posData.SizeZ - 1) self.zSliceScrollBar.valueChanged.connect(self.update_z_slice) self.zProjComboBox.currentTextChanged.connect(self.updateZproj) if posData.SizeT > 1: @@ -2130,7 +2111,7 @@ def init_segmInfo_df(self): self.ZforwAction.setEnabled(True) df = posData.segmInfo_df idx = (posData.filename, self.frame_i) - how = posData.segmInfo_df.at[idx, 'which_z_proj'] + how = posData.segmInfo_df.at[idx, "which_z_proj"] self.zProjComboBox.setCurrentText(how) else: self.zSliceScrollBar.setDisabled(True) @@ -2138,37 +2119,36 @@ def init_segmInfo_df(self): self.z_label.setDisabled(True) def update_z_slice(self, z): - if self.zProjComboBox.currentText() == 'single z-slice': + if self.zProjComboBox.currentText() == "single z-slice": posData = self.data[self.pos_i] df = posData.segmInfo_df idx = (posData.filename, self.frame_i) - posData.segmInfo_df.at[idx, 'z_slice_used_dataPrep'] = z - posData.segmInfo_df.at[idx, 'z_slice_used_gui'] = z + posData.segmInfo_df.at[idx, "z_slice_used_dataPrep"] = z + posData.segmInfo_df.at[idx, "z_slice_used_gui"] = z self.update_img() posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - def updateZproj(self, how): posData = self.data[self.pos_i] for frame_i in range(self.frame_i, posData.SizeT): df = posData.segmInfo_df idx = (posData.filename, self.frame_i) - posData.segmInfo_df.at[idx, 'which_z_proj'] = how - posData.segmInfo_df.at[idx, 'which_z_proj_gui'] = how - if how == 'single z-slice': + posData.segmInfo_df.at[idx, "which_z_proj"] = how + posData.segmInfo_df.at[idx, "which_z_proj_gui"] = how + if how == "single z-slice": self.zSliceScrollBar.setDisabled(False) - self.z_label.setStyleSheet('color: black') + self.z_label.setStyleSheet("color: black") self.update_z_slice(self.zSliceScrollBar.sliderPosition()) else: self.zSliceScrollBar.setDisabled(True) - self.z_label.setStyleSheet('color: gray') + self.z_label.setStyleSheet("color: gray") self.update_img() # Apply same z-proj to future pos if posData.SizeT == 1: - for posData in self.data[self.pos_i+1:]: + for posData in self.data[self.pos_i + 1 :]: idx = (posData.filename, self.frame_i) - posData.segmInfo_df.at[idx, 'which_z_proj'] = how + posData.segmInfo_df.at[idx, "which_z_proj"] = how self.save_segmInfo_df_pos() @@ -2189,205 +2169,190 @@ def useSameZ_fromHereBack(self, event): how = self.zProjComboBox.currentText() posData = self.data[self.pos_i] df = posData.segmInfo_df - z = df.at[(posData.filename, self.frame_i), 'z_slice_used_dataPrep'] + z = df.at[(posData.filename, self.frame_i), "z_slice_used_dataPrep"] if posData.SizeT > 1: for i in range(0, self.frame_i): - df.at[(posData.filename, i), 'z_slice_used_dataPrep'] = z - df.at[(posData.filename, i), 'z_slice_used_gui'] = z - df.at[(posData.filename, i), 'which_z_proj'] = how + df.at[(posData.filename, i), "z_slice_used_dataPrep"] = z + df.at[(posData.filename, i), "z_slice_used_gui"] = z + df.at[(posData.filename, i), "which_z_proj"] = how posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) elif posData.SizeZ > 1: - for _posData in self.data[:self.pos_i]: + for _posData in self.data[: self.pos_i]: df = _posData.segmInfo_df - df.at[(_posData.filename, 0), 'z_slice_used_dataPrep'] = z - df.at[(_posData.filename, 0), 'z_slice_used_gui'] = z - df.at[(_posData.filename, 0), 'which_z_proj'] = how + df.at[(_posData.filename, 0), "z_slice_used_dataPrep"] = z + df.at[(_posData.filename, 0), "z_slice_used_gui"] = z + df.at[(_posData.filename, 0), "which_z_proj"] = how self.save_segmInfo_df_pos() def useSameZ_fromHereForw(self, event): how = self.zProjComboBox.currentText() posData = self.data[self.pos_i] df = posData.segmInfo_df - z = df.at[(posData.filename, self.frame_i), 'z_slice_used_dataPrep'] + z = df.at[(posData.filename, self.frame_i), "z_slice_used_dataPrep"] if posData.SizeT > 1: for i in range(self.frame_i, posData.SizeT): - df.at[(posData.filename, i), 'z_slice_used_dataPrep'] = z - df.at[(posData.filename, i), 'z_slice_used_gui'] = z - df.at[(posData.filename, i), 'which_z_proj'] = how + df.at[(posData.filename, i), "z_slice_used_dataPrep"] = z + df.at[(posData.filename, i), "z_slice_used_gui"] = z + df.at[(posData.filename, i), "which_z_proj"] = how posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) elif posData.SizeZ > 1: - for _posData in self.data[self.pos_i:]: + for _posData in self.data[self.pos_i :]: df = _posData.segmInfo_df - df.at[(_posData.filename, 0), 'z_slice_used_dataPrep'] = z - df.at[(_posData.filename, 0), 'z_slice_used_gui'] = z - df.at[(_posData.filename, 0), 'which_z_proj'] = how + df.at[(_posData.filename, 0), "z_slice_used_dataPrep"] = z + df.at[(_posData.filename, 0), "z_slice_used_gui"] = z + df.at[(_posData.filename, 0), "which_z_proj"] = how self.save_segmInfo_df_pos() def interp_z(self, event): posData = self.data[self.pos_i] df = posData.segmInfo_df - x0, z0 = 0, df.at[(posData.filename, 0), 'z_slice_used_dataPrep'] + x0, z0 = 0, df.at[(posData.filename, 0), "z_slice_used_dataPrep"] x1 = self.frame_i - z1 = df.at[(posData.filename, x1), 'z_slice_used_dataPrep'] + z1 = df.at[(posData.filename, x1), "z_slice_used_dataPrep"] f = scipy.interpolate.interp1d([x0, x1], [z0, z1]) xx = np.arange(0, self.frame_i) zz = np.round(f(xx)).astype(int) for i in range(self.frame_i): - df.at[(posData.filename, i), 'z_slice_used_dataPrep'] = zz[i] - df.at[(posData.filename, i), 'z_slice_used_gui'] = zz[i] - df.at[(posData.filename, i), 'which_z_proj'] = 'single z-slice' + df.at[(posData.filename, i), "z_slice_used_dataPrep"] = zz[i] + df.at[(posData.filename, i), "z_slice_used_gui"] = zz[i] + df.at[(posData.filename, i), "which_z_proj"] = "single z-slice" posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - + def waitAlignDataWorker(self): self.alignDataWorkerLoop = QEventLoop(self) self.alignDataWorkerLoop.exec_() - + def waitWorker(self, worker): worker.loop = QEventLoop(self) worker.loop.exec_() - - def workerProgress(self, text, loggerLevel='INFO'): + + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: - self.progressWin.logConsole.append('-'*60) + self.progressWin.logConsole.append("-" * 60) self.progressWin.logConsole.append(text) self.logger.log(getattr(logging, loggerLevel), text) - + def startAlignDataWorker(self, posData, align, user_ch_name, progressText): # Disable clicks on image during alignment self.img.mousePressEvent = None - + if posData.SizeT > 1: self.progressWin = apps.QDialogWorkerProgress( - title='Aligning data', - parent=self, - pbarDesc=progressText + title="Aligning data", parent=self, pbarDesc=progressText ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(0) - + self._thread = QThread() self.alignDataWorkerMutex = QMutex() self.alignDataWorkerWaitCond = QWaitCondition() - + self.alignDataWorker = workers.AlignDataWorker( - posData, self, self.alignDataWorkerMutex, - self.alignDataWorkerWaitCond + posData, self, self.alignDataWorkerMutex, self.alignDataWorkerWaitCond ) self.alignDataWorker.set_attr(align, user_ch_name) self.alignDataWorker.moveToThread(self._thread) - + self.alignDataWorker.signals.finished.connect(self._thread.quit) - self.alignDataWorker.signals.finished.connect( - self.alignDataWorker.deleteLater - ) + self.alignDataWorker.signals.finished.connect(self.alignDataWorker.deleteLater) self._thread.finished.connect(self._thread.deleteLater) - self.alignDataWorker.signals.finished.connect( - self.alignDataWorkerFinished - ) + self.alignDataWorker.signals.finished.connect(self.alignDataWorkerFinished) self.alignDataWorker.signals.progress.connect(self.workerProgress) - self.alignDataWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.alignDataWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.alignDataWorker.signals.critical.connect( - self.workerCritical - ) + self.alignDataWorker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.alignDataWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.alignDataWorker.signals.critical.connect(self.workerCritical) + + self.alignDataWorker.sigAskAlignSegmData.connect(self.askAlignSegmData) + self.alignDataWorker.sigWarnTifAligned.connect(self.warnTifAligned) - self.alignDataWorker.sigAskAlignSegmData.connect( - self.askAlignSegmData - ) - self.alignDataWorker.sigWarnTifAligned.connect( - self.warnTifAligned - ) - self._thread.started.connect(self.alignDataWorker.run) self._thread.start() - + @exception_handler def prepData(self, event): self.titleLabel.setText( - 'Prepping data... (check progress in the terminal)', - color='w') + "Prepping data... (check progress in the terminal)", color="w" + ) self.tempFilesToMove = {} doZip = False for p, posData in enumerate(self.data): self.startAction.setDisabled(True) nonTifFound = ( - any([npz is not None for npz in posData.npz_paths]) or - any([npy is not None for npy in posData.npy_paths]) or - posData.segmFound + any([npz is not None for npz in posData.npz_paths]) + or any([npy is not None for npy in posData.npy_paths]) + or posData.segmFound ) imagesPath = posData.images_path - zipPath = f'{imagesPath}.zip' - if nonTifFound and p==0: + zipPath = f"{imagesPath}.zip" + if nonTifFound and p == 0: txt = ( - 'Additional NON-tif files detected.

' - 'The requested experiment folder already contains .npy ' - 'or .npz files ' - 'most likely from previous analysis runs.

' - 'To avoid data losses we recommend zipping the ' + "Additional NON-tif files detected.

" + "The requested experiment folder already contains .npy " + "or .npz files " + "most likely from previous analysis runs.

" + "To avoid data losses we recommend zipping the " '"Images" folder.

' - 'If everything looks fine after prepping the data, ' - 'you can manually ' - 'delete the zip archive.

' - 'Do you want to automatically zip now?

' - 'PS: Zip archive location:

' - f'{zipPath}' + "If everything looks fine after prepping the data, " + "you can manually " + "delete the zip archive.

" + "Do you want to automatically zip now?

" + "PS: Zip archive location:

" + f"{zipPath}" ) txt = html_utils.paragraph(txt) msg = widgets.myMessageBox() _, yes, no = msg.warning( - self, 'NON-Tif data detected!', txt, - buttonsTexts=('Cancel', 'Yes', 'No') + self, + "NON-Tif data detected!", + txt, + buttonsTexts=("Cancel", "Yes", "No"), ) if msg.cancel: self.setEnabledCropActions(True) - self.titleLabel.setText('Process aborted', color='w') + self.titleLabel.setText("Process aborted", color="w") return if yes == msg.clickedButton: doZip = True if doZip: - self.logger.info(f'Zipping Images folder: {zipPath}') - shutil.make_archive(imagesPath, 'zip', imagesPath) + self.logger.info(f"Zipping Images folder: {zipPath}") + shutil.make_archive(imagesPath, "zip", imagesPath) success = self.alignData(self.user_ch_name, posData) if not success: - self.titleLabel.setText('Data prep cancelled.', color='r') + self.titleLabel.setText("Data prep cancelled.", color="r") return - if posData.SizeZ>1: + if posData.SizeZ > 1: posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) # For loop did not break, proceed with the rest self.update_img() - self.logger.info('Done.') + self.logger.info("Done.") self.addROIs() self.saveROIcoords(False, self.data[self.pos_i]) self.saveBkgrROIs(self.data[self.pos_i]) self.setEnabledCropActions(True) txt = ( - 'Data successfully prepped. You can now crop the images, ' - 'place the background ROIs, or close the program' + "Data successfully prepped. You can now crop the images, " + "place the background ROIs, or close the program" ) - self.titleLabel.setText(txt, color='w') + self.titleLabel.setText(txt, color="w") msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(txt) - msg.information(self, 'Dataprep completed', txt) + msg.information(self, "Dataprep completed", txt) def setStandardRoiShape(self, text): posData = self.data[self.pos_i] - if not hasattr(posData, 'cropROIs'): - return - + if not hasattr(posData, "cropROIs"): + return + if posData.cropROIs is None: return - if len(posData.cropROIs)>1: + if len(posData.cropROIs) > 1: return Y, X = posData.img_data.shape[-2:] - m = re.findall(r'(\d+)x(\d+)', text) + m = re.findall(r"(\d+)x(\d+)", text) w, h = int(m[0][0]), int(m[0][1]) # xc, yc = int(round(X/2)), int(round(Y/2)) # yt, xl = int(round(xc-w/2)), int(round(yc-h/2)) @@ -2396,34 +2361,34 @@ def setStandardRoiShape(self, text): def hideROIs(self): posData = self.data[self.pos_i] - if not hasattr(posData, 'cropROIs'): + if not hasattr(posData, "cropROIs"): return - + if posData.cropROIs is None: return - + for cropROI in posData.cropROIs: self.ax1.removeItem(cropROI.label) self.ax1.removeItem(cropROI) - + def reAddROIs(self): posData = self.data[self.pos_i] - if not hasattr(posData, 'cropROIs'): + if not hasattr(posData, "cropROIs"): return - + if posData.cropROIs is None: return - + for cropROI in posData.cropROIs: self.ax1.addItem(cropROI.label) self.ax1.addItem(cropROI) - + def addROIs(self): Y, X = self.img.image.shape - max_size = round(int(np.log2(min([Y, X])/16))) - items = [f'{16*(2**i)}x{16*(2**i)}' for i in range(1, max_size+1)] - items.append(f'{X}x{Y}') + max_size = round(int(np.log2(min([Y, X]) / 16))) + items = [f"{16 * (2**i)}x{16 * (2**i)}" for i in range(1, max_size + 1)] + items.append(f"{X}x{Y}") self.ROIshapeComboBox.clear() self.ROIshapeComboBox.addItems(items) self.ROIshapeComboBox.setCurrentText(items[-1]) @@ -2440,32 +2405,31 @@ def addROIs(self): n = 1 for roi_id, df_roi in grouped: df_roi = df_roi.loc[roi_id] - xl = df_roi.at['x_left', 'value'] - yt = df_roi.at['y_top', 'value'] - w = df_roi.at['x_right', 'value'] - xl - h = df_roi.at['y_bottom', 'value'] - yt + xl = df_roi.at["x_left", "value"] + yt = df_roi.at["y_top", "value"] + w = df_roi.at["x_right", "value"] - xl + h = df_roi.at["y_bottom", "value"] - yt cropROI = pg.ROI( - [xl, yt], [w, h], + [xl, yt], + [w, h], rotatable=False, removable=False, - pen=pg.mkPen(color='r'), - maxBounds=QRectF(QRect(0,0,X,Y)), + pen=pg.mkPen(color="r"), + maxBounds=QRectF(QRect(0, 0, X, Y)), scaleSnap=True, - translateSnap=True + translateSnap=True, ) self.setROIprops(cropROI, roiNumber=n) posData.cropROIs.append(cropROI) n += 1 - + self.addAndConnectCropROIs() try: self.ROIshapeComboBox.currentTextChanged.disconnect() except Exception as e: pass - self.ROIshapeComboBox.currentTextChanged.connect( - self.setStandardRoiShape - ) + self.ROIshapeComboBox.currentTextChanged.connect(self.setStandardRoiShape) self.addBkrgRoiActon.setDisabled(False) self.addCropRoiActon.setDisabled(False) @@ -2486,15 +2450,16 @@ def getDefaultBkgrROI(self): Y, X = self.img.image.shape xRange, yRange = self.ax1.viewRange() xl, yt = abs(xRange[0]), abs(yRange[0]) - w, h = int(X/8), int(Y/8) + w, h = int(X / 8), int(Y / 8) bkgrROI = pg.ROI( - [xl, yt], [w, h], + [xl, yt], + [w, h], rotatable=False, removable=False, - pen=pg.mkPen(color=(255,255,255)), - maxBounds=QRectF(QRect(0,0,X,Y)), + pen=pg.mkPen(color=(255, 255, 255)), + maxBounds=QRectF(QRect(0, 0, X, Y)), scaleSnap=True, - translateSnap=True + translateSnap=True, ) return bkgrROI @@ -2502,7 +2467,7 @@ def setBkgrROIprops(self, bkgrROI): bkgrROI.handleSize = 7 xl, yt = [int(round(c)) for c in bkgrROI.pos()] - bkgrROI.label = pg.TextItem('Bkgr. ROI', color=(255,255,255)) + bkgrROI.label = pg.TextItem("Bkgr. ROI", color=(255, 255, 255)) bkgrROI.label.setFont(self.roiLabelFont) # hLabel = bkgrROI.label.rect().bottom() bkgrROI.label.setPos(xl, yt) @@ -2527,7 +2492,7 @@ def setBkgrROIprops(self, bkgrROI): def addCropROI(self): cropROI = self.getDefaultROI(shrinkFactor=0.5) posData = self.data[self.pos_i] - self.setROIprops(cropROI, roiNumber=len(posData.cropROIs)+1) + self.setROIprops(cropROI, roiNumber=len(posData.cropROIs) + 1) posData.cropROIs.append(cropROI) self.addAndConnectROI(cropROI) @@ -2544,8 +2509,8 @@ def addDefaultBkgrROI(self, checked=False): bkgrROI.sigRegionChangeFinished.connect(self.bkgrROImovingFinished) def bkgrROIMoving(self, roi): - roi.setPen(color=(255,255,0)) - roi.label.setColor((255,255,0)) + roi.setPen(color=(255, 255, 0)) + roi.label.setColor((255, 255, 0)) # roi.label.setText(txt, color=(255,255,0), size=self.roiLabelSize) xl, yt = [int(round(c)) for c in roi.pos()] # hLabel = roi.label.rect().bottom() @@ -2553,8 +2518,8 @@ def bkgrROIMoving(self, roi): def bkgrROImovingFinished(self, roi): txt = roi.label.toPlainText() - roi.setPen(color=(255,255,255)) - roi.label.setColor((255,255,255)) + roi.setPen(color=(255, 255, 255)) + roi.label.setColor((255, 255, 255)) # roi.label.setText(txt, color=(150,150,150), size=self.roiLabelSize) posData = self.data[self.pos_i] idx = posData.bkgrROIs.index(roi) @@ -2563,21 +2528,21 @@ def bkgrROImovingFinished(self, roi): def ROImovingFinished(self, roi): txt = roi.label.toPlainText() - roi.setPen(color='r') - roi.label.setColor('r') + roi.setPen(color="r") + roi.label.setColor("r") # roi.label.setText(txt, color='r', size=self.roiLabelSize) self.saveROIcoords(False, self.data[self.pos_i]) def updateCurrentRoiShape(self, roi): - roi.setPen(color=(255,255,0)) - roi.label.setColor((255,255,0)) + roi.setPen(color=(255, 255, 0)) + roi.label.setColor((255, 255, 0)) # roi.label.setText('ROI', color=(255,255,0), size=self.roiLabelSize) xl, yt = [int(round(c)) for c in roi.pos()] # hLabel = roi.label.rect().bottom() roi.label.setPos(xl, yt) w, h = [int(round(c)) for c in roi.size()] - self.ROIshapeLabel.setText(f' Current ROI shape: {w} x {h}') - + self.ROIshapeLabel.setText(f" Current ROI shape: {w} x {h}") + def alignDataWorkerFinished(self, result): if self.progressWin is not None: self.progressWin.workerFinished = True @@ -2585,7 +2550,7 @@ def alignDataWorkerFinished(self, result): self.progressWin = None self.alignDataWorkerLoop.exit() self.img.mousePressEvent = self.gui_mousePressEventImg - + def saveBkgrDataWorkerFinished(self, result): if self.progressWin is not None: self.progressWin.workerFinished = True @@ -2593,7 +2558,7 @@ def saveBkgrDataWorkerFinished(self, result): self.progressWin = None self.saveBkgrDataWorker.loop.exit() self.img.mousePressEvent = self.gui_mousePressEventImg - + def cropWorkerFinished(self, result): if self.progressWin is not None: self.progressWin.workerFinished = True @@ -2601,23 +2566,23 @@ def cropWorkerFinished(self, result): self.progressWin = None self.cropWorker.loop.exit() self.img.mousePressEvent = self.gui_mousePressEventImg - + def workerInitProgressbar(self, totalIter): self.progressWin.mainPbar.setValue(0) if totalIter == 1: totalIter = 0 self.progressWin.mainPbar.setMaximum(totalIter) - + def workerUpdateProgressbar(self, step): self.progressWin.mainPbar.update(step) - + @exception_handler def workerCritical(self, error): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() raise error - + def warnZeroPaddingAlignment(self): txt = html_utils.paragraph(""" To align the frames, Cell-ACDC needs to shift the images @@ -2632,16 +2597,15 @@ def warnZeroPaddingAlignment(self): """) msg = widgets.myMessageBox(showCentered=False, wrapText=False) msg.information( - self, 'Padding alignment shifts', txt, - buttonsTexts=('Cancel', 'Ok') + self, "Padding alignment shifts", txt, buttonsTexts=("Cancel", "Ok") ) if msg.cancel: return False return True - + def alignData(self, user_ch_name, posData): align = False - progressText = 'Aligning data...' + progressText = "Aligning data..." if posData.SizeT > 1: msg = widgets.myMessageBox(showCentered=False) if posData.loaded_shifts is not None: @@ -2665,16 +2629,15 @@ def alignData(self, user_ch_name, posData): aligning. """) _, yesButton, noButton = msg.question( - self, 'Align frames?', txt, - buttonsTexts=('Cancel', 'Yes', 'No') + self, "Align frames?", txt, buttonsTexts=("Cancel", "Yes", "No") ) if msg.cancel: return False elif msg.clickedButton == noButton: align = False # Create 0, 0 shifts to perform 0 alignment - posData.loaded_shifts = np.zeros((self.num_frames,2), int) - progressText = 'Skipping alignment...' + posData.loaded_shifts = np.zeros((self.num_frames, 2), int) + progressText = "Skipping alignment..." else: if posData.loaded_shifts is not None: # Discard current shifts since we want to repeat it @@ -2683,8 +2646,8 @@ def alignData(self, user_ch_name, posData): elif posData.SizeT == 1: align = False # Create 0, 0 shifts to perform 0 alignment - posData.loaded_shifts = np.zeros((self.num_frames,2), int) - + posData.loaded_shifts = np.zeros((self.num_frames, 2), int) + if align: proceed = self.warnZeroPaddingAlignment() if not proceed: @@ -2694,22 +2657,21 @@ def alignData(self, user_ch_name, posData): if align: self.logger.info(progressText) - self.titleLabel.setText(progressText, color='w') + self.titleLabel.setText(progressText, color="w") self.startAlignDataWorker(posData, align, user_ch_name, progressText) self.waitAlignDataWorker() - + return not self.alignDataWorker.doAbort def askAlignSegmData(self): msg = widgets.myMessageBox() txt = html_utils.paragraph( - 'Cell-ACDC found an existing segmentation mask.

' - 'Do you need to align that too?' + "Cell-ACDC found an existing segmentation mask.

" + "Do you need to align that too?" ) _, noButton = msg.question( - self, 'Align segmentation data?', txt, - buttonsTexts=('Yes', 'No') + self, "Align segmentation data?", txt, buttonsTexts=("Yes", "No") ) self.alignDataWorker.doNotAlignSegmData = msg.clickedButton == noButton self.alignDataWorker.restart() @@ -2721,24 +2683,26 @@ def detectTifAlignment(self, tif_data, posData): for img in tif_data: if posData.SizeZ > 1: firtsCol = img[:, :, 0] - lastCol = img[:, : -1] + lastCol = img[:, :-1] firstRow = img[:, 0] lastRow = img[:, -1] else: firtsCol = img[:, 0] - lastCol = img[: -1] + lastCol = img[:-1] firstRow = img[0] lastRow = img[-1] someZeros = ( - not np.any(firstRow) or not np.any(firtsCol) - or not np.any(lastRow) or not np.any(lastCol) + not np.any(firstRow) + or not np.any(firtsCol) + or not np.any(lastRow) + or not np.any(lastCol) ) if someZeros: numFramesWith0s += 1 return numFramesWith0s def warnTifAligned(self, numFramesWith0s, tifPath, posData): - if numFramesWith0s>0 and posData.loaded_shifts is not None: + if numFramesWith0s > 0 and posData.loaded_shifts is not None: msg = widgets.myMessageBox() txt = html_utils.paragraph(""" Cell-ACDC detected that the .tif file contains LREADY @@ -2750,8 +2714,7 @@ def warnTifAligned(self, numFramesWith0s, tifPath, posData): Do you want to continue? """) msg.warning( - self, 'Tif data ALREADY aligned!', txt, - buttonsTexts=('Cancel', 'Yes') + self, "Tif data ALREADY aligned!", txt, buttonsTexts=("Cancel", "Yes") ) if msg.cancel: self.alignDataWorker.doAbort = True @@ -2764,35 +2727,35 @@ def getTempfilePath(self, path): return tempFilePath def moveTempFile(self, source, dst): - self.logger.info(f'Moving temp file: {source}') + self.logger.info(f"Moving temp file: {source}") tempDir = os.path.dirname(source) shutil.move(source, dst) shutil.rmtree(tempDir) def storeTempFileMove(self, source, dst): self.tempFilesToMove[source] = dst - + def getMostRecentPath(self): if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - self.MostRecentPath = df.iloc[0]['path'] + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + self.MostRecentPath = df.iloc[0]["path"] if not isinstance(self.MostRecentPath, str): - self.MostRecentPath = '' + self.MostRecentPath = "" else: - self.MostRecentPath = '' + self.MostRecentPath = "" def addToRecentPaths(self, exp_path): if not os.path.exists(exp_path): return if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() else: - openedOn = [np.nan]*len(recentPaths) + openedOn = [np.nan] * len(recentPaths) if exp_path in recentPaths: pop_idx = recentPaths.index(exp_path) recentPaths.pop(pop_idx) @@ -2806,10 +2769,13 @@ def addToRecentPaths(self, exp_path): else: recentPaths = [exp_path] openedOn = [datetime.datetime.now()] - df = pd.DataFrame({'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, - dtype='datetime64[ns]')}) - df.index.name = 'index' + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" df.to_csv(recentPaths_path) def populateOpenRecent(self): @@ -2817,10 +2783,10 @@ def populateOpenRecent(self): self.openRecentMenu.clear() # Step 1. Read recent Paths if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - recentPaths = df['path'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + recentPaths = df["path"].to_list() else: recentPaths = [] # Step 2. Dynamically create the actions @@ -2837,8 +2803,7 @@ def populateOpenRecent(self): @exception_handler def loadFiles(self, exp_path, user_ch_file_paths, user_ch_name): self.titleLabel.setText( - 'Loading data (check progress in the terminal)...', - color='w' + "Loading data (check progress in the terminal)...", color="w" ) version = self._acdc_version self.setWindowTitle(f'Cell-ACDC v{version} - Data Prep. - "{exp_path}"') @@ -2853,9 +2818,9 @@ def loadFiles(self, exp_path, user_ch_file_paths, user_ch_name): # Connect events at the end of loading data process self.gui_connectGraphicsEvents() - + exp_path = self.data[self.pos_i].exp_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) + pos_foldernames = utils.get_pos_foldernames(exp_path) if len(pos_foldernames) == 1: # There is only one position --> disable switch pos action self.loadPosAction.setDisabled(True) @@ -2864,21 +2829,20 @@ def loadFiles(self, exp_path, user_ch_file_paths, user_ch_name): if self.titleText is None: self.titleLabel.setText( - 'Data successfully loaded.
' + "Data successfully loaded.
" 'Press "START" button (top-left) to start prepping your data.', - color='w') + color="w", + ) else: - self.titleLabel.setText( - self.titleText, - color='w') + self.titleLabel.setText(self.titleText, color="w") self.openFolderAction.setEnabled(True) self.startAction.setEnabled(True) self.showInExplorerAction.setEnabled(True) self.setImageNameText() - + self.img.preComputedMinMaxValues(self.data) - + self.update_img() self.setFontSizeROIlabels() @@ -2888,9 +2852,9 @@ def setImageNameText(self): self.statusbar.clearMessage() posData = self.data[self.pos_i] txt = ( - f'{posData.pos_foldername} || ' - f'Basename: {posData.basename} || ' - f'Loaded channel: {posData.filename_ext}' + f"{posData.pos_foldername} || " + f"Basename: {posData.basename} || " + f"Loaded channel: {posData.filename_ext}" ) self.statusbar.showMessage(txt) @@ -2905,21 +2869,21 @@ def initLoading(self): self.setCenterAlignmentTitle() self.openFolderAction.setEnabled(False) self.setEnabledCropActions(False) - + self.freeRoiItem = None self.freeRoiMask = None - + self.saveSegmInfoWorkers = [] def showAbout(self): self.aboutWin = about.QDialogAbout(parent=self) self.aboutWin.show() - + def showHowToDataPrep(self): - myutils.browse_url(urls.dataprep_docs) - + utils.browse_url(urls.dataprep_docs) + def openRecentFile(self, path): - self.logger.info(f'Opening recent folder: {path}') + self.logger.info(f"Opening recent folder: {path}") self.openFolder(exp_path=path) def openFolder(self, checked=False, exp_path=None): @@ -2928,29 +2892,32 @@ def openFolder(self, checked=False, exp_path=None): if exp_path is None: self.getMostRecentPath() exp_path = QFileDialog.getExistingDirectory( - self, 'Select experiment folder containing Position_n folders ' - 'or specific Position_n folder', self.MostRecentPath) + self, + "Select experiment folder containing Position_n folders " + "or specific Position_n folder", + self.MostRecentPath, + ) self.addToRecentPaths(exp_path) - if exp_path == '': + if exp_path == "": self.openFolderAction.setEnabled(True) self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) return - folder_type = myutils.determine_folder_type(exp_path) + folder_type = utils.determine_folder_type(exp_path) is_pos_folder, is_images_folder, exp_path = folder_type - self.titleLabel.setText('Loading data...', color='w') + self.titleLabel.setText("Loading data...", color="w") self.setWindowTitle( f'Cell-ACDC v{self._acdc_version} - Data Prep - "{exp_path}"' ) self.setCenterAlignmentTitle() ch_name_selector = prompts.select_channel_name( - which_channel='segm', allow_abort=False + which_channel="segm", allow_abort=False ) if not is_pos_folder and not is_images_folder: @@ -2958,44 +2925,42 @@ def openFolder(self, checked=False, exp_path=None): values = select_folder.get_values_dataprep(exp_path) if not values: txt = ( - 'The selected folder:\n\n ' - f'{exp_path}\n\n' - 'is not a valid folder. ' - 'Select a folder that contains the Position_n folders' + "The selected folder:\n\n " + f"{exp_path}\n\n" + "is not a valid folder. " + "Select a folder that contains the Position_n folders" ) msg = QMessageBox() - msg.critical( - self, 'Incompatible folder', txt, msg.Ok - ) + msg.critical(self, "Incompatible folder", txt, msg.Ok) self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) self.openFolderAction.setEnabled(True) return select_folder.QtPrompt(self, values, allow_cancel=False) if select_folder.cancel: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) self.openFolderAction.setEnabled(True) return images_paths = [] for pos in select_folder.selected_pos: - images_paths.append(os.path.join(exp_path, pos, 'Images')) + images_paths.append(os.path.join(exp_path, pos, "Images")) if select_folder.cancel: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) self.openFolderAction.setEnabled(True) return elif is_pos_folder: pos_foldername = os.path.basename(exp_path) exp_path = os.path.dirname(exp_path) - images_paths = [os.path.join(exp_path, pos_foldername, 'Images')] + images_paths = [os.path.join(exp_path, pos_foldername, "Images")] elif is_images_folder: images_paths = [exp_path] @@ -3004,21 +2969,21 @@ def openFolder(self, checked=False, exp_path=None): # Get info from first position selected images_path = self.images_paths[0] - filenames = myutils.listdir(images_path) + filenames = utils.listdir(images_path) if ch_name_selector.is_first_call: - ch_names, warn = ( - ch_name_selector.get_available_channels(filenames, images_path) + ch_names, warn = ch_name_selector.get_available_channels( + filenames, images_path ) ch_names = ch_name_selector.askChannelName( filenames, images_path, warn, ch_names ) if ch_name_selector.was_aborted: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) self.openFolderAction.setEnabled(True) return - + if not ch_names: self.criticalNoTifFound(images_path) elif len(ch_names) > 1: @@ -3028,15 +2993,13 @@ def openFolder(self, checked=False, exp_path=None): ch_name_selector.setUserChannelName() if ch_name_selector.was_aborted: self.titleLabel.setText( - 'File --> Open or Open recent to start the process', - color='w') + "File --> Open or Open recent to start the process", color="w" + ) self.openFolderAction.setEnabled(True) return user_ch_name = ch_name_selector.user_ch_name - user_ch_file_paths = load.get_user_ch_paths( - self.images_paths, user_ch_name - ) + user_ch_file_paths = load.get_user_ch_paths(self.images_paths, user_ch_name) self.AutoPilotProfile.storeSelectedChannel(user_ch_name) self.loadFiles(exp_path, user_ch_file_paths, user_ch_name) @@ -3047,18 +3010,18 @@ def openFolder(self, checked=False, exp_path=None): def setFontSizeROIlabels(self): Y, X = self.img.image.shape factor = 50 - self.pt = int(X/factor) - self.roiLabelSize = '11px' + self.pt = int(X / factor) + self.roiLabelSize = "11px" self.roiLabelFont = QFont() self.roiLabelFont.setPixelSize(13) def criticalNoTifFound(self, images_path): - err_title = f'No .tif files found in folder.' + err_title = f"No .tif files found in folder." err_msg = ( f'The folder "{images_path}" does not contain .tif files.\n\n' 'Only .tif files can be loaded with "Open Folder" button.\n\n' 'Try with "File --> Open image/video file..." and directly select ' - 'the file you want to load.' + "the file you want to load." ) msg = QMessageBox() msg.critical(self, err_title, err_msg, msg.Ok) @@ -3081,53 +3044,37 @@ def askSaveAlignedData(self): Cell-ACDC detected aligned data that was not saved.

Do you want to save the aligned data? """) - buttonsTexts = ( - 'Cancel', 'No, close data-prep', 'Yes, save aligned data' - ) + buttonsTexts = ("Cancel", "No, close data-prep", "Yes, save aligned data") _, noButton, yesAlignButton = msg.question( - self, 'Save cropped data?', txt, buttonsTexts=buttonsTexts + self, "Save cropped data?", txt, buttonsTexts=buttonsTexts ) return msg.clickedButton == yesAlignButton, msg.cancel - + def startMoveTempFilesWorker(self): self.progressWin = apps.QDialogWorkerProgress( - title='Saving aligned data', - parent=self, - pbarDesc='Saving aligned data' + title="Saving aligned data", parent=self, pbarDesc="Saving aligned data" ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(len(self.tempFilesToMove)) - + self.saveAlignedThread = QThread() - self.saveAlignedWorker = workers.MoveTempFilesWorker( - self.tempFilesToMove - ) - + self.saveAlignedWorker = workers.MoveTempFilesWorker(self.tempFilesToMove) + self.saveAlignedWorker.moveToThread(self.saveAlignedThread) - self.saveAlignedWorker.signals.finished.connect( - self.saveAlignedThread.quit - ) + self.saveAlignedWorker.signals.finished.connect(self.saveAlignedThread.quit) self.saveAlignedWorker.signals.finished.connect( self.saveAlignedWorker.deleteLater ) - self.saveAlignedThread.finished.connect( - self.saveAlignedThread.deleteLater - ) - - self.saveAlignedWorker.signals.finished.connect( - self.saveAlignedWorkerFinished - ) + self.saveAlignedThread.finished.connect(self.saveAlignedThread.deleteLater) + + self.saveAlignedWorker.signals.finished.connect(self.saveAlignedWorkerFinished) self.saveAlignedWorker.signals.progress.connect(self.workerProgress) self.saveAlignedWorker.signals.initProgressBar.connect( self.workerInitProgressbar ) - self.saveAlignedWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.saveAlignedWorker.signals.critical.connect( - self.workerCritical - ) - + self.saveAlignedWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.saveAlignedWorker.signals.critical.connect(self.workerCritical) + self.saveAlignedThread.started.connect(self.saveAlignedWorker.run) self.saveAlignedThread.start() @@ -3138,51 +3085,49 @@ def saveAlignedWorkerFinished(self): self.progressWin = None self.saveAlignedWorkerLoop.exit() self.tempFilesToMove = {} - + def waitMoveTempFilesWorker(self): self.saveAlignedWorkerLoop = QEventLoop(self) self.saveAlignedWorkerLoop.exec_() - + def removeAlignShiftsFile(self): for posData in self.data: posData = self.data[self.pos_i] if posData.align_shifts_path is None: continue - + if not os.path.exists(posData.align_shifts_path): continue - - self.logger.info( - f'Removing align shifts file: {posData.align_shifts_path}' - ) + + self.logger.info(f"Removing align shifts file: {posData.align_shifts_path}") try: os.remove(posData.align_shifts_path) except Exception as e: pass - + def handleAlignedDataOnClosing(self): - if not hasattr(self, 'tempFilesToMove'): + if not hasattr(self, "tempFilesToMove"): return True - + if not self.tempFilesToMove: return True - + saveAligned, cancel = self.askSaveAlignedData() if cancel: return False - + if not saveAligned: self.removeAlignShiftsFile() return True - + cancel = self.warnSaveAlignedNotReversible() if cancel: return True - + self.startMoveTempFilesWorker() self.waitMoveTempFilesWorker() return True - + def warnSaveAlignedNotReversible(self): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(""" @@ -3190,59 +3135,57 @@ def warnSaveAlignedNotReversible(self): Do you want to continue with saving the aligned data? """) _, yesButton = msg.warning( - self, 'Save aligned data?', txt, - buttonsTexts=('Cancel', 'Yes, save aligned data') + self, + "Save aligned data?", + txt, + buttonsTexts=("Cancel", "Yes, save aligned data"), ) return msg.cancel - + def askCropAndSave(self): if not self.saveAction.isEnabled(): return True - + isCropped = False for p, posData in enumerate(self.data): - data = posData.img_data + data = posData.img_data allCropsData = [] for cropROI in posData.cropROIs: croppedData, SizeZ = self.crop(data, posData, cropROI) allCropsData.append(croppedData) - isCropped = any( - [cropped.shape != data.shape for cropped in allCropsData] - ) + isCropped = any([cropped.shape != data.shape for cropped in allCropsData]) if isCropped: break - + if not isCropped: return True - + msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(""" You seem to have cropping information that you did not save.

Do you want to save cropped data? """) - buttonsTexts = ( - 'Cancel', 'No, close data-prep', 'Yes, save cropped data' - ) + buttonsTexts = ("Cancel", "No, close data-prep", "Yes, save cropped data") _, noButton, yesButton = msg.question( - self, 'Save cropped data?', txt, buttonsTexts=buttonsTexts + self, "Save cropped data?", txt, buttonsTexts=buttonsTexts ) if msg.cancel: return False - + if msg.clickedButton == yesButton: self.cropAndSave() - + return True - + def closeEvent(self, event): self.saveWindowGeometry() - + proceed = self.handleAlignedDataOnClosing() if not proceed: event.ignore() return - + proceed = self.askCropAndSave() if not proceed: event.ignore() @@ -3251,8 +3194,7 @@ def closeEvent(self, event): if self.buttonToRestore is not None: button, color, text = self.buttonToRestore button.setText(text) - button.setStyleSheet( - f'QPushButton {{background-color: {color};}}') + button.setStyleSheet(f"QPushButton {{background-color: {color};}}") self.mainWin.setWindowState(Qt.WindowNoState) self.mainWin.setWindowState(Qt.WindowActive) self.mainWin.raise_() @@ -3260,7 +3202,7 @@ def closeEvent(self, event): event.ignore() self.hide() - self.logger.info('Closing dataPrep logger...') + self.logger.info("Closing dataPrep logger...") handlers = self.logger.handlers[:] for handler in handlers: handler.close() @@ -3268,17 +3210,17 @@ def closeEvent(self, event): if self.loop is not None: self.loop.exit() - + self.sigClose.emit(self) gc.collect() def saveWindowGeometry(self): - settings = QSettings('schmollerlab', 'acdc_dataPrep') + settings = QSettings("schmollerlab", "acdc_dataPrep") settings.setValue("geometry", self.saveGeometry()) def readSettings(self): - settings = QSettings('schmollerlab', 'acdc_dataPrep') - if settings.value('geometry') is not None: + settings = QSettings("schmollerlab", "acdc_dataPrep") + if settings.value("geometry") is not None: self.restoreGeometry(settings.value("geometry")) def show(self): diff --git a/cellacdc/dataReStruct.py b/cellacdc/dataReStruct.py index b1f629f4a..966c590ba 100644 --- a/cellacdc/dataReStruct.py +++ b/cellacdc/dataReStruct.py @@ -11,17 +11,15 @@ from qtpy.QtCore import QThread from qtpy.QtWidgets import QFileDialog -from . import apps, html_utils, myutils, printl, widgets, workers +from . import apps, html_utils, utils, printl, widgets, workers # Frame number must be at the end with .ext, e.g., _t01.tif -frame_name_patterns = ( - r'_(day)?(\d+)\.[A-Za-z0-9]+$', - r'_(t)?(\d+)\.[A-Za-z0-9]+$' -) +frame_name_patterns = (r"_(day)?(\d+)\.[A-Za-z0-9]+$", r"_(t)?(\d+)\.[A-Za-z0-9]+$") + def get_frame_num_and_pattern(filename): # Start with random un-matching pattern - matching_frame_name_pattern = r'^\.+' + matching_frame_name_pattern = r"^\.+" for frame_name_pattern in frame_name_patterns: try: frameNumber = re.findall(frame_name_pattern, filename)[0][1] @@ -31,42 +29,46 @@ def get_frame_num_and_pattern(filename): frameNumber = None return matching_frame_name_pattern, frameNumber + def readFilenamePattern(fileName): - matching_frame_name_pattern, frameNumber = get_frame_num_and_pattern( - fileName - ) - - s = re.sub(matching_frame_name_pattern, '', fileName) + matching_frame_name_pattern, frameNumber = get_frame_num_and_pattern(fileName) + + s = re.sub(matching_frame_name_pattern, "", fileName) for i, c in enumerate(s[::-1]): - if c == '_': + if c == "_": break channelName = s[-i:] - posName = s[:-i-1] - if channelName.endswith('.tif'): + posName = s[: -i - 1] + if channelName.endswith(".tif"): channelName = channelName[:-4] - + return posName, frameNumber, channelName def _log(mainWin, text): mainWin.log(text) + def run(mainWin): items = ( - 'Multiple files, one for each time-point', - 'Multiple files, one for each channel' + "Multiple files, one for each time-point", + "Multiple files, one for each channel", ) selectHowWin = apps.QDialogCombobox( - 'Select how files are structured', items, - 'Select how files are structured', - CbLabel='', parent=mainWin + "Select how files are structured", + items, + "Select how files are structured", + CbLabel="", + parent=mainWin, ) selectHowWin.exec_() if selectHowWin.cancel: return False - - mainWin.log(f'[Data Re-Struct] Selected file structure = "{selectHowWin.selectedItemText}"') + + mainWin.log( + f'[Data Re-Struct] Selected file structure = "{selectHowWin.selectedItemText}"' + ) msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph(""" @@ -74,39 +76,35 @@ def run(mainWin): into an empty folder before closing this dialogue.

Note that there should be no other files in this folder. - """ - ) + """) msg.information( - mainWin, 'Microscopy files location', txt, - buttonsTexts=('Cancel', 'Done') + mainWin, "Microscopy files location", txt, buttonsTexts=("Cancel", "Done") ) if msg.cancel: return False - + mainWin.log( - '[Data Re-Struct] Asking to select the folder that contains the image files...' + "[Data Re-Struct] Asking to select the folder that contains the image files..." ) - MostRecentPath = myutils.getMostRecentPath() + MostRecentPath = utils.getMostRecentPath() rootFolderPath = QFileDialog.getExistingDirectory( - mainWin.progressWin, 'Select folder containing the image files', - MostRecentPath) - myutils.addToRecentPaths(rootFolderPath) + mainWin.progressWin, "Select folder containing the image files", MostRecentPath + ) + utils.addToRecentPaths(rootFolderPath) if not rootFolderPath: return False - - mainWin.log( - '[Data Re-Struct] Asking in which folder to save the images files...' - ) + + mainWin.log("[Data Re-Struct] Asking in which folder to save the images files...") dstFolderPath = QFileDialog.getExistingDirectory( - mainWin.progressWin, - 'Select the folder in which to save the images files', - rootFolderPath + mainWin.progressWin, + "Select the folder in which to save the images files", + rootFolderPath, ) - myutils.addToRecentPaths(dstFolderPath) + utils.addToRecentPaths(dstFolderPath) if not rootFolderPath: return False - - mainWin.log('[Data Re-Struct] Checking file format of loaded files...') + + mainWin.log("[Data Re-Struct] Checking file format of loaded files...") validFilenames = checkFileFormat(rootFolderPath, mainWin) if not validFilenames: return False @@ -118,46 +116,48 @@ def run(mainWin): return started elif selectHowWin.selectedItemIdx == 1: msg = widgets.myMessageBox(wrapText=False) - copyButton = widgets.copyPushButton('Copy files') - moveButton = widgets.movePushButton('Move files') + copyButton = widgets.copyPushButton("Copy files") + moveButton = widgets.movePushButton("Move files") txt = html_utils.paragraph( - 'Do you want to copy or move the files to the ' - 'Position folders?' + "Do you want to copy or move the files to the Position folders?" ) msg.question( - mainWin, 'Copy or move files?', txt, - buttonsTexts=('Cancel', copyButton, moveButton) + mainWin, + "Copy or move files?", + txt, + buttonsTexts=("Cancel", copyButton, moveButton), ) if msg.cancel: return False - action = 'copy' if msg.clickedButton == copyButton else 'move' + action = "copy" if msg.clickedButton == copyButton else "move" started = _run_multi_files_multi_pos( mainWin, rootFolderPath, dstFolderPath, action ) return started - + return True + def checkFileFormat(folderPath, mainWin): - ls = natsorted(myutils.listdir(folderPath)) + ls = natsorted(utils.listdir(folderPath)) files = [ - filename for filename in ls + filename + for filename in ls if os.path.isfile(os.path.join(folderPath, filename)) ] if not files: msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'The following folder

' - f'{folderPath}

' - 'does not contain any file!
' + "The following folder

" + f"{folderPath}

" + "does not contain any file!
" ) msg.addShowInFileManagerButton(folderPath) - msg.critical( - mainWin, 'Multiple extensions detected', txt - ) + msg.critical(mainWin, "Multiple extensions detected", txt) return [] all_ext = [ - os.path.splitext(filename)[1] for filename in ls + os.path.splitext(filename)[1] + for filename in ls if os.path.isfile(os.path.join(folderPath, filename)) ] counter = collections.Counter(all_ext) @@ -167,21 +167,21 @@ def checkFileFormat(folderPath, mainWin): if not is_ext_unique: msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'The following folder

' - f'{folderPath}

' - 'contains files with different file extensions ' - f'(extensions detected: {unique_ext})

' - f'However, the most common extension is {most_common_ext}, ' - 'do you want to proceed with
' - f'loading only files with extension {most_common_ext}?' + "The following folder

" + f"{folderPath}

" + "contains files with different file extensions " + f"(extensions detected: {unique_ext})

" + f"However, the most common extension is {most_common_ext}, " + "do you want to proceed with
" + f"loading only files with extension {most_common_ext}?" ) _, proceedWithMostCommon = msg.warning( - mainWin, 'Multiple extensions detected', txt, - buttonsTexts=('Cancel', 'Yes') + mainWin, "Multiple extensions detected", txt, buttonsTexts=("Cancel", "Yes") ) if proceedWithMostCommon == msg.clickedButton: files = [ - filename for filename in files + filename + for filename in files if os.path.splitext(filename)[1] == most_common_ext ] otherExt = [ext for ext in unique_ext if ext != most_common_ext] @@ -190,14 +190,14 @@ def checkFileFormat(folderPath, mainWin): return files + def saveTiff(filePath, data, waitCond): - myutils.to_tiff(filePath, data) + utils.to_tiff(filePath, data) waitCond.wakeAll() del data -def _run_multi_files_timepoints( - mainWin, validFilenames, rootFolderPath, dstFolderPath - ): + +def _run_multi_files_timepoints(mainWin, validFilenames, rootFolderPath, dstFolderPath): sampleFilename = validFilenames[0] win = apps.MultiTimePointFilePattern( @@ -206,21 +206,21 @@ def _run_multi_files_timepoints( win.exec_() if win.cancel: return False - - matching_frame_name_pattern, frameNumber = get_frame_num_and_pattern( - sampleFilename - ) + + matching_frame_name_pattern, frameNumber = get_frame_num_and_pattern(sampleFilename) mainWin.thread = QThread() mainWin.restructWorker = workers.RestructMultiTimepointsWorker( - win.allChannels, matching_frame_name_pattern, win.basename, - validFilenames, rootFolderPath, dstFolderPath, - segmFolderPath=win.segmFolderPath + win.allChannels, + matching_frame_name_pattern, + win.basename, + validFilenames, + rootFolderPath, + dstFolderPath, + segmFolderPath=win.segmFolderPath, ) mainWin.restructWorker.moveToThread(mainWin.thread) mainWin.restructWorker.signals.finished.connect(mainWin.thread.quit) - mainWin.restructWorker.signals.finished.connect( - mainWin.restructWorker.deleteLater - ) + mainWin.restructWorker.signals.finished.connect(mainWin.restructWorker.deleteLater) mainWin.thread.finished.connect(mainWin.thread.deleteLater) # Custom signals @@ -230,9 +230,7 @@ def _run_multi_files_timepoints( mainWin.restructWorker.signals.initProgressBar.connect( mainWin.workerInitProgressbar ) - mainWin.restructWorker.signals.progressBar.connect( - mainWin.workerUpdateProgressbar - ) + mainWin.restructWorker.signals.progressBar.connect(mainWin.workerUpdateProgressbar) mainWin.restructWorker.sigSaveTiff.connect(saveTiff) mainWin.thread.started.connect(mainWin.restructWorker.run) @@ -240,6 +238,7 @@ def _run_multi_files_timepoints( return True + def _run_multi_files_multi_pos(mainWin, rootFolderPath, dstFolderPath, action): mainWin.thread = QThread() mainWin.restructWorker = workers.RestructMultiPosWorker( @@ -247,9 +246,7 @@ def _run_multi_files_multi_pos(mainWin, rootFolderPath, dstFolderPath, action): ) mainWin.restructWorker.moveToThread(mainWin.thread) mainWin.restructWorker.signals.finished.connect(mainWin.thread.quit) - mainWin.restructWorker.signals.finished.connect( - mainWin.restructWorker.deleteLater - ) + mainWin.restructWorker.signals.finished.connect(mainWin.restructWorker.deleteLater) mainWin.thread.finished.connect(mainWin.thread.deleteLater) # Custom signals @@ -259,11 +256,9 @@ def _run_multi_files_multi_pos(mainWin, rootFolderPath, dstFolderPath, action): mainWin.restructWorker.signals.initProgressBar.connect( mainWin.workerInitProgressbar ) - mainWin.restructWorker.signals.progressBar.connect( - mainWin.workerUpdateProgressbar - ) + mainWin.restructWorker.signals.progressBar.connect(mainWin.workerUpdateProgressbar) mainWin.thread.started.connect(mainWin.restructWorker.run) mainWin.thread.start() - return True \ No newline at end of file + return True diff --git a/cellacdc/dataStruct.py b/cellacdc/dataStruct.py index b41dbbc28..d2dd70a64 100755 --- a/cellacdc/dataStruct.py +++ b/cellacdc/dataStruct.py @@ -23,21 +23,26 @@ from itertools import permutations from qtpy.QtWidgets import ( - QApplication, QMainWindow, QFileDialog, - QVBoxLayout, QPushButton, QLabel, QStyleFactory, - QWidget, QMessageBox, QPlainTextEdit, QHBoxLayout -) -from qtpy.QtCore import ( - Qt, QObject, Signal, QThread, QMutex, QWaitCondition, - QEventLoop + QApplication, + QMainWindow, + QFileDialog, + QVBoxLayout, + QPushButton, + QLabel, + QStyleFactory, + QWidget, + QMessageBox, + QPlainTextEdit, + QHBoxLayout, ) +from qtpy.QtCore import Qt, QObject, Signal, QThread, QMutex, QWaitCondition, QEventLoop from qtpy import QtGui # Here we use from cellacdc because this script is laucnhed in # a separate process that doesn't have a parent package from . import issues_url from . import exception_handler -from . import apps, myutils, widgets, html_utils, printl +from . import apps, utils, widgets, html_utils, printl from . import load, settings_csv_path from . import _palettes from . import recentPaths_path, cellacdc_path, settings_folderpath @@ -47,15 +52,17 @@ from . import acdc_regex from . import io -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except: pass + def worker_exception_handler(func): @wraps(func) def run(self): @@ -65,8 +72,10 @@ def run(self): result = None self.critical.emit(error) return result + return run + class bioFormatsWorker(QObject): finished = Signal() progress = Signal(str) @@ -74,21 +83,43 @@ class bioFormatsWorker(QObject): initPbar = Signal(int) criticalError = Signal(str, str, str) confirmMetadata = Signal( - str, float, int, int, int, int, - float, str, float, float, float, - str, list, list, str, str, object + str, + float, + int, + int, + int, + int, + float, + str, + float, + float, + float, + str, + list, + list, + str, + str, + object, ) critical = Signal(object) sigFinishedReadingSampleImageData = Signal(object) def __init__( - self, raw_src_path, rawFilenames, exp_dst_path, - mutex, waitCond, rawDataStruct, - bioformats_backend: Literal['bioio', 'python-bioformats'], - lazy_load=True, move_raw_microscopy_files=True, - overwrite=False, add_files=False, create_new=False, - start_pos_n=1 - ): + self, + raw_src_path, + rawFilenames, + exp_dst_path, + mutex, + waitCond, + rawDataStruct, + bioformats_backend: Literal["bioio", "python-bioformats"], + lazy_load=True, + move_raw_microscopy_files=True, + overwrite=False, + add_files=False, + create_new=False, + start_pos_n=1, + ): QObject.__init__(self) self.raw_src_path = raw_src_path self.exp_dst_path = exp_dst_path @@ -107,34 +138,41 @@ def __init__( self.bioformats_backend = bioformats_backend self.lazy_load = lazy_load self.move_raw_microscopy_files = move_raw_microscopy_files - + def _readSampleDataPythonBioformats( - self, bioformats, rawFilePath, sampleImgData, SizeC, SizeT, SizeZ, - sampleSizeT, sampleSizeZ - ): + self, + bioformats, + rawFilePath, + sampleImgData, + SizeC, + SizeT, + SizeZ, + sampleSizeT, + sampleSizeZ, + ): dimsIdx = {} allChannelsData = None with bioformats.ImageReader(rawFilePath) as reader: permut_pbar = tqdm(total=6, ncols=100) - for dimsOrd in permutations('zct', 3): - if allChannelsData is not None and self.bioformats_backend == 'bioio': - sampleImgData[''.join(dimsOrd)] = allChannelsData + for dimsOrd in permutations("zct", 3): + if allChannelsData is not None and self.bioformats_backend == "bioio": + sampleImgData["".join(dimsOrd)] = allChannelsData permut_pbar.update(1) continue - + allChannelsData = [] idxs = self.buildIndexes(SizeC, SizeT, SizeZ, dimsOrd) - numIter = SizeC*sampleSizeT*sampleSizeZ + numIter = SizeC * sampleSizeT * sampleSizeZ pbar = tqdm(total=numIter, ncols=100, leave=False) skipPermutation = False for c in range(SizeC): - dimsIdx['c'] = c + dimsIdx["c"] = c imgData_tz = [] - for t in range(sampleSizeT): - dimsIdx['t'] = t + for t in range(sampleSizeT): + dimsIdx["t"] = t imgData_z = [] for z in range(sampleSizeZ): - dimsIdx['z'] = z + dimsIdx["z"] = z try: idx = self.getIndex(idxs, dimsIdx, dimsOrd) imgData = reader.read( @@ -143,10 +181,10 @@ def _readSampleDataPythonBioformats( except Exception as e: skipPermutation = True break - imgData_z.append(imgData) + imgData_z.append(imgData) pbar.update() if skipPermutation: - break + break imgData_z = np.array(imgData_z, dtype=imgData.dtype) imgData_z = np.squeeze(imgData_z) imgData_tz.append(imgData_z) @@ -157,49 +195,49 @@ def _readSampleDataPythonBioformats( pbar.close() permut_pbar.update(1) if not skipPermutation: - sampleImgData[''.join(dimsOrd)] = allChannelsData + sampleImgData["".join(dimsOrd)] = allChannelsData permut_pbar.close() return sampleImgData - + def readSampleData(self, rawFilePath, SizeC, SizeT, SizeZ): - if self.bioformats_backend == 'bioio': + if self.bioformats_backend == "bioio": from cellacdc import acdc_bioio_bioformats as bioformats else: import javabridge from cellacdc import bioformats - + sampleImgData = {} - self.progress.emit('Reading sample image data...') - - if self.bioformats_backend == 'bioio': - # To avoid running Java in the main process, we spawn a new + self.progress.emit("Reading sample image data...") + + if self.bioformats_backend == "bioio": + # To avoid running Java in the main process, we spawn a new # process that runs a python script to read sample data, save # it to disk, and then load it back here. import subprocess from . import _process, bioio_sample_data_folderpath - + read_sample_data_py_filepath = os.path.join( - os.path.dirname(bioformats.__file__), '_read_sample_data.py' + os.path.dirname(bioformats.__file__), "_read_sample_data.py" ) uuid4 = uuid.uuid4() command = ( - f'{sys.executable}, ' - f'{read_sample_data_py_filepath}, ' - f'-f, {rawFilePath}, ' - f'-c, {SizeC}, ' - f'-t, {SizeT}, ' - f'-z, {SizeZ},' - f'-uuid, {uuid4}' + f"{sys.executable}, " + f"{read_sample_data_py_filepath}, " + f"-f, {rawFilePath}, " + f"-c, {SizeC}, " + f"-t, {SizeT}, " + f"-z, {SizeZ}," + f"-uuid, {uuid4}" ) if not self.lazy_load: - command = f'{command}, -a' - - args = [sys.executable, _process.__file__, '-c', command] + command = f"{command}, -a" + + args = [sys.executable, _process.__file__, "-c", command] subprocess.run(args) - + bioformats._utils.check_raise_exception(uuid4) - + allChannelsData = [] for c in range(SizeC): filepath = os.path.join( @@ -208,35 +246,41 @@ def readSampleData(self, rawFilePath, SizeC, SizeT, SizeZ): channel_data = np.load(filepath) allChannelsData.append(channel_data) os.remove(filepath) - - for dimsOrd in permutations('zct', 3): - sampleImgData[''.join(dimsOrd)] = allChannelsData + + for dimsOrd in permutations("zct", 3): + sampleImgData["".join(dimsOrd)] = allChannelsData else: if SizeT >= 4: sampleSizeT = 4 else: - sampleSizeT = SizeT + sampleSizeT = SizeT if SizeZ > 20: sampleSizeZ = 20 else: sampleSizeZ = SizeZ sampleImgData = self._readSampleDataPythonBioformats( - bioformats, rawFilePath, sampleImgData, SizeC, SizeT, SizeZ, - sampleSizeT, sampleSizeZ + bioformats, + rawFilePath, + sampleImgData, + SizeC, + SizeT, + SizeZ, + sampleSizeT, + sampleSizeZ, ) - + self.sigFinishedReadingSampleImageData.emit(sampleImgData) return sampleImgData def getSizeZ(self, rawFilePath): - if self.bioformats_backend == 'bioio': + if self.bioformats_backend == "bioio": from cellacdc import acdc_bioio_bioformats as bioformats else: import javabridge from cellacdc import bioformats - + try: - if rawFilePath.endswith('.ome.tif'): + if rawFilePath.endswith(".ome.tif"): metadata = load.OMEXML(rawFilePath) metadataXML = metadata.omexml_string else: @@ -249,51 +293,46 @@ def getSizeZ(self, rawFilePath): def _readMetadataBioIO(self, rawFilePath): from . import bioio_sample_data_folderpath, _process from . import acdc_bioio_bioformats as bioformats - + import subprocess - + read_metadata_py_filepath = os.path.join( - os.path.dirname(bioformats.__file__), '_read_metadata.py' + os.path.dirname(bioformats.__file__), "_read_metadata.py" ) uuid4 = uuid.uuid4() command = ( - f'{sys.executable}, {read_metadata_py_filepath}, ' - f'-f, {rawFilePath}, ' - f'-uuid, {uuid4}' + f"{sys.executable}, {read_metadata_py_filepath}, " + f"-f, {rawFilePath}, " + f"-uuid, {uuid4}" ) - - args = [sys.executable, _process.__file__, '-c', command] + + args = [sys.executable, _process.__file__, "-c", command] subprocess.run(args) - + bioformats._utils.check_raise_exception(uuid4) metadataXML_filepath = os.path.join( - bioio_sample_data_folderpath, 'metadataXML.txt' + bioio_sample_data_folderpath, "metadataXML.txt" ) metadataXML = bioformats.Metadata().init_from_file(metadataXML_filepath) - metadata_filepath = os.path.join( - bioio_sample_data_folderpath, 'metadata.txt' - ) - metadata = bioformats.OMEXML().init_from_file( - metadata_filepath, rawFilePath - ) + metadata_filepath = os.path.join(bioio_sample_data_folderpath, "metadata.txt") + metadata = bioformats.OMEXML().init_from_file(metadata_filepath, rawFilePath) return metadata, metadataXML - - + def readMetadata(self, raw_src_path, filename): - if self.bioformats_backend == 'bioio': + if self.bioformats_backend == "bioio": from cellacdc import acdc_bioio_bioformats as bioformats else: import javabridge from cellacdc import bioformat - + rawFilePath = os.path.join(raw_src_path, filename) - self.progress.emit('Reading OME metadata...') + self.progress.emit("Reading OME metadata...") try: - if rawFilePath.endswith('.ome.tif'): + if rawFilePath.endswith(".ome.tif"): metadata = load.OMEXML(rawFilePath) metadataXML = metadata.omexml_string else: @@ -304,20 +343,17 @@ def readMetadata(self, raw_src_path, filename): traceback.print_exc() self.isCriticalError = True self.criticalError.emit( - 'reading image data or metadata', - traceback.format_exc(), filename + "reading image data or metadata", traceback.format_exc(), filename ) return True try: LensNA = float(metadata.instrument().Objective.LensNA) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: LensNA not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: LensNA not found in metadata.") + self.progress.emit("===================================================") LensNA = 1.4 if self.rawDataStruct != 2: @@ -325,11 +361,13 @@ def readMetadata(self, raw_src_path, filename): SizeS = int(metadata.get_image_count()) except Exception as e: self.progress.emit( - '===================================================') + "===================================================" + ) self.progress.emit(rawFilePath) - self.progress.emit('WARNING: SizeS not found in metadata.') + self.progress.emit("WARNING: SizeS not found in metadata.") self.progress.emit( - '===================================================') + "===================================================" + ) SizeS = 1 else: SizeS = self.SizeS @@ -337,126 +375,105 @@ def readMetadata(self, raw_src_path, filename): try: SizeZ = int(metadata.image().Pixels.SizeZ) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: SizeZ not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: SizeZ not found in metadata.") + self.progress.emit("===================================================") SizeZ = 1 try: SizeT = int(metadata.image().Pixels.SizeT) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: SizeT not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: SizeT not found in metadata.") + self.progress.emit("===================================================") SizeT = 1 try: Pixels = metadata.image().Pixels - TimeIncrement = float(Pixels.node.get('TimeIncrement')) + TimeIncrement = float(Pixels.node.get("TimeIncrement")) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: TimeIncrement not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: TimeIncrement not found in metadata.") + self.progress.emit("===================================================") TimeIncrement = 1.0 try: Pixels = metadata.image().Pixels - TimeIncrementUnit = Pixels.node.get('TimeIncrementUnit') + TimeIncrementUnit = Pixels.node.get("TimeIncrementUnit") if TimeIncrementUnit is None: raise except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: TimeIncrementUnit not found in metadata.') - self.progress.emit( - '===================================================') - TimeIncrementUnit = 's' + self.progress.emit("WARNING: TimeIncrementUnit not found in metadata.") + self.progress.emit("===================================================") + TimeIncrementUnit = "s" try: SizeC = int(metadata.image().Pixels.SizeC) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: SizeC not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: SizeC not found in metadata.") + self.progress.emit("===================================================") SizeC = 1 try: PhysicalSizeX = float(metadata.image().Pixels.PhysicalSizeX) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: PhysicalSizeX not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: PhysicalSizeX not found in metadata.") + self.progress.emit("===================================================") PhysicalSizeX = 1.0 try: PhysicalSizeY = float(metadata.image().Pixels.PhysicalSizeY) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: PhysicalSizeY not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: PhysicalSizeY not found in metadata.") + self.progress.emit("===================================================") PhysicalSizeY = 1.0 try: PhysicalSizeZ = float(metadata.image().Pixels.PhysicalSizeZ) except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: PhysicalSizeZ not found in metadata.') - self.progress.emit( - '===================================================') + self.progress.emit("WARNING: PhysicalSizeZ not found in metadata.") + self.progress.emit("===================================================") PhysicalSizeZ = 1.0 try: Pixels = metadata.image().Pixels - PhysicalSizeUnit = Pixels.node.get('PhysicalSizeXUnit') + PhysicalSizeUnit = Pixels.node.get("PhysicalSizeXUnit") if PhysicalSizeUnit is None: raise except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: PhysicalSizeUnit not found in metadata.') - self.progress.emit( - '===================================================') - PhysicalSizeUnit = 'μm' + self.progress.emit("WARNING: PhysicalSizeUnit not found in metadata.") + self.progress.emit("===================================================") + PhysicalSizeUnit = "μm" try: ImageName = metadata.image().Name if ImageName is None: raise except Exception as e: - self.progress.emit( - '===================================================') + self.progress.emit("===================================================") self.progress.emit(rawFilePath) - self.progress.emit('WARNING: Image Name not found in metadata.') - self.progress.emit( - '===================================================') - ImageName = '' - + self.progress.emit("WARNING: Image Name not found in metadata.") + self.progress.emit("===================================================") + ImageName = "" if self.rawDataStruct != 2: try: - chNames = ['']*SizeC + chNames = [""] * SizeC for c in range(SizeC): try: chNames[c] = metadata.image().Pixels.Channel(c).Name @@ -464,19 +481,21 @@ def readMetadata(self, raw_src_path, filename): pass except Exception as e: self.progress.emit( - '===================================================') + "===================================================" + ) self.progress.emit(rawFilePath) - self.progress.emit('WARNING: chNames not found in metadata.') + self.progress.emit("WARNING: chNames not found in metadata.") self.progress.emit( - '===================================================') - chNames = ['']*SizeC + "===================================================" + ) + chNames = [""] * SizeC else: chNames = self.chNames SizeC = len(self.chNames) if self.rawDataStruct != 2: try: - emWavelens = [500.0]*SizeC + emWavelens = [500.0] * SizeC for c in range(SizeC): try: Channel = metadata.image().Pixels.Channel(c) @@ -487,14 +506,16 @@ def readMetadata(self, raw_src_path, filename): except Exception as e: traceback.print_exc() self.progress.emit( - '===================================================') + "===================================================" + ) self.progress.emit(rawFilePath) - self.progress.emit('WARNING: EmissionWavelength not found in metadata.') + self.progress.emit("WARNING: EmissionWavelength not found in metadata.") self.progress.emit( - '===================================================') - emWavelens = [500.0]*SizeC + "===================================================" + ) + emWavelens = [500.0] * SizeC else: - emWavelens = [500.0]*SizeC + emWavelens = [500.0] * SizeC if self.trustMetadataReader: self.LensNA = LensNA @@ -514,14 +535,25 @@ def readMetadata(self, raw_src_path, filename): while True: self.mutex.lock() if self.rawDataStruct != 2: - sampleImgData = self.readSampleData( - rawFilePath, SizeC, SizeT, SizeZ - ) + sampleImgData = self.readSampleData(rawFilePath, SizeC, SizeT, SizeZ) self.confirmMetadata.emit( - filename, LensNA, SizeT, SizeZ, SizeC, SizeS, - TimeIncrement, TimeIncrementUnit, PhysicalSizeX, PhysicalSizeY, - PhysicalSizeZ, PhysicalSizeUnit, chNames, emWavelens, ImageName, - rawFilePath, sampleImgData + filename, + LensNA, + SizeT, + SizeZ, + SizeC, + SizeS, + TimeIncrement, + TimeIncrementUnit, + PhysicalSizeX, + PhysicalSizeY, + PhysicalSizeZ, + PhysicalSizeUnit, + chNames, + emWavelens, + ImageName, + rawFilePath, + sampleImgData, ) self.waitCond.wait(self.mutex) self.mutex.unlock() @@ -566,50 +598,58 @@ def readMetadata(self, raw_src_path, filename): self.saveChannels = self.metadataWin.saveChannels self.emWavelens = self.metadataWin.emWavelens self.addImageName = self.metadataWin.addImageName - + return False def saveToPosFolder( - self, p, raw_src_path, exp_dst_path, filename, series, pos_n, - p_idx=0, - ): + self, + p, + raw_src_path, + exp_dst_path, + filename, + series, + pos_n, + p_idx=0, + ): rawFilePath = os.path.join(raw_src_path, filename) - if os.path.basename(raw_src_path) == 'raw_microscopy_files': + if os.path.basename(raw_src_path) == "raw_microscopy_files": raw_src_path = os.path.dirname(raw_src_path) - in_file_pos_name = f'Position_{p+1}' + in_file_pos_name = f"Position_{p + 1}" savePos = ( - 'All Positions' in self.selectedPos - or in_file_pos_name in self.selectedPos + "All Positions" in self.selectedPos or in_file_pos_name in self.selectedPos ) if not savePos: return False - pos_path = os.path.join(exp_dst_path, f'Position_{pos_n}') - images_path = os.path.join(pos_path, 'Images') + pos_path = os.path.join(exp_dst_path, f"Position_{pos_n}") + images_path = os.path.join(pos_path, "Images") if os.path.exists(images_path) and self.overwritePos: shutil.rmtree(images_path) - + if os.path.exists(images_path) and self.createNew: - images_path = re.sub( - r'Position_\d+', f'Position_{pos_n}', images_path - ) - + images_path = re.sub(r"Position_\d+", f"Position_{pos_n}", images_path) + if not os.path.exists(images_path): os.makedirs(images_path, exist_ok=True) - + self.saveData( - images_path, rawFilePath, filename, p, series, pos_n, p_idx=p_idx, + images_path, + rawFilePath, + filename, + p, + series, + pos_n, + p_idx=p_idx, ) - + return False def _saveDataPythonBioformats( - self, bioformats, rawFilePath, series, images_path, filenameNOext, - s0p, idxs - ): + self, bioformats, rawFilePath, series, images_path, filenameNOext, s0p, idxs + ): SizeZ = self.getSizeZ(rawFilePath) with bioformats.ImageReader(rawFilePath) as reader: iter = enumerate(zip(self.chNames, self.saveChannels)) @@ -619,155 +659,166 @@ def _saveDataPythonBioformats( continue self.progress.emit( - f' Saving channel {c+1}/{len(self.chNames)} ({chName})' + f" Saving channel {c + 1}/{len(self.chNames)} ({chName})" ) self.saveImgDataChannel( - reader, series, images_path, filenameNOext, s0p, - chName, c, idxs, SizeZ + reader, + series, + images_path, + filenameNOext, + s0p, + chName, + c, + idxs, + SizeZ, ) - + def _saveDataPythonBioformatsSingleChannel( - self, bioformats, rawFilePath, series, images_path, filenameNOext, - s0p, idxs, chName, c_idx - ): + self, + bioformats, + rawFilePath, + series, + images_path, + filenameNOext, + s0p, + idxs, + chName, + c_idx, + ): SizeZ = self.getSizeZ(rawFilePath) with bioformats.ImageReader(rawFilePath) as reader: self.progress.emit( - f' Saving channel {c_idx+1}/{len(self.chNames)} ({chName})' + f" Saving channel {c_idx + 1}/{len(self.chNames)} ({chName})" ) imgData_ch = [] self.saveImgDataChannel( - reader, series, images_path, filenameNOext, s0p, - chName, 0, idxs, SizeZ + reader, series, images_path, filenameNOext, s0p, chName, 0, idxs, SizeZ ) - + def removeInvalidCharacters(self, chName_in): # Remove invalid charachters chName = "".join( - c if c.isalnum() or c=='_' or c=='' else '_' for c in chName_in + c if c.isalnum() or c == "_" or c == "" else "_" for c in chName_in ) - trim_ = chName.endswith('_') + trim_ = chName.endswith("_") while trim_: chName = chName[:-1] - trim_ = chName.endswith('_') + trim_ = chName.endswith("_") def getFilename( - self, filenameNOext, s0p, appendTxt, series, ext, - return_basename=False - ): + self, filenameNOext, s0p, appendTxt, series, ext, return_basename=False + ): # Do not allow dots in the filename since it breaks stuff here and there - filenameNOext = filenameNOext.replace('.', '_') + filenameNOext = filenameNOext.replace(".", "_") if self.addImageName: try: ImageName = self.metadata.image(index=series).Name if not isinstance(ImageName, str): raise except Exception as e: - ImageName = '' + ImageName = "" self.removeInvalidCharacters(ImageName) - basename = f'{filenameNOext}_{ImageName}_s{s0p}_' - filename = f'{basename}{appendTxt}{ext}' + basename = f"{filenameNOext}_{ImageName}_s{s0p}_" + filename = f"{basename}{appendTxt}{ext}" else: - basename = f'{filenameNOext}_s{s0p}_' - filename = f'{basename}{appendTxt}{ext}' + basename = f"{filenameNOext}_s{s0p}_" + filename = f"{basename}{appendTxt}{ext}" if return_basename: return filename, basename else: return filename - + def buildIndexes(self, SizeC, SizeT, SizeZ): - SizesCTZ = {'c': SizeC, 't': SizeT, 'z': SizeZ} + SizesCTZ = {"c": SizeC, "t": SizeT, "z": SizeZ} idxs = {} - k_key, i_key, j_key = 'ztc' + k_key, i_key, j_key = "ztc" idx = 0 for k in range(SizesCTZ[k_key]): for i in range(SizesCTZ[i_key]): for j in range(SizesCTZ[j_key]): - idxs[(k,i,j)] = idx + idxs[(k, i, j)] = idx idx += 1 return idxs def getIndex(self, idxs, dimsIdx): - dims = tuple([dimsIdx.get(v, 0) for v in 'ztc']) + dims = tuple([dimsIdx.get(v, 0) for v in "ztc"]) return idxs[dims] - + def saveImgDataChannel( - self, reader, series, images_path, filenameNOext, s0p, chName, - ch_idx, idxs, SizeZ - ): + self, + reader, + series, + images_path, + filenameNOext, + s0p, + chName, + ch_idx, + idxs, + SizeZ, + ): savedSizeT = self.timeRangeToSave[1] - self.timeRangeToSave[0] + 1 if self.to_h5: - filename = self.getFilename( - filenameNOext, s0p, chName, series, '.h5' - ) + filename = self.getFilename(filenameNOext, s0p, chName, series, ".h5") tempDir = tempfile.mkdtemp() tempFilepath = os.path.join(tempDir, filename) - print('==========================================================') + print("==========================================================") print(f'.h5 tempfile: "{tempFilepath}"') - print('==========================================================') - h5f = h5py.File(tempFilepath, 'w') + print("==========================================================") + h5f = h5py.File(tempFilepath, "w") # Read SizeX and SizeY from the shape of one image - imgData = reader.read( - c=ch_idx, z=0, t=0, series=series, rescale=False - ) + imgData = reader.read(c=ch_idx, z=0, t=0, series=series, rescale=False) shape = (savedSizeT, SizeZ, *imgData.shape) - chunks = (1,1,*imgData.shape) + chunks = (1, 1, *imgData.shape) imgData_ch = h5f.create_dataset( - 'data', shape, dtype=imgData.dtype, - chunks=chunks, shuffle=False + "data", shape, dtype=imgData.dtype, chunks=chunks, shuffle=False ) else: - filename = self.getFilename( - filenameNOext, s0p, chName, series, '.tif' - ) + filename = self.getFilename(filenameNOext, s0p, chName, series, ".tif") imgData_ch = [] - framesRange = range( - self.timeRangeToSave[0]-1, - self.timeRangeToSave[1] - ) + framesRange = range(self.timeRangeToSave[0] - 1, self.timeRangeToSave[1]) filePath = os.path.join(images_path, filename) - dimsIdx = {'c': ch_idx} + dimsIdx = {"c": ch_idx} numFrames = len(framesRange) - num_imgs = numFrames*SizeZ + num_imgs = numFrames * SizeZ pbar = tqdm( - total=num_imgs, - ncols=100, - desc=f'Reading image (z 0/{SizeZ}, t 0/{numFrames})' + total=num_imgs, + ncols=100, + desc=f"Reading image (z 0/{SizeZ}, t 0/{numFrames})", ) for out_t, t in enumerate(framesRange): imgData_z = [] - dimsIdx['t'] = t + dimsIdx["t"] = t for z in range(SizeZ): pbar.set_description( - f'Reading image (z {z+1}/{SizeZ}, t {out_t+1}/{numFrames})' + f"Reading image (z {z + 1}/{SizeZ}, t {out_t + 1}/{numFrames})" ) - dimsIdx['z'] = z + dimsIdx["z"] = z if self.rawDataStruct != 2: idx = self.getIndex(idxs, dimsIdx) else: idx = None imgData = reader.read( - c=ch_idx, z=z, t=t, series=series, rescale=False, - index=idx + c=ch_idx, z=z, t=t, series=series, rescale=False, index=idx ) if self.to_h5: imgData_ch[out_t, z] = imgData else: imgData_z.append(imgData) - + pbar.update() if not self.to_h5: imgData_z = np.squeeze(np.array(imgData_z, dtype=imgData.dtype)) imgData_ch.append(imgData_z) pbar.close() - + if not self.to_h5: imgData_ch = np.squeeze(np.array(imgData_ch, dtype=imgData.dtype)) - myutils.to_tiff( - filePath, imgData_ch, + utils.to_tiff( + filePath, + imgData_ch, SizeT=savedSizeT, SizeZ=self.SizeZ, TimeIncrement=self.TimeIncrement, @@ -780,91 +831,91 @@ def saveImgDataChannel( shutil.move(tempFilepath, filePath) shutil.rmtree(tempDir) - def saveData( - self, images_path, rawFilePath, filename, p, series, pos_n, - p_idx=0 - ): - if self.bioformats_backend == 'bioio': + def saveData(self, images_path, rawFilePath, filename, p, series, pos_n, p_idx=0): + if self.bioformats_backend == "bioio": from cellacdc import acdc_bioio_bioformats as bioformats else: import javabridge from cellacdc import bioformats - + s0p = str(pos_n).zfill(self.numPosDigits) self.progress.emit( - f'Position {pos_n}/{self.numPos}: saving data to {images_path}...' + f"Position {pos_n}/{self.numPos}: saving data to {images_path}..." ) filenameNOext, ext = os.path.splitext(filename) metadataXML_path = os.path.join( images_path, - self.getFilename(filenameNOext, s0p, 'metadataXML', series, '.txt') + self.getFilename(filenameNOext, s0p, "metadataXML", series, ".txt"), ) - with open(metadataXML_path, 'w', encoding="utf-8") as txt: + with open(metadataXML_path, "w", encoding="utf-8") as txt: txt.write(str(self.metadataXML)) metadata_filename, basename = self.getFilename( - filenameNOext, s0p, 'metadata', series, '.csv', - return_basename=True + filenameNOext, s0p, "metadata", series, ".csv", return_basename=True ) metadata_csv_path = os.path.join(images_path, metadata_filename) - savedSizeT = ( - self.timeRangeToSave[1] - self.timeRangeToSave[0] + 1 - ) - df = pd.DataFrame({ - 'LensNA': self.LensNA, - 'SizeT': savedSizeT, - 'SizeZ': self.SizeZ, - 'TimeIncrement': self.TimeIncrement, - 'PhysicalSizeZ': self.PhysicalSizeZ, - 'PhysicalSizeY': self.PhysicalSizeY, - 'PhysicalSizeX': self.PhysicalSizeX, - 'basename': basename - }, index=['values']).T - df.index.name = 'Description' + savedSizeT = self.timeRangeToSave[1] - self.timeRangeToSave[0] + 1 + df = pd.DataFrame( + { + "LensNA": self.LensNA, + "SizeT": savedSizeT, + "SizeZ": self.SizeZ, + "TimeIncrement": self.TimeIncrement, + "PhysicalSizeZ": self.PhysicalSizeZ, + "PhysicalSizeY": self.PhysicalSizeY, + "PhysicalSizeX": self.PhysicalSizeX, + "basename": basename, + }, + index=["values"], + ).T + df.index.name = "Description" ch_metadata = [ - chName for c, chName in enumerate(self.chNames) - if self.saveChannels[c] + chName for c, chName in enumerate(self.chNames) if self.saveChannels[c] ] description = [ - f'channel_{c}_name' for c in range(self.SizeC) - if self.saveChannels[c] + f"channel_{c}_name" for c in range(self.SizeC) if self.saveChannels[c] ] - ch_metadata.extend([ - wavelen for c, wavelen in enumerate(self.emWavelens) - if self.saveChannels[c] - ]) - description.extend([ - f'channel_{c}_emWavelen' for c in range(self.SizeC) - if self.saveChannels[c] - ]) - - df_channelNames = pd.DataFrame({ - 'Description': description, - 'values': ch_metadata - }).set_index('Description') + ch_metadata.extend( + [ + wavelen + for c, wavelen in enumerate(self.emWavelens) + if self.saveChannels[c] + ] + ) + description.extend( + [ + f"channel_{c}_emWavelen" + for c in range(self.SizeC) + if self.saveChannels[c] + ] + ) + + df_channelNames = pd.DataFrame( + {"Description": description, "values": ch_metadata} + ).set_index("Description") df = pd.concat([df, df_channelNames]) if os.path.exists(metadata_csv_path): # Keep channel names already existing and not saved now - existing_df = pd.read_csv(metadata_csv_path).set_index('Description') + existing_df = pd.read_csv(metadata_csv_path).set_index("Description") for c, chName in enumerate(self.chNames): if self.saveChannels[c]: continue - chName_idx = f'channel_{c}_name' - chWavelen_idx = f'channel_{c}_emWavelen' + chName_idx = f"channel_{c}_name" + chWavelen_idx = f"channel_{c}_emWavelen" try: - existing_chName = existing_df.at[chName_idx, 'values'] - df.at[chName_idx, 'values'] = existing_chName + existing_chName = existing_df.at[chName_idx, "values"] + df.at[chName_idx, "values"] = existing_chName except Exception as e: traceback.print_exc() pass - + try: - existing_chWavelen = existing_df.at[chWavelen_idx, 'values'] - df.at[chWavelen_idx, 'values'] = existing_chWavelen + existing_chWavelen = existing_df.at[chWavelen_idx, "values"] + df.at[chWavelen_idx, "values"] = existing_chWavelen except Exception as e: traceback.print_exc() pass @@ -872,53 +923,58 @@ def saveData( df.to_csv(metadata_csv_path) idxs = self.buildIndexes(self.SizeC, self.SizeT, self.SizeZ) - if self.rawDataStruct != 2: - if self.bioformats_backend == 'bioio': + if self.rawDataStruct != 2: + if self.bioformats_backend == "bioio": import subprocess from . import _process - + save_data_py_filepath = os.path.join( - os.path.dirname(bioformats.__file__), '_save_data.py' + os.path.dirname(bioformats.__file__), "_save_data.py" ) zyx_physical_sizes = ( - self.PhysicalSizeZ, self.PhysicalSizeY, self.PhysicalSizeX - ) - zyx_physical_sizes = " ".join( - [str(val) for val in zyx_physical_sizes] + self.PhysicalSizeZ, + self.PhysicalSizeY, + self.PhysicalSizeX, ) + zyx_physical_sizes = " ".join([str(val) for val in zyx_physical_sizes]) uuid4 = uuid.uuid4() command = ( - f'{sys.executable}, {save_data_py_filepath}, ' - f'-f, {rawFilePath}, ' - f'-d, {" ".join([str(val) for val in self.saveChannels])}, ' - f'-c, {" ".join(self.chNames)}, ' - f'-s, {series}, ' - f'-i, {images_path}, ' - f'-p, {filenameNOext}, ' - f'-pos, {s0p}, ' - f'-t, {self.SizeT}, ' - f'-z, {self.getSizeZ(rawFilePath)}, ' - f'-time_increment, {self.TimeIncrement}, ' - f'-zyx, {zyx_physical_sizes}, ' - f'-r, {" ".join([str(val) for val in self.timeRangeToSave])}, ' - f'-uuid, {uuid4}' + f"{sys.executable}, {save_data_py_filepath}, " + f"-f, {rawFilePath}, " + f"-d, {' '.join([str(val) for val in self.saveChannels])}, " + f"-c, {' '.join(self.chNames)}, " + f"-s, {series}, " + f"-i, {images_path}, " + f"-p, {filenameNOext}, " + f"-pos, {s0p}, " + f"-t, {self.SizeT}, " + f"-z, {self.getSizeZ(rawFilePath)}, " + f"-time_increment, {self.TimeIncrement}, " + f"-zyx, {zyx_physical_sizes}, " + f"-r, {' '.join([str(val) for val in self.timeRangeToSave])}, " + f"-uuid, {uuid4}" ) if self.to_h5: - command = f'{command}, -to_h5' - + command = f"{command}, -to_h5" + if not self.lazy_load: - command = f'{command}, -a' - - args = [sys.executable, _process.__file__, '-c', command] + command = f"{command}, -a" + + args = [sys.executable, _process.__file__, "-c", command] subprocess.run(args) - + bioformats._utils.check_raise_exception(uuid4) - + self.progressPbar.emit(len(self.chNames)) - else: + else: self._saveDataPythonBioformats( - bioformats, rawFilePath, series, images_path, - filenameNOext, s0p, idxs + bioformats, + rawFilePath, + series, + images_path, + filenameNOext, + s0p, + idxs, ) elif self.rawDataStruct == 2: @@ -930,74 +986,81 @@ def saveData( if not saveCh: continue - rawFilename = f'{basename}{pos_n}_{chName}' + rawFilename = f"{basename}{pos_n}_{chName}" pos_rawFilenames.append(rawFilename) raw_src_path = os.path.dirname(rawFilePath) rawFilePath = [ os.path.join(raw_src_path, f) - for f in myutils.listdir(raw_src_path) - if f.find(rawFilename)!=-1 + for f in utils.listdir(raw_src_path) + if f.find(rawFilename) != -1 ][0] - if self.bioformats_backend == 'bioio': + if self.bioformats_backend == "bioio": import subprocess from . import _process - + save_data_py_filepath = os.path.join( - os.path.dirname(bioformats.__file__), - '_save_data_single_channel.py' + os.path.dirname(bioformats.__file__), + "_save_data_single_channel.py", ) zyx_physical_sizes = ( - self.PhysicalSizeZ, - self.PhysicalSizeY, - self.PhysicalSizeX + self.PhysicalSizeZ, + self.PhysicalSizeY, + self.PhysicalSizeX, ) zyx_physical_sizes = " ".join( [str(val) for val in zyx_physical_sizes] ) uuid4 = uuid.uuid4() command = ( - f'{sys.executable}, {save_data_py_filepath}, ' - f'-f, {rawFilePath}, ' - f'-d, {" ".join([str(val) for val in self.saveChannels])}, ' - f'-c, {chName}, ' - f'-ch_idx, {c}, ' - f'-s, {series}, ' - f'-i, {images_path}, ' - f'-p, {filenameNOext}, ' - f'-pos, {s0p}, ' - f'-t, {self.SizeT}, ' - f'-z, {self.getSizeZ(rawFilePath)}, ' - f'-time_increment, {self.TimeIncrement}, ' - f'-zyx, {zyx_physical_sizes}, ' - f'-r, {" ".join([str(val) for val in self.timeRangeToSave])}, ' - f'-uuid, {uuid4}' + f"{sys.executable}, {save_data_py_filepath}, " + f"-f, {rawFilePath}, " + f"-d, {' '.join([str(val) for val in self.saveChannels])}, " + f"-c, {chName}, " + f"-ch_idx, {c}, " + f"-s, {series}, " + f"-i, {images_path}, " + f"-p, {filenameNOext}, " + f"-pos, {s0p}, " + f"-t, {self.SizeT}, " + f"-z, {self.getSizeZ(rawFilePath)}, " + f"-time_increment, {self.TimeIncrement}, " + f"-zyx, {zyx_physical_sizes}, " + f"-r, {' '.join([str(val) for val in self.timeRangeToSave])}, " + f"-uuid, {uuid4}" ) if self.to_h5: - command = f'{command}, -to_h5' - - args = [sys.executable, _process.__file__, '-c', command] + command = f"{command}, -to_h5" + + args = [sys.executable, _process.__file__, "-c", command] subprocess.run(args) - + bioformats._utils.check_raise_exception(uuid4) - + self.progressPbar.emit(1) - else: + else: self._saveDataPythonBioformatsSingleChannel( - bioformats, rawFilePath, series, images_path, - filenameNOext, s0p, idxs, chName, c + bioformats, + rawFilePath, + series, + images_path, + filenameNOext, + s0p, + idxs, + chName, + c, ) if self.moveOtherFiles or self.copyOtherFiles: # Move the other files present in the folder if they # contain "otherFilename" in the name - otherFilename = f'{basename}{pos_n}' + otherFilename = f"{basename}{pos_n}" rawFilePath = set() - for f in myutils.listdir(raw_src_path): + for f in utils.listdir(raw_src_path): notRawFile = all( - [f.find(rawName)==-1 for rawName in pos_rawFilenames] + [f.find(rawName) == -1 for rawName in pos_rawFilenames] ) - isPosFile = f.find(otherFilename)!=-1 + isPosFile = f.find(otherFilename) != -1 if isPosFile and notRawFile: rawFilePath.add(os.path.join(raw_src_path, f)) @@ -1005,12 +1068,12 @@ def saveData( # Determine basename, posNum and chName to build # filename as "basename_s01_chName.ext" _filename = os.path.basename(file) - m = re.findall(fr'{basename}(\d+)_(.+)', _filename) - if not m or len(m[0])!=2: + m = re.findall(rf"{basename}(\d+)_(.+)", _filename) + if not m or len(m[0]) != 2: dst = os.path.join(images_path, _filename) else: _chNameWithExt = m[0][1] - _filename = f'{filenameNOext}_s{s0p}_{_chNameWithExt}' + _filename = f"{filenameNOext}_s{s0p}_{_chNameWithExt}" dst = os.path.join(images_path, _filename) if self.moveOtherFiles: try: @@ -1027,16 +1090,17 @@ def saveData( def run(self): raw_src_path = self.raw_src_path exp_dst_path = self.exp_dst_path - - if self.bioformats_backend == 'python-bioformats': + + if self.bioformats_backend == "python-bioformats": import javabridge from cellacdc import bioformats + javabridge.start_vm(class_path=bioformats.JARS, run_headless=True) - self.progress.emit('Java VM running.') - + self.progress.emit("Java VM running.") + self.cancelled = False self.isCriticalError = False - + for p, filename in enumerate(self.rawFilenames): pos_n = p + self.start_pos_n if self.rawDataStruct == 0: @@ -1049,12 +1113,16 @@ def run(self): self.numPos = self.SizeS self.numPosDigits = len(str(self.numPos)) if p == 0: - self.initPbar.emit(self.numPos*self.SizeC) - + self.initPbar.emit(self.numPos * self.SizeC) + for in_file_p in range(self.SizeS): cancel = self.saveToPosFolder( - in_file_p, raw_src_path, exp_dst_path, filename, - in_file_p, pos_n + in_file_p, + raw_src_path, + exp_dst_path, + filename, + in_file_p, + pos_n, ) if cancel: self.cancelled = True @@ -1069,7 +1137,7 @@ def run(self): self.numPos = len(self.rawFilenames) self.numPosDigits = len(str(self.numPos)) if p == 0: - self.initPbar.emit(self.numPos*self.SizeC) + self.initPbar.emit(self.numPos * self.SizeC) cancel = self.saveToPosFolder( p, raw_src_path, exp_dst_path, filename, 0, pos_n ) @@ -1081,9 +1149,7 @@ def run(self): break # Move files to raw_microscopy_files folder - self.move_to_raw_microscopy_files_folder( - self.raw_src_path, filename - ) + self.move_to_raw_microscopy_files_folder(self.raw_src_path, filename) if self.rawDataStruct == 2: filename = self.rawFilenames[0] @@ -1091,48 +1157,45 @@ def run(self): abort = self.readMetadata(raw_src_path, filename) if abort: self.cancelled = True - if self.bioformats_backend == 'python-bioformats': + if self.bioformats_backend == "python-bioformats": javabridge.kill_vm() self.finished.emit() return - + self.numPos = len(self.posNums) self.numPosDigits = len(str(self.numPos)) - self.initPbar.emit(self.numPos*self.SizeC) + self.initPbar.emit(self.numPos * self.SizeC) for p_idx, pos in enumerate(self.posNums): - p = pos-1 + p = pos - 1 abort = self.saveToPosFolder( - p, raw_src_path, exp_dst_path, self.basename, 0, - pos, p_idx=p_idx + p, raw_src_path, exp_dst_path, self.basename, 0, pos, p_idx=p_idx ) if abort: self.cancelled = True break for filename in self.rawFilenames: - self.move_to_raw_microscopy_files_folder( - self.raw_src_path, filename - ) + self.move_to_raw_microscopy_files_folder(self.raw_src_path, filename) - if self.bioformats_backend == 'python-bioformats': + if self.bioformats_backend == "python-bioformats": javabridge.kill_vm() self.finished.emit() - + def move_to_raw_microscopy_files_folder(self, raw_src_path, filename): # Move files to raw_microscopy_files folder foldername = os.path.basename(raw_src_path) - + if self.cancelled: return - - if foldername == 'raw_microscopy_files': + + if foldername == "raw_microscopy_files": return - + if not self.move_raw_microscopy_files: return - + rawFilePath = os.path.join(self.raw_src_path, filename) - raw_path = os.path.join(raw_src_path, 'raw_microscopy_files') + raw_path = os.path.join(raw_src_path, "raw_microscopy_files") if not os.path.exists(raw_path): os.mkdir(raw_path) dst = os.path.join(raw_path, filename) @@ -1141,17 +1204,23 @@ def move_to_raw_microscopy_files_folder(self, raw_src_path, filename): except PermissionError as e: self.progress.emit(e) + class createDataStructWin(QMainWindow): def __init__( - self, parent=None, allowExit=False, buttonToRestore=None, - mainWin=None, start_JVM=True, version=None - ): + self, + parent=None, + allowExit=False, + buttonToRestore=None, + mainWin=None, + start_JVM=True, + version=None, + ): super().__init__(parent) self._version = version - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='dataStruct' + logger, logs_path, log_path, log_filename = utils.setupLogger( + module="dataStruct" ) self.logger = logger self.log_path = log_path @@ -1159,9 +1228,9 @@ def __init__( self.logs_path = logs_path if self._version is not None: - logger.info(f'Initializing Data structure module v{self._version}...') + logger.info(f"Initializing Data structure module v{self._version}...") else: - logger.info(f'Initializing Data structure module...') + logger.info(f"Initializing Data structure module...") self.start_JVM = start_JVM self.allowExit = allowExit @@ -1169,11 +1238,9 @@ def __init__( self.buttonToRestore = buttonToRestore self.mainWin = mainWin self.metadataDialogIsOpen = False - self.df_settings = pd.read_csv( - settings_csv_path, index_col='setting' - ) + self.df_settings = pd.read_csv(settings_csv_path, index_col="setting") - version = myutils.read_version() + version = utils.read_version() self.setWindowTitle(f"Cell-ACDC v{version} - Data structure") self.setWindowIcon(QtGui.QIcon(":icon.ico")) @@ -1182,9 +1249,7 @@ def __init__( mainLayout = QVBoxLayout() - label = QLabel( - 'Creating data structure from raw microscopy file(s)...' - ) + label = QLabel("Creating data structure from raw microscopy file(s)...") label.setStyleSheet("padding:5px 10px 10px 10px;") label.setAlignment(Qt.AlignCenter) @@ -1194,8 +1259,7 @@ def __init__( label.setFont(font) mainLayout.addWidget(label) - informativeHtml = ( - """ + informativeHtml = """ @@ -1224,7 +1288,6 @@ def __init__( """ - ) informativeText = QLabel(self) @@ -1237,104 +1300,104 @@ def __init__( self.logWin.setReadOnly(True) mainLayout.addWidget(self.logWin) - abortButton = widgets.cancelPushButton(' Stop processs ') + abortButton = widgets.cancelPushButton(" Stop processs ") abortButton.clicked.connect(self.close) - + buttonsLayout = QHBoxLayout() buttonsLayout.addStretch(1) buttonsLayout.addWidget(abortButton) - + mainLayout.addLayout(buttonsLayout) mainLayout.setContentsMargins(20, 0, 20, 20) mainContainer.setLayout(mainLayout) self.mainLayout = mainLayout - + try: import javabridge from cellacdc import bioformats - self.bioformats_backend = 'python-bioformats' + + self.bioformats_backend = "python-bioformats" except Exception as e: pass - - self.bioformats_backend = 'bioio' + + self.bioformats_backend = "bioio" success = self.checkInstallBioIO(parent) if success: return - - self.bioformats_backend = 'python-bioformats' + + self.bioformats_backend = "python-bioformats" self.checkInstallPythonBioformats(parent) def checkInstallBioIO(self, parent): - myutils.check_install_package( - 'BioIO', - import_pkg_name='bioio', - pypi_name='bioio', - min_version='0.1.0', + utils.check_install_package( + "BioIO", + import_pkg_name="bioio", + pypi_name="bioio", + min_version="0.1.0", parent=parent, ) - + return True - + def checkInstallPythonBioformats(self, parent): from . import is_win, is_mac - + if not is_win and not is_mac: if parent is None: self.show() self.criticalOSnotSupported() self.close() - raise OSError('This module is supported ONLY on Windows 10/10 and macOS') + raise OSError("This module is supported ONLY on Windows 10/10 and macOS") - success, jar_dst_path = myutils.download_bioformats_jar( - qparent=self, logger_info=self.logger.info, - logger_exception=self.logger.exception + success, jar_dst_path = utils.download_bioformats_jar( + qparent=self, + logger_info=self.logger.info, + logger_exception=self.logger.exception, ) - self.logger.info('Checking if Java is installed...') - myutils.check_upgrade_javabridge() + self.logger.info("Checking if Java is installed...") + utils.check_upgrade_javabridge() try: import javabridge except ModuleNotFoundError as e: - print('======================================') + print("======================================") traceback_str = traceback.format_exc() self.logger.exception(traceback_str) - print('======================================') - cancel = myutils.install_javabridge_help(parent=self) + print("======================================") + cancel = utils.install_javabridge_help(parent=self) if cancel: - raise ModuleNotFoundError( - 'User aborted javabridge installation' - ) + raise ModuleNotFoundError("User aborted javabridge installation") - isGitInstalled = myutils.check_git_installed(parent=self) + isGitInstalled = utils.check_git_installed(parent=self) if not isGitInstalled: raise ModuleNotFoundError( - 'Git is not installed. Install from ' - 'https://git-scm.com/book/en/v2/Getting-Started-Installing-Git' + "Git is not installed. Install from " + "https://git-scm.com/book/en/v2/Getting-Started-Installing-Git" ) try: - jre_path, jdk_path, url = myutils.download_java() + jre_path, jdk_path, url = utils.download_java() except Exception as e: - print('======================================') + print("======================================") traceback_str = traceback.format_exc() self.logger.exception(traceback_str) - print('======================================') - java_info = myutils.get_java_url() + print("======================================") + java_info = utils.get_java_url() url, file_size, os_foldername, unzipped_foldername = java_info - acdc_java_path, _ = myutils.get_acdc_java_path() + acdc_java_path, _ = utils.get_acdc_java_path() java_href = f'this' s = ( - f'1. Download {java_href} .zip file and unzip it.
' - '2. Inside the unzipped folder there should be a folder called ' + f"1. Download {java_href} .zip file and unzip it.
" + "2. Inside the unzipped folder there should be a folder called " f'"{unzipped_foldername}". Open that folder and copy its ' - 'content to the following path:

' - f'{os.path.join(acdc_java_path, os_foldername)}' + "content to the following path:

" + f"{os.path.join(acdc_java_path, os_foldername)}" ) note = ( - '

NOTE: if clicking on the link above does not work ' - 'copy the link below and paste it into the browser

' - f'{url}' + "

NOTE: if clicking on the link above does not work " + "copy the link below and paste it into the browser

" + f"{url}" ) msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(f""" @@ -1343,79 +1406,69 @@ def checkInstallPythonBioformats(self, parent): launching this module again.

{s}{note} """) - msg.warning(self, 'Java not found', txt) - - err = s.replace('
', ' ') - err = err.replace('this', '') + msg.warning(self, "Java not found", txt) + + err = s.replace("
", " ") + err = err.replace("this", "") raise ModuleNotFoundError( - 'Installation of module "javabridge" failed. ' - f'{err}' + f'Installation of module "javabridge" failed. {err}' ) if not is_win: - cancel = myutils.install_java() + cancel = utils.install_java() if cancel: - raise ModuleNotFoundError( - 'User aborted Java installation' - ) + raise ModuleNotFoundError("User aborted Java installation") return - myutils.install_javabridge() + utils.install_javabridge() except Exception as e: - print('======================================') + print("======================================") traceback_str = traceback.format_exc() self.logger.exception(traceback_str) - print('======================================') - cancel = myutils.install_java() + print("======================================") + cancel = utils.install_java() if cancel: - raise ModuleNotFoundError( - 'User aborted Java installation' - ) + raise ModuleNotFoundError("User aborted Java installation") return - myutils.install_javabridge( - force_compile=True, attempt_uninstall_first=True - ) + utils.install_javabridge(force_compile=True, attempt_uninstall_first=True) try: import javabridge from cellacdc import bioformats except Exception as e: - print('===============================================================') + print("===============================================================") traceback_str = traceback.format_exc() self.logger.exception(traceback_str) error_msg = ( 'Error while importing "javabridge" and "bioformats".\n\n' - f'Please report error here: {issues_url}\n' + f"Please report error here: {issues_url}\n" ) print(error_msg) - print('===============================================================') + print("===============================================================") - title = 'Import javabridge/bioformats error' - txt = error_msg.replace('\n', '
') - txt = txt.replace( - issues_url, html_utils.href_tag(issues_url, issues_url) - ) + title = "Import javabridge/bioformats error" + txt = error_msg.replace("\n", "
") + txt = txt.replace(issues_url, html_utils.href_tag(issues_url, issues_url)) txt = html_utils.paragraph(txt) msg = widgets.myMessageBox(wrapText=False) - msg.critical( - self, title, txt, detailsText=traceback_str - ) + msg.critical(self, title, txt, detailsText=traceback_str) raise ModuleNotFoundError( - 'Error when importing javabridge. See above for details.' + "Error when importing javabridge. See above for details." ) def criticalOSnotSupported(self): from cellacdc import widgets + if self.parent() is None: msg = widgets.myMessageBox(self) else: msg = widgets.myMessageBox(self.parent()) - msg.setIcon(iconName='SP_MessageBoxCritical') - msg.setWindowTitle('Not a supported OS') - msg.addButton(' Ok ') - err_msg = (f""" + msg.setIcon(iconName="SP_MessageBoxCritical") + msg.setWindowTitle("Not a supported OS") + msg.addButton(" Ok ") + err_msg = f"""

Unfortunately, the module "0. Create data structure from microscopy file(s)" is functional only on Windows 10/11 and macOS.

@@ -1439,58 +1492,56 @@ def criticalOSnotSupported(self): here .

- """) + """ msg.addText(err_msg) # msg_label = msg.findChild(QLabel, "qt_msgbox_label") # msg_label.setOpenExternalLinks(False) # msg_label.linkActivated.connect(self.on_linkActivated) msg.exec_() - def on_linkActivated(self, link): - if link == 'manual': + if link == "manual": systems = { - 'nt': os.startfile, - 'posix': lambda foldername: os.system('xdg-open "%s"' % foldername), - 'os2': lambda foldername: os.system('open "%s"' % foldername) - } + "nt": os.startfile, + "posix": lambda foldername: os.system('xdg-open "%s"' % foldername), + "os2": lambda foldername: os.system('open "%s"' % foldername), + } main_path = pathlib.Path(__file__).resolve().parents[1] - userManual_path = main_path / 'UserManual' + userManual_path = main_path / "UserManual" systems.get(os.name, os.startfile)(userManual_path) - elif link == 'fiji': + elif link == "fiji": systems = { - 'nt': os.startfile, - 'posix': lambda foldername: os.system('xdg-open "%s"' % foldername), - 'os2': lambda foldername: os.system('open "%s"' % foldername) - } + "nt": os.startfile, + "posix": lambda foldername: os.system('xdg-open "%s"' % foldername), + "os2": lambda foldername: os.system('open "%s"' % foldername), + } main_path = pathlib.Path(__file__).resolve().parents[1] - fijiMacros_path = main_path / 'FijiMacros' + fijiMacros_path = main_path / "FijiMacros" systems.get(os.name, os.startfile)(fijiMacros_path) - def getMostRecentPath(self): if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - self.MostRecentPath = df.iloc[0]['path'] + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + self.MostRecentPath = df.iloc[0]["path"] if not isinstance(self.MostRecentPath, str): - self.MostRecentPath = '' + self.MostRecentPath = "" else: - self.MostRecentPath = '' + self.MostRecentPath = "" def addToRecentPaths(self, raw_src_path): if not os.path.exists(raw_src_path): return if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() else: - openedOn = [np.nan]*len(recentPaths) + openedOn = [np.nan] * len(recentPaths) if raw_src_path in recentPaths: pop_idx = recentPaths.index(raw_src_path) recentPaths.pop(pop_idx) @@ -1504,15 +1555,18 @@ def addToRecentPaths(self, raw_src_path): else: recentPaths = [raw_src_path] openedOn = [datetime.datetime.now()] - df = pd.DataFrame({'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, - dtype='datetime64[ns]')}) - df.index.name = 'index' + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" df.to_csv(recentPaths_path) @exception_handler def main(self): - self.log('Asking how raw data is structured...') + self.log("Asking how raw data is structured...") rawDataStruct, abort = self.askRawDataStruct() if abort: self.close() @@ -1525,39 +1579,32 @@ def main(self): self.close() return - self.log('Instructing to move raw data...') + self.log("Instructing to move raw data...") proceed = self.instructMoveRawFiles() if not proceed: self.close() return - self.log( - 'Asking to select the folder that contains the microscopy files...' - ) + self.log("Asking to select the folder that contains the microscopy files...") self.getMostRecentPath() raw_src_path = QFileDialog.getExistingDirectory( - self, 'Select folder containing the microscopy files', - self.MostRecentPath + self, "Select folder containing the microscopy files", self.MostRecentPath ) self.addToRecentPaths(raw_src_path) - if raw_src_path == '': + if raw_src_path == "": self.close() return - + self.log(f'Selected folder: "{raw_src_path}"') - - self.log( - 'Checking file format of loaded files...' - ) + + self.log("Checking file format of loaded files...") rawFilenames = self.checkFileFormat(raw_src_path) if not rawFilenames: self.close() return - - self.log( - 'Checking file names of loaded files...' - ) + + self.log("Checking file names of loaded files...") proceed, rawFilenames = self.checkFileNames(rawFilenames, raw_src_path) if not proceed: self.close() @@ -1569,41 +1616,38 @@ def main(self): self.close() return - self.log( - 'Asking in which folder to save the images files...' - ) + self.log("Asking in which folder to save the images files...") exp_dst_path = QFileDialog.getExistingDirectory( - self, 'Select the folder in which to save the images files', - raw_src_path + self, "Select the folder in which to save the images files", raw_src_path ) if not exp_dst_path: self.close() return - + out = self.askPosFoldersExisting(exp_dst_path) if out is None: self.close() return overwrite, add_files, create_new, start_pos_n = out - - self.log('Instructing to move raw data...') + + self.log("Instructing to move raw data...") loadEntirePosIntoRam = self.askHowToLoadData() if loadEntirePosIntoRam is None: self.close() return - + if not loadEntirePosIntoRam: self._installLazyLoadModules() - + self.loadEntirePosIntoRam = loadEntirePosIntoRam self.addToRecentPaths(exp_dst_path) self.addPbar() - + self.initBioIO(raw_src_path, rawFilenames) - + move_raw_microscopy_files = False if exp_dst_path == raw_src_path: move_raw_microscopy_files, cancel = self.askMoveRawMicroscopyFiles() @@ -1616,9 +1660,13 @@ def main(self): self.waitCond = QWaitCondition() self.thread = QThread() self.worker = bioFormatsWorker( - raw_src_path, rawFilenames, exp_dst_path, - self.mutex, self.waitCond, rawDataStruct, - self.bioformats_backend, + raw_src_path, + rawFilenames, + exp_dst_path, + self.mutex, + self.waitCond, + rawDataStruct, + self.bioformats_backend, lazy_load=not self.loadEntirePosIntoRam, move_raw_microscopy_files=move_raw_microscopy_files, overwrite=overwrite, @@ -1649,7 +1697,7 @@ def main(self): self.thread.started.connect(self.worker.run) self.thread.start() - + def askMoveRawMicroscopyFiles(self): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(f""" @@ -1658,29 +1706,28 @@ def askMoveRawMicroscopyFiles(self): to a sub-folder called raw_microscopy_files? """) _, doNotMoveButton, moveButton = msg.warning( - self, 'Too many objects', txt, - buttonsTexts=( - 'Cancel', 'No, do not move the files', 'Yes, move the files' - ) + self, + "Too many objects", + txt, + buttonsTexts=("Cancel", "No, do not move the files", "Yes, move the files"), ) return msg.clickedButton == moveButton, msg.cancel - + def _installLazyLoadModules(self): - myutils.check_install_package( - 'zarr', - installer='pip', + utils.check_install_package( + "zarr", + installer="pip", is_cli=False, parent=self, ) - + @exception_handler def workerCritical(self, error): raise error def instructManualStruct(self): - manual_url = 'https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf' - txt = ( - f""" + manual_url = "https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf" + txt = f"""

If you would like to add compatibility with your raw microscopy files,
you can request a new feature here.

@@ -1692,9 +1739,8 @@ def instructManualStruct(self): "Manually create data structure from microscopy file(s)"

""" - ) msg = QMessageBox(self) - msg.setWindowTitle('Data structure not available') + msg.setWindowTitle("Data structure not available") msg.setIcon(msg.Information) msg.setText(txt) msg.setTextInteractionFlags(Qt.TextBrowserInteraction) @@ -1702,46 +1748,47 @@ def instructManualStruct(self): msg.exec_() def initBioIO(self, raw_src_path, raw_filenames): - if self.bioformats_backend == 'python-bioformats': + if self.bioformats_backend == "python-bioformats": return - + from cellacdc import acdc_bioio_bioformats as bioformats + raw_filepath = os.path.join(raw_src_path, raw_filenames[0]) - + bioformats.install.install_reader_dependencies( - raw_filepath, + raw_filepath, exception=Exception( - 'Failed installing reader dependencies from the GUI, ' - 'trying from terminal...' + "Failed installing reader dependencies from the GUI, " + "trying from terminal..." ), - qparent=self + qparent=self, ) - + import subprocess from . import _process - + init_reader_py_filepath = os.path.join( - os.path.dirname(bioformats.__file__), '_init_reader.py' + os.path.dirname(bioformats.__file__), "_init_reader.py" ) uuid4 = uuid.uuid4() command = ( - f'{sys.executable}, {init_reader_py_filepath}, ' - f'-f, {raw_filepath}, ' - f'-uuid, {uuid4}' + f"{sys.executable}, {init_reader_py_filepath}, " + f"-f, {raw_filepath}, " + f"-uuid, {uuid4}" ) - - args = [sys.executable, _process.__file__, '-c', command] + + args = [sys.executable, _process.__file__, "-c", command] subprocess.run(args) - + bioformats._utils.check_raise_exception(uuid4) - + def addPbar(self): self.QPbar = widgets.ProgressBar(self) self.QPbar.setValue(0) self.mainLayout.insertWidget(3, self.QPbar) def updatePbar(self, deltaPbar): - self.QPbar.setValue(self.QPbar.value()+deltaPbar) + self.QPbar.setValue(self.QPbar.value() + deltaPbar) def setPbarMax(self, max): self.QPbar.setMaximum(max) @@ -1749,23 +1796,18 @@ def setPbarMax(self, max): def taskEnded(self): if self.worker.cancelled and not self.worker.isCriticalError: msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'Conversion task cancelled.' - ) - msg.critical( - self, 'Conversion task cancelled.', txt - ) + txt = html_utils.paragraph("Conversion task cancelled.") + msg.critical(self, "Conversion task cancelled.", txt) self.close() elif not self.worker.cancelled: msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'Conversion task ended.

' - 'Files saved to' - ) + txt = html_utils.paragraph("Conversion task ended.

Files saved to") abort = msg.information( - self, 'Conversion task ended.', txt, - commands=(self.worker.exp_dst_path,), - path_to_browse=self.worker.exp_dst_path + self, + "Conversion task ended.", + txt, + commands=(self.worker.exp_dst_path,), + path_to_browse=self.worker.exp_dst_path, ) self.close() @@ -1774,18 +1816,20 @@ def log(self, text): self.logger.info(text) def askRawDataStruct(self): - infoText = html_utils.paragraph( - 'Select how you have your raw microscopy files arranged' + infoText = html_utils.paragraph( + "Select how you have your raw microscopy files arranged" ) win = apps.QDialogCombobox( - 'Raw data structure', + "Raw data structure", [ - 'Single microscopy file with multiple positions', - 'One or more microscopy files, one for each position', - 'One or more microscopy files, one for each channel', - 'NONE of the above' + "Single microscopy file with multiple positions", + "One or more microscopy files, one for each position", + "One or more microscopy files, one for each channel", + "NONE of the above", ], - infoText, CbLabel='', parent=self + infoText, + CbLabel="", + parent=self, ) win.exec_() if not win.cancel: @@ -1795,9 +1839,9 @@ def askRawDataStruct(self): def instructMoveRawFiles(self): msg = widgets.myMessageBox(showCentered=False, wrapText=False) tip_admon = html_utils.to_admonition( - 'If you have a single gray-scale TIFF file, ' - 'placing into a folder called Images will be enough.', - admonition_type='tip', + "If you have a single gray-scale TIFF file, " + "placing into a folder called Images will be enough.", + admonition_type="tip", ) txt = html_utils.paragraph(f""" Put all of the raw microscopy files from the same experiment @@ -1809,11 +1853,12 @@ def instructMoveRawFiles(self): by the microscope, for example '.czi' (Zeiss), '.nd2' (Nikon), '.lif' (Leica), etc.

{tip_admon} - """ - ) + """) msg.information( - self, 'Microscopy files location', txt, - buttonsTexts=('Cancel', widgets.okPushButton('Done')) + self, + "Microscopy files location", + txt, + buttonsTexts=("Cancel", widgets.okPushButton("Done")), ) if msg.cancel: return False @@ -1834,18 +1879,20 @@ def askHowToLoadData(self): """ ) _, loadFrameButton, loadPosButton = msg.warning( - self, 'Loading data', txt, + self, + "Loading data", + txt, buttonsTexts=( - 'Cancel', - widgets.twoDPushButton('No, load one frame (2D) at a time'), - widgets.FutureAllPushButton('Yes, load entire position at once') - ) + "Cancel", + widgets.twoDPushButton("No, load one frame (2D) at a time"), + widgets.FutureAllPushButton("Yes, load entire position at once"), + ), ) if msg.cancel: return None return msg.clickedButton == loadPosButton - + def warnSelectedPathEmpty(self, raw_src_path): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( @@ -1857,24 +1904,28 @@ def warnSelectedPathEmpty(self, raw_src_path): """ ) msg.warning( - self, 'Empty folder', txt, - commands=(raw_src_path, ), - path_to_browse=raw_src_path + self, + "Empty folder", + txt, + commands=(raw_src_path,), + path_to_browse=raw_src_path, ) - + def checkFileFormat(self, raw_src_path): self.moveOtherFiles = False self.copyOtherFiles = False - ls = natsorted(myutils.listdir(raw_src_path)) + ls = natsorted(utils.listdir(raw_src_path)) files = [ - filename for filename in ls + filename + for filename in ls if os.path.isfile(os.path.join(raw_src_path, filename)) ] if not files: self.warnSelectedPathEmpty(raw_src_path) return [] all_ext = [ - os.path.splitext(filename)[1] for filename in ls + os.path.splitext(filename)[1] + for filename in ls if os.path.isfile(os.path.join(raw_src_path, filename)) ] counter = Counter(all_ext) @@ -1883,10 +1934,10 @@ def checkFileFormat(self, raw_src_path): most_common_ext, _ = counter.most_common(1)[0] if not is_ext_unique: if not most_common_ext: - most_common_ext_msg = '' + most_common_ext_msg = "" else: most_common_ext_msg = most_common_ext - + msg = widgets.myMessageBox(showCentered=False) txt = html_utils.paragraph(f""" The following folder @@ -1902,21 +1953,24 @@ def checkFileFormat(self, raw_src_path):
""") _, yesButton, noButton = msg.warning( - self, 'Multiple extensions detected', txt, + self, + "Multiple extensions detected", + txt, buttonsTexts=( - 'Cancel', 'Yes, load only most common', - 'No, load all files' - ) + "Cancel", + "Yes, load only most common", + "No, load all files", + ), ) if msg.cancel: return [] if msg.clickedButton == yesButton: files = [ - filename for filename in files + filename + for filename in files if os.path.splitext(filename)[1] == most_common_ext ] - otherExt = [ - ext for ext in unique_ext if ext != most_common_ext] + otherExt = [ext for ext in unique_ext if ext != most_common_ext] files = self.askActionWithOtherFiles(files, otherExt) else: return files @@ -1948,40 +2002,38 @@ def checkFileNames(self, raw_filenames, raw_src_path): 'Rename file (replace invalid characters with "-")' ) msg.warning( - self, 'Invalid filename', txt, + self, + "Invalid filename", + txt, path_to_browse=raw_src_path, buttonsTexts=( - 'Let me rename files myself', - renameWithUnderscoresButton, - renameWithDashesButton - ) + "Let me rename files myself", + renameWithUnderscoresButton, + renameWithDashesButton, + ), ) if msg.clickedButton == renameWithUnderscoresButton: - self.log( - 'Renaming files to replace invalid characters with "_"...' - ) + self.log('Renaming files to replace invalid characters with "_"...') renamed_filenames = io.rename_files_replace_invalid_chars( - raw_filenames, raw_src_path, replacement_char='_' + raw_filenames, raw_src_path, replacement_char="_" ) return True, renamed_filenames elif msg.clickedButton == renameWithDashesButton: - self.log( - 'Renaming files to replace invalid characters with "-"...' - ) + self.log('Renaming files to replace invalid characters with "-"...') renamed_filenames = io.rename_files_replace_invalid_chars( - raw_filenames, raw_src_path, replacement_char='-' + raw_filenames, raw_src_path, replacement_char="-" ) return True, renamed_filenames else: return False, [] return True, raw_filenames - + def askActionWithOtherFiles(self, files, otherExt): self.moveOtherFiles = False msg = QMessageBox(self) - msg.setWindowTitle('Action with the other files?') - txt = (f""" + msg.setWindowTitle("Action with the other files?") + txt = f"""

What should I do with the other files (ext: {otherExt}) in the folder?

@@ -1989,21 +2041,13 @@ def askActionWithOtherFiles(self, files, otherExt): as the raw files will be moved or copied.

- """) + """ msg.setIcon(msg.Question) msg.setText(txt) - leaveButton = QPushButton( - 'Leave them where they are' - ) - moveButton = QPushButton( - 'Attempt MOVING to their Position folder' - ) - copyButton = QPushButton( - 'Attempt COPYING to their Position folder' - ) - cancelButton = QPushButton( - 'Cancel' - ) + leaveButton = QPushButton("Leave them where they are") + moveButton = QPushButton("Attempt MOVING to their Position folder") + copyButton = QPushButton("Attempt COPYING to their Position folder") + cancelButton = QPushButton("Cancel") msg.addButton(leaveButton, msg.YesRole) msg.addButton(moveButton, msg.NoRole) msg.addButton(copyButton, msg.RejectRole) @@ -2024,16 +2068,17 @@ def askActionWithOtherFiles(self, files, otherExt): elif msg.clickedButton() == cancelButton: return [] - def warnMultipleFiles(self, files): win = apps.QDialogCombobox( - 'Multiple microscopy files detected!', files, - '

' - 'You selected "Single microscopy file", ' - 'but the folder contains multiple files.
' - '

', - CbLabel='Select which file to load: ', parent=self, - iconPixmap=QtGui.QPixmap(':warning.svg') + "Multiple microscopy files detected!", + files, + '

' + 'You selected "Single microscopy file", ' + "but the folder contains multiple files.
" + "

", + CbLabel="Select which file to load: ", + parent=self, + iconPixmap=QtGui.QPixmap(":warning.svg"), ) win.exec_() if win.cancel: @@ -2048,7 +2093,7 @@ def attemptSeparateMultiChannel(self, rawFilenames): stripped_filenames = [] for file in rawFilenames: filename, ext = os.path.splitext(file) - m_iter = myutils.findalliter(fr'(\d+)_(.+)', filename) + m_iter = utils.findalliter(rf"(\d+)_(.+)", filename) if len(m_iter) <= 1: self.criticalNoFilenamePattern() return False @@ -2059,7 +2104,7 @@ def attemptSeparateMultiChannel(self, rawFilenames): posNum, chName = int(m[0][0]), m[0][1] self.chNames.add(chName) self.posNums.add(posNum) - ch_idx = filename.find(f'{posNum}_{chName}') + ch_idx = filename.find(f"{posNum}_{chName}") stripped_filenames.append(filename[:ch_idx]) except Exception as e: traceback_str = traceback.format_exc() @@ -2067,7 +2112,7 @@ def attemptSeparateMultiChannel(self, rawFilenames): self.criticalNoFilenamePattern(error=traceback.format_exc()) return False - basename = myutils.getBasename(stripped_filenames) + basename = utils.getBasename(stripped_filenames) if not basename: self.criticalNoFilenamePattern() return False @@ -2079,10 +2124,8 @@ def attemptSeparateMultiChannel(self, rawFilenames): self.SizeS = len(self.posNums) return True - - def criticalNoFilenamePattern(self, error=''): - txt = ( - """ + def criticalNoFilenamePattern(self, error=""): + txt = """ Files are named with a non-compatible pattern.

In order to automatically generate the required data structure from "Multiple files, one for each channel" the filenames must @@ -2093,9 +2136,8 @@ def criticalNoFilenamePattern(self, error=''): Note that the channel MUST be separated from the rest of the name by an underscore "_" """ - ) msg = QMessageBox(self) - msg.setWindowTitle('Non-compatible pattern') + msg.setWindowTitle("Non-compatible pattern") msg.setIcon(msg.Critical) msg.setText(txt) if error: @@ -2105,8 +2147,8 @@ def criticalNoFilenamePattern(self, error=''): def criticalBioFormats(self, actionTxt, tracebackFormat, filename): msg = widgets.myMessageBox(self, wrapText=True) - url = 'https://docs.openmicroscopy.org/bio-formats/6.7.0/supported-formats.html' - seeHere = f'here' + url = "https://docs.openmicroscopy.org/bio-formats/6.7.0/supported-formats.html" + seeHere = f'here' _, ext = os.path.splitext(filename) txt = html_utils.paragraph( @@ -2124,31 +2166,55 @@ def criticalBioFormats(self, actionTxt, tracebackFormat, filename): You were trying to read file: {filename} """ ) - - msg.critical( - self, 'Error with Bio-Formats', txt, detailsText=tracebackFormat - ) + + msg.critical(self, "Error with Bio-Formats", txt, detailsText=tracebackFormat) self.close() def askConfirmMetadata( - self, filename, LensNA, SizeT, SizeZ, SizeC, SizeS, - TimeIncrement, TimeIncrementUnit, PhysicalSizeX, PhysicalSizeY, - PhysicalSizeZ, PhysicalSizeUnit, chNames, emWavelens, ImageName, - rawFilePath, sampleImgData - ): + self, + filename, + LensNA, + SizeT, + SizeZ, + SizeC, + SizeS, + TimeIncrement, + TimeIncrementUnit, + PhysicalSizeX, + PhysicalSizeY, + PhysicalSizeZ, + PhysicalSizeUnit, + chNames, + emWavelens, + ImageName, + rawFilePath, + sampleImgData, + ): if self.rawDataStruct == 2: filename = self.basename self.metadataDialogIsOpen = True self.metadataWin = apps.QDialogMetadataXML( - title=f'Metadata for {filename}', rawFilename=filename, - LensNA=LensNA, SizeT=SizeT, SizeZ=SizeZ, SizeC=SizeC, SizeS=SizeS, - TimeIncrement=TimeIncrement, TimeIncrementUnit=TimeIncrementUnit, - PhysicalSizeX=PhysicalSizeX, PhysicalSizeY=PhysicalSizeY, - PhysicalSizeZ=PhysicalSizeZ, PhysicalSizeUnit=PhysicalSizeUnit, - ImageName=ImageName, chNames=chNames, emWavelens=emWavelens, - parent=self, rawDataStruct=self.rawDataStruct, - sampleImgData=sampleImgData, rawFilePath=rawFilePath + title=f"Metadata for {filename}", + rawFilename=filename, + LensNA=LensNA, + SizeT=SizeT, + SizeZ=SizeZ, + SizeC=SizeC, + SizeS=SizeS, + TimeIncrement=TimeIncrement, + TimeIncrementUnit=TimeIncrementUnit, + PhysicalSizeX=PhysicalSizeX, + PhysicalSizeY=PhysicalSizeY, + PhysicalSizeZ=PhysicalSizeZ, + PhysicalSizeUnit=PhysicalSizeUnit, + ImageName=ImageName, + chNames=chNames, + emWavelens=emWavelens, + parent=self, + rawDataStruct=self.rawDataStruct, + sampleImgData=sampleImgData, + rawFilePath=rawFilePath, ) self.metadataWin.exec_() self.metadataDialogIsOpen = False @@ -2156,43 +2222,45 @@ def askConfirmMetadata( self.waitCond.wakeAll() def askPosFoldersExisting(self, exp_dst_path): - pos_foldernames = myutils.get_pos_foldernames(exp_dst_path) + pos_foldernames = utils.get_pos_foldernames(exp_dst_path) if not pos_foldernames: return False, False, False, 1 - + msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'The selected destination folder already contains Position folders.

' - 'Do you want to overwrite all of its content, ' - 'add files to the existing Position folders,
' - 'or create new Position folders?' + "The selected destination folder already contains Position folders.

" + "Do you want to overwrite all of its content, " + "add files to the existing Position folders,
" + "or create new Position folders?" ) _, overwriteButton, addFilesButton, createNewButton = msg.warning( - self, 'Warning: existing Position folders detected!', txt, - buttonsTexts=( - 'Cancel', - 'Overwrite', - 'Add files', - widgets.newFilePushButton('Create new'), + self, + "Warning: existing Position folders detected!", + txt, + buttonsTexts=( + "Cancel", + "Overwrite", + "Add files", + widgets.newFilePushButton("Create new"), ), - path_to_browse=exp_dst_path + path_to_browse=exp_dst_path, ) if msg.cancel: - return - + return + overwrite = overwriteButton == msg.clickedButton add_files = addFilesButton == msg.clickedButton create_new = createNewButton == msg.clickedButton - + start_pos_n = 1 if create_new: - pos_ns = [int(pos.split('_')[-1]) for pos in pos_foldernames] + pos_ns = [int(pos.split("_")[-1]) for pos in pos_foldernames] start_pos_n = max(pos_ns) + 1 - + return overwrite, add_files, create_new, start_pos_n def closeEvent(self, event): - self.logger.info('Closing data structure logger...') + self.logger.info("Closing data structure logger...") handlers = self.logger.handlers[:] for handler in handlers: handler.close() @@ -2201,21 +2269,21 @@ def closeEvent(self, event): if self.buttonToRestore is not None: button, color = self.buttonToRestore button.setText('0. Attempt "Create data structure" again') - button.setStyleSheet( - f'QPushButton {{background-color: {color};}}') + button.setStyleSheet(f"QPushButton {{background-color: {color};}}") self.mainWin.setWindowState(Qt.WindowNoState) self.mainWin.setWindowState(Qt.WindowActive) self.mainWin.raise_() + class InitFijiMacro: def __init__(self, acdcLauncher): self.acdcLauncher = acdcLauncher self.logger = self.acdcLauncher.logger - + def askSelectInstalledFiji(self): - if os.path.exists(myutils.get_fiji_exec_folderpath()): + if os.path.exists(utils.get_fiji_exec_folderpath()): return False, False - + txt = html_utils.paragraph(f""" Do you already have Fiji (ImageJ)?

If yes, click on the Select Fiji location button below @@ -2223,63 +2291,61 @@ def askSelectInstalledFiji(self): Alternatively, you can ignore this and let Cell-ACDC automatically download Fiji for you. """) - selectFijiButton = ( - widgets.OpenFilePushButton('Select Fiji.app location') - ) - downloadFijiButton = ( - widgets.DownloadPushButton('Download Fiji.app') - ) + selectFijiButton = widgets.OpenFilePushButton("Select Fiji.app location") + downloadFijiButton = widgets.DownloadPushButton("Download Fiji.app") msg = widgets.myMessageBox(wrapText=False) msg.did_user_selected_fiji = False msg.question( - self.acdcLauncher, 'Select Fiji location', txt, - buttonsTexts=( - 'Cancel', selectFijiButton, downloadFijiButton - ), - showDialog=False + self.acdcLauncher, + "Select Fiji location", + txt, + buttonsTexts=("Cancel", selectFijiButton, downloadFijiButton), + showDialog=False, ) selectFijiButton.clicked.disconnect() selectFijiButton.clicked.connect( partial(self.selectFijiLocation, messagebox=msg) ) msg.exec_() - + return msg.cancel, msg.did_user_selected_fiji - + def selectFijiLocation(self, checked=True, messagebox=None): import qtpy.compat + filepath = qtpy.compat.getopenfilename( - parent=messagebox, - caption='Select Fiji.app location', - filters='Application (*.app);;All Files (*)' + parent=messagebox, + caption="Select Fiji.app location", + filters="Application (*.app);;All Files (*)", )[0] if not filepath: return - + from cellacdc import fiji_location_filepath - with open(fiji_location_filepath, 'w') as txt: + + with open(fiji_location_filepath, "w") as txt: txt.write(os.path.join(filepath)) - + messagebox.did_user_selected_fiji = True messagebox.cancel = False messagebox.close() - + def run(self): cancel, did_user_selected_fiji = self.askSelectInstalledFiji() if cancel: self.cancel() return - - txt = (f""" + + txt = f""" In order to run Bio-Formats on your system, Cell-ACDC will use Fiji (ImageJ) from the command line.

The process entails the creation of a macro (.ijm) file and its execution from the command line.

If you prefer to run the macro yourself, you can go through its creation process and cancel its execution later. - """) - self.logger.info('Testing Fiji command...') - fiji_success = myutils.test_fiji_base_command(self.logger.info) + """ + self.logger.info("Testing Fiji command...") + fiji_success = utils.test_fiji_base_command(self.logger.info) commands = None if not fiji_success: if not did_user_selected_fiji: @@ -2287,96 +2353,98 @@ def run(self): shutil.rmtree(acdc_fiji_path) except Exception as err: pass - - href = html_utils.href_tag('here', urls.fiji_downloads) - note_download_txt = (f""" + + href = html_utils.href_tag("here", urls.fiji_downloads) + note_download_txt = f""" Before continuing, Fiji will be automatically downloaded now.

If the download fails, please download the zip file from {href} and unzip it in the following location: - """) - admon = html_utils.to_admonition( - note_download_txt, admonition_type='note' - ) - txt = f'{txt}
{admon}' + """ + admon = html_utils.to_admonition(note_download_txt, admonition_type="note") + txt = f"{txt}
{admon}" commands = (acdc_fiji_path,) - + txt = html_utils.paragraph(txt) msg = widgets.myMessageBox(wrapText=False) msg.information( - self.acdcLauncher, 'Running Fiji in the command line', txt, - buttonsTexts=('Cancel', 'Ok'), - commands=commands + self.acdcLauncher, + "Running Fiji in the command line", + txt, + buttonsTexts=("Cancel", "Ok"), + commands=commands, ) if msg.cancel: self.cancel() return - - myutils.download_fiji(logger_func=self.logger.info) - + + utils.download_fiji(logger_func=self.logger.info) + win = apps.InitFijiMacroDialog(parent=self.acdcLauncher) win.exec_() if win.cancel: self.cancel() return - + init_macro_args = win.init_macro_args is_separate_channels = init_macro_args[2] macro_filepath = fiji_macros.init_macro(*init_macro_args) macro_command = fiji_macros.command_run_macro(macro_filepath) - - txt = (""" + + txt = """ Cell-ACDC will now run the macro in the terminal.

During the process, the GUI will be unresponsive, while progress will be displayed in the terminal.

If you prefer, you can stop the process now and run the command yourself, or even run the macro directly from the Fiji GUI.
- """) - + """ + if is_separate_channels: important_admon = html_utils.to_admonition( - 'There are still steps to run after the macro finishes, so ' - 'if you run it yourself, ' - 'please close this dialogue only after the macro completes.', - admonition_type='important' + "There are still steps to run after the macro finishes, so " + "if you run it yourself, " + "please close this dialogue only after the macro completes.", + admonition_type="important", ) - txt = f'{txt}{important_admon}' - - txt = f'{txt}
Command to run the macro:' - + txt = f"{txt}{important_admon}" + + txt = f"{txt}
Command to run the macro:" + txt = html_utils.paragraph(txt) msg = widgets.myMessageBox(wrapText=False) _, _, okButton = msg.information( - self.acdcLauncher, 'Fiji macro command', txt, - buttonsTexts=('Cancel', 'I already ran the macro', 'Ok'), + self.acdcLauncher, + "Fiji macro command", + txt, + buttonsTexts=("Cancel", "I already ran the macro", "Ok"), commands=(macro_filepath), - path_to_browse=os.path.dirname(macro_filepath) + path_to_browse=os.path.dirname(macro_filepath), ) if msg.cancel: self.cancel() return - + success = True if msg.clickedButton == okButton: - success = fiji_macros.run_macro(macro_command) - + success = fiji_macros.run_macro(macro_command) + files_folderpath = init_macro_args[0] dst_folderpath = init_macro_args[3] channels = init_macro_args[4] if success and is_separate_channels: - self.logger.info('Moving files to Position folders...') + self.logger.info("Moving files to Position folders...") success = io.move_separate_channels_tiffs_to_pos_folders( dst_folderpath, channels ) - + if success: txt = html_utils.paragraph(""" Macro execution completed. Path to the macro file: """) - msg_func = 'information' + msg_func = "information" else: - href = html_utils.href_tag('GitHub page', urls.issues_url) + href = html_utils.href_tag("GitHub page", urls.issues_url) txt = html_utils.paragraph(f""" Macro execution completed with errors. More details in the terminal.

@@ -2384,15 +2452,15 @@ def run(self): {href}

Path to the macro file: """) - msg_func = 'information' - + msg_func = "information" + msg = widgets.myMessageBox(wrapText=False) getattr(msg, msg_func)( - self.acdcLauncher, 'Macro execution completed', txt, - commands=(macro_filepath,) + self.acdcLauncher, + "Macro execution completed", + txt, + commands=(macro_filepath,), ) - + def cancel(self): - self.logger.info('Running Bio-Formats from Fiji process cancelled.') - - + self.logger.info("Running Bio-Formats from Fiji process cancelled.") diff --git a/cellacdc/data_source.py b/cellacdc/data_source.py new file mode 100644 index 000000000..656b12ade --- /dev/null +++ b/cellacdc/data_source.py @@ -0,0 +1,342 @@ +"""Unified experiment data for decoupling the GUI from filesystem loading.""" + +from __future__ import annotations + +import os +import tempfile +from dataclasses import dataclass, field +from typing import Literal + +import numpy as np +import pandas as pd + +VolumeAxes = Literal["yx", "zyx", "tyx", "tzyx"] +PathKind = Literal["file", "experiment", "images", "folder"] +DataSourceKind = Literal["memory", "path"] + + +@dataclass +class ArrayDataSource: + """Specification for building in-memory position data.""" + + image: np.ndarray + labels: np.ndarray | None = None + name: str = "data" + channel_name: str = "cells" + axes: VolumeAxes = "tyx" + workspace: str | os.PathLike | None = None + time_increment: float = 1.0 + physical_size_xy: tuple[float, float] = (1.0, 1.0) + physical_size_z: float = 1.0 + is_segm_3d: bool = False + metadata: dict[str, str | float | int] = field(default_factory=dict) + + +class ExperimentData: + """Unified dataset handle for the Cell-ACDC script API. + + Use :meth:`from_arrays` or :meth:`from_path` to create instances. + """ + + name: str + source: DataSourceKind + path: str | None + path_kind: PathKind | None + _positions: list | None + + def __init__(self): + pass + + @classmethod + def from_arrays( + cls, + image: np.ndarray, + labels: np.ndarray | None = None, + **kwargs, + ) -> ExperimentData: + """Create dataset data from in-memory arrays.""" + self = cls() + load_data_cls = kwargs.pop("_load_data_cls", None) + name = kwargs.get("name", "data") + pos = pos_data_from_kwargs( + image, + labels, + _load_data_cls=load_data_cls, + **kwargs, + ) + self.source = "memory" + self.name = name + self.path = None + self.path_kind = None + self._positions = [pos] + return self + + @classmethod + def from_path(cls, path: str | os.PathLike, **kwargs) -> ExperimentData: + """Create a dataset handle from a filesystem path.""" + path = os.fspath(path) + if not os.path.exists(path): + raise FileNotFoundError(path) + + self = cls() + name = kwargs.get("name", "data") + self.source = "path" + self.path = path + self.path_kind = _detect_path_kind(path) + self.name = ( + os.path.basename(path.rstrip(os.sep)) if name == "data" else name + ) + self._positions = None + return self + + @property + def is_materialized(self) -> bool: + return self.source == "memory" and self._positions is not None + + @property + def positions(self) -> list: + if not self.is_materialized: + raise RuntimeError( + "Path-based ExperimentData is loaded by the viewer on demand. " + "Use Viewer(data) or data.load_into(window)." + ) + return self._positions + + def load_into(self, window) -> None: + if self.source == "memory": + window.loadFromExperimentData(self) + return + + if self.path_kind == "file": + window.openFile(file_path=self.path) + elif self.path_kind == "images": + window.openFolder(exp_path=self.path) + elif self.path_kind == "experiment": + window.openFolder(exp_path=self.path) + else: + if os.path.isdir(self.path): + window.openFolder(exp_path=self.path) + else: + window.openFile(file_path=self.path) + + +def _detect_path_kind(path: str) -> PathKind: + if os.path.isfile(path): + return "file" + + basename = os.path.basename(path.rstrip(os.sep)) + if basename == "Images": + return "images" + + try: + entries = os.listdir(path) + except OSError: + return "folder" + + if any(entry.startswith("Position") and os.path.isdir(os.path.join(path, entry)) for entry in entries): + return "experiment" + + return "folder" + + +def normalize_volume( + array: np.ndarray, + *, + axes: VolumeAxes = "tyx", +) -> tuple[np.ndarray, int, int]: + """Return (array, SizeT, SizeZ) in Cell-ACDC's pre-finalize layout.""" + arr = np.asarray(array) + if arr.ndim == 2: + if axes != "yx": + raise ValueError( + f"A 2D array requires axes='yx', got axes={axes!r}." + ) + return arr, 1, 1 + + if arr.ndim == 3: + if axes == "zyx": + return arr, 1, arr.shape[0] + if axes == "tyx": + return arr, arr.shape[0], 1 + raise ValueError( + f"A 3D array requires axes='tyx' or 'zyx', got axes={axes!r}." + ) + + if arr.ndim == 4: + if axes != "tzyx": + raise ValueError( + f"A 4D array requires axes='tzyx', got axes={axes!r}." + ) + return arr, arr.shape[0], arr.shape[1] + + raise ValueError( + f"Expected a 2D, 3D, or 4D array, got shape {arr.shape}." + ) + + +def _finalize_pos_data_arrays(pos_data) -> None: + """Match the array layout produced by ``loadDataWorker``.""" + if pos_data.SizeT == 1: + pos_data.img_data = pos_data.img_data[np.newaxis] + if pos_data.segm_data is not None: + pos_data.segm_data = pos_data.segm_data[np.newaxis] + + pos_data.img_data_shape = pos_data.img_data.shape + pos_data.dset = pos_data.img_data + if pos_data.segm_data is not None: + pos_data.segmSizeT = len(pos_data.segm_data) + + +def _write_metadata_csv( + metadata_csv_path: os.PathLike, + *, + basename: str, + size_t: int, + size_z: int, + size_y: int, + size_x: int, + channel_name: str, + time_increment: float, + physical_size_xy: tuple[float, float], + physical_size_z: float, + is_segm_3d: bool, + extra: dict[str, str | float | int], +) -> None: + rows = { + "basename": basename, + "SizeT": size_t, + "SizeZ": size_z, + "SizeY": size_y, + "SizeX": size_x, + "TimeIncrement": time_increment, + "PhysicalSizeX": physical_size_xy[0], + "PhysicalSizeY": physical_size_xy[1], + "PhysicalSizeZ": physical_size_z, + "segm_isSegm3D": str(is_segm_3d), + f"{channel_name}_name": channel_name, + } + rows.update(extra) + df = pd.DataFrame( + {"Description": list(rows.keys()), "values": [str(v) for v in rows.values()]} + ) + df.to_csv(metadata_csv_path, index=False) + + +def pos_data_from_arrays(source: ArrayDataSource, *, _load_data_cls=None): + """Build a ``loadData`` instance backed by in-memory arrays.""" + if _load_data_cls is None: + from cellacdc import load + + _load_data_cls = load.loadData + + image, size_t, size_z = normalize_volume(source.image, axes=source.axes) + size_y, size_x = image.shape[-2:] + + labels = source.labels + if labels is not None: + labels, labels_size_t, labels_size_z = normalize_volume( + labels, axes=source.axes + ) + if (labels_size_t, labels_size_z, *labels.shape[-2:]) != ( + size_t, + size_z, + size_y, + size_x, + ): + raise ValueError( + "Labels shape must match the image shape for the given axes." + ) + labels = labels.astype(np.uint32, copy=False) + + if source.workspace is None: + workspace = tempfile.mkdtemp(prefix="cellacdc_") + else: + workspace = os.fspath(source.workspace) + os.makedirs(workspace, exist_ok=True) + + exp_path = os.path.join(workspace, source.name) + pos_path = os.path.join(exp_path, "Position_001") + images_path = os.path.join(pos_path, "Images") + os.makedirs(images_path, exist_ok=True) + + basename = f"{source.name}_" + channel_name = source.channel_name + img_filename = f"{basename}{channel_name}.npz" + img_path = os.path.join(images_path, img_filename) + + pos = _load_data_cls(img_path, channel_name, log_func=print) + pos.basename = basename + pos.chNames = [channel_name] + pos.filename = f"{basename}{channel_name}" + pos.filename_ext = img_filename + pos.ext = ".npz" + pos.images_folder_files = [img_filename] + pos.img_data = image + pos.SizeT = size_t + pos.SizeZ = size_z + pos.SizeY = size_y + pos.SizeX = size_x + pos.loadSizeS = 1 + pos.loadSizeT = size_t + pos.loadSizeZ = size_z + pos.TimeIncrement = source.time_increment + pos.PhysicalSizeX = source.physical_size_xy[0] + pos.PhysicalSizeY = source.physical_size_xy[1] + pos.PhysicalSizeZ = source.physical_size_z + pos.isSegm3D = source.is_segm_3d + pos.is_in_memory = True + + pos.buildPaths() + metadata_csv_path = pos.metadata_csv_path + _write_metadata_csv( + metadata_csv_path, + basename=basename, + size_t=size_t, + size_z=size_z, + size_y=size_y, + size_x=size_x, + channel_name=channel_name, + time_increment=source.time_increment, + physical_size_xy=source.physical_size_xy, + physical_size_z=source.physical_size_z, + is_segm_3d=source.is_segm_3d, + extra=source.metadata, + ) + pos.metadataFound = True + pos.metadata_df = pd.read_csv(metadata_csv_path).set_index("Description") + pos.extractMetadata() + + if labels is not None: + pos.segmFound = True + pos.segm_data = labels + pos.labelBoolSegm = False + else: + pos.segmFound = False + pos.labelBoolSegm = False + pos.loadOtherFiles( + load_segm_data=False, + create_new_segm=True, + load_acdc_df=False, + load_metadata=False, + ) + pos.setBlankSegmData(pos.SizeT, pos.SizeZ, size_y, size_x) + + pos.acdc_df_found = False + pos.acdc_df = None + pos.segmInfo_df = None + pos.allData_li = [None] * pos.SizeT + pos.frame_i = 0 + + _finalize_pos_data_arrays(pos) + return pos + + +def pos_data_from_kwargs( + image: np.ndarray, + labels: np.ndarray | None = None, + *, + _load_data_cls=None, + **kwargs, +): + source = ArrayDataSource(image=image, labels=labels, **kwargs) + return pos_data_from_arrays(source, _load_data_cls=_load_data_cls) diff --git a/cellacdc/debugutils.py b/cellacdc/debugutils.py index b55d0eef1..b135cfeba 100644 --- a/cellacdc/debugutils.py +++ b/cellacdc/debugutils.py @@ -1,11 +1,12 @@ import inspect, os, datetime, sys, traceback -from . import cellacdc_path, myutils +from . import cellacdc_path, utils import gc import psutil -def showRefGraph(object_str:str, debug:bool=True): + +def showRefGraph(object_str: str, debug: bool = True): """Save a reference graph of the given object type. @@ -18,48 +19,51 @@ def showRefGraph(object_str:str, debug:bool=True): """ if not debug: return - + try: import objgraph except ImportError: - conda_prefix, pip_prefix = myutils.get_pip_conda_prefix() + conda_prefix, pip_prefix = utils.get_pip_conda_prefix() - print(f"objgraph is not installed. Install it with '{pip_prefix} objgraph' to use reference graph features, as well as https://www.graphviz.org/") + print( + f"objgraph is not installed. Install it with '{pip_prefix} objgraph' to use reference graph features, as well as https://www.graphviz.org/" + ) return caller_func = inspect.currentframe().f_back.f_code.co_name caller_file = inspect.currentframe().f_back.f_code.co_filename - caller_file = os.path.basename(caller_file).rstrip('.py') + caller_file = os.path.basename(caller_file).rstrip(".py") caller_line = inspect.currentframe().f_back.f_lineno - timestap = datetime.datetime.now().strftime('%H_%M_%S') + timestap = datetime.datetime.now().strftime("%H_%M_%S") - ref_graph_path = os.path.join( - os.path.dirname(cellacdc_path), - '.ref_graphs' - ) + ref_graph_path = os.path.join(os.path.dirname(cellacdc_path), ".ref_graphs") os.makedirs(ref_graph_path, exist_ok=True) - - filename = os.path.join(ref_graph_path, f'ref_graph_{timestap}_{object_str}_from_{caller_file}_{caller_func}_{caller_line}.svg') - timestap = datetime.datetime.now().strftime('%H:%M:%S') + filename = os.path.join( + ref_graph_path, + f"ref_graph_{timestap}_{object_str}_from_{caller_file}_{caller_func}_{caller_line}.svg", + ) + + timestap = datetime.datetime.now().strftime("%H:%M:%S") currentframe = inspect.currentframe() outerframes = inspect.getouterframes(currentframe) callingframe = outerframes[1].frame callingframe_info = inspect.getframeinfo(callingframe) filepath = callingframe_info.filename - fileinfo_str = ( - f'File "{filepath}", line {callingframe_info.lineno} - {timestap}:' - ) - + fileinfo_str = f'File "{filepath}", line {callingframe_info.lineno} - {timestap}:' gc.collect() instances = objgraph.by_type(object_str) if instances: objgraph.show_backrefs(instances, max_depth=500, filename=filename) - print(fileinfo_str, f'Graph saved as "{filename}" \n for {len(instances)} instances of "{object_str}"') + print( + fileinfo_str, + f'Graph saved as "{filename}" \n for {len(instances)} instances of "{object_str}"', + ) else: - print(fileinfo_str, f'No {object_str} instances found.') + print(fileinfo_str, f"No {object_str} instances found.") + def print_largest_attributes( obj, top_n=10, return_list=False, show_percent=True, process_mem=None @@ -67,7 +71,7 @@ def print_largest_attributes( attrs = [] total = 0 for attr in dir(obj): - if attr.startswith('__'): + if attr.startswith("__"): continue try: val = getattr(obj, attr) @@ -85,7 +89,9 @@ def print_largest_attributes( percent = (size / total * 100) if total > 0 else 0 proc_percent = (size / process_mem * 100) if process_mem else 0 if show_percent and process_mem: - print(f"{attr:30} {size:10,} bytes {percent:6.2f}% of obj {proc_percent:6.2f}% of proc {typ}") + print( + f"{attr:30} {size:10,} bytes {percent:6.2f}% of obj {proc_percent:6.2f}% of proc {typ}" + ) elif show_percent: print(f"{attr:30} {size:10,} bytes {percent:6.2f}% {typ}") else: @@ -93,6 +99,7 @@ def print_largest_attributes( if return_list: return attrs[:top_n] + def print_call_stack(debug=True, depth=None): if not debug: return @@ -100,11 +107,12 @@ def print_call_stack(debug=True, depth=None): stack = stack[:-1] if depth: depth = depth + 1 - stack = stack[-depth:] + stack = stack[-depth:] print("Call stack:") for line in stack: print(line.strip()) + def print_largest_attributes_for_all_classes(package_prefix="cellacdc", top_n=5): # Find all classes defined in your package classes = set() @@ -119,12 +127,13 @@ def print_largest_attributes_for_all_classes(package_prefix="cellacdc", top_n=5) continue print(f"\nClass: {cls.__module__}.{cls.__name__} ({len(instances)} instances)") for i, inst in enumerate(instances): - print(f" Instance {i+1}:") + print(f" Instance {i + 1}:") try: print_largest_attributes(inst, top_n=top_n) except Exception as e: print(f" Could not inspect instance: {e}") + def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100): """ Print classes (optionally filtered by module prefix) sorted by total memory usage. @@ -134,6 +143,7 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 import gc import psutil import os + try: from pympler import asizeof except ImportError: @@ -185,7 +195,7 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 # scale up if sampled if counted > 0 and n > counted: - total_size *= (n / counted) + total_size *= n / counted if total_size > 0: class_mem.append((cls, total_size, n)) @@ -193,7 +203,7 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 # ✅ Sort by memory class_mem.sort(key=lambda x: x[1], reverse=True) - print(f"Total process memory: {process_mem/1024**2:.1f} MB") + print(f"Total process memory: {process_mem / 1024**2:.1f} MB") print(f"{'Class':60} {'Instances':>10} {'Total MB':>12} {'% of proc':>10}") for cls, total_size, n in class_mem[:top_n]: @@ -201,7 +211,7 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 name = f"{cls.__module__}.{cls.__name__}" - print(f"{name:<60} {n:10} {total_size/1024**2:12.2f} {percent:9.2f}%") + print(f"{name:<60} {n:10} {total_size / 1024**2:12.2f} {percent:9.2f}%") # Example usage: diff --git a/cellacdc/dialogs/__init__.py b/cellacdc/dialogs/__init__.py new file mode 100644 index 000000000..3a88c346c --- /dev/null +++ b/cellacdc/dialogs/__init__.py @@ -0,0 +1,228 @@ +"""Cell-ACDC dialog windows.""" + +from ._base import ( + ArgWidget, + QBaseDialog, +) + +from .export import ( + ExportToImageParametersDialog, + ExportToVideoParametersDialog, + LogoDialog, + ObjectCountDialog, + ScaleBarPropertiesDialog, + ShortcutEditorDialog, + TimestampPropertiesDialog, + ViewTextDialog, + pdDataFrameWidget, +) + +from .general import ( + AddPointsLayerDialog, + EditPointsLayerAppearanceDialog, + QDialogCombobox, + QDialogPbar, + QDialogWorkerProgress, + QLineEditDialog, + QTreeDialog, + QtSelectItems, + SelectSegmFileDialog, + SetCustomLevelsLut, + _PointsLayerAppearanceGroupbox, + askStopFrameSegm, + customAnnotationDialog, + get_existing_directory, + imageViewer, + pgTestWindow, +) + +from .measurements import ( + CombineFeaturesCalculator, + CombineMetricsMultiDfsDialog, + CombineMetricsMultiDfsSummaryDialog, + ComputeMetricsErrorsDialog, + SelectFeaturesRange, + SelectFeaturesRangeDialog, + SelectFeaturesRangeGroupbox, + SetMeasurementsDialog, + combineMetricsEquationDialog, +) + +from .metadata import ( + AutoSaveIntervalDialog, + MultiListSelector, + MultiTimePointFilePattern, + OrderableListWidgetDialog, + OverlayLabelsAppearanceDialog, + QCropTrangeTool, + QCropZtool, + QDialogAppendTextFilename, + QDialogEntriesWidget, + QDialogMetadata, + QDialogMetadataXML, + QDialogZsliceAbsent, + SelectFoldersToAnalyse, + SetColumnNamesDialog, + TreeSelectorDialog, + TreesSelectorDialog, + filenameDialog, + selectPositionsMultiExp, +) + +from .models import ( + ChangeUserProfileFolderPathDialog, + DataFrameModel, + InstallPyTorchDialog, + QDialogModelParams, + QDialogSelectModel, + QInput, + SelectAcdcDfVersionToRestore, + SelectPromptableModelDialog, + addCustomModelMessages, + addCustomPromptModelMessages, + downloadModel, +) + +from .preprocess import ( + CombineChannelsSetupDialog, + CombineChannelsSetupDialogGUI, + CombineChannelsSetupDialogUtil, + CombineChannelsWidget, + DataPrepSubCropsPathsDialog, + FormulaEditWidget, + FucciPreprocessDialog, + FunctionParamsDialog, + FutureFramesAction_QDialog, + ImageJRoisToSegmManager, + InitFijiMacroDialog, + PostProcessSegmDialog, + PostProcessSegmParams, + PreProcessParamsWidget, + PreProcessRecipeDialog, + PreProcessRecipeDialogUtil, + QDialogAutomaticThresholding, + ResizeUtilProps, + TestSegmModelInitalDialog, + randomWalkerDialog, + startStopFramesDialog, + stopFrameDialog, + wandToleranceWidget, +) + +from .tracking import ( + ApplyTrackTableSelectColumnsDialog, + BayesianTrackerParamsWin, + CellACDCTrackerParamsWin, + DeltaTrackerParamsWin, + EditIDDialog, + FindIDDialog, + GenerateMotherBudTotalTableSelectColumnsDialog, + NumericEntryDialog, + TrackSubCellObjectsDialog, + ViewCcaTableWindow, + editCcaTableWidget, + manualSeparateGui, +) + +__all__ = [ + "ArgWidget", + "QBaseDialog", + "ExportToImageParametersDialog", + "ExportToVideoParametersDialog", + "LogoDialog", + "ObjectCountDialog", + "ScaleBarPropertiesDialog", + "ShortcutEditorDialog", + "TimestampPropertiesDialog", + "ViewTextDialog", + "pdDataFrameWidget", + "AddPointsLayerDialog", + "EditPointsLayerAppearanceDialog", + "QDialogCombobox", + "QDialogPbar", + "QDialogWorkerProgress", + "QLineEditDialog", + "QTreeDialog", + "QtSelectItems", + "SelectSegmFileDialog", + "SetCustomLevelsLut", + "_PointsLayerAppearanceGroupbox", + "askStopFrameSegm", + "customAnnotationDialog", + "get_existing_directory", + "imageViewer", + "pgTestWindow", + "CombineFeaturesCalculator", + "CombineMetricsMultiDfsDialog", + "CombineMetricsMultiDfsSummaryDialog", + "ComputeMetricsErrorsDialog", + "SelectFeaturesRange", + "SelectFeaturesRangeDialog", + "SelectFeaturesRangeGroupbox", + "SetMeasurementsDialog", + "combineMetricsEquationDialog", + "AutoSaveIntervalDialog", + "MultiListSelector", + "MultiTimePointFilePattern", + "OrderableListWidgetDialog", + "OverlayLabelsAppearanceDialog", + "QCropTrangeTool", + "QCropZtool", + "QDialogAppendTextFilename", + "QDialogEntriesWidget", + "QDialogMetadata", + "QDialogMetadataXML", + "QDialogZsliceAbsent", + "SelectFoldersToAnalyse", + "SetColumnNamesDialog", + "TreeSelectorDialog", + "TreesSelectorDialog", + "filenameDialog", + "selectPositionsMultiExp", + "ChangeUserProfileFolderPathDialog", + "DataFrameModel", + "InstallPyTorchDialog", + "QDialogModelParams", + "QDialogSelectModel", + "QInput", + "SelectAcdcDfVersionToRestore", + "SelectPromptableModelDialog", + "addCustomModelMessages", + "addCustomPromptModelMessages", + "downloadModel", + "CombineChannelsSetupDialog", + "CombineChannelsSetupDialogGUI", + "CombineChannelsSetupDialogUtil", + "CombineChannelsWidget", + "DataPrepSubCropsPathsDialog", + "FormulaEditWidget", + "FucciPreprocessDialog", + "FunctionParamsDialog", + "FutureFramesAction_QDialog", + "ImageJRoisToSegmManager", + "InitFijiMacroDialog", + "PostProcessSegmDialog", + "PostProcessSegmParams", + "PreProcessParamsWidget", + "PreProcessRecipeDialog", + "PreProcessRecipeDialogUtil", + "QDialogAutomaticThresholding", + "ResizeUtilProps", + "TestSegmModelInitalDialog", + "randomWalkerDialog", + "startStopFramesDialog", + "stopFrameDialog", + "wandToleranceWidget", + "ApplyTrackTableSelectColumnsDialog", + "BayesianTrackerParamsWin", + "CellACDCTrackerParamsWin", + "DeltaTrackerParamsWin", + "EditIDDialog", + "FindIDDialog", + "GenerateMotherBudTotalTableSelectColumnsDialog", + "NumericEntryDialog", + "TrackSubCellObjectsDialog", + "ViewCcaTableWindow", + "editCcaTableWidget", + "manualSeparateGui", +] diff --git a/cellacdc/dialogs/_base.py b/cellacdc/dialogs/_base.py new file mode 100644 index 000000000..b107439f8 --- /dev/null +++ b/cellacdc/dialogs/_base.py @@ -0,0 +1,181 @@ +"""Cell-ACDC dialog windows: _base.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +class ArgWidget: + def __init__( + self, name, type, widget, defaultVal, valueSetter, valueGetter, changeSig=None + ): + self.name = name + self.type = type + self.widget = widget + self.defaultVal = defaultVal + self.valueSetter = valueSetter + self.valueGetter = valueGetter + if changeSig is not None: + self.changeSig = changeSig + + +class QBaseDialog(_base_widgets.QBaseDialog): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/cellacdc/dialogs/export.py b/cellacdc/dialogs/export.py new file mode 100644 index 000000000..a11857f59 --- /dev/null +++ b/cellacdc/dialogs/export.py @@ -0,0 +1,1512 @@ +"""Cell-ACDC dialog windows: export.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +class ViewTextDialog(QBaseDialog): + def __init__(self, text, parent=None): + super().__init__(parent) + + mainLayout = QVBoxLayout() + + textViewWidget = QTextEdit() + textViewWidget.setReadOnly(True) + + textViewWidget.setText(text) + + buttonsLayout = QHBoxLayout() + okButton = widgets.okPushButton("Ok") + + okButton.clicked.connect(self.close) + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(okButton) + + mainLayout.addWidget(textViewWidget) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + self.setFont(font) + + +class pdDataFrameWidget(QMainWindow): + def __init__(self, df, parent=None): + super().__init__(parent) + self.parent = parent + self.setWindowTitle("Cell cycle annotations") + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + mainContainer = QWidget() + self.setCentralWidget(mainContainer) + + layout = QVBoxLayout() + self._layout = layout + + self.tableView = QTableView(self) + layout.addWidget(self.tableView) + model = DataFrameModel(df) + self.tableView.setModel(model) + for i in range(len(df.columns)): + self.tableView.resizeColumnToContents(i) + # layout.addWidget(QPushButton('Ok', self)) + mainContainer.setLayout(layout) + + def updateTable(self, df, IDs=None): + if df is None: + df = self.parent.getBaseCca_df() + + if IDs is not None: + df = df.loc[IDs] + + df = df.reset_index() + model = DataFrameModel(df) + self.tableView.setModel(model) + for i in range(len(df.columns)): + self.tableView.resizeColumnToContents(i) + + def setGeometryWindow(self, maxWidth=1024): + width = self.tableView.verticalHeader().width() + 4 + for j in range(self.tableView.model().columnCount()): + width += self.tableView.columnWidth(j) + 4 + height = self.tableView.horizontalHeader().height() + 4 + h = height + (self.tableView.rowHeight(0) + 4) * 10 + w = width if width < maxWidth else maxWidth + self.setGeometry(100, 100, w, h) + + # Center window + parent = self.parent + if parent is not None: + # Center the window on main window + mainWinGeometry = parent.geometry() + mainWinLeft = mainWinGeometry.left() + mainWinTop = mainWinGeometry.top() + mainWinWidth = mainWinGeometry.width() + mainWinHeight = mainWinGeometry.height() + mainWinCenterX = int(mainWinLeft + mainWinWidth / 2) + mainWinCenterY = int(mainWinTop + mainWinHeight / 2) + winGeometry = self.geometry() + winWidth = winGeometry.width() + winHeight = winGeometry.height() + winLeft = int(mainWinCenterX - winWidth / 2) + winRight = int(mainWinCenterY - winHeight / 2) + self.move(winLeft, winRight) + + def closeEvent(self, event): + self.parent.ccaTableWin = None + + +class ShortcutEditorDialog(QBaseDialog): + def __init__( + self, + widgetsWithShortcut: dict, + delObjectKey="", + delObjectButton: Literal["Middle click", "Left click"] = "Middle click", + zoomOutKeyValue: int = None, + parent=None, + ): + self.cancel = True + super().__init__(parent) + + self.setWindowTitle("Customize keyboard shortcuts") + + mainLayout = QVBoxLayout() + + self.customShortcuts = {} + self.shortcutLineEdits = {} + + scrollArea = QScrollArea(self) + scrollArea.setWidgetResizable(True) + scrollAreaWidget = QWidget() + entriesLayout = QGridLayout() + + row = 0 + button = widgets.PushButton(self, flat=True) + button.setIcon(QIcon(":del_obj_click.svg")) + self.delObjShortcutLineEdit = widgets.ShortcutLineEdit( + allowModifiers=True, notAllowedModifier=Qt.AltModifier + ) + if delObjectKey is not None: + self.delObjShortcutLineEdit.setText(delObjectKey) + self.delObjButtonCombobox = QComboBox() + self.delObjButtonCombobox.addItems(["Middle click", "Left click"]) + self.delObjButtonCombobox.setCurrentText(delObjectButton) + entriesLayout.addWidget(button, row, 0) + entriesLayout.addWidget(QLabel("Delete object:"), row, 1) + entriesLayout.addWidget(self.delObjShortcutLineEdit, row, 2) + entriesLayout.addWidget( + self.delObjButtonCombobox, row, 3, alignment=Qt.AlignLeft + ) + + row += 1 + name = "Zoom out" + button = widgets.PushButton(self, flat=True) + label = QLabel("Zoom out:") + self.zoomShortcutLineEdit = widgets.ShortcutLineEdit() + if zoomOutKeyValue is not None: + zoomOutKeySequence = widgets.KeySequenceFromText(zoomOutKeyValue) + self.zoomShortcutLineEdit.setText(zoomOutKeySequence.toString()) + self.zoomShortcutLineEdit.key = zoomOutKeyValue + self.zoomShortcutLineEdit.textChanged.connect(self.checkDuplicateShortcuts) + entriesLayout.addWidget(button, row, 0) + entriesLayout.addWidget(label, row, 1) + entriesLayout.addWidget(self.zoomShortcutLineEdit, row, 2) + self.shortcutLineEdits[name] = self.zoomShortcutLineEdit + + row += 1 + for row, (name, widget) in enumerate(widgetsWithShortcut.items(), start=row): + button = widgets.PushButton(self, flat=True) + try: + button.setIcon(widget.icon()) + except: + pass + label = QLabel(f"{name}:") + shortcutLineEdit = widgets.ShortcutLineEdit() + if hasattr(widget, "keyPressShortcut"): + shortcutLineEdit.key = widget.keyPressShortcut + shortcut = widgets.KeySequenceFromText(widget.keyPressShortcut) + isShortcutKeyPress = True + else: + shortcut = widget.shortcut() + isShortcutKeyPress = False + shortcutLineEdit.setText(shortcut.toString()) + shortcutLineEdit.textChanged.connect(self.checkDuplicateShortcuts) + shortcutLineEdit.isShortcutKeyPress = isShortcutKeyPress + entriesLayout.addWidget(button, row, 0) + entriesLayout.addWidget(label, row, 1) + entriesLayout.addWidget(shortcutLineEdit, row, 2) + self.shortcutLineEdits[name] = shortcutLineEdit + + entriesLayout.setColumnStretch(0, 0) + entriesLayout.setColumnStretch(1, 0) + entriesLayout.setColumnStretch(2, 1) + entriesLayout.setColumnStretch(3, 0) + + scrollAreaWidget.setLayout(entriesLayout) + scrollArea.setWidget(scrollAreaWidget) + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(scrollArea) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setFont(font) + self.setLayout(mainLayout) + + def checkDuplicateShortcuts(self, text): + for name, shortcutLineEdit in self.shortcutLineEdits.items(): + if shortcutLineEdit == self.sender(): + continue + if shortcutLineEdit.text() != text: + continue + shortcutLineEdit.setText("") + + def warnInvalidKeySequenceDelObjWithLeftClick(self): + txt = html_utils.paragraph( + 'The selected key sequence to delete objects with "Left click" ' + "is invalid.

" + 'Only "Middle click" can be used without pressing keys.

' + "Thank you for your patience!" + ) + msg = widgets.myMessageBox() + msg.warning(self, "Invalid key sequence to delete objects", txt) + + def ok_cb(self): + delObjButtonText = self.delObjButtonCombobox.currentText() + delObjKeySequence = self.delObjShortcutLineEdit.keySequence + if delObjButtonText == "Left click" and delObjKeySequence is None: + self.warnInvalidKeySequenceDelObjWithLeftClick() + return + + self.shortcutLineEdits.pop("Zoom out") + self.cancel = False + for name, shortcutLineEdit in self.shortcutLineEdits.items(): + text = shortcutLineEdit.text() + if shortcutLineEdit.isShortcutKeyPress: + self.customShortcuts[name] = (text, shortcutLineEdit.key) + else: + self.customShortcuts[name] = (text, shortcutLineEdit.keySequence) + + delObjQtButton = ( + Qt.MouseButton.LeftButton + if delObjButtonText == "Left click" + else Qt.MouseButton.MiddleButton + ) + self.delObjAction = delObjKeySequence, delObjQtButton + self.zoomOutKeyValue = self.zoomShortcutLineEdit.key + + self.close() + + def showEvent(self, event) -> None: + self.resize(int(self.width() * 1.2), self.height()) + self.move(self.x(), 100) + + +class ScaleBarPropertiesDialog(QBaseDialog): + sigValueChanged = Signal(object) + + def __init__( + self, maxLength, maxThickness, PhysicalSizeX, parent=None, **properties + ): + super().__init__(parent=parent) + + self.cancel = True + self.setWindowTitle("Scale bar properties") + + self.PhysicalSizeX = PhysicalSizeX + + mainLayout = QVBoxLayout() + + formLayout = widgets.FormLayout() + formLayout.setVerticalSpacing(10) + formLayout.setHorizontalSpacing(50) + + row = 0 + unitCombobox = QComboBox() + unitFormWidget = widgets.formWidget(unitCombobox, labelTextLeft="Physical unit") + unitCombobox.addItems(["nm", "μm", "mm", "cm"]) + if properties.get("unit") is None: + unitCombobox.setCurrentIndex(1) + else: + unitCombobox.setCurrentText(properties.get("unit")) + formLayout.addFormWidget( + unitFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.unitCombobox = unitCombobox + + row += 1 + lengthDoubleSpinbox = widgets.DoubleSpinBox() + lengthDoubleSpinbox.setMaximum(maxLength) + lengthDoubleSpinbox.setMinimum(PhysicalSizeX) + lengthDoubleSpinbox.setDecimals(1) + if properties.get("length_unit") is not None: + lengthDoubleSpinbox.setValue(properties.get("length_unit")) + else: + deafultLength = np.ceil(PhysicalSizeX * 15) + lengthDoubleSpinbox.setValue(round(deafultLength)) + lengthFormWidget = widgets.formWidget( + lengthDoubleSpinbox, labelTextLeft="Length (μm)" + ) + self.lengthFormWidget = lengthFormWidget + self.lengthDoubleSpinbox = lengthDoubleSpinbox + formLayout.addFormWidget( + lengthFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + thicknessSpinbox = widgets.DoubleSpinBox() + thicknessSpinbox.setMaximum(maxThickness) + thicknessSpinbox.setMinimum(1) + if properties.get("thickness") is not None: + thicknessSpinbox.setValue(properties.get("thickness")) + else: + thicknessSpinbox.setValue(round(4, 1)) + thicknessSpinbox.setDecimals(1) + thicknessFormWidget = widgets.formWidget( + thicknessSpinbox, labelTextLeft="Thickness (pixel)" + ) + formLayout.addFormWidget( + thicknessFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.thicknessSpinbox = thicknessSpinbox + + row += 1 + locCombobox = QComboBox() + locFormWidget = widgets.formWidget(locCombobox, labelTextLeft="Location") + locCombobox.addItems( + ["Bottom-right", "Bottom-left", "Top-left", "Top-right", "Custom"] + ) + loc = properties.get("loc") + if isinstance(loc, str): + locCombobox.setCurrentText(loc.capitalize()) + formLayout.addFormWidget( + locFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.locCombobox = locCombobox + + row += 1 + self.colorButton = widgets.myColorButton(color=(255, 255, 255)) + if properties.get("color") is not None: + self.colorButton.setColor(properties.get("color")) + colorFormWidget = widgets.formWidget( + self.colorButton, + labelTextLeft="Color", + widgetAlignment=Qt.AlignCenter, + stretchWidget=False, + ) + formLayout.addFormWidget( + colorFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + displayTextToggle = widgets.Toggle() + if properties.get("is_text_visible") is not None: + displayTextToggle.setChecked(properties.get("is_text_visible")) + else: + displayTextToggle.setChecked(True) + displayTextFormWidget = widgets.formWidget( + displayTextToggle, + labelTextLeft="Display text", + widgetAlignment=Qt.AlignCenter, + stretchWidget=False, + ) + formLayout.addFormWidget( + displayTextFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.displayTextToggle = displayTextToggle + + row += 1 + fontSizeSpinbox = widgets.SpinBox() + if properties.get("font_size") is not None: + fontSizeSpinbox.setValue(int(properties.get("font_size"))) + else: + fontSizeSpinbox.setValue(12) + fontSizeFormWidget = widgets.formWidget( + fontSizeSpinbox, labelTextLeft="Font size (px)" + ) + self.fontSizeSpinbox = fontSizeSpinbox + formLayout.addFormWidget( + fontSizeFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + decimalsSpinbox = widgets.SpinBox() + decimalsSpinbox.setMaximum(6) + decimalsSpinbox.setMinimum(0) + if properties.get("num_decimals") is not None: + decimalsSpinbox.setValue(properties.get("num_decimals")) + else: + decimalsSpinbox.setValue(0) + decimalsFormWidget = widgets.formWidget( + decimalsSpinbox, labelTextLeft="Number of decimals" + ) + formLayout.addFormWidget( + decimalsFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.decimalsSpinbox = decimalsSpinbox + + row += 1 + moveWithZoomToggle = widgets.Toggle() + moveWithZoomWidget = widgets.formWidget( + moveWithZoomToggle, + labelTextLeft="Move scale bar with zoom", + widgetAlignment=Qt.AlignCenter, + stretchWidget=False, + ) + formLayout.addFormWidget( + moveWithZoomWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.moveWithZoomToggle = moveWithZoomToggle + + mainLayout.addLayout(formLayout) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch() + + self.setLayout(mainLayout) + self.setFont(font) + + self.unitCombobox.currentTextChanged.connect(self.updateLengthUnit) + self.colorButton.clicked.disconnect() + self.colorButton.clicked.connect(self.selectColor) + + self.colorButton.sigColorChanging.connect(self.onValueChanged) + self.lengthDoubleSpinbox.valueChanged.connect(self.onValueChanged) + self.thicknessSpinbox.valueChanged.connect(self.onValueChanged) + self.locCombobox.currentTextChanged.connect(self.onValueChanged) + self.displayTextToggle.toggled.connect(self.onValueChanged) + self.fontSizeSpinbox.valueChanged.connect(self.onValueChanged) + self.decimalsSpinbox.valueChanged.connect(self.onValueChanged) + self.moveWithZoomToggle.toggled.connect(self.onValueChanged) + + def onValueChanged(self, *args, **kwargs): + self.sigValueChanged.emit(self.kwargs()) + + def selectColor(self): + color = self.colorButton.color() + self.colorButton.origColor = color + self.colorButton.colorDialog.setCurrentColor(color) + self.colorButton.colorDialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.colorButton.colorDialog.setParent(self) + self.colorButton.colorDialog.open() + w = self.width() + left = self.pos().x() + colorDialogTop = self.colorButton.colorDialog.pos().y() + self.colorButton.colorDialog.move(w + left + 10, colorDialogTop) + + def updateLengthUnit(self, unit): + newText = re.sub(r"\(.*\)", f"({unit})", self.lengthFormWidget.labelLeft.text()) + self.lengthFormWidget.labelLeft.setText(newText) + self.onValueChanged(self) + + def kwargs(self): + unit = self.unitCombobox.currentText() + length_unit = self.lengthDoubleSpinbox.value() + length_um = _core.convert_length(length_unit, unit, "μm") + length_pixel = length_um / self.PhysicalSizeX + kwargs = { + "thickness": self.thicknessSpinbox.value(), + "length_pixel": length_pixel, + "length_unit": length_unit, + "is_text_visible": self.displayTextToggle.isChecked(), + "color": self.colorButton.color(), + "loc": self.locCombobox.currentText().lower(), + "font_size": self.fontSizeSpinbox.value(), + "unit": unit, + "num_decimals": self.decimalsSpinbox.value(), + "move_with_zoom": self.moveWithZoomToggle.isChecked(), + } + return kwargs + + def ok_cb(self): + self.cancel = False + self.close() + + +class ExportToVideoParametersDialog(QBaseDialog): + sigOk = Signal(dict) + sigAddScaleBar = Signal(bool) + sigAddTimestamp = Signal(bool) + sigRescaleIntensLut = Signal(str, str) + sigChangeStartTime = Signal(str) + + def __init__( + self, + channels, + parent=None, + startFolderpath="", + startFilename="", + startFrameNum=1, + SizeT=1, + SizeZ=1, + isTimelapseVideo=True, + isScaleBarPresent=False, + isTimestampPresent=False, + rescaleIntensChannelHowMapper=None, + startTime=None, + ): + self.cancel = True + + if rescaleIntensChannelHowMapper is None: + rescaleIntensChannelHowMapper = {} + + super().__init__(parent=parent) + + self.setWindowTitle("Preferences for output video") + + mainLayout = QVBoxLayout() + + gridLayout = QGridLayout() + + navVar = "frame number" if isTimelapseVideo else "z-slice" + maxNavVar = SizeT if isTimelapseVideo else SizeZ + + self.isTimelapseVideo = isTimelapseVideo + + row = 0 + gridLayout.addWidget(QLabel(f"Start {navVar}:"), row, 0) + self.startNavVarNumberEntry = widgets.SpinBox() + self.startNavVarNumberEntry.setMinimum(1) + self.startNavVarNumberEntry.setMaximum(maxNavVar - 1) + self.startNavVarNumberEntry.setValue(startFrameNum) + gridLayout.addWidget(self.startNavVarNumberEntry, row, 1) + + row += 1 + gridLayout.addWidget(QLabel(f"Stop {navVar}:"), row, 0) + self.stopNavVarNumberEntry = widgets.SpinBox() + self.stopNavVarNumberEntry.setMinimum(2) + self.stopNavVarNumberEntry.setMaximum(maxNavVar) + self.stopNavVarNumberEntry.setValue(maxNavVar) + gridLayout.addWidget(self.stopNavVarNumberEntry, row, 1) + + row += 1 + gridLayout.addWidget(QLabel("File format:"), row, 0) + self.fileFormatCombobox = QComboBox() + self.fileFormatCombobox.addItems(["MP4", "AVI"]) + gridLayout.addWidget(self.fileFormatCombobox, row, 1) + + row += 1 + gridLayout.addWidget(QLabel("Frame rate (FPS):"), row, 0) + self.fpsWidget = widgets.FloatLineEdit(allowNegative=False) + self.fpsWidget.setValue(10.0) + gridLayout.addWidget(self.fpsWidget, row, 1) + + row += 1 + self.dpiWidget = widgets.IntLineEdit(allowNegative=False) + self.dpiWidget.setValue(300) + self.dpiWidget.label = QLabel("DPI") + gridLayout.addWidget(self.dpiWidget.label, row, 0) + gridLayout.addWidget(self.dpiWidget, row, 1) + + row += 1 + gridLayout.addWidget(QLabel("Folder path:"), row, 0) + self.folderPathLineEdit = widgets.ElidingLineEdit(minWidth=240) + self.folderPathLineEdit.setText(startFolderpath) + gridLayout.addWidget(self.folderPathLineEdit, row, 1) + self.browseButton = widgets.browseFileButton( + start_dir=startFolderpath, openFolder=True + ) + gridLayout.addWidget(self.browseButton, row, 2) + + row += 1 + gridLayout.addWidget(QLabel("Filename:"), row, 0) + self.filenameLineEdit = widgets.alphaNumericLineEdit() + self.filenameLineEdit.setAlignment(Qt.AlignCenter) + self.filenameLineEdit.setText(startFilename) + gridLayout.addWidget(self.filenameLineEdit, row, 1) + self.fileFormatLabel = QLabel(".mp4") + gridLayout.addWidget(self.fileFormatLabel, row, 2) + + row += 1 + gridLayout.addWidget(QLabel("Add Scale Bar:"), row, 0) + self.addScaleBarToggle = widgets.Toggle() + gridLayout.addWidget(self.addScaleBarToggle, row, 1, alignment=Qt.AlignCenter) + self.addScaleBarToggle.setChecked(isScaleBarPresent) + + if isTimelapseVideo: + row += 1 + gridLayout.addWidget(QLabel("Add timestamp:"), row, 0) + self.addTimestampToggle = widgets.Toggle() + gridLayout.addWidget( + self.addTimestampToggle, row, 1, alignment=Qt.AlignCenter + ) + self.addTimestampToggle.setChecked(isTimestampPresent) + + for channel in channels: + row += 1 + labelText = f"Rescale intensities (LUT) {channel}:" + gridLayout.addWidget(QLabel(labelText), row, 0) + rescaleItems = ["Rescale each 2D image"] + if SizeZ > 1: + rescaleItems.append("Rescale across z-stack") + if isTimelapseVideo: + rescaleItems.append("Rescale across time frames") + rescaleItems.append("Choose custom levels...") + rescaleItems.append("Do no rescale, display raw image") + rescaleIntensCombobox = QComboBox() + rescaleIntensCombobox.addItems(rescaleItems) + rescaleIntensHow = rescaleIntensChannelHowMapper.get(channel) + if rescaleIntensHow is not None: + rescaleIntensCombobox.setCurrentText(rescaleIntensHow) + gridLayout.addWidget(rescaleIntensCombobox, row, 1) + rescaleIntensCombobox.textActivated.connect( + partial(self.emitRescaleIntens, channel=channel) + ) + + row += 1 + gridLayout.addWidget(QLabel("Save a PNG for each frame:"), row, 0) + self.saveFramesToggle = widgets.Toggle() + gridLayout.addWidget(self.saveFramesToggle, row, 1, alignment=Qt.AlignCenter) + + gridLayout.setColumnStretch(0, 0) + gridLayout.setColumnStretch(1, 1) + gridLayout.setColumnStretch(2, 0) + + self.fileFormatCombobox.currentTextChanged.connect(self.updateFileFormat) + self.browseButton.sigPathSelected.connect(self.updateFolderPath) + self.addScaleBarToggle.toggled.connect(self.addScaleBarToggled) + if isTimelapseVideo: + self.addTimestampToggle.toggled.connect(self.addTimestampToggled) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.setText("Export") + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def emitRescaleIntens(self, how, channel=""): + self.sigRescaleIntensLut.emit(how, channel) + + def addScaleBarToggled(self, checked): + self.sigAddScaleBar.emit(checked) + + def addTimestampToggled(self, checked): + self.sigAddTimestamp.emit(checked) + + def updateFolderPath(self, folderPath): + self.folderPathLineEdit.setText(folderPath) + self.browseButton.setStartPath(folderPath) + + def updateFileFormat(self, fileFormat): + self.fileFormatLabel.setText(f".{fileFormat.lower()}") + + def validateFolderPath(self): + folderPath = self.folderPathLineEdit.text() + if os.path.exists(folderPath) and os.path.isdir(folderPath): + return True + + text = html_utils.paragraph( + "The selected folder path is not a valid folder or does not exist" + ) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Not a valid folder", text) + return False + + def validateFilename(self): + filename = self.filenameLineEdit.text() + if filename: + return True + + text = html_utils.paragraph("The filename cannot be empty!") + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Not a valid folder", text) + return False + + def validate(self): + proceed = self.validateFolderPath() + if not proceed: + return False + + proceed = self.validateFilename() + if not proceed: + return False + + return True + + def preferences(self, makedirs=True): + filename = f"{self.filenameLineEdit.text()}{self.fileFormatLabel.text()}" + avi_filename = f"{self.filenameLineEdit.text()}.avi" + avi_filepath = os.path.join(self.folderPathLineEdit.text(), avi_filename) + png_foldername = f"{self.filenameLineEdit.text()}_frames_PNG" + pngs_folderpath = os.path.join(self.folderPathLineEdit.text(), png_foldername) + if makedirs: + os.makedirs(pngs_folderpath, exist_ok=True) + + preferences = { + "start_nav_var_num": self.startNavVarNumberEntry.value(), + "stop_nav_var_num": self.stopNavVarNumberEntry.value(), + "filepath": os.path.join(self.folderPathLineEdit.text(), filename), + "filename": self.filenameLineEdit.text(), + "avi_filepath": avi_filepath, + "pngs_folderpath": pngs_folderpath, + "num_digits": len(str(self.stopNavVarNumberEntry.value())), + "fps": self.fpsWidget.value(), + "save_pngs": self.saveFramesToggle.isChecked(), + "is_timelapse": self.isTimelapseVideo, + "dpi": self.dpiWidget.value(), + } + return preferences + + def ok_cb(self): + proceed = self.validate() + if not proceed: + return + self.cancel = False + self.sigOk.emit(self.preferences()) + self.selected_preferences = self.preferences() + self.close() + + +class TimestampPropertiesDialog(QBaseDialog): + sigValueChanged = Signal(object) + + def __init__(self, parent=None, **properties): + super().__init__(parent=parent) + + self.cancel = True + self.setWindowTitle("Timestamp preferences") + + mainLayout = QVBoxLayout() + + formLayout = widgets.FormLayout() + formLayout.setVerticalSpacing(10) + formLayout.setHorizontalSpacing(50) + + row = 0 + self.startTimeWidget = widgets.TimeWidget() + if properties.get("start_timedelta") is not None: + self.startTimeWidget.setValuesFromTimedelta( + properties.get("start_timedelta") + ) + startTimeFormWidget = widgets.formWidget( + self.startTimeWidget, + labelTextLeft="Start time", + ) + formLayout.addFormWidget( + startTimeFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + self.colorButton = widgets.myColorButton(color=(255, 255, 255)) + if properties.get("color") is not None: + self.colorButton.setColor(properties.get("color")) + colorFormWidget = widgets.formWidget( + self.colorButton, + labelTextLeft="Color", + widgetAlignment=Qt.AlignCenter, + stretchWidget=False, + ) + formLayout.addFormWidget( + colorFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + fontSizeWidget = widgets.FontSizeWidget() + if properties.get("font_size") is not None: + fontSizeWidget.setValue(properties.get("font_size")) + else: + fontSizeWidget.setValue(12) + fontSizeFormWidget = widgets.formWidget( + fontSizeWidget, labelTextLeft="Font size (px)" + ) + self.fontSizeWidget = fontSizeWidget + formLayout.addFormWidget( + fontSizeFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + + row += 1 + locCombobox = QComboBox() + locFormWidget = widgets.formWidget(locCombobox, labelTextLeft="Location") + locCombobox.addItems( + ["Top-left", "Top-right", "Bottom-left", "Bottom-right", "Custom"] + ) + loc = properties.get("loc") + if isinstance(loc, str): + locCombobox.setCurrentText(loc.capitalize()) + formLayout.addFormWidget( + locFormWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.locCombobox = locCombobox + + row += 1 + moveWithZoomToggle = widgets.Toggle() + moveWithZoomWidget = widgets.formWidget( + moveWithZoomToggle, + labelTextLeft="Move timestamp with zoom", + widgetAlignment=Qt.AlignCenter, + stretchWidget=False, + ) + formLayout.addFormWidget( + moveWithZoomWidget, row=row, leftLabelAlignment=Qt.AlignLeft + ) + self.moveWithZoomToggle = moveWithZoomToggle + + mainLayout.addLayout(formLayout) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch() + + self.setLayout(mainLayout) + self.setFont(font) + + self.colorButton.clicked.disconnect() + self.colorButton.clicked.connect(self.selectColor) + + self.startTimeWidget.sigValueChanged.connect(self.onValueChanged) + + self.locCombobox.currentTextChanged.connect(self.onValueChanged) + self.fontSizeWidget.sigTextChanged.connect(self.onValueChanged) + self.moveWithZoomToggle.toggled.connect(self.onValueChanged) + + def onValueChanged(self, *args, **kwargs): + self.sigValueChanged.emit(self.kwargs()) + + def selectColor(self): + color = self.colorButton.color() + self.colorButton.origColor = color + self.colorButton.colorDialog.setCurrentColor(color) + self.colorButton.colorDialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.colorButton.colorDialog.setParent(self) + self.colorButton.colorDialog.open() + w = self.width() + left = self.pos().x() + colorDialogTop = self.colorButton.colorDialog.pos().y() + self.colorButton.colorDialog.move(w + left + 10, colorDialogTop) + + def kwargs(self): + kwargs = { + "color": self.colorButton.color(), + "start_timedelta": self.startTimeWidget.timedelta(), + "loc": self.locCombobox.currentText().lower(), + "font_size": self.fontSizeWidget.text(), + "move_with_zoom": self.moveWithZoomToggle.isChecked(), + } + return kwargs + + def ok_cb(self): + self.cancel = False + self.close() + + +class ExportToImageParametersDialog(QBaseDialog): + sigOk = Signal(dict) + sigAddScaleBar = Signal(bool) + sigRangeChanged = Signal(object) + + def __init__( + self, + parent=None, + startFolderpath="", + startFilename="", + startViewRange=None, + isScaleBarPresent=False, + ): + self.cancel = True + + super().__init__(parent=parent) + + self.setWindowTitle("Preferences for output image") + + mainLayout = QVBoxLayout() + + gridLayout = QGridLayout() + + row = 0 + gridLayout.addWidget(QLabel("View range X axis:"), row, 0) + self.xRangeSelector = widgets.RangeSelector(integers=True) + if startViewRange is not None: + xRange, yRange = startViewRange + self.xRangeSelector.setRange(*xRange) + gridLayout.addWidget(self.xRangeSelector, row, 1) + + row += 1 + gridLayout.addWidget(QLabel("View range Y axis:"), row, 0) + self.yRangeSelector = widgets.RangeSelector(integers=True) + if startViewRange is not None: + xRange, yRange = startViewRange + self.yRangeSelector.setRange(*yRange) + gridLayout.addWidget(self.yRangeSelector, row, 1) + + row += 1 + gridLayout.addWidget(QLabel("Width and Height:"), row, 0) + self.widthHeightSelector = widgets.RangeSelector(integers=True, ordered=False) + if startViewRange is not None: + xRange, yRange = startViewRange + width = int(xRange[1] - xRange[0]) + height = int(yRange[1] - yRange[0]) + self.widthHeightSelector.setRange(width, height) + gridLayout.addWidget(self.widthHeightSelector, row, 1) + self.lockSizeButton = widgets.LockPushButton() + self.lockSizeButton.setCheckable(True) + self.lockSizeButton.setToolTip("Lock width and height") + gridLayout.addWidget(self.lockSizeButton, row, 2) + + row += 1 + gridLayout.addWidget(QLabel("File format:"), row, 0) + self.fileFormatCombobox = QComboBox() + self.fileFormatCombobox.addItems(["SVG", "PNG", "TIFF", "JPEG"]) + gridLayout.addWidget(self.fileFormatCombobox, row, 1) + + row += 1 + self.dpiWidget = widgets.IntLineEdit(allowNegative=False) + self.dpiWidget.setValue(300) + self.dpiWidget.label = QLabel("DPI") + gridLayout.addWidget(self.dpiWidget.label, row, 0) + gridLayout.addWidget(self.dpiWidget, row, 1) + self.dpiWidget.hide() + self.dpiWidget.label.hide() + + row += 1 + gridLayout.addWidget(QLabel("Folder path:"), row, 0) + self.folderPathLineEdit = widgets.ElidingLineEdit(minWidth=240) + self.folderPathLineEdit.setText(startFolderpath) + gridLayout.addWidget(self.folderPathLineEdit, row, 1) + self.browseButton = widgets.browseFileButton( + start_dir=startFolderpath, openFolder=True + ) + gridLayout.addWidget(self.browseButton, row, 2) + + row += 1 + gridLayout.addWidget(QLabel("Filename:"), row, 0) + self.filenameLineEdit = widgets.alphaNumericLineEdit() + self.filenameLineEdit.setAlignment(Qt.AlignCenter) + self.filenameLineEdit.setText(startFilename) + gridLayout.addWidget(self.filenameLineEdit, row, 1) + self.fileFormatLabel = QLabel( + f".{self.fileFormatCombobox.currentText().lower()}" + ) + gridLayout.addWidget(self.fileFormatLabel, row, 2) + + row += 1 + gridLayout.addWidget(QLabel("Add Scale Bar:"), row, 0) + self.addScaleBarToggle = widgets.Toggle() + gridLayout.addWidget(self.addScaleBarToggle, row, 1, alignment=Qt.AlignCenter) + self.addScaleBarToggle.setChecked(isScaleBarPresent) + + self.fileFormatCombobox.currentTextChanged.connect(self.updateFileFormat) + self.browseButton.sigPathSelected.connect(self.updateFolderPath) + self.addScaleBarToggle.toggled.connect(self.addScaleBarToggled) + self.xRangeSelector.sigLowValueChanged.connect(self.x0Changed) + self.xRangeSelector.sigHighValueChanged.connect(self.x1Changed) + self.yRangeSelector.sigLowValueChanged.connect(self.y0Changed) + self.yRangeSelector.sigHighValueChanged.connect(self.y1Changed) + self.widthHeightSelector.sigLowValueChanged.connect(self.widthChanged) + self.widthHeightSelector.sigHighValueChanged.connect(self.heightChanged) + self.widthHeightSelector.sigRangeManuallyChanged.connect( + self.widthHeightManuallyChanged + ) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.setText("Export") + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + gridLayout.setColumnStretch(2, 0) + + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def widthHeightManuallyChanged(self, *args): + self.lockSizeButton.setChecked(True) + + def x0Changed(self, *args): + if self.lockSizeButton.isChecked(): + x0, _ = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + width, height = self.widthHeightSelector.range() + x1 = x0 + width + xRange = (x0, x1) + else: + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + _, height = self.widthHeightSelector.range() + width = int(xRange[1] - xRange[0]) + + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + self.widthHeightSelector.setRangeNoEmit(width, height) + self.rangeChanged() + + def x1Changed(self, *args): + if self.lockSizeButton.isChecked(): + _, x1 = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + width, height = self.widthHeightSelector.range() + x0 = x1 - width + xRange = (x0, x1) + else: + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + _, height = self.widthHeightSelector.range() + width = int(xRange[1] - xRange[0]) + + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + self.widthHeightSelector.setRangeNoEmit(width, height) + + self.rangeChanged() + + def y0Changed(self, *args): + if self.lockSizeButton.isChecked(): + xRange = self.xRangeSelector.range() + y0, _ = self.yRangeSelector.range() + width, height = self.widthHeightSelector.range() + y1 = y0 + height + yRange = (y0, y1) + else: + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + width, _ = self.widthHeightSelector.range() + height = int(yRange[1] - yRange[0]) + + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + self.widthHeightSelector.setRangeNoEmit(width, height) + + self.rangeChanged() + + def y1Changed(self, *args): + if self.lockSizeButton.isChecked(): + xRange = self.xRangeSelector.range() + _, y1 = self.yRangeSelector.range() + width, height = self.widthHeightSelector.range() + y0 = y1 - height + yRange = (y0, y1) + else: + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + width, _ = self.widthHeightSelector.range() + height = int(yRange[1] - yRange[0]) + + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + self.widthHeightSelector.setRangeNoEmit(width, height) + + self.rangeChanged() + + def widthChanged(self, *args): + self.widthHeightChanged() + self.rangeChanged() + + def heightChanged(self, *args): + self.widthHeightChanged() + self.rangeChanged() + + def updateViewRangeExportToImageDialog(self, viewBox, viewRange, changed): + xRange, yRange = viewRange + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + + def widthHeightChanged(self, *args): + x0, _ = self.xRangeSelector.range() + y0, _ = self.yRangeSelector.range() + width, height = self.widthHeightSelector.range() + x1 = x0 + width + y1 = y0 + height + self.xRangeSelector.setRangeNoEmit(x0, x1) + self.yRangeSelector.setRangeNoEmit(y0, y1) + self.rangeChanged() + + def rangeChanged(self, *args): + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + self.sigRangeChanged.emit((xRange, yRange)) + + def addScaleBarToggled(self, checked): + self.sigAddScaleBar.emit(checked) + + def updateFolderPath(self, folderPath): + self.folderPathLineEdit.setText(folderPath) + self.browseButton.setStartPath(folderPath) + + def updateFileFormat(self, fileFormat): + if fileFormat == "SVG": + self.dpiWidget.hide() + self.dpiWidget.label.hide() + else: + self.dpiWidget.show() + self.dpiWidget.label.show() + + self.fileFormatLabel.setText(f".{fileFormat.lower()}") + + def validateFolderPath(self): + folderPath = self.folderPathLineEdit.text() + if os.path.exists(folderPath) and os.path.isdir(folderPath): + return True + + text = html_utils.paragraph( + "The selected folder path is not a valid folder or does not exist" + ) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Not a valid folder", text) + return False + + def validateFilename(self): + filename = self.filenameLineEdit.text() + if filename: + return True + + text = html_utils.paragraph("The filename cannot be empty!") + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Not a valid folder", text) + return False + + def validate(self): + proceed = self.validateFolderPath() + if not proceed: + return False + + proceed = self.validateFilename() + if not proceed: + return False + + return True + + def setViewRange(self, xRange, yRange, emitSignal=True): + if self.lockSizeButton.isChecked(): + x0, _ = xRange + y0, _ = yRange + width, height = self.widthHeightSelector.range() + x1 = x0 + width + y1 = y0 + height + xRange = (x0, x1) + yRange = (y0, y1) + else: + width = int(xRange[1] - xRange[0]) + height = int(yRange[1] - yRange[0]) + + self.xRangeSelector.setRangeNoEmit(*xRange) + self.yRangeSelector.setRangeNoEmit(*yRange) + self.widthHeightSelector.setRangeNoEmit(width, height) + if not emitSignal: + return + + self.rangeChanged() + + def viewRange(self): + xRange = self.xRangeSelector.range() + yRange = self.yRangeSelector.range() + return (xRange, yRange) + + def preferences(self): + filename = f"{self.filenameLineEdit.text()}{self.fileFormatLabel.text()}" + preferences = { + "view_range_x": self.xRangeSelector.range(), + "view_range_y": self.yRangeSelector.range(), + "filepath": os.path.join(self.folderPathLineEdit.text(), filename), + "filename": self.filenameLineEdit.text(), + "dpi": self.dpiWidget.value(), + } + return preferences + + def ok_cb(self): + proceed = self.validate() + if not proceed: + return + self.cancel = False + self.sigOk.emit(self.preferences()) + self.selected_preferences = self.preferences() + self.close() + + +class LogoDialog(QDialog): + def __init__(self, logo_path, icon_path, parent=None): + super().__init__(parent) + + layout = QVBoxLayout() + + self.setWindowFlags(Qt.FramelessWindowHint) + # self.setWindowFlags(Qt.WindowStaysOnTopHint | Qt.FramelessWindowHint) + # self.setAttribute(Qt.WA_TranslucentBackground) + # self.setWindowIcon(QIcon(icon_path)) + + labelLogo = QLabel() + pixmapLogo = QPixmap(logo_path) + labelLogo.setPixmap(pixmapLogo) + + layout.addWidget(labelLogo) + + self.setLayout(layout) + + +class ObjectCountDialog(QBaseDialog): + sigShowEvent = Signal() + sigUpdateCounts = Signal() + + def __init__( + self, + categoryCountMapper: dict, + parent=None, + data: list["load.loadData"] | None = None, + ): + super().__init__(parent=parent) + self.setWindowTitle("Object count") + + self.cancel = False + mainLayout = QVBoxLayout() + + cancelOkLayout = widgets.CancelOkButtonsLayout() + cancelOkLayout.okButton.clicked.connect(self.ok_cb) + cancelOkLayout.cancelButton.clicked.connect(self.close) + + self.data = data + if data is not None: + saveCountsButton = widgets.savePushButton("Export counts to CSV table") + saveCountsButton.clicked.connect(self.saveCounts) + cancelOkLayout.insertWidget(3, saveCountsButton) + + updateCountsButton = widgets.reloadPushButton("Update counts") + cancelOkLayout.insertWidget(3, updateCountsButton) + updateCountsButton.clicked.connect(self.emitUpdateCounts) + + mainLayout.addWidget( + QLabel(html_utils.paragraph("Object count
", font_size="18px")), + alignment=Qt.AlignLeft, + ) + self.showHideButtons = [] + self.categoryLabelMapper = {} + for category, count in categoryCountMapper.items(): + categoryLayout = QHBoxLayout() + categoryLayout.addSpacing(10) + catText = html_utils.paragraph(f"
{category}
", font_size="13px") + catLabel = QLabel(catText) + categoryLayout.addWidget(catLabel) + categoryLayout.addStretch(1) + + countText = html_utils.paragraph(f"
{count}
", font_size="13px") + countLabel = QLabel(countText) + categoryLayout.addWidget(countLabel) + + self.categoryLabelMapper[category] = countLabel + + showHideButton = widgets.showDetailsButton(txt="") + showHideButton.setChecked(True) + showHideButton.sigToggled.connect( + partial(self.showHideCount, labels=(catLabel, countLabel)) + ) + showHideButton.setToolTip(f'Show/hide "{category}" count') + categoryLayout.addSpacing(10) + categoryLayout.addWidget(showHideButton) + showHideButton.category = category + + self.showHideButtons.append(showHideButton) + + categoryLayout.setStretch(0, 0) + categoryLayout.setStretch(1, 0) + categoryLayout.setStretch(3, 0) + + mainLayout.addLayout(categoryLayout) + mainLayout.addWidget(widgets.QHLine()) + + mainLayout.addSpacing(10) + + infoLayout = QHBoxLayout() + self.livePreviewCheckbox = QCheckBox("Live preview") + self.livePreviewCheckbox.setChecked(True) + infoLayout.addWidget(self.livePreviewCheckbox) + infoLayout.addStretch(1) + self.warnLabel = QLabel("") + infoLayout.addWidget(self.warnLabel) + self.livePreviewCheckbox.toggled.connect(self.updateWarnLabel) + mainLayout.addLayout(infoLayout) + + mainLayout.addSpacing(30) + mainLayout.addStretch(1) + mainLayout.addLayout(cancelOkLayout) + + self.setLayout(mainLayout) + + def saveCounts(self, checked=False): + categories = self.activeCategories() + for posData in self.data: + countMapper = posData.countObjectsInSegm(categories) + countMapper.pop("In current frame", None) + df_count_endname = posData.saveObjCounts(countMapper) + + txt = html_utils.paragraph(f""" + Done!

+ Objects count table saved in every loaded Position folder
+ as a CSV file ending with {df_count_endname} + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Objects count saved", txt) + + def updateWarnLabel(self, checked): + if not checked: + self.warnLabel.setText( + html_utils.paragraph( + "WARNING: without live preview, counts are not updated", + font_color="red", + ) + ) + else: + self.warnLabel.setText("") + + def emitUpdateCounts(self): + self.sigUpdateCounts.emit() + + def activeCategories(self) -> List[str]: + activeCategories = [] + for showHideButton in self.showHideButtons: + if not showHideButton.isChecked(): + continue + activeCategories.append(showHideButton.category) + + return activeCategories + + def showHideCount(self, checked, labels): + for label in labels: + label.setVisible(checked) + + QTimer.singleShot(100, self.resizeToHeightHint) + + def updateCounts(self, categoryCountMapper): + for category, count in categoryCountMapper.items(): + countLabel = self.categoryLabelMapper[category] + countText = html_utils.paragraph(f"
{count}
", font_size="13px") + countLabel.setText(countText) + + def resizeToHeightHint(self): + heightHint = self.sizeHint().height() + self.resize(self.width(), heightHint) + + def showEvent(self, event): + widthHint = self.sizeHint().width() + self.resize(int(widthHint * 1.5), self.height()) + self.sigShowEvent.emit() + + def ok_cb(self): + self.cancel = False + self.close() + +# Sibling imports (deferred to avoid import cycles) +from .models import ( + DataFrameModel, +) + diff --git a/cellacdc/dialogs/general.py b/cellacdc/dialogs/general.py new file mode 100644 index 000000000..697fe6655 --- /dev/null +++ b/cellacdc/dialogs/general.py @@ -0,0 +1,3414 @@ +"""Cell-ACDC dialog windows: general.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +class customAnnotationDialog(QDialog): + sigDeleteSelecAnnot = Signal(object) + + def __init__(self, savedCustomAnnot, parent=None, state=None): + self.cancel = True + self.loop = None + self.clickedButton = None + self.savedCustomAnnot = savedCustomAnnot + + self.internalNames = measurements.get_all_acdc_df_colnames(include_custom=False) + + super().__init__(parent) + + self.setWindowTitle("Custom annotation") + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + layout = widgets.FormLayout() + + row = 0 + typeCombobox = QComboBox() + typeCombobox.addItems( + ["Single time-point", "Multiple time-points", "Multiple values class"] + ) + if state is not None: + typeCombobox.setCurrentText(state["type"]) + self.typeCombobox = typeCombobox + body_txt = """ + Single time-point annotation: use this to annotate + an event that happens on a single frame in time + (e.g. cell division). +

+ Multiple time-points annotation: use this to annotate + an event that has a duration, i.e., a start frame and a stop + frame (e.g. cell cycle phase).

+ Multiple values class annotation: use this to annotate a class + that has multiple values. An example could be a cell cycle stage + that can have different values, such as 2-cells division + or 4-cells division. + """ + typeInfoTxt = f"{html_utils.paragraph(body_txt)}" + self.typeWidget = widgets.formWidget( + typeCombobox, + addInfoButton=True, + labelTextLeft="Type: ", + parent=self, + infoTxt=typeInfoTxt, + ) + layout.addFormWidget(self.typeWidget, row=row) + typeCombobox.currentTextChanged.connect(self.warnType) + + row += 1 + nameInfoTxt = """ + Name of the column that will be saved in the acdc_output.csv + file.

+ Valid charachters are letters and numbers separate by underscore + or dash only.

+ Additionally, some names are reserved because they are used + by Cell-ACDC for standard measurements.

+ Internally reserved names: + """ + self.nameInfoTxt = f"{html_utils.paragraph(nameInfoTxt)}" + self.nameWidget = widgets.formWidget( + widgets.alphaNumericLineEdit(), + addInfoButton=True, + labelTextLeft="Name: ", + parent=self, + infoTxt=self.nameInfoTxt, + ) + self.nameWidget.infoButton.disconnect() + self.nameWidget.infoButton.clicked.connect(self.showNameInfo) + if state is not None: + self.nameWidget.widget.setText(state["name"]) + self.nameWidget.widget.textChanged.connect(self.checkName) + layout.addFormWidget(self.nameWidget, row=row) + + row += 1 + self.nameInfoLabel = QLabel() + layout.addWidget(self.nameInfoLabel, row, 0, 1, 2, alignment=Qt.AlignCenter) + + row += 1 + spacing = QSpacerItem(10, 10) + layout.addItem(spacing, row, 0) + + row += 1 + symbolInfoTxt = """ + Symbol that will be drawn on the annotated cell at + the requested time frame. + """ + symbolInfoTxt = f"{html_utils.paragraph(symbolInfoTxt)}" + self.symbolWidget = widgets.formWidget( + widgets.pgScatterSymbolsCombobox(), + addInfoButton=True, + labelTextLeft="Symbol: ", + parent=self, + infoTxt=symbolInfoTxt, + ) + if state is not None: + self.symbolWidget.widget.setCurrentText(state["symbol"]) + layout.addFormWidget(self.symbolWidget, row=row) + + row += 1 + shortcutInfoTxt = """ + Shortcut that you can use to activate/deactivate annotation + of this event.

Leave empty if you don't need a shortcut. + """ + shortcutInfoTxt = f"{html_utils.paragraph(shortcutInfoTxt)}" + self.shortcutWidget = widgets.formWidget( + widgets.ShortcutLineEdit(), + addInfoButton=True, + labelTextLeft="Shortcut: ", + parent=self, + infoTxt=shortcutInfoTxt, + ) + if state is not None: + self.shortcutWidget.widget.setText(state["shortcut"]) + layout.addFormWidget(self.shortcutWidget, row=row) + + row += 1 + descInfoTxt = """ + Description will be used as the tool tip that will be + displayed when you hover with th mouse cursor on the toolbar button + specific for this annotation + """ + descInfoTxt = f"{html_utils.paragraph(descInfoTxt)}" + self.descWidget = widgets.formWidget( + QPlainTextEdit(), + addInfoButton=True, + labelTextLeft="Description: ", + parent=self, + infoTxt=descInfoTxt, + ) + if state is not None: + self.descWidget.widget.setPlainText(state["description"]) + layout.addFormWidget(self.descWidget, row=row) + + row += 1 + optionsGroupBox = QGroupBox("Additional options") + optionsLayout = QGridLayout() + toggle = widgets.Toggle() + toggle.setChecked(True) + self.keepActiveToggle = toggle + toggleLabel = QLabel("Keep tool active after using it: ") + colorButtonLabel = QLabel("Symbol color: ") + self.hideAnnotTooggle = widgets.Toggle() + self.hideAnnotTooggle.setChecked(True) + hideAnnotTooggleLabel = QLabel("Hide annotation when button is not active: ") + self.colorButton = widgets.myColorButton(color=(255, 0, 0)) + self.colorButton.clicked.disconnect() + self.colorButton.clicked.connect(self.selectColor) + + optionsLayout.setColumnStretch(0, 1) + optRow = 0 + optionsLayout.addWidget(toggleLabel, optRow, 1) + optionsLayout.addWidget(toggle, optRow, 2) + optRow += 1 + optionsLayout.addWidget(hideAnnotTooggleLabel, optRow, 1) + optionsLayout.addWidget(self.hideAnnotTooggle, optRow, 2) + optionsLayout.setColumnStretch(3, 1) + optRow += 1 + optionsLayout.addWidget(colorButtonLabel, optRow, 1) + optionsLayout.addWidget(self.colorButton, optRow, 2) + + optionsGroupBox.setLayout(optionsLayout) + layout.addWidget(optionsGroupBox, row, 1, alignment=Qt.AlignCenter) + optionsInfoButton = QPushButton(self) + optionsInfoButton.setCursor(Qt.WhatsThisCursor) + optionsInfoButton.setIcon(QIcon(":info.svg")) + optionsInfoButton.clicked.connect(self.showOptionsInfo) + layout.addWidget(optionsInfoButton, row, 3, alignment=Qt.AlignRight) + + row += 1 + layout.addItem(QSpacerItem(5, 5), row, 0) + + row += 1 + noteText = ( + "NOTE: you can change these options later with
" + "RIGHT-click on the associated left-side toolbar button.
" + ) + noteLabel = QLabel(html_utils.paragraph(noteText, font_size="11px")) + layout.addWidget(noteLabel, row, 1, 1, 3) + + buttonsLayout = QHBoxLayout() + + self.loadSavedAnnotButton = widgets.OpenFilePushButton(" Load annotation... ") + if not savedCustomAnnot: + self.loadSavedAnnotButton.setDisabled(True) + self.okButton = widgets.okPushButton(" Ok ") + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(self.loadSavedAnnotButton) + buttonsLayout.addWidget(self.okButton) + + cancelButton.clicked.connect(self.cancelCallBack) + self.cancelButton = cancelButton + self.loadSavedAnnotButton.clicked.connect(self.loadSavedAnnot) + self.okButton.clicked.connect(self.ok_cb) + self.okButton.setFocus() + + mainLayout = QVBoxLayout() + + noteTxt = """ + Custom annotations will be saved in the acdc_output.csv
+ file as a column with the name you write in the field Name
+ """ + noteTxt = f"{html_utils.paragraph(noteTxt, font_size='15px')}" + noteLabel = QLabel(noteTxt) + noteLabel.setAlignment(Qt.AlignCenter) + mainLayout.addWidget(noteLabel) + + mainLayout.addLayout(layout) + mainLayout.addStretch(1) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def checkName(self, text): + if not text: + txt = "Name cannot be empty" + self.nameInfoLabel.setText( + html_utils.paragraph(txt, font_size="11px", font_color="red") + ) + return + for name in self.internalNames: + if name.find(text) != -1: + txt = f'"{text}" cannot be part of the name, because reserved.' + self.nameInfoLabel.setText( + html_utils.paragraph(txt, font_size="11px", font_color="red") + ) + break + else: + self.nameInfoLabel.setText("") + + def loadSavedAnnot(self): + items = list(self.savedCustomAnnot.keys()) + self.selectAnnotWin = widgets.QDialogListbox( + "Load annotation parameters", + "Select annotation to load:", + items, + additionalButtons=("Delete selected annnotations",), + parent=self, + multiSelection=False, + ) + for button in self.selectAnnotWin._additionalButtons: + button.disconnect() + button.clicked.connect(self.deleteSelectedAnnot) + self.selectAnnotWin.exec_() + if self.selectAnnotWin.cancel: + return + if self.selectAnnotWin.listBox.count() == 0: + return + if not self.selectAnnotWin.selectedItemsText: + self.warnNoItemsSelected() + return + selectedName = self.selectAnnotWin.selectedItemsText[-1] + selectedAnnot = self.savedCustomAnnot[selectedName] + self.typeCombobox.setCurrentText(selectedAnnot["type"]) + self.nameWidget.widget.setText(selectedAnnot["name"]) + self.symbolWidget.widget.setCurrentText(selectedAnnot["symbol"]) + self.shortcutWidget.widget.setText(selectedAnnot["shortcut"]) + self.descWidget.widget.setPlainText(selectedAnnot["description"]) + self.colorButton.setColor(selectedAnnot["symbolColor"]) + keySequence = widgets.macShortcutToWindows(selectedAnnot["shortcut"]) + if keySequence: + self.shortcutWidget.widget.keySequence = widgets.KeySequenceFromText( + keySequence + ) + + def warnNoItemsSelected(self): + msg = widgets.myMessageBox(parent=self) + msg.setIcon(iconName="SP_MessageBoxWarning") + msg.setWindowTitle("Delete annotation?") + msg.addText("You didn't select any annotation!") + msg.addButton(" Ok ") + msg.exec_() + + def deleteSelectedAnnot(self): + msg = widgets.myMessageBox(parent=self) + msg.setIcon(iconName="SP_MessageBoxWarning") + msg.setWindowTitle("Delete annotation?") + msg.addText("Are you sure you want to delete the selected annotations?") + msg.addButton("Yes") + cancelButton = msg.addButton(" Cancel ") + msg.exec_() + if msg.clickedButton == cancelButton: + return + for item in self.selectAnnotWin.listBox.selectedItems(): + name = item.text() + self.savedCustomAnnot.pop(name) + self.sigDeleteSelecAnnot.emit(self.selectAnnotWin.listBox.selectedItems()) + items = list(self.savedCustomAnnot.keys()) + self.selectAnnotWin.listBox.clear() + self.selectAnnotWin.listBox.addItems(items) + + def selectColor(self): + color = self.colorButton.color() + self.colorButton.origColor = color + self.colorButton.colorDialog.setCurrentColor(color) + self.colorButton.colorDialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.colorButton.colorDialog.open() + w = self.width() + left = self.pos().x() + colorDialogTop = self.colorButton.colorDialog.pos().y() + self.colorButton.colorDialog.move(w + left + 10, colorDialogTop) + + def warnType(self, currentText): + if currentText == "Single time-point": + return + + self.typeCombobox.setCurrentIndex(0) + + txt = """ + Unfortunately, the only annotation type that is available so far is + Single time-point.

+ We are working on implementing the other types too, so stay tuned!

+ Thank you for your patience! + """ + txt = f"{html_utils.paragraph(txt)}" + msg = widgets.myMessageBox() + msg.setIcon(iconName="SP_MessageBoxWarning") + msg.setWindowTitle(f"Feature not implemented yet") + msg.addText(txt) + msg.addButton(" Ok ") + msg.exec_() + + def showOptionsInfo(self): + info = """ + Keep tool active after using it: Choose whether the tool + should stay active or not after annotating.

+ Hide annotation when button is not active: Choose whether + annotation on the cell/object should be visible only if the + button is active or also when it is not active.
+ NOTE: annotations are always stored no matter whether + they are visible or not.

+ Symbol color: Choose color of the symbol that will be used + to label annotated cell/object. + """ + info = f"{html_utils.paragraph(info)}" + msg = widgets.myMessageBox() + msg.setIcon() + msg.setWindowTitle(f"Additional options info") + msg.addText(info) + msg.addButton(" Ok ") + msg.exec_() + + def ok_cb(self, checked=True): + self.cancel = False + self.clickedButton = self.okButton + self.close() + + def cancelCallBack(self, checked=True): + self.cancel = True + self.clickedButton = self.cancelButton + self.close() + + def showNameInfo(self): + msg = widgets.myMessageBox() + listView = widgets.readOnlyQList(msg) + listView.addItems(self.internalNames) + # listView.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) + msg.information( + self, "Annotation Name info", self.nameInfoTxt, widgets=listView + ) + + def closeEvent(self, event): + if self.clickedButton is None or self.clickedButton == self.cancelButton: + # cancel button or closed with 'x' button + self.cancel = True + return + + if self.clickedButton == self.okButton and not self.nameWidget.widget.text(): + msg = QMessageBox() + msg.critical(self, "Empty name", "The name cannot be empty!", msg.Ok) + event.ignore() + self.cancel = True + return + + if self.clickedButton == self.okButton and self.nameInfoLabel.text(): + msg = widgets.myMessageBox() + listView = widgets.listWidget(msg) + listView.addItems(self.internalNames) + listView.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) + name = self.nameWidget.widget.text() + txt = ( + f'"{name}" cannot be part of the name, ' + "because it is reserved for standard measurements " + "saved by Cell-ACDC.

" + "Internally reserved names:" + ) + msg.critical( + self, "Not a valid name", html_utils.paragraph(txt), widgets=listView + ) + event.ignore() + self.cancel = True + return + + self.toolTip = ( + f"Name: {self.nameWidget.widget.text()}\n\n" + f"Type: {self.typeWidget.widget.currentText()}\n\n" + f"Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n" + f"Description: {self.descWidget.widget.toPlainText()}\n\n" + f'SHORTCUT: "{self.shortcutWidget.widget.text()}"' + ) + + symbol = self.symbolWidget.widget.currentText() + self.symbol = re.findall(r"\'(.+)\'", symbol)[0] + + self.state = { + "type": self.typeWidget.widget.currentText(), + "name": self.nameWidget.widget.text(), + "symbol": self.symbolWidget.widget.currentText(), + "shortcut": self.shortcutWidget.widget.text(), + "description": self.descWidget.widget.toPlainText(), + "keepActive": self.keepActiveToggle.isChecked(), + "isHideChecked": self.hideAnnotTooggle.isChecked(), + "symbolColor": self.colorButton.color(), + } + + if self.loop is not None: + self.loop.exit() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + +class _PointsLayerAppearanceGroupbox(QGroupBox): + def __init__(self, *args): + super().__init__(*args) + + self.setTitle("Points appearance") + + layout = widgets.FormLayout() + + "----------------------------------------------------------------------" + row = 0 + symbolInfoTxt = """ + Symbol used to draw the points. + """ + symbolInfoTxt = f"{html_utils.paragraph(symbolInfoTxt)}" + self.symbolWidget = widgets.formWidget( + widgets.pgScatterSymbolsCombobox(), + addInfoButton=True, + labelTextLeft="Symbol: ", + parent=self, + infoTxt=symbolInfoTxt, + stretchWidget=False, + ) + layout.addFormWidget(self.symbolWidget, row=row) + "----------------------------------------------------------------------" + + "----------------------------------------------------------------------" + row += 1 + self.colorButton = widgets.myColorButton(color=(255, 0, 0)) + self.colorWidget = widgets.formWidget( + self.colorButton, stretchWidget=True, labelTextLeft="Colour: ", parent=self + ) + layout.addFormWidget(self.colorWidget, align=Qt.AlignLeft, row=row) + self.colorButton.clicked.disconnect() + self.colorButton.clicked.connect(self.selectColor) + "----------------------------------------------------------------------" + + "----------------------------------------------------------------------" + row += 1 + self.sizeSpinBox = widgets.SpinBox() + self.sizeSpinBox.setValue(5) + self.sizeWidget = widgets.formWidget( + self.sizeSpinBox, stretchWidget=True, labelTextLeft="Size: ", parent=self + ) + layout.addFormWidget(self.sizeWidget, row=row) + "----------------------------------------------------------------------" + + "----------------------------------------------------------------------" + row += 1 + zHeightTooltip = ( + 'If "Z-depth" is greater than 1, the points will be annotated ' + "in all the z-slices in the range `z - (Z-depth/2) < z < z + (Z-depth/2)`\n" + "where `z` is the center z-slice of the added point." + ) + self.zHeightSpinBox = widgets.OddSpinBox() + self.zHeightSpinBox.setValue(1) + self.zHeightSpinBox.setMinimum(1) + self.zHeightWidget = widgets.formWidget( + self.zHeightSpinBox, + stretchWidget=True, + labelTextLeft="Z-depth: ", + parent=self, + toolTip=zHeightTooltip, + ) + layout.addFormWidget(self.zHeightWidget, row=row) + "----------------------------------------------------------------------" + + "----------------------------------------------------------------------" + row += 1 + shortcutInfoTxt = """ + Shortcut that you can use to hide/show points. + """ + shortcutInfoTxt = f"{html_utils.paragraph(shortcutInfoTxt)}" + self.shortcutWidget = widgets.formWidget( + widgets.ShortcutLineEdit(), + addInfoButton=True, + labelTextLeft="Shortcut: ", + parent=self, + infoTxt=shortcutInfoTxt, + ) + layout.addFormWidget(self.shortcutWidget, row=row) + "----------------------------------------------------------------------" + + self.setLayout(layout) + + def restoreState(self, state): + self.shortcutWidget.widget.setText(state["shortcut"]) + self.colorButton.setColor(state["color"]) + self.symbolWidget.widget.setCurrentText(state["symbol"]) + self.sizeSpinBox.setValue(state["pointSize"]) + self.zHeightSpinBox.setValue(state["zHeight"]) + + def selectColor(self): + color = self.colorButton.color() + self.colorButton.origColor = color + self.colorButton.colorDialog.setCurrentColor(color) + self.colorButton.colorDialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.colorButton.colorDialog.open() + w = self.width() + left = self.pos().x() + colorDialogTop = self.colorButton.colorDialog.pos().y() + self.colorButton.colorDialog.move(w + left + 10, colorDialogTop) + + def state(self): + r, g, b, a = self.colorButton.color().getRgb() + _state = { + "symbol": self.symbolWidget.widget.currentText(), + "color": (r, g, b), + "pointSize": self.sizeSpinBox.value(), + "zHeight": self.zHeightSpinBox.value(), + "shortcut": self.shortcutWidget.widget.text(), + } + return _state + + +class AddPointsLayerDialog(QBaseDialog): + sigClosed = Signal() + sigCriticalReadTable = Signal(str) + sigLoadedTable = Signal(object, str) + sigCheckClickEntryTableEndnameExists = Signal(str, bool) + + def __init__( + self, + channelNames=None, + imagesPath="", + SizeT=1, + hideCentroidsSection=False, + hideWeightedCentroidsSection=False, + hideFromTableSection=False, + hideManualEntrySection=False, + hideWithMouseClicksSection=False, + parent=None, + ): + self.cancel = True + super().__init__(parent) + + self._parent = parent + + self.imagesPath = imagesPath + + self.setWindowTitle("Add points layer") + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + mainLayout = QVBoxLayout() + + scrollArea = widgets.ScrollArea() + typeGroupbox = QGroupBox("Points to draw") + typeLayout = QGridLayout() + typeGroupbox.setLayout(typeLayout) + typeLayout.addItem(QSpacerItem(10, 1), 0, 0) + typeLayout.setColumnStretch(0, 0) + typeLayout.setColumnStretch(2, 1) + vSpacing = 15 + + row = 0 + + sections = ( + ("addCentroidsSection", hideCentroidsSection), + ("addWeightedCentroidsSection", hideWeightedCentroidsSection), + ("addFromTableSection", hideFromTableSection), + ("addManualEntrySection", hideManualEntrySection), + ("addWithMouseClicksSection", hideWithMouseClicksSection), + ) + radioButtonChecked = False + for section, hideSection in sections: + addFunc = getattr(self, section) + row, sectionWidgets = addFunc( + row, + typeLayout, + imagesPath=imagesPath, + SizeT=SizeT, + channelNames=channelNames, + ) + if not hideSection: + spacer = QSpacerItem(1, vSpacing) + typeLayout.addItem(spacer, row, 0) + row += 1 + if not radioButtonChecked: + sectionWidgets[0].setChecked(True) + radioButtonChecked = True + continue + + for widget in sectionWidgets: + widget.setVisible(False) + + self.scrollArea = scrollArea + scrollArea.setWidget(typeGroupbox) + + self.appearanceGroupbox = _PointsLayerAppearanceGroupbox() + self.appearanceGroupbox.sizeSpinBox.setValue(3) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + self.buttonsLayout = buttonsLayout + + mainLayout.addWidget(scrollArea) + mainLayout.addSpacing(20) + _layout = QHBoxLayout() + _layout.addWidget(self.appearanceGroupbox) + _layout.addStretch(1) + mainLayout.addLayout(_layout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + self.setFont(font) + + def addCentroidsSection(self, row, layout, **kwargs): + sectionWidgets = [] + self.centroidsRadiobutton = QRadioButton("Centroids") + layout.addWidget(self.centroidsRadiobutton, row, 0, 1, 2) + sectionWidgets.append(self.centroidsRadiobutton) + + self.centroidsRadiobutton.setChecked(True) + return row + 1, sectionWidgets + + def addWeightedCentroidsSection(self, row, layout, channelNames=None, **kwargs): + if channelNames is None: + channelNames = [] + + sectionWidgets = [] + + self.weightedCentroidsRadiobutton = QRadioButton("Weighted centroids") + layout.addWidget(self.weightedCentroidsRadiobutton, row, 0, 1, 2) + sectionWidgets.append(self.weightedCentroidsRadiobutton) + + row += 1 + label = QLabel("Weighing channel: ") + label.setEnabled(False) + layout.addWidget(label, row, 1) + sectionWidgets.append(label) + + self.channelNameForWeightedCentr = widgets.QCenteredComboBox() + if channelNames: + self.channelNameForWeightedCentr.addItems(channelNames) + self.channelNameForWeightedCentr.setDisabled(True) + layout.addWidget(self.channelNameForWeightedCentr, row, 2) + sectionWidgets.append(self.channelNameForWeightedCentr) + + self.weightedCentroidsRadiobutton.toggled.connect(label.setEnabled) + self.weightedCentroidsRadiobutton.toggled.connect( + self.channelNameForWeightedCentr.setEnabled + ) + + return row + 1, sectionWidgets + + def addFromTableSection(self, row, layout, imagesPath="", SizeT=1, **kwargs): + sectionWidgets = [] + + self.fromTableRadiobutton = QRadioButton("From table") + layout.addWidget(self.fromTableRadiobutton, row, 0, 1, 2) + sectionWidgets.append(self.fromTableRadiobutton) + self.fromTableRadiobutton.widgets = [] + + row += 1 + self.tablePath = widgets.ElidingLineEdit() + self.tablePath.label = QLabel("Table file path: ") + layout.addWidget(self.tablePath.label, row, 1) + layout.addWidget(self.tablePath, row, 2) + self.fromTableRadiobutton.widgets.append(self.tablePath) + sectionWidgets.append(self.tablePath.label) + sectionWidgets.append(self.tablePath) + + browseButton = widgets.browseFileButton( + start_dir=imagesPath, ext={"Table": [".csv", ".h5"]} + ) + layout.addWidget(browseButton, row, 3) + browseButton.sigPathSelected.connect(self.tablePathSelected) + self.browseTableButton = browseButton + self.fromTableRadiobutton.widgets.append(browseButton) + sectionWidgets.append(browseButton) + + row += 1 + self.xColName = widgets.QCenteredComboBox() + self.xColName.addItem("None") + self.xColName.label = QLabel("X coord. column: ") + layout.addWidget(self.xColName.label, row, 1) + layout.addWidget(self.xColName, row, 2) + self.xColName.currentTextChanged.connect(self.checkColNameX) + self.fromTableRadiobutton.widgets.append(self.xColName) + sectionWidgets.append(self.xColName.label) + sectionWidgets.append(self.xColName) + + row += 1 + self.yColName = widgets.QCenteredComboBox() + self.yColName.addItem("None") + self.yColName.label = QLabel("Y coord. column: ") + layout.addWidget(self.yColName.label, row, 1) + layout.addWidget(self.yColName, row, 2) + self.yColName.currentTextChanged.connect(self.checkColNameY) + self.fromTableRadiobutton.widgets.append(self.yColName) + sectionWidgets.append(self.yColName.label) + sectionWidgets.append(self.yColName) + + row += 1 + self.zColName = widgets.QCenteredComboBox() + self.zColName.addItem("None") + self.zColName.label = QLabel("Z coord. column: ") + layout.addWidget(self.zColName.label, row, 1) + layout.addWidget(self.zColName, row, 2) + self.zColName.currentTextChanged.connect(self.checkColNameZ) + self.fromTableRadiobutton.widgets.append(self.zColName) + sectionWidgets.append(self.zColName.label) + sectionWidgets.append(self.zColName) + + row += 1 + self.tColName = widgets.QCenteredComboBox() + self.tColName.addItem("None") + self.tColName.label = QLabel("Frame index column: ") + layout.addWidget(self.tColName.label, row, 1) + layout.addWidget(self.tColName, row, 2) + self.fromTableRadiobutton.widgets.append(self.tColName) + sectionWidgets.append(self.tColName.label) + sectionWidgets.append(self.tColName) + + if SizeT == 1: + self.tColName.clear() + self.tColName.addItem("None") + self.tColName.label.setVisible(False) + self.tColName.setVisible(False) + + self.fromTableRadiobutton.toggled.connect(self.enableRadioButtonWidgets) + self.enableRadioButtonWidgets(False, sender=self.fromTableRadiobutton) + + return row + 1, sectionWidgets + + def addManualEntrySection(self, row, layout, SizeT=1, **kwargs): + sectionWidgets = [] + + self.manualEntryRadiobutton = QRadioButton("Manual entry") + layout.addWidget(self.manualEntryRadiobutton, row, 0, 1, 2) + self.manualEntryRadiobutton.widgets = [] + sectionWidgets.append(self.manualEntryRadiobutton) + + row += 1 + self.manualXspinbox = widgets.NumericCommaLineEdit() + self.manualXspinbox.label = QLabel("X coords: ") + layout.addWidget(self.manualXspinbox.label, row, 1) + layout.addWidget(self.manualXspinbox, row, 2) + self.manualEntryRadiobutton.widgets.append(self.manualXspinbox) + sectionWidgets.append(self.manualXspinbox.label) + sectionWidgets.append(self.manualXspinbox) + + row += 1 + self.manualYspinbox = widgets.NumericCommaLineEdit() + self.manualYspinbox.label = QLabel("Y coords: ") + layout.addWidget(self.manualYspinbox.label, row, 1) + layout.addWidget(self.manualYspinbox, row, 2) + self.manualEntryRadiobutton.widgets.append(self.manualYspinbox) + sectionWidgets.append(self.manualYspinbox.label) + sectionWidgets.append(self.manualYspinbox) + + row += 1 + self.manualZspinbox = widgets.NumericCommaLineEdit() + self.manualZspinbox.label = QLabel("Z coords: ") + layout.addWidget(self.manualZspinbox.label, row, 1) + layout.addWidget(self.manualZspinbox, row, 2) + self.manualEntryRadiobutton.widgets.append(self.manualZspinbox) + sectionWidgets.append(self.manualZspinbox.label) + sectionWidgets.append(self.manualZspinbox) + + row += 1 + self.manualTspinbox = widgets.NumericCommaLineEdit() + self.manualTspinbox.label = QLabel("Frame numbers: ") + layout.addWidget(self.manualTspinbox.label, row, 1) + layout.addWidget(self.manualTspinbox, row, 2) + self.manualEntryRadiobutton.widgets.append(self.manualTspinbox) + sectionWidgets.append(self.manualTspinbox.label) + sectionWidgets.append(self.manualTspinbox) + + if SizeT == 1: + self.manualTspinbox.setVisible(False) + self.manualTspinbox.label.setVisible(False) + + self.manualEntryRadiobutton.toggled.connect(self.enableRadioButtonWidgets) + self.enableRadioButtonWidgets(False, sender=self.manualEntryRadiobutton) + + return row + 1, sectionWidgets + + def addWithMouseClicksSection(self, row, layout, imagesPath="", **kwargs): + sectionWidgets = [] + + self.clickEntryIsLoadedDf = None + + self.clickEntryRadiobutton = QRadioButton("Add points with mouse clicks") + layout.addWidget(self.clickEntryRadiobutton, row, 0, 1, 2) + self.clickEntryRadiobutton.widgets = [] + sectionWidgets.append(self.clickEntryRadiobutton) + + row += 1 + self.snapToMaxToggle = widgets.Toggle() + self.snapToMaxToggle.label = QLabel("Snap to closest maximum: ") + layout.addWidget(self.snapToMaxToggle.label, row, 1) + layout.addWidget(self.snapToMaxToggle, row, 2, alignment=Qt.AlignCenter) + sectionWidgets.append(self.snapToMaxToggle.label) + sectionWidgets.append(self.snapToMaxToggle) + + self.snapToMaxInfoButton = widgets.infoPushButton() + layout.addWidget(self.snapToMaxInfoButton, row, 3) + sectionWidgets.append(self.snapToMaxInfoButton) + + self.snapToMaxInfoButton.clicked.connect(self.showSnapToMaxButton) + self.clickEntryRadiobutton.widgets.append(self.snapToMaxToggle) + self.clickEntryRadiobutton.widgets.append(self.snapToMaxInfoButton) + + row += 1 + self.autoPilotToggle = widgets.Toggle() + self.autoPilotToggle.label = QLabel("Use auto-pilot: ") + layout.addWidget(self.autoPilotToggle.label, row, 1) + layout.addWidget(self.autoPilotToggle, row, 2, alignment=Qt.AlignCenter) + sectionWidgets.append(self.autoPilotToggle.label) + sectionWidgets.append(self.autoPilotToggle) + self.autoPilotInfoButton = widgets.infoPushButton() + layout.addWidget(self.autoPilotInfoButton, row, 3) + sectionWidgets.append(self.autoPilotInfoButton) + + self.autoPilotInfoButton.clicked.connect(self.showAutoPilotInfo) + self.clickEntryRadiobutton.widgets.append(self.autoPilotToggle) + self.clickEntryRadiobutton.widgets.append(self.autoPilotInfoButton) + + row += 1 + self.clickEntryTableEndname = widgets.alphaNumericLineEdit() + self.clickEntryTableEndname.setText("points_added_by_clicking") + self.clickEntryTableEndname.setAlignment(Qt.AlignCenter) + self.clickEntryTableEndname.label = QLabel("Table endname: ") + loadButton = widgets.browseFileButton(start_dir=imagesPath, ext={"CSV": ".csv"}) + layout.addWidget(loadButton, row, 3) + sectionWidgets.append(loadButton) + + loadButton.sigPathSelected.connect(self.loadClickEntryTable) + self.loadButton = loadButton + self.clickEntryLoadTableButton = loadButton + layout.addWidget(self.clickEntryTableEndname.label, row, 1) + layout.addWidget(self.clickEntryTableEndname, row, 2) + self.clickEntryRadiobutton.widgets.append(self.clickEntryTableEndname) + self.clickEntryTableEndname.editingFinished.connect( + self.emitCheckClickEntryTableEndnameExists + ) + sectionWidgets.append(self.clickEntryTableEndname) + sectionWidgets.append(self.clickEntryTableEndname.label) + + row += 1 + instructionsText = html_utils.paragraph( + "
Left-click to annotate a new point with a new id.

" + "Right-click to annotate a point with the same id

" + "Same click used to delete objects to annotate
" + "a point with id = 0 (negative prompt)

" + "Click on point to delete it", + font_size="11px", + ) + self.instructionsLabel = QLabel(instructionsText) + self.instructionsLabel.label = QLabel("Instructions") + layout.addWidget(self.instructionsLabel.label, row, 1) + layout.addWidget(self.instructionsLabel, row, 2) + self.clickEntryRadiobutton.widgets.append(self.instructionsLabel) + sectionWidgets.append(self.instructionsLabel) + sectionWidgets.append(self.instructionsLabel.label) + + self.clickEntryRadiobutton.toggled.connect(self.enableRadioButtonWidgets) + self.clickEntryRadiobutton.toggled.connect( + self.emitCheckClickEntryTableEndnameExists + ) + self.enableRadioButtonWidgets(False, sender=self.clickEntryRadiobutton) + + return row + 1, sectionWidgets + + def emitCheckClickEntryTableEndnameExists(self, *args, **kwargs): + if not self.clickEntryRadiobutton.isChecked(): + return + self.clickEntryIsLoadedDf = None + tableEndName = self.clickEntryTableEndname.text() + self.sigCheckClickEntryTableEndnameExists.emit(tableEndName, False) + + def loadClickEntryTable(self, csv_path): + self.clickEntryIsLoadedDf = None + posData = load.loadData(csv_path, "points") + posData.getBasenameAndChNames(qparent=self) + basename = posData.basename + filename = os.path.basename(csv_path) + filename, ext = os.path.splitext(filename) + if not basename.endswith("_"): + basename = f"{basename}_" + + endname = filename[len(basename) :] + self.clickEntryTableEndname.setText(endname) + self.sigCheckClickEntryTableEndnameExists.emit(endname, True) + + def showAutoPilotInfo(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + With Auto-pilot mode active, Cell-ACDC will automatically zoom on + to an object
+ to allow you clicking on the points you want to add.

+ You can then go to the next object by pressing the + Enter key or go back to the
+ previous object by pressing Backspace. + """) + msg.information(self, "Auto-pilot info", txt) + + def showSnapToMaxButton(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + With mode active, Cell-ACDC will + automatically add the point
+ to the closest maximum within the point footprint (defined in + the appearance settings). + """) + msg.information(self, "Snap to closest maximum info", txt) + + def closeEvent(self, event): + self.sigClosed.emit() + + def enableRadioButtonWidgets(self, enabled, sender=None): + if sender is None: + sender = self.sender() + for widget in sender.widgets: + widget.setDisabled(not enabled) + try: + widget.label.setDisabled(not enabled) + except: + pass + + def _readTable(self, path): + return load.load_df_points_layer(path) + + def tryAutoFillColNames(self, df): + if "x" in df.columns: + self.xColName.setCurrentText("x") + + if "y" in df.columns: + self.yColName.setCurrentText("y") + + if "z" in df.columns: + self.zColName.setCurrentText("z") + + if "frame_i" in df.columns: + self.tColName.setCurrentText("frame_i") + + def tablePathSelected(self, path): + self.tablePath.setText(path) + try: + df = self._readTable(path) + self.xColName.addItems(df.columns) + self.yColName.addItems(df.columns) + self.zColName.addItems(df.columns) + self.tColName.addItems(df.columns) + self.tryAutoFillColNames(df) + self.sigLoadedTable.emit(df, os.path.basename(path)) + self.browseTableButton.confirmAction() + except Exception as e: + traceback_format = traceback.format_exc() + self.sigCriticalReadTable.emit(traceback_format) + self.criticalReadTable(path, traceback_format) + self.tablePath.setText("") + + def criticalLenMismatchManualEntry(self): + txt = html_utils.paragraph(f""" + X coords and Y coords must have the same length. + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.critical(self, f"X and Y have different length", txt) + + def criticalColNameIsNone(self, axis): + txt = html_utils.paragraph(f""" + The "{axis.upper()} coord. column" cannot be "None" + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.critical(self, f"{axis.upper()} coord. is None", txt) + + def criticalReadTable(self, path, traceback_format): + txt = html_utils.paragraph(f""" + Something went wrong when reading the table from the + following path:

+ {path}

+ See the error message below. + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + detailsText = traceback_format + msg.critical(self, "Error when reading table", txt, detailsText=detailsText) + + def criticalEmptyTablePath(self): + txt = html_utils.paragraph(f""" + The table file path cannot be empty. + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.critical(self, "Table file path is empty", txt) + + def state(self): + _state = self.appearanceGroupbox.state() + return _state + + def _checkSelectedColName(self, colName, label): + labelsToCheck = ["z", "y", "x"] + labelsToCheck.remove(label) + for labelToCheck in labelsToCheck: + if colName.find(labelToCheck) != -1: + break + else: + return True + + txt = html_utils.paragraph(f""" + Are you sure that the {label.upper()} coord. column should contain + the letter {labelToCheck}? + """) + + msg = widgets.myMessageBox(wrapText=False) + _, noButton, yesButton = msg.warning( + self, + "Check column name", + txt, + buttonsTexts=("Cancel", "No, let me correct it", "Yes, I am"), + ) + if msg.cancel or msg.clickedButton == noButton: + return False + return True + + def checkColNameX(self, text): + accepted = self._checkSelectedColName(text, "x") + if accepted: + return + self.xColName.setCurrentText("None") + + def checkColNameY(self, text): + accepted = self._checkSelectedColName(text, "y") + if accepted: + return + self.yColName.setCurrentText("None") + + def checkColNameZ(self, text): + accepted = self._checkSelectedColName(text, "z") + if accepted: + return + self.zColName.setCurrentText("None") + + def ok_cb(self): + self.pointsData = {} + self.loadedDfInfo = None + self.loadedDf = None + self.weighingChannel = "" + if self.fromTableRadiobutton.isChecked(): + tablePath = self.tablePath.text() + if not tablePath: + self.criticalEmptyTablePath() + return + + try: + df = self._readTable(tablePath) + tColName = self.tColName.currentText() + xColName = self.xColName.currentText() + yColName = self.yColName.currentText() + zColName = self.zColName.currentText() + + self.loadedDfInfo = { + "filepath": tablePath, + "t": tColName, + "z": zColName, + "y": yColName, + "x": xColName, + } + + self._df_to_pointsData(df, tColName, zColName, yColName, xColName) + + except Exception as e: + traceback_format = traceback.format_exc() + self.sigCriticalReadTable.emit(traceback_format) + self.criticalReadTable(tablePath, traceback_format) + return + + if self.xColName.currentText() == "None": + self.criticalColNameIsNone("x") + return + if self.yColName.currentText() == "None": + self.criticalColNameIsNone("y") + return + + self.layerType = os.path.basename(self.tablePath.text()) + self.layerTypeIdx = 2 + elif self.centroidsRadiobutton.isChecked(): + self.layerType = "Centroids" + self.layerTypeIdx = 0 + elif self.weightedCentroidsRadiobutton.isChecked(): + channel = self.channelNameForWeightedCentr.currentText() + self.weighingChannel = channel + self.layerType = f"Centroids weighted by channel {channel}" + self.layerTypeIdx = 1 + elif self.manualEntryRadiobutton.isChecked(): + xx = self.manualXspinbox.values() + yy = self.manualYspinbox.values() + if len(xx) != len(yy): + self.criticalLenMismatchManualEntry() + return + zz = self.manualZspinbox.values() + tt = [t + 1 for t in self.manualTspinbox.values()] + df = pd.DataFrame({"x": xx, "y": yy, "id": np.arange(1, len(xx) + 1)}) + if tt: + df["t"] = tt + tCol = "t" + else: + tCol = "None" + if zz: + df["z"] = zz + zCol = "z" + else: + zCol = "None" + + self._df_to_pointsData(df, tCol, zCol, "y", "x") + + self.layerType = "Manual entry" + self.layerTypeIdx = 3 + elif self.clickEntryRadiobutton.isChecked(): + self.layerType = "Click to annotate point" + self.description = ( + "Left-click to add a point, click on point to delete it.\n" + "With auto-pilot you can navigate through object with Up/Down arrows." + ) + self.clickEntryTableEndnameText = self.clickEntryTableEndname.text() + self.layerTypeIdx = 4 + + self.cancel = False + symbol = self.appearanceGroupbox.symbolWidget.widget.currentText() + self.symbol = re.findall(r"\'(.+)\'", symbol)[0] + self.symbolText = symbol + self.color = self.appearanceGroupbox.colorButton.color() + self.pointSize = self.appearanceGroupbox.sizeSpinBox.value() + self.zHeight = self.appearanceGroupbox.zHeightSpinBox.value() + shortcutWidget = self.appearanceGroupbox.shortcutWidget + self.shortcut = shortcutWidget.widget.text() + self.keySequence = shortcutWidget.widget.keySequence + self.close() + + def _df_to_pointsData(self, df, tColName, zColName, yColName, xColName): + self.pointsData = load.loaded_df_to_points_data( + df, tColName, zColName, yColName, xColName + ) + + def showEvent(self, event) -> None: + if self._parent is None: + screen = self.screen() + else: + screen = self._parent.screen() + screenWidth = screen.size().width() + screenHeight = screen.size().height() + + maxHeight = screenHeight - 100 + + buttonHeight = self.buttonsLayout.okButton.minimumSizeHint().height() + height = ( + self.scrollArea.minimumHeightNoScrollbar() + + self.appearanceGroupbox.sizeHint().height() + + buttonHeight + + 70 + ) + width = self.scrollArea.minimumWidthNoScrollbar() + 50 + + height = min(height, maxHeight) + + self.resize(width, height) + + screenLeft = screen.geometry().x() + screenTop = screen.geometry().y() + w, h = self.width(), self.height() + left = int(screenLeft + screenWidth / 2 - w / 2) + top = int(screenTop + screenHeight / 2 - h / 2 - 20) + + self.move(left, top) + + +class EditPointsLayerAppearanceDialog(QBaseDialog): + sigClosed = Signal() + + def __init__(self, parent=None): + self.cancel = True + super().__init__(parent) + + self._parent = parent + + self.setWindowTitle("Custom annotation") + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + mainLayout = QVBoxLayout() + + self.appearanceGroupbox = _PointsLayerAppearanceGroupbox() + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(self.appearanceGroupbox) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + self.setFont(font) + + def restoreState(self, state): + self.appearanceGroupbox.restoreState(state) + + def closeEvent(self, event): + super().closeEvent(event) + self.sigClosed.emit() + + def state(self): + _state = self.appearanceGroupbox.state() + return _state + + def ok_cb(self): + self.cancel = False + symbol = self.appearanceGroupbox.symbolWidget.widget.currentText() + self.symbol = re.findall(r"\'(.+)\'", symbol)[0] + self.color = self.appearanceGroupbox.colorButton.color() + self.pointSize = self.appearanceGroupbox.sizeSpinBox.value() + self.zHeight = self.appearanceGroupbox.zHeightSpinBox.value() + shortcutWidget = self.appearanceGroupbox.shortcutWidget + self.shortcut = shortcutWidget.widget.text() + self.keySequence = shortcutWidget.widget.keySequence + self.close() + + +class QDialogWorkerProgress(QDialog): + sigClosed = Signal(bool) + + def __init__( + self, + title="Progress", + infoTxt="", + showInnerPbar=False, + pbarDesc="", + parent=None, + ): + self.workerFinished = False + self.aborted = False + self.clickCount = 0 + super().__init__(parent) + + abort_text = "Option+Command+C" if is_mac else "Ctrl+Alt+C" + self.abort_text = abort_text + + self.setWindowTitle(f"{title} ({abort_text} to abort)") + self.setWindowFlags(Qt.Window) + + mainLayout = QVBoxLayout() + pBarLayout = QGridLayout() + + if infoTxt: + infoLabel = QLabel(infoTxt) + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + + self.progressLabel = QLabel(pbarDesc) + + self.mainPbar = widgets.ProgressBarWithETA(self) + self.mainPbar.setValue(0) + pBarLayout.addWidget(self.mainPbar, 0, 0) + pBarLayout.addWidget(self.mainPbar.ETA_label, 0, 1) + + self.innerPbar = widgets.ProgressBarWithETA(self) + self.innerPbar.setValue(0) + pBarLayout.addWidget(self.innerPbar, 1, 0) + pBarLayout.addWidget(self.innerPbar.ETA_label, 1, 1) + if showInnerPbar: + self.innerPbar.show() + else: + self.innerPbar.hide() + + self.logConsole = widgets.QLogConsole() + + mainLayout.addWidget(self.progressLabel) + mainLayout.addLayout(pBarLayout) + mainLayout.addWidget(self.logConsole) + + self.setLayout(mainLayout) + # self.setModal(True) + + def keyPressEvent(self, event): + isCtrlAlt = event.modifiers() == (Qt.ControlModifier | Qt.AltModifier) + if isCtrlAlt and event.key() == Qt.Key_C: + doAbort = self.askAbort() + if doAbort: + self.aborted = True + self.workerFinished = True + self.close() + + def askAbort(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + Aborting with {self.abort_text} to abort is + not safe.

+ The system status cannot be predicted and + it will require a restart.

+ Are you sure you want to abort? + """) + yesButton, noButton = msg.critical( + self, "Are you sure you want to abort?", txt, buttonsTexts=("Yes", "No") + ) + return msg.clickedButton == yesButton + + def closeEvent(self, event): + if not self.workerFinished: + event.ignore() + return + + self.sigClosed.emit(self.aborted) + + def log(self, text): + self.logConsole.append(text) + + def show(self, app): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + QDialog.show(self) + screen = app.primaryScreen() + screenWidth = screen.size().width() + screenHeight = screen.size().height() + parentGeometry = self.parent().geometry() + mainWinLeft, mainWinWidth = parentGeometry.left(), parentGeometry.width() + mainWinTop, mainWinHeight = parentGeometry.top(), parentGeometry.height() + mainWinCenterX = int(mainWinLeft + mainWinWidth / 2) + mainWinCenterY = int(mainWinTop + mainWinHeight / 2) + + width = int(screenWidth / 3) + width = width if self.width() < width else self.width() + height = int(screenHeight / 3) + left = int(mainWinCenterX - width / 2) + left = left if left >= 0 else 0 + top = int(mainWinCenterY - height / 2) + + self.setGeometry(left, top, width, height) + + +class QDialogCombobox(QDialog): + def __init__( + self, + title, + ComboBoxItems, + informativeText, + CbLabel="Select value: ", + parent=None, + defaultChannelName=None, + iconPixmap=None, + centeredCombobox=False, + ): + self.cancel = True + self.selectedItemText = "" + self.selectedItemIdx = None + super().__init__(parent=parent) + self.setWindowTitle(title) + + mainLayout = QVBoxLayout() + infoLayout = QHBoxLayout() + topLayout = QHBoxLayout() + bottomLayout = QHBoxLayout() + + self.mainLayout = mainLayout + + if iconPixmap is not None: + label = QLabel() + # padding: top, left, bottom, right + # label.setStyleSheet("padding:5px 0px 12px 0px;") + label.setPixmap(iconPixmap) + infoLayout.addWidget(label) + + if informativeText: + infoLabel = QLabel(informativeText) + infoLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + + if CbLabel: + label = QLabel(CbLabel) + topLayout.addWidget(label, alignment=Qt.AlignRight) + + if centeredCombobox: + combobox = widgets.QCenteredComboBox() + else: + combobox = QComboBox() + combobox.addItems(ComboBoxItems) + if defaultChannelName is not None and defaultChannelName in ComboBoxItems: + combobox.setCurrentText(defaultChannelName) + self.ComboBox = combobox + topLayout.addWidget(combobox) + topLayout.setContentsMargins(0, 10, 0, 0) + + okButton = widgets.okPushButton("Ok") + + cancelButton = widgets.cancelPushButton("Cancel") + + bottomLayout.addStretch(1) + bottomLayout.addWidget(cancelButton) + bottomLayout.addSpacing(20) + bottomLayout.addWidget(okButton) + bottomLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(infoLayout) + mainLayout.addLayout(topLayout) + mainLayout.addLayout(bottomLayout) + self.setLayout(mainLayout) + + # self.setModal(True) + + # Connect events + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + self.loop = None + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.setFont(font) + + def ok_cb(self, checked=False): + self.cancel = False + self.selectedItemText = self.ComboBox.currentText() + self.selectedItemIdx = self.ComboBox.currentIndex() + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + QDialog.show(self) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class imageViewer(QMainWindow): + """Main Window.""" + + sigClosed = Signal() + sigHoveringImage = Signal(object, object) + + def __init__( + self, + parent=None, + posData=None, + button_toUncheck=None, + spinBox=None, + linkWindow=None, + enableOverlay=False, + isSigleFrame=False, + enableMirroredCursor=False, + ): + self.button_toUncheck = button_toUncheck + self.parent = parent + self.posData = posData + self.spinBox = spinBox + self.linkWindow = linkWindow + self.enableMirroredCursor = enableMirroredCursor + self.isSigleFrame = isSigleFrame + self.minMaxValuesMapper = None + """Initializer.""" + super().__init__(parent) + + if posData is None: + posData = self.parent.data[self.parent.pos_i] + self.posData = posData + self.enableOverlay = enableOverlay + + self.gui_createActions() + self.gui_createMenuBar() + self.gui_createToolBars() + + self.gui_createStatusBar() + + self.gui_createGraphics() + + self.gui_connectImgActions() + + self.gui_createImgWidgets() + self.gui_connectActions() + + self.gui_setSingleFrameMode(self.isSigleFrame) + + self.setupMirroredCursor() + + mainContainer = QWidget() + self.setCentralWidget(mainContainer) + + mainLayout = QGridLayout() + mainLayout.addWidget(self.graphLayout, 0, 0, 1, 1) + mainLayout.addLayout(self.img_Widglayout, 1, 0) + + mainContainer.setLayout(mainLayout) + + self.frame_i = posData.frame_i + self.num_frames = posData.SizeT + + version = utils.read_version() + self.setWindowTitle(f"Cell-ACDC v{version} - {posData.relPath}") + + def gui_createActions(self): + # File actions + self.exitAction = QAction("&Exit", self) + + # Toolbar actions + self.prevAction = QAction("Previous frame", self) + self.nextAction = QAction("Next Frame", self) + self.jumpForwardAction = QAction("Jump to 10 frames ahead", self) + self.jumpBackwardAction = QAction("Jump to 10 frames back", self) + self.prevAction.setShortcut("left") + self.nextAction.setShortcut("right") + self.jumpForwardAction.setShortcut("up") + self.jumpBackwardAction.setShortcut("down") + self.addAction(self.nextAction) + self.addAction(self.prevAction) + self.addAction(self.jumpBackwardAction) + self.addAction(self.jumpForwardAction) + if self.enableOverlay: + self.overlayButton = widgets.rightClickToolButton(parent=self) + self.overlayButton.setIcon(QIcon(":overlay.svg")) + self.overlayButton.setCheckable(True) + + def gui_createMenuBar(self): + menuBar = self.menuBar() + # File menu + fileMenu = QMenu("&File", self) + menuBar.addMenu(fileMenu) + # fileMenu.addAction(self.newAction) + fileMenu.addAction(self.exitAction) + + def gui_createToolBars(self): + toolbarSize = 30 + + editToolBar = QToolBar("Edit", self) + editToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) + self.addToolBar(editToolBar) + + self.editToolBar = editToolBar + + if self.enableOverlay: + editToolBar.addWidget(self.overlayButton) + + if self.linkWindow: + # Insert a spacing + editToolBar.addWidget(QLabel(" ")) + self.linkWindowCheckbox = QCheckBox("Link to main GUI") + self.linkWindowCheckbox.setChecked(True) + editToolBar.addWidget(self.linkWindowCheckbox) + + if self.enableMirroredCursor: + self.showMirroredCursorCheckbox = QCheckBox( + "Show mirrored cursor from main window" + ) + self.showMirroredCursorCheckbox.setChecked(True) + editToolBar.addWidget(self.showMirroredCursorCheckbox) + + def setupMirroredCursor(self): + self.cursor = pg.ScatterPlotItem( + symbol="+", + pxMode=True, + pen=pg.mkPen("k", width=1), + brush=pg.mkBrush("w"), + size=16, + tip=None, + ) + self.Plot.addItem(self.cursor) + + def gui_connectActions(self): + self.exitAction.triggered.connect(self.close) + self.prevAction.triggered.connect(self.prev_frame) + self.nextAction.triggered.connect(self.next_frame) + self.jumpForwardAction.triggered.connect(self.skip10ahead_frames) + self.jumpBackwardAction.triggered.connect(self.skip10back_frames) + if self.enableOverlay: + self.overlayButton.toggled.connect(self.overlay_cb) + self.overlayButton.sigRightClick.connect(self.showOverlayContextMenu) + + def gui_setSingleFrameMode(self, isSingleFrame: bool): + if not isSingleFrame: + return + + self.framesScrollBar.setDisabled(True) + self.framesScrollBar.setVisible(False) + self.frameLabel.hide() + self.t_label.hide() + self.prevAction.triggered.disconnect() + self.nextAction.triggered.disconnect() + self.jumpForwardAction.triggered.disconnect() + self.jumpBackwardAction.triggered.disconnect() + self.editToolBar.setVisible(False) + + def showOverlayContextMenu(self, event): + if not self.overlayButton.isChecked(): + return + + if self.parent is not None: + self.overlayContextMenu.exec_(QCursor.pos()) + + def gui_createStatusBar(self): + self.statusbar = self.statusBar() + # Temporary message + self.statusbar.showMessage("Ready", 3000) + # Permanent widget + self.wcLabel = QLabel(f"") + self.statusbar.addPermanentWidget(self.wcLabel) + + def gui_createGraphics(self): + self.graphLayout = pg.GraphicsLayoutWidget() + + # Plot Item container for image + self.Plot = pg.PlotItem() + self.Plot.invertY(True) + self.Plot.setAspectLocked(True) + self.Plot.hideAxis("bottom") + self.Plot.hideAxis("left") + self.graphLayout.addItem(self.Plot, row=1, col=1) + + # Image Item + self.img = widgets.BaseImageItem() + self.img.setEnableAutoLevels(True) + self.Plot.addItem(self.img) + + # Image histogram + self.imgGrad = widgets.myHistogramLUTitem(isViewer=True) + self.imgGrad.gradient.showMenu = self.showLutItemOverlayContextMenu + self.imgGrad.vb.raiseContextMenu = lambda x: None + self.imgGrad.setImageItem(self.img) + self.graphLayout.addItem(self.imgGrad, row=1, col=0) + + # Current frame text + self.frameLabel = pg.LabelItem(justify="center", color="w", size="14pt") + self.frameLabel.setText(" ") + self.graphLayout.addItem(self.frameLabel, row=2, col=0, colspan=2) + + if not self.enableOverlay: + return + + def gui_createOverlayItems(self): + self.createOverlayChannelsActions() + self.overlayLayersItems = {} + for ch in self.posData.chNames: + if ch == self.parent.user_ch_name: + continue + overlayItems = self.getOverlayItems(ch) + imageItem, lutItem, alphaScrollbar = overlayItems + lutItem.vb.raiseContextMenu = lambda x: None + lutItem.gradient.showMenu = self.showLutItemOverlayContextMenu + lutItem.overlayColorButton.sigColorChanging.connect(self.updateOlColors) + self.addAlphaScrollbar(ch, imageItem, alphaScrollbar) + self.overlayLayersItems[ch] = overlayItems + self.Plot.addItem(imageItem) + + def createOverlayChannelsActions(self): + self.overlayLutItemAdditionalActions = [] + separator = QAction(self) + separator.setSeparator(True) + self.overlayLutItemAdditionalActions.append(separator) + section = self.imgGrad.gradient.menu.addSection("Select channel to adjust: ") + self.overlayLutItemAdditionalActions.append(section) + self.imgGrad.gradient.menu.removeAction(section) + + self.overlayChNamesActionGroup = QActionGroup(self) + self.overlayChNamesActionGroup.setExclusive(True) + for chName in self.posData.chNames: + action = QAction(chName, self) + action.setCheckable(True) + if chName == self.parent.user_ch_name: + action.setChecked(True) + self.overlayChNamesActionGroup.addAction(action) + self.overlayChNamesActionGroup.triggered.connect( + self.chNameGradientActionClicked + ) + + def chNameGradientActionClicked(self, action): + # Action triggered from lutItem + self.checkedOverlayChName = action.text() + if action.text() == self.posData.user_ch_name: + self.setOverlayItemsVisible("", False) + else: + self.setOverlayItemsVisible(action.text(), True) + + def showLutItemOverlayContextMenu(self, event): + lutItem = self.currentLutItem + + for action in self.overlayLutItemAdditionalActions: + try: + lutItem.gradient.menu.removeAction(action) + except Exception as e: + pass + + for action in self.overlayChNamesActionGroup.actions(): + try: + lutItem.gradient.menu.removeAction(action) + except Exception as e: + pass + + if self.overlayButton.isChecked(): + for action in self.overlayLutItemAdditionalActions: + lutItem.gradient.menu.addAction(action) + + for action in self.overlayChNamesActionGroup.actions(): + if action.text() == self.posData.user_ch_name: + lutItem.gradient.menu.addAction(action) + continue + for filename in self.posData.ol_data: + if filename.endswith(action.text()): + lutItem.gradient.menu.addAction(action) + break + if filename.endswith(f"{action.text()}_aligned"): + lutItem.gradient.menu.addAction(action) + break + + try: + # Convert QPointF to QPoint + lutItem.gradient.menu.popup(event.screenPos().toPoint()) + except AttributeError: + lutItem.gradient.menu.popup(event.screenPos()) + + def gui_connectImgActions(self): + self.img.hoverEvent = self.gui_hoverEventImg + + def gui_createImgWidgets(self): + if self.posData is None: + posData = self.parent.data[self.parent.pos_i] + else: + posData = self.posData + self.img_Widglayout = QGridLayout() + + # Frames scrollbar + self.framesScrollBar = QScrollBar(Qt.Horizontal) + # self.framesScrollBar.setFixedHeight(20) + self.framesScrollBar.setMinimum(1) + self.framesScrollBar.setMaximum(posData.SizeT) + t_label = QLabel("frame ") + _font = QFont() + _font.setPixelSize(12) + t_label.setFont(_font) + self.img_Widglayout.addWidget(t_label, 0, 0, alignment=Qt.AlignRight) + self.img_Widglayout.addWidget(self.framesScrollBar, 0, 1, 1, 20) + self.t_label = t_label + self.framesScrollBar.valueChanged.connect(self.framesScrollBarMoved) + + # z-slice scrollbar + self.zSliceScrollBar = QScrollBar(Qt.Horizontal) + # self.zSliceScrollBar.setFixedHeight(20) + self.zSliceScrollBar.setMaximum(self.posData.SizeZ - 1) + _z_label = QLabel("z-slice ") + _font = QFont() + _font.setPixelSize(12) + _z_label.setFont(_font) + self.z_label = _z_label + self.img_Widglayout.addWidget(_z_label, 1, 0, alignment=Qt.AlignCenter) + self.img_Widglayout.addWidget(self.zSliceScrollBar, 1, 1, 1, 20) + + if self.posData.SizeZ == 1: + self.zSliceScrollBar.setDisabled(True) + self.zSliceScrollBar.setVisible(False) + _z_label.setVisible(False) + + self.img_Widglayout.setContentsMargins(100, 0, 50, 0) + self.zSliceScrollBar.valueChanged.connect(self.update_z_slice) + + if self.enableOverlay: + self.setOverlayColors() + self.gui_createOverlayItems() + self.createOverlayContextMenu() + + self.img.alphaScrollbar = self.addAlphaScrollbar( + self.parent.user_ch_name, self.img + ) + + def getOverlayItems(self, channelName): + imageItem = pg.ImageItem() + imageItem.setOpacity(0.5) + + lutItem = widgets.myHistogramLUTitem(isViewer=True) + + lutItem.setImageItem(imageItem) + lutItem.vb.raiseContextMenu = lambda x: None + initColor = self.overlayRGBs.pop(0) + self.parent.initColormapOverlayLayerItem(initColor, lutItem) + lutItem.addOverlayColorButton(initColor, channelName) + lutItem.initColor = initColor + lutItem.hide() + + alphaScrollBar = self.addAlphaScrollbar(channelName, imageItem) + return imageItem, lutItem, alphaScrollBar + + def setMirroredCursorPos(self, x, y): + if not self.enableMirroredCursor: + return + + if not self.showMirroredCursorCheckbox.isChecked(): + return + + self.cursor.setData([x], [y]) + + def setOverlayColors(self): + self.overlayRGBs = [ + (255, 255, 0), + (252, 72, 254), + (49, 222, 134), + (22, 108, 27), + ] + cmap = matplotlib.colormaps["gist_rainbow"] + self.overlayRGBs.extend( + [tuple([round(c * 255) for c in cmap(i)][:3]) for i in np.linspace(0, 1, 8)] + ) + + def setOpacityOverlayLayersItems(self, value, imageItem=None): + if imageItem is None: + imageItem = self.sender().imageItem + alpha = value / self.sender().maximum() + else: + alpha = value + imageItem.setOpacity(alpha) + + def overlay_cb(self, checked): + if checked: + if self.posData.ol_data is None: + selectedChannels = self.askSelectOverlayChannel() + if selectedChannels is None: + self.overlayButton.toggled.disconnect() + self.overlayButton.setChecked(False) + self.overlayButton.toggled.connect(self.overlay_cb) + return + success = self.parent.loadOverlayData(selectedChannels) + if not success: + return False + lastChannel = selectedChannels[-1] + self.checkedOverlayChName = lastChannel + imageItem = self.overlayLayersItems[lastChannel][0] + self.setOpacityOverlayLayersItems(0.5, imageItem=imageItem) + self.img.setOpacity(0.5) + self.setCheckedOverlayContextMenusActions(selectedChannels) + else: + self.checkedOverlayChName = self.parent.imgGrad.checkedChannelname + selectedChannels = self.parent.checkedOverlayChannels + self.setCheckedOverlayContextMenusActions(selectedChannels) + self.setOverlayItemsVisible(self.checkedOverlayChName, True) + else: + self.img.setOpacity(1.0) + self.setOverlayItemsVisible("", False) + for items in self.overlayLayersItems.values(): + imageItem = items[0] + imageItem.clear() + self.update_img() + + def createOverlayContextMenu(self): + ch_names = [ + ch for ch in self.posData.chNames if ch != self.posData.user_ch_name + ] + self.overlayContextMenu = QMenu() + self.overlayContextMenu.addSeparator() + self.checkedOverlayChannels = set() + for chName in ch_names: + action = QAction(chName, self.overlayContextMenu) + action.setCheckable(True) + action.toggled.connect(self.overlayChannelToggled) + self.overlayContextMenu.addAction(action) + + def setCheckedOverlayContextMenusActions(self, channelNames): + for action in self.overlayContextMenu.actions(): + if action.text() not in channelNames: + continue + action.setChecked(True) + self.checkedOverlayChannels.add(action.text()) + + def overlayChannelToggled(self, checked): + # Action toggled from overlayButton context menu + channelName = self.sender().text() + if checked: + posData = self.posData + if channelName not in posData.loadedFluoChannels: + self.parent.loadOverlayData([channelName], addToExisting=True) + self.setOverlayItemsVisible(channelName, True) + self.checkedOverlayChannels.add(channelName) + self.updateOlColors(None) + else: + self.checkedOverlayChannels.remove(channelName) + imageItem = self.overlayLayersItems[channelName][0] + imageItem.clear() + try: + channelToShow = next(iter(self.checkedOverlayChannels)) + self.setOverlayItemsVisible(channelToShow, True) + except StopIteration: + self.setOverlayItemsVisible("", False) + self.update_img() + + def updateOlColors(self, button): + lutItem = self.overlayLayersItems[self.checkedOverlayChName][1] + rgb = lutItem.overlayColorButton.color().getRgb()[:3] + self.parent.initColormapOverlayLayerItem(rgb, lutItem) + lutItem.overlayColorButton.setColor(rgb) + + def addAlphaScrollbar(self, channelName, imageItem, alphaScrollBar=None): + if alphaScrollBar is None: + alphaScrollBar = QScrollBar(Qt.Horizontal) + label = QLabel(f"Alpha {channelName}") + label.setFont(font) + label.hide() + alphaScrollBar.imageItem = imageItem + alphaScrollBar.label = label + alphaScrollBar.setFixedHeight(self.parent.h) + alphaScrollBar.hide() + alphaScrollBar.setMinimum(0) + alphaScrollBar.setMaximum(40) + alphaScrollBar.setValue(20) + alphaScrollBar.setToolTip( + f"Control the alpha value of the overlaid channel {channelName}.\n" + "alpha=0 results in NO overlay,\n" + "alpha=1 results in only fluorescence data visible" + ) + self.img_Widglayout.addWidget( + alphaScrollBar.label, 2, 0, alignment=Qt.AlignRight + ) + self.img_Widglayout.addWidget(alphaScrollBar, 2, 1, 1, 20) + sp = alphaScrollBar.label.sizePolicy() + sp.setRetainSizeWhenHidden(True) + alphaScrollBar.label.setSizePolicy(sp) + + sp = alphaScrollBar.sizePolicy() + sp.setRetainSizeWhenHidden(True) + alphaScrollBar.setSizePolicy(sp) + + alphaScrollBar.valueChanged.connect(self.setOpacityOverlayLayersItems) + return alphaScrollBar + + def setOverlayItemsVisible(self, channelName, visible): + if visible: + self.imgGrad.hide() + self.img.alphaScrollbar.hide() + self.img.alphaScrollbar.label.hide() + try: + self.graphLayout.removeItem(self.imgGrad) + except Exception as e: + pass + itemsToShow = None + for name, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB = items + if name == channelName: + itemsToShow = items + else: + lutItem.hide() + alphaSB.hide() + alphaSB.label.hide() + try: + self.graphLayout.removeItem(lutItem) + except Exception as e: + pass + + if itemsToShow is None: + self.graphLayout.addItem(self.imgGrad, row=1, col=0) + self.imgGrad.show() + self.currentLutItem = self.imgGrad + self.img.alphaScrollbar.show() + self.img.alphaScrollbar.label.show() + else: + _, lutItem, alphaSB = itemsToShow + lutItem.show() + alphaSB.show() + alphaSB.label.show() + self.currentLutItem = lutItem + self.graphLayout.addItem(lutItem, row=1, col=0) + else: + if self.overlayButton.isChecked(): + self.img.alphaScrollbar.show() + self.img.alphaScrollbar.label.show() + else: + self.img.alphaScrollbar.hide() + self.img.alphaScrollbar.label.hide() + for name, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB = items + lutItem.hide() + alphaSB.hide() + alphaSB.label.hide() + try: + self.graphLayout.removeItem(lutItem) + except Exception as e: + pass + self.graphLayout.addItem(self.imgGrad, row=1, col=0) + self.imgGrad.show() + self.currentLutItem = self.imgGrad + + def framesScrollBarMoved(self, frame_n): + self.frame_i = frame_n - 1 + self.t_label.setText(f"frame n. {self.frame_i + 1}/{self.num_frames}") + if self.spinBox is not None: + self.spinBox.setValue(frame_n) + self.update_img() + + def gui_hoverEventImg(self, event): + # Update x, y, value label bottom right + try: + x, y = event.pos() + xdata, ydata = int(x), int(y) + _img = self.img.image + Y, X = _img.shape + if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: + val = _img[ydata, xdata] + self.wcLabel.setText(f"(x={x:.2f}, y={y:.2f}, value={val:.2f})") + else: + self.wcLabel.setText(f"") + except Exception as e: + self.wcLabel.setText(f"") + + emitHovering = ( + self.enableMirroredCursor and self.showMirroredCursorCheckbox.isChecked() + ) + if emitHovering: + if event.isExit(): + x, y = None, None + else: + x, y = event.pos() + self.sigHoveringImage.emit(x, y) + self.cursor.setData([], []) + + def next_frame(self): + if self.frame_i < self.num_frames - 1: + self.frame_i += 1 + else: + self.frame_i = 0 + self.update_img() + + def prev_frame(self): + if self.frame_i > 0: + self.frame_i -= 1 + else: + self.frame_i = self.num_frames - 1 + self.update_img() + + def skip10ahead_frames(self): + if self.frame_i < self.num_frames - 10: + self.frame_i += 10 + else: + self.frame_i = 0 + self.update_img() + + def skip10back_frames(self): + if self.frame_i > 9: + self.frame_i -= 10 + else: + self.frame_i = self.num_frames - 1 + self.update_img() + + def update_z_slice(self, z): + if self.posData is None: + posData = self.parent.data[self.parent.pos_i] + else: + posData = self.posData + idx = (posData.filename, posData.frame_i) + posData.segmInfo_df.at[idx, "z_slice_used_gui"] = z + + self.z_label.setText(f"z-slice {z + 1:02}/{posData.SizeZ}") + self.img.setCurrentZsliceIndex(z) + self.update_img() + + def getImage(self): + posData = self.posData + frame_i = self.frame_i + if posData.SizeZ > 1: + idx = (posData.filename, frame_i) + z = posData.segmInfo_df.at[idx, "z_slice_used_gui"] + zProjHow = posData.segmInfo_df.at[idx, "which_z_proj_gui"] + img = posData.img_data[frame_i] + if zProjHow == "single z-slice": + self.zSliceScrollBar.setSliderPosition(z) + self.z_label.setText(f"z-slice {z + 1:02}/{posData.SizeZ}") + img = img[z].copy() + elif zProjHow == "max z-projection": + img = img.max(axis=0).copy() + elif zProjHow == "mean z-projection": + img = img.mean(axis=0).copy() + elif zProjHow == "median z-proj.": + img = np.median(img, axis=0).copy() + else: + img = posData.img_data[frame_i].copy() + return img + + def update_img(self): + self.frameLabel.setText(f"Current frame = {self.frame_i + 1}/{self.num_frames}") + if self.parent is None: + img = self.getImage() + else: + img = self.parent.getImage(frame_i=self.frame_i, raw=True) + + self.img.setCurrentFrameIndex(self.frame_i) + self.img.setImage(img) + self.framesScrollBar.setSliderPosition(self.frame_i + 1) + + if not self.enableOverlay: + return + + if not self.overlayButton.isChecked(): + return + + self.setOverlayImages(frame_i=self.frame_i) + + def askSelectOverlayChannel(self): + ch_names = [ + ch for ch in self.posData.chNames if ch != self.posData.user_ch_name + ] + selectFluo = widgets.QDialogListbox( + "Select channel", + "Select channel names to overlay:\n", + ch_names, + multiSelection=True, + parent=self, + ) + selectFluo.exec_() + if selectFluo.cancel: + return + + return selectFluo.selectedItemsText + + def setOverlayImages(self, frame_i=None): + posData = self.posData + for filename in posData.ol_data: + chName = utils.get_chname_from_basename( + filename, posData.basename, remove_ext=False + ) + if chName not in self.checkedOverlayChannels: + continue + + imageItem = self.overlayLayersItems[chName][0] + ol_img = self.parent.getOlImg(filename, frame_i=frame_i) + imageItem.setImage(ol_img) + + def closeEvent(self, event): + if self.button_toUncheck is not None: + self.button_toUncheck.setChecked(False) + self.sigClosed.emit() + + def show(self, left=None, top=None): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + QMainWindow.show(self) + try: + self.framesScrollBar.setFixedHeight(self.parent.h) + except Exception as e: + pass + try: + self.zSliceScrollBar.setFixedHeight(self.parent.h) + except Exception as e: + pass + + try: + self.img.alphaScrollbar.setFixedHeight(self.parent.h) + except Exception as e: + pass + if left is not None and top is not None: + self.setGeometry(left, top, 850, 800) + + +class askStopFrameSegm(QDialog): + def __init__(self, user_ch_file_paths, user_ch_name, parent=None): + self.parent = parent + self.cancel = True + + super().__init__(parent) + self.setWindowTitle("Enter stop frame") + + self.visualizeWindows = [] + + mainLayout = QVBoxLayout() + buttonsLayout = QHBoxLayout() + + # Message + infoTxt = html_utils.paragraph(""" + Enter a stop frame number when to stop + segmentation for each Position loaded: + """) + infoLabel = QLabel(infoTxt, self) + infoLabel.setAlignment(Qt.AlignCenter) + # padding: top, left, bottom, right + infoLabel.setStyleSheet("padding:0px 0px 8px 0px;") + + self.dataDict = {} + + exp_path_pos_mapper = path.get_exp_path_pos_foldernames_mapper( + user_ch_file_paths + ) + + columnsLayout = QHBoxLayout() + mainScrollArea = widgets.ScrollArea() + mainScrollAreaWidget = QWidget() + mainScrollAreaWidget.setLayout(columnsLayout) + mainScrollArea.setWidget(mainScrollAreaWidget) + self.mainScrollArea = mainScrollArea + + # Form layout widget + self.spinBoxes = [] + self.tab_idx = 0 + iter_items = exp_path_pos_mapper.items() + self.groupboxScrollAreas = [] + + for col, (exp_path, pos_folders_files) in enumerate(iter_items): + groupboxScrollArea = widgets.ScrollArea() + self.groupboxScrollAreas.append(groupboxScrollArea) + groupbox = QGroupBox() + groupbox.setCheckable(False) + groupbox.setToolTip(exp_path) + groupboxLayout = QFormLayout() + groupbox.setLayout(groupboxLayout) + groupboxScrollArea.setWidget(groupbox) + columnsLayout.addWidget(groupboxScrollArea) + pos_folders = pos_folders_files["pos_foldernames"] + filenames = pos_folders_files["filenames"] + for i, pos_foldername in enumerate(pos_folders): + img_filename = filenames[i] + images_path = os.path.join(exp_path, pos_foldername, "Images") + img_path = os.path.join(images_path, img_filename) + spinBox = widgets.mySpinBox() + spinBox.sigTabEvent.connect(self.keyTabEventSpinbox) + posData = load.loadData(img_path, user_ch_name, QParent=parent) + posData.getBasenameAndChNames(qparent=self) + posData.buildPaths() + posData.loadOtherFiles( + load_segm_data=False, + load_metadata=True, + loadSegmInfo=True, + ) + spinBox.setMaximum(posData.SizeT) + stopFrameNum = posData.readLastUsedStopFrameNumber() + if stopFrameNum is None: + spinBox.setValue(posData.SizeT) + else: + spinBox.setValue(stopFrameNum) + spinBox.setAlignment(Qt.AlignCenter) + visualizeButton = widgets.viewPushButton("Visualize") + visualizeButton.clicked.connect(self.visualize_cb) + formLabel = QLabel(html_utils.paragraph(f"{pos_foldername} ")) + layout = QHBoxLayout() + layout.addWidget(formLabel, alignment=Qt.AlignRight) + layout.addWidget(spinBox) + layout.addWidget(visualizeButton) + self.dataDict[visualizeButton] = (spinBox, posData) + groupboxLayout.addRow(layout) + spinBox.idx = i + self.spinBoxes.append(spinBox) + + fm = QFontMetrics(self.font()) + elidedTitle = fm.elidedText( + exp_path, Qt.ElideLeft, groupbox.sizeHint().width() + ) + groupbox.setTitle(elidedTitle) + + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + mainLayout.addWidget(mainScrollArea) + + okButton = widgets.okPushButton("Ok") + okButton.setShortcut(Qt.Key_Enter) + + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + # # self.setModal(True) + + def keyTabEventSpinbox(self, event, sender): + self.tab_idx += 1 + if self.tab_idx >= len(self.spinBoxes): + self.tab_idx = 0 + focusSpinbox = self.spinBoxes[self.tab_idx] + focusSpinbox.setFocus() + + def saveStopFrameNumbers(self): + for spinBox, posData in self.dataDict.values(): + posData.metadata_df.at["stop_frame_num", "values"] = spinBox.value() + posData.metadataToCsv() + + def ok_cb(self, event): + self.cancel = False + try: + self.saveStopFrameNumbers() + except Exception as err: + printl(traceback.format_exc()) + self.stopFrames = [ + spinBox.value() for spinBox, posData in self.dataDict.values() + ] + self.close() + + def closeEvent(self, event): + for window in self.visualizeWindows: + window.close() + + def visualize_cb(self, checked=True): + self.setDisabled(True) + spinBox, posData = self.dataDict[self.sender()] + print("Loading image data...") + posData.loadImgData() + posData.frame_i = spinBox.value() - 1 + win = plot.imshow( + posData.img_data, lut="gray", figure_title=posData.relPath, block=False + ) + self.visualizeWindows.append(win) + self.setDisabled(False) + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + screenSize = self.screen().size() + maxWidth = screenSize.width() - 50 + maxHeight = screenSize.height() - 100 + width, height = 0, 0 + for scrollArea in self.groupboxScrollAreas: + width += scrollArea.minimumWidthNoScrollbar() + scrollAreaHeight = scrollArea.minimumHeightNoScrollbar() + if scrollAreaHeight > height: + height = scrollAreaHeight + + width += 70 + height += self.sizeHint().height() - self.mainScrollArea.sizeHint().height() + + if width > maxWidth: + width = maxWidth + + if height > maxHeight: + height = maxHeight + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + self.resize(width, height) + self.move(25, 50) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QLineEditDialog(QDialog): + def __init__( + self, + title="Entry messagebox", + msg="Entry value", + defaultTxt="", + parent=None, + allowedValues=None, + warnLastFrame=False, + isInteger=False, + isFloat=False, + stretchEntry=True, + allowEmpty=True, + allowedTextEntries=None, + allowText=False, + lastVisitedFrame=None, + allowList=False, + ): + QDialog.__init__(self, parent) + + self.loop = None + self.cancel = True + self.assignNewID = False + self.allowedValues = allowedValues + self.warnLastFrame = warnLastFrame + self.isFloat = isFloat + self.allowEmpty = allowEmpty + self.isInteger = isInteger + self.allowedTextEntries = allowedTextEntries + self.allowText = allowText + self.lastVisitedFrame = lastVisitedFrame + if allowedValues and warnLastFrame: + self.maxValue = max(allowedValues) + + self.setWindowTitle(title) + + # Layouts + mainLayout = QVBoxLayout() + LineEditLayout = QVBoxLayout() + buttonsLayout = QHBoxLayout() + + # Widgets + if not msg.startswith(" np.iinfo(np.uint32).max: + self.entryWidget.setText(str(np.iinfo(np.uint32).max)) + except Exception as e: + text = text.replace(newChar, "") + self.entryWidget.setText(text) + return + + if self.allowedValues is not None: + currentVal = self.value() + if self.allowList: + currentVal = currentVal[-1] + if currentVal not in self.allowedValues: + self.notValidLabel.setText(f"{currentVal} not existing!") + else: + self.notValidLabel.setText("") + + def warnValLessLastFrame(self, val): + msg = widgets.myMessageBox() + warn_txt = html_utils.paragraph(f""" + WARNING: saving until a frame number below the last visited + frame ({self.lastVisitedFrame}) will result in LOSS of information + about any edit or annotation you did on frames + {val + 1}-{self.lastVisitedFrame}.

+ Are you sure you want to proceed? + """) + msg.warning( + self, + "WARNING: Potential loss of information", + warn_txt, + buttonsTexts=("Cancel", "Yes, I am sure."), + ) + return msg.cancel + + def warnValMoreLastVisitedFrame(self, val): + msg = widgets.myMessageBox() + warn_txt = html_utils.paragraph(f""" + The last visited/validated frame is {self.lastVisitedFrame} + .

+ Are you sure you want to save until frame n. {val}?
+ """) + msg.warning( + self, + "Saving past last visited frame", + warn_txt, + buttonsTexts=("Cancel", "Yes, I am sure."), + ) + return msg.cancel + + def ok_cb(self, event): + if not self.allowEmpty and not self.entryWidget.text(): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.critical( + self, + "Empty text", + html_utils.paragraph("Text entry field cannot be empty"), + ) + return + if self.allowedTextEntries is not None: + if self.entryWidget.text() not in self.allowedTextEntries: + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph( + f'"{self.entryWidget.text()}" is not a valid entry.

' + "Valid entries are:
" + f"{html_utils.to_list(self.allowedTextEntries)}" + ) + msg.critical(self, "Not a valid entry", txt) + return + + if self.allowedValues: + if self.notValidLabel.text(): + return + + val = self.value() + + if self.warnLastFrame and self.lastVisitedFrame is not None: + if val < self.lastVisitedFrame: + cancel = self.warnValLessLastFrame(val) + if cancel: + return + + if self.lastVisitedFrame is not None: + if val > self.lastVisitedFrame: + cancel = self.warnValMoreLastVisitedFrame(val) + if cancel: + return + + self.cancel = False + try: + self.EntryID = int(val) + except Exception as err: + self.EntryID = val + + self.enteredValue = val + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QtSelectItems(QDialog): + def __init__( + self, + title, + items, + informativeText, + CbLabel="Select value: ", + parent=None, + showInFileManagerPath=None, + ): + self.cancel = True + self.selectedItemsText = "" + self.selectedItemsIdx = None + self.showInFileManagerPath = showInFileManagerPath + self.items = items + super().__init__(parent) + self.setWindowTitle(title) + + mainLayout = QVBoxLayout() + topLayout = QHBoxLayout() + self.topLayout = topLayout + bottomLayout = QHBoxLayout() + + stretchRow = 0 + if informativeText: + infoLabel = QLabel(informativeText) + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + stretchRow = 1 + + label = QLabel(CbLabel) + topLayout.addWidget(label, alignment=Qt.AlignRight) + + combobox = QComboBox(self) + combobox.addItems(items) + self.ComboBox = combobox + topLayout.addWidget(combobox) + + okButton = widgets.okPushButton("Ok") + cancelButton = widgets.cancelPushButton("Cancel") + if showInFileManagerPath is not None: + txt = utils.get_open_filemaneger_os_string() + showInFileManagerButton = widgets.showInFileManagerButton(txt) + + bottomLayout.addStretch(1) + bottomLayout.addWidget(cancelButton) + bottomLayout.addSpacing(20) + if showInFileManagerPath is not None: + bottomLayout.addWidget(showInFileManagerButton) + bottomLayout.addWidget(okButton) + + multiPosButton = QPushButton("Multiple selection") + multiPosButton.setCheckable(True) + self.multiPosButton = multiPosButton + bottomLayout.addWidget(multiPosButton, alignment=Qt.AlignLeft) + + listBox = widgets.listWidget() + listBox.addItems(items) + listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + listBox.setCurrentRow(0) + listBox.setFont(font) + topLayout.addWidget(listBox) + listBox.hide() + self.ListBox = listBox + + mainLayout.addLayout(topLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(bottomLayout) + + self.setLayout(mainLayout) + self.mainLayout = mainLayout + self.topLayout = topLayout + + # self.setModal(True) + + # Connect events + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + multiPosButton.toggled.connect(self.toggleMultiSelection) + if showInFileManagerPath is not None: + showInFileManagerButton.clicked.connect(self.showInFileManager) + + self.setFont(font) + + def setSelectedItems(self, selectedItemsText): + if self.multiPosButton.isChecked(): + for i in range(self.ListBox.count()): + item = self.ListBox.item(i) + if item.text() in selectedItemsText: + item.setSelected(True) + else: + idx = self.items.index(selectedItemsText[0]) + self.ComboBox.setCurrentIndex(idx) + + def showInFileManager(self): + selectedTexts, _ = self.getSelectedItems() + folder = selectedTexts[0].split("(")[0].strip() + path = os.path.join(self.showInFileManagerPath, folder) + if os.path.exists(path) and os.path.isdir(path): + showPath = path + else: + showPath = self.showInFileManagerPath + utils.showInExplorer(showPath) + + def toggleMultiSelection(self, checked): + if checked: + self.multiPosButton.setText("Single selection") + self.ComboBox.hide() + self.ListBox.show() + # Show 10 items + n = self.ListBox.count() + if n > 10: + h = sum([self.ListBox.sizeHintForRow(i) for i in range(10)]) + else: + h = sum([self.ListBox.sizeHintForRow(i) for i in range(n)]) + self.ListBox.setMinimumHeight(h + 5) + self.ListBox.setFocusPolicy(Qt.StrongFocus) + self.ListBox.setFocus() + self.ListBox.setCurrentRow(0) + self.mainLayout.setStretchFactor(self.topLayout, 2) + else: + self.multiPosButton.setText("Multiple selection") + self.ListBox.hide() + self.ComboBox.show() + self.resize(self.width(), self.singleSelectionHeight) + + def getSelectedItems(self): + if self.multiPosButton.isChecked(): + selectedItems = self.ListBox.selectedItems() + selectedItemsText = [item.text() for item in selectedItems] + selectedItemsText = natsorted(selectedItemsText) + selectedItemsIdx = [self.items.index(txt) for txt in selectedItemsText] + else: + selectedItemsText = [self.ComboBox.currentText()] + selectedItemsIdx = [self.ComboBox.currentIndex()] + return selectedItemsText, selectedItemsIdx + + def ok_cb(self, event): + self.cancel = False + self.selectedItemsText, self.selectedItemsIdx = self.getSelectedItems() + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + self.singleSelectionHeight = self.height() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class SelectSegmFileDialog(QDialog): + def __init__( + self, + images_ls, + parent_path, + parent=None, + addNewFileButton=False, + basename="", + infoText=None, + fileType="segmentation", + allowMultipleSelection=False, + custom_first=None, + ): + self.cancel = True + self.selectedItemText = "" + self.selectedItemIdx = None + self.removeOthers = False + self.okAllPos = False + self.newSegmEndName = None + self.allowMultipleSelection = allowMultipleSelection + self.basename = basename + images_ls = sorted(images_ls, key=len) + if custom_first is not None: + images_ls.remove(custom_first) + images_ls.insert(0, custom_first) + + # Remove the 'segm_' part to allow filenameDialog to check if + # a new file is existing (since we only ask for the part after + # 'segm_') + self.existingEndNames = [ + n.replace("segm", "", 1).replace("_", "", 1) for n in images_ls + ] + + self.images_ls = images_ls + self.parent_path = parent_path + super().__init__(parent) + + informativeText = html_utils.paragraph(f""" + The loaded Position folders already contains + {len(self.existingEndNames)} {fileType} masks
+ """) + + self.setWindowTitle(f"{fileType.capitalize()} files detected") + is_win = sys.platform.startswith("win") + + mainLayout = QVBoxLayout() + infoLayout = QHBoxLayout() + selectionLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + + # Standard Qt Question icon + label = QLabel() + standardIcon = getattr(QStyle, "SP_MessageBoxQuestion") + icon = self.style().standardIcon(standardIcon) + pixmap = icon.pixmap(60, 60) + label.setPixmap(pixmap) + infoLayout.addWidget(label) + + infoLabel = QLabel(informativeText) + infoLayout.addWidget(infoLabel) + infoLayout.addStretch(1) + mainLayout.addLayout(infoLayout) + + if infoText is None: + infoText = f"Select which {fileType} file to load:" + + questionText = html_utils.paragraph(infoText) + label = QLabel(questionText) + listWidget = widgets.listWidget() + listWidget.addItems(images_ls) + listWidget.setCurrentRow(0) + listWidget.itemDoubleClicked.connect(self.listDoubleClicked) + if allowMultipleSelection: + listWidget.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + self.items = list(images_ls) + self.listWidget = listWidget + + okButton = widgets.okPushButton(" Load selected ") + txt = "Reveal in Finder..." if is_mac else "Show in Explorer..." + showInFileManagerButton = widgets.showInFileManagerButton(txt) + cancelButton = widgets.cancelPushButton(" Cancel ") + + if addNewFileButton: + newFileButton = widgets.newFilePushButton("New file...") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addWidget(showInFileManagerButton) + buttonsLayout.addSpacing(20) + if addNewFileButton: + buttonsLayout.addWidget(newFileButton) + buttonsLayout.addWidget(okButton) + + buttonsLayout.setContentsMargins(0, 10, 0, 10) + + selectionLayout.addWidget(label, 0, 1, alignment=Qt.AlignLeft) + selectionLayout.addWidget(listWidget, 1, 1) + selectionLayout.setColumnStretch(0, 0) + selectionLayout.setColumnStretch(1, 1) + selectionLayout.setColumnStretch(2, 0) + selectionLayout.addLayout(buttonsLayout, 2, 1) + + mainLayout.addLayout(selectionLayout) + self.setLayout(mainLayout) + + self.okButton = okButton + + # Connect events + okButton.clicked.connect(self.ok_cb) + if addNewFileButton: + newFileButton.clicked.connect(self.newFile_cb) + cancelButton.clicked.connect(self.close) + showInFileManagerButton.clicked.connect(self.showInFileManager) + + def listDoubleClicked(self, item): + self.ok_cb() + + def showInFileManager(self, checked=True): + utils.showInExplorer(self.parent_path) + + def newFile_cb(self): + win = filenameDialog( + basename=f"{self.basename}segm", + hintText="Insert a filename for the segmentation file:", + existingNames=self.existingEndNames, + ) + win.exec_() + if win.cancel: + return + self.cancel = False + self.newSegmEndName = win.entryText + self.close() + + def setSelectedItemFromText(self, itemText): + for i in range(self.listWidget.count()): + if self.listWidget.item(i).text() == itemText: + self.listWidget.setCurrentRow(i) + break + + def ok_cb(self, event=None): + self.cancel = False + try: + self.selectedItemText = self.listWidget.selectedItems()[0].text() + except IndexError: + self.cancel = True + self.close() + return + self.selectedItemIdx = self.items.index(self.selectedItemText) + self.selectedItemTexts = [ + selectedItem.text() for selectedItem in self.listWidget.selectedItems() + ] + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QDialogPbar(QDialog): + def __init__(self, title="Progress", infoTxt="", parent=None): + self.workerFinished = False + self.aborted = False + self.clickCount = 0 + super().__init__(parent) + + abort_text = "Option+Command+C" if is_mac else "Ctrl+Alt+C" + self.abort_text = abort_text + + self.setWindowTitle(f"{title} ({abort_text} to abort)") + self.setWindowFlags(Qt.Window) + + mainLayout = QVBoxLayout() + pBarLayout = QGridLayout() + + if infoTxt: + infoLabel = QLabel(infoTxt) + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + + self.progressLabel = QLabel() + + self.QPbar = widgets.ProgressBar(self) + pBarLayout.addWidget(self.QPbar, 0, 0) + self.ETA_label = QLabel("NDh:NDm:NDs") + pBarLayout.addWidget(self.ETA_label, 0, 1) + + self.metricsQPbar = widgets.ProgressBar(self) + self.metricsQPbar.setValue(0) + pBarLayout.addWidget(self.metricsQPbar, 1, 0) + + # pBarLayout.setColumnStretch(2, 1) + + mainLayout.addWidget(self.progressLabel) + mainLayout.addLayout(pBarLayout) + + self.setLayout(mainLayout) + # self.setModal(True) + + def keyPressEvent(self, event): + isCtrlAlt = event.modifiers() == (Qt.ControlModifier | Qt.AltModifier) + if isCtrlAlt and event.key() == Qt.Key_C: + doAbort = self.askAbort() + if doAbort: + self.aborted = True + self.workerFinished = True + self.close() + + def askAbort(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + Aborting with {self.abort_text} to abort + is not safe.

+ The system status cannot be predicted and + it will require a restart.

+ Are you sure you want to abort? + """) + yesButton, noButton = msg.critical( + self, "Are you sure you want to abort?", txt, buttonsTexts=("Yes", "No") + ) + return msg.clickedButton == yesButton + + def abort(self): + self.clickCount += 1 + self.aborted = True + if self.clickCount > 3: + self.workerFinished = True + self.close() + + def closeEvent(self, event): + if not self.workerFinished: + event.ignore() + + +class pgTestWindow(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + + layout = QVBoxLayout() + + self.graphLayout = pg.GraphicsLayoutWidget() + self.ax1 = pg.PlotItem() + self.ax1.setAspectLocked(True) + self.graphLayout.addItem(self.ax1) + + layout.addWidget(self.graphLayout) + + self.setLayout(layout) + + +def get_existing_directory(allow_images_path=True, **kwargs): + while True: + folder_path = qtpy.compat.getexistingdirectory(**kwargs) + if not folder_path: + return + + if allow_images_path: + return folder_path + + pos_folderpath = os.path.dirname(folder_path) + is_images_folder = ( + folder_path.endswith("Images") + and os.path.basename(pos_folderpath).startswith("Position_") + and os.path.isdir(folder_path) + ) + if not is_images_folder: + return folder_path + + txt = html_utils.paragraph( + "You cannot save to the Images folder " + "because it is reserved to files that start with the same " + "basename.

Thank you for your patience!" + ) + msg = widgets.myMessageBox() + msg.warning(kwargs["parent"], "Cannot save here", txt) + + +class SetCustomLevelsLut(QBaseDialog): + sigLevelsChanged = Signal(object) + + def __init__( + self, + init_min_value=None, + init_max_value=None, + minimum_min_value=0, + maximum_max_value=None, + parent=None, + ): + super().__init__(parent=parent) + + self.cancel = True + + self.setWindowTitle("Custom LUT levels") + + layout = QVBoxLayout() + + self.minLevelSlider = widgets.sliderWithSpinBox( + title="Minimum", + title_loc="top", + ) + self.minLevelSlider.setMinimum(minimum_min_value) + + if init_min_value is not None: + self.minLevelSlider.setValue(init_min_value) + + layout.addWidget(self.minLevelSlider) + + self.maxLevelSlider = widgets.sliderWithSpinBox( + title="Maximum", + title_loc="top", + ) + self.maxLevelSlider.setMinimum(minimum_min_value) + if init_max_value is not None: + self.maxLevelSlider.setValue(init_max_value) + + if maximum_max_value is not None: + self.maxLevelSlider.setMaximum(maximum_max_value) + self.minLevelSlider.setMaximum(maximum_max_value) + + layout.addWidget(self.maxLevelSlider) + + self.minLevelSlider.sigValueChange.connect(self.emitLevelsChanged) + self.maxLevelSlider.sigValueChange.connect(self.emitLevelsChanged) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + + self.setLayout(layout) + + def sizeHint(self): + heightHint = super().sizeHint().height() + widthHint = super().sizeHint().width() * 2 + return QSize(widthHint, heightHint) + + def levels(self): + levels = (self.minLevelSlider.value(), self.maxLevelSlider.value()) + return levels + + def emitLevelsChanged(self, value): + self.sigLevelsChanged.emit(self.levels()) + + def ok_cb(self): + self.cancel = False + self.selectedLevels = self.levels() + self.close() + + +class QTreeDialog(QBaseDialog): + def __init__( + self, + items: List[Tuple[str]], + headerLabels: List[str] = None, + parent=None, + infoText="Select item", + title="Select item", + path_to_browse=None, + additional_buttons=None, + ): + self.cancel = True + super().__init__(parent) + + self.setWindowTitle(title) + + mainLayout = QVBoxLayout() + + infoLabel = QLabel(html_utils.paragraph(infoText)) + + self.treeWidget = widgets.TreeWidget() + if headerLabels is not None: + self.treeWidget.setHeaderLabels(headerLabels) + else: + self.treeWidget.setHeaderHidden(True) + + for row, texts in enumerate(items): + item = widgets.TreeWidgetItem(self.treeWidget) + for i, text in enumerate(texts): + item.setText(i, text) + self.treeWidget.addTopLevelItem(item) + + self.treeWidget.resizeColumnToContents(0) + self.treeWidget.resizeColumnToContents(1) + + # self.treeWidget.header().setStretchLastSection(False) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + if path_to_browse is not None: + browseButton = widgets.showInFileManagerButton(setDefaultText=True) + browseButton.setPathToBrowse(path_to_browse) + buttonsLayout.insertWidget(3, browseButton) + + if additional_buttons is not None: + for btn in additional_buttons: + buttonsLayout.insertWidget(3, btn) + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(infoLabel) + mainLayout.addWidget(self.treeWidget) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def show(self, block=False): + w = self.sizeHint().width() + h = self.sizeHint().height() + self.resize(int(w * 1.3), h) + super().show(block=block) + + def ok_cb(self): + self.clickedButton = self.sender() + self.cancel = False + self.selectedItem = self.treeWidget.currentItem() + self.selectedText = self.selectedItem.text(0) + self.close() + +# Sibling imports (deferred to avoid import cycles) +from .metadata import ( + filenameDialog, +) + diff --git a/cellacdc/dialogs/measurements.py b/cellacdc/dialogs/measurements.py new file mode 100644 index 000000000..9a82a63e1 --- /dev/null +++ b/cellacdc/dialogs/measurements.py @@ -0,0 +1,2979 @@ +"""Cell-ACDC dialog windows: measurements.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +class SetMeasurementsDialog(QBaseDialog): + sigClosed = Signal() + sigCancel = Signal() + sigRestart = Signal() + + def __init__( + self, + loadedChNames, + notLoadedChNames, + isZstack, + isSegm3D, + favourite_funcs=None, + parent=None, + allPos_acdc_df_cols=None, + acdc_df_path=None, + posData=None, + addCombineMetricCallback=None, + allPosData=None, + is_concat=False, + isSingleSelection=False, + state=None, + ): + super().__init__(parent=parent) + + self.checkBoxedGroup = QButtonGroup() + self.checkBoxedGroup.setExclusive(isSingleSelection) + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + self.cancel = True + + self.delExistingCols = False + self.okClicked = False + self.is_concat = is_concat + self.allPos_acdc_df_cols = allPos_acdc_df_cols + self.acdc_df_path = acdc_df_path + self.allPosData = allPosData + self.doNotWarn = False + + self.setWindowTitle("Set measurements") + # self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + layout = QVBoxLayout() + + searchLayout = QHBoxLayout() + + searchLineEdit = widgets.SearchLineEdit() + searchLayout.addStretch(5) + searchLayout.addWidget(searchLineEdit) + searchLayout.setStretch(1, 3) + + mainScrollArea = widgets.ScrollArea() + mainScrollAreaWidget = QWidget() + mainScrollArea.setWidget(mainScrollAreaWidget) + + groupsLayout = QGridLayout() + self.groupsLayout = groupsLayout + + mainScrollAreaWidget.setLayout(groupsLayout) + + buttonsLayout = QHBoxLayout() + + self.chNameGroupboxes = [] + self.all_metrics = [] + + col = 0 + for col, chName in enumerate(loadedChNames): + channelGBox = widgets.channelMetricsQGBox( + isZstack, + chName, + isSegm3D, + favourite_funcs=favourite_funcs, + posData=posData, + is_concat=is_concat, + ) + channelGBox.chName = chName + groupsLayout.addWidget(channelGBox, 0, col, 3, 1) + self.chNameGroupboxes.append(channelGBox) + channelGBox.sigDelClicked.connect(self.delMixedChannelCombineMetric) + channelGBox.sigCheckboxToggled.connect(self.channelCheckboxToggled) + groupsLayout.setColumnStretch(col, 5) + self.all_metrics.extend([c.text() for c in channelGBox.checkBoxes]) + + current_col = col + 1 + for col, chName in enumerate(notLoadedChNames): + channelGBox = widgets.channelMetricsQGBox( + isZstack, + chName, + isSegm3D, + favourite_funcs=favourite_funcs, + posData=posData, + is_concat=is_concat, + ) + channelGBox.setChecked(False) + channelGBox.chName = chName + groupsLayout.addWidget(channelGBox, 0, current_col, 3, 1) + self.chNameGroupboxes.append(channelGBox) + groupsLayout.setColumnStretch(current_col, 5) + channelGBox.sigDelClicked.connect(self.delMixedChannelCombineMetric) + channelGBox.sigCheckboxToggled.connect(self.channelCheckboxToggled) + current_col += 1 + self.all_metrics.extend([c.text() for c in channelGBox.checkBoxes]) + + current_col += 1 + + if posData is None: + isTimelapse = False + else: + isTimelapse = posData.SizeT > 1 + size_metrics_desc = measurements.get_size_metrics_desc(isSegm3D, isTimelapse) + if not isSegm3D: + size_metrics_desc = { + key: val + for key, val in size_metrics_desc.items() + if not key.endswith("_3D") + } + + row = 0 + sizeMetricsQGBox = widgets._metricsQGBox( + size_metrics_desc, + "Physical measurements", + favourite_funcs=favourite_funcs, + isZstack=isZstack, + addCalcForEachZsliceToggle=isSegm3D, + ) + self.all_metrics.extend([c.text() for c in sizeMetricsQGBox.checkBoxes]) + self.sizeMetricsQGBox = sizeMetricsQGBox + for sizeCheckbox in sizeMetricsQGBox.checkBoxes: + sizeCheckbox.toggled.connect(self.sizeMetricToggled) + groupsLayout.addWidget(sizeMetricsQGBox, row, current_col) + groupsLayout.setRowStretch(0, 1) + groupsLayout.setColumnStretch(current_col, 3) + row += 1 + + props_info_txt_mapper = measurements.get_props_info_txt_mapper( + isSegm3D=isSegm3D + ) + rp_desc = props_info_txt_mapper + regionPropsQGBox = widgets._metricsQGBox( + rp_desc, + "Morphological properties", + favourite_funcs=favourite_funcs, + isZstack=isZstack, + ) + self.regionPropsQGBox = regionPropsQGBox + for rpCheckbox in regionPropsQGBox.checkBoxes: + rpCheckbox.toggled.connect(self.rpMetricToggled) + groupsLayout.addWidget(regionPropsQGBox, row, current_col) + groupsLayout.setRowStretch(1, 2) + self.all_metrics.extend([c.text() for c in regionPropsQGBox.checkBoxes]) + row += 1 + + # Custom metrics that are channel indipendent + self.chIndipendCustomeMetricsQGBox = None + out = measurements.ch_indipend_custom_metrics_desc( + isZstack, + isSegm3D=isSegm3D, + ) + ch_indipend_custom_metrics_desc = out + if ch_indipend_custom_metrics_desc: + self.chIndipendCustomeMetricsQGBox = widgets._metricsQGBox( + ch_indipend_custom_metrics_desc, + "Channel indipendent custom measurements", + favourite_funcs=favourite_funcs, + isZstack=isZstack, + parent=self, + ) + groupsLayout.addWidget(self.chIndipendCustomeMetricsQGBox, row, current_col) + groupsLayout.setRowStretch(1, 1) + row += 1 + + desc, equations = measurements.combine_mixed_channels_desc( + isSegm3D=isSegm3D, posData=posData, available_cols=self.all_metrics + ) + self.mixedChannelsCombineMetricsQGBox = None + if desc: + self.mixedChannelsCombineMetricsQGBox = widgets._metricsQGBox( + desc, + "Mixed channels combined measurements", + favourite_funcs=favourite_funcs, + isZstack=isZstack, + equations=equations, + addDelButton=True, + ) + self.mixedChannelsCombineMetricsQGBox.sigDelClicked.connect( + self.delMixedChannelCombineMetric + ) + groupsLayout.addWidget( + self.mixedChannelsCombineMetricsQGBox, row, current_col + ) + groupsLayout.setRowStretch(1, 1) + if not self.is_concat: + self.setDisabledMetricsRequestedForCombined(False) + self.mixedChannelsCombineMetricsQGBox.toggled.connect( + self.setDisabledMetricsRequestedForCombined + ) + for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + combCheckbox.toggled.connect( + self.setDisabledMetricsRequestedForCombined + ) + else: + for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + combCheckbox.toggled.connect(self.mixedChannelsMetricToggled) + row += 1 + + self.last_row = row + self.last_col = current_col + + okButton = widgets.okPushButton(" Ok ") + cancelButton = widgets.cancelPushButton("Cancel") + if addCombineMetricCallback is not None: + addCombineMetricButton = widgets.addPushButton( + "Add combined measurement..." + ) + addCombineMetricButton.clicked.connect(addCombineMetricCallback) + self.okButton = okButton + + loadLastSelButton = widgets.reloadPushButton("Load last selection...") + self.deselectAllButton = QPushButton("Deselect all") + self.deselectAllButton.setIcon(QIcon(":deselect_all.svg")) + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(self.deselectAllButton) + buttonsLayout.addSpacing(20) + + if addCombineMetricCallback is not None: + buttonsLayout.addWidget(addCombineMetricButton) + buttonsLayout.addSpacing(20) + + saveCurrentSelectionButton = widgets.savePushButton("Save current selection...") + saveCurrentSelectionButton.clicked.connect(self.saveCurrentSelectionClicked) + + buttonsLayout.addWidget(saveCurrentSelectionButton) + + loadSavedSelectionButton = widgets.OpenFilePushButton("Load saved selection...") + loadSavedSelectionButton.clicked.connect(self.loadSavedSelectionClicked) + buttonsLayout.addWidget(loadSavedSelectionButton) + + buttonsLayout.addWidget(loadLastSelButton) + + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + self.okButton = okButton + + layout.addLayout(searchLayout) + layout.addSpacing(10) + # layout.addLayout(groupsLayout) + layout.addWidget(mainScrollArea) + layout.addLayout(buttonsLayout) + + self.setLayout(layout) + + if state is not None: + self.setState(state) + + searchLineEdit.textEdited.connect(self.searchAndHighlight) + self.deselectAllButton.clicked.connect(self.deselectAll) + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + loadLastSelButton.clicked.connect(self.loadLastSelection) + + self.addCheckboxesToGroup() + + for channelGBox in self.chNameGroupboxes: + for checkbox in channelGBox.checkBoxes: + self.channelCheckboxToggled(checkbox) + + def allMetricsDict(self): + all_metrics = { + "standard": {}, + "regionprop": [], + "size": [], + "mixed_channels": [], + } + for chNameGroupbox in self.chNameGroupboxes: + channel_name = chNameGroupbox.chName + for checkBox in chNameGroupbox.checkBoxes: + if channel_name not in all_metrics["standard"]: + all_metrics["standard"][channel_name] = [] + all_metrics["standard"][channel_name].append(checkBox.text()) + + for checkBox in self.regionPropsQGBox.checkBoxes: + all_metrics["regionprop"].append(checkBox.text()) + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + all_metrics["size"].append(checkBox.text()) + + if self.chIndipendCustomeMetricsQGBox is not None: + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + all_metrics["ch_indipend_custom_metric"].append(checkBox.text()) + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + all_metrics["mixed_channels"].append(checkBox.text()) + + return all_metrics + + def searchAndHighlight(self, text): + for chNameGroupbox in self.chNameGroupboxes: + for groupbox in chNameGroupbox.groupboxes: + groupbox.highlightCheckboxesFromSearchText(text) + + self.regionPropsQGBox.highlightCheckboxesFromSearchText(text) + self.sizeMetricsQGBox.highlightCheckboxesFromSearchText(text) + + if self.chIndipendCustomeMetricsQGBox is not None: + self.chIndipendCustomeMetricsQGBox.highlightCheckboxesFromSearchText(text) + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + self.mixedChannelsCombineMetricsQGBox.highlightCheckboxesFromSearchText(text) + + def selectedMetricNameAndGroup(self): + for chNameGroupbox in self.chNameGroupboxes: + for checkBox in chNameGroupbox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text(), {"standard": chNameGroupbox.chName} + + for checkBox in self.regionPropsQGBox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text(), "regionprop" + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text(), "size" + + if self.chIndipendCustomeMetricsQGBox is not None: + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + if checkBox.isChecked(): + return checkBox.text(), "ch_indipend_custom_metric" + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + if checkBox.isChecked(): + return checkBox.text(), "mixed_channels" + + def selectedMetricGroup(self): + for chNameGroupbox in self.chNameGroupboxes: + for checkBox in chNameGroupbox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text() + + for checkBox in self.regionPropsQGBox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text() + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + if checkBox.isChecked(): + return checkBox.text() + + if self.chIndipendCustomeMetricsQGBox is not None: + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + if checkBox.isChecked(): + return checkBox.text() + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + if checkBox.isChecked(): + return checkBox.text() + + def addCheckboxesToGroup(self): + for chNameGroupbox in self.chNameGroupboxes: + for checkBox in chNameGroupbox.checkBoxes: + self.checkBoxedGroup.addButton(checkBox) + + for checkBox in self.regionPropsQGBox.checkBoxes: + self.checkBoxedGroup.addButton(checkBox) + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + self.checkBoxedGroup.addButton(checkBox) + + if self.chIndipendCustomeMetricsQGBox is not None: + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + self.checkBoxedGroup.addButton(checkBox) + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + self.checkBoxedGroup.addButton(checkBox) + + def channelCheckboxToggled(self, checkbox): + # Make sure to automatically check the requested cell_vol metric for + # concentration metrics + if checkbox.text().find("concentration_") == -1: + return + + if self.is_concat: + # When this dialogue is used in concatenate pos utility we do not + # need to check that certain metrics are present + return + + pattern = r".+_from_vol_([a-z]+)(_3D)?(_?[A-Za-z0-9]*)" + repl = r"cell_vol_\1\2" + cell_vol_metric_name = re.sub(pattern, repl, checkbox.text()) + for sizeCheckbox in self.sizeMetricsQGBox.checkBoxes: + if sizeCheckbox.text() == cell_vol_metric_name: + break + else: + # Make sure to not check for similarly named custom metrics + return + + if checkbox.isChecked(): + sizeCheckbox.setChecked(True) + sizeCheckbox.isRequired = True + else: + # Do not enable cell vol checkbox is any of the other + # concentration metrics requiring it is checked + unit = cell_vol_metric_name[9:] + is3D = unit.endswith("3D") + for channelGBox in self.chNameGroupboxes: + if not channelGBox.isChecked(): + continue + for _checkbox in channelGBox.checkBoxes: + if _checkbox.text().find(f"_from_vol_{unit}") == -1: + continue + if not is3D and _checkbox.text().find(f"{unit}_3D") != -1: + # Metric is 3D but the cell_vol is not + continue + if _checkbox.isChecked(): + return + sizeCheckbox.isRequired = False + + def rpMetricToggled(self, checked): + pass + + def mixedChannelsMetricToggled(self, checked): + pass + + def sizeMetricToggled(self, checked): + """Method called when a checkbox of a size metric is toggled. + Check if the size value is required and explain why it cannot be + unchecked. + + Parameters + ---------- + checked : bool + State of the checkbox toggled + """ + checkbox = self.sender() + + if self.is_concat: + # When this dialogue is used in concatenate pos utility we do not + # need to check that certain metrics are present + return + + if not hasattr(checkbox, "isRequired"): + return + + if not checkbox.isRequired: + return + + if checkbox.isChecked(): + return + + checkbox.setChecked(True) + + if self.doNotWarn: + return + + linked_autoBkgr_metric = checkbox.text().replace("cell", "_autoBkgr_from") + linked_dataPrepBkgr_metric = checkbox.text().replace( + "cell", "_dataPrepBkgr_from" + ) + txt = html_utils.paragraph(f""" + This physical measurement cannot be unchecked + because it is required + by the {linked_autoBkgr_metric} and + {linked_dataPrepBkgr_metric} measurements + that you requested to save.

+ + Thank you for you patience! + """) + msg = widgets.myMessageBox(showCentered=False) + msg.warning(self, "Physical measurement required", txt) + + def deselectAll(self): + self.doNotWarn = True + for chNameGroupbox in self.chNameGroupboxes: + for gb in chNameGroupbox.groupboxes: + gb.checkAll(None, False) + cgb = getattr(chNameGroupbox, "customMetricsQGBox", None) + if cgb is not None: + cgb.checkAll(None, False) + + self.sizeMetricsQGBox.checkAll(None, False) + self.regionPropsQGBox.checkAll(None, False) + if self.chIndipendCustomeMetricsQGBox is not None: + self.chIndipendCustomeMetricsQGBox.checkAll(None, False) + + if self.mixedChannelsCombineMetricsQGBox is not None: + self.mixedChannelsCombineMetricsQGBox.checkAll(None, False) + self.doNotWarn = False + + def delMixedChannelCombineMetric(self, colname_to_del, hlayout): + cp = measurements.read_saved_user_combine_config() + for section in cp.sections(): + cp.remove_option(section, colname_to_del) + measurements.save_common_combine_metrics(cp) + + for i in range(hlayout.count()): + item = hlayout.itemAt(i) + w = item.widget() + if w is None: + continue + w.hide() + + if self.allPosData is not None: + for posData in self.allPosData: + _config = posData.combineMetricsConfig + for section in _config.sections(): + _config.remove_option(section, colname_to_del) + posData.saveCombineMetrics() + + def setState(self, state): + self.doNotWarn = True + for chNameGroupbox in self.chNameGroupboxes: + measurementsInfo = state.get(chNameGroupbox.title()) + if not measurementsInfo: + chNameGroupbox.setChecked(False) + else: + for checkBox in chNameGroupbox.checkBoxes: + colname = checkBox.text() + checkBox.setChecked(measurementsInfo[colname]) + + measurementsInfo = state.get(self.sizeMetricsQGBox.title()) + if not measurementsInfo: + self.sizeMetricsQGBox.setChecked(False) + else: + for checkBox in self.sizeMetricsQGBox.checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + checkBox.setChecked(measurementsInfo[colname]) + + measurementsInfo = state.get(self.regionPropsQGBox.title()) + if not measurementsInfo: + self.regionPropsQGBox.setChecked(False) + else: + self.regionPropsToSave = [] + for checkBox in self.regionPropsQGBox.checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + checkBox.setChecked(measurementsInfo[colname]) + + if self.chIndipendCustomeMetricsQGBox is not None: + measurementsInfo = state.get(self.chIndipendCustomeMetricsQGBox.title()) + if not measurementsInfo: + self.chIndipendCustomeMetricsQGBox.setChecked(False) + else: + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + key = self.chIndipendCustomeMetricsQGBox.title() + checkBox.setChecked(measurementsInfo[colname]) + + if self.mixedChannelsCombineMetricsQGBox is not None: + measurementsInfo = state.get(self.mixedChannelsCombineMetricsQGBox.title()) + if not measurementsInfo: + self.mixedChannelsCombineMetricsQGBox.setChecked(False) + else: + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + key = self.mixedChannelsCombineMetricsQGBox.title() + checkBox.setChecked(measurementsInfo[colname]) + + self.doNotWarn = False + + def state(self): + state = {self.sizeMetricsQGBox.title(): {}, self.regionPropsQGBox.title(): {}} + for chNameGroupbox in self.chNameGroupboxes: + state[chNameGroupbox.title()] = {} + if not chNameGroupbox.isChecked(): + # Channel unchecked + continue + else: + for checkBox in chNameGroupbox.checkBoxes: + colname = checkBox.text() + state[chNameGroupbox.title()][colname] = checkBox.isChecked() + + if not self.sizeMetricsQGBox.isChecked(): + pass + else: + for checkBox in self.sizeMetricsQGBox.checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + state[self.sizeMetricsQGBox.title()][colname] = checked + + if not self.regionPropsQGBox.isChecked(): + pass + else: + self.regionPropsToSave = [] + for checkBox in self.regionPropsQGBox.checkBoxes: + checked = checkBox.isChecked() + colname = checkBox.text() + state[self.regionPropsQGBox.title()][colname] = checked + + if self.chIndipendCustomeMetricsQGBox is not None: + state[self.chIndipendCustomeMetricsQGBox.title()] = {} + if self.chIndipendCustomeMetricsQGBox.isChecked(): + checkBoxes = self.chIndipendCustomeMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + checked = checkBox.isChecked() + key = self.chIndipendCustomeMetricsQGBox.title() + colname = checkBox.text() + state[key][colname] = checked + + if self.mixedChannelsCombineMetricsQGBox is not None: + state[self.mixedChannelsCombineMetricsQGBox.title()] = {} + if self.mixedChannelsCombineMetricsQGBox.isChecked(): + checkBoxes = self.mixedChannelsCombineMetricsQGBox.checkBoxes + for checkBox in checkBoxes: + checked = checkBox.isChecked() + key = self.mixedChannelsCombineMetricsQGBox.title() + colname = checkBox.text() + state[key][colname] = checked + + return state + + def restoreState(self, state): + for chNameGroupbox in self.chNameGroupboxes: + _state = state.get(chNameGroupbox.title()) + if _state is None or not _state: + continue + for checkBox in chNameGroupbox.checkBoxes: + isChecked = _state.get(checkBox.text()) + if isChecked is None: + continue + checkBox.setChecked(isChecked) + + _state = state.get(self.sizeMetricsQGBox.title()) + if _state is None or not _state: + pass + else: + for checkBox in self.sizeMetricsQGBox.checkBoxes: + isChecked = _state.get(checkBox.text()) + if isChecked is None: + continue + checkBox.setChecked(isChecked) + + _state = state.get(self.regionPropsQGBox.title()) + if _state is None or not _state: + pass + else: + for checkBox in self.regionPropsQGBox.checkBoxes: + isChecked = _state.get(checkBox.text()) + if isChecked is None: + continue + checkBox.setChecked(isChecked) + + if self.chIndipendCustomeMetricsQGBox is not None: + _state = state.get(self.chIndipendCustomeMetricsQGBox.title()) + if _state is None or not _state: + pass + else: + for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: + isChecked = _state.get(checkBox.text()) + if isChecked is None: + continue + checkBox.setChecked(isChecked) + + if self.mixedChannelsCombineMetricsQGBox is not None: + _state = state.get(self.mixedChannelsCombineMetricsQGBox.title()) + if _state is None or not _state: + pass + else: + for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + isChecked = _state.get(checkBox.text()) + if isChecked is None: + continue + checkBox.setChecked(isChecked) + + def currentSelectionMapper(self): + current_selected_meas = defaultdict(dict) + + for chNameGroupbox in self.chNameGroupboxes: + if not chNameGroupbox.isChecked(): + continue + + chName = chNameGroupbox.chName + for checkBox in chNameGroupbox.checkBoxes: + if not checkBox.isChecked(): + continue + + current_selected_meas[chName][checkBox.text()] = "Yes" + + size_selected_meas = current_selected_meas.get(self.sizeMetricsQGBox.title()) + if self.sizeMetricsQGBox.isChecked(): + for checkBox in self.sizeMetricsQGBox.checkBoxes: + if not checkBox.isChecked(): + continue + + section = self.sizeMetricsQGBox.title() + current_selected_meas[section][checkBox.text()] = "Yes" + + size_selected_meas = current_selected_meas.get(self.regionPropsQGBox.title()) + if self.regionPropsQGBox.isChecked(): + for checkBox in self.regionPropsQGBox.checkBoxes: + if not checkBox.isChecked(): + continue + + section = self.regionPropsQGBox.title() + current_selected_meas[section][checkBox.text()] = "Yes" + + if self.chIndipendCustomeMetricsQGBox is not None: + if self.chIndipendCustomeMetricsQGBox.isChecked(): + for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: + if not checkBox.isChecked(): + continue + + section = self.chIndipendCustomeMetricsQGBox.title() + current_selected_meas[section][checkBox.text()] = "Yes" + + if self.mixedChannelsCombineMetricsQGBox is not None: + if self.mixedChannelsCombineMetricsQGBox.isChecked(): + for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + if not checkBox.isChecked(): + continue + + section = self.mixedChannelsCombineMetricsQGBox.title() + current_selected_meas[section][checkBox.text()] = "Yes" + + return current_selected_meas + + def saveCurrentSelectionClicked(self): + current_selection_mapper = self.currentSelectionMapper() + defaultEntry = "_and_".join(current_selection_mapper.keys()) + defaultEntry = defaultEntry.replace(" ", "_").lower() + saved_selections = io.get_saved_measurements_selections() + win = filenameDialog( + basename="", + ext="", + hintText="Insert a name for the current selection:", + existingNames=saved_selections, + allowEmpty=False, + defaultEntry=defaultEntry, + ) + win.exec_() + if win.cancel: + return + + filename = win.filename + ini_filepath = io.save_measurements_selections( + filename, current_selection_mapper + ) + + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(f""" + Done!

+ Current selection saved with name {filename} at + the following path: + """) + msg.information( + self, + "Selection saved", + txt, + commands=(ini_filepath,), + path_to_browse=os.path.dirname(ini_filepath), + ) + + def loadSavedSelectionClicked(self): + self.doNotWarn = True + + saved_selections = io.get_saved_measurements_selections() + + selectNameWin = widgets.QDialogListbox( + "Choose selection to load", + "Choose selection to load:\n", + saved_selections, + multiSelection=False, + parent=self, + ) + selectNameWin.exec_() + if selectNameWin.cancel: + return + + selection_mapper = io.read_measurements_selections( + selectNameWin.selectedItemsText[0] + ) + + self.setCurrentSelectionFromMapper(selection_mapper) + + self.doNotWarn = False + + def saveLastSelection(self): + last_selected_meas = self.currentSelectionMapper() + load.write_last_selected_set_measurements(last_selected_meas) + + def setCurrentSelectionFromMapper(self, selection_mapper): + for chNameGroupbox in self.chNameGroupboxes: + chName = chNameGroupbox.chName + chSelectedMeas = selection_mapper.get(chName) + if chSelectedMeas is None: + chNameGroupbox.setChecked(False) + continue + + chNameGroupbox.setChecked(True) + for checkBox in chNameGroupbox.checkBoxes: + checked = chSelectedMeas.get(checkBox.text()) + if checked is not None: + checkBox.setChecked(True) + else: + checkBox.setChecked(False) + + size_selected_meas = selection_mapper.get(self.sizeMetricsQGBox.title()) + if size_selected_meas is None: + self.sizeMetricsQGBox.setChecked(False) + else: + self.sizeMetricsQGBox.setChecked(True) + for checkBox in self.sizeMetricsQGBox.checkBoxes: + checked = size_selected_meas.get(checkBox.text()) + if checked is not None: + checkBox.setChecked(True) + else: + checkBox.setChecked(False) + + size_selected_meas = selection_mapper.get(self.regionPropsQGBox.title()) + if size_selected_meas is None: + self.regionPropsQGBox.setChecked(False) + else: + self.regionPropsQGBox.setChecked(True) + for checkBox in self.regionPropsQGBox.checkBoxes: + checked = size_selected_meas.get(checkBox.text()) + if checked is not None: + checkBox.setChecked(True) + else: + checkBox.setChecked(False) + + if self.chIndipendCustomeMetricsQGBox is not None: + ch_indip_custom_metrics = selection_mapper.get( + self.chIndipendCustomeMetricsQGBox.title() + ) + if size_selected_meas is None: + self.chIndipendCustomeMetricsQGBox.setChecked(False) + else: + self.chIndipendCustomeMetricsQGBox.setChecked(True) + for checkBox in self.chIndipendCustomeMetricsQGBox.checkBoxes: + checked = size_selected_meas.get(checkBox.text()) + if checked is not None: + checkBox.setChecked(True) + else: + checkBox.setChecked(False) + + if self.mixedChannelsCombineMetricsQGBox is not None: + ch_indip_custom_metrics = selection_mapper.get( + self.mixedChannelsCombineMetricsQGBox.title() + ) + if size_selected_meas is None: + self.mixedChannelsCombineMetricsQGBox.setChecked(False) + else: + self.mixedChannelsCombineMetricsQGBox.setChecked(True) + for checkBox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + checked = size_selected_meas.get(checkBox.text()) + if checked is not None: + checkBox.setChecked(True) + else: + checkBox.setChecked(False) + + def loadLastSelection(self): + self.doNotWarn = True + last_selected_meas = load.read_last_selected_set_measurements() + last_selected_meas = dict(last_selected_meas) + + self.setCurrentSelectionFromMapper(last_selected_meas) + + self.doNotWarn = False + + def setDisabledMetricsRequestedForCombined(self, checked): + checkbox = self.sender() + + if self.is_concat: + # When this dialogue is used in concatenate pos utility we do not + # need to check that certain metrics are present + return + + # Set checked and disable those metrics that are requested for + # combined measurements + allCheckboxes = [] + + for chNameGroupbox in self.chNameGroupboxes: + for chCheckBox in chNameGroupbox.checkBoxes: + chCheckBox.setDisabled(False) + allCheckboxes.append(chCheckBox) + + for sizeCheckBox in self.sizeMetricsQGBox.checkBoxes: + sizeCheckBox.setDisabled(False) + allCheckboxes.append(chCheckBox) + + for rpCheckBox in self.regionPropsQGBox.checkBoxes: + rpCheckBox.setDisabled(False) + allCheckboxes.append(chCheckBox) + + if not self.mixedChannelsCombineMetricsQGBox.isChecked(): + return + + for cb in allCheckboxes: + metricName = cb.text() + for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + equation = combCheckbox.equation + if equation.find(metricName) == -1: + continue + elif combCheckbox.isChecked(): + cb.setChecked(True) + cb.setDisabled(True) + cb.setToolTip( + "This metric cannot be removed because it is required " + f'by the combined measurement "{combCheckbox.text()}"' + ) + + def keyPressEvent(self, a0: QKeyEvent) -> None: + state = self.state() + return super().keyPressEvent(a0) + + def closeEvent(self, event): + if self.cancel: + self.sigCancel.emit() + super().closeEvent(event) + + def restart(self): + self.cancel = False + self.close() + self.sigRestart.emit() + + def setDisabledNotExistingMeasurements(self, existing_colnames): + self.existing_colnames = existing_colnames + for chNameGroupbox in self.chNameGroupboxes: + for checkBox in chNameGroupbox.checkBoxes: + colname = checkBox.text() + if colname in existing_colnames: + checkBox.setChecked(True) + continue + + checkBox.setChecked(False) + checkBox.setDisabled(True) + self.setNotExistingMeasurementTooltip(checkBox) + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + colname = checkBox.text() + if colname in existing_colnames: + checkBox.setChecked(True) + continue + checkBox.setChecked(False) + checkBox.setDisabled(True) + self.setNotExistingMeasurementTooltip(checkBox) + + for checkBox in self.regionPropsQGBox.checkBoxes: + prop_name = checkBox.text() + for existing_col in existing_colnames: + if prop_name == existing_col: + checkBox.setChecked(True) + break + m = re.match(rf"{prop_name}-\d", existing_col) + if m is not None: + checkBox.setChecked(True) + break + else: + checkBox.setChecked(False) + checkBox.setDisabled(True) + self.setNotExistingMeasurementTooltip(checkBox) + + if self.mixedChannelsCombineMetricsQGBox is None: + return + + for combCheckbox in self.mixedChannelsCombineMetricsQGBox.checkBoxes: + colname = combCheckbox.text() + if colname in existing_colnames: + combCheckbox.setChecked(True) + continue + combCheckbox.setChecked(False) + combCheckbox.setDisabled(True) + self.setNotExistingMeasurementTooltip(combCheckbox) + + def addNonMeasurementColumns(self, colnames): + additionalCols = measurements.get_non_measurements_cols( + colnames, self.all_metrics + ) + if not additionalCols: + return + self.nonMeasurementsGroupbox = widgets.CheckboxesGroupBox( + additionalCols, title="Additional columns", checkable=True + ) + self.groupsLayout.addWidget( + self.nonMeasurementsGroupbox, 0, self.last_col + 1, self.last_row + 1, 1 + ) + + def setNotExistingMeasurementTooltip(self, checkBox): + checkBox.setToolTip( + "Measurement is disabled because it is not present in selected " + "acdc_output tables, hence it cannot be addded to concatenated " + "table. " + ) + + def ok_cb(self): + for chNameGroupbox in self.chNameGroupboxes: + chNameGroupbox.calcForEachZsliceRequested = ( + chNameGroupbox.isCalcForEachZsliceRequested() + ) + + self.sizeMetricsQGBox.calcForEachZsliceRequested = ( + self.sizeMetricsQGBox.isCalcForEachZsliceRequested() + ) + + if self.allPos_acdc_df_cols is None: + self.saveLastSelection() + self.cancel = False + self.close() + self.sigClosed.emit() + return + + self.okClicked = True + existing_colnames = self.allPos_acdc_df_cols + unchecked_existing_colnames = [] + unchecked_existing_rps = [] + for chNameGroupbox in self.chNameGroupboxes: + for checkBox in chNameGroupbox.checkBoxes: + colname = checkBox.text() + is_existing = colname in existing_colnames + if not chNameGroupbox.isChecked() and is_existing: + unchecked_existing_colnames.append(colname) + continue + if not checkBox.isChecked() and is_existing: + unchecked_existing_colnames.append(colname) + + for checkBox in self.sizeMetricsQGBox.checkBoxes: + colname = checkBox.text() + is_existing = colname in existing_colnames + if not self.sizeMetricsQGBox.isChecked() and is_existing: + unchecked_existing_colnames.append(colname) + continue + + if not checkBox.isChecked() and is_existing: + unchecked_existing_colnames.append(colname) + for checkBox in self.regionPropsQGBox.checkBoxes: + colname = checkBox.text() + is_existing = any([col == colname for col in existing_colnames]) + if not self.regionPropsQGBox.isChecked() and is_existing: + unchecked_existing_rps.append(colname) + continue + + if not checkBox.isChecked() and is_existing: + unchecked_existing_rps.append(colname) + + if unchecked_existing_colnames or unchecked_existing_rps: + cancel, self.delExistingCols = self.warnUncheckedExistingMeasurements( + unchecked_existing_colnames, unchecked_existing_rps + ) + self.existingUncheckedColnames = unchecked_existing_colnames + self.existingUncheckedRps = unchecked_existing_rps + if cancel: + return + + self.saveLastSelection() + self.cancel = False + self.close() + self.sigClosed.emit() + + def warnUncheckedExistingMeasurements( + self, unchecked_existing_colnames, unchecked_existing_rps + ): + msg = widgets.myMessageBox() + msg.setWidth(500) + msg.addShowInFileManagerButton(self.acdc_df_path) + txt = html_utils.paragraph( + "You chose to not save some measurements that are " + "already present in the saved acdc_output.csv " + "file.

" + "Do you want to delete these measurements or " + "keep them?

" + "Existing measurements not selected:" + ) + listView = widgets.readOnlyQList(msg) + items = unchecked_existing_colnames.copy() + items.extend(unchecked_existing_rps) + listView.addItems(items) + _, delButton, keepButton = msg.warning( + self, + "Unchecked existing measurements", + txt, + widgets=listView, + buttonsTexts=("Cancel", "Delete", "Keep"), + ) + return msg.cancel, msg.clickedButton == delButton + + def show(self, block=False): + super().show(block=False) + self.deselectAllButton.setMinimumHeight(self.okButton.height()) + screenWidth = self.screen().size().width() + screenHeight = self.screen().size().height() + screenLeft = self.screen().geometry().x() + screenTop = self.screen().geometry().y() + h = screenHeight - 200 + minColWith = screenWidth / 5 + w = minColWith * (self.last_col + 1) + xLeft = int((screenWidth - w) / 2) + if w > screenWidth: + self.move(screenLeft + 10, screenTop + 50) + self.resize(screenWidth - 20, h) + else: + self.move(screenLeft + xLeft, screenTop + 50) + self.resize(int(w), h) + super().show(block=block) + + +class ComputeMetricsErrorsDialog(QBaseDialog): + def __init__(self, errorsDict, log_path="", parent=None, log_type="custom_metrics"): + super().__init__(parent) + + self.errorsDict = errorsDict + + layout = QGridLayout() + + self.setWindowTitle("Errors summary") + + label = QLabel(self) + standardIcon = getattr(QStyle, "SP_MessageBoxWarning") + icon = self.style().standardIcon(standardIcon) + pixmap = icon.pixmap(60, 60) + label.setPixmap(pixmap) + layout.addWidget(label, 0, 0, alignment=Qt.AlignTop) + + if log_type == "custom_metrics": + infoText = """ + When computing custom metrics the following metrics + were ignored because they raised an error.

+ """ + elif log_type == "standard_metrics": + infoText = """ + Some or all of the standard metrics were NOT saved + because Cell-ACDC encoutered the following errors.

+ """ + elif log_type == "region_props": + rp_url = "https://scikit-image.org/docs/0.18.x/api/skimage.measure.html#skimage.measure.regionprops" + rp_href = f'skimage.measure.regionprops' + infoText = f""" + Region properties were NOT saved because Cell-ACDC + encoutered the following errors.
+ Region properties are calculated using the scikit-image + function called {rp_href}.

+ """ + elif log_type == "missing_annot": + infoText = """ + The following Positions were SKIPPED because they did + not have cell cycle annotations.

+ To add lineage tree information you first need to do the + cell cycle analysis in module 3 "Main GUI".

+ """ + else: + infoText = """ + Process raised the errors listed below.

+ """ + + github_issues_href = f"here" + noteText = f""" + NOTE: If you need help understanding these errors you can + open an issue on our github page {github_issues_href}. + """ + + infoLabel = QLabel(html_utils.paragraph(f"{infoText}{noteText}")) + infoLabel.setOpenExternalLinks(True) + layout.addWidget(infoLabel, 0, 1) + + scrollArea = QScrollArea() + scrollAreaWidget = QWidget() + textLayout = QVBoxLayout() + for func_name, traceback_format in errorsDict.items(): + nameLabel = QLabel(f"{func_name}: ") + errorMessage = f"\n{traceback_format}" + errorLabel = QLabel(errorMessage) + errorLabel.setTextInteractionFlags( + Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard + ) + # errorLabel.setStyleSheet("background-color: white") + errorLabel.setFrameShape(QFrame.Shape.Panel) + errorLabel.setFrameShadow(QFrame.Shadow.Sunken) + textLayout.addWidget(nameLabel) + textLayout.addWidget(errorLabel) + textLayout.addStretch(1) + + scrollAreaWidget.setLayout(textLayout) + scrollArea.setWidget(scrollAreaWidget) + + layout.addWidget(scrollArea, 1, 1) + + buttonsLayout = QHBoxLayout() + showLogButton = widgets.showInFileManagerButton("Show log file...") + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(showLogButton) + + copyButton = widgets.copyPushButton("Copy error message") + copyButton.clicked.connect(self.copyErrorMessage) + buttonsLayout.addWidget(copyButton) + self.copyButton = copyButton + self.copyButton.text = "Copy error message" + self.copyButton.icon = self.copyButton.icon() + + okButton = widgets.okPushButton(" Ok ") + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + showLogButton.clicked.connect(partial(utils.showInExplorer, log_path)) + okButton.clicked.connect(self.close) + layout.setVerticalSpacing(10) + layout.addLayout(buttonsLayout, 2, 1) + + self.setLayout(layout) + self.setFont(font) + + def copyErrorMessage(self): + cb = QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + copiedText = "" + for _, traceback_format in self.errorsDict.items(): + errorBlock = f"{'=' * 30}\n{traceback_format}{'*' * 30}" + copiedText = f"{copiedText}{errorBlock}" + cb.setText(copiedText, mode=cb.Clipboard) + print("Error message copied.") + self.copyButton.setIcon(QIcon(":okButton.svg")) + self.copyButton.setText(" Copied to clipboard!") + QTimer.singleShot(2000, self.restoreCopyButton) + + def restoreCopyButton(self): + self.copyButton.setText(self.copyButton.text) + self.copyButton.setIcon(self.copyButton.icon) + + def showEvent(self, a0) -> None: + self.copyButton.setFixedWidth(self.copyButton.width()) + return super().showEvent(a0) + + +class combineMetricsEquationDialog(QBaseDialog): + sigOk = Signal(object) + + def __init__( + self, allChNames, isZstack, isSegm3D, parent=None, debug=False, closeOnOk=True + ): + super().__init__(parent) + + self.setWindowTitle("Add combined measurement") + + self.initAttributes() + + self.allChNames = allChNames + + self.cancel = True + self.isOperatorMode = False + self.closeOnOk = closeOnOk + + mainLayout = QVBoxLayout() + equationLayout = QHBoxLayout() + + metricsTreeWidget = QTreeWidget() + metricsTreeWidget.setHeaderHidden(True) + metricsTreeWidget.setFont(font) + self.metricsTreeWidget = metricsTreeWidget + + for chName in allChNames: + channelTreeItem = QTreeWidgetItem(metricsTreeWidget) + channelTreeItem.setText(0, f"{chName} measurements") + metricsTreeWidget.addTopLevelItem(channelTreeItem) + + metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( + isZstack, chName, isSegm3D=isSegm3D + ) + custom_metrics_desc = measurements.custom_metrics_desc( + isZstack, chName, isSegm3D=isSegm3D + ) + + foregrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + foregrMetricsTreeItem.setText(0, "Cell signal measurements") + channelTreeItem.addChild(foregrMetricsTreeItem) + + bkgrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + bkgrMetricsTreeItem.setText(0, "Background values") + channelTreeItem.addChild(bkgrMetricsTreeItem) + + if custom_metrics_desc: + customMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + customMetricsTreeItem.setText(0, "Custom measurements") + channelTreeItem.addChild(customMetricsTreeItem) + + self.addTreeItems(foregrMetricsTreeItem, metrics_desc.keys(), isCol=True) + self.addTreeItems(bkgrMetricsTreeItem, bkgr_val_desc.keys(), isCol=True) + + if custom_metrics_desc: + self.addTreeItems( + customMetricsTreeItem, custom_metrics_desc.keys(), isCol=True + ) + + self.addChannelLessItems(isZstack, isSegm3D=isSegm3D) + + sizeMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) + sizeMetricsTreeItem.setText(0, "Size measurements") + metricsTreeWidget.addTopLevelItem(sizeMetricsTreeItem) + + size_metrics_desc = measurements.get_size_metrics_desc(isSegm3D, True) + self.addTreeItems(sizeMetricsTreeItem, size_metrics_desc.keys(), isCol=True) + + propMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) + propMetricsTreeItem.setText(0, "Region properties") + metricsTreeWidget.addTopLevelItem(propMetricsTreeItem) + + props_names = measurements.get_props_names() + self.addTreeItems(propMetricsTreeItem, props_names, isCol=True) + + operatorsLayout = QHBoxLayout() + operatorsLayout.addStretch(1) + + iconSize = 24 + + self.operatorButtons = [] + self.operators = [ + ("add", "+"), + ("subtract", "-"), + ("multiply", "*"), + ("divide", "/"), + ("open_bracket", "("), + ("close_bracket", ")"), + ("square", "**2"), + ("pow", "**"), + ("ln", "log("), + ("log10", "log10("), + ] + operatorFont = QFont() + operatorFont.setPixelSize(16) + for name, text in self.operators: + button = QPushButton() + button.setIcon(QIcon(f":{name}.svg")) + button.setIconSize(QSize(iconSize, iconSize)) + button.text = text + operatorsLayout.addWidget(button) + self.operatorButtons.append(button) + button.clicked.connect(self.addOperator) + # button.setFont(operatorFont) + + clearButton = QPushButton() + clearButton.setIcon(QIcon(":clear.svg")) + clearButton.setIconSize(QSize(iconSize, iconSize)) + clearButton.setFont(operatorFont) + + clearEntryButton = QPushButton() + clearEntryButton.setIcon(QIcon(":backspace.svg")) + clearEntryButton.setFont(operatorFont) + clearEntryButton.setIconSize(QSize(iconSize, iconSize)) + + operatorsLayout.addWidget(clearButton) + operatorsLayout.addWidget(clearEntryButton) + operatorsLayout.addStretch(1) + + newColNameLayout = QVBoxLayout() + newColNameLineEdit = widgets.alphaNumericLineEdit() + newColNameLineEdit.setAlignment(Qt.AlignCenter) + self.newColNameLineEdit = newColNameLineEdit + newColNameLayout.addStretch(1) + newColNameLayout.addWidget(QLabel("New measurement name:")) + newColNameLayout.addWidget(newColNameLineEdit) + newColNameLayout.addStretch(1) + + equationDisplayLayout = QVBoxLayout() + equationDisplayLayout.addWidget(QLabel("Equation:")) + equationDisplay = QPlainTextEdit() + # equationDisplay.setReadOnly(True) + self.equationDisplay = equationDisplay + equationDisplayLayout.addWidget(equationDisplay) + equationDisplayLayout.setStretch(0, 0) + equationDisplayLayout.setStretch(1, 1) + + equationLayout.addLayout(newColNameLayout) + equationLayout.addWidget(QLabel(" = ")) + equationLayout.addLayout(equationDisplayLayout) + equationLayout.setStretch(0, 1) + equationLayout.setStretch(1, 0) + equationLayout.setStretch(2, 2) + + testOutputLayout = QVBoxLayout() + testOutputLayout.addWidget(QLabel("Result of test with random inputs:")) + testOutputDisplay = QTextEdit() + testOutputDisplay.setReadOnly(True) + self.testOutputDisplay = testOutputDisplay + testOutputLayout.addWidget(testOutputDisplay) + testOutputLayout.setStretch(0, 0) + testOutputLayout.setStretch(1, 1) + + instructions = html_utils.paragraph(""" + Double-click on any of the available measurements + to add it to the equation.

+ NOTE: the result will be saved in the acdc_output.csv + file as a column with the same name
+ you enter in "New measurement name" + field.

+ """) + + buttonsLayout = QHBoxLayout() + + cancelButton = widgets.cancelPushButton("Cancel") + helpButton = widgets.infoPushButton(" Help...") + testButton = widgets.calcPushButton("Test output") + okButton = widgets.okPushButton(" Ok ") + okButton.setDisabled(True) + self.okButton = okButton + + buttonsLayout.addStretch(1) + + if debug: + debugButton = QPushButton("Debug") + debugButton.clicked.connect(self._debug) + buttonsLayout.addWidget(debugButton) + + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(helpButton) + buttonsLayout.addWidget(testButton) + buttonsLayout.addWidget(okButton) + + mainLayout.addWidget(QLabel(instructions)) + mainLayout.addWidget(QLabel("Available measurements:")) + mainLayout.addWidget(metricsTreeWidget) + mainLayout.addLayout(operatorsLayout) + mainLayout.addLayout(equationLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addLayout(testOutputLayout) + + clearButton.clicked.connect(self.clearEquation) + clearEntryButton.clicked.connect(self.clearEntryEquation) + metricsTreeWidget.itemDoubleClicked.connect(self.addColname) + + helpButton.clicked.connect(self.showHelp) + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + testButton.clicked.connect(self.test_cb) + + self.setLayout(mainLayout) + self.setFont(font) + + self.setStyleSheet(TREEWIDGET_STYLESHEET) + + def addChannelLessItems(self, isZstack, isSegm3D=False): + allChannelsTreeItem = QTreeWidgetItem(self.metricsTreeWidget) + allChannelsTreeItem.setText(0, f"All channels measurements") + metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( + isZstack, "", isSegm3D=isSegm3D + ) + custom_metrics_desc = measurements.custom_metrics_desc( + isZstack, "", isSegm3D=isSegm3D + ) + + foregrMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) + foregrMetricsTreeItem.setText(0, "Cell signal measurements") + allChannelsTreeItem.addChild(foregrMetricsTreeItem) + + bkgrMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) + bkgrMetricsTreeItem.setText(0, "Background values") + allChannelsTreeItem.addChild(bkgrMetricsTreeItem) + + if custom_metrics_desc: + customMetricsTreeItem = QTreeWidgetItem(allChannelsTreeItem) + customMetricsTreeItem.setText(0, "Custom measurements") + allChannelsTreeItem.addChild(customMetricsTreeItem) + + self.addTreeItems( + foregrMetricsTreeItem, metrics_desc.keys(), isCol=True, isChannelLess=True + ) + self.addTreeItems( + bkgrMetricsTreeItem, bkgr_val_desc.keys(), isCol=True, isChannelLess=True + ) + + if custom_metrics_desc: + self.addTreeItems( + customMetricsTreeItem, + custom_metrics_desc.keys(), + isCol=True, + isChannelLess=True, + ) + + def addOperator(self): + button = self.sender() + text = f"{self.equationDisplay.toPlainText()}{button.text}" + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(button.text)) + + def clearEquation(self): + self.isOperatorMode = False + self.equationDisplay.setPlainText("") + self.initAttributes() + + def initAttributes(self): + self.clearLenghts = [] + self.equationColNames = [] + self.channelLessColnames = [] + + def clearEntryEquation(self): + if not self.clearLenghts: + return + + text = self.equationDisplay.toPlainText() + newText = text[: -self.clearLenghts[-1]] + clearedText = text[-self.clearLenghts[-1] :] + self.clearLenghts.pop(-1) + self.equationDisplay.setPlainText(newText) + if clearedText in self.equationColNames: + self.equationColNames.remove(clearedText) + if clearedText in self.channelLessColnames: + self.channelLessColnames.remove(clearedText) + + def addTreeItems(self, parentItem, itemsText, isCol=False, isChannelLess=False): + for text in itemsText: + _item = QTreeWidgetItem(parentItem) + _item.setText(0, text) + parentItem.addChild(_item) + if isCol: + _item.isCol = True + _item.isChannelLess = isChannelLess + + def addColname(self, item, column): + if not hasattr(item, "isCol"): + return + + colName = item.text(0) + text = f"{self.equationDisplay.toPlainText()}{colName}" + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(colName)) + self.equationColNames.append(colName) + if item.isChannelLess: + self.channelLessColnames.append(colName) + + def _debug(self): + print(self.getEquationsDict()) + + def getEquationsDict(self): + equation = self.equationDisplay.toPlainText() + newColName = self.newColNameLineEdit.text() + if not self.channelLessColnames: + chNamesInTerms = set() + for term in self.equationColNames: + for chName in self.allChNames: + if chName in term: + chNamesInTerms.add(chName) + if len(chNamesInTerms) == 1: + # Equation uses metrics from a single channel --> append channel name + chName = chNamesInTerms.pop() + chColName = f"{chName}_{newColName}" + isMixedChannels = False + return {chColName: equation}, isMixedChannels + else: + # Equation doesn't use all channels metrics nor is single channel + isMixedChannels = True + return {newColName: equation}, isMixedChannels + + isMixedChannels = False + equations = {} + for chName in self.allChNames: + chEquation = equation + chEquationName = newColName + # Append each channel name to channelLess terms + for colName in self.channelLessColnames: + chColName = f"{chName}{colName}" + chEquation = chEquation.replace(colName, chColName) + chEquationName = f"{chName}_{newColName}" + equations[chEquationName] = chEquation + return equations, isMixedChannels + + def ok_cb(self): + if not self.newColNameLineEdit.text(): + self.warnEmptyEquationName() + return + + self.cancel = False + + # Save equation to "/acdc-metrics/combine_metrics.ini" file + config = measurements.read_saved_user_combine_config() + + equationsDict, isMixedChannels = self.getEquationsDict() + for newColName, equation in equationsDict.items(): + config = measurements.add_user_combine_metrics( + config, equation, newColName, isMixedChannels + ) + + isChannelLess = len(self.channelLessColnames) > 0 + if isChannelLess: + channelLess_equation = self.equationDisplay.toPlainText() + equation_name = self.newColNameLineEdit.text() + config = measurements.add_channelLess_combine_metrics( + config, channelLess_equation, equation_name, self.channelLessColnames + ) + + measurements.save_common_combine_metrics(config) + + self.sigOk.emit(self) + + if self.closeOnOk: + self.close() + + def warnEmptyEquationName(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + "New measurement name" field cannot be empty! + """) + msg.critical(self, "Empty new measurement name", txt) + + def showHelp(self): + txt = measurements.get_combine_metrics_help_txt() + msg = widgets.myMessageBox( + showCentered=False, + wrapText=False, + scrollableText=True, + enlargeWidthFactor=1.7, + ) + path = measurements.acdc_metrics_path + msg.addShowInFileManagerButton(path, txt="Show saved file...") + msg.information(self, "Combine measurements help", txt) + + def test_cb(self): + # Evaluate equation with random inputs + equation = self.equationDisplay.toPlainText() + random_data = np.random.rand(1, len(self.equationColNames)) * 5 + df = pd.DataFrame(data=random_data, columns=self.equationColNames).round(5) + newColName = self.newColNameLineEdit.text() + try: + df[newColName] = df.eval(equation) + except Exception as e: + traceback.print_exc() + self.testOutputDisplay.setHtml(html_utils.paragraph(e)) + self.testOutputDisplay.setStyleSheet("border: 2px solid red") + return + + self.testOutputDisplay.setStyleSheet("border: 2px solid green") + self.okButton.setDisabled(False) + + result = df.round(5).iloc[0][newColName] + + # Substitute numbers into equation + inputs = df.iloc[0] + equation_numbers = equation + for c, col in enumerate(self.equationColNames): + equation_numbers = equation_numbers.replace(col, str(inputs[c])) + + # Format output into html text + cols = self.equationColNames + inputs_txt = [f"{col} = {input}" for col, input in zip(cols, inputs)] + list_html = html_utils.to_list(inputs_txt) + text = html_utils.paragraph(f""" + By substituting the following random inputs: + {list_html} + we get the equation:

+   {newColName} = {equation_numbers}

+ that equals to:

+   {newColName} = {result} + """) + self.testOutputDisplay.setHtml(text) + + +class CombineMetricsMultiDfsDialog(QBaseDialog): + sigOk = Signal(object, object) + sigClose = Signal(bool) + + def __init__(self, acdcDfs, allChNames, parent=None, debug=False): + super().__init__(parent) + + self.setWindowTitle("Add combined measurement") + + self.initAttributes() + + self.acdcDfs = acdcDfs + self.cancel = True + self.isOperatorMode = False + + mainLayout = QVBoxLayout() + equationLayout = QHBoxLayout() + + treesLayout = QHBoxLayout() + for i, (acdc_df_endname, acdc_df) in enumerate(acdcDfs.items()): + metricsTreeWidget = QTreeWidget() + metricsTreeWidget.setHeaderHidden(True) + metricsTreeWidget.setFont(font) + + classified_metrics = measurements.classify_acdc_df_colnames( + acdc_df, allChNames + ) + + for chName in allChNames: + channelTreeItem = QTreeWidgetItem(metricsTreeWidget) + channelTreeItem.setText(0, f"{chName} measurements") + metricsTreeWidget.addTopLevelItem(channelTreeItem) + + standard_metrics = classified_metrics["foregr"][chName] + bkgr_metrics = classified_metrics["bkgr"][chName] + custom_metrics = classified_metrics["custom"][chName] + + if standard_metrics: + foregrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + foregrMetricsTreeItem.setText(0, "Cell signal measurements") + channelTreeItem.addChild(foregrMetricsTreeItem) + self.addTreeItems( + foregrMetricsTreeItem, standard_metrics, isCol=True, index=i + ) + + if bkgr_metrics: + bkgrMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + bkgrMetricsTreeItem.setText(0, "Background values") + channelTreeItem.addChild(bkgrMetricsTreeItem) + self.addTreeItems( + bkgrMetricsTreeItem, bkgr_metrics, isCol=True, index=i + ) + + if custom_metrics: + customMetricsTreeItem = QTreeWidgetItem(channelTreeItem) + customMetricsTreeItem.setText(0, "Custom measurements") + channelTreeItem.addChild(customMetricsTreeItem) + self.addTreeItems( + customMetricsTreeItem, custom_metrics, isCol=True, index=i + ) + + if classified_metrics["size"]: + sizeMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) + sizeMetricsTreeItem.setText(0, "Size measurements") + metricsTreeWidget.addTopLevelItem(sizeMetricsTreeItem) + self.addTreeItems( + sizeMetricsTreeItem, classified_metrics["size"], isCol=True, index=i + ) + + if classified_metrics["props"]: + propMetricsTreeItem = QTreeWidgetItem(metricsTreeWidget) + propMetricsTreeItem.setText(0, "Region properties") + metricsTreeWidget.addTopLevelItem(propMetricsTreeItem) + self.addTreeItems( + propMetricsTreeItem, + classified_metrics["props"], + isCol=True, + index=i, + ) + + treeLayout = QVBoxLayout() + treeTitle = QLabel( + html_utils.paragraph( + f"{i + 1}. {acdc_df_endname} measurements " + ) + ) + treeLayout.addWidget(treeTitle) + treeLayout.addWidget(metricsTreeWidget) + treesLayout.addLayout(treeLayout) + + metricsTreeWidget.index = i + metricsTreeWidget.itemDoubleClicked.connect(self.addColname) + + operatorsLayout = QHBoxLayout() + operatorsLayout.addStretch(1) + + iconSize = 24 + + self.operatorButtons = [] + self.operators = [ + ("add", "+"), + ("subtract", "-"), + ("multiply", "*"), + ("divide", "/"), + ("open_bracket", "("), + ("close_bracket", ")"), + ("square", "**2"), + ("pow", "**"), + ("ln", "log("), + ("log10", "log10("), + ] + operatorFont = QFont() + operatorFont.setPixelSize(16) + for name, text in self.operators: + button = QPushButton() + button.setIcon(QIcon(f":{name}.svg")) + button.setIconSize(QSize(iconSize, iconSize)) + button.text = text + operatorsLayout.addWidget(button) + self.operatorButtons.append(button) + button.clicked.connect(self.addOperator) + # button.setFont(operatorFont) + + clearButton = QPushButton() + clearButton.setIcon(QIcon(":clear.svg")) + clearButton.setIconSize(QSize(iconSize, iconSize)) + clearButton.setFont(operatorFont) + + clearEntryButton = QPushButton() + clearEntryButton.setIcon(QIcon(":backspace.svg")) + clearEntryButton.setFont(operatorFont) + clearEntryButton.setIconSize(QSize(iconSize, iconSize)) + + operatorsLayout.addWidget(clearButton) + operatorsLayout.addWidget(clearEntryButton) + operatorsLayout.addStretch(1) + + newColNameLayout = QVBoxLayout() + newColNameLineEdit = widgets.alphaNumericLineEdit() + newColNameLineEdit.setAlignment(Qt.AlignCenter) + self.newColNameLineEdit = newColNameLineEdit + newColNameLayout.addStretch(1) + newColNameLayout.addWidget(QLabel("New measurement name:")) + newColNameLayout.addWidget(newColNameLineEdit) + newColNameLayout.addStretch(1) + + equationDisplayLayout = QVBoxLayout() + equationDisplayLayout.addWidget(QLabel("Equation:")) + equationDisplay = QPlainTextEdit() + # equationDisplay.setReadOnly(True) + self.equationDisplay = equationDisplay + equationDisplayLayout.addWidget(equationDisplay) + equationDisplayLayout.setStretch(0, 0) + equationDisplayLayout.setStretch(1, 1) + + equationLayout.addLayout(newColNameLayout) + equationLayout.addWidget(QLabel(" = ")) + equationLayout.addLayout(equationDisplayLayout) + equationLayout.setStretch(0, 1) + equationLayout.setStretch(1, 0) + equationLayout.setStretch(2, 2) + + instructions = html_utils.paragraph(""" + Double-click on any of the available measurements + to add it to the equation.

+ NOTE: the result will be saved in a new acdc_output + file as a column with the same name
+ you enter in "New measurement name" + field.

+ """) + + buttonsLayout = QHBoxLayout() + + cancelButton = widgets.cancelPushButton("Cancel") + testButton = widgets.calcPushButton("Test equation") + okButton = widgets.okPushButton(" Ok ") + okButton.setDisabled(True) + self.okButton = okButton + + if debug: + debugButton = QPushButton("Debug") + debugButton.clicked.connect(self._debug) + buttonsLayout.addWidget(debugButton) + + self.statusLabel = QLabel() + buttonsLayout.addWidget(self.statusLabel) + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(testButton) + buttonsLayout.addWidget(okButton) + + mainLayout.addWidget(QLabel(instructions)) + mainLayout.addLayout(treesLayout) + mainLayout.addLayout(operatorsLayout) + mainLayout.addLayout(equationLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + clearButton.clicked.connect(self.clearEquation) + clearEntryButton.clicked.connect(self.clearEntryEquation) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + testButton.clicked.connect(self.test_cb) + + self.equationDisplay.textChanged.connect(self.equationChanged) + # self.newColNameLineEdit.editingFinished.connect(self.equationChanged) + + self.setLayout(mainLayout) + self.setFont(font) + + self.setStyleSheet(TREEWIDGET_STYLESHEET) + + def setLogger(self, logger, logs_path, log_path): + self.logger = logger + self.logs_path = logs_path + self.log_path = log_path + + def closeEvent(self, event): + self.sigClose.emit(self.cancel) + return super().closeEvent(event) + + def getCombinedDf(self): + dfs = [] + for i, acdc_df in enumerate(self.acdcDfs.values()): + dfs.append(acdc_df.add_suffix(f"_table{i + 1}")) + return pd.concat(dfs, axis=1) + + def _log(self, txt): + if hasattr(self, "logger"): + self.logger.info(txt) + else: + print(f"[INFO]: {txt}") + + def equationChanged(self): + self.okButton.setDisabled(True) + self.statusLabel.setText("") + + @exception_handler + def test_cb(self): + combined_df = self.getCombinedDf() + new_df = pd.DataFrame(index=combined_df.index) + equation = self.equationDisplay.toPlainText() + newColName = self.newColNameLineEdit.text() + new_df[newColName] = combined_df.eval(equation) + self.okButton.setDisabled(False) + self._log("Equation test was successful.") + self.statusLabel.setText("Equation test was successful. You can now click OK.") + + def addOperator(self): + button = self.sender() + text = f"{self.equationDisplay.toPlainText()}{button.text}" + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(button.text)) + + def clearEquation(self): + self.isOperatorMode = False + self.equationDisplay.setPlainText("") + self.initAttributes() + + def initAttributes(self): + self.clearLenghts = [] + self.equationColNames = [] + self.channelLessColnames = [] + + def clearEntryEquation(self): + if not self.clearLenghts: + return + + text = self.equationDisplay.toPlainText() + newText = text[: -self.clearLenghts[-1]] + clearedText = text[-self.clearLenghts[-1] :] + self.clearLenghts.pop(-1) + self.equationDisplay.setPlainText(newText) + if clearedText in self.equationColNames: + self.equationColNames.remove(clearedText) + if clearedText in self.channelLessColnames: + self.channelLessColnames.remove(clearedText) + + def addTreeItems( + self, parentItem, itemsText, isCol=False, isChannelLess=False, index=None + ): + for text in itemsText: + _item = QTreeWidgetItem(parentItem) + _item.setText(0, text) + parentItem.addChild(_item) + if isCol: + _item.isCol = True + if index is not None: + _item.index = index + _item.isChannelLess = isChannelLess + + def addColname(self, item, column): + if not hasattr(item, "isCol"): + return + + colName = f"{item.text(0)}_table{item.index + 1}" + text = f"{self.equationDisplay.toPlainText()}{colName}" + + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(colName)) + self.equationColNames.append(colName) + if item.isChannelLess: + self.channelLessColnames.append(colName) + + def _debug(self): + print(self.getEquationsDict()) + + def ok_cb(self): + if not self.newColNameLineEdit.text(): + self.warnEmptyEquationName() + return + if not self.equationDisplay.toPlainText(): + self.warnEmptyEquation() + return + + self.expression = self.equationDisplay.toPlainText() + self.newColname = self.newColNameLineEdit.text() + self.cancel = False + self.sigOk.emit(self.newColname, self.expression) + self.close() + + def warnEmptyEquation(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + "Equation" field cannot be empty! + """) + msg.critical(self, "Empty equation", txt) + + def warnEmptyEquationName(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + "New measurement name" field cannot be empty! + """) + msg.critical(self, "Empty new measurement name", txt) + + +class CombineMetricsMultiDfsSummaryDialog(QBaseDialog): + sigLoadAdditionalAcdcDf = Signal() + + def __init__(self, acdcDfs, allChNames, parent=None, debug=False): + super().__init__(parent) + + self.editedIndex = None + self.cancel = True + self.acdcDfs = acdcDfs + self.allChNames = allChNames + + self.setWindowTitle("Combine measurements summary") + + mainLayout = QVBoxLayout() + viewLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + + row = 0 + txt = html_utils.paragraph("Selected acdc_output tables:") + viewLayout.addWidget(QLabel(txt), row, 0) + + row += 1 + items = [ + f"• Table {i + 1}: {e}" + for i, e in enumerate(acdcDfs.keys()) + ] + selectedAcdcDfsList = widgets.readOnlyQList() + selectedAcdcDfsList.addItems(items) + self.selectedAcdcDfsList = selectedAcdcDfsList + + tablesButtonsLayout = QVBoxLayout() + loadAcdcDfButton = widgets.showInFileManagerButton("Load additional tables") + tablesButtonsLayout.addWidget(loadAcdcDfButton) + + loadEquationsButton = widgets.reloadPushButton("Load previously used equations") + tablesButtonsLayout.addWidget(loadEquationsButton) + + tablesButtonsLayout.addStretch(1) + + viewLayout.addWidget(selectedAcdcDfsList, row, 0) + viewLayout.addLayout(tablesButtonsLayout, row, 1) + viewLayout.setRowStretch(row, 1) + + row += 1 + txt = html_utils.paragraph("Equations:") + viewLayout.addWidget(QLabel(txt), row, 0) + + row += 1 + self.equationsList = widgets.TreeWidget() + self.equationsList.setFont(font) + self.equationsList.setHeaderLabels(["Metric", "Expression"]) + self.equationsList.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + + equationsButtonsLayout = QVBoxLayout() + addEquationButton = widgets.addPushButton("Add metric") + removeEquationButton = widgets.subtractPushButton("Remove metric(s)") + editEquationButton = widgets.editPushButton("Edit metric") + removeEquationButton.setDisabled(True) + editEquationButton.setDisabled(True) + self.removeEquationButton = removeEquationButton + self.editEquationButton = editEquationButton + + equationsButtonsLayout.addWidget(addEquationButton) + equationsButtonsLayout.addWidget(removeEquationButton) + equationsButtonsLayout.addWidget(editEquationButton) + equationsButtonsLayout.addStretch(1) + + viewLayout.addWidget(self.equationsList, row, 0) + viewLayout.addLayout(equationsButtonsLayout, row, 1) + viewLayout.setRowStretch(row, 2) + + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton("Ok") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + viewLayout.setVerticalSpacing(10) + mainLayout.addLayout(viewLayout) + mainLayout.addSpacing(10) + mainLayout.addLayout(buttonsLayout) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + addEquationButton.clicked.connect(self.addEquation_cb) + loadAcdcDfButton.clicked.connect(self.loadButtonClicked) + loadEquationsButton.clicked.connect(self.loadEquationsButtonClicked) + removeEquationButton.clicked.connect(self.removeButtonClicked) + editEquationButton.clicked.connect(self.editButtonClicked) + self.equationsList.itemSelectionChanged.connect( + self.onEquationItemSelectionChanged + ) + + self.setLayout(mainLayout) + + def setLogger(self, logger, logs_path, log_path): + self.logger = logger + self.logs_path = logs_path + self.log_path = log_path + + def loadEquationsButtonClicked(self): + MostRecentPath = utils.getMostRecentPath() + file_path = QFileDialog.getOpenFileName( + self, + "Select equations file", + MostRecentPath, + "Config Files (*.ini);;All Files (*)", + )[0] + if file_path == "": + return + + cp = config.ConfigParser() + cp.read(file_path) + sectionToMatch = [f"table{i + 1}:{end}" for i, end in enumerate(self.acdcDfs)] + sectionToMatch = ";".join(sectionToMatch) + + lists = {} + nonMatchingLists = {} + groupsDescr = {} + + for section in cp.sections(): + # Tag acdc_output names with html and table(\d+) with html bold tag + listName = ";".join( + [ + re.sub( + r"table(\d+):(.*)", r"table\g<1>: \g<2>", s + ) + for s in section.split(";") + ] + ) + listName = listName.replace(";", " ; ") + children = [f"{opt} = {cp[section][opt]}" for opt in cp[section]] + if section == sectionToMatch: + groupsDescr[listName] = ( + "Equations that were calculated from the same " + "table names you loaded" + ) + lists[listName] = children + else: + groupsDescr[listName] = ( + "Equations that were calculated from table names that " + "you did not load now" + ) + nonMatchingLists[listName] = children + # # Not implemented yet --> selecting from non matching table names + # # would require an additional widget where the user sets + # # what df1 and df2 are. + # trees[treeName] = children + + if not lists: + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(""" + None of the equations in the selected file used the same + table names that you loaded.

+ See below which table names and equations are present in the loaded file. + """) + with open(file_path) as iniFile: + detailedText = iniFile.read() + + msg.warning(self, "Not the same tables", txt, showDialog=False) + msg.setDetailedText(detailedText, visible=True) + msg.addShowInFileManagerButton(os.path.dirname(file_path)) + msg.exec_() + return + + selectWindow = MultiListSelector( + lists, + groupsDescr=groupsDescr, + title="Select equations to load", + infoTxt="Select equations you want to load", + ) + selectWindow.exec_() + if selectWindow.cancel or not selectWindow.selectedItems: + return + + for listName, equations in selectWindow.selectedItems.items(): + for equation in equations: + metricName, expression = equation.split(" = ") + self.addEquation(metricName, expression) + + def ok_cb(self): + self.cancel = False + self.equations = {} + for i in range(self.equationsList.topLevelItemCount()): + item = self.equationsList.topLevelItem(i) + self.equations[item.text(0)] = item.text(1) + + self.close() + + def loadButtonClicked(self): + self.sigLoadAdditionalAcdcDf.emit() + + def removeButtonClicked(self): + for item in self.equationsList.selectedItems(): + self.equationsList.invisibleRootItem().removeChild(item) + + def editButtonClicked(self): + self.editedItem = self.equationsList.selectedItems()[0] + self.editedIndex = self.equationsList.indexOfTopLevelItem(self.editedItem) + self.addEquation_cb() + + def onEquationItemSelectionChanged(self): + selectedItems = self.equationsList.selectedItems() + if len(selectedItems) == 1: + self.editEquationButton.setDisabled(False) + self.removeEquationButton.setDisabled(False) + elif len(selectedItems) > 1: + self.removeEquationButton.setDisabled(False) + self.editEquationButton.setDisabled(True) + else: + self.removeEquationButton.setDisabled(True) + self.editEquationButton.setDisabled(True) + + def addAcdcDfs(self, acdcDfsDict): + self.acdcDfs = {**self.acdcDfs, **acdcDfsDict} + items = [ + f"• Table {i + 1}: {e}" + for i, e in enumerate(self.acdcDfs.keys()) + ] + self.selectedAcdcDfsList = widgets.readOnlyQList() + self.selectedAcdcDfsList.addItems(items) + + def addEquation(self, newColname, expression): + if self.editedIndex is not None: + self.equationsList.invisibleRootItem().removeChild(self.editedItem) + bkgrColor = QColor(*BACKGROUND_RGBA[:3], 200) + item = widgets.TreeWidgetItem( + self.equationsList, columnColors=[None, bkgrColor] + ) + item.setText(0, newColname) + item.setText(1, expression) + if self.editedIndex is not None: + self.equationsList.insertTopLevelItem(self.editedIndex, item) + else: + self.equationsList.addTopLevelItem(item) + self.equationsList.resizeColumnToContents(0) + self.equationsList.resizeColumnToContents(1) + self.editedIndex = None + + def addEquation_cb(self): + self.addEquationWin = CombineMetricsMultiDfsDialog( + self.acdcDfs, self.allChNames, parent=self + ) + if hasattr(self, "logger"): + self.addEquationWin.setLogger(self.logger, self.logs_path, self.log_path) + if self.editedIndex is not None: + editedMetricName = self.editedItem.text(0) + self.addEquationWin.newColNameLineEdit.setText(editedMetricName) + editedExpression = self.editedItem.text(1) + self.addEquationWin.equationDisplay.setPlainText(editedExpression) + self.addEquationWin.show() + self.addEquationWin.sigOk.connect(self.addEquation) + self.addEquationWin.sigClose.connect(self.addEquationClosed) + + def addEquationClosed(self, cancelled): + if cancelled: + self.editedIndex = None + + def showEvent(self, event) -> None: + self.resize(int(self.width() * 2), self.height()) + + +class SelectFeaturesRange: + def __init__( + self, posData, force_postprocess_2D=False, qparent=None, sigValueChanged=None + ) -> None: + self.posData = posData + self.qparent = qparent + self.force_postprocess_2D = force_postprocess_2D + self.sigValueChanged = sigValueChanged + + self.lowRangeWidgets = widgets.CheckableSpinBoxWidgets() + self.highRangeWidgets = widgets.CheckableSpinBoxWidgets() + + self.selectButton = widgets.FeatureSelectorButton("Click to select feature...") + self.selectButton.setSizeLongestText( + "Spotfit intens. metric, Foregr. integral gauss. peak" + ) + self.selectButton.clicked.connect(self.selectFeature) + self.selectButton.setCursor(Qt.PointingHandCursor) + + self.selectedFeatureGroups = {} + + self.widgets = [ + {"pos": (0, 0), "widget": self.lowRangeWidgets.checkbox}, + {"pos": (1, 0), "widget": self.lowRangeWidgets.spinbox}, + {"pos": (1, 1), "widget": widgets.LessThanPushButton(flat=True)}, + {"pos": (1, 2), "widget": self.selectButton}, + {"pos": (1, 3), "widget": widgets.LessThanPushButton(flat=True)}, + {"pos": (0, 4), "widget": self.highRangeWidgets.checkbox}, + {"pos": (1, 4), "widget": self.highRangeWidgets.spinbox}, + {"pos": (2, 0), "widget": widgets.VerticalSpacerEmptyWidget(height=10)}, + ] + self.columnsStretches = {0: 0, 1: 0, 2: 1, 3: 0, 4: 0} + + def setText(self, text): + self.selectButton.setText(text) + + def selectFeature(self): + loadedChNames = [self.posData.user_ch_name] + notLoadedChNames = [] + isZstack = self.posData.SizeZ > 1 and not self.force_postprocess_2D + isSegm3D = self.posData.isSegm3D and not self.force_postprocess_2D + self.selectFeatureDialog = SetMeasurementsDialog( + loadedChNames, + notLoadedChNames, + isZstack, + isSegm3D, + posData=self.posData, + parent=self.qparent, + isSingleSelection=True, + is_concat=True, + ) + # self.selectFeatureDialog.resizeVertical() + self.selectFeatureDialog.sigClosed.connect(self.setFeatureText) + self.selectFeatureDialog.show() + + def setFeatureText(self): + if self.selectFeatureDialog.cancel: + return + self.selectButton.setFlat(True) + selectedMetricName, selectedMetricGroup = ( + self.selectFeatureDialog.selectedMetricNameAndGroup() + ) + self.selectButton.setText(selectedMetricName) + self.featureGroup = selectedMetricGroup + + +class SelectFeaturesRangeDialog(QBaseDialog): + sigValueChanged = Signal(object) + + def __init__(self, posData=None, parent=None, force_postprocess_2D=False): + super().__init__(parent) + + self.force_postprocess_2D = force_postprocess_2D + + layout = QVBoxLayout() + self.setWindowTitle("Custom features for post-processing") + + self.groupbox = SelectFeaturesRangeGroupbox( + posData=posData, parent=parent, force_postprocess_2D=force_postprocess_2D + ) + + buttonsLayout = QHBoxLayout() + okPushButton = widgets.okPushButton(" Ok ") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(okPushButton) + + okPushButton.clicked.connect(self.ok_cb) + + layout.addWidget(self.groupbox) + layout.addSpacing(10) + layout.addLayout(buttonsLayout) + + self.setLayout(layout) + + def ok_cb(self): + if self.groupbox.selectedFeaturesRange(): + self.sigValueChanged.emit(None) + self.hide() + + +class SelectFeaturesRangeGroupbox(QGroupBox): + def __init__(self, posData=None, parent=None, force_postprocess_2D=False): + super().__init__(parent) + + self.setTitle("Features and thresholds for filtering segmented objects") + # self.setCheckable(True) + + self.posData = posData + self.force_postprocess_2D = force_postprocess_2D + + self._layout = QGridLayout() + self._layout.setVerticalSpacing(0) + + firstSelector = SelectFeaturesRange( + posData, force_postprocess_2D=force_postprocess_2D + ) + self.addButton = widgets.addPushButton(" Add feature ") + self.addButton.setSizePolicy( + QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding + ) + for col, widget in enumerate(firstSelector.widgets): + row, col = widget["pos"] + self._layout.addWidget(widget["widget"], row, col) + for col, stretch in firstSelector.columnsStretches.items(): + self._layout.setColumnStretch(col, stretch) + + lastCol = self._layout.columnCount() + self._layout.addWidget(self.addButton, 0, lastCol + 1, 2, 1) + self.lastCol = lastCol + 1 + self.selectors = [firstSelector] + + self.setLayout(self._layout) + + # self.setFont(font) + + self.addButton.clicked.connect(self.addFeatureField) + + def addFeatureField(self): + row = self._layout.rowCount() + selector = SelectFeaturesRange( + self.posData, force_postprocess_2D=self.force_postprocess_2D + ) + delButton = widgets.delPushButton("Remove feature") + delButton.setSizePolicy( + QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding + ) + delButton.selector = selector + selector.delButton = delButton + for col, widget in enumerate(selector.widgets): + relRow, col = widget["pos"] + self._layout.addWidget(widget["widget"], relRow + row, col) + self._layout.addWidget(delButton, row, self.lastCol, 2, 1) + self.selectors.append(selector) + delButton.clicked.connect(self.removeFeatureField) + + def resetFields(self): + while len(self.selectors) > 1: + selector = self.selectors[-1] + selector.delButton.click() + firstSelector = self.selectors[0] + firstSelector.selectButton.setText("Click to select feature...") + firstSelector.lowRangeWidgets.checkbox.setChecked(False) + firstSelector.highRangeWidgets.checkbox.setChecked(False) + + def removeFeatureField(self): + delButton = self.sender() + for widget in delButton.selector.widgets: + self._layout.removeWidget(widget["widget"]) + self._layout.removeWidget(delButton) + self.selectors.remove(delButton.selector) + + def selectedFeaturesRange(self): + featuresRange = {} + for selector in self.selectors: + if selector.selectButton.text().find("Click") != -1: + continue + featuresRange[selector.selectButton.text()] = ( + selector.lowRangeWidgets.value(), + selector.highRangeWidgets.value(), + ) + return featuresRange + + def selectedFeaturesGroup(self): + featuresGroup = {} + for selector in self.selectors: + if selector.selectButton.text().find("Click") != -1: + continue + group = selector.featureGroup + featuresGroup[selector.selectButton.text()] = group + return featuresGroup + + def groupedFeatures(self): + featuresGroup = self.selectedFeaturesGroup() + groupedFeatures = {} + for feature, group in featuresGroup.items(): + group = featuresGroup[feature] + if isinstance(group, str): + key = group + if key not in groupedFeatures: + groupedFeatures[key] = [] + groupedFeatures[key].append(feature) + else: + key, channel = list(group.items())[0] + if key not in groupedFeatures: + groupedFeatures[key] = {} + if channel not in groupedFeatures[key]: + groupedFeatures[key][channel] = [] + groupedFeatures[key][channel].append(feature) + return groupedFeatures + + def setValue(self, value): + pass + + +class CombineFeaturesCalculator(QBaseDialog): + sigOk = Signal(object) + + def __init__( + self, + features_groups: dict, + group_name_to_col_mapper: dict = None, + title="Combine features calculator", + parent=None, + ): + super().__init__(parent) + + self.cancel = True + + self.setWindowTitle(title) + self.initAttributes() + + mainLayout = QVBoxLayout() + equationLayout = QHBoxLayout() + + metricsTreeWidget = QTreeWidget() + metricsTreeWidget.setHeaderHidden(True) + metricsTreeWidget.setFont(font) + self.metricsTreeWidget = metricsTreeWidget + + for groupName, features in features_groups.items(): + topLevelTreeWidgetItem = QTreeWidgetItem(metricsTreeWidget) + topLevelTreeWidgetItem.setText(0, groupName) + metricsTreeWidget.addTopLevelItem(topLevelTreeWidgetItem) + self.addTreeItems( + topLevelTreeWidgetItem, + features, + isCol=True, + name_to_col_mapper=group_name_to_col_mapper.get(groupName), + ) + + operatorsLayout = self.createOperatorsLayout() + newFeatureNameLayout = self.createNewFeatureNameLayout() + equationDisplayLayout = self.createEquationDisplayLayout() + + equationLayout.addLayout(newFeatureNameLayout) + equationLayout.addWidget(QLabel(" = ")) + equationLayout.addLayout(equationDisplayLayout) + equationLayout.setStretch(0, 1) + equationLayout.setStretch(1, 0) + equationLayout.setStretch(2, 2) + + testOutputLayout = self.createTestOutputLayout() + buttonsLayout = self.createButtonsOutputLayout() + + instructions = html_utils.paragraph(""" + Double-click on any of the available measurements + to add it to the equation.

+ Before clicking the `Ok` button, check that the equation returns + the expected result by clicking the `Test output` button. + """) + + mainLayout.addWidget(QLabel(instructions)) + mainLayout.addWidget(QLabel("Available measurements:")) + mainLayout.addWidget(metricsTreeWidget) + mainLayout.addLayout(operatorsLayout) + mainLayout.addLayout(equationLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addLayout(testOutputLayout) + + metricsTreeWidget.itemDoubleClicked.connect(self.addFeatureName) + self.setLayout(mainLayout) + self.setFont(font) + + self.setStyleSheet(TREEWIDGET_STYLESHEET) + + def setExpandedAll(self, expanded): + if expanded: + self.expandAll() + else: + for i in range(self.metricsTreeWidget.topLevelItemCount()): + topLevelItem = self.metricsTreeWidget.topLevelItem(i) + topLevelItem.setExpanded(False) + + def expandAll(self): + for i in range(self.metricsTreeWidget.topLevelItemCount()): + topLevelItem = self.metricsTreeWidget.topLevelItem(i) + topLevelItem.setExpanded(True) + + def addTreeItems(self, parentItem, itemsText, isCol=False, name_to_col_mapper=None): + for text in itemsText: + _item = QTreeWidgetItem(parentItem) + _item.setText(0, text) + parentItem.addChild(_item) + if isCol: + _item.isCol = True + _item.variable_name = text + if name_to_col_mapper is None: + continue + + col_name = name_to_col_mapper.get(text, None) + if col_name is None: + continue + + _item.variable_name = col_name + + def addFeatureName(self, item, column): + if not hasattr(item, "isCol"): + return + + colName = item.variable_name + text = f"{self.equationDisplay.toPlainText()}{colName}" + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(colName)) + self.equationColNames.append(colName) + + def clearEquation(self): + self.isOperatorMode = False + self.equationDisplay.setPlainText("") + self.initAttributes() + + def createButtonsOutputLayout(self): + buttonsLayout = QHBoxLayout() + + cancelButton = widgets.cancelPushButton("Cancel") + helpButton = widgets.infoPushButton(" Help...") + testButton = widgets.calcPushButton("Test output") + okButton = widgets.okPushButton(" Ok ") + okButton.setDisabled(True) + self.okButton = okButton + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(helpButton) + buttonsLayout.addWidget(testButton) + buttonsLayout.addWidget(okButton) + + helpButton.clicked.connect(self.showHelp) + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + testButton.clicked.connect(self.test_cb) + + return buttonsLayout + + def ok_cb(self): + if not self.newFeatureNameLineEdit.text(): + self.warnEmptyEquationName() + return + + self.equation = self.equationDisplay.toPlainText() + self.newFeatureName = self.newFeatureNameLineEdit.text() + self.cancel = False + self.close() + self.sigOk.emit(self) + + def test_cb(self): + # Evaluate equation with random inputs + equation = self.equationDisplay.toPlainText() + random_data = np.random.rand(1, len(self.equationColNames)) * 5 + df = pd.DataFrame(data=random_data, columns=self.equationColNames).round(5) + newColName = self.newFeatureNameLineEdit.text() + try: + df[newColName] = df.eval(equation) + except Exception as e: + traceback.print_exc() + self.testOutputDisplay.setHtml(html_utils.paragraph(e)) + self.testOutputDisplay.setStyleSheet("border: 2px solid red") + return + + self.testOutputDisplay.setStyleSheet("border: 2px solid green") + self.okButton.setDisabled(False) + + result = df.round(5).iloc[0][newColName] + + # Substitute numbers into equation + inputs = df.iloc[0] + equation_numbers = equation + for c, col in enumerate(self.equationColNames): + equation_numbers = equation_numbers.replace(col, str(inputs[c])) + + # Format output into html text + cols = self.equationColNames + inputs_txt = [f"{col} = {input}" for col, input in zip(cols, inputs)] + list_html = html_utils.to_list(inputs_txt) + text = html_utils.paragraph(f""" + By substituting the following random inputs: + {list_html} + we get the equation:

+   {newColName} = {equation_numbers}

+ that equals to:

+   {newColName} = {result} + """) + self.testOutputDisplay.setHtml(text) + + def warnEmptyEquationName(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + "New measurement name" field cannot be empty! + """) + msg.critical(self, "Empty new measurement name", txt) + + def showHelp(self): + pass + + def createTestOutputLayout(self): + testOutputLayout = QVBoxLayout() + testOutputLayout.addWidget(QLabel("Result of test with random inputs:")) + testOutputDisplay = QTextEdit() + testOutputDisplay.setReadOnly(True) + self.testOutputDisplay = testOutputDisplay + testOutputLayout.addWidget(testOutputDisplay) + testOutputLayout.setStretch(0, 0) + testOutputLayout.setStretch(1, 1) + + return testOutputLayout + + def createEquationDisplayLayout(self): + equationDisplayLayout = QVBoxLayout() + equationDisplayLayout.addWidget(QLabel("Equation:")) + equationDisplay = QPlainTextEdit() + # equationDisplay.setReadOnly(True) + self.equationDisplay = equationDisplay + equationDisplayLayout.addWidget(equationDisplay) + equationDisplayLayout.setStretch(0, 0) + equationDisplayLayout.setStretch(1, 1) + return equationDisplayLayout + + def createNewFeatureNameLayout(self): + newFeatureNameLayout = QVBoxLayout() + newFeatureNameLineEdit = widgets.alphaNumericLineEdit() + newFeatureNameLineEdit.setAlignment(Qt.AlignCenter) + self.newFeatureNameLineEdit = newFeatureNameLineEdit + newFeatureNameLayout.addStretch(1) + newFeatureNameLayout.addWidget(QLabel("New measurement name:")) + newFeatureNameLayout.addWidget(newFeatureNameLineEdit) + newFeatureNameLayout.addStretch(1) + return newFeatureNameLayout + + def createOperatorsLayout(self): + operatorsLayout = QHBoxLayout() + operatorsLayout.addStretch(1) + + iconSize = 24 + + self.operatorButtons = [] + self.operators = [ + ("add", "+"), + ("subtract", "-"), + ("multiply", "*"), + ("divide", "/"), + ("open_bracket", "("), + ("close_bracket", ")"), + ("square", "**2"), + ("pow", "**"), + ("ln", "log("), + ("log10", "log10("), + ] + operatorFont = QFont() + operatorFont.setPixelSize(16) + for name, text in self.operators: + button = QPushButton() + button.setIcon(QIcon(f":{name}.svg")) + button.setIconSize(QSize(iconSize, iconSize)) + button.text = text + operatorsLayout.addWidget(button) + self.operatorButtons.append(button) + button.clicked.connect(self.addOperator) + # button.setFont(operatorFont) + + clearButton = QPushButton() + clearButton.setIcon(QIcon(":clear.svg")) + clearButton.setIconSize(QSize(iconSize, iconSize)) + clearButton.setFont(operatorFont) + + clearEntryButton = QPushButton() + clearEntryButton.setIcon(QIcon(":backspace.svg")) + clearEntryButton.setFont(operatorFont) + clearEntryButton.setIconSize(QSize(iconSize, iconSize)) + + operatorsLayout.addWidget(clearButton) + operatorsLayout.addWidget(clearEntryButton) + operatorsLayout.addStretch(1) + + clearButton.clicked.connect(self.clearEquation) + clearEntryButton.clicked.connect(self.clearEntryEquation) + + return operatorsLayout + + def addOperator(self): + button = self.sender() + text = f"{self.equationDisplay.toPlainText()}{button.text}" + self.equationDisplay.setPlainText(text) + self.clearLenghts.append(len(button.text)) + + def clearEquation(self): + self.isOperatorMode = False + self.equationDisplay.setPlainText("") + self.initAttributes() + + def initAttributes(self): + self.clearLenghts = [] + self.equationColNames = [] + self.channelLessColnames = [] + + def clearEntryEquation(self): + if not self.clearLenghts: + return + + text = self.equationDisplay.toPlainText() + newText = text[: -self.clearLenghts[-1]] + clearedText = text[-self.clearLenghts[-1] :] + self.clearLenghts.pop(-1) + self.equationDisplay.setPlainText(newText) + if clearedText in self.equationColNames: + self.equationColNames.remove(clearedText) + if clearedText in self.channelLessColnames: + self.channelLessColnames.remove(clearedText) + +# Sibling imports (deferred to avoid import cycles) +from .metadata import ( + MultiListSelector, + filenameDialog, +) + diff --git a/cellacdc/dialogs/metadata.py b/cellacdc/dialogs/metadata.py new file mode 100644 index 000000000..cdc7ec627 --- /dev/null +++ b/cellacdc/dialogs/metadata.py @@ -0,0 +1,3590 @@ +"""Cell-ACDC dialog windows: metadata.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +class filenameDialog(QDialog): + def __init__( + self, + ext=".npz", + basename="", + title="Insert file name", + hintText="", + existingNames="", + parent=None, + allowEmpty=True, + helpText="", + defaultEntry="", + resizeOnShow=True, + additionalButtons=None, + addDoNotSaveButton=False, + ): + self.cancel = True + super().__init__(parent) + + self.resizeOnShow = resizeOnShow + + if hintText.find("segmentation") != -1: + if helpText: + helpText = f"{helpText}" + helpText_loc = """ + With Cell-ACDC you can create as many segmentation files + as you want.

+ If you plan to create only one file then you can leave the + text entry empty.
+ Cell-ACDC will save the segmentation file with the filename + ending with _segm.npz.

+ However, we recommend to insert some text that will easily + allow you to identify what is the segmentation file about.

+ For example, if you are about to segment the channel + phase_contr, you could write + phase_contr.
+ Cell-ACDC will then save the file with the + filename ending with _segm_phase_contr.npz.

+ This way you can create multiple segmentation files, + for example one for each channel or one for each segmentation model.

+ Note that the numerical features and annotations will be saved + in a CSV file ending with the same text as the segmentation file,
+ e.g., ending with _acdc_output_phase_contr.csv. + """ + helpText = f"{helpText}{html_utils.paragraph(helpText_loc)}" + + self.isSegmFile = basename.endswith("_segm") + self.allowEmpty = allowEmpty + self.basename = basename + if ext and not ext.startswith("."): + ext = f".{ext}" + self.ext = ext + + self.setWindowTitle(title) + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + layout = QVBoxLayout() + entryLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + + hintLabel = QLabel(hintText) + + basenameLabel = QLabel(basename) + + self.lineEdit = widgets.alphaNumericLineEdit(onlyWarn=True) + self.lineEdit.setAlignment(Qt.AlignCenter) + defaultEntry = to_alphanumeric(defaultEntry) + defaultEntry = defaultEntry.replace(".", "_") + self.lineEdit.setText(defaultEntry) + + extLabel = QLabel(ext) + + self.filenameLabel = QLabel() + self.filenameLabel.setText(f"{basename}{ext}") + + entryLayout.addWidget(basenameLabel, 0, 1) + entryLayout.addWidget(self.lineEdit, 0, 2) + entryLayout.addWidget(extLabel, 0, 3) + entryLayout.addWidget(self.filenameLabel, 1, 1, 1, 3, alignment=Qt.AlignCenter) + # entryLayout.setColumnStretch(0, 1) + entryLayout.setColumnStretch(2, 1) + + self.warningInvalidCharLabel = QLabel() + + okButton = widgets.okPushButton("Ok") + cancelButton = widgets.cancelPushButton("Cancel") + self.okButton = okButton + + buttonsLayout.addStretch() + buttonsLayout.addWidget(cancelButton) + + if addDoNotSaveButton: + doNotSaveButton = widgets.noPushButton("Do not save") + doNotSaveButton.clicked.connect(self.doNotSave_cb) + buttonsLayout.addWidget(doNotSaveButton) + self.doNotSave = False + + buttonsLayout.addSpacing(20) + if helpText: + helpButton = widgets.helpPushButton("Help...") + helpButton.clicked.connect(partial(self.showHelp, helpText)) + buttonsLayout.addWidget(helpButton) + if additionalButtons is not None: + for button in additionalButtons: + buttonsLayout.addWidget(button) + buttonsLayout.addWidget(okButton) + + cancelButton.clicked.connect(self.close) + okButton.clicked.connect(self.ok_cb) + self.lineEdit.textChanged.connect(self.updateFilename) + self.lineEdit.sigInvalidCharactersEntered.connect( + self.warnInvalidCharactersEntered + ) + + self.existingNames = [] + if existingNames: + self.existingNames = existingNames + # self.lineEdit.editingFinished.connect(self.checkExistingNames) + + layout.addWidget(hintLabel) + layout.addSpacing(20) + layout.addLayout(entryLayout) + layout.addSpacing(10) + layout.addWidget(self.warningInvalidCharLabel) + layout.addStretch(1) + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + + self.setLayout(layout) + self.setFont(font) + + if defaultEntry: + self.updateFilename(defaultEntry) + + def doNotSave_cb(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Are you sure you do not want to save the file?" + ) + noButton, yesButton = msg.warning( + self, "Do not save?", txt, buttonsTexts=("No", "Yes") + ) + if msg.clickedButton == noButton: + return + + self.doNotSave = True + self.cancel = False + self.close() + + def showHelp(self, text): + text = html_utils.paragraph(text) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Filename help", text) + + def _text(self): + return self.lineEdit.text() + + def warnInvalidCharactersEntered(self, characters: set[str]): + statement = "is not a valid character" + if len(characters) > 1: + statement = "are not valid characters" + + characters_str = "".join(characters) + characters_str = html.escape(characters_str) + warning_text = html_utils.span(f""" + WARNING: "{characters_str}" {statement}.
+ """) + warning_text = ( + f"{warning_text}" + "Valid characters are letters, numbers, underscore, and dash." + ) + self.warningInvalidCharLabel.setText(warning_text) + + def checkExistingNames(self): + is_existing = ( + self._text() in self.existingNames + or self.filenameLabel.text() in self.existingNames + ) + if not is_existing: + return True + + filename = self.filenameLabel.text() + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "The following file

" + f"{filename}

" + "is already existing.

" + "Do you want to overwrite the existing file?" + ) + noButton, yesButton = msg.warning( + self, "File name existing", txt, buttonsTexts=("No", "Yes") + ) + return msg.clickedButton == yesButton + + def updateFilename(self, text): + if self.lineEdit.invalidCharacters(): + return + + if not text: + self.filenameLabel.setText(f"{self.basename}{self.ext}") + else: + text = text.replace(" ", "_") + if self.basename: + if self.basename.endswith("_"): + self.filenameLabel.setText(f"{self.basename}{text}{self.ext}") + else: + self.filenameLabel.setText(f"{self.basename}_{text}{self.ext}") + else: + self.filenameLabel.setText(f"{text}{self.ext}") + + self.warningInvalidCharLabel.setText("") + + def checkEmptyText(self): + if self.allowEmpty: + return True + + if self._text(): + return True + + msg = widgets.myMessageBox() + msg.critical( + self, + "Empty text", + html_utils.paragraph("Text entry field cannot be empty"), + ) + return False + + def checkSegmFilename(self): + if not self.isSegmFile: + return True + + if "segm" not in self._text(): + return True + + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "The text appended to the filename cannot contain the text " + '"segm".

' + "Sorry, that would confuse me. Thank you for your patience!" + ) + msg.critical(self, 'Cannot use "segm" in filename', txt) + return False + + def ok_cb(self, checked=True): + if self.warningInvalidCharLabel.text(): + return + + valid = self.checkExistingNames() + if not valid: + return + + valid = self.checkEmptyText() + if not valid: + return + + valid = self.checkSegmFilename() + if not valid: + return + + self.filename = self.filenameLabel.text() + self.entryText = self._text() + self.cancel = False + self.close() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + super().show() + if self.resizeOnShow: + self.lineEdit.setMinimumWidth(self.lineEdit.width() * 2) + self.okButton.setDefault(True) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + +class QDialogMetadataXML(QDialog): + def __init__( + self, + title="Metadata", + LensNA=1.0, + rawFilename="test", + SizeT=1, + SizeZ=1, + SizeC=1, + SizeS=1, + TimeIncrement=1.0, + TimeIncrementUnit="s", + PhysicalSizeX=1.0, + PhysicalSizeY=1.0, + PhysicalSizeZ=1.0, + PhysicalSizeUnit="μm", + ImageName="", + chNames=None, + emWavelens=None, + parent=None, + rawDataStruct=None, + sampleImgData=None, + rawFilePath=None, + ): + self.cancel = True + self.trust = False + self.overWrite = False + rawFilename = os.path.splitext(rawFilename)[0] + self.rawFilename = self.removeInvalidCharacters(rawFilename) + self.rawFilePath = rawFilePath + self.sampleImgData = sampleImgData + self.ImageName = ImageName + self.rawDataStruct = rawDataStruct + self.readSampleImgDataAgain = False + self.requestedReadingSampleImageDataAgain = False + self.imageViewer = None + super().__init__(parent) + self.setWindowTitle(title) + font = QFont() + font.setPixelSize(12) + self.setFont(font) + + mainLayout = QVBoxLayout() + entriesLayout = QGridLayout() + self.channelNameLayouts = ( + QVBoxLayout(), + QVBoxLayout(), + QVBoxLayout(), + QVBoxLayout(), + ) + self.channelEmWLayouts = ( + QVBoxLayout(), + QVBoxLayout(), + QVBoxLayout(), + QVBoxLayout(), + ) + buttonsLayout = QGridLayout() + + infoLabel = QLabel() + infoTxt = "Confirm/Edit the metadata below." + infoLabel.setText(infoTxt) + # padding: top, left, bottom, right + infoLabel.setStyleSheet("font-size:12pt; padding:0px 0px 5px 0px;") + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + + noteLabel = QLabel() + noteLabel.setText( + f"NOTE: If you are not sure about some of the entries " + 'you can try to click "Ok".\n' + "If they are wrong you will get " + "an error message later when trying to read the data." + ) + noteLabel.setAlignment(Qt.AlignCenter) + mainLayout.addWidget(noteLabel, alignment=Qt.AlignCenter) + + row = 0 + to_tif_radiobutton = QRadioButton(".tif") + to_tif_radiobutton.setChecked(True) + to_h5_radiobutton = QRadioButton(".h5") + to_h5_radiobutton.setToolTip( + ".h5 is highly recommended for big datasets to avoid memory issues.\n" + "As a rule of thumb, if the single position, single channel file\n" + "is larger than 1/5 of the available RAM we recommend using .h5 format" + ) + self.to_h5_radiobutton = to_h5_radiobutton + txt = "File format: " + label = QLabel(txt) + fileFormatLayout = QHBoxLayout() + fileFormatLayout.addStretch(1) + fileFormatLayout.addWidget(to_tif_radiobutton) + fileFormatLayout.addStretch(1) + fileFormatLayout.addWidget(to_h5_radiobutton) + fileFormatLayout.addStretch(1) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addLayout(fileFormatLayout, row, 1) + to_h5_radiobutton.toggled.connect(self.updateFileFormat) + + row += 1 + self.SizeS_SB = QSpinBox() + self.SizeS_SB.setAlignment(Qt.AlignCenter) + self.SizeS_SB.setMinimum(1) + self.SizeS_SB.setMaximum(2147483647) + self.SizeS_SB.setValue(SizeS) + txt = "Number of positions (SizeS): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.SizeS_SB, row, 1) + + if rawDataStruct == 0: + row += 1 + self.SizeS_SB.setValue(1) + self.SizeS_SB.setDisabled(True) + self.posSelector = widgets.ExpandableListBox() + positions = ["All positions"] + positions.extend([f"Position_{i + 1}" for i in range(SizeS)]) + self.posSelector.addItems(positions) + txt = "Positions to save: " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.posSelector, row, 1) + self.SizeS_SB.valueChanged.connect(self.SizeSvalueChanged) + + row += 1 + self.LensNA_DSB = QDoubleSpinBox() + self.LensNA_DSB.setAlignment(Qt.AlignCenter) + self.LensNA_DSB.setSingleStep(0.1) + self.LensNA_DSB.setValue(LensNA) + txt = "Numerical Aperture Objective Lens: " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.LensNA_DSB, row, 1) + + row += 1 + self.SizeT_SB = QSpinBox() + self.SizeT_SB.setAlignment(Qt.AlignCenter) + self.SizeT_SB.setMinimum(1) + self.SizeT_SB.setMaximum(2147483647) + self.SizeT_SB.setValue(SizeT) + txt = "Number of frames (SizeT): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.SizeT_SB, row, 1) + self.SizeT_SB.valueChanged.connect(self.hideShowTimeIncrement) + + row += 1 + self.timeRangeToSaveWidget = widgets.RangeSelector(integers=True) + self.timeRangeToSaveWidget.setRange(1, SizeT) + txt = "Time range to save: " + label = QLabel(txt) + self.timeRangeToSaveWidget.label = label + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.timeRangeToSaveWidget, row, 1) + + row += 1 + self.SizeZ_SB = QSpinBox() + self.SizeZ_SB.setAlignment(Qt.AlignCenter) + self.SizeZ_SB.setMinimum(1) + self.SizeZ_SB.setMaximum(2147483647) + self.SizeZ_SB.setValue(SizeZ) + txt = "Number of z-slices in the z-stack (SizeZ): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.SizeZ_SB, row, 1) + self.SizeZ_SB.valueChanged.connect(self.hideShowPhysicalSizeZ) + + row += 1 + self.TimeIncrement_DSB = widgets.FloatLineEdit( + allowNegative=False, warningValues={1.0} + ) + self.TimeIncrement_DSB.setValue(TimeIncrement) + self.TimeIncrement_DSB.setMinimum(0.0) + txt = "Frame interval: " + label = QLabel(txt) + self.TimeIncrement_Label = label + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.TimeIncrement_DSB, row, 1) + + self.TimeIncrementUnit_CB = QComboBox() + unitItems = ["ms", "seconds", "minutes", "hours"] + currentTxt = [unit for unit in unitItems if unit.startswith(TimeIncrementUnit)] + self.TimeIncrementUnit_CB.addItems(unitItems) + if currentTxt: + self.TimeIncrementUnit_CB.setCurrentText(currentTxt[0]) + entriesLayout.addWidget( + self.TimeIncrementUnit_CB, row, 2, alignment=Qt.AlignLeft + ) + + row += 1 + self.PhysicalSizeX_DSB = QDoubleSpinBox() + self.PhysicalSizeX_DSB.setAlignment(Qt.AlignCenter) + self.PhysicalSizeX_DSB.setMaximum(2147483647.0) + self.PhysicalSizeX_DSB.setSingleStep(0.001) + self.PhysicalSizeX_DSB.setDecimals(7) + self.PhysicalSizeX_DSB.setValue(PhysicalSizeX) + txt = "Pixel width (X): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.PhysicalSizeX_DSB, row, 1) + + self.PhysicalSizeUnit_CB = QComboBox() + unitItems = ["nm", "μm", "mm", "cm"] + currentTxt = [unit for unit in unitItems if unit.startswith(PhysicalSizeUnit)] + self.PhysicalSizeUnit_CB.addItems(unitItems) + if currentTxt: + self.PhysicalSizeUnit_CB.setCurrentText(currentTxt[0]) + else: + self.PhysicalSizeUnit_CB.setCurrentText(unitItems[1]) + entriesLayout.addWidget( + self.PhysicalSizeUnit_CB, row, 2, alignment=Qt.AlignLeft + ) + self.PhysicalSizeUnit_CB.currentTextChanged.connect(self.updatePSUnit) + + row += 1 + self.PhysicalSizeY_DSB = QDoubleSpinBox() + self.PhysicalSizeY_DSB.setAlignment(Qt.AlignCenter) + self.PhysicalSizeY_DSB.setMaximum(2147483647.0) + self.PhysicalSizeY_DSB.setSingleStep(0.001) + self.PhysicalSizeY_DSB.setDecimals(7) + self.PhysicalSizeY_DSB.setValue(PhysicalSizeY) + txt = "Pixel height (Y): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.PhysicalSizeY_DSB, row, 1) + + self.PhysicalSizeYUnit_Label = QLabel() + self.PhysicalSizeYUnit_Label.setStyleSheet( + "font-size:13px; padding:5px 0px 2px 0px;" + ) + unit = self.PhysicalSizeUnit_CB.currentText() + self.PhysicalSizeYUnit_Label.setText(unit) + entriesLayout.addWidget(self.PhysicalSizeYUnit_Label, row, 2) + + row += 1 + self.PhysicalSizeZ_DSB = QDoubleSpinBox() + self.PhysicalSizeZ_DSB.setAlignment(Qt.AlignCenter) + self.PhysicalSizeZ_DSB.setMaximum(2147483647.0) + self.PhysicalSizeZ_DSB.setSingleStep(0.001) + self.PhysicalSizeZ_DSB.setDecimals(7) + self.PhysicalSizeZ_DSB.setValue(PhysicalSizeZ) + txt = "Voxel depth (Z): " + self.PSZlabel = QLabel(txt) + entriesLayout.addWidget(self.PSZlabel, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.PhysicalSizeZ_DSB, row, 1) + + self.PhysicalSizeZUnit_Label = QLabel() + # padding: top, left, bottom, right + self.PhysicalSizeZUnit_Label.setStyleSheet( + "font-size:13px; padding:5px 0px 2px 0px;" + ) + unit = self.PhysicalSizeUnit_CB.currentText() + self.PhysicalSizeZUnit_Label.setText(unit) + entriesLayout.addWidget(self.PhysicalSizeZUnit_Label, row, 2) + + if SizeZ == 1: + self.PSZlabel.hide() + self.PhysicalSizeZ_DSB.hide() + self.PhysicalSizeZUnit_Label.hide() + + row += 1 + self.SizeC_SB = QSpinBox() + self.SizeC_SB.setAlignment(Qt.AlignCenter) + self.SizeC_SB.setMinimum(1) + self.SizeC_SB.setMaximum(2147483647) + self.SizeC_SB.setValue(SizeC) + txt = "Number of channels (SizeC): " + label = QLabel(txt) + entriesLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + entriesLayout.addWidget(self.SizeC_SB, row, 1) + self.SizeC_SB.valueChanged.connect(self.addRemoveChannels) + + row += 1 + for j, layout in enumerate(self.channelNameLayouts): + entriesLayout.addLayout(layout, row, j) + + self.chNames_QLEs = [] + self.saveChannels_QCBs = [] + self.filename_QLabels = [] + self.showChannelDataButtons = [] + + ext = "h5" if self.to_h5_radiobutton.isChecked() else "tif" + for c in range(SizeC): + chName_QLE = QLineEdit() + chName_QLE.setStyleSheet("") + chName_QLE.setAlignment(Qt.AlignCenter) + chName_QLE.textChanged.connect(self.checkChNames) + if chNames is not None: + chName_QLE.setText(chNames[c]) + else: + chName_QLE.setText(f"channel_{c}") + filename = f"" + + txt = f"Channel {c} name: " + label = QLabel(txt) + + filenameDescLabel = QLabel(f"e.g., filename for channel {c}: ") + + chName = chName_QLE.text() + chName = self.removeInvalidCharacters(chName) + rawFilename = self.elidedRawFilename() + filenameLabel = QLabel(f""" +

{rawFilename}_{chName}.{ext}

+ """) + filenameLabel.setToolTip(f"{self.rawFilename}_{chName}.{ext}") + + checkBox = QCheckBox("Save this channel") + checkBox.setChecked(True) + checkBox.stateChanged.connect(self.saveCh_checkBox_cb) + + self.channelNameLayouts[0].addWidget(label, alignment=Qt.AlignRight) + self.channelNameLayouts[0].addWidget( + filenameDescLabel, alignment=Qt.AlignRight + ) + self.channelNameLayouts[1].addWidget(chName_QLE) + self.channelNameLayouts[1].addWidget( + filenameLabel, alignment=Qt.AlignCenter + ) + + self.channelNameLayouts[2].addWidget(checkBox) + if c == 0 and ImageName: + addImageName_QCB = QCheckBox("Include image name") + addImageName_QCB.stateChanged.connect(self.addImageName_cb) + self.addImageName_QCB = addImageName_QCB + self.channelNameLayouts[2].addWidget(addImageName_QCB) + else: + self.addImageName_QCB = QCheckBox("dummy") + self.addImageName_QCB.hide() + self.channelNameLayouts[2].addWidget(QLabel()) + + showChannelDataButton = QPushButton() + showChannelDataButton.setIcon(QIcon(":eye-plus.svg")) + showChannelDataButton.clicked.connect(self.showChannelData) + self.channelNameLayouts[3].addWidget(showChannelDataButton) + if self.sampleImgData is None: + showChannelDataButton.setDisabled(True) + + self.chNames_QLEs.append(chName_QLE) + self.saveChannels_QCBs.append(checkBox) + self.filename_QLabels.append(filenameLabel) + self.showChannelDataButtons.append(showChannelDataButton) + + self.checkChNames() + + row += 1 + for j, layout in enumerate(self.channelEmWLayouts): + entriesLayout.addLayout(layout, row, j) + + self.emWavelens_DSBs = [] + for c in range(SizeC): + row += 1 + emWavelen_DSB = QDoubleSpinBox() + emWavelen_DSB.setAlignment(Qt.AlignCenter) + emWavelen_DSB.setMaximum(2147483647.0) + emWavelen_DSB.setSingleStep(0.001) + emWavelen_DSB.setDecimals(2) + if emWavelens is not None: + emWavelen_DSB.setValue(emWavelens[c]) + else: + emWavelen_DSB.setValue(500.0) + + txt = f"Channel {c} emission wavelength: " + label = QLabel(txt) + self.channelEmWLayouts[0].addWidget(label, alignment=Qt.AlignRight) + self.channelEmWLayouts[1].addWidget(emWavelen_DSB) + self.emWavelens_DSBs.append(emWavelen_DSB) + + unit = QLabel("nm") + unit.setStyleSheet("font-size:13px; padding:5px 0px 2px 0px;") + self.channelEmWLayouts[2].addWidget(unit) + + entriesLayout.setContentsMargins(0, 15, 0, 0) + + if rawDataStruct is None or rawDataStruct != -1: + okButton = widgets.okPushButton(" Ok ") + elif rawDataStruct == 1: + okButton = QPushButton(" Load next position ") + buttonsLayout.addWidget(okButton, 0, 1) + + self.trustButton = None + self.overWriteButton = None + if rawDataStruct == 1: + trustButton = QPushButton( + " Trust metadata reader\n for all next positions " + ) + trustButton.setToolTip( + "If you didn't have to manually modify metadata entries\n" + "it is very likely that metadata from the metadata reader\n" + "will be correct also for all the next positions.\n\n" + "Click this button to stop showing this dialog and use\n" + "the metadata from the reader\n" + "(except for channel names, I will use the manually entered)" + ) + buttonsLayout.addWidget(trustButton, 1, 1) + self.trustButton = trustButton + + overWriteButton = QPushButton( + " Use the above metadata\n for all the next positions " + ) + overWriteButton.setToolTip( + "If you had to manually modify metadata entries\n" + "AND you know they will be the same for all next positions\n" + "you can click this button to stop showing this dialog\n" + "and use the same metadata for all the next positions." + ) + buttonsLayout.addWidget(overWriteButton, 1, 2) + self.overWriteButton = overWriteButton + + trustButton.clicked.connect(self.ok_cb) + overWriteButton.clicked.connect(self.ok_cb) + + cancelButton = widgets.cancelPushButton("Cancel") + buttonsLayout.addWidget(cancelButton, 0, 2) + buttonsLayout.setColumnStretch(0, 1) + buttonsLayout.setColumnStretch(3, 1) + buttonsLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(entriesLayout) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch(1) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + + self.hideShowTimeIncrement(SizeT) + self.readSampleImgDataAgain = False + + self.setLayout(mainLayout) + # self.setModal(True) + + def saveCh_checkBox_cb(self, state): + self.checkChNames() + idx = self.saveChannels_QCBs.index(self.sender()) + LE = self.chNames_QLEs[idx] + idx *= 2 + LE.setDisabled(state == 0) + label = self.channelNameLayouts[0].itemAt(idx).widget() + if state == 0: + label.setStyleSheet("color: gray; font-size: 10pt") + else: + label.setStyleSheet("color: black; font-size: 10pt") + + label = self.channelNameLayouts[0].itemAt(idx + 1).widget() + if state == 0: + label.setStyleSheet("color: gray; font-size: 10pt") + else: + label.setStyleSheet("color: black; font-size: 10pt") + + label = self.channelNameLayouts[1].itemAt(idx + 1).widget() + if state == 0: + label.setStyleSheet("color: gray; font-size: 10pt") + else: + label.setStyleSheet("color: black; font-size: 10pt") + + def addImageName_cb(self, state): + for idx in range(self.SizeC_SB.value()): + self.updateFilename(idx) + + def setInvalidChName_StyleSheet(self, LE): + LE.setStyleSheet( + "border-radius: 4px;border: 1.5px solid red;padding: 1px 0px 1px 0px" + ) + + def removeInvalidCharacters(self, chName): + # Remove invalid charachters + chName = "".join( + c if c.isalnum() or c == "_" or c == "" else "_" for c in chName + ) + trim_ = chName.endswith("_") + while trim_: + chName = chName[:-1] + trim_ = chName.endswith("_") + return chName + + def updateFileFormat(self, is_h5): + for idx in range(len(self.chNames_QLEs)): + self.updateFilename(idx) + + def SizeSvalueChanged(self, SizeS): + positions = ["All positions"] + positions.extend([f"Position_{i + 1}" for i in range(SizeS)]) + self.posSelector.setItems(positions) + + def elidedRawFilename(self): + n = 31 + idx = int((n - 3) / 2) + if len(self.rawFilename) > 21: + elidedText = f"{self.rawFilename[:idx]}...{self.rawFilename[-idx:]}" + else: + elidedText = self.rawFilename + return elidedText + + def updateFilename(self, idx): + chName = self.chNames_QLEs[idx].text() + chName = self.removeInvalidCharacters(chName) + if self.rawDataStruct == 2: + rawFilename = f"{self.rawFilename}_s{idx + 1}" + else: + rawFilename = self.rawFilename + + ext = "h5" if self.to_h5_radiobutton.isChecked() else "tif" + + rawFilename = self.elidedRawFilename() + + filenameLabel = self.filename_QLabels[idx] + if self.addImageName_QCB.isChecked(): + self.ImageName = self.removeInvalidCharacters(self.ImageName) + filename = f""" +

+ {rawFilename}_{self.ImageName}_{chName}.{ext} +

+ """ + fullFilename = f"{self.rawFilename}_{self.ImageName}_{chName}.{ext}" + else: + filename = f""" +

+ {rawFilename}_{chName}.{ext} +

+ """ + fullFilename = f"{self.rawFilename}_{chName}.{ext}" + filenameLabel.setToolTip(fullFilename) + filenameLabel.setText(filename) + + def checkChNames(self, text=""): + if self.sender() in self.chNames_QLEs: + idx = self.chNames_QLEs.index(self.sender()) + self.updateFilename(idx) + elif self.sender() in self.saveChannels_QCBs: + idx = self.saveChannels_QCBs.index(self.sender()) + self.updateFilename(idx) + + areChNamesValid = True + if len(self.chNames_QLEs) == 1: + LE1 = self.chNames_QLEs[0] + saveCh = self.saveChannels_QCBs[0].isChecked() + if not saveCh: + LE1.setStyleSheet("") + return areChNamesValid + + s1 = LE1.text() + if not s1: + self.setInvalidChName_StyleSheet(LE1) + areChNamesValid = False + else: + LE1.setStyleSheet("") + return areChNamesValid + + for LE1, LE2 in combinations(self.chNames_QLEs, 2): + s1 = LE1.text() + s2 = LE2.text() + LE1_idx = self.chNames_QLEs.index(LE1) + LE2_idx = self.chNames_QLEs.index(LE2) + saveCh1 = self.saveChannels_QCBs[LE1_idx].isChecked() + saveCh2 = self.saveChannels_QCBs[LE2_idx].isChecked() + if not s1 or not s2 or s1 == s2: + if not s1 and saveCh1: + self.setInvalidChName_StyleSheet(LE1) + areChNamesValid = False + else: + LE1.setStyleSheet("") + if not s2 and saveCh2: + self.setInvalidChName_StyleSheet(LE2) + areChNamesValid = False + else: + LE2.setStyleSheet("") + if s1 == s2 and saveCh1 and saveCh2: + self.setInvalidChName_StyleSheet(LE1) + self.setInvalidChName_StyleSheet(LE2) + areChNamesValid = False + else: + LE1.setStyleSheet("") + LE2.setStyleSheet("") + return areChNamesValid + + def hideShowTimeIncrement(self, value): + if self.TimeIncrement_DSB.isVisible() and value == 1: + self.readSampleImgDataAgain = True + + if not self.TimeIncrement_DSB.isVisible() and value > 1: + self.readSampleImgDataAgain = True + + if value > 1: + self.TimeIncrement_DSB.show() + self.TimeIncrementUnit_CB.show() + self.TimeIncrement_Label.show() + self.timeRangeToSaveWidget.show() + self.timeRangeToSaveWidget.label.show() + self.timeRangeToSaveWidget.setRange(1, value) + else: + self.TimeIncrement_DSB.hide() + self.TimeIncrementUnit_CB.hide() + self.TimeIncrement_Label.hide() + self.timeRangeToSaveWidget.hide() + self.timeRangeToSaveWidget.label.hide() + + def hideShowPhysicalSizeZ(self, value): + if value > 1: + self.PSZlabel.show() + self.PhysicalSizeZ_DSB.show() + self.PhysicalSizeZUnit_Label.show() + else: + self.PSZlabel.hide() + self.PhysicalSizeZ_DSB.hide() + self.PhysicalSizeZUnit_Label.hide() + self.readSampleImgDataAgain = True + + def updatePSUnit(self, unit): + self.PhysicalSizeYUnit_Label.setText(unit) + self.PhysicalSizeZUnit_Label.setText(unit) + + def warnRestart(self): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(""" + Since you manually changed some of the metadata, this dialogue will now restart
+ because it needs to read the image data again.

+ Thank you for your patience. + """) + msg.warning(self, "Restart required", txt) + + def showChannelData(self, checked=False, idx=None): + if self.readSampleImgDataAgain: + # User changed SizeZ, SizeT, or SizeC --> we need to read sample + # image again + del self.sampleImgData + self.requestedReadingSampleImageDataAgain = True + self.sampleImgData = None + self.warnRestart() + self.getValues() + self.cancel = False + self.close() + return + + if idx is None: + idx = self.showChannelDataButtons.index(self.sender()) + dimsOrder = "ctz" + imgData = self.sampleImgData[dimsOrder][idx] + posData = utils.utilClass() + posData.frame_i = 0 + sampleSizeT = 4 if self.SizeT_SB.value() >= 4 else self.SizeT_SB.value() + posData.SizeT = sampleSizeT + SizeZ = self.SizeZ_SB.value() + posData.SizeZ = 20 if SizeZ > 20 else SizeZ + posData.filename = f"{self.rawFilename}_C={idx}" + posData.segmInfo_df = pd.DataFrame( + { + "filename": [posData.filename] * sampleSizeT, + "frame_i": range(sampleSizeT), + "which_z_proj_gui": ["single z-slice"] * sampleSizeT, + "z_slice_used_gui": [int(posData.SizeZ / 2)] * sampleSizeT, + } + ).set_index(["filename", "frame_i"]) + path_li = os.path.normpath(self.rawFilePath).split(os.sep) + posData.relPath = f"{f'{os.sep}'.join(path_li[-3:1])}" + posData.relPath = f"{posData.relPath}{os.sep}{posData.filename}" + if sampleSizeT == 1: + posData.img_data = [imgData] # single frame data + else: + posData.img_data = imgData + + if self.imageViewer is not None: + self.imageViewer.close() + + self.imageViewer = imageViewer( + posData=posData, isSigleFrame=False, enableOverlay=False + ) + self.imageViewer.channelIndex = idx + self.imageViewer.update_img() + self.imageViewer.sigClosed.connect(self.imageViewerClosed) + self.imageViewer.show() + + def imageViewerClosed(self): + self.imageViewer = None + + def addRemoveChannels(self, value): + self.readSampleImgDataAgain = True + currentSizeC = len(self.chNames_QLEs) + DeltaChannels = abs(value - currentSizeC) + ext = "h5" if self.to_h5_radiobutton.isChecked() else "tif" + if value > currentSizeC: + for c in range(currentSizeC, currentSizeC + DeltaChannels): + chName_QLE = QLineEdit() + chName_QLE.setStyleSheet("") + chName_QLE.setAlignment(Qt.AlignCenter) + chName_QLE.setText(f"channel_{c}") + chName_QLE.textChanged.connect(self.checkChNames) + + txt = f"Channel {c} name: " + label = QLabel(txt) + + filenameDescLabel = QLabel(f"e.g., filename for channel {c}: ") + + chName = chName_QLE.text() + rawFilename = self.elidedRawFilename() + filenameLabel = QLabel(f""" +

{rawFilename}_{chName}.{ext}

+ """) + filenameLabel.setToolTip(f"{self.rawFilename}_{chName}.{ext}") + + checkBox = QCheckBox("Save this channel") + checkBox.setChecked(True) + checkBox.stateChanged.connect(self.saveCh_checkBox_cb) + + self.channelNameLayouts[0].addWidget(label, alignment=Qt.AlignRight) + self.channelNameLayouts[0].addWidget( + filenameDescLabel, alignment=Qt.AlignRight + ) + self.channelNameLayouts[1].addWidget(chName_QLE) + self.channelNameLayouts[1].addWidget( + filenameLabel, alignment=Qt.AlignCenter + ) + + self.channelNameLayouts[2].addWidget(checkBox) + self.channelNameLayouts[2].addWidget(QLabel()) + + showChannelDataButton = QPushButton() + showChannelDataButton.setIcon(QIcon(":eye-plus.svg")) + showChannelDataButton.clicked.connect(self.showChannelData) + self.channelNameLayouts[3].addWidget(showChannelDataButton) + if self.sampleImgData is None: + showChannelDataButton.setDisabled(True) + + self.chNames_QLEs.append(chName_QLE) + self.saveChannels_QCBs.append(checkBox) + self.filename_QLabels.append(filenameLabel) + self.showChannelDataButtons.append(showChannelDataButton) + + emWavelen_DSB = QDoubleSpinBox() + emWavelen_DSB.setAlignment(Qt.AlignCenter) + emWavelen_DSB.setMaximum(2147483647.0) + emWavelen_DSB.setSingleStep(0.001) + emWavelen_DSB.setDecimals(2) + emWavelen_DSB.setValue(500.0) + unit = QLabel("nm") + unit.setStyleSheet("font-size:13px; padding:5px 0px 2px 0px;") + + txt = f"Channel {c} emission wavelength: " + label = QLabel(txt) + self.channelEmWLayouts[0].addWidget(label, alignment=Qt.AlignRight) + self.channelEmWLayouts[1].addWidget(emWavelen_DSB) + self.channelEmWLayouts[2].addWidget(unit) + self.emWavelens_DSBs.append(emWavelen_DSB) + else: + for c in range(currentSizeC, currentSizeC + DeltaChannels): + idx = (c - 1) * 2 + label1 = self.channelNameLayouts[0].itemAt(idx).widget() + label2 = self.channelNameLayouts[0].itemAt(idx + 1).widget() + chName_QLE = self.channelNameLayouts[1].itemAt(idx).widget() + filename_L = self.channelNameLayouts[1].itemAt(idx + 1).widget() + checkBox = self.channelNameLayouts[2].itemAt(idx).widget() + dummyLabel = self.channelNameLayouts[2].itemAt(idx + 1).widget() + showButton = self.showChannelDataButtons[-1] + showButton.clicked.disconnect() + + self.channelNameLayouts[0].removeWidget(label1) + self.channelNameLayouts[0].removeWidget(label2) + self.channelNameLayouts[1].removeWidget(chName_QLE) + self.channelNameLayouts[1].removeWidget(filename_L) + self.channelNameLayouts[2].removeWidget(checkBox) + self.channelNameLayouts[2].removeWidget(dummyLabel) + self.channelNameLayouts[3].removeWidget(showButton) + + self.chNames_QLEs.pop(-1) + self.saveChannels_QCBs.pop(-1) + self.filename_QLabels.pop(-1) + self.showChannelDataButtons.pop(-1) + + label = self.channelEmWLayouts[0].itemAt(c - 1).widget() + emWavelen_DSB = self.channelEmWLayouts[1].itemAt(c - 1).widget() + unit = self.channelEmWLayouts[2].itemAt(c - 1).widget() + self.channelEmWLayouts[0].removeWidget(label) + self.channelEmWLayouts[1].removeWidget(emWavelen_DSB) + self.channelEmWLayouts[2].removeWidget(unit) + self.emWavelens_DSBs.pop(-1) + + self.adjustSize() + + def ok_cb(self, event): + areChNamesValid = self.checkChNames() + if not areChNamesValid: + err_msg = html_utils.paragraph( + "Channel names cannot be empty or equal to each other." + "

" + "Insert a unique text for each channel name." + ) + msg = widgets.myMessageBox() + msg.critical(self, "Invalid channel names", err_msg) + return + + self.getValues() + self.convertUnits() + + if self.sender() == self.trustButton: + self.trust = True + elif self.sender() == self.overWriteButton: + self.overWrite = True + + self.cancel = False + self.close() + + def getValues(self): + self.LensNA = self.LensNA_DSB.value() + self.SizeT = self.SizeT_SB.value() + self.SizeZ = self.SizeZ_SB.value() + self.SizeC = self.SizeC_SB.value() + self.SizeS = self.SizeS_SB.value() + self.timeRangeToSave = self.timeRangeToSaveWidget.range() + self.TimeIncrement = self.TimeIncrement_DSB.value() + self.PhysicalSizeX = self.PhysicalSizeX_DSB.value() + self.PhysicalSizeY = self.PhysicalSizeY_DSB.value() + self.PhysicalSizeZ = self.PhysicalSizeZ_DSB.value() + self.to_h5 = self.to_h5_radiobutton.isChecked() + if hasattr(self, "posSelector"): + self.selectedPos = self.posSelector.selectedItemsText() + else: + self.selectedPos = ["All Positions"] + self.chNames = [] + if hasattr(self, "addImageName_QCB"): + self.addImageName = self.addImageName_QCB.isChecked() + else: + self.addImageName = False + self.saveChannels = [] + for LE, QCB in zip(self.chNames_QLEs, self.saveChannels_QCBs): + s = LE.text() + s = "".join(c if c.isalnum() or c == "_" or c == "" else "_" for c in s) + trim_ = s.endswith("_") + while trim_: + s = s[:-1] + trim_ = s.endswith("_") + self.chNames.append(s) + self.saveChannels.append(QCB.isChecked()) + self.emWavelens = [DSB.value() for DSB in self.emWavelens_DSBs] + + def convertUnits(self): + timeUnit = self.TimeIncrementUnit_CB.currentText() + if timeUnit == "ms": + self.TimeIncrement /= 1000 + elif timeUnit == "minutes": + self.TimeIncrement *= 60 + elif timeUnit == "hours": + self.TimeIncrement *= 3600 + + PhysicalSizeUnit = self.PhysicalSizeUnit_CB.currentText() + if timeUnit == "nm": + self.PhysicalSizeX /= 1000 + self.PhysicalSizeY /= 1000 + self.PhysicalSizeZ /= 1000 + elif timeUnit == "mm": + self.PhysicalSizeX *= 1000 + self.PhysicalSizeY *= 1000 + self.PhysicalSizeZ *= 1000 + elif timeUnit == "cm": + self.PhysicalSizeX *= 1e4 + self.PhysicalSizeY *= 1e4 + self.PhysicalSizeZ *= 1e4 + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def setSize(self): + h = self.SizeS_SB.height() + self.TimeIncrement_DSB.setMinimumHeight(h) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + self.setSize() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class MultiTimePointFilePattern(QBaseDialog): + def __init__(self, fileName, folderPath, readPatternFunc=None, parent=None): + super().__init__(parent) + + self.setWindowTitle("File name pattern") + self.cancel = True + self.additionalChannelWidgets = {} + + mainLayout = QVBoxLayout() + self.readPatternFunc = readPatternFunc + + infoText = html_utils.paragraph(""" + The image files for each time-point must be named with the following pattern:

+ position_channel_timepoint +

+ For example a file with name "pos1_GFP_1.tif" would be the first time-point of the channell GFP
+ and position called pos1.

+ The Position number will be determined by alphabetically sorting + all the image files.

+ Please, provide the channel names below. + Optionally, you can provide a basename
+ that will be pre-pended to the name of all created files.

+ You can also provide a folder path containing the segmentation masks file.
+ These files MUST be named exactly as the raw files. +
+ """) + + noteLayout = QHBoxLayout() + noteText = html_utils.paragraph(""" + Channels do not need to have the same number of frames, + however, Cell-ACDC will place
+ the frames at the right frame number + (given by timepoint number at the end
+ of the filename) and it will fill missing frames with zeros. + """) + noteLayout.addWidget( + QLabel(html_utils.to_admonition(noteText)), + # alignment=(Qt.AlignTop | Qt.AlignRight) + ) + + mainLayout.addWidget(QLabel(infoText)) + mainLayout.addLayout(noteLayout) + noteLayout.setStretch(0, 0) + noteLayout.setStretch(1, 1) + + label = QLabel( + html_utils.paragraph(f"Sample file name: {fileName}") + ) + mainLayout.addWidget(label, alignment=Qt.AlignCenter) + mainLayout.addSpacing(5) + + channelName = "" + posName = "" + frameNumber = None + if readPatternFunc is not None: + posName, frameNumber, channelName = readPatternFunc(fileName) + + formLayout = QGridLayout() + + ncols = 3 + self.vLayouts = [QVBoxLayout() for _ in range(ncols)] + for j, l in enumerate(self.vLayouts): + formLayout.addLayout(l, 0, j) + + row = 0 + items = QLabel("Position name: "), widgets.ReadOnlyLineEdit(), QLabel() + label, self.posNameEntry, button = items + self.posNameEntry.setAlignment(Qt.AlignCenter) + self.posNameEntry.setText(str(posName)) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + row += 1 + items = (QLabel("Frame number name: "), widgets.ReadOnlyLineEdit(), QLabel()) + self.frameNumberEntry = items[1] + self.frameNumberEntry.setText(str(frameNumber)) + self.frameNumberEntry.setAlignment(Qt.AlignCenter) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + row += 1 + self.channelNameLE = widgets.alphaNumericLineEdit() + items = ( + QLabel("Channel_1 name: "), + self.channelNameLE, + widgets.addPushButton(" Add channel"), + ) + self.addChannelButton = items[2] + self.addChannelButton._row = row + self.channelNameLE.setAlignment(Qt.AlignCenter) + self.channelNameLE.setText(channelName) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + row += 1 + items = ( + QLabel("Basename (optional): "), + widgets.alphaNumericLineEdit(), + QLabel(), + ) + label, self.baseNameLE, button = items + self.baseNameLE.setAlignment(Qt.AlignCenter) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + row += 1 + items = QLabel("File will be saved as: "), QLineEdit(), QLabel() + label, self.relPathEntry, button = items + self.relPathEntry.setAlignment(Qt.AlignCenter) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + row += 1 + items = ( + QLabel("Segmentation masks folder path: "), + widgets.ElidingLineEdit(), + widgets.browseFileButton( + "Browse...", + title="Select folder containing segmentation masks", + start_dir=folderPath, + openFolder=True, + ), + ) + label, self.segmFolderPathEntry, button = items + button.sigPathSelected.connect(self.segmFolderpathSelected) + self.segmFolderPathEntry.setAlignment(Qt.AlignCenter) + for j, w in enumerate(items): + self.vLayouts[j].addWidget(w) + + self.formLayout = formLayout + + self.updateRelativePath() + + self.channelNameLE.textChanged.connect(self.updateRelativePath) + self.baseNameLE.textChanged.connect(self.updateRelativePath) + self.addChannelButton.clicked.connect(self.addChannel) + + mainLayout.addLayout(formLayout) + + buttonsLayout = widgets.CancelOkButtonsLayout() + showInFileManagerButton = widgets.showInFileManagerButton( + utils.get_open_filemaneger_os_string() + ) + buttonsLayout.insertWidget(3, showInFileManagerButton) + func = partial(utils.showInExplorer, folderPath) + showInFileManagerButton.clicked.connect(func) + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch() + + self.setLayout(mainLayout) + + self.setFont(font) + + def segmFolderpathSelected(self, path): + self.segmFolderPathEntry.setText(path) + + def addChannel(self): + self.addChannelButton._row += 1 + row = self.addChannelButton._row + + channel_idx = len(self.additionalChannelWidgets) + items = ( + QLabel(f"Channel_{channel_idx + 1} name: "), + widgets.alphaNumericLineEdit(), + widgets.subtractPushButton("Remove channel"), + ) + label, lineEdit, button = items + lineEdit.setAlignment(Qt.AlignCenter) + button.clicked.connect(self.removeChannel) + button._row = row + for j, w in enumerate(items): + self.vLayouts[j].insertWidget(row, w) + + self.additionalChannelWidgets[row] = items + lineEdit.setFocus() + + def removeChannel(self): + row = self.sender()._row + for j, w in enumerate(self.additionalChannelWidgets[row]): + self.vLayouts[j].removeWidget(w) + + self.additionalChannelWidgets.pop(row) + self.addChannelButton._row -= 1 + + def checkChannelNames(self): + allChannels = [self.channelNameLE.text()] + allChannels.extend( + [w[1].text() for w in self.additionalChannelWidgets.values()] + ) + for ch1, ch2 in combinations(allChannels, 2): + if ch1 == ch2: + break + if not ch1 or not ch2: + break + else: + # Channel names are fine + return allChannels + + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(""" + Some channel names are empty or not different from each other. + """) + msg.critical(self, "Select two or more items", txt) + return None + + def updateRelativePath(self, text=""): + posName = self.posNameEntry.text() + frameNumber = self.frameNumberEntry.text() + channelName = self.channelNameLE.text() + basename = self.baseNameLE.text() + if basename: + filename = f"{basename}_{posName}_{channelName}.tif" + else: + filename = f"{posName}_{channelName}.tif" + relPath = f"...{os.sep}Position_1{os.sep}Images{os.sep}{filename}" + self.relPathEntry.setText(relPath) + + def ok_cb(self): + allChannels = self.checkChannelNames() + if allChannels is None: + return + self.allChannels = allChannels + self.basename = self.baseNameLE.text() + self.segmFolderPath = self.segmFolderPathEntry.text() + self.cancel = False + self.close() + + def showEvent(self, event) -> None: + self.channelNameLE.setFocus() + + +class OrderableListWidgetDialog(QBaseDialog): + def __init__( + self, items, title="Select items", infoTxt="", helpText="", parent=None + ): + super().__init__(parent) + + self.selectedItemsText = [] + + self.cancel = True + self.setWindowTitle(title) + + mainLayout = QVBoxLayout() + self.helpText = helpText + + if infoTxt: + mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) + + self.listWidget = widgets.OrderableList() + self.listWidget.addItems(items) + + buttonsLayout = widgets.CancelOkButtonsLayout() + if helpText: + helpButton = widgets.helpPushButton("Help...") + buttonsLayout.insertWidget(3, helpButton) + helpButton.clicked.connect(self.showHelp) + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(self.listWidget) + mainLayout.addSpacing(10) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def showHelp(self): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(self.helpText) + msg.information(self, "Select tables help", txt) + + def ok_cb(self): + self.cancel = False + self.selectedItemsText = [None] * len(self.listWidget.selectedItems()) + for itemW in self.listWidget.selectedItems(): + idx = int(itemW._nrWidget.currentText()) - 1 + if idx >= len(self.selectedItemsText): + idx = len(self.selectedItemsText) - 1 + self.selectedItemsText[idx] = itemW._text + self.close() + + +class QDialogAppendTextFilename(QDialog): + def __init__(self, filename, ext, parent=None, font=None): + super().__init__(parent) + self.cancel = True + filenameNOext, _ = os.path.splitext(filename) + self.filenameNOext = filenameNOext + if ext.find(".") == -1: + ext = f".{ext}" + self.ext = ext + + self.setWindowTitle("Append text to file name") + + mainLayout = QVBoxLayout() + formLayout = QFormLayout() + buttonsLayout = QHBoxLayout() + + if font is not None: + self.setFont(font) + + self.LE = QLineEdit() + self.LE.setAlignment(Qt.AlignCenter) + formLayout.addRow("Appended text", self.LE) + self.LE.textChanged.connect(self.updateFinalFilename) + + self.finalName_label = QLabel(f'Final file name: "{filenameNOext}_{ext}"') + # padding: top, left, bottom, right + self.finalName_label.setStyleSheet("font-size:13px; padding:5px 0px 0px 0px;") + + okButton = widgets.okPushButton("Ok") + okButton.setShortcut(Qt.Key_Enter) + + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + buttonsLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(formLayout) + mainLayout.addWidget(self.finalName_label, alignment=Qt.AlignCenter) + mainLayout.addLayout(buttonsLayout) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + self.formLayout = formLayout + + self.setLayout(mainLayout) + # self.setModal(True) + + def updateFinalFilename(self, text): + finalFilename = f"{self.filenameNOext}_{text}{self.ext}" + self.finalName_label.setText(f'Final file name: "{finalFilename}"') + + def ok_cb(self, event): + if not self.LE.text(): + err_msg = "Appended name cannot be empty!" + msg = QMessageBox() + msg.critical(self, "Empty name", err_msg, msg.Ok) + return + self.cancel = False + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QDialogEntriesWidget(QDialog): + def __init__( + self, entriesLabels, defaultTxts, winTitle="Input", parent=None, font=None + ): + self.cancel = True + self.entriesTxt = [] + self.entriesLabels = entriesLabels + self.QLEs = [] + super().__init__(parent) + self.setWindowTitle(winTitle) + + mainLayout = QVBoxLayout() + formLayout = QFormLayout() + buttonsLayout = QHBoxLayout() + + if font is not None: + self.setFont(font) + + for label, txt in zip(entriesLabels, defaultTxts): + LE = QLineEdit() + LE.setAlignment(Qt.AlignCenter) + LE.setText(txt) + formLayout.addRow(label, LE) + self.QLEs.append(LE) + + okButton = widgets.okPushButton("Ok") + okButton.setShortcut(Qt.Key_Enter) + + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + buttonsLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(formLayout) + mainLayout.addLayout(buttonsLayout) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + self.formLayout = formLayout + + self.setLayout(mainLayout) + # self.setModal(True) + + def ok_cb(self, event): + self.cancel = False + self.entriesTxt = [ + self.formLayout.itemAt(i, 1).widget().text() + for i in range(len(self.entriesLabels)) + ] + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QDialogMetadata(QDialog): + def __init__( + self, + SizeT, + SizeZ, + TimeIncrement, + PhysicalSizeZ, + PhysicalSizeY, + PhysicalSizeX, + ask_SizeT, + ask_TimeIncrement, + ask_PhysicalSizes, + parent=None, + font=None, + imgDataShape=None, + posData=None, + singlePos=False, + askSegm3D=True, + additionalValues=None, + forceEnableAskSegm3D=False, + SizeT_metadata=None, + SizeZ_metadata=None, + basename="", + ): + self.cancel = True + self.ask_TimeIncrement = ask_TimeIncrement + self.ask_PhysicalSizes = ask_PhysicalSizes + self.askSegm3D = askSegm3D + self.imgDataShape = imgDataShape + self.posData = posData + self._additionalValues = additionalValues + self.SizeT_metadata = SizeT_metadata + self.SizeZ_metadata = SizeZ_metadata + super().__init__(parent) + self.setWindowTitle("Image properties") + + mainLayout = QVBoxLayout() + gridLayout = QGridLayout() + # formLayout = QFormLayout() + buttonsLayout = QGridLayout() + + if imgDataShape is not None: + label = QLabel( + html_utils.paragraph( + f"Image data shape = {imgDataShape}
" + ) + ) + mainLayout.addWidget(label, alignment=Qt.AlignCenter) + + row = 0 + self.basenameLineEdit = None + if basename: + gridLayout.addWidget( + QLabel("Basename (read-only)"), row, 0, alignment=Qt.AlignRight + ) + self.basenameLineEdit = QLineEdit() + self.basenameLineEdit.setReadOnly(True) + self.basenameLineEdit.setText(basename) + minWidth = ( + self.basenameLineEdit.fontMetrics().boundingRect(basename).width() + 10 + ) + self.basenameLineEdit.setMinimumWidth(minWidth) + self.basenameLineEdit.setAlignment(Qt.AlignCenter) + gridLayout.addWidget(self.basenameLineEdit, row, 1) + row += 1 + + gridLayout.addWidget( + QLabel("Number of frames (SizeT)"), row, 0, alignment=Qt.AlignRight + ) + self.SizeT_SpinBox = QSpinBox() + self.SizeT_SpinBox.setMinimum(1) + self.SizeT_SpinBox.setMaximum(2147483647) + SizeTinfoButton = widgets.infoPushButton() + self.allowEditSizeTcheckbox = QCheckBox("Let me edit it") + if ask_SizeT: + self.SizeT_SpinBox.setValue(SizeT) + SizeTinfoButton.hide() + self.allowEditSizeTcheckbox.hide() + else: + self.SizeT_SpinBox.setValue(1) + self.SizeT_SpinBox.setDisabled(True) + SizeTinfoButton.show() + SizeTinfoButton.clicked.connect(self.showWhySizeTisGrayed) + self.allowEditSizeTcheckbox.show() + self.allowEditSizeTcheckbox.toggled.connect(self.allowEditSizeT) + self.SizeT_SpinBox.setAlignment(Qt.AlignCenter) + self.SizeT_SpinBox.valueChanged.connect(self.TimeIncrementShowHide) + gridLayout.addWidget(self.SizeT_SpinBox, row, 1) + gridLayout.addWidget(SizeTinfoButton, row, 2) + gridLayout.setColumnStretch(2, 0) + gridLayout.addWidget(self.allowEditSizeTcheckbox, row, 3) + gridLayout.setColumnStretch(3, 0) + + row += 1 + gridLayout.addWidget( + QLabel("Number of z-slices (SizeZ)"), row, 0, alignment=Qt.AlignRight + ) + self.SizeZ_SpinBox = QSpinBox() + self.SizeZ_SpinBox.setMinimum(1) + self.SizeZ_SpinBox.setMaximum(2147483647) + self.SizeZ_SpinBox.setValue(SizeZ) + self.SizeZ_SpinBox.setAlignment(Qt.AlignCenter) + self.SizeZ_SpinBox.valueChanged.connect(self.SizeZvalueChanged) + gridLayout.addWidget(self.SizeZ_SpinBox, row, 1) + + row += 1 + self.TimeIncrementLabel = QLabel("Time interval (s)") + gridLayout.addWidget(self.TimeIncrementLabel, row, 0, alignment=Qt.AlignRight) + self.TimeIncrementSpinBox = widgets.FloatLineEdit() + self.TimeIncrementSpinBox.setValue(TimeIncrement) + gridLayout.addWidget(self.TimeIncrementSpinBox, row, 1) + + if SizeT == 1 or not ask_TimeIncrement: + self.TimeIncrementSpinBox.hide() + self.TimeIncrementLabel.hide() + + row += 1 + self.PhysicalSizeZLabel = QLabel("Physical Size Z (um/pixel)") + gridLayout.addWidget(self.PhysicalSizeZLabel, row, 0, alignment=Qt.AlignRight) + self.PhysicalSizeZSpinBox = widgets.FloatLineEdit() + self.PhysicalSizeZSpinBox.setValue(PhysicalSizeZ) + gridLayout.addWidget(self.PhysicalSizeZSpinBox, row, 1) + + if SizeZ == 1 or not ask_PhysicalSizes: + self.PhysicalSizeZSpinBox.hide() + self.PhysicalSizeZLabel.hide() + + row += 1 + self.PhysicalSizeYLabel = QLabel("Physical Size Y (um/pixel)") + gridLayout.addWidget(self.PhysicalSizeYLabel, row, 0, alignment=Qt.AlignRight) + self.PhysicalSizeYSpinBox = widgets.FloatLineEdit() + self.PhysicalSizeYSpinBox.setValue(PhysicalSizeY) + gridLayout.addWidget(self.PhysicalSizeYSpinBox, row, 1) + + if not ask_PhysicalSizes: + self.PhysicalSizeYSpinBox.hide() + self.PhysicalSizeYLabel.hide() + + row += 1 + self.PhysicalSizeXLabel = QLabel("Physical Size X (um/pixel)") + gridLayout.addWidget(self.PhysicalSizeXLabel, row, 0, alignment=Qt.AlignRight) + self.PhysicalSizeXSpinBox = widgets.FloatLineEdit() + self.PhysicalSizeXSpinBox.setValue(PhysicalSizeX) + gridLayout.addWidget(self.PhysicalSizeXSpinBox, row, 1) + + if not ask_PhysicalSizes: + self.PhysicalSizeXSpinBox.hide() + self.PhysicalSizeXLabel.hide() + + row += 1 + self.isSegm3Dtoggle = widgets.Toggle() + if posData is not None: + self.isSegm3Dtoggle.setChecked(posData.getIsSegm3D()) + disableToggle = ( + # Disable toggle if not force enable and if + # segm data was found (we cannot change the shape of + # loaded segmentation in the GUI) + posData.segmFound is not None + and posData.segmFound + and not forceEnableAskSegm3D + ) + if disableToggle: + self.isSegm3Dtoggle.setDisabled(True) + self.isSegm3DLabel = QLabel("Work with 3D segmentation masks (z-stack)") + gridLayout.addWidget(self.isSegm3DLabel, row, 0, alignment=Qt.AlignRight) + gridLayout.addWidget(self.isSegm3Dtoggle, row, 1, alignment=Qt.AlignCenter) + self.infoButtonSegm3D = QPushButton(self) + self.infoButtonSegm3D.setCursor(Qt.WhatsThisCursor) + self.infoButtonSegm3D.setIcon(QIcon(":info.svg")) + gridLayout.addWidget(self.infoButtonSegm3D, row, 2, alignment=Qt.AlignLeft) + self.infoButtonSegm3D.clicked.connect(self.infoSegm3D) + if SizeZ == 1 or not askSegm3D: + self.isSegm3DLabel.hide() + self.isSegm3Dtoggle.hide() + self.infoButtonSegm3D.hide() + + self.SizeZvalueChanged(SizeZ) + + self.additionalFieldsWidgets = [] + addFieldButton = widgets.addPushButton("Add custom field") + addFieldInfoButton = widgets.infoPushButton() + addFieldInfoButton.clicked.connect(self.showAddFieldInfo) + addFieldButton.clicked.connect(self.addField) + addFieldLayout = QHBoxLayout() + addFieldLayout.addStretch(1) + addFieldLayout.addWidget(addFieldButton) + addFieldLayout.addWidget(addFieldInfoButton) + addFieldLayout.addStretch(1) + + if singlePos: + okTxt = "Apply only to this Position" + else: + okTxt = "Ok for loaded Positions" + okButton = widgets.okPushButton(okTxt) + okButton.setToolTip("Save metadata only for current positionh") + okButton.setShortcut(Qt.Key_Enter) + self.okButton = okButton + + if ask_TimeIncrement or ask_PhysicalSizes: + okAllButton = QPushButton("Apply to ALL Positions") + okAllButton.setToolTip( + "Update existing Physical Sizes, Time interval, cell volume (fl), " + "cell area (um^2), and time (s) for all the positions " + "in the experiment folder." + ) + self.okAllButton = okAllButton + + selectButton = QPushButton("Select the Positions to be updated") + selectButton.setToolTip( + "Ask to select positions then update existing Physical Sizes, " + "Time interval, cell volume (fl), cell area (um^2), and time (s)" + "for selected positions." + ) + self.selectButton = selectButton + else: + self.okAllButton = None + self.selectButton = None + okButton.setText("Ok") + + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.setColumnStretch(0, 1) + buttonsLayout.addWidget(okButton, 0, 1) + if ask_TimeIncrement or ask_PhysicalSizes: + buttonsLayout.addWidget(okAllButton, 0, 2) + buttonsLayout.addWidget(selectButton, 1, 1) + buttonsLayout.addWidget(cancelButton, 1, 2) + else: + buttonsLayout.addWidget(cancelButton, 0, 2) + buttonsLayout.setColumnStretch(3, 1) + + gridLayout.setColumnMinimumWidth(1, 100) + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(10) + mainLayout.addLayout(addFieldLayout) + # mainLayout.addLayout(formLayout) + mainLayout.addSpacing(20) + mainLayout.addStretch(1) + mainLayout.addLayout(buttonsLayout) + self.mainLayout = mainLayout + + okButton.clicked.connect(self.ok_cb) + if ask_TimeIncrement or ask_PhysicalSizes: + okAllButton.clicked.connect(self.ok_cb) + selectButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + + self.addAdditionalValues(additionalValues) + + self.setLayout(mainLayout) + self.setFont(font) + # self.setModal(True) + + def showWhySizeTisGrayed(self): + txt = html_utils.paragraph(f""" + The "Number of frames" field is grayed-out because you loaded multiple Positions.

+ Cell-ACDC cannot load multiple time-lapse Positions, + so it is assuming you are loading NON time-lapse data.

+ To load time-lapse data, load one Position at a time.

+ Note that you can still edit the number of frames if you need to correct it.
+ However, you can only edit the metadata, then the loading process will be stopped. + """) + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + msg.information(self, "Why is the number of frames grayed out?", txt) + + def addAdditionalValues(self, values): + if values is None: + return + + for i, (name, value) in enumerate(values.items()): + self.addField() + nameWidget = self.additionalFieldsWidgets[i]["nameWidget"] + valueWidget = self.additionalFieldsWidgets[i]["valueWidget"] + nameWidget.setText(str(name).strip("__")) + valueWidget.setText(str(value)) + + def addField(self): + nameWidget = QLineEdit() + nameWidget.setAlignment(Qt.AlignCenter) + valueWidget = QLineEdit() + valueWidget.setAlignment(Qt.AlignCenter) + removeButton = widgets.delPushButton() + + fieldLayout = QGridLayout() + fieldLayout.addWidget(QLabel("Name"), 0, 0) + fieldLayout.addWidget(nameWidget, 1, 0) + fieldLayout.addWidget(QLabel("Value"), 0, 1) + fieldLayout.addWidget(valueWidget, 1, 1) + fieldLayout.addWidget(removeButton, 1, 2) + + self.additionalFieldsWidgets.append( + { + "nameWidget": nameWidget, + "valueWidget": valueWidget, + "removeButton": removeButton, + "layout": fieldLayout, + } + ) + + idx = len(self.additionalFieldsWidgets) - 1 + removeButton.clicked.connect(partial(self.removeField, idx)) + + row = self.mainLayout.count() - 3 + self.mainLayout.insertLayout(row, fieldLayout) + + def removeField(self, idx): + widgets = self.additionalFieldsWidgets[idx] + + layoutToRemove = widgets["layout"] + for row in range(layoutToRemove.rowCount()): + for col in range(layoutToRemove.columnCount()): + item = layoutToRemove.itemAtPosition(row, col) + if item is not None: + widget = item.widget() + layoutToRemove.removeWidget(widget) + + self.additionalFieldsWidgets.pop(idx) + + self.mainLayout.removeItem(layoutToRemove) + + def showAddFieldInfo(self): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + Add a field (name and value) that will be saved to the + metadata.csv file and as a column in the + acdc_output.csv table.

+ Example: a strain name or the replicate number. + """) + msg.information(self, "Add field info", txt) + + def infoSegm3D(self): + txt = ( + "Cell-ACDC supports both 2D and 3D segmentation. If your data " + "also have a time dimension, then you can choose to segment " + "a specific z-slice (2D segmentation mask per frame) or all of them " + "(3D segmentation mask per frame)

" + "In any case, if you choose to activate 3D segmentation then the " + "segmentation mask will have the same number of z-slices " + "of the image data.

" + "Additionally, in the model parameters window, you will be able " + "to choose if you want to segment the entire 3D volume at once " + "or use the 2D model on each z-slice, one by one.

" + "NOTE: if the toggle is disabled it means you already " + "loaded segmentation data and the shape cannot be changed now.
" + "if you need to start with a blank segmentation, " + 'use the "Create a new segmentation file" button instead of the ' + '"Load folder" button.' + "
" + ) + msg = widgets.myMessageBox() + msg.setIcon() + msg.setWindowTitle(f"3D segmentation info") + msg.addText(html_utils.paragraph(txt)) + msg.addButton(" Ok ") + msg.exec_() + + def SizeZvalueChanged(self, val): + if len(self.imgDataShape) < 3: + return + + if val > 1 and self.imgDataShape is not None: + maxSizeZ = self.imgDataShape[-3] + self.SizeZ_SpinBox.setMaximum(maxSizeZ) + else: + self.SizeZ_SpinBox.setMaximum(2147483647) + + if val > 1: + if self.ask_PhysicalSizes: + self.PhysicalSizeZSpinBox.show() + self.PhysicalSizeZLabel.show() + if self.askSegm3D: + self.isSegm3DLabel.show() + self.isSegm3Dtoggle.show() + self.infoButtonSegm3D.show() + else: + self.PhysicalSizeZSpinBox.hide() + self.PhysicalSizeZLabel.hide() + self.isSegm3DLabel.hide() + self.isSegm3Dtoggle.hide() + self.infoButtonSegm3D.hide() + + self.checkSegmDataShape() + + def checkSegmDataShape(self): + if self.posData is None: + return + + if self.isSegm3Dtoggle.isEnabled(): + return + + SizeT = self.SizeT_SpinBox.value() + SizeZ = self.SizeZ_SpinBox.value() + segm_data_ndim = self.posData.segm_data.ndim + isSegm3D = False + if segm_data_ndim == 4: + # Segm data is 4D so it must be 3D over time + isSegm3D = True + elif segm_data_ndim == 3 and SizeZ > 1 and SizeT == 1: + # Segm data is 3D while SizeT == 1 and SizeZ > 1 + # --> also segm is 3D z-stack + isSegm3D = True + + self.isSegm3Dtoggle.setDisabled(False) + self.isSegm3Dtoggle.setChecked(isSegm3D) + self.isSegm3Dtoggle.setDisabled(True) + + def TimeIncrementShowHide(self, val): + self.checkSegmDataShape() + if not self.ask_TimeIncrement: + return + + if val > 1: + self.TimeIncrementSpinBox.show() + self.TimeIncrementLabel.show() + else: + self.TimeIncrementSpinBox.hide() + self.TimeIncrementLabel.hide() + + def allowEditSizeT(self, checked): + if checked: + self.SizeT_SpinBox.setDisabled(False) + if self.SizeT_metadata is not None: + self.SizeT_SpinBox.setValue(self.SizeT_metadata) + else: + self.SizeT_SpinBox.setDisabled(True) + self.SizeT_SpinBox.setValue(1) + + def warnEditingMetadata(self, Size, Size_metadata, which_dim): + txt = html_utils.paragraph(f""" + The number of {which_dim} in the saved metadata is {Size_metadata}, + but you are requesting to change it to {Size}.

+ Are you sure you want to proceed? + """) + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + _, noButton, yesButton = msg.warning( + self, + "WARNING: Edinting saved metadata", + txt, + buttonsTexts=("Cancel", "No", "Yes, edit the metadata"), + ) + return msg.clickedButton == yesButton + + def ok_cb(self, checked=False): + self.cancel = False + self.SizeT = self.SizeT_SpinBox.value() + self.SizeZ = self.SizeZ_SpinBox.value() + + if self.SizeT_metadata is not None: + if self.SizeT != self.SizeT_metadata: + proceed = self.warnEditingMetadata( + self.SizeT, self.SizeT_metadata, "frames" + ) + if not proceed: + return + + if self.SizeZ_metadata is not None: + if self.SizeZ != self.SizeZ_metadata: + proceed = self.warnEditingMetadata( + self.SizeZ, self.SizeZ_metadata, "z-slices" + ) + if not proceed: + return + + self.isSegm3D = self.isSegm3Dtoggle.isChecked() + + self.TimeIncrement = self.TimeIncrementSpinBox.value() + self.PhysicalSizeX = self.PhysicalSizeXSpinBox.value() + self.PhysicalSizeY = self.PhysicalSizeYSpinBox.value() + self.PhysicalSizeZ = self.PhysicalSizeZSpinBox.value() + self._additionalValues = { + f"__{field['nameWidget'].text()}": field["valueWidget"].text() + for field in self.additionalFieldsWidgets + } + proceed = self.checkShapeMismatchMetadata() + if not proceed: + return + + if self.posData is not None and self.sender() != self.okButton: + exp_path = self.posData.exp_path + pos_foldernames = utils.get_pos_foldernames(exp_path) + if self.sender() == self.selectButton: + select_folder = load.select_exp_folder() + select_folder.pos_foldernames = pos_foldernames + select_folder.QtPrompt( + self, pos_foldernames, allow_cancel=False, toggleMulti=True + ) + pos_foldernames = select_folder.selected_pos + for pos in pos_foldernames: + images_path = os.path.join(exp_path, pos, "Images") + ls = utils.listdir(images_path) + search = [file for file in ls if file.find("metadata.csv") != -1] + metadata_df = None + if search: + fileName = search[0] + metadata_csv_path = os.path.join(images_path, fileName) + metadata_df = pd.read_csv(metadata_csv_path).set_index( + "Description" + ) + if metadata_df is not None: + metadata_df.at["TimeIncrement", "values"] = self.TimeIncrement + metadata_df.at["PhysicalSizeZ", "values"] = self.PhysicalSizeZ + metadata_df.at["PhysicalSizeY", "values"] = self.PhysicalSizeY + metadata_df.at["PhysicalSizeX", "values"] = self.PhysicalSizeX + metadata_df.to_csv(metadata_csv_path) + + search = [file for file in ls if file.find("acdc_output.csv") != -1] + acdc_df = None + if search: + fileName = search[0] + acdc_df_path = os.path.join(images_path, fileName) + acdc_df = pd.read_csv(acdc_df_path) + yx_pxl_to_um2 = self.PhysicalSizeY * self.PhysicalSizeX + vox_to_fl = self.PhysicalSizeY * (self.PhysicalSizeX**2) + if "cell_vol_fl" not in acdc_df.columns: + continue + acdc_df["cell_vol_fl"] = acdc_df["cell_vol_vox"] * vox_to_fl + acdc_df["cell_area_um2"] = acdc_df["cell_area_pxl"] * yx_pxl_to_um2 + acdc_df["time_seconds"] = acdc_df["frame_i"] * self.TimeIncrement + try: + acdc_df.to_csv(acdc_df_path, index=False) + except PermissionError: + err_msg = html_utils.paragraph( + "The below file is open in another app " + "(Excel maybe?).

" + f"{acdc_df_path}

" + 'Close file and then press "Ok".' + ) + msg = widgets.myMessageBox() + msg.critical(self, "Permission denied", err_msg) + acdc_df.to_csv(acdc_df_path, index=False) + + elif self.sender() == self.selectButton: + pass + + self.close() + + def checkShapeMismatchMetadata(self): + valid4D = True + valid3D = True + valid2D = True + if self.imgDataShape is None: + self.close() + elif len(self.imgDataShape) == 4: + T, Z, Y, X = self.imgDataShape + valid4D = self.SizeT == T and self.SizeZ == Z + elif len(self.imgDataShape) == 3: + TorZ, Y, X = self.imgDataShape + valid3D = self.SizeT == TorZ or self.SizeZ == TorZ + elif len(self.imgDataShape) == 2: + valid2D = self.SizeT == 1 and self.SizeZ == 1 + + valid = all([valid4D, valid3D, valid2D]) + if valid: + return True + + if not valid4D: + txt = f""" + You loaded 4D data, hence the number of frames MUST be + {T}
and the number of z-slices MUST be {Z}.

+ What do you want to do? + """ + if not valid3D: + txt = f""" + You loaded 3D data, hence either the number of frames or + the number of z-slices is {TorZ}.

+ However, if the number of frames is greater than 1 then the
+ number of z-slices MUST be 1, and vice-versa.

+ What do you want to do? + """ + + if not valid2D: + txt = f""" + You loaded 2D data, hence the number of frames MUST be 1 + and the number of z-slices MUST be 1.

+ What do you want to do? + """ + + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(txt) + + continueButton = widgets.okPushButton("Continue anyway") + correctButton = widgets.editPushButton("Let me correct") + + msg.warning( + self, + "Shape-metadata mismatch", + txt, + buttonsTexts=(continueButton, correctButton), + ) + if msg.cancel or msg.clickedButton == correctButton: + return False + + return True + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class QCropZtool(QBaseDialog): + sigClose = Signal() + sigZvalueChanged = Signal(str, int) + sigReset = Signal() + sigCrop = Signal(int, int) + + def __init__( + self, + SizeZ, + cropButtonText="Apply crop", + parent=None, + addDoNotShowAgain=False, + title="Select z-slices", + ): + super().__init__(parent) + + self.cancel = True + + self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) + + self.SizeZ = SizeZ + self.numDigits = len(str(self.SizeZ)) + + self.setWindowTitle(title) + + layout = QGridLayout() + buttonsLayout = QHBoxLayout() + + self.lowerZscrollbar = widgets.ScrollBarWithNumericControl() + self.lowerZscrollbar.setMaximum(SizeZ) + self.lowerZscrollbar.setMinimum(1) + self.lowerZscrollbar.setValue(1) + + self.upperZscrollbar = widgets.ScrollBarWithNumericControl() + self.upperZscrollbar.setMaximum(SizeZ) + self.upperZscrollbar.setValue(SizeZ) + + cancelButton = widgets.cancelPushButton("Cancel") + cropButton = widgets.okPushButton(cropButtonText) + buttonsLayout.addWidget(cropButton) + buttonsLayout.addWidget(cancelButton) + + row = 0 + layout.addWidget(QLabel("Lower z-slice "), row, 0, alignment=Qt.AlignRight) + layout.addWidget(self.lowerZscrollbar, row, 1) + + row += 1 + layout.setRowStretch(row, 5) + + row += 1 + layout.addWidget(QLabel("Upper z-slice "), row, 0, alignment=Qt.AlignRight) + layout.addWidget(self.upperZscrollbar, row, 1) + + row += 1 + if addDoNotShowAgain: + self.doNotShowAgainCheckbox = QCheckBox("Do not ask again") + layout.addWidget( + self.doNotShowAgainCheckbox, row, 1, alignment=Qt.AlignLeft + ) + row += 1 + + layout.addLayout(buttonsLayout, row, 1, alignment=Qt.AlignRight) + + layout.setColumnStretch(0, 0) + layout.setColumnStretch(1, 10) + + self.setLayout(layout) + + # resetButton.clicked.connect(self.emitReset) + cropButton.clicked.connect(self.emitCrop) + cancelButton.clicked.connect(self.close) + self.lowerZscrollbar.sigValueChanged.connect(self.ZvalueChanged) + self.upperZscrollbar.sigValueChanged.connect(self.ZvalueChanged) + + def emitReset(self): + self.sigReset.emit() + + def emitCrop(self): + self.cancel = False + low_z = self.lowerZscrollbar.value() - 1 + high_z = self.upperZscrollbar.value() - 1 + self.sigCrop.emit(low_z, high_z) + self.close() + + def updateScrollbars(self, lower_z, upper_z): + self.lowerZscrollbar.setValue(lower_z + 1) + self.upperZscrollbar.setValue(upper_z + 1) + + def ZvalueChanged(self, value): + which = "lower" if self.sender() == self.lowerZscrollbar else "upper" + if which == "lower" and value > self.upperZscrollbar.value() - 1: + self.lowerZscrollbar.setValue(self.upperZscrollbar.value() - 1) + return + if which == "upper" and value < self.lowerZscrollbar.value() + 1: + self.upperZscrollbar.setValue(self.lowerZscrollbar.value() + 1) + return + + z_slice_n = value - 1 + self.sigZvalueChanged.emit(which, z_slice_n) + + def showEvent(self, event): + self.resize(int(self.width() * 1.5), self.height()) + + def closeEvent(self, event): + super().closeEvent(event) + self.sigClose.emit() + + +class TreeSelectorDialog(QBaseDialog): + sigItemDoubleClicked = Signal(object) + + def __init__( + self, + title="Tree selector", + infoTxt="", + parent=None, + multiSelection=True, + widthFactor=None, + heightFactor=None, + expandOnDoubleClick=False, + isTopLevelSelectable=True, + allItemsExpanded=True, + allowNoSelection=True, + ): + super().__init__(parent) + + self.setWindowTitle(title) + + self.cancel = True + self.widthFactor = widthFactor + self.heightFactor = heightFactor + self.allItemsExpanded = allItemsExpanded + self.mainLayout = QVBoxLayout() + self._isTopLevelSelectable = isTopLevelSelectable + self.allowNoSelection = allowNoSelection + + if infoTxt: + self.mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) + + self.treeWidget = widgets.TreeWidget(multiSelection=multiSelection) + self.treeWidget.setExpandsOnDoubleClick(expandOnDoubleClick) + self.treeWidget.setHeaderHidden(True) + self.mainLayout.addWidget(self.treeWidget) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addSpacing(20) + self.mainLayout.addLayout(buttonsLayout) + + self.buttonsLayout = buttonsLayout + + self.setLayout(self.mainLayout) + + self.treeWidget.itemClicked.connect(self.onItemClicked) + self.treeWidget.itemDoubleClicked.connect(self.onItemDoubleClicked) + + def onItemDoubleClicked(self, item): + self.sigItemDoubleClicked.emit(item) + + def onItemClicked(self, item): + if self._isTopLevelSelectable: + return + if item.parent() is None: + item.setSelected(False) + + def addTree(self, tree: dict): + for topLevel, children in tree.items(): + topLevelItem = widgets.TreeWidgetItem(self.treeWidget) + topLevelItem.setText(0, topLevel) + self.treeWidget.addTopLevelItem(topLevelItem) + childrenItems = [widgets.TreeWidgetItem([c]) for c in children] + topLevelItem.addChildren(childrenItems) + if not self.allItemsExpanded: + continue + topLevelItem.setExpanded(True) + + def resizeVertical(self): + if not self.isVisible(): + self.show() + + currentTreeWidgetHeight = self.treeWidget.height() + treeWidgetHeight = 0 + for i in range(self.treeWidget.topLevelItemCount()): + topLevelItem = self.treeWidget.topLevelItem(i) + rect = self.treeWidget.visualItemRect(topLevelItem) + treeWidgetHeight += rect.height() + for j in range(topLevelItem.childCount()): + childItem = topLevelItem.child(j) + rect = self.treeWidget.visualItemRect(childItem) + treeWidgetHeight += rect.height() + + deltaHeight = treeWidgetHeight - currentTreeWidgetHeight + 10 + self.resize(self.width(), self.height() + deltaHeight) + self.move(self.x(), 20) + + def setCurrentItem(self, itemText: dict): + if not itemText: + return + for i in range(self.treeWidget.topLevelItemCount()): + topLevelItem = self.treeWidget.topLevelItem(i) + topLevelName = topLevelItem.text(0) + childText = itemText.get(topLevelName) + if childText is None: + continue + for j in range(topLevelItem.childCount()): + childItem = topLevelItem.child(j) + childItemText = childItem.text(0) + if childItemText == childText: + childItem.setSelected(True) + topLevelItem.setExpanded(True) + self.treeWidget.scrollToItem(topLevelItem) + break + + def selectedItems(self): + self._selectedItems = {} + for i in range(self.treeWidget.topLevelItemCount()): + topLevelItem = self.treeWidget.topLevelItem(i) + topLevelName = topLevelItem.text(0) + for j in range(topLevelItem.childCount()): + childItem = topLevelItem.child(j) + if not childItem.isSelected(): + continue + if topLevelName not in self._selectedItems: + self._selectedItems[topLevelName] = [childItem.text(0)] + else: + self._selectedItems[topLevelName].append(childItem.text(0)) + return self._selectedItems + + def warnSelectionIsEmpty(self): + txt = html_utils.paragraph(""" + You did not select anything :(.

+ Please press Cancel to exit without selecting items. + Thanks! + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Selection is empty", txt) + + def ok_cb(self): + if not self.allowNoSelection and not self.selectedItems(): + self.warnSelectionIsEmpty() + return + self.cancel = False + self.close() + + def showEvent(self, event) -> None: + super().showEvent(event) + if self.widthFactor is not None: + self.resize(int(self.width() * self.widthFactor), self.height()) + if self.heightFactor is not None: + self.resize(self.width(), int(self.height() * self.heightFactor)) + + +class TreesSelectorDialog(QBaseDialog): + def __init__( + self, trees, groupsDescr=None, title="Trees selector", infoTxt="", parent=None + ): + super().__init__(parent) + + self.setWindowTitle(title) + + self.cancel = True + self.mainLayout = QVBoxLayout() + + if infoTxt: + self.mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) + + self.treeWidgets = {} + self.setLayout(self.mainLayout) + + createdGroupLayouts = {} + for treeName, tree in trees.items(): + if groupsDescr is None: + groupName = "" + else: + groupName = groupsDescr.get(treeName, "Group info missing") + groupLayout = createdGroupLayouts.get(groupName, None) + if groupLayout is None: + self.mainLayout.addWidget(QLabel(html_utils.paragraph(groupName))) + groupBox = QGroupBox() + self.mainLayout.addWidget(groupBox) + groupLayout = QVBoxLayout() + groupBox.setLayout(groupLayout) + createdGroupLayouts[groupName] = groupLayout + else: + groupLayout.addSpacing(10) + groupLayout.addWidget(QLabel(html_utils.paragraph(treeName))) + treeWidget = widgets.TreeWidget(multiSelection=True) + treeWidget.setHeaderHidden(True) + for topLevel, children in tree.items(): + topLevelItem = widgets.TreeWidgetItem(treeWidget) + topLevelItem.setText(0, topLevel) + treeWidget.addTopLevelItem(topLevelItem) + childrenItems = [widgets.TreeWidgetItem([c]) for c in children] + topLevelItem.addChildren(childrenItems) + topLevelItem.setExpanded(True) + self.treeWidgets[treeName] = treeWidget + groupLayout.addWidget(treeWidget) + self.mainLayout.addSpacing(20) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addSpacing(10) + self.mainLayout.addLayout(buttonsLayout) + + def ok_cb(self): + self.cancel = False + self.selectedItems = {} + for treeName, treeWidget in self.treeWidgets.items(): + for i in range(treeWidget.topLevelItemCount()): + topLevelItem = treeWidget.topLevelItem(i) + for j in range(topLevelItem.childCount()): + childItem = topLevelItem.child(j) + if not childItem.isSelected(): + continue + if treeName not in self.selectedItems: + self.selectedItems[treeName] = [childItem.text(0)] + else: + self.selectedItems[treeName].append(childItem.text(0)) + self.close() + + +class MultiListSelector(QBaseDialog): + def __init__( + self, + lists: dict, + groupsDescr: dict = None, + title="Lists selector", + infoTxt="", + parent=None, + ): + super().__init__(parent) + + self.setWindowTitle(title) + + self.cancel = True + mainLayout = QVBoxLayout() + + if infoTxt: + mainLayout.addWidget(QLabel(html_utils.paragraph(infoTxt))) + + self.listWidgets = {} + createdGroupLayouts = {} + for listName, listItems in lists.items(): + if groupsDescr is None: + groupName = "" + else: + groupName = groupsDescr.get(listName, "Group info missing") + groupLayout = createdGroupLayouts.get(listName, None) + if groupLayout is None: + mainLayout.addWidget(QLabel(html_utils.paragraph(groupName))) + groupBox = QGroupBox() + mainLayout.addWidget(groupBox) + groupLayout = QVBoxLayout() + groupBox.setLayout(groupLayout) + createdGroupLayouts[groupName] = groupLayout + else: + groupLayout.addSpacing(10) + groupLayout.addWidget(QLabel(html_utils.paragraph(listName))) + listWidget = widgets.listWidget() + listWidget.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + listWidget.addItems(listItems) + groupLayout.addWidget(listWidget) + mainLayout.addSpacing(20) + self.listWidgets[listName] = listWidget + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(10) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def ok_cb(self): + self.cancel = False + self.selectedItems = {} + for listName, listWidget in self.listWidgets.items(): + if not listWidget.selectedItems(): + continue + self.selectedItems[listName] = [ + item.text() for item in listWidget.selectedItems() + ] + self.close() + + +class selectPositionsMultiExp(QBaseDialog): + def __init__(self, expPaths: dict, infoPaths: dict = None, parent=None): + super().__init__(parent=parent) + + self.expPaths = expPaths + self.cancel = True + + mainLayout = QVBoxLayout() + + self.setWindowTitle("Select Positions to process") + + infoTxt = html_utils.paragraph( + "Select one or more Positions to process

" + "Click on experiment path to select all positions
" + "Ctrl+Click to select multiple items
" + "Shift+Click to select a range of items
", + center=True, + ) + infoLabel = QLabel(infoTxt) + + self.treeWidget = QTreeWidget() + self.treeWidget.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + self.treeWidget.setHeaderHidden(True) + self.treeWidget.setFont(font) + for exp_path, positions in expPaths.items(): + pathLevels = exp_path.split(os.sep) + posFoldersInfo = None + if infoPaths is not None: + posFoldersInfo = infoPaths.get(exp_path) + if len(pathLevels) > 4: + itemText = os.path.join(*pathLevels[-4:]) + itemText = f"...{itemText}" + else: + itemText = exp_path + exp_path_item = QTreeWidgetItem([itemText]) + exp_path_item.setToolTip(0, exp_path) + exp_path_item.full_path = exp_path + self.treeWidget.addTopLevelItem(exp_path_item) + postions_items = [] + for pos in positions: + if posFoldersInfo is not None: + status = posFoldersInfo.get(pos, "") + else: + status = "" + pos_item_text = f"{pos}{status}" + pos_item = QTreeWidgetItem(exp_path_item, [pos_item_text]) + pos_item.posFoldername = pos + postions_items.append(pos_item) + exp_path_item.addChildren(postions_items) + exp_path_item.setExpanded(True) + + self.treeWidget.itemClicked.connect(self.selectAllChildren) + + buttonsLayout = QHBoxLayout() + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton(" Ok ") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + mainLayout.addWidget(self.treeWidget) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + self.setStyleSheet(TREEWIDGET_STYLESHEET) + + def selectAllChildren(self, item, col): + if item.parent() is not None: + return + + for i in range(item.childCount()): + item.child(i).setSelected(True) + + def ok_cb(self): + if not self.treeWidget.selectedItems(): + msg = widgets.myMessageBox(wrapText=False) + txt = "You did not select any experiment/Position folder!" + msg.warning(self, "Empty selection!", html_utils.paragraph(txt)) + return + + self.cancel = False + self.selectedPaths = {} + for item in self.treeWidget.selectedItems(): + if item.parent() is None: + continue + parent = item.parent() + exp_path = parent.full_path + pos_folder = item.posFoldername + if exp_path not in self.selectedPaths: + self.selectedPaths[exp_path] = [] + self.selectedPaths[exp_path].append(pos_folder) + + self.close() + + def showEvent(self, event): + self.resize(int(self.width() * 2), self.height()) + + +class QDialogZsliceAbsent(QDialog): + def __init__(self, filename, SizeZ, filenamesWithInfo, parent=None): + self.runDataPrep = False + self.useMiddleSlice = False + self.useSameAsCh = False + + self.cancel = True + + super().__init__(parent) + self.setWindowTitle("Reference z-slice info absent") + + mainLayout = QVBoxLayout() + buttonsLayout = QGridLayout() + + txt = html_utils.paragraph( + f""" + You loaded the fluorescent file called

{filename}

+ however you never selected which z-slice
you want to use + when calculating metrics
(e.g., mean, median, amount...etc.)

+ Choose one of following options: + """, + center=True, + ) + infoLabel = QLabel(txt) + mainLayout.addWidget(infoLabel, alignment=Qt.AlignCenter) + + runDataPrepButton = QPushButton( + " Visualize the data now and select a z-slice " + ) + buttonsLayout.addWidget(runDataPrepButton, 0, 1, 1, 2) + runDataPrepButton.clicked.connect(self.runDataPrep_cb) + + useMiddleSliceButton = QPushButton( + f" Use the middle z-slice ({int(SizeZ / 2) + 1}) " + ) + buttonsLayout.addWidget(useMiddleSliceButton, 1, 1, 1, 2) + useMiddleSliceButton.clicked.connect(self.useMiddleSlice_cb) + + useSameAsChButton = QPushButton(" Use the same z-slice used for the channel: ") + useSameAsChButton.clicked.connect(self.useSameAsCh_cb) + + chNameComboBox = QComboBox() + chNameComboBox.addItems(filenamesWithInfo) + # chNameComboBox.setEditable(True) + # chNameComboBox.lineEdit().setAlignment(Qt.AlignCenter) + # chNameComboBox.lineEdit().setReadOnly(True) + self.chNameComboBox = chNameComboBox + buttonsLayout.addWidget(useSameAsChButton, 2, 1) + buttonsLayout.addWidget(chNameComboBox, 2, 2) + + buttonsLayout.setColumnStretch(0, 1) + buttonsLayout.setColumnStretch(3, 1) + buttonsLayout.setContentsMargins(10, 0, 10, 0) + + cancelButtonLayout = QHBoxLayout() + cancelButton = widgets.cancelPushButton("Cancel") + cancelButtonLayout.addStretch(1) + cancelButtonLayout.addWidget(cancelButton) + cancelButtonLayout.addStretch(1) + cancelButtonLayout.setStretch(1, 1) + cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(buttonsLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(cancelButtonLayout) + mainLayout.addStretch(1) + + self.setLayout(mainLayout) + + font = QFont() + font.setPixelSize(12) + self.setFont(font) + + # self.setModal(True) + + def ok_cb(self, checked=True): + self.cancel = False + self.close() + + def useSameAsCh_cb(self, checked): + self.useSameAsCh = True + self.selectedChannel = self.chNameComboBox.currentText() + self.ok_cb() + + def useMiddleSlice_cb(self, checked): + self.useMiddleSlice = True + self.ok_cb() + + def runDataPrep_cb(self, checked): + self.runDataPrep = True + self.ok_cb() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class SetColumnNamesDialog(QBaseDialog): + def __init__(self, columnNames, categories, optionalCategories=None, parent=None): + super().__init__(parent) + + if not optionalCategories: + optionalCategories = None + + self.cancel = True + + mainLayout = QVBoxLayout() + + mainLayout.addWidget( + QLabel( + html_utils.paragraph("Assign a column to the following categories:
") + ) + ) + + self.categoriesWidgets = {} + formLayout = QFormLayout() + for row, category in enumerate(categories): + combobox = widgets.ComboBox() + combobox.addItems(columnNames) + if optionalCategories is not None: + text = f"* {category}" + else: + text = category + formLayout.addRow(text, combobox) + self.categoriesWidgets[category] = combobox + + if optionalCategories is not None: + optionalItems = ["None", *columnNames] + for row, category in enumerate(optionalCategories): + combobox = widgets.ComboBox() + combobox.addItems(optionalItems) + formLayout.addRow(category, combobox) + self.categoriesWidgets[category] = combobox + + mainLayout.addLayout(formLayout) + if optionalCategories is not None: + mainLayout.addSpacing(10) + mainLayout.addWidget( + QLabel(html_utils.paragraph("* mandatory", font_size="11px")) + ) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + self.setFont(font) + + def _warnNonUniqueCategories(self, category_1, category_2): + txt = html_utils.paragraph(f""" + The following categories have the same column assigned to it.

+ Columns assigned to categories must be unique.

+ Categories with the same column: + {html_utils.to_list((category_1, category_2))} + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Non-unique columns", txt) + + def _checkUniqueNames(self): + self.textToCategoryMapper = {} + for category, combobox in self.categoriesWidgets.items(): + if combobox.text() == "None": + continue + + if combobox.text() not in self.textToCategoryMapper: + self.textToCategoryMapper[combobox.text()] = category + continue + + sameCategory = self.textToCategoryMapper[combobox.text()] + self._warnNonUniqueCategories(category, sameCategory) + return False + + return True + + def ok_cb(self): + proceed = self._checkUniqueNames() + if not proceed: + return + + self.selectedColumns = { + category: combobox.text() + for category, combobox in self.categoriesWidgets.items() + } + self.cancel = False + self.close() + + +class QCropTrangeTool(QBaseDialog): + sigClose = Signal() + sigTvalueChanged = Signal(int) + sigReset = Signal() + sigCrop = Signal(int, int) + + def __init__( + self, + SizeT, + cropButtonText="Apply crop", + parent=None, + addDoNotShowAgain=False, + title="Select frames range", + ): + super().__init__(parent) + + self.cancel = True + + self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) + + self.SizeT = SizeT + self.numDigits = len(str(self.SizeT)) + + self.setWindowTitle(title) + + layout = QGridLayout() + buttonsLayout = QHBoxLayout() + + self.startFrameScrollbar = widgets.sliderWithSpinBox( + spinbox_loc="left", maximum_on_label=SizeT + ) + self.startFrameScrollbar.setMaximum(SizeT, including_spinbox=True) + self.startFrameScrollbar.setMinimum(1, including_spinbox=True) + + self.endFrameScrollbar = widgets.sliderWithSpinBox( + spinbox_loc="left", maximum_on_label=SizeT + ) + self.endFrameScrollbar.setMaximum(SizeT, including_spinbox=True) + self.endFrameScrollbar.setMinimum(1, including_spinbox=True) + self.endFrameScrollbar.setValue(SizeT) + + cancelButton = widgets.cancelPushButton("Cancel") + cropButton = widgets.okPushButton(cropButtonText) + buttonsLayout.addWidget(cropButton) + buttonsLayout.addWidget(cancelButton) + + row = 0 + layout.addWidget(QLabel("Start frame n. "), row, 0, alignment=Qt.AlignRight) + layout.addWidget(self.startFrameScrollbar, row, 2) + + row += 1 + layout.setRowStretch(row, 5) + layout.addItem(QSpacerItem(10, 10), row, 0) + + row += 1 + layout.addWidget(QLabel("Stop frame n. "), row, 0, alignment=Qt.AlignRight) + layout.addWidget(self.endFrameScrollbar, row, 2) + + row += 1 + if addDoNotShowAgain: + self.doNotShowAgainCheckbox = QCheckBox("Do not ask again") + layout.addWidget( + self.doNotShowAgainCheckbox, row, 2, alignment=Qt.AlignLeft + ) + row += 1 + + layout.addItem(QSpacerItem(10, 20), row, 0) + layout.addLayout(buttonsLayout, row + 1, 2, alignment=Qt.AlignRight) + + layout.setColumnStretch(0, 0) + layout.setColumnStretch(1, 0) + layout.setColumnStretch(2, 10) + + self.setLayout(layout) + + # resetButton.clicked.connect(self.emitReset) + cropButton.clicked.connect(self.emitCrop) + cancelButton.clicked.connect(self.close) + self.startFrameScrollbar.sigValueChange.connect(self.TvalueChanged) + self.endFrameScrollbar.sigValueChange.connect(self.TvalueChanged) + + def emitReset(self): + self.sigReset.emit() + + def emitCrop(self): + self.cancel = False + low_z = self.startFrameScrollbar.value() - 1 + high_z = self.endFrameScrollbar.value() - 1 + self.sigCrop.emit(low_z, high_z) + self.close() + + def updateScrollbars(self, start_frame_i, lower_frame_i): + self.startFrameScrollbar.setValue(start_frame_i + 1) + self.endFrameScrollbar.setValue(lower_frame_i + 1) + + def TvalueChanged(self, value): + frame_i = value - 1 + self.sigTvalueChanged.emit(frame_i) + + def showEvent(self, event): + self.resize(int(self.width() * 2.0), self.height()) + + def closeEvent(self, event): + super().closeEvent(event) + self.sigClose.emit() + + +class SelectFoldersToAnalyse(QBaseDialog): + def __init__( + self, + parent=None, + preSelectedPaths=None, + onlyExpPaths=False, + scanFolderTree=True, + instructionsText="Select experiment folders to analyse", + askSelectPosFolders=False, + ): + super().__init__(parent) + + self.cancel = True + self.onlyExpPaths = onlyExpPaths + self.setWindowTitle("Select experiments to analyse") + self.scanTree = scanFolderTree + self.askSelectPosFolders = askSelectPosFolders + + mainLayout = QVBoxLayout() + + instructionsText = html_utils.paragraph( + f"{instructionsText}

" + "Drag and drop folders or click on Add folder button to " + "add as many folders " + "as needed.
", + font_size="14px", + ) + instructionsLabel = QLabel(instructionsText) + instructionsLabel.setAlignment(Qt.AlignCenter) + + infoText = html_utils.paragraph( + "A valid folder is either a Position folder, " + "or an experiment folder (containing Position_n folders),
" + "or any folder that contains multiple experiment folders.

" + "In the last case, Cell-ACDC will automatically scan the entire tree of " + "sub-directories
" + "and will add all experiments having the right folder structure.
", + font_size="12px", + ) + infoLabel = QLabel(infoText) + infoLabel.setAlignment(Qt.AlignCenter) + + self.listWidget = widgets.listWidget() + self.listWidget.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if preSelectedPaths is not None: + self.listWidget.addItems(preSelectedPaths) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + delButton = widgets.delPushButton("Remove selected path(s)") + browseButton = widgets.browseFileButton( + "Add folder...", openFolder=True, start_dir=utils.getMostRecentPath() + ) + + buttonsLayout.insertWidget(3, delButton) + buttonsLayout.insertWidget(4, browseButton) + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + browseButton.sigPathSelected.connect(self.addFolderPath) + delButton.clicked.connect(self.removePaths) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addWidget(instructionsLabel) + mainLayout.addWidget(infoLabel) + mainLayout.addWidget(self.listWidget) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch(1) + + self.setLayout(mainLayout) + + self.setAcceptDrops(True) + + self.setFont(font) + + def dragEnterEvent(self, event): + event.acceptProposedAction() + + def dropEvent(self, event): + event.setDropAction(Qt.CopyAction) + for url in event.mimeData().urls(): + dropped_path = url.toLocalFile() + if os.path.isfile(dropped_path): + dropped_path = os.path.dirname(dropped_path) + + QTimer.singleShot(50, partial(self.addFolderPath, dropped_path)) + + def pathsList(self): + return [ + self.listWidget.item(i).text().replace("\\", "/") + for i in range(self.listWidget.count()) + ] + + def expFolderToPosFoldernamesMapper(self): + expPathsPosFoldernamesMapper = defaultdict(set) + for selectedPath in self.pathsList(): + pos_foldernames = utils.get_pos_foldernames( + selectedPath, check_if_is_sub_folder=True + ) + if not pos_foldernames: + images_path = utils.get_images_folderpath(selectedPath) + expPathsPosFoldernamesMapper[selectedPath].add("") + else: + expPath = load.get_exp_path(selectedPath) + expPathsPosFoldernamesMapper[expPath].update(pos_foldernames) + + expPathsPosFoldernamesMapper = { + expPath: natsorted(pos_foldernames) + for expPath, pos_foldernames in expPathsPosFoldernamesMapper.items() + } + return expPathsPosFoldernamesMapper + + def ok_cb(self): + self.cancel = False + self.paths = self.pathsList() + self.selectedExpFolderToPosFoldernamesMapper = ( + self.expFolderToPosFoldernamesMapper() + ) + self.close() + + def warnNoValidPathsFound(self, selected_path): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + The selected path (see below) does not contain any valid folder.

+ Please, make sure to select a Position folder, the Images folder + inside a Position folder, or any folder containing a Position folder + as a sub-directory.

+ Thank you for your patience!

+ Selected path: + """) + msg.warning( + self, + "Training workflow generated", + txt, + commands=(f"{selected_path}",), + path_to_browse=selected_path, + ) + + def warnNoValidExpPaths(self, selected_path): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + The selected folder does + not contain any valid experiment folders. + """) + command = selected_path.replace("\\", os.sep) + command = selected_path.replace("/", os.sep) + msg.warning( + self, + "No valid folders found", + txt, + commands=(command,), + path_to_browse=selected_path, + ) + + def parse_select_from_exp_paths(self, exp_paths: dict[os.PathLike, Iterable[str]]): + if not self.askSelectPosFolders: + return list(exp_paths.keys()) + + paths = [] + for exp_path, pos_foldernames in exp_paths.items(): + if len(pos_foldernames) == 1: + paths.append(exp_path) + continue + + informativeText = html_utils.paragraph( + "The following experiment folder

" + f"{exp_path}

" + "contains multiple Position folders.

" + "Please, select which Position folder(s) you want to analyse:
" + ) + select_folder = load.select_exp_folder() + values = select_folder.get_values_dataprep(exp_path) + select_folder.QtPrompt( + self, + values, + toggleMulti=True, + informativeText=informativeText, + selectedValues=values, + ) + if select_folder.cancel: + return + + for pos in select_folder.selected_pos: + paths.append(os.path.join(exp_path, pos)) + + return paths + + def addFolderPath(self, selected_path): + utils.addToRecentPaths(selected_path) + + folder_type = utils.determine_folder_type(selected_path) + is_pos_folder, is_images_folder, folder_path = folder_type + if is_pos_folder: + paths = [selected_path] + elif is_images_folder: + paths = [os.path.dirname(selected_path)] + elif self.scanTree: + print(f'Scanning selected folder "{selected_path}"...') + exp_paths = path.get_posfolderpaths_walk(selected_path) + if not exp_paths: + self.warnNoValidExpPaths(selected_path) + return + + paths = self.parse_select_from_exp_paths(exp_paths) + if paths is None: + return + else: + paths = [selected_path] + + if not paths: + self.warnNoValidPathsFound(selected_path) + + for selectedPath in paths: + if self.onlyExpPaths: + selectedPath = load.get_exp_path(selectedPath) + + selectedPath = selectedPath.replace("\\", "/") + if selectedPath in self.pathsList(): + print( + f"[WARNING]: The following path was already selected: " + f'"{selectedPath}"' + ) + return + + self.listWidget.addItem(selectedPath) + + def removePaths(self): + for item in self.listWidget.selectedItems(): + row = self.listWidget.row(item) + self.listWidget.takeItem(row) + + +class OverlayLabelsAppearanceDialog(QBaseDialog): + sigValuesChanged = Signal(object) + + def __init__(self, scatterPlotItem: pg.ScatterPlotItem = None, parent=None): + super().__init__(parent) + + self.cancel = True + + self.setWindowTitle("Overlay contours appearance properties") + + mainLayout = QVBoxLayout() + + formLayout = widgets.FormLayout() + + row = -1 + + row += 1 + self.colorButton = widgets.myColorButton(color=(255, 0, 0)) + self.colorButton.clicked.disconnect() + self.colorButton.clicked.connect(self.selectColor) + self.colorButton.setCursor(Qt.PointingHandCursor) + self.colorWidget = widgets.formWidget( + self.colorButton, + addInfoButton=False, + stretchWidget=False, + labelTextLeft="Symbol color: ", + parent=self, + widgetAlignment="left", + ) + if scatterPlotItem is not None: + pen = scatterPlotItem.opts["pen"] + color = pen.color() + self.colorButton.setColor(color) + formLayout.addFormWidget(self.colorWidget, row=row) + + row += 1 + self.penWidthSpinBox = widgets.SpinBox() + self.penWidthSpinBox.setMinimum(0) + self.penWidthSpinBox.setValue(2) + + self.penWidthWidget = widgets.formWidget( + self.penWidthSpinBox, + addInfoButton=False, + stretchWidget=False, + labelTextLeft="Symbol weight: ", + parent=self, + widgetAlignment="left", + ) + if scatterPlotItem is not None: + pen = scatterPlotItem.opts["pen"] + width = pen.width() + self.penWidthSpinBox.setValue(width) + formLayout.addFormWidget(self.penWidthWidget, row=row) + + row += 1 + self.opacitySlider = widgets.sliderWithSpinBox(isFloat=True, normalize=True) + self.opacitySlider.setMinimum(0) + self.opacitySlider.setMaximum(100) + self.opacitySlider.setValue(0.8) + + self.opacityWidget = widgets.formWidget( + self.opacitySlider, + addInfoButton=False, + stretchWidget=True, + labelTextLeft="Symbol opacity: ", + parent=self, + ) + if scatterPlotItem is not None: + brush = scatterPlotItem.opts["brush"] + alpha = brush.color().alpha() + opacity = alpha / 255 + self.opacitySlider.setValue(opacity) + formLayout.addFormWidget(self.opacityWidget, row=row) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(formLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def selectColor(self): + color = self.colorButton.color() + self.colorButton.origColor = color + self.colorButton.colorDialog.setCurrentColor(color) + self.colorButton.colorDialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.colorButton.colorDialog.open() + w = self.width() + left = self.pos().x() + colorDialogTop = self.colorButton.colorDialog.pos().y() + self.colorButton.colorDialog.move(w + left + 10, colorDialogTop) + + def getBrush(self): + r, g, b, _ = self.colorButton.color().getRgb() + alpha = round(self.opacitySlider.value() * 255) + brushColor = (r, g, b, alpha) + brush = pg.mkBrush(brushColor) + return brush + + def getPen(self): + color = self.colorButton.color() + penWidth = self.penWidthSpinBox.value() + if penWidth == 0: + return + + pen = pg.mkPen(color, width=penWidth) + return pen + + def ok_cb(self): + self.cancel = False + self.properties = {"brush": self.getBrush(), "pen": self.getPen()} + self.close() + + +class AutoSaveIntervalDialog(QBaseDialog): + sigValueChanged = Signal(float, str) + + def __init__(self, parent=None): + super().__init__(parent) + + self.cancel = True + + self.setWindowTitle("Change autosave interval") + + mainLayout = QVBoxLayout() + + self.autoSaveIntervalWidget = widgets.AutoSaveIntervalWidget(parent=self) + + mainLayout.addWidget(QLabel("Autosave interval:")) + mainLayout.addWidget(self.autoSaveIntervalWidget) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def setValues(self, autoSaveIntevalValue, autoSaveIntervalUnit): + self.autoSaveIntervalWidget.spinbox.setValue(autoSaveIntevalValue) + self.autoSaveIntervalWidget.unitCombobox.setCurrentText(autoSaveIntervalUnit) + + def sizeHint(self): + defaultWidth = super().sizeHint().width() + defaultHeight = super().sizeHint().height() + return QSize(defaultWidth * 2, defaultHeight) + + def ok_cb(self): + self.cancel = False + self.sigValueChanged.emit( + self.autoSaveIntervalWidget.spinbox.value(), + self.autoSaveIntervalWidget.unitCombobox.currentText(), + ) + self.close() + +# Sibling imports (deferred to avoid import cycles) +from .general import ( + imageViewer, +) + diff --git a/cellacdc/dialogs/models.py b/cellacdc/dialogs/models.py new file mode 100644 index 000000000..c7579e463 --- /dev/null +++ b/cellacdc/dialogs/models.py @@ -0,0 +1,2265 @@ +"""Cell-ACDC dialog windows: models.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +def addCustomModelMessages(QParent=None): + modelFilePath = None + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(""" + Do you already have the acdcSegment.py file for your code + or do you need instructions on how to set-up your custom model?
+ """) + infoButton = widgets.infoPushButton(" I need instructions") + browseButton = widgets.browseFileButton(" I have the model, let me select it") + msg.information( + QParent, + "Add custom model", + txt, + buttonsTexts=("Cancel", infoButton, browseButton), + showDialog=False, + ) + browseButton.clicked.disconnect() + browseButton.clicked.connect(msg.buttonCallBack) + msg.exec_() + if msg.cancel: + return + if msg.clickedButton == infoButton: + txt = utils.get_add_custom_model_instructions() + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.information( + QParent, + "Custom model instructions", + txt, + buttonsTexts=("Ok",), + path_to_browse=models_path, + browse_button_text="Open models folder...", + ) + else: + homePath = pathlib.Path.home() + modelFilePath = QFileDialog.getOpenFileName( + QParent, + "Select the acdcSegment.py file of your model", + str(homePath), + "acdcSegment.py file (*.py);;All files (*)", + )[0] + if not modelFilePath: + return + + return modelFilePath + + +def addCustomPromptModelMessages(QParent=None): + modelFilePath = None + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(""" + Do you already have the acdcPromptSegment.py file for your code + or do you need instructions on how to set-up your custom model?
+ """) + infoButton = widgets.infoPushButton(" I need instructions") + browseButton = widgets.browseFileButton(" I have the model, let me select it") + msg.information( + QParent, + "Add custom promptable model", + txt, + buttonsTexts=("Cancel", infoButton, browseButton), + showDialog=False, + ) + browseButton.clicked.disconnect() + browseButton.clicked.connect(msg.buttonCallBack) + msg.exec_() + if msg.cancel: + return + if msg.clickedButton == infoButton: + txt = utils.get_add_custom_prompt_model_instructions() + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.information( + QParent, + "Custom promptable model instructions", + txt, + buttonsTexts=("Ok",), + path_to_browse=promptable_models_path, + browse_button_text="Open promptable models folder...", + ) + else: + homePath = pathlib.Path.home() + modelFilePath = QFileDialog.getOpenFileName( + QParent, + "Select the acdcPromptSegment.py file of your model", + str(homePath), + "acdcPromptSegment.py file (*.py);;All files (*)", + )[0] + if not modelFilePath: + return + + return modelFilePath + + +class SelectPromptableModelDialog(QBaseDialog): + def __init__(self, parent=None): + self.cancel = True + super().__init__(parent) + + self.setWindowTitle("Select model for segmentation") + + mainLayout = QVBoxLayout() + + label = QLabel(html_utils.paragraph("Select model to use for segmentation: ")) + mainLayout.addWidget(label, alignment=Qt.AlignCenter) + + listBox = widgets.listWidget() + models = utils.get_list_of_promptable_models() + listBox.addItems(models) + listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + listBox.setCurrentRow(0) + listBox.itemDoubleClicked.connect(self.ok_cb) + + self.listBox = listBox + + mainLayout.addWidget(listBox) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def ok_cb(self): + self.cancel = False + self.model_name = self.listBox.currentItem().text() + self.close() + + +class QDialogSelectModel(QDialog): + def __init__(self, parent=None, addSkipSegmButton=False, customFirst=""): + self.cancel = True + super().__init__(parent) + self.setWindowTitle("Select model") + + mainLayout = QVBoxLayout() + topLayout = QVBoxLayout() + bottomLayout = QHBoxLayout() + + self.mainLayout = mainLayout + + label = QLabel(html_utils.paragraph("Select model to use for segmentation: ")) + # padding: top, left, bottom, right + label.setStyleSheet("padding:0px 0px 3px 0px;") + topLayout.addWidget(label, alignment=Qt.AlignCenter) + + listBox = widgets.listWidget() + models = utils.get_list_of_models() + + if customFirst: + try: + idx = models.index(customFirst) + models.insert(0, models.pop(idx)) + except ValueError: + print(f"Warning: {customFirst} not found in models list.") + pass + + listBox.setFont(font) + listBox.addItems(models) + addCustomModelItem = QListWidgetItem("Add custom model...") + addCustomModelItem.setFont(italicFont) + listBox.addItem(addCustomModelItem) + listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + listBox.setCurrentRow(0) + self.listBox = listBox + listBox.itemDoubleClicked.connect(self.ok_cb) + topLayout.addWidget(listBox) + + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton(" Ok ") + okButton.setShortcut(Qt.Key_Enter) + + bottomLayout.addStretch(1) + bottomLayout.addWidget(cancelButton) + bottomLayout.addSpacing(20) + if addSkipSegmButton: + skipSegmButton = widgets.SkipPushButton("Skip segmentation") + bottomLayout.addWidget(skipSegmButton) + skipSegmButton.clicked.connect(self.skipSegm) + bottomLayout.addWidget(okButton) + bottomLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(topLayout) + mainLayout.addLayout(bottomLayout) + self.setLayout(mainLayout) + + # Connect events + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + + self.setStyleSheet(LISTWIDGET_STYLESHEET) + + def skipSegm(self): + self.cancel = False + self.selectedModel = "skip_segmentation" + self.close() + + def keyPressEvent(self, event: QKeyEvent) -> None: + if event.key() == Qt.Key_Escape: + event.ignore() + return + + super().keyPressEvent(event) + + def ok_cb(self, event): + self.clickedButton = self.sender() + self.cancel = False + item = self.listBox.currentItem() + model = item.text() + if model == "Add custom model...": + modelFilePath = addCustomModelMessages(self) + if modelFilePath is None: + return + utils.store_custom_model_path(modelFilePath) + modelName = os.path.basename(os.path.dirname(modelFilePath)) + item = QListWidgetItem(modelName) + self.listBox.addItem(item) + self.listBox.setCurrentItem(item) + elif model == "Automatic thresholding": + self.selectedModel = "thresholding" + self.close() + else: + self.selectedModel = model + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.selectedModel = None + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + + horizontal_sb = self.listBox.horizontalScrollBar() + while horizontal_sb.isVisible(): + self.resize(self.height(), self.width() + 10) + + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class DataFrameModel(QtCore.QAbstractTableModel): + # https://stackoverflow.com/questions/44603119/how-to-display-a-pandas-data-frame-with-pyqt5-pyside2 + DtypeRole = QtCore.Qt.UserRole + 1000 + ValueRole = QtCore.Qt.UserRole + 1001 + + def __init__(self, df=pd.DataFrame(), parent=None): + super(DataFrameModel, self).__init__(parent) + self._dataframe = df + + def setDataFrame(self, dataframe): + self.beginResetModel() + self._dataframe = dataframe.copy() + self.endResetModel() + + def dataFrame(self): + return self._dataframe + + dataFrame = QtCore.Property(pd.DataFrame, fget=dataFrame, fset=setDataFrame) + + @QtCore.Slot(int, QtCore.Qt.Orientation, result=str) + def headerData( + self, + section: int, + orientation: QtCore.Qt.Orientation, + role: int = QtCore.Qt.DisplayRole, + ): + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._dataframe.columns[section] + else: + return str(self._dataframe.index[section]) + return QtCore.QVariant() + + def rowCount(self, parent=QtCore.QModelIndex()): + if parent.isValid(): + return 0 + return len(self._dataframe.index) + + def columnCount(self, parent=QtCore.QModelIndex()): + if parent.isValid(): + return 0 + return self._dataframe.columns.size + + def data(self, index, role=QtCore.Qt.DisplayRole): + if not index.isValid() or not ( + 0 <= index.row() < self.rowCount() + and 0 <= index.column() < self.columnCount() + ): + return QtCore.QVariant() + row = self._dataframe.index[index.row()] + col = self._dataframe.columns[index.column()] + dt = self._dataframe[col].dtype + + if role == Qt.TextAlignmentRole: + return Qt.AlignCenter + + val = self._dataframe.iloc[row][col] + if role == QtCore.Qt.DisplayRole: + return str(val) + elif role == DataFrameModel.ValueRole: + return val + if role == DataFrameModel.DtypeRole: + return dt + return QtCore.QVariant() + + def roleNames(self): + roles = { + QtCore.Qt.DisplayRole: b"display", + DataFrameModel.DtypeRole: b"dtype", + DataFrameModel.ValueRole: b"value", + } + return roles + + +class QDialogModelParams(QDialog): + def __init__( + self, + init_params, + segment_params, + model_name, + is_tracker=False, + url=None, + parent=None, + initLastParams=True, + posData=None, + channels=None, + currentChannelName=None, + segmFileEndnames=None, + df_metadata=None, + force_postprocess_2D=False, + model_module=None, + action_type="", + addPreProcessParams=True, + addPostProcessParams=True, + extraParams=None, + extraParamsTitle=None, + ini_filename=None, + add_additional_segm_params=False, + ): + self.cancel = True + super().__init__(parent) + self.channels = channels + self.is_tracker = is_tracker + self.currentChannelName = currentChannelName + self.channelCombobox = None + self.segmFileEndnames = segmFileEndnames + self.df_metadata = df_metadata + self.force_postprocess_2D = force_postprocess_2D + + self.skipSegmentation = False + if len(segment_params) > 0: + if segment_params[0].name.lower().find("skip_segmentation") != -1: + self.skipSegmentation = True + addPreProcessParams = False + else: + self.skipSegmentation = False + if ini_filename is not None: + self.ini_filename = ini_filename + elif is_tracker: + self.ini_filename = "last_params_trackers.ini" + addPreProcessParams = False + addPostProcessParams = False + else: + self.ini_filename = "last_params_segm_models.ini" + + self.addPreProcessParams = addPreProcessParams + + self.model_name = model_name + + self.setWindowTitle(f"{model_name} parameters") + + # Create main vertical layout and horizontal layout for two columns + mainLayout = QVBoxLayout() + + gridLayout = QGridLayout() + self.gridLayout = gridLayout + + loadFunc = self.loadLastSelection + + self.paramsGroupPosMapper = {} + + # LEFT COLUMN: Preprocessing params + row, col = 0, 0 + preProcessLayout = None + self.preProcessParamsWidget = None + if addPreProcessParams: + preProcessLayout = QVBoxLayout() + self.preProcessParamsWidget = PreProcessParamsWidget( + parent=self, addApplyButton=False + ) + self.preProcessParamsWidget.setChecked(False) + preProcessLayout.addWidget(self.preProcessParamsWidget) + self.preProcessParamsWidget.sigLoadRecipe.connect(self.loadPreprocRecipe) + gridLayout.addLayout(preProcessLayout, row, col, 1, 2) + self.paramsGroupPosMapper[self.preProcessParamsWidget] = (row, col) + gridLayout.addItem(QSpacerItem(10, 5), 0, col + 1) + # gridLayout.setColumnMinimumWidth(col+1, 15) + col += 2 + + # Center COLUMN: Init, Segmentation/Eval + row = 0 + self.secondColLayout = QVBoxLayout() + self.initParamsScrollArea = widgets.ScrollArea() + initParamsScrollAreaLayout = QVBoxLayout() + self.initParamsScrollArea.setVerticalLayout(initParamsScrollAreaLayout) + + initGroupBox, self.init_argsWidgets = self.createGroupParams( + init_params, "Parameters for model initialization" + ) + self.init_params = init_params + initDefaultButton = widgets.reloadPushButton("Restore default") + initLoadLastSelButton = widgets.OpenFilePushButton("Load last parameters") + initLoadLastSelButton.setIcon(QIcon(":folder-open.svg")) + initButtonsLayout = QHBoxLayout() + initButtonsLayout.addStretch(1) + initButtonsLayout.addWidget(initDefaultButton) + initButtonsLayout.addWidget(initLoadLastSelButton) + initDefaultButton.clicked.connect(self.restoreDefaultInit) + initLoadLastSelButton.clicked.connect( + partial(loadFunc, f"{self.model_name}.init", self.init_argsWidgets) + ) + + initParamsScrollAreaLayout.addWidget(initGroupBox) + + initParamsLayout = QVBoxLayout() + initParamsLayout.addWidget(QLabel(f"{initGroupBox.title()}")) + initGroupBox.setTitle("") + initParamsLayout.addWidget(self.initParamsScrollArea) + initParamsLayout.addLayout(initButtonsLayout) + self.secondColLayout.addLayout(initParamsLayout) + self.paramsGroupPosMapper[self.initParamsScrollArea] = (0, col) + + self.segmentParamsScrollArea = None + if not self.skipSegmentation: + self.segmentParamsScrollArea = widgets.ScrollArea() + segmentParamsScrollAreaLayout = QVBoxLayout() + self.segmentParamsScrollArea.setVerticalLayout( + segmentParamsScrollAreaLayout + ) + if action_type: + runGroupboxTitle = f"Parameters for {action_type}" + elif is_tracker: + runGroupboxTitle = "Parameters for tracking" + else: + runGroupboxTitle = "Parameters for segmentation" + + segmentGroupBox, self.argsWidgets = self.createGroupParams( + segment_params, runGroupboxTitle, addChannelSelector=True + ) + self.segment_params = segment_params + self.segmentGroupBox = segmentGroupBox + segmentDefaultButton = widgets.reloadPushButton("Restore default") + segmentLoadLastSelButton = widgets.OpenFilePushButton( + "Load last parameters" + ) + segmentButtonsLayout = QHBoxLayout() + segmentButtonsLayout.addStretch(1) + segmentButtonsLayout.addWidget(segmentDefaultButton) + segmentButtonsLayout.addWidget(segmentLoadLastSelButton) + segmentDefaultButton.clicked.connect(self.restoreDefaultSegment) + section = f"{self.model_name}.segment" + segmentLoadLastSelButton.clicked.connect( + partial(loadFunc, section, self.argsWidgets) + ) + segmentParamsScrollAreaLayout.addWidget(segmentGroupBox) + + segmentParamsLayout = QVBoxLayout() + segmentParamsLayout.addWidget(QLabel(f"{segmentGroupBox.title()}")) + segmentGroupBox.setTitle("") + segmentParamsLayout.addWidget(self.segmentParamsScrollArea) + segmentParamsLayout.addLayout(segmentButtonsLayout) + self.secondColLayout.addLayout(segmentParamsLayout) + self.paramsGroupPosMapper[self.segmentParamsScrollArea] = (1, col) + + gridLayout.addLayout(self.secondColLayout, row, col) + + gridLayout.addItem(QSpacerItem(10, 5), 0, col + 1) + col += 2 + + # Buttons layout (spans both columns) + buttonsLayout = QHBoxLayout() + cancelButton = widgets.cancelPushButton(" Cancel ") + okButton = widgets.okPushButton(" Ok ") + + enableLoadingSavingRecipe = not is_tracker and ( + addPreProcessParams or addPostProcessParams + ) + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + if enableLoadingSavingRecipe: + loadEntireRecipeButton = widgets.OpenFilePushButton("Load saved recipe...") + saveEntireRecipeButton = widgets.savePushButton( + "Save all parameters to recipe file..." + ) + buttonsLayout.addWidget(loadEntireRecipeButton) + buttonsLayout.addWidget(saveEntireRecipeButton) + loadEntireRecipeButton.clicked.connect(self.loadEntireRecipe) + saveEntireRecipeButton.clicked.connect(self.saveEntireRecipe) + + buttonsLayout.addWidget(okButton) + + buttonsLayout.setContentsMargins(0, 10, 0, 10) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + self.okButton = okButton + + # Extra params in right column + row = 0 + self.extraArgsWidgets = None + self.extraParamsScrollArea = None + if extraParams is not None: + self.extraParamsScrollArea = widgets.ScrollArea() + extraParamsScrollAreaLayout = QVBoxLayout() + self.extraParamsScrollArea.setVerticalLayout(extraParamsScrollAreaLayout) + if extraParamsTitle is None: + extraParamsTitle = "Additional parameters" + + self.extraGroupBox, self.extraArgsWidgets = self.createGroupParams( + extraParams, extraParamsTitle + ) + + extraDefaultButton = widgets.reloadPushButton("Restore default") + extraLoadLastSelButton = widgets.OpenFilePushButton("Load last parameters") + extraButtonsLayout = QHBoxLayout() + extraButtonsLayout.addStretch(1) + extraButtonsLayout.addWidget(extraDefaultButton) + extraButtonsLayout.addWidget(extraLoadLastSelButton) + extraDefaultButton.clicked.connect(self.restoreDefaultExtra) + section = f"{self.model_name}.extra" + extraLoadLastSelButton.clicked.connect( + partial(loadFunc, section, self.extraArgsWidgets) + ) + + extraParamsScrollAreaLayout.addWidget(self.extraGroupBox) + + extraParamsLayout = QVBoxLayout() + extraParamsLayout.addWidget(QLabel(f"{self.extraGroupBox.title()}")) + self.extraGroupBox.setTitle("") + extraParamsLayout.addWidget(self.extraParamsScrollArea) + extraParamsLayout.addLayout(extraButtonsLayout) + self.paramsGroupPosMapper[self.extraParamsScrollArea] = (row, col) + gridLayout.addLayout(extraParamsLayout, row, col) + row += 1 + + # Post-processing in right-most column + self.postProcessGroupbox = None + self.seeHereLabel = None + thirdColumnLayout = QVBoxLayout() + if addPostProcessParams: + # Add minimum size spinbox which is valid for all models + postProcessGroupbox = PostProcessSegmParams( + "Post-processing segmentation parameters", + posData, + force_postprocess_2D=force_postprocess_2D, + ) + postProcessGroupbox.setCheckable(True) + postProcessGroupbox.setChecked(False) + self.postProcessGroupbox = postProcessGroupbox + + thirdColumnLayout.addWidget(postProcessGroupbox) + + postProcDefaultButton = widgets.reloadPushButton("Restore default") + postProcLoadLastSelButton = widgets.OpenFilePushButton( + "Load last parameters" + ) + postProcButtonsLayout = QHBoxLayout() + postProcButtonsLayout.addStretch(1) + postProcButtonsLayout.addWidget(postProcDefaultButton) + postProcButtonsLayout.addWidget(postProcLoadLastSelButton) + postProcDefaultButton.clicked.connect(self.restoreDefaultPostprocess) + postProcLoadLastSelButton.clicked.connect(self.loadLastSelectionPostProcess) + thirdColumnLayout.addLayout(postProcButtonsLayout) + thirdColumnLayout.addSpacing(15) + + if url is not None: + self.seeHereLabel = self.createSeeHereLabel(url) + thirdColumnLayout.addWidget(self.seeHereLabel, alignment=Qt.AlignCenter) + + self.paramsGroupPosMapper[self.preProcessParamsWidget] = (row, col) + + # Additional segmentation params in right column + self.additionalSegmGroupbox = None + if add_additional_segm_params: + thirdColumnLayout.addWidget(widgets.QHLine()) + additionalSegmGroupbox = self.getAdditionalSegmParams() + thirdColumnLayout.addWidget(additionalSegmGroupbox) + self.additionalSegmGroupbox = additionalSegmGroupbox + self.paramsGroupPosMapper[self.additionalSegmGroupbox] = (row, col) + + thirdColumnLayout.addStretch(1) + gridLayout.addLayout(thirdColumnLayout, row, col) + row += 1 + + # Add everything to main layout + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.configPars = self.readLastSelection() + if self.configPars is None: + initLoadLastSelButton.setDisabled(True) + segmentLoadLastSelButton.setDisabled(True) + if self.postProcessGroupbox is not None: + postProcLoadLastSelButton.setDisabled(True) + + if initLastParams: + initLoadLastSelButton.click() + if not self.skipSegmentation: + segmentLoadLastSelButton.click() + + if self.extraArgsWidgets is not None: + extraLoadLastSelButton.click() + + if self.postProcessGroupbox is not None: + postProcLoadLastSelButton.click() + + try: + self.connectCustomSignals(model_module) + except Exception as e: + printl(traceback.format_exc()) + + self.setLayout(mainLayout) + self.setFont(font) + # self.setModal(True) + + def warningNoSegmRecipes(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "No segmentation recipes found!

" + "To create a segmentation recipe you need click on " + "Save all parameters to recipe file... " + "button." + ) + msg.warning(self, "No segmentation recipes found!", txt) + + def selectIniFileToLoadEntireRecipe(self): + import qtpy.compat + + recipe_filepath = qtpy.compat.getopenfilename( + parent=self, + caption="Select INI file to load entire recipe", + filters="INI (*.ini);;All Files (*)", + )[0] + if not recipe_filepath: + return + + self.loadRecipeFromFilepath(recipe_filepath) + + txt = html_utils.paragraph("Done!

Segmentation recipe loaded from:") + msg = widgets.myMessageBox() + msg.information( + self, + "Segmentation recipe loaded!", + txt, + commands=(recipe_filepath,), + path_to_browse=os.path.dirname(recipe_filepath), + ) + + print("Done. Segmentation recipe loaded from:", recipe_filepath) + + def loadEntireRecipe(self): + segm_recipes_path_model = os.path.join(segm_recipes_path, self.model_name) + + if not os.path.exists(segm_recipes_path_model): + # self.warningNoSegmRecipes() + self.selectIniFileToLoadEntireRecipe() + return + + recipe_files = os.listdir(segm_recipes_path_model) + + if not recipe_files: + # self.warningNoSegmRecipes() + self.selectIniFileToLoadEntireRecipe() + return + + headerLabels = ["Name", "Date Created"] + items = [] + for recipe_file in recipe_files: + cp = config.ConfigParser() + cp.read(os.path.join(segm_recipes_path_model, recipe_file)) + date_created = cp["info"]["created_on"] + items.append((recipe_file, date_created)) + + browseButton = widgets.browseFileButton( + "Select INI file...", + title="Select INI file to load entire recipe", + openFolder=False, + start_dir=utils.getMostRecentPath(), + ext={"INI": ".ini"}, + ) + win = QTreeDialog( + items, + headerLabels=headerLabels, + title="Select a segmentation recipe to load", + infoText="Select a segmentation recipe to load:
", + path_to_browse=segm_recipes_path_model, + additional_buttons=(browseButton,), + ) + browseButton.sigPathSelected.connect( + partial( + self.entireRecipeIniFileSelected, + selectRecipeWin=win, + sender=browseButton, + ) + ) + win.exec_() + if win.cancel or not hasattr(win, "selectedText"): + print("Loading segmentation recipe cancelled.") + return + + if win.clickedButton == browseButton: + recipe_filepath = win.selectedIniFilepath + else: + recipe_filename = win.selectedText + recipe_filepath = os.path.join(segm_recipes_path_model, recipe_filename) + + self.loadRecipeFromFilepath(recipe_filepath) + + txt = html_utils.paragraph("Done!

Segmentation recipe loaded from:") + msg = widgets.myMessageBox() + msg.information( + self, + "Segmentation recipe laoded!", + txt, + commands=(recipe_filepath,), + path_to_browse=os.path.dirname(recipe_filepath), + ) + + print("Done. Segmentation recipe loaded from:", recipe_filepath) + + def entireRecipeIniFileSelected( + self, recipe_filepath, selectRecipeWin=None, sender=None + ): + selectRecipeWin.selectedText = "None" + selectRecipeWin.clickedButton = sender + selectRecipeWin.selectedIniFilepath = recipe_filepath + selectRecipeWin.cancel = False + selectRecipeWin.close() + + def loadRecipeFromFilepath(self, recipe_filepath): + cp = config.ConfigParser() + cp.read(recipe_filepath) + + self.loadPreprocRecipe(configPars=cp) + self.loadLastSelection( + f"{self.model_name}.init", self.init_argsWidgets, configPars=cp + ) + self.loadLastSelection( + f"{self.model_name}.segment", self.argsWidgets, configPars=cp + ) + if self.extraArgsWidgets: + self.loadLastSelection( + f"{self.model_name}.extra", self.extraArgsWidgets, configPars=cp + ) + self.loadLastSelectionPostProcess(configPars=cp) + + def saveEntireRecipe(self): + segm_recipes_path_model = os.path.join(segm_recipes_path, self.model_name) + try: + existingNames = os.listdir(segm_recipes_path_model) + except FileNotFoundError: + existingNames = [] + + win = filenameDialog( + title="Filename for segmentation recipe", + basename="segmentation_recipe", + ext=".ini", + hintText="Insert a filename for the segmentation recipe:", + allowEmpty=False, + parent=self, + existingNames=existingNames, + ) + win.exec_() + if win.cancel: + return + + ini_filename = win.filename + os.makedirs(segm_recipes_path, exist_ok=True) + os.makedirs(segm_recipes_path_model, exist_ok=True) + ini_filepath = os.path.join(segm_recipes_path_model, ini_filename) + + configPars = self.getConfigPars(create_new=True) + + if hasattr(self, "reduceMemUsageToggle"): + configPars[f"{self.model_name}.additional_segm_params"] = {} + reduceMemoryUsage = self.reduceMemUsageToggle.isChecked() + option = self.reduceMemUsageToggle.label + configPars[f"{self.model_name}.additional_segm_params"][option] = str( + reduceMemoryUsage + ) + + configPars["info"] = {} + configPars["info"]["created_on"] = datetime.datetime.now().strftime( + r"%Y/%m/%d %H:%M" + ) + + with open(ini_filepath, "w") as configfile: + configPars.write(configfile) + + txt = html_utils.paragraph("Done!

Segmentation recipe saved to:") + msg = widgets.myMessageBox() + msg.information( + self, + "Segmnentation recipe saved!", + txt, + commands=(ini_filepath,), + path_to_browse=os.path.dirname(ini_filepath), + ) + + print("Done. Segmentation recipe saved to:", ini_filepath) + + def getAdditionalSegmParams(self): + additionalSegmGroupbox = QGroupBox("Additional segmentation parameters") + local_row = 0 + additionalSegmLayout = QGridLayout() + option = "Reduce memory usage" + additionalSegmLayout.addWidget( + QLabel(f"{option}: "), local_row, 0, alignment=Qt.AlignRight + ) + self.reduceMemUsageToggle = widgets.Toggle() + additionalSegmLayout.addWidget( + self.reduceMemUsageToggle, local_row, 1, 1, 2, alignment=Qt.AlignCenter + ) + self.reduceMemUsageToggle.label = option + reduceMemUsageInfoButton = widgets.infoPushButton() + additionalSegmLayout.addWidget(reduceMemUsageInfoButton, local_row, 3) + reduceMemUsageInfoButton.clicked.connect(self.showInfoReduceMemUsage) + additionalSegmLayout.setColumnStretch(0, 0) + additionalSegmLayout.setColumnStretch(1, 1) + additionalSegmLayout.setColumnStretch(3, 0) + additionalSegmGroupbox.setLayout(additionalSegmLayout) + return additionalSegmGroupbox + + def showInfoReduceMemUsage(self): + infoText = html_utils.paragraph(f""" + If you are experiencing memory issues, you can try reducing the + memory usage by toggling this option.

+ This will reduce the memory usage by segmenting timelapse data + frame-by-frame instead of all frames at once. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Reduce memory usage", infoText) + + def loadPreprocRecipe(self, configPars=None): + if self.configPars is None and configPars is None: + return + + if configPars is None: + configPars = self.configPars + + preprocConfigPars = {} + for section in configPars.sections(): + if not section.startswith(f"{self.model_name}.preprocess"): + continue + + preprocConfigPars[section] = configPars[section] + + if not preprocConfigPars: + return + + self.preProcessParamsWidget.loadRecipe(preprocConfigPars) + + def connectCustomSignals(self, model_module): + if model_module is None: + return + + if not hasattr(model_module, "CustomSignals"): + return + + customSignals = model_module.CustomSignals() + for slot_info in customSignals.slots_info: + group = slot_info["group"] + widget_name = slot_info["widget_name"] + if group == "init": + ArgsWidgets_list = self.init_argsWidgets + else: + ArgsWidgets_list = self.argsWidgets + for argwidget in ArgsWidgets_list: + if argwidget.name == widget_name: + signal = getattr(argwidget.widget, slot_info["signal"]) + signal.connect(partial(slot_info["slot"], self)) + break + + def selectedFeaturesRange(self): + if self.postProcessGroupbox is None: + return {} + return self.postProcessGroupbox.selectedFeaturesRange() + + def groupedFeatures(self): + if self.postProcessGroupbox is None: + return {} + return self.postProcessGroupbox.groupedFeatures() + + def setChannelNames(self, chNames): + if not hasattr(self, "channelsCombobox"): + return + + items = ["None"] + items.extend(chNames) + self.channelsCombobox.addItems(items) + + def getValueFromMetadata(self, name): + try: + value = self.df_metadata.at[name, "values"] + except Exception as e: + # traceback.print_exc() + value = None + return value + + def criticalSegmFileRequiredButNoneAvailable(self): + model_name = f"{self.model_name} model" + action_txt = ( + f"Please, segment the correct channel before using {self.model_name}." + ) + if self.model_name == "skip_segmentation": + model_name = "Skipping the segmentation" + action_txt = ( + "To be able to skip the segmentation step, you need " + "create at least one segmentation file." + ) + txt = html_utils.paragraph(f""" + {model_name} + requires an additional segmentation file + but there are none available!

+ {action_txt} +

Thank you for you patience! + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Segmentation file required", txt) + raise FileNotFoundError( + "Model requires segmentation file but none are available." + ) + + def checkAddSegmEndnameCombobox(self, ArgSpec, groupBoxLayout, row): + if ArgSpec.name != "Auxiliary segmentation file": + return False + + if self.segmFileEndnames is None or not self.segmFileEndnames: + self.criticalSegmFileRequiredButNoneAvailable() + + label = QLabel(f"{ArgSpec.name}: ") + groupBoxLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + items = self.segmFileEndnames + self.segmEndnameCombobox = widgets.QCenteredComboBox() + self.segmEndnameCombobox.addItems(items) + groupBoxLayout.addWidget(self.segmEndnameCombobox, row, 1, 1, 2) + return True + + def createGroupParams(self, ArgSpecs_list, groupName, addChannelSelector=False): + ArgsWidgets_list = [] + groupBox = QGroupBox(groupName) + groupBoxLayout = QGridLayout() + + start_row = 0 + if self.is_tracker and self.channels is not None and addChannelSelector: + label = QLabel(f"Input image: ") + groupBoxLayout.addWidget(label, start_row, 0, alignment=Qt.AlignRight) + items = ["None", *self.channels] + self.channelCombobox = widgets.QCenteredComboBox() + self.channelCombobox.addItems(items) + groupBoxLayout.addWidget(self.channelCombobox, start_row, 1, 1, 2) + if self.currentChannelName is not None: + self.channelCombobox.setCurrentText(self.currentChannelName) + infoText = ( + "Some trackers require the intensity image as input.

" + "If this one does not require it, leave the selected value " + "to `None`." + ) + infoButton = self.getInfoButton("Input image", infoText) + groupBoxLayout.addWidget(infoButton, start_row, 3) + start_row += 1 + + addSecondChannelSelector = addChannelSelector + if len(ArgSpecs_list) > 0: + if addSecondChannelSelector and ArgSpecs_list[0].docstring is not None: + isSingleChannel = ( + ArgSpecs_list[0].docstring.lower().find("single channel only") != -1 + ) + if isSingleChannel: + addSecondChannelSelector = False + + isDualChannelModel = self.model_name.find("cellpose") != -1 or any( + [_types.is_second_channel_type(ArgSpec.type) for ArgSpec in ArgSpecs_list] + ) + askSecondChannel = isDualChannelModel and addSecondChannelSelector + + if askSecondChannel: + label = QLabel("Second channel (optional): ") + groupBoxLayout.addWidget(label, start_row, 0, alignment=Qt.AlignRight) + self.channelsCombobox = widgets.QCenteredComboBox() + groupBoxLayout.addWidget(self.channelsCombobox, start_row, 1, 1, 2) + infoText = ( + "Some models can merge two channels (e.g., cyto + " + "nucleus) to obtain better perfomance.\n\n" + "Select a channel as additional input to the model." + ) + infoButton = self.getInfoButton("Second channel", infoText) + groupBoxLayout.addWidget(infoButton, start_row, 3) + start_row += 1 + + exclusive_withs = dict() + default_exclusives = dict() + row_mapper = dict() + for row, ArgSpec in enumerate(ArgSpecs_list): + if _types.is_second_channel_type(ArgSpec.type): + continue + + if _types.is_widget_not_required(ArgSpec): + continue + + row = row + start_row + skip = self.checkAddSegmEndnameCombobox(ArgSpec, groupBoxLayout, row) + if skip: + continue + + arg_name = ArgSpec.name + var_name = arg_name.replace("_", " ") + var_name = f"{var_name[0].upper()}{var_name[1:]}" + label = QLabel(f"{var_name}: ") + metadata_val = self.getValueFromMetadata(ArgSpec.name) + groupBoxLayout.addWidget(label, row, 0, alignment=Qt.AlignRight) + try: + values = ArgSpec.type().values + isCustomListType = True + except Exception as err: + isCustomListType = False + + isVectorEntry = False + try: + if isinstance(ArgSpec.type(), _types.Vector): + isVectorEntry = True + except Exception as err: + pass + + isFolderPath = False + try: + if isinstance(ArgSpec.type(), _types.FolderPath): + isFolderPath = True + except Exception as err: + pass + + try: + exclusive_with = ArgSpec.type().is_exclusive_with + except Exception as err: + exclusive_with = [] + + try: + default_exclusive = ArgSpec.type().default_exclusive + except Exception as err: + default_exclusive = "" + + exclusive_withs[arg_name] = exclusive_with + default_exclusives[arg_name] = default_exclusive + row_mapper[arg_name] = row + + isCustomWidget = hasattr(ArgSpec.type, "isWidget") + + if isCustomWidget: + widget = ArgSpec.type().widget + defaultVal = ArgSpec.default + valueSetter = widget.setValue + valueGetter = widget.value + changeSig = widget.sigValueChanged + groupBoxLayout.addWidget(widget, row, 1, 1, 2) + elif isVectorEntry: + vectorLineEdit = widgets.VectorLineEdit() + vectorLineEdit.setValue(ArgSpec.default) + defaultVal = ArgSpec.default + valueSetter = widgets.VectorLineEdit.setValue + valueGetter = widgets.VectorLineEdit.value + changeSig = vectorLineEdit.valueChanged + widget = vectorLineEdit + groupBoxLayout.addWidget(vectorLineEdit, row, 1, 1, 2) + elif isFolderPath: + folderPathControl = widgets.FolderPathControl() + folderPathControl.setText(str(ArgSpec.default)) + widget = folderPathControl + defaultVal = str(ArgSpec.default) + valueSetter = widgets.FolderPathControl.setText + valueGetter = widgets.FolderPathControl.path + changeSig = widget.sigValueChanged + groupBoxLayout.addWidget(folderPathControl, row, 1, 1, 2) + elif ArgSpec.type == bool: + booleanGroup = QButtonGroup() + booleanGroup.setExclusive(True) + checkBox = widgets.Toggle() + checkBox.setChecked(ArgSpec.default) + defaultVal = ArgSpec.default + valueSetter = widgets.Toggle.setChecked + valueGetter = widgets.Toggle.isChecked + changeSig = checkBox.toggled + widget = checkBox + groupBoxLayout.addWidget( + checkBox, row, 1, 1, 2, alignment=Qt.AlignCenter + ) + elif ArgSpec.type == int: + spinBox = widgets.SpinBox() + if metadata_val is None: + spinBox.setValue(ArgSpec.default) + else: + spinBox.setValue(int(metadata_val)) + spinBox.isMetadataValue = True + defaultVal = ArgSpec.default + valueSetter = QSpinBox.setValue + valueGetter = QSpinBox.value + changeSig = spinBox.sigValueChanged + widget = spinBox + groupBoxLayout.addWidget(spinBox, row, 1, 1, 2) + elif ArgSpec.type == float: + doubleSpinBox = widgets.FloatLineEdit() + if metadata_val is None: + doubleSpinBox.setValue(ArgSpec.default) + else: + doubleSpinBox.setValue(float(metadata_val)) + doubleSpinBox.isMetadataValue = True + widget = doubleSpinBox + defaultVal = ArgSpec.default + valueSetter = widgets.FloatLineEdit.setValue + valueGetter = widgets.FloatLineEdit.value + changeSig = doubleSpinBox.valueChanged + groupBoxLayout.addWidget(doubleSpinBox, row, 1, 1, 2) + elif ArgSpec.type == os.PathLike: + filePathControl = widgets.filePathControl() + filePathControl.setText(str(ArgSpec.default)) + widget = filePathControl + defaultVal = str(ArgSpec.default) + valueSetter = widgets.filePathControl.setText + valueGetter = widgets.filePathControl.path + changeSig = filePathControl.sigValueChanged + groupBoxLayout.addWidget(filePathControl, row, 1, 1, 2) + elif isCustomListType: + items = ArgSpec.type().values + defaultVal = str(ArgSpec.default) + combobox = widgets.AlphaNumericComboBox() + combobox.addItems(items) + combobox.setCurrentValue(defaultVal) + valueSetter = widgets.AlphaNumericComboBox.setCurrentValue + valueGetter = widgets.AlphaNumericComboBox.currentValue + changeSig = combobox.currentTextChanged + widget = combobox + groupBoxLayout.addWidget(combobox, row, 1, 1, 2) + else: + lineEdit = QLineEdit() + lineEdit.setText(str(ArgSpec.default)) + lineEdit.setAlignment(Qt.AlignCenter) + widget = lineEdit + defaultVal = str(ArgSpec.default) + valueSetter = QLineEdit.setText + valueGetter = QLineEdit.text + changeSig = lineEdit.editingFinished + groupBoxLayout.addWidget(lineEdit, row, 1, 1, 2) + + if ArgSpec.desc: + infoButton = self.getInfoButton(ArgSpec.name, ArgSpec.desc) + groupBoxLayout.addWidget(infoButton, row, 3) + + argsInfo = ArgWidget( + name=ArgSpec.name, + type=ArgSpec.type, + widget=widget, + defaultVal=defaultVal, + valueSetter=valueSetter, + valueGetter=valueGetter, + changeSig=changeSig, + ) + ArgsWidgets_list.append(argsInfo) + + exclusive_group = core.connected_components_in_undirected_graph(exclusive_withs) + + for group in exclusive_group: + if len(group) == 1: + continue + for arg_name in group: + default_exclusive = default_exclusives[arg_name] + row = row_mapper[arg_name] + + argsInfo = ArgsWidgets_list[row] + valueSetter = argsInfo.valueSetter + widget = argsInfo.widget + valueGetter = argsInfo.valueGetter + + argsInfo.valueGetter = qutils.replace_certain_vals( + argsInfo.valueGetter, default_exclusive, None + ) + + for arg_name_other in group: + if arg_name == arg_name_other: + continue + row_other = row_mapper[arg_name_other] + argsInfo_other = ArgsWidgets_list[row_other] + changeSig_other = argsInfo_other.changeSig + changeSig_other.connect( + partial( + qutils.set_exclusive_valueSetter, + widget, + valueSetter, + default_exclusive, + ) + ) + + groupBoxLayout.setColumnStretch(0, 0) + groupBoxLayout.setColumnStretch(1, 1) + groupBoxLayout.setColumnStretch(3, 0) + nrows = groupBoxLayout.rowCount() + groupBoxLayout.setRowStretch(nrows, 1) + + groupBox.setLayout(groupBoxLayout) + return groupBox, ArgsWidgets_list + + def getInfoButton(self, param_name, infoText): + infoButton = widgets.infoPushButton() + infoButton.param_name = param_name + infoButton.setToolTip( + f"Click to get more info about `{param_name}` parameter..." + ) + infoButton.infoText = infoText + infoButton.clicked.connect(self.showInfoParam) + return infoButton + + def showInfoParam(self): + text = self.sender().infoText + text = text.replace("\n", "
") + text = html_utils.rst_urls_to_html(text) + text = html_utils.rst_to_html(text) + text = html_utils.paragraph(text) + param_name = self.sender().param_name + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, f"Info about `{param_name}` parameter", text) + + def restoreDefaultInit(self): + for argWidget in self.init_argsWidgets: + defaultVal = argWidget.defaultVal + widget = argWidget.widget + valueSetter = argWidget.valueSetter + qutils.set_exclusive_valueSetter(widget, valueSetter, defaultVal) + + def restoreDefaultSegment(self): + for argWidget in self.argsWidgets: + defaultVal = argWidget.defaultVal + widget = argWidget.widget + valueSetter = argWidget.valueSetter + qutils.set_exclusive_valueSetter(widget, valueSetter, defaultVal) + + def restoreDefaultExtra(self): + for argWidget in self.extraArgsWidgets: + defaultVal = argWidget.defaultVal + widget = argWidget.widget + valueSetter = argWidget.valueSetter + qutils.set_exclusive_valueSetter(widget, valueSetter, defaultVal) + + def restoreDefaultPostprocess(self): + self.postProcessGroupbox.restoreDefault() + + def readLastSelection(self): + self.ini_path = os.path.join(settings_folderpath, self.ini_filename) + + if not os.path.exists(self.ini_path): + return None + + print(f"Reading last selected parameters from: {self.ini_path}") + configPars = config.ConfigParser() + configPars.read(self.ini_path) + return configPars + + def setValuesFromParams(self, init_params, segment_params, extra_params=None): + sections = { + f"{self.model_name}.init": (init_params, self.init_argsWidgets), + f"{self.model_name}.segment": (segment_params, self.argsWidgets), + } + if extra_params is not None: + sections[f"{self.model_name}.extra"] = (extra_params, self.extraArgsWidgets) + + for section, values in sections.items(): + params, argWidgetList = values + for argWidget in argWidgetList: + val = params.get(argWidget.name) + widget = argWidget.widget + if val is None: + continue + casters = [lambda x: x, int, float, str, bool] + for caster in casters: + try: + argWidget.valueSetter(widget, caster(val)) + break + except Exception as e: + continue + + def loadLastSelection(self, section, argWidgetList, checked=False, configPars=None): + if self.configPars is None and configPars is None: + return + + if configPars is None: + configPars = self.configPars + + getters = ["getboolean", "getint", "getfloat", "get"] + try: + options = configPars.options(section) + except Exception: + return + + for argWidget in argWidgetList: + option = argWidget.name + val = None + for getter in getters: + try: + val = getattr(configPars, getter)(section, option) + break + except Exception as err: + pass + widget = argWidget.widget + + if hasattr(widget, "isMetadataValue"): + continue + if val is None: + continue + + casters = [lambda x: x, int, float, str, bool] + for caster in casters: + try: + val = caster(val) + valueSetter = argWidget.valueSetter + qutils.set_exclusive_valueSetter(widget, valueSetter, val) + break + except Exception as e: + printl(traceback.format_exc()) + continue + + def loadLastSelectionPostProcess(self, checked=False, configPars=None): + if self.postProcessGroupbox is None: + return + + postProcessSection = f"{self.model_name}.postprocess" + + if isinstance(configPars, bool): + configPars = None + + if configPars is None: + configPars = self.configPars + + if postProcessSection in configPars.sections(): + try: + minSize = configPars.getint(postProcessSection, "minSize", fallback=10) + except ValueError: + minSize = 10 + + try: + minSolidity = configPars.getfloat( + postProcessSection, "minSolidity", fallback=0.5 + ) + except ValueError: + minSolidity = 0.5 + + try: + maxElongation = configPars.getfloat( + postProcessSection, "maxElongation", fallback=3 + ) + except ValueError: + maxElongation = 3 + + try: + minObjSizeZ = configPars.getint( + postProcessSection, "min_obj_no_zslices", fallback=3 + ) + except ValueError: + minObjSizeZ = 3 + + kwargs = { + "min_solidity": minSolidity, + "min_area": minSize, + "max_elongation": maxElongation, + "min_obj_no_zslices": minObjSizeZ, + } + self.postProcessGroupbox.restoreFromKwargs(kwargs) + + applyPostProcessing = configPars.getboolean( + postProcessSection, "applyPostProcessing" + ) + self.postProcessGroupbox.setChecked(applyPostProcessing) + + customPostProcessSection = f"{self.model_name}.custom_postprocess" + if postProcessSection not in configPars.sections(): + return + + selectFeaturesWidget = self.postProcessGroupbox.selectedFeaturesDialog.groupbox + selectFeaturesWidget.resetFields() + f = 0 + for col_name, value in configPars[customPostProcessSection].items(): + low, high = value.split(",") + low = low.strip() + high = high.strip() + if f > 0: + selectFeaturesWidget.addFeatureField() + + selector = selectFeaturesWidget.selectors[f] + selector.selectButton.setText(col_name) + selector.selectButton.setFlat(True) + + feature_group = measurements.get_metric_group_name(col_name) + selector.featureGroup = feature_group + + if low != "None": + try: + low_val = int(low) + except ValueError: + low_val = float(low) + + selector.lowRangeWidgets.checkbox.setChecked(True) + selector.lowRangeWidgets.spinbox.setValue(low_val) + + if high != "None": + try: + high_val = int(high) + except ValueError: + high_val = float(high) + + selector.highRangeWidgets.checkbox.setChecked(True) + selector.highRangeWidgets.spinbox.setValue(high_val) + + f += 1 + + def createSeeHereLabel(self, url): + htmlTxt = f'here' + seeHereLabel = QLabel() + seeHereLabel.setText(f""" +

+ See {htmlTxt} for details on the parameters +

+ """) + seeHereLabel.setTextFormat(Qt.RichText) + seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) + seeHereLabel.setOpenExternalLinks(True) + seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") + return seeHereLabel + + def argsWidgets_to_kwargs(self, argsWidgets): + kwargs_dict = { + argWidget.name: argWidget.valueGetter(argWidget.widget) + for argWidget in argsWidgets + } + return kwargs_dict + + def getInitKwargs(self): + init_kwargs = self.argsWidgets_to_kwargs(self.init_argsWidgets) + if hasattr(self, "segmEndnameCombobox"): + init_kwargs["segm_endname"] = self.segmEndnameCombobox.currentText() + + return init_kwargs + + def getModelKwargs(self): + if self.skipSegmentation: + return {} + + return self.argsWidgets_to_kwargs(self.argsWidgets) + + def getExtraKwargs(self): + if self.extraArgsWidgets is None: + return {} + + return self.argsWidgets_to_kwargs(self.extraArgsWidgets) + + def ok_cb(self, checked): + self.cancel = False + self.preproc_recipe = None + if self.preProcessParamsWidget is not None: + self.preproc_recipe = self.preProcessParamsWidget.recipe() + if self.preproc_recipe is None: + return + + self.init_kwargs = self.getInitKwargs() + + if self.extraArgsWidgets: + self.extra_kwargs = self.getExtraKwargs() + + self.model_kwargs = self.getModelKwargs() + self.segment_kwargs = self.model_kwargs + + if self.postProcessGroupbox is not None: + self.applyPostProcessing = self.postProcessGroupbox.isChecked() + self.standardPostProcessKwargs = self.postProcessGroupbox.kwargs() + self.secondChannelName = None + if hasattr(self, "channelsCombobox"): + self.secondChannelName = self.channelsCombobox.currentText() + if self.secondChannelName == "None": + self.secondChannelName = None + self.inputChannelName = "None" + if self.channelCombobox is not None: + self.inputChannelName = self.channelCombobox.currentText() + + self.reduceMemoryUsage = False + if hasattr(self, "reduceMemUsageToggle"): + self.reduceMemoryUsage = self.reduceMemUsageToggle.isChecked() + self.customPostProcessFeatures = self.selectedFeaturesRange() + self.customPostProcessGroupedFeatures = self.groupedFeatures() + self.saveLastSelection() + self.freePosData() + self.close() + + def freePosData(self): + if hasattr(self, "postProcessGroupbox"): + try: + for ( + selector + ) in self.postProcessGroupbox.selectedFeaturesDialog.groupbox.selectors: + qutils.hardDelete(selector) + except AttributeError: + pass + try: + qutils.hardDelete( + self.postProcessGroupbox.selectedFeaturesDialog.groupbox + ) + except AttributeError: + pass + try: + qutils.hardDelete(self.postProcessGroupbox.selectedFeaturesDialog) + except AttributeError: + pass + try: + qutils.hardDelete(self.postProcessGroupbox) + except AttributeError: + pass + + def getConfigPars(self, create_new=False): + if self.configPars is None or create_new: + configPars = config.ConfigParser() + else: + configPars = self.configPars + + if self.preProcessParamsWidget is not None: + preprocCp = self.preProcessParamsWidget.recipeConfigPars(self.model_name) + for section in preprocCp.sections(): + configPars[section] = preprocCp[section] + + configPars[f"{self.model_name}.init"] = {} + configPars[f"{self.model_name}.segment"] = {} + configPars[f"{self.model_name}.extra"] = {} + + init_kwargs = self.getInitKwargs() + model_kwargs = self.getModelKwargs() + + for key, val in init_kwargs.items(): + configPars[f"{self.model_name}.init"][key] = str(val) + for key, val in model_kwargs.items(): + configPars[f"{self.model_name}.segment"][key] = str(val) + if self.extraArgsWidgets: + extra_kwargs = self.getExtraKwargs() + for key, val in extra_kwargs.items(): + configPars[f"{self.model_name}.extra"][key] = str(val) + + configPars[f"{self.model_name}.postprocess"] = {} + if self.postProcessGroupbox is not None: + postProcKwargs = self.postProcessGroupbox.kwargs() + postProcessConfig = configPars[f"{self.model_name}.postprocess"] + postProcessConfig["minSize"] = str(postProcKwargs["min_area"]) + postProcessConfig["minSolidity"] = str(postProcKwargs["min_solidity"]) + postProcessConfig["maxElongation"] = str(postProcKwargs["max_elongation"]) + postProcessConfig["min_obj_no_zslices"] = str( + postProcKwargs["min_obj_no_zslices"] + ) + postProcessConfig["applyPostProcessing"] = str( + self.postProcessGroupbox.isChecked() + ) + + custom_postproc_section = f"{self.model_name}.custom_postprocess" + configPars[custom_postproc_section] = {} + if self.postProcessGroupbox is not None: + selectFeaturesWidget = ( + self.postProcessGroupbox.selectedFeaturesDialog.groupbox + ) + for selector in selectFeaturesWidget.selectors: + col_name = selector.selectButton.text() + lowStr = "None" + highStr = "None" + if selector.lowRangeWidgets.checkbox.isChecked(): + lowVal = selector.lowRangeWidgets.spinbox.value() + lowStr = str(lowVal) + if selector.highRangeWidgets.checkbox.isChecked(): + highVal = selector.highRangeWidgets.spinbox.value() + highStr = str(highVal) + + configPars[custom_postproc_section][col_name] = f"{lowStr}, {highStr}" + + return configPars + + def saveLastSelection(self): + self.configPars = self.getConfigPars() + with open(self.ini_path, "w") as configfile: + self.configPars.write(configfile) + + mode = "Segmentation" if not self.is_tracker else "Tracking" + + print(f'{mode} parameters saved at "{self.ini_path}"') + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if self.model_name == "thresholding": + self.segmentGroupBox.setDisabled(True) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + self.freePosData() + if hasattr(self, "loop"): + self.loop.exit() + + def cancel_cb(self, checked): + self.cancel = True + self.freePosData() + + def showEvent(self, event) -> None: + buttonHeight = self.okButton.minimumSizeHint().height() + heightInitParams = self.initParamsScrollArea.minimumHeightNoScrollbar() + heightLeft = 70 + buttonHeight + heightCenter = heightInitParams + heightRight = 0 + if self.segmentParamsScrollArea is not None: + heightSegmentParams = ( + self.segmentParamsScrollArea.minimumHeightNoScrollbar() + ) + heightCenter += heightSegmentParams + 70 + buttonHeight + + rowInitParams, _ = self.paramsGroupPosMapper[self.initParamsScrollArea] + rowSegmParams, _ = self.paramsGroupPosMapper[self.segmentParamsScrollArea] + + numInitParams = len(self.init_params) + numSegmentParams = len(self.segment_params) + + try: + segmentParamsStretch = max(1, round(numSegmentParams / numInitParams)) + except ZeroDivisionError as err: + segmentParamsStretch = 1 + self.secondColLayout.setStretch(rowInitParams, 1) + self.secondColLayout.setStretch(rowSegmParams, segmentParamsStretch) + + if self.extraParamsScrollArea is not None: + heightRight += ( + self.extraParamsScrollArea.minimumHeightNoScrollbar() + + 70 + + buttonHeight + ) + + if self.additionalSegmGroupbox is not None: + heightRight += self.additionalSegmGroupbox.minimumSizeHint().height() + heightRight += buttonHeight + if self.preProcessParamsWidget is not None: + heightPreprocParams = self.preProcessParamsWidget.minimumSizeHint().height() + heightLeft += heightPreprocParams + heightLeft += buttonHeight + if self.postProcessGroupbox is not None: + heightRight += self.postProcessGroupbox.minimumSizeHint().height() + heightRight += buttonHeight + if self.seeHereLabel is not None: + heightRight += self.seeHereLabel.minimumSizeHint().height() + height = max(heightLeft, heightRight, heightCenter) + screenHeight = self.screen().size().height() + screenGeom = self.screen().geometry() + screenLeft = screenGeom.left() + screenRight = screenGeom.right() + screenCenter = (screenLeft + screenRight) / 2 + width = self.sizeHint().width() + windowLeft = int(screenCenter - width / 2) + self.move(windowLeft, 20) + + if height >= screenHeight - 150: + height = screenHeight - 150 + self.resize(width, height) + + +class downloadModel: + def __init__(self, model_name, parent=None): + self.loop = None + self.model_name = model_name + self._parent = parent + + def download(self): + model_url = utils._model_url(self.model_name) + if model_url is None: + return + + _, model_path = utils.get_model_path(self.model_name, create_temp_dir=False) + model_name = self.model_name + model_exists = utils.check_model_exists(model_path, model_name) + if not model_exists: + self.warnDownloadModel(model_path, self.model_name) + try: + self._parent.logger.info( + f'Downloading {self.model_name} model(s) to "{model_path}"' + ) + except Exception as err: + pass + + success = utils.download_model(self.model_name) + if not success: + self.criticalDowloadFailed() + + def warnDownloadModel(self, model_path, model_name): + txt = html_utils.paragraph( + "Cell-ACDC needs to download the model " + f"{model_name}.

" + "The files will be dowloaded into the following folder:

" + f"{model_path}

" + "Progress will be displayed in the terminal.
" + ) + msg = widgets.myMessageBox() + msg.information(self._parent, "Download model", txt) + + def criticalDowloadFailed(self): + import cellacdc + + model_name = self.model_name + m = model_name.lower() + weights_filenames = getattr(cellacdc, f"{m}_weights_filenames") + url, alternative_url = utils._model_url(model_name, return_alternative=True) + url_href = f'this link' + alternative_url_href = f'this link' + _, model_path = utils.get_model_path(model_name, create_temp_dir=False) + txt = html_utils.paragraph(f""" + Automatic download of {model_name} failed.

+ Please, manually download the model weights from {url_href} or + {alternative_url_href}.

+ Next, unzip the content (or move the files if not a zip archive) + of the downloaded file into the following folder:

+ {model_path}

+ NOTE: if clicking on the link above does not work + copy one of the links below and paste it into the browser

+ {url} +

+ {alternative_url} + """) + weights_paths = [os.path.join(model_path, f) for f in weights_filenames] + weights = "\n\n".join(weights_paths) + detailsText = f"Files that {model_name} requires:\n\n{weights}" + msg = widgets.myMessageBox() + msg.critical( + self._parent, + f"Download of {model_name} failed", + txt, + detailsText=detailsText, + ) + self.close_() + + def close_(self): + return + + +class SelectAcdcDfVersionToRestore(QBaseDialog): + def __init__(self, posData, parent=None): + super().__init__(parent=parent) + + self.cancel = True + + self.setWindowTitle("Select annotations table to restore") + + mainLayout = QVBoxLayout() + + acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) + instructionsLabel = html_utils.paragraph( + f"Select an older version of the {acdc_df_filename} " + "annotations table to load.

" + "The datetime refers to the time you replaced the old version with " + "a newer one.

" + ) + mainLayout.addWidget(QLabel(instructionsLabel)) + + self.savedListBox = None + if os.path.exists(posData.acdc_output_backup_zip_path): + zip_path = posData.acdc_output_backup_zip_path + self.savedArchivefilepath = zip_path + with zipfile.ZipFile(zip_path, mode="r") as zip: + csv_names = natsorted(zip.namelist(), reverse=True) + + keys = [csv_name[:-4] for csv_name in csv_names] + + self.savedKeys = keys + f = load.ISO_TIMESTAMP_FORMAT + timestamps = [datetime.datetime.strptime(key, f) for key in keys] + items = [date.strftime(r"%d %b %Y, %H:%M:%S") for date in timestamps] + mainLayout.addWidget(QLabel("Saved annotations:")) + self.savedListBox = widgets.listWidget() + self.savedListBox.addItems(items) + mainLayout.addWidget(self.savedListBox) + self.savedListBox.itemSelectionChanged.connect(self.onItemSelectionChanged) + + recovery_folderpath = posData.recoveryFolderpath() + unsaved_recovery_folderpath = os.path.join(recovery_folderpath, "never_saved") + self.neverSavedFolderpath = unsaved_recovery_folderpath + files = utils.listdir(unsaved_recovery_folderpath) + csv_files = [file for file in files if file.endswith(".csv")] + self.neverSavedListBox = None + if csv_files: + csv_names = natsorted(csv_files, reverse=True) + keys = [csv_name[:-4] for csv_name in csv_names] + self.neverSavedKeys = keys + f = load.ISO_TIMESTAMP_FORMAT + timestamps = [datetime.datetime.strptime(key, f) for key in keys] + items = [date.strftime(r"%d %b %Y, %H:%M:%S") for date in timestamps] + mainLayout.addWidget(QLabel("Never saved annotations:")) + self.neverSavedListBox = widgets.listWidget() + self.neverSavedListBox.addItems(items) + mainLayout.addWidget(self.neverSavedListBox) + self.neverSavedListBox.itemSelectionChanged.connect( + self.onItemSelectionChanged + ) + + cancelOkLayout = widgets.CancelOkButtonsLayout() + + cancelOkLayout.okButton.clicked.connect(self.ok_cb) + cancelOkLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(cancelOkLayout) + + self.setLayout(mainLayout) + + self.setFont(font) + + def ok_cb(self): + self.cancel = False + try: + for i in range(self.savedListBox.count()): + item = self.savedListBox.item(i) + if item.isSelected(): + self.selectedTimestamp = item.text() + self.selectedKey = self.savedKeys[i] + self.archiveFilePath = self.savedArchivefilepath + break + except Exception as e: + pass + + try: + for i in range(self.neverSavedListBox.count()): + item = self.neverSavedListBox.item(i) + if item.isSelected(): + self.selectedTimestamp = item.text() + self.selectedKey = self.neverSavedKeys[i] + self.archiveFilePath = self.neverSavedFolderpath + break + except Exception as e: + pass + self.close() + + def onItemSelectionChanged(self): + otherListBox = ( + self.savedListBox + if self.sender() == self.neverSavedListBox + else self.neverSavedListBox + ) + if otherListBox is None: + return + for i in range(otherListBox.count()): + item = otherListBox.item(i) + item.setSelected(False) + + +class ChangeUserProfileFolderPathDialog(QBaseDialog): + def __init__(self, posData, parent=None): + super().__init__(parent=parent) + + self.cancel = True + + self.setWindowTitle("Change user profile folder path") + + mainLayout = QVBoxLayout() + + acdc_folders = load.get_all_acdc_folders(user_profile_path) + acdc_folders_format = [f" - {folder}" for folder in acdc_folders] + acdc_folders_format = "
".join(acdc_folders_format) + + txt = f""" + Current user profile path:

+ {user_profile_path}

+ The user profile contains the following Cell-ACDC folders:

+ {acdc_folders_format}

+ After clicking "Ok" you will be asked to select the folder where + you want to migrate the user profile data. + """ + + txt = html_utils.paragraph(txt) + label = QLabel(txt) + + mainLayout.addWidget(label) + + buttonsLayout = widgets.CancelOkButtonsLayout() + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch() + + self.setLayout(mainLayout) + + def ok_cb(self): + self.cancel = False + self.close() + + +class QInput(QBaseDialog): + def __init__(self, parent=None, title="Input"): + self.cancel = True + self.allowEmpty = True + + super().__init__(parent) + + self.setWindowTitle(title) + + self.mainLayout = QVBoxLayout() + + self.infoLabel = QLabel() + self.mainLayout.addWidget(self.infoLabel) + + promptLayout = QHBoxLayout() + self.promptLabel = QLabel() + promptLayout.addWidget(self.promptLabel) + self.lineEdit = QLineEdit() + promptLayout.addWidget(self.lineEdit) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addLayout(promptLayout) + self.mainLayout.addSpacing(20) + self.mainLayout.addLayout(buttonsLayout) + + self.buttonsLayout = buttonsLayout + + self.setFont(font) + self.setLayout(self.mainLayout) + + def askText(self, prompt, infoText="", allowEmpty=False): + self.allowEmpty = allowEmpty + if infoText: + infoText = f"{infoText}
" + self.infoLabel.setText(html_utils.paragraph(infoText)) + self.promptLabel.setText(prompt) + self.exec_(resizeWidthFactor=1.5) + + def ok_cb(self): + self.answer = self.lineEdit.text() + if not self.allowEmpty and not self.answer: + msg = widgets.myMessageBox(showCentered=False) + msg.critical(self, "Empty", "Entry cannot be empty.") + return + self.cancel = False + self.close() + + +class InstallPyTorchDialog(QBaseDialog): + def __init__(self, parent=None, caller_name="Cell-ACDC"): + super().__init__(parent=parent) + + self.cancel = True + + mainLayout = QVBoxLayout() + + innerLayout = QGridLayout() + + iconLabel = QLabel(self) + standardIcon = getattr(QStyle, "SP_MessageBoxInformation") + icon = self.style().standardIcon(standardIcon) + pixmap = icon.pixmap(60, 60) + iconLabel.setPixmap(pixmap) + innerLayout.addWidget(iconLabel, 0, 0, alignment=Qt.AlignTop) + + href = html_utils.href_tag("How to install PyTorch", urls.install_pytorch) + important = html_utils.to_admonition( + """ + Should you choose to install PyTorch yourself, make sure to + activate
+ the correct acdc environment first
. + """, + admonition_type="important", + ) + + infoText = html_utils.paragraph(f""" + {caller_name} needs to install the package PyTorch.

+ Select your preferences and click ok to install it now. + You will have to confirm the installation in the terminal.

+ Alternatively, you can close {caller_name} and run the command + yourself.

+ For more details see this guide: {href}
+ {important} + """) + innerLayout.addWidget(QLabel(infoText), 0, 1) + innerLayout.addItem(QSpacerItem(10, 10), 1, 1) + + preferencesLayout = QGridLayout() + + row = 0 + self.osCombobox = QComboBox() + self.osCombobox.addItems(["Linux", "Mac", "Windows"]) + preferencesLayout.addWidget(QLabel("Your OS"), row, 0) + preferencesLayout.addWidget(self.osCombobox, row, 1) + + if is_mac: + self.osCombobox.setCurrentText("Mac") + elif is_win: + self.osCombobox.setCurrentText("Windows") + + row += 1 + self.pkgManagerCombobox = QComboBox() + self.pkgManagerCombobox.addItems(["Pip"]) + if not is_conda_env(): + self.pkgManagerCombobox.setCurrentText("Pip") + self.pkgManagerCombobox.setDisabled(True) + + preferencesLayout.addWidget(QLabel("Package manager"), row, 0) + preferencesLayout.addWidget(self.pkgManagerCombobox, row, 1) + + row += 1 + self.cmptPlatformCombobox = QComboBox() + self.cmptPlatformCombobox.addItems( + ["CPU", "CUDA 11.8 (NVIDIA GPU)", "CUDA 12.1 (NVIDIA GPU)"] + ) + + preferencesLayout.addWidget(QLabel("Compute Platform"), row, 0) + preferencesLayout.addWidget(self.cmptPlatformCombobox, row, 1) + + row += 1 + pip_prefix, conda_prefix = utils.get_pip_conda_prefix() + self.commandWidget = widgets.CopiableCommandWidget( + command=f"{pip_prefix} torch" + ) + preferencesLayout.addWidget(QLabel("Run this command: "), row, 0) + preferencesLayout.addWidget(self.commandWidget, row, 1, 1, 2) + preferencesLayout.setColumnStretch(0, 0) + preferencesLayout.setColumnStretch(1, 0) + preferencesLayout.setColumnStretch(2, 1) + + innerLayout.addLayout(preferencesLayout, 2, 1) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(innerLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + self.osCombobox.currentTextChanged.connect(self.updateCommand) + self.pkgManagerCombobox.currentTextChanged.connect(self.updateCommand) + self.cmptPlatformCombobox.currentTextChanged.connect(self.updateCommand) + + self.updateCommand() + + def updateCommand(self, *args, **kwargs): + osText = self.osCombobox.currentText() + pkgManager = self.pkgManagerCombobox.currentText() + cmptPlatform = self.cmptPlatformCombobox.currentText() + command = utils.get_pytorch_command()[osText][pkgManager][cmptPlatform] + self.commandWidget.setCommand(command) + + def ok_cb(self): + self.command = self.commandWidget.command() + self.cancel = False + self.close() + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + ArgWidget, +) +from .general import ( + QTreeDialog, +) +from .metadata import ( + filenameDialog, +) +from .preprocess import ( + PostProcessSegmParams, + PreProcessParamsWidget, +) + diff --git a/cellacdc/dialogs/preprocess.py b/cellacdc/dialogs/preprocess.py new file mode 100644 index 000000000..742a0ebfd --- /dev/null +++ b/cellacdc/dialogs/preprocess.py @@ -0,0 +1,4000 @@ +"""Cell-ACDC dialog windows: preprocess.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) + +class wandToleranceWidget(QFrame): + def __init__(self, parent=None): + super().__init__(parent) + + self.slider = widgets.sliderWithSpinBox(title="Tolerance") + self.slider.setMaximum(255) + self.slider._layout.setColumnStretch(2, 21) + + self.setLayout(self.slider.layout) + + +class QDialogAutomaticThresholding(QBaseDialog): + def __init__(self, parent=None, isSegm3D=True): + super().__init__(parent) + + self.cancel = True + + self.setWindowTitle("Automatic thresholding parameters") + + layout = QVBoxLayout() + formLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + + row = 0 + self.sigmaGaussSpinbox = QDoubleSpinBox() + self.sigmaGaussSpinbox.setValue(1) + self.sigmaGaussSpinbox.setMaximum(2**31) + self.sigmaGaussSpinbox.setAlignment(Qt.AlignCenter) + formLayout.addWidget( + QLabel("Gaussian filter sigma (0 to ignore): "), + row, + 0, + alignment=Qt.AlignRight, + ) + formLayout.addWidget(self.sigmaGaussSpinbox, row, 1, 1, 2) + + row += 1 + self.threshMethodCombobox = QComboBox() + self.threshMethodCombobox.addItems( + ["Isodata", "Li", "Mean", "Minimum", "Otsu", "Triangle", "Yen"] + ) + formLayout.addWidget( + QLabel("Thresholding algorithm: "), row, 0, alignment=Qt.AlignRight + ) + formLayout.addWidget(self.threshMethodCombobox, row, 1, 1, 2) + + self.segment3Dcheckbox = None + if isSegm3D: + row += 1 + formLayout.addWidget( + QLabel("Segment 3D volume: "), row, 0, alignment=Qt.AlignRight + ) + group = QButtonGroup() + group.setExclusive(True) + self.segment3Dcheckbox = QRadioButton("Yes") + segmentSliceBySliceCheckbox = QRadioButton("No, segment slice-by-slice") + group.addButton(self.segment3Dcheckbox) + group.addButton(segmentSliceBySliceCheckbox) + formLayout.addWidget(self.segment3Dcheckbox, row, 1) + formLayout.addWidget(segmentSliceBySliceCheckbox, row, 2) + self.segment3Dcheckbox.setChecked(True) + + okButton = widgets.okPushButton("Ok") + cancelButton = widgets.cancelPushButton("Cancel") + helpButton = widgets.helpPushButton("Help...") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(helpButton) + buttonsLayout.addWidget(okButton) + + layout.addLayout(formLayout) + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + + okButton.clicked.connect(self.ok_cb) + helpButton.clicked.connect(self.help_cb) + cancelButton.clicked.connect(self.close) + + self.setLayout(layout) + self.setFont(font) + + self.configPars = self.loadLastSelection() + + def help_cb(self): + import webbrowser + + url = "https://scikit-image.org/docs/stable/auto_examples/applications/plot_thresholding.html" + webbrowser.open(url) + + def ok_cb(self): + self.cancel = False + self.gaussSigma = self.sigmaGaussSpinbox.value() + threshMethod = self.threshMethodCombobox.currentText().lower() + self.threshMethod = f"threshold_{threshMethod}" + self.segment_kwargs = { + "gauss_sigma": self.gaussSigma, + "threshold_method": self.threshMethod, + "segment_3D_volume": False, + } + self.reduceMemoryUsage = False + if self.segment3Dcheckbox is not None: + doSegm3D = self.segment3Dcheckbox.isChecked() + self.segment_kwargs["segment_3D_volume"] = doSegm3D + self.close() + + def loadLastSelection(self): + self.ini_path = os.path.join(settings_folderpath, "last_params_segm_models.ini") + if not os.path.exists(self.ini_path): + return + + configPars = config.ConfigParser() + configPars.read(self.ini_path) + + if "thresholding.segment" not in configPars.sections(): + return + + section = configPars["thresholding.segment"] + self.sigmaGaussSpinbox.setValue(float(section["gauss_sigma"])) + + threshold_method = section["threshold_method"] + Method = threshold_method[10:].capitalize() + self.threshMethodCombobox.setCurrentText(Method) + if self.segment3Dcheckbox is None: + return + self.segment3Dcheckbox.setChecked(section.getboolean("segment_3D_volume")) + + +class startStopFramesDialog(QBaseDialog): + def __init__( + self, + SizeT, + currentFrameNum=0, + parent=None, + windowTitle="Select frame range to segment", + ): + super().__init__(parent=parent) + + self.setWindowTitle(windowTitle) + + self.cancel = True + + layout = QVBoxLayout() + buttonsLayout = QHBoxLayout() + + self.selectFramesGroupbox = widgets.selectStartStopFrames( + SizeT, currentFrameNum=currentFrameNum, parent=parent + ) + + okButton = widgets.okPushButton("Ok") + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + layout.addWidget(self.selectFramesGroupbox) + layout.addLayout(buttonsLayout) + self.setLayout(layout) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + self.setFont(font) + + def ok_cb(self): + if self.selectFramesGroupbox.warningLabel.text(): + return + else: + self.startFrame = self.selectFramesGroupbox.startFrame_SB.value() + self.stopFrame = self.selectFramesGroupbox.stopFrame_SB.value() + self.cancel = False + self.close() + + def show(self, block=False): + super().show(block=False) + + self.resize(int(self.width() * 1.5), self.height()) + + if block: + super().show(block=True) + + +class randomWalkerDialog(QDialog): + def __init__(self, mainWindow): + super().__init__(mainWindow) + self.cancel = True + self.mainWindow = mainWindow + + if mainWindow is not None: + posData = self.mainWindow.data[self.mainWindow.pos_i] + items = [posData.filename] + else: + items = ["test"] + try: + posData = self.mainWindow.data[self.mainWindow.pos_i] + items.extend(list(posData.ol_data_dict.keys())) + except Exception as e: + pass + + self.keys = items + + self.setWindowTitle("Random walker segmentation") + + self.colors = [self.mainWindow.RWbkgrColor, self.mainWindow.RWforegrColor] + + mainLayout = QVBoxLayout() + paramsLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + + self.mainWindow.clearAllItems() + + row = 0 + paramsLayout.addWidget(QLabel("Background threshold:"), row, 0) + row += 1 + self.bkgrThreshValLabel = QLabel("0.05") + paramsLayout.addWidget(self.bkgrThreshValLabel, row, 1) + self.bkgrThreshSlider = QSlider(Qt.Horizontal) + self.bkgrThreshSlider.setMinimum(1) + self.bkgrThreshSlider.setMaximum(100) + self.bkgrThreshSlider.setValue(5) + self.bkgrThreshSlider.setTickPosition(QSlider.TickPosition.TicksBelow) + self.bkgrThreshSlider.setTickInterval(10) + paramsLayout.addWidget(self.bkgrThreshSlider, row, 0) + + row += 1 + foregrQSLabel = QLabel("Foreground threshold:") + # padding: top, left, bottom, right + foregrQSLabel.setStyleSheet("font-size:13px; padding:5px 0px 0px 0px;") + paramsLayout.addWidget(foregrQSLabel, row, 0) + row += 1 + self.foregrThreshValLabel = QLabel("0.95") + paramsLayout.addWidget(self.foregrThreshValLabel, row, 1) + self.foregrThreshSlider = QSlider(Qt.Horizontal) + self.foregrThreshSlider.setMinimum(1) + self.foregrThreshSlider.setMaximum(100) + self.foregrThreshSlider.setValue(95) + self.foregrThreshSlider.setTickPosition(QSlider.TickPosition.TicksBelow) + self.foregrThreshSlider.setTickInterval(10) + paramsLayout.addWidget(self.foregrThreshSlider, row, 0) + + # Parameters link label + row += 1 + url1 = "https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_random_walker_segmentation.html" + url2 = "https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.random_walker" + htmlTxt1 = f'here' + htmlTxt2 = f'here' + seeHereLabel = QLabel() + seeHereLabel.setText( + f"See {htmlTxt1} and {htmlTxt2} for details " + "about Random walker segmentation." + ) + seeHereLabel.setTextFormat(Qt.RichText) + seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) + seeHereLabel.setOpenExternalLinks(True) + font = QFont() + font.setPixelSize(12) + seeHereLabel.setFont(font) + seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") + paramsLayout.addWidget(seeHereLabel, row, 0, 1, 2) + + computeButton = QPushButton("Compute segmentation") + closeButton = QPushButton("Close") + + buttonsLayout.addWidget(computeButton, alignment=Qt.AlignRight) + buttonsLayout.addWidget(closeButton, alignment=Qt.AlignLeft) + + paramsLayout.setContentsMargins(0, 10, 0, 0) + buttonsLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(paramsLayout) + mainLayout.addLayout(buttonsLayout) + + self.bkgrThreshSlider.sliderMoved.connect(self.bkgrSliderMoved) + self.foregrThreshSlider.sliderMoved.connect(self.foregrSliderMoved) + computeButton.clicked.connect(self.computeSegmAndPlot) + closeButton.clicked.connect(self.close) + + self.setLayout(mainLayout) + + self.getImage() + self.plotMarkers() + + def getImage(self): + img = self.mainWindow.getDisplayedImg1() + self.img = img / img.max() + self.imgRGB = (skimage.color.gray2rgb(self.img) * 255).astype(np.uint8) + + def setSize(self): + x = self.pos().x() + y = self.pos().y() + h = self.size().height() + w = self.size().width() + if w < 400: + w = 400 + self.setGeometry(x, y, w, h) + + def plotMarkers(self): + imgMin, imgMax = self.computeMarkers() + + img = self.img + + imgRGB = self.imgRGB.copy() + R, G, B = self.colors[0] + imgRGB[:, :, 0][img < imgMin] = R + imgRGB[:, :, 1][img < imgMin] = G + imgRGB[:, :, 2][img < imgMin] = B + R, G, B = self.colors[1] + imgRGB[:, :, 0][img > imgMax] = R + imgRGB[:, :, 1][img > imgMax] = G + imgRGB[:, :, 2][img > imgMax] = B + + self.mainWindow.img1.setImage(imgRGB) + + def computeMarkers(self): + bkgrThresh = self.bkgrThreshSlider.sliderPosition() / 100 + foregrThresh = self.foregrThreshSlider.sliderPosition() / 100 + img = self.img + self.markers = np.zeros(img.shape, np.uint8) + imgRange = img.max() - img.min() + imgMin = img.min() + imgRange * bkgrThresh + imgMax = img.min() + imgRange * foregrThresh + self.markers[img < imgMin] = 1 + self.markers[img > imgMax] = 2 + return imgMin, imgMax + + def computeSegm(self, checked=True): + self.mainWindow.storeUndoRedoStates(False) + self.mainWindow.titleLabel.setText("Randomly walking around... ", color="w") + img = self.img + img = skimage.exposure.rescale_intensity(img) + t0 = time.time() + lab = skimage.segmentation.random_walker(img, self.markers, mode="bf") + lab = skimage.measure.label(lab > 1) + t1 = time.time() + if len(np.unique(lab)) > 2: + lab = skimage.morphology.remove_small_objects(lab, min_size=5) + posData = self.mainWindow.data[self.mainWindow.pos_i] + posData.lab = lab + return t1 - t0 + + def computeSegmAndPlot(self): + deltaT = self.computeSegm() + + posData = self.mainWindow.data[self.mainWindow.pos_i] + + self.mainWindow.update_rp() + self.mainWindow.tracking(enforce=True) + self.mainWindow.updateAllImages() + self.mainWindow.warnEditingWithCca_df("Random Walker segmentation") + txt = f"Random Walker segmentation computed in {deltaT:.3f} s" + print("-----------------") + print(txt) + print("=================") + # self.mainWindow.titleLabel.setText(txt, color='g') + + def bkgrSliderMoved(self, intVal): + self.bkgrThreshValLabel.setText(f"{intVal / 100:.2f}") + self.plotMarkers() + + def foregrSliderMoved(self, intVal): + self.foregrThreshValLabel.setText(f"{intVal / 100:.2f}") + self.plotMarkers() + + def closeEvent(self, event): + self.mainWindow.segmModel = "" + self.mainWindow.updateAllImages() + + +class FutureFramesAction_QDialog(QDialog): + def __init__( + self, + frame_i, + last_tracked_i, + change_txt, + applyTrackingB=False, + parent=None, + addApplyAllButton=False, + ): + self.decision = None + self.last_tracked_i = last_tracked_i + super().__init__(parent) + self.setWindowTitle("Future frames action?") + + mainLayout = QVBoxLayout() + txtLayout = QVBoxLayout() + doNotShowLayout = QVBoxLayout() + buttonsLayout = QVBoxLayout() + + txt = html_utils.paragraph( + "You already visited/checked future frames " + f"{frame_i + 1}-{last_tracked_i + 1}.

" + f'The requested "{change_txt}" change might result in
' + "NON-correct segmentation/tracking for those frames.
" + ) + + txtLabel = QLabel(txt) + txtLabel.setAlignment(Qt.AlignCenter) + txtLayout.addWidget(txtLabel, alignment=Qt.AlignCenter) + + options = [ + f'Apply the "{change_txt}" only to current frame and re-initialize
' + "the future frames to the segmentation file present
" + "on the hard drive.", + "Apply only to this frame and keep the future frames as they are.", + "Apply the change to ALL visited/checked future frames.", + ] + if addApplyAllButton: + options.append( + "Apply to ALL future frames including unvisited ones." + ) + if applyTrackingB: + options.append("Repeat ONLY tracking for all future frames (RECOMMENDED)") + + infoTxt = html_utils.paragraph( + f"Choose one of the following options:" + f"{html_utils.to_list(options, ordered=True)}" + ) + + infotxtLabel = QLabel(infoTxt) + txtLayout.addWidget(infotxtLabel, alignment=Qt.AlignCenter) + + noteLayout = QHBoxLayout() + noteTxt = html_utils.paragraph( + "Only changes applied to current frame can be undone.
" + "Changes applied to future frames CANNOT be UNDONE
" + ) + noteLayout.addWidget( + QLabel(html_utils.paragraph("NOTE:")), alignment=Qt.AlignTop + ) + noteTxtLabel = QLabel(noteTxt) + noteLayout.addWidget(noteTxtLabel) + noteLayout.addStretch(1) + txtLayout.addSpacing(10) + txtLayout.addLayout(noteLayout) + + # Do not show this message again checkbox + doNotShowCheckbox = QCheckBox( + "Remember my choice and do not show this message again" + ) + doNotShowLayout.addWidget(doNotShowCheckbox) + doNotShowLayout.setContentsMargins(50, 0, 0, 10) + self.doNotShowCheckbox = doNotShowCheckbox + + apply_and_reinit_b = widgets.reloadPushButton( + " 1. Apply only to this frame and re-initialize future frames" + ) + + self.apply_and_reinit_b = apply_and_reinit_b + buttonsLayout.addWidget(apply_and_reinit_b) + + apply_and_NOTreinit_b = widgets.currentPushButton( + " 2. Apply only to this frame and keep future frames as they are" + ) + self.apply_and_NOTreinit_b = apply_and_NOTreinit_b + buttonsLayout.addWidget(apply_and_NOTreinit_b) + + apply_to_all_visited_b = widgets.futurePushButton( + " 3. Apply to all future VISITED frames" + ) + self.apply_to_all_visited_b = apply_to_all_visited_b + buttonsLayout.addWidget(apply_to_all_visited_b) + + if addApplyAllButton: + apply_to_all_b = QPushButton( + " 4. Apply to ALL future frames (including unvisted)" + ) + apply_to_all_b.setIcon(QIcon(":arrow_future_all.svg")) + self.apply_to_all_b = apply_to_all_b + buttonsLayout.addWidget(apply_to_all_b) + + self.applyTrackingButton = None + if applyTrackingB: + n = "5" if addApplyAllButton else "4" + applyTrackingButton = QPushButton( + f" {n}. Repeat ONLY tracking for all future frames" + ) + applyTrackingButton.setIcon(QIcon(":repeat-tracking.svg")) + self.applyTrackingButton = applyTrackingButton + buttonsLayout.addWidget(applyTrackingButton) + + buttonsLayout.setContentsMargins(20, 0, 20, 0) + + self.formLayout = QFormLayout() + + ButtonsGroup = QButtonGroup(self) + ButtonsGroup.addButton(apply_and_reinit_b) + ButtonsGroup.addButton(apply_and_NOTreinit_b) + ButtonsGroup.addButton(apply_to_all_visited_b) + if addApplyAllButton: + ButtonsGroup.addButton(apply_to_all_b) + if applyTrackingB: + ButtonsGroup.addButton(applyTrackingButton) + + mainLayout.addLayout(txtLayout) + mainLayout.addLayout(doNotShowLayout) + mainLayout.addLayout(buttonsLayout) + mainLayout.addLayout(self.formLayout) + mainLayout.addStretch(1) + self.mainLayout = mainLayout + self.setLayout(mainLayout) + + # Connect events + ButtonsGroup.buttonClicked.connect(self.buttonClicked) + self.ButtonsGroup = ButtonsGroup + + # self.setModal(True) + + def buttonClicked(self, button): + if button == self.apply_and_reinit_b: + self.decision = "apply_and_reinit" + self.endFrame_i = None + elif button == self.apply_and_NOTreinit_b: + self.decision = "apply_and_NOTreinit" + self.endFrame_i = None + elif button == self.apply_to_all_visited_b: + self.decision = "apply_to_all_visited" + self.endFrame_i = self.last_tracked_i + elif button == self.applyTrackingButton: + self.decision = "only_tracking" + self.endFrame_i = self.last_tracked_i + elif button == self.apply_to_all_b: + self.decision = "apply_to_all" + self.endFrame_i = self.last_tracked_i + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + for button in self.ButtonsGroup.buttons(): + button.setMinimumHeight(int(button.height() * 1.2)) + if hasattr(self, "apply_to_all_b"): + iconHeight = self.apply_to_all_b.iconSize().height() + self.apply_to_all_b.setIconSize(QSize(iconHeight * 2, iconHeight)) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class PostProcessSegmParams(QGroupBox): + valueChanged = Signal(object) + editingFinished = Signal() + + def __init__( + self, + title, + posData, + useSliders=False, + parent=None, + maxSize=None, + force_postprocess_2D=False, + ): + QGroupBox.__init__(self, title, parent) + SizeZ = posData.SizeZ + self.isSegm3D = posData.isSegm3D + self.channelName = posData.user_ch_name + self.useSliders = useSliders + self.force_postprocess_2D = force_postprocess_2D + if maxSize is None: + maxSize = 2147483647 + + layout = QGridLayout() + + self.controlWidgets = [] + + row = 0 + label = QLabel("Minimum area (pixels) ") + layout.addWidget(label, row, 0, alignment=Qt.AlignRight) + + minSize_SB = widgets.PostProcessSegmWidget(1, 1000, 10, useSliders, label=label) + + txt = "Area is the total number of pixels in the segmented object." + + layout.addWidget(minSize_SB, row, 1) + infoButton = widgets.infoPushButton() + infoButton.clicked.connect(self.showInfo) + infoButton.tooltip = txt + infoButton.name = "area" + infoButton.desc = f'less than "{label.text()}"' + layout.addWidget(infoButton, row, 2) + self.minSize_SB = minSize_SB + self.controlWidgets.append(minSize_SB) + + # minSize_SB.disableThisCheckbox = QCheckBox('Disable this filter') + # layout.addWidget(minSize_SB.disableThisCheckbox, row, 3) + + row += 1 + label = QLabel("Minimum solidity (0-1) ") + layout.addWidget(label, row, 0, alignment=Qt.AlignRight) + minSolidity_DSB = widgets.PostProcessSegmWidget( + 0, 1.0, 0.5, useSliders, isFloat=True, normalize=True, label=label + ) + minSolidity_DSB.setValue(0.5) + minSolidity_DSB.setSingleStep(0.1) + self.controlWidgets.append(minSolidity_DSB) + + txt = ( + "Solidity is a measure of convexity. A solidity of 1 means " + "that the shape is fully convex (i.e., equal to the convex hull). " + "As solidity approaches 0 the object is more concave.
" + "Write 0 for ignoring this parameter." + ) + + layout.addWidget(minSolidity_DSB, row, 1) + infoButton = widgets.infoPushButton() + infoButton.clicked.connect(self.showInfo) + infoButton.tooltip = txt + infoButton.name = "solidity" + infoButton.desc = f'less than "{label.text()}"' + layout.addWidget(infoButton, row, 2) + self.minSolidity_DSB = minSolidity_DSB + + row += 1 + label = QLabel("Max elongation (1=circle) ") + layout.addWidget(label, row, 0, alignment=Qt.AlignRight) + maxElongation_DSB = widgets.PostProcessSegmWidget( + 0, 100, 3, useSliders, isFloat=True, normalize=False, label=label + ) + maxElongation_DSB.setDecimals(1) + maxElongation_DSB.setSingleStep(1.0) + + txt = ( + "Elongation is the ratio between major and minor axis lengths. " + "An elongation of 1 is like a circle.
" + "Write 0 for ignoring this parameter." + ) + + layout.addWidget(maxElongation_DSB, row, 1) + infoButton = widgets.infoPushButton() + infoButton.clicked.connect(self.showInfo) + infoButton.tooltip = txt + infoButton.name = "elongation" + infoButton.desc = f'greater than "{label.text()}"' + layout.addWidget(infoButton, row, 2) + self.maxElongation_DSB = maxElongation_DSB + self.controlWidgets.append(maxElongation_DSB) + + if self.isSegm3D: + row += 1 + label = QLabel("Minimum number of z-slices ") + layout.addWidget(label, row, 0, alignment=Qt.AlignRight) + minObjSizeZ_SB = widgets.PostProcessSegmWidget( + 0, SizeZ, 3, useSliders, isFloat=False, normalize=False, label=label + ) + + txt = "Minimum number of z-slices per object." + + layout.addWidget(minObjSizeZ_SB, row, 1) + infoButton = widgets.infoPushButton() + infoButton.clicked.connect(self.showInfo) + infoButton.tooltip = txt + infoButton.name = "number of z-slices" + infoButton.desc = f'less than "{label.text()}"' + layout.addWidget(infoButton, row, 2) + self.minObjSizeZ_SB = minObjSizeZ_SB + self.controlWidgets.append(minObjSizeZ_SB) + else: + self.minObjSizeZ_SB = widgets.NoneWidget() + + row += 1 + addCustomFeatureLayout = QHBoxLayout() + self.addCustomFeaturesButton = widgets.setPushButton( + "Select custom features for post-processing...", + ) + addCustomFeatureLayout.addWidget(self.addCustomFeaturesButton) + addCustomFeatureLayout.addStretch(1) + self.selectedFeaturesDialog = SelectFeaturesRangeDialog( + posData=posData, parent=self, force_postprocess_2D=force_postprocess_2D + ) + self.selectedFeaturesDialog.hide() + self.addCustomFeaturesButton.clicked.connect(self.selectedFeaturesDialog.show) + self.selectedFeaturesDialog.sigValueChanged.connect(self.onValueChanged) + + layout.addLayout(addCustomFeatureLayout, row, 0, 1, 2) + + layout.setColumnStretch(1, 2) + # layout.setRowStretch(row+1, 1) + + self.setLayout(layout) + + for widget in self.controlWidgets: + widget.valueChanged.connect(self.onValueChanged) + widget.editingFinished.connect(self.onEditingFinished) + + def selectedFeaturesRange(self): + return self.selectedFeaturesDialog.groupbox.selectedFeaturesRange() + + def groupedFeatures(self): + return self.selectedFeaturesDialog.groupbox.groupedFeatures() + + def restoreDefault(self): + self.minSolidity_DSB.setValue(0.5) + self.minSize_SB.setValue(10) + self.maxElongation_DSB.setValue(3) + self.minObjSizeZ_SB.setValue(3) + self.selectedFeaturesDialog.groupbox.resetFields() + + def restoreFromKwargs(self, kwargs): + for name, value in kwargs.items(): + if name == "min_solidity": + self.minSolidity_DSB.setValue(value) + elif name == "min_area": + self.minSize_SB.setValue(value) + elif name == "max_elongation": + self.maxElongation_DSB.setValue(value) + elif name == "min_obj_no_zslices": + self.minObjSizeZ_SB.setValue(value) + + def kwargs(self): + kwargs = { + "min_solidity": self.minSolidity_DSB.value(), + "min_area": self.minSize_SB.value(), + "max_elongation": self.maxElongation_DSB.value(), + "min_obj_no_zslices": self.minObjSizeZ_SB.value(), + } + return kwargs + + def onValueChanged(self, value): + self.valueChanged.emit(value) + + def onEditingFinished(self): + self.editingFinished.emit() + + def showInfo(self): + title = f"{self.sender().text()} info" + tooltip = self.sender().tooltip + name = self.sender().name + desc = self.sender().desc + txt = f""" + The post-processing step is applied to the output of the + segmentation model.

+ During this step, Cell-ACDC will remove all the objects with {name} + {desc}.

+ {tooltip} + """ + if self.isCheckable(): + note = f"""" + You can deactivate this step by un-checking the checkbox + called "Post-processing parameters". + """ + txt = f"{txt}{note}" + msg = widgets.myMessageBox(showCentered=False) + msg.information(self, title, html_utils.paragraph(txt)) + + +class PostProcessSegmDialog(QBaseDialog): + sigClosed = Signal() + sigValueChanged = Signal(object, object) + sigEditingFinished = Signal() + sigApplyToAllFutureFrames = Signal(object, object, object) + + def __init__(self, posData, mainWin=None, useSliders=True, maxSize=None): + super().__init__(mainWin) + self.cancel = True + self.mainWin = mainWin + self.isTimelapse = False + self.isMultiPos = False + if mainWin is not None: + self.isMultiPos = len(self.mainWin.data) > 1 + self.isTimelapse = self.mainWin.data[self.mainWin.pos_i].SizeT > 1 + + self.setWindowTitle("Post-processing segmentation parameters") + self.setWindowFlags(Qt.Tool | Qt.WindowStaysOnTopHint) + + mainLayout = QVBoxLayout() + buttonsLayout = QHBoxLayout() + + self.postProcessGroupbox = PostProcessSegmParams( + "Post-processing parameters", + posData, + useSliders=useSliders, + maxSize=maxSize, + parent=mainWin, + ) + + self.postProcessGroupbox.valueChanged.connect(self.valueChanged) + self.postProcessGroupbox.editingFinished.connect(self.onEditingFinished) + + if self.isTimelapse: + applyAllButton = widgets.futurePushButton("Apply to all frames...") + applyAllButton.clicked.connect(self.applyAll_cb) + applyButton = widgets.okPushButton("Apply", isDefault=False) + applyButton.clicked.connect(self.apply_cb) + elif self.isMultiPos: + applyAllButton = widgets.futurePushButton("Apply to all Positions...") + applyAllButton.clicked.connect(self.applyAll_cb) + applyButton = widgets.okPushButton("Apply", isDefault=False) + applyButton.clicked.connect(self.apply_cb) + else: + applyAllButton = widgets.okPushButton("Apply", isDefault=False) + applyAllButton.clicked.connect(self.ok_cb) + applyButton = None + + cancelButton = widgets.cancelPushButton("Cancel") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + if applyButton is not None: + buttonsLayout.addWidget(applyButton) + buttonsLayout.addWidget(applyAllButton) + + emitEditingFinishedButton = widgets.okPushButton() + buttonsLayout.addWidget(emitEditingFinishedButton) + emitEditingFinishedButton.hide() + buttonsLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addWidget(self.postProcessGroupbox) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + cancelButton.clicked.connect(self.cancel_cb) + + if mainWin is not None: + self.setPosData() + + def keyPressEvent(self, event) -> None: + return super().keyPressEvent(event) + + def setPosData(self): + if self.mainWin is None: + return + + self.mainWin.storeUndoRedoStates(False) + self.posData = self.mainWin.data[self.mainWin.pos_i] + # self.img.setCurrentPosIndex(self.pos_i) + # self.img.minMaxValuesMapper = self.mainWin.img1.minMaxValuesMapper + self.origLab = self.posData.lab.copy() + self.origRp = skimage.measure.regionprops(self.origLab) + self.origObjs = {obj.label: obj for obj in self.origRp} + + def valueChanged(self, value): + lab, delObjs = self.apply() + self.sigValueChanged.emit(lab, delObjs) + + def apply(self, origLab=None): + self.mainWin.warnEditingWithCca_df( + "post-processing segmentation mask", update_images=False + ) + ccaAnnotRemoved = self.mainWin.removeCcaAnnotationsCurrentFrame() + if ccaAnnotRemoved: + self.mainWin.updateAllImages() + + if origLab is None: + origLab = self.origLab.copy() + + lab, delIDs = core.post_process_segm( + origLab, return_delIDs=True, **self.postProcessGroupbox.kwargs() + ) + + if self.postProcessGroupbox.selectedFeaturesRange(): + lab, custom_delIDs = features.custom_post_process_segm( + self.posData, + self.postProcessGroupbox.groupedFeatures(), + lab, + self.posData.img_data[self.posData.frame_i], + self.posData.frame_i, + self.posData.filename, + self.posData.user_ch_name, + self.postProcessGroupbox.selectedFeaturesRange(), + return_delIDs=True, + ) + delIDs.extend(custom_delIDs) + + delObjs = {delID: self.origObjs[delID] for delID in delIDs} + return lab, delObjs + + def onEditingFinished(self): + self.sigEditingFinished.emit() + + def ok_cb(self): + self.cancel = False + self.apply() + self.onEditingFinished() + self.close() + + def apply_cb(self): + self.cancel = False + self.apply() + self.onEditingFinished() + + def applyAll_cb(self): + self.cancel = False + self.sigApplyToAllFutureFrames.emit( + self.postProcessGroupbox.kwargs(), + self.postProcessGroupbox.groupedFeatures(), + self.postProcessGroupbox.selectedFeaturesRange(), + ) + self.close() + + def cancel_cb(self): + self.cancel = True + self.close() + + def undoChanges(self): + if self.mainWin is not None: + self.posData.lab = self.origLab + self.mainWin.update_rp() + self.mainWin.updateAllImages() + + # Undo if changes were applied to all future frames + if hasattr(self, "origSegmData"): + if self.isTimelapse: + current_frame_i = self.posData.frame_i + for frame_i in range(self.posData.segmSizeT): + self.posData.frame_i = frame_i + origLab = self.origSegmData[frame_i] + lab = self.posData.allData_li[frame_i]["labels"] + if lab is None: + # Non-visited frame modify segm_data + self.posData.segm_data[frame_i] = origLab + else: + self.posData.allData_li[frame_i]["labels"] = origLab.copy() + self.posData.lab = origLab.copy() + self.mainWin.update_rp() + # Get the rest of the stored metadata based on the new lab + self.mainWin.get_data() + self.mainWin.store_data() + # Back to current frame + self.posData.frame_i = current_frame_i + self.mainWin.get_data() + self.mainWin.updateAllImages() + elif self.isMultiPos: + current_pos_i = self.mainWin.pos_i + # Apply to all future frames or future positions + for pos_i, posData in enumerate(self.mainWin.data): + self.mainWin.pos_i = pos_i + origLab = self.origSegmData[pos_i] + self.posData.allData_li[0]["labels"] = lab.copy() + # Get the rest of the stored metadata based on the new lab + self.mainWin.get_data() + self.mainWin.store_data() + # Back to current pos and current frame + self.mainWin.pos_i = current_pos_i + self.mainWin.get_data() + self.mainWin.updateAllImages() + + def show(self, block=False): + # self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show(block=False) + self.resize(int(self.width() * 1.5), self.height()) + super().show(block=block) + + def closeEvent(self, event): + self.sigClosed.emit() + if self.cancel: + self.undoChanges() + super().closeEvent(event) + + +class FunctionParamsDialog(QBaseDialog): + sigValuesChanged = Signal(dict) + + def __init__( + self, + params_argspecs, + function_name="Function", + df_metadata=None, + parent=None, + addApplyButton=False, + ): + self.cancel = True + self.df_metadata = df_metadata + + super().__init__(parent) + + self.setWindowTitle(f"{function_name} parameters") + + self.mainLayout = QVBoxLayout() + + widgetsLayout, self.argsWidgets = self.getWidgetsLayout(params_argspecs) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + self.buttonsLayout = buttonsLayout + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + if addApplyButton: + applyButton = widgets.viewPushButton("Apply") + applyButton.clicked.connect(self.emitValuesChanged) + buttonsLayout.insertWidget(3, applyButton) + self.applyButton = applyButton + + self.mainLayout.addLayout(widgetsLayout) + self.mainLayout.addSpacing(20) + self.mainLayout.addLayout(buttonsLayout) + + self.setLayout(self.mainLayout) + + def emitValuesChanged(self, *args, **kwargs): + self.sigValuesChanged.emit(self.functionKwargs()) + + def functionKwargs(self): + function_kwargs = { + argWidget.name: argWidget.valueGetter(argWidget.widget) + for argWidget in self.argsWidgets + } + return function_kwargs + + def kwargWidgetMapper(self) -> Dict[str, tuple]: + kwarg_widget_mapper = { + argWidget.name: (argWidget.widget, argWidget.valueSetter) + for argWidget in self.argsWidgets + } + return kwarg_widget_mapper + + def ok_cb(self): + self.cancel = False + + self.function_kwargs = self.functionKwargs() + + self.close() + + def getValueFromMetadata(self, name): + try: + value = self.df_metadata.at[name, "values"] + except Exception as e: + # traceback.print_exc() + value = None + return value + + def getWidgetsLayout(self, params_argspecs): + widgetsLayout = QGridLayout() + ArgsWidgets_list = [] + + for row, ArgSpec in enumerate(params_argspecs): + if _types.is_widget_not_required(ArgSpec): + continue + + arg_name = ArgSpec.name + var_name = arg_name.replace("_", " ") + var_name = f"{var_name[0].upper()}{var_name[1:]}" + label = QLabel(f"{var_name}: ") + metadata_val = self.getValueFromMetadata(ArgSpec.name) + widgetsLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) + try: + values = ArgSpec.type().values + isCustomListType = True + except Exception as err: + isCustomListType = False + + isVectorEntry = False + try: + if isinstance(ArgSpec.type(), _types.Vector): + isVectorEntry = True + except Exception as err: + pass + + isFolderPath = False + try: + if isinstance(ArgSpec.type(), _types.FolderPath): + isFolderPath = True + except Exception as err: + pass + + isCustomWidget = hasattr(ArgSpec.type, "isWidget") + + if isCustomWidget: + widget = ArgSpec.type().widget + self.checkIfTypeCLassHasCastDtype(widget) + defaultVal = ArgSpec.default + valueSetter = widget.setValue + valueGetter = widget.value + widgetsLayout.addWidget(widget, row, 1, 1, 2) + try: + widget.sigValueChanged.connect(self.emitValuesChanged) + except Exception as err: + pass + elif isVectorEntry: + vectorLineEdit = widgets.VectorLineEdit() + self.checkIfTypeCLassHasCastDtype(ArgSpec.type) + vectorLineEdit.setValue(ArgSpec.default) + defaultVal = ArgSpec.default + valueSetter = widgets.VectorLineEdit.setValue + valueGetter = widgets.VectorLineEdit.value + widget = vectorLineEdit + widgetsLayout.addWidget(vectorLineEdit, row, 1, 1, 2) + widget.valueChangeFinished.connect(self.emitValuesChanged) + elif isFolderPath: + folderPathControl = widgets.FolderPathControl() + self.checkIfTypeCLassHasCastDtype(ArgSpec.type) + folderPathControl.setText(str(ArgSpec.default)) + widget = folderPathControl + defaultVal = str(ArgSpec.default) + valueSetter = widgets.FolderPathControl.setText + valueGetter = widgets.FolderPathControl.path + widgetsLayout.addWidget(folderPathControl, row, 1, 1, 2) + widget.sigValueChanged.connect(self.emitValuesChanged) + elif ArgSpec.type == bool: + booleanGroup = QButtonGroup() + booleanGroup.setExclusive(True) + checkBox = widgets.Toggle() + checkBox.setChecked(ArgSpec.default) + defaultVal = ArgSpec.default + valueSetter = widgets.Toggle.setChecked + valueGetter = widgets.Toggle.isChecked + widget = checkBox + widgetsLayout.addWidget( + checkBox, row, 1, 1, 2, alignment=Qt.AlignCenter + ) + widget.toggled.connect(self.emitValuesChanged) + elif ArgSpec.type == int: + spinBox = widgets.SpinBox() + if metadata_val is None: + spinBox.setValue(ArgSpec.default) + else: + spinBox.setValue(int(metadata_val)) + spinBox.isMetadataValue = True + defaultVal = ArgSpec.default + valueSetter = QSpinBox.setValue + valueGetter = QSpinBox.value + widget = spinBox + widgetsLayout.addWidget(spinBox, row, 1, 1, 2) + widget.sigValueChanged.connect(self.emitValuesChanged) + elif ArgSpec.type == float: + doubleSpinBox = widgets.FloatLineEdit() + if metadata_val is None: + doubleSpinBox.setValue(ArgSpec.default) + else: + doubleSpinBox.setValue(float(metadata_val)) + doubleSpinBox.isMetadataValue = True + widget = doubleSpinBox + defaultVal = ArgSpec.default + valueSetter = widgets.FloatLineEdit.setValue + valueGetter = widgets.FloatLineEdit.value + widgetsLayout.addWidget(doubleSpinBox, row, 1, 1, 2) + widget.valueChanged.connect(self.emitValuesChanged) + elif ArgSpec.type == os.PathLike: + filePathControl = widgets.filePathControl() + filePathControl.setText(str(ArgSpec.default)) + widget = filePathControl + defaultVal = str(ArgSpec.default) + valueSetter = widgets.filePathControl.setText + valueGetter = widgets.filePathControl.path + widgetsLayout.addWidget(filePathControl, row, 1, 1, 2) + widget.sigValueChanged.connect(self.emitValuesChanged) + elif isCustomListType: + items = ArgSpec.type().values + ArgSpec.type.cast_dtype = _types.to_str + defaultVal = str(ArgSpec.default) + combobox = widgets.AlphaNumericComboBox() + combobox.addItems(items) + combobox.setCurrentValue(defaultVal) + valueSetter = widgets.AlphaNumericComboBox.setCurrentValue + valueGetter = widgets.AlphaNumericComboBox.currentValue + widget = combobox + widgetsLayout.addWidget(combobox, row, 1, 1, 2) + widget.currentTextChanged.connect(self.emitValuesChanged) + else: + lineEdit = QLineEdit() + lineEdit.setText(str(ArgSpec.default)) + lineEdit.setAlignment(Qt.AlignCenter) + widget = lineEdit + defaultVal = str(ArgSpec.default) + valueSetter = QLineEdit.setText + valueGetter = QLineEdit.text + widgetsLayout.addWidget(lineEdit, row, 1, 1, 2) + widget.editingFinished.connect(self.emitValuesChanged) + + if ArgSpec.desc: + infoButton = self.getInfoButton(ArgSpec.name, ArgSpec.desc) + widgetsLayout.addWidget(infoButton, row, 3) + + argsInfo = ArgWidget( + name=ArgSpec.name, + type=ArgSpec.type, + widget=widget, + defaultVal=defaultVal, + valueSetter=valueSetter, + valueGetter=valueGetter, + ) + ArgsWidgets_list.append(argsInfo) + + widgetsLayout.setColumnStretch(0, 0) + widgetsLayout.setColumnStretch(1, 1) + widgetsLayout.setColumnStretch(3, 0) + + return widgetsLayout, ArgsWidgets_list + + def checkIfTypeCLassHasCastDtype(self, cls): + cast_dtype = getattr(cls, "cast_dtype", None) + if callable(cast_dtype): + return + + raise AttributeError( + "The custom type or widget does not have the `cast_dtype` method. " + "Please, implement it. The method should cast the value to the " + "correct type." + ) + + def getInfoButton(self, param_name, infoText): + infoButton = widgets.infoPushButton() + infoButton.param_name = param_name + infoButton.setToolTip( + f"Click to get more info about `{param_name}` parameter..." + ) + infoButton.infoText = infoText + infoButton.clicked.connect(self.showInfoParam) + return infoButton + + def showInfoParam(self): + text = self.sender().infoText + text = html_utils.rst_urls_to_html(text) + text = html_utils.rst_to_html(text) + text = html_utils.paragraph(text) + param_name = self.sender().param_name + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, f"Info about `{param_name}` parameter", text) + + +class stopFrameDialog(QBaseDialog): + def __init__(self, posDatas, parent=None): + super().__init__(parent=parent) + + self.cancel = True + + self.setWindowTitle("Stop frame") + + mainLayout = QVBoxLayout() + + infoTxt = html_utils.paragraph( + "Enter a stop frame number for each of the loaded Positions", + center=True, + ) + exp_path = posDatas[0].exp_path + exp_path = os.path.normpath(exp_path).split(os.sep) + exp_path = f"...{f'{os.sep}'.join(exp_path[-4:])}" + subInfoTxt = html_utils.paragraph( + f"Experiment folder: {exp_path}", font_size="12px", center=True + ) + infoLabel = QLabel(f"{infoTxt}{subInfoTxt}") + infoLabel.setToolTip(posDatas[0].exp_path) + mainLayout.addWidget(infoLabel) + mainLayout.addSpacing(20) + + self.posDatas = posDatas + for posData in posDatas: + _layout = QHBoxLayout() + _layout.addStretch(1) + _label = QLabel(html_utils.paragraph(f"{posData.pos_foldername}")) + _layout.addWidget(_label) + + _spinBox = QSpinBox() + _spinBox.setMaximum(214748364) + _spinBox.setAlignment(Qt.AlignCenter) + _spinBox.setFont(font) + if posData.acdc_df is not None: + _val = posData.acdc_df.index.get_level_values(0).max() + 1 + else: + _val = posData.readLastUsedStopFrameNumber() + if _val is None: + _val = posData.SizeT + _spinBox.setValue(_val) + + posData.stopFrameSpinbox = _spinBox + + _layout.addWidget(_spinBox) + + viewButton = widgets.viewPushButton("Visualize...") + viewButton.clicked.connect(partial(self.viewChannelData, posData, _spinBox)) + _layout.addWidget(viewButton, alignment=Qt.AlignRight) + + _layout.addStretch(1) + + mainLayout.addLayout(_layout) + + buttonsLayout = QHBoxLayout() + + okButton = widgets.okPushButton(" Ok ") + cancelButton = widgets.cancelPushButton(" Cancel ") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.close) + + self.setLayout(mainLayout) + + def viewChannelData(self, posData, spinBox): + self.sender().setText("Loading...") + QTimer.singleShot( + 200, partial(self._viewChannelData, posData, spinBox, self.sender()) + ) + + def _viewChannelData(self, posData, spinBox, senderButton): + chNames = posData.chNames + if len(chNames) > 1: + ch_name_selector = prompts.select_channel_name( + which_channel="segm", allow_abort=False + ) + ch_name_selector.QtPrompt( + self, chNames, "Select channel name to visualize: " + ) + if ch_name_selector.was_aborted: + return + chName = ch_name_selector.channel_name + else: + chName = chNames[0] + + channel_file_path = load.get_filename_from_channel(posData.images_path, chName) + posData.frame_i = 0 + posData.loadImgData(imgPath=channel_file_path) + self.slideshowWin = imageViewer(posData=posData, spinBox=spinBox) + self.slideshowWin.update_img() + self.slideshowWin.show() + senderButton.setText("Visualize...") + + def ok_cb(self): + self.cancel = False + for posData in self.posDatas: + stopFrameNum = posData.stopFrameSpinbox.value() + posData.stopFrameNum = stopFrameNum + self.close() + + +class DataPrepSubCropsPathsDialog(QBaseDialog): + def __init__(self, cropPaths=None, parent=None): + self.cancel = True + + super().__init__(parent=parent) + + mainLayout = QVBoxLayout() + + gridLayout = QGridLayout() + row = 0 + + if cropPaths is None: + cropPaths = {os.path.expanduser("~"): 1} + + if any([numCrops > 1 for numCrops in cropPaths.values()]): + row += 1 + gridLayout.addWidget(QLabel("Same folder for all crops:"), row, 0) + self.sameFolderPathToggle = widgets.Toggle() + gridLayout.addWidget( + self.sameFolderPathToggle, row, 1, alignment=Qt.AlignCenter + ) + self.sameFolderPathToggle.setChecked(True) + self.sameFolderPathToggle.toggled.connect(self.setSameFolderPath) + + self.windowMinWidth = 0 + minWidth = int(self.screen().size().width() / 3) + self.folderPathLineEdits = defaultdict(list) + for path, numCrops in cropPaths.items(): + row += 1 + gridLayout.addWidget(QLabel("Master Position:"), row, 0) + masterPathLabel = QLabel(f"{path}") + gridLayout.addWidget(masterPathLabel, row, 1) + + scrollArea = QScrollArea() + scrollArea.setWidgetResizable(True) + scrollAreaLayout = QGridLayout() + for i in range(numCrops): + label = QLabel(f"Crop {i + 1} folder path:") + scrollAreaLayout.addWidget(label, i, 0) + folderPathLineEdit = widgets.ElidingLineEdit() + folderPathLineEdit.label = label + folderPathLineEdit.setText(path) + scrollAreaLayout.addWidget(folderPathLineEdit, i, 1) + browseButton = widgets.browseFileButton(start_dir=path, openFolder=True) + scrollAreaLayout.addWidget(browseButton, i, 2) + browseButton.sigPathSelected.connect( + partial(self.updateFolderPath, lineEdit=folderPathLineEdit) + ) + self.folderPathLineEdits[path].append(folderPathLineEdit) + folderPathLineEdit.browseButton = browseButton + + scrollAreaLayout.setColumnStretch(0, 0) + scrollAreaLayout.setColumnStretch(1, 1) + scrollAreaLayout.setColumnStretch(2, 0) + container = QWidget() + container.setLayout(scrollAreaLayout) + scrollArea.setWidget(container) + + row += 1 + gridLayout.addWidget(scrollArea, row, 0, 1, 2) + noHorizontalScrollbarWidth = ( + container.sizeHint().width() + + scrollArea.verticalScrollBar().sizeHint().width() + + 20 + ) + if noHorizontalScrollbarWidth > self.windowMinWidth: + self.windowMinWidth = noHorizontalScrollbarWidth + + row += 1 + gridLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) + + row += 1 + gridLayout.addItem(QSpacerItem(10, 10), row, 0, 1, 2) + + row += 1 + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def show(self, block=False): + self.resize(self.windowMinWidth, self.sizeHint().height()) + super().show(block=block) + + def setSameFolderPath(self, checked): + for masterPath, lineEdits in self.folderPathLineEdits.items(): + referencePath = lineEdits[0].text() + for lineEdit in lineEdits[1:]: + if checked: + lineEdit.setText(referencePath) + + lineEdit.setDisabled(checked) + lineEdit.browseButton.setDisabled(checked) + lineEdit.label.setDisabled(checked) + + def updateFolderPath(self, path, lineEdit=None): + lineEdit.setText(path) + lineEdit.browseButton.setStartPath(path) + + def warnFolderPathNotValid(self, cropNum, masterPath, folderPath): + text = html_utils.paragraph( + f"The following folder path for crop number {cropNum} " + "is not a valid folder or does not exist:" + ) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Not a valid folder", text, commands=(folderPath,)) + + def askOverwritingPaths(self, overwritingPaths): + text = html_utils.paragraph( + "Data in the following paths will be overwritten with " + "cropped data.

" + "Are you sure you want to continue?" + ) + msg = widgets.myMessageBox(wrapText=False) + _, yesButton = msg.warning( + self, + "Not a valid folder", + text, + commands=overwritingPaths, + buttonsTexts=("No, let me edit paths", "Yes, overwrite"), + ) + return msg.clickedButton == yesButton + + def validatePaths(self): + for masterPath, lineEdits in self.folderPathLineEdits.items(): + for i, lineEdit in enumerate(lineEdits): + path = lineEdit.text() + if os.path.exists(path) and os.path.isdir(path): + continue + + self.warnFolderPathNotValid(i + 1, masterPath, path) + return False + + overwritingPaths = [] + for masterPath, lineEdits in self.folderPathLineEdits.items(): + masterPath = masterPath.replace("\\", "/") + if not masterPath.endswith("Images"): + continue + + for i, lineEdit in enumerate(lineEdits): + path = lineEdit.text() + path = path.replace("\\", "/") + if path == masterPath: + overwritingPaths.append(masterPath) + + if not overwritingPaths: + return True + + return self.askOverwritingPaths(overwritingPaths) + + def paths(self): + selectedPaths = {} + for masterPath, lineEdits in self.folderPathLineEdits.items(): + selectedPaths[masterPath] = [le.text() for le in lineEdits] + return selectedPaths + + def ok_cb(self): + proceed = self.validatePaths() + if not proceed: + return + + self.folderPaths = self.paths() + self.cancel = False + self.close() + + +class PreProcessParamsWidget(QWidget): + sigLoadRecipe = Signal() + sigLoadSavedRecipe = Signal() + sigValuesChanged = Signal(list) + + def __init__(self, df_metadata=None, addApplyButton=False, parent=None): + super().__init__(parent) + + mainLayout = QVBoxLayout() + + self.df_metadata = df_metadata + self.addApplyButton = addApplyButton + + groupbox = QGroupBox() + self.groupbox = groupbox + + groupbox.setTitle("Pre-processing") + groupbox.setCheckable(True) + + self.gridLayout = QGridLayout() + self.row = -1 + self.stepsWidgets = {} + + self.gridLayout.setColumnStretch(0, 0) + self.gridLayout.setColumnStretch(1, 1) + self.gridLayout.setColumnStretch(2, 0) + self.gridLayout.setColumnStretch(3, 0) + self.gridLayout.setColumnStretch(4, 0) + groupbox.setLayout(self.gridLayout) + + buttonsLayout = QGridLayout() + row = 0 + col = 0 + buttonsLayout.setColumnStretch(col, 1) + + loadRecipeButton = widgets.OpenFilePushButton("Load saved recipe...") + self.loadRecipeButton = loadRecipeButton + buttonsLayout.addWidget(loadRecipeButton, row, col + 2) + + saveRecipeButton = widgets.savePushButton("Save current recipe...") + self.saveRecipeButton = saveRecipeButton + buttonsLayout.addWidget(saveRecipeButton, row + 1, col + 2) + + loadLastRecipeButton = widgets.reloadPushButton("Load last parameters") + self.loadLastRecipeButton = loadLastRecipeButton + buttonsLayout.addWidget(loadLastRecipeButton, row, col + 1) + + self.buttonsLayout = buttonsLayout + + loadLastRecipeButton.clicked.connect(self.emitLoadRecipe) + saveRecipeButton.clicked.connect(self.saveRecipe) + loadRecipeButton.clicked.connect(self.selectAndLoadRecipe) + + mainLayout.addWidget(groupbox) + mainLayout.addSpacing(10) + mainLayout.addLayout(buttonsLayout) + + self.addStep(is_first=True) + + mainLayout.setContentsMargins(0, 0, 0, 0) + self.setLayout(mainLayout) + + def stepSizeHeightHint(self): + stepWidgets = self.stepsWidgets[1] + height = ( + stepWidgets["stepLabel"].minimumSizeHint().height() + + stepWidgets["selector"].minimumSizeHint().height() + ) + return height + + def setChecked(self, checked): + self.groupbox.setChecked(checked) + + def emitLoadRecipe(self): + self.sigLoadRecipe.emit() + + def loadRecipe(self, configPars: dict): + for stepWidgets in list(self.stepsWidgets.values()): + try: + stepWidgets["delButton"].click() + except Exception as err: + pass + + configPars = self.sortStepsConfigPars(configPars) + for s in range(1, len(configPars)): + self.stepsWidgets[1]["addButton"].click() + + for i, (section, section_items) in enumerate(configPars.items()): + step_n = i + 1 + selector = self.stepsWidgets[step_n]["selector"] + kwarg_to_value_mapper = {} + for option, value in section_items.items(): + if option == "method": + selector.setCurrentText(value) + method = value + else: + kwarg_to_value_mapper[option] = value + selector.setParams(method, kwarg_to_value_mapper) + + self.setChecked(True) + + def sortStepsConfigPars(self, configPars: dict): + sortedConfigPars = {} + sortedKeys = sorted( + configPars.keys(), key=lambda key: int(re.findall(r"step(\d+)", key)[0]) + ) + for key in sortedKeys: + sortedConfigPars[key] = configPars[key] + return sortedConfigPars + + def saveRecipeUI( + self, folder_path, ext, title, basename, hintText, default_text + ): # -> tuple[Literal[False], Literal['']] | tuple[Literal[True], Any]: + win = filenameDialog( + title=title, + basename=basename, + ext=ext, + hintText=hintText, + allowEmpty=False, + defaultEntry=default_text, + parent=self, + ) + win.exec_() + if win.cancel: + return False, "" + + self.cancel = False + filepath = win.filename + os.makedirs(folder_path, exist_ok=True) + filepath = os.path.join(folder_path, filepath) + + if os.path.exists(filepath): + proceed = self.warnExistingRecipeFile(filepath) + if not proceed: + return False, "" + + return True, filepath + + def saveRecipe(self): + recipe = self.recipe() + if recipe is None: + return + + default_text = "" + for step in recipe[:2]: + method = step["method"] + func_name = config.PREPROCESS_MAPPER[method]["function_name"] + default_text = f"{default_text}-{func_name}" + default_text = default_text.lstrip("-") + + proceed, ini_filepath = self.saveRecipeUI( + preproc_recipes_path, + ".ini", + "Filename for pre-processing recipe", + "preprocessing_recipe", + "Insert a filename for the pre-processing recipe:", + default_text, + ) + if not proceed: + return + + cp = self.recipeConfigPars("acdc") + with open(ini_filepath, "w") as configfile: + cp.write(configfile) + + self.communicateSavingRecipeFinished(ini_filepath) + + def warnExistingRecipeFile(self, ini_filename): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "A file with the following name

" + f"{ini_filename}

" + "already exists.

" + "Do you want to overwrite the existing file?" + ) + noButton, yesButton = msg.warning( + self, + "File name existing", + txt, + buttonsTexts=("No, stop saving process", "Yes, overwrite existing file"), + ) + return msg.clickedButton == yesButton + + def warnNoAvailableRecipesToLoad(self): + text = html_utils.paragraph("There are no recipes saved. Sorry about that :(") + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "No recipes saved", text) + + # def selectIniFileToLoadRecipe(self): + # import qtpy.compat + # ini_filepath = qtpy.compat.getopenfilename( + # parent=self, + # caption='Select INI file to load pre-processing recipe', + # filters='INI (*.ini);;All Files (*)' + # )[0] + # if not ini_filepath: + # return + + # cp = config.ConfigParser() + # cp.read(ini_filepath) + # preprocConfigPars = {} + # for section in cp.sections(): + # if not section.startswith('acdc.preprocess'): + # continue + + # preprocConfigPars[section] = cp[section] + + # if not preprocConfigPars: + # return + + # self.loadRecipe(preprocConfigPars) + + def selectRecipeFilepath(self, recipes_path, recipe_prefix, ext_label, ext): + availableRecipes = [] + if os.path.exists(recipes_path): + for file in utils.listdir(recipes_path): + if not file.startswith(recipe_prefix): + continue + endname = file.split(f"{recipe_prefix}_")[1] + availableRecipes.append(endname) + + if not availableRecipes: + import qtpy.compat + + filepath = qtpy.compat.getopenfilename( + parent=self, + caption=f"Select {ext_label} file to load recipe", + filters=f"{ext_label} (*.{ext});;All Files (*)", + )[0] + return filepath or None + + browseButton = widgets.browseFileButton( + f"Select {ext_label} file...", + title=f"Select {ext_label} file to load recipe", + openFolder=False, + start_dir=utils.getMostRecentPath(), + ext={ext_label: f".{ext}"}, + ) + selectRecipeWin = widgets.QDialogListbox( + "Select recipe", + "Select recipe to load:\n", + availableRecipes, + multiSelection=False, + allowEmptySelection=False, + parent=self, + additionalButtons=(browseButton,), + ) + browseButton.sigPathSelected.connect( + partial( + self.recipeIniFileSelected, + selectRecipeWin=selectRecipeWin, + sender=browseButton, + ) + ) + selectRecipeWin.exec_() + if selectRecipeWin.cancel: + return None + + if selectRecipeWin.clickedButton == browseButton: + return selectRecipeWin.selectedIniFilepath + + selected_endname = selectRecipeWin.selectedItemsText[0] + filename = f"{recipe_prefix}_{selected_endname}" + return os.path.join(recipes_path, filename) + + def selectAndLoadRecipe(self): + filepath = self.selectRecipeFilepath( + preproc_recipes_path, "preprocessing_recipe", "INI", "ini" + ) + if filepath is None: + return + cp = config.ConfigParser() + cp.read(filepath) + preprocConfigPars = { + s: cp[s] for s in cp.sections() if s.startswith("acdc.preprocess") + } + if not preprocConfigPars: + return + self.loadRecipe(preprocConfigPars) + + def recipeIniFileSelected(self, ini_filepath, selectRecipeWin=None, sender=None): + selectRecipeWin.clickedButton = sender + selectRecipeWin.selectedIniFilepath = ini_filepath + selectRecipeWin.cancel = False + selectRecipeWin.close() + + def communicateSavingRecipeFinished(self, ini_filepath): + text = html_utils.paragraph("Done!

Pre-processing recipe saved to:") + msg = widgets.myMessageBox(wrapText=False) + msg.information( + self, + "Pre-processing recipe saved!", + text, + commands=(ini_filepath,), + path_to_browse=os.path.dirname(ini_filepath), + ) + + def addStep(self, is_first=False): + stepWidgets = {} + + self.row += 1 + + step_n = len(self.stepsWidgets) + 1 + label = QLabel(f"Step {step_n}: ") + self.gridLayout.addWidget(label, self.row, 0) + stepWidgets["stepLabel"] = label + + selector = widgets.PreProcessingSelector() + self.gridLayout.addWidget(selector, self.row, 1) + stepWidgets["selector"] = selector + + setParamsButton = widgets.setPushButton() + setParamsButton.setToolTip("Set step parameters") + self.gridLayout.addWidget(setParamsButton, self.row, 2) + setParamsButton.clicked.connect(partial(self.setParamsStep, selector=selector)) + stepWidgets["setParamsButton"] = setParamsButton + + infoButton = widgets.infoPushButton() + self.gridLayout.addWidget(infoButton, self.row, 3) + infoButton.clicked.connect(partial(self.showInfo, selector=selector)) + stepWidgets["infoButton"] = infoButton + + if is_first: + addButton = widgets.addPushButton() + self.gridLayout.addWidget(addButton, self.row, 4) + addButton.clicked.connect(self.addStep) + stepWidgets["addButton"] = addButton + else: + delButton = widgets.delPushButton() + self.gridLayout.addWidget(delButton, self.row, 4) + delButton.clicked.connect(self.removeStep) + delButton.step_n = step_n + stepWidgets["delButton"] = delButton + + self.row += 1 + selector.row = self.row + selector.step_n = step_n + + hline = widgets.QHLine() + self.gridLayout.addWidget(hline, self.row, 0, 1, 6) + stepWidgets["hline"] = hline + self.row += 1 + + self.stepsWidgets[step_n] = stepWidgets + + selector.sigValuesChanged.connect(self.emitValuesChanged) + selector.currentTextChanged.connect( + partial(self.clearInitKwargs, step_n=step_n) + ) + + self.resetStretch() + + def emitValuesChanged(self, functionKwargs, step_n): + self.stepsWidgets[step_n]["step_kwargs"] = functionKwargs + + recipe = self.recipe(warn=False) + if recipe is None: + return + + self.sigValuesChanged.emit(recipe) + + def clearInitKwargs(self, selected_method, step_n=0): + stepWidgets = self.stepsWidgets[step_n] + stepWidgets.pop("step_kwargs", None) + + def resetStretch(self): + for row in range(self.gridLayout.rowCount()): + self.gridLayout.setRowStretch(row, 0) + + self.gridLayout.setRowStretch(self.gridLayout.rowCount(), 1) + self.row = self.gridLayout.rowCount() - 1 + + def showInfo(self, checked=False, selector=None): + if selector is None: + return + + htmlText = selector.htmlInfo() + htmlText = html_utils.paragraph(htmlText) + + method = selector.currentText() + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, f"Info about `{method}`", htmlText) + + def setParamsStep( + self, checked=False, selector: "widgets.PreProcessingSelector" = None + ): + step_n = selector.step_n + stepFunctionKwargs = selector.askSetParams( + df_metadata=self.df_metadata, addApplyButton=self.addApplyButton + ) + if stepFunctionKwargs is None: + return + + self.stepsWidgets[step_n]["step_kwargs"] = stepFunctionKwargs + + def removeStep(self, checked=False, step_n=None): + if step_n is None: + step_n = self.sender().step_n + + stepWidgets = self.stepsWidgets[step_n] + + stepWidgets["stepLabel"].hide() + self.gridLayout.removeWidget(stepWidgets["stepLabel"]) + + stepWidgets["selector"].hide() + self.gridLayout.removeWidget(stepWidgets["selector"]) + + stepWidgets["infoButton"].hide() + self.gridLayout.removeWidget(stepWidgets["infoButton"]) + + # stepWidgets['addButton'].hide() + # self.gridLayout.removeWidget(stepWidgets['addButton']) + + stepWidgets["setParamsButton"].hide() + self.gridLayout.removeWidget(stepWidgets["setParamsButton"]) + + stepWidgets["delButton"].hide() + self.gridLayout.removeWidget(stepWidgets["delButton"]) + self.row -= 1 + + stepWidgets["hline"].hide() + self.gridLayout.removeWidget(stepWidgets["hline"]) + self.row -= 1 + + self.stepsWidgets.pop(step_n) + + stepsWidgetsMapper = {1: self.stepsWidgets[1]} + for i, stepWidgets in enumerate(self.stepsWidgets.values()): + if i == 0: + continue + step_n = i + 1 + label = stepWidgets["stepLabel"] + label.setText(f"Step {step_n}: ") + stepWidgets["delButton"].step_n = step_n + stepWidgets["selector"].step_n = step_n + stepsWidgetsMapper[step_n] = stepWidgets + + self.stepsWidgets = stepsWidgetsMapper + + self.resetStretch() + + def isChecked(self): + return self.groupbox.isChecked() + + def warnStepNotInit(self, method): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + f"The parameters for the preprocessing step {method} " + "were not initialized.

" + "Please, click on the corresponding Set step parameters " + "button to initialize this step (cog icon).

" + "Thank you for your patience!" + ) + msg.warning(self, "Params not initialized!", txt) + + def recipe(self, warn=True): + recipe = [] + if not self.groupbox.isChecked() and self.groupbox.isCheckable(): + return recipe + + for stepWidgets in self.stepsWidgets.values(): + method = stepWidgets["selector"].currentText() + step_kwargs = stepWidgets.get("step_kwargs") + if step_kwargs is None: + if warn: + self.warnStepNotInit(method) + return + + try: + init_func = config.PREPROCESS_INIT_MAPPER[method]["function"] + init_func(**step_kwargs) + except Exception as err: + pass + + recipe.append({"method": method, "kwargs": step_kwargs}) + + return recipe + + def recipeConfigPars(self, model_name): + cp = config.ConfigParser() + if not self.groupbox.isChecked() and self.groupbox.isCheckable(): + return cp + + for s, step in enumerate(self.recipe()): + section = f"{model_name}.preprocess.step{s + 1}" + cp[section] = {} + cp[section]["method"] = step["method"] + for option, value in step["kwargs"].items(): + cp[section][option] = str(value) + return cp + + +class CombineChannelsWidget(PreProcessParamsWidget): + sigValuesChangedCombineChannels = Signal() + + def __init__(self, channel_names: Iterable[str], parent=None): + self.channel_names = channel_names + + super().__init__(parent) + + self.parent = parent + qutils.delete_widget(self.loadLastRecipeButton) + qutils.delete_widget(self.saveRecipeButton) + qutils.delete_widget(self.loadRecipeButton) + + def addStep(self, is_first=False): + stepWidgets = {} + + self.row += 1 + if is_first: + self.row += 1 + + step_n = len(self.stepsWidgets) + 1 + tooltip = "Use this text in the formula" + if is_first: + label = QLabel("Formula var") + label.setToolTip(tooltip) + self.gridLayout.addWidget(label, self.row - 1, 1) + name_edit = QLineEdit(text=f"img{step_n}") + name_edit.setToolTip(tooltip) + self.gridLayout.addWidget(name_edit, self.row, 1) + stepWidgets["name_edit"] = name_edit + name_edit.textChanged.connect(self.emitValuesChanged) + + tooltip = "Select a channel or a segmentation mask" + if is_first: + label = QLabel("Channel") + label.setToolTip(tooltip) + self.gridLayout.addWidget(label, self.row - 1, 2) + ch_selector = QComboBox() + ch_selector.setToolTip(tooltip) + ch_selector.addItems(self.channel_names) + self.gridLayout.addWidget(ch_selector, self.row, 2) + stepWidgets["selector"] = ch_selector + ch_selector.currentTextChanged.connect(self.setBinarizeCheckableAndNorm) + + # add binarisaion spinbox + tooltip = ( + "If binarize is selected, the channel will be binarized first, before applying offset and multiplier.\n" + "If inverse binarize is selected, the channel will be binerized and " + "then the logical NOT will be applied." + ) + if is_first: + label = QLabel("Binarize") + label.setToolTip(tooltip) + self.gridLayout.addWidget(label, self.row - 1, 5) + options = ["No", "binarize", "inverse binarize"] + self.binarizeCombobox = QComboBox() + self.binarizeCombobox.addItems(options) + self.binarizeCombobox.setCurrentIndex(0) + self.binarizeCombobox.setEnabled(False) + self.binarizeCombobox.setToolTip(tooltip) + self.binarizeCombobox.currentIndexChanged.connect(self.emitValuesChanged) + self.gridLayout.addWidget(self.binarizeCombobox, self.row, 5) + stepWidgets["binarize"] = self.binarizeCombobox + + tooltip = "Min value of the channel to be normalized to." + if is_first: + label = QLabel("Min val") + label.setToolTip(tooltip) + self.gridLayout.addWidget(label, self.row - 1, 6) + self.minValueSpinbox = QDoubleSpinBox() + self.minValueSpinbox.setRange(-np.inf, np.inf) + self.minValueSpinbox.setSingleStep(0.1) + self.minValueSpinbox.setValue(0) + self.minValueSpinbox.setToolTip(tooltip) + + self.minValueSpinbox.valueChanged.connect(self.emitValuesChanged) + self.gridLayout.addWidget(self.minValueSpinbox, self.row, 6) + stepWidgets["minValueSpinbox"] = self.minValueSpinbox + + tooltip = "Max value of the channel to be normalized to." + if is_first: + label = QLabel("Max val") + label.setToolTip(tooltip) + self.gridLayout.addWidget(label, self.row - 1, 7) + self.maxValueSpinbox = QDoubleSpinBox() + self.maxValueSpinbox.setRange(-np.inf, np.inf) + self.maxValueSpinbox.setSingleStep(0.1) + self.maxValueSpinbox.setValue(1) + self.maxValueSpinbox.setToolTip(tooltip) + + self.maxValueSpinbox.valueChanged.connect(self.emitValuesChanged) + self.gridLayout.addWidget(self.maxValueSpinbox, self.row, 7) + stepWidgets["maxValueSpinbox"] = self.maxValueSpinbox + + if is_first: + addButton = widgets.addPushButton() + self.gridLayout.addWidget(addButton, self.row, 8) + addButton.clicked.connect(self.addStep) + stepWidgets["addButton"] = addButton + + else: + delButton = widgets.delPushButton() + self.gridLayout.addWidget(delButton, self.row, 8) + delButton.clicked.connect(self.removeStep) + delButton.step_n = step_n + stepWidgets["delButton"] = delButton + + self.row += 1 + ch_selector.row = self.row + ch_selector.step_n = step_n + + hline = widgets.QHLine() + self.gridLayout.addWidget(hline, self.row, 0, 1, 8) + stepWidgets["hline"] = hline + self.row += 1 + + self.stepsWidgets[step_n] = stepWidgets + + self.resetStretch() + self.sigValuesChangedCombineChannels.emit() + self.setBinarizeCheckableAndNorm() + + def emitValuesChanged(self, *args): + self.sigValuesChangedCombineChannels.emit() + + def setBinarizeCheckableAndNorm(self): + for step_n, stepWidgets in self.stepsWidgets.items(): + binarizeSelector = stepWidgets["binarize"] + channel = stepWidgets["selector"].currentText() + if "segm" in channel: + binarizeSelector.setEnabled(True) + # set min and max to 0 and 1 and disable + stepWidgets["minValueSpinbox"].setValue(0) + stepWidgets["maxValueSpinbox"].setValue(1) + stepWidgets["minValueSpinbox"].setEnabled(False) + stepWidgets["maxValueSpinbox"].setEnabled(False) + else: + binarizeSelector.setEnabled(False) + binarizeSelector.setCurrentIndex(0) + # set min and max to 0 and 1 and enable + stepWidgets["minValueSpinbox"].setEnabled(True) + stepWidgets["maxValueSpinbox"].setEnabled(True) + + self.emitValuesChanged() + + def removeStep(self, checked=False, step_n=None): + if step_n is None: + step_n = self.sender().step_n + + stepWidgets = self.stepsWidgets[step_n] + + stepWidgets["name_edit"].hide() + self.gridLayout.removeWidget(stepWidgets["name_edit"]) + + stepWidgets["selector"].hide() + self.gridLayout.removeWidget(stepWidgets["selector"]) + + stepWidgets["binarize"].hide() + self.gridLayout.removeWidget(stepWidgets["binarize"]) + + stepWidgets["minValueSpinbox"].hide() + self.gridLayout.removeWidget(stepWidgets["minValueSpinbox"]) + + stepWidgets["maxValueSpinbox"].hide() + self.gridLayout.removeWidget(stepWidgets["maxValueSpinbox"]) + + stepWidgets["delButton"].hide() + self.gridLayout.removeWidget(stepWidgets["delButton"]) + + self.row -= 1 + + stepWidgets["hline"].hide() + self.gridLayout.removeWidget(stepWidgets["hline"]) + self.row -= 1 + + self.stepsWidgets.pop(step_n) + + stepsWidgetsMapper = {1: self.stepsWidgets[1]} + for i, stepWidgets in enumerate(self.stepsWidgets.values()): + if i == 0: + continue + step_n = i + 1 + stepWidgets["delButton"].step_n = step_n + stepWidgets["selector"].step_n = step_n + stepsWidgetsMapper[step_n] = stepWidgets + + self.stepsWidgets = stepsWidgetsMapper + + self.resetStretch() + self.sigValuesChangedCombineChannels.emit() + + def steps(self): + steps = {} + if not self.groupbox.isChecked() and self.groupbox.isCheckable(): + return steps + + for step_number, stepWidgets in self.stepsWidgets.items(): + name = stepWidgets["name_edit"].text() + channel = stepWidgets["selector"].currentText() + binarize = stepWidgets["binarize"].currentText() + min_val = stepWidgets["minValueSpinbox"].value() + max_val = stepWidgets["maxValueSpinbox"].value() + steps[step_number] = { + "name": name, + "channel": channel, + "binarize": binarize, + "min_val": min_val, + "max_val": max_val, + } + + steps = dict(sorted(steps.items())) + return steps + + +class FormulaEditWidget(QWidget): + sigFormulaChanged = Signal(str, bool) # formula_str, is_valid + + def __init__(self, variable_names=None, parent=None): + super().__init__(parent) + self._variable_names = variable_names or [] + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(4) + + self._edit = QLineEdit() + self._edit.setPlaceholderText("e.g. img1 + img2 * 0.5") + layout.addWidget(self._edit) + + self._status_label = QLabel() + self._status_label.setWordWrap(True) + self._status_label.setStyleSheet("font-size: 11px;") + layout.addWidget(self._status_label) + + self._edit.textChanged.connect(self._onTextChanged) + self._clearStatus() + + self.parent = parent + + def setVariableNames(self, variable_names): + """Allows setting the variables. + + Parameters + ---------- + variable_names : list + list of variable names (strings) + """ + + self._variable_names = variable_names + self._onTextChanged(self._edit.text()) + + def text(self): + """Returns the current formula text.""" + return self._edit.text() + + def setText(self, text): + """Sets the formula text.""" + self._edit.setText(text) + + def _clearStatus(self): + self._status_label.setText("") + self._status_label.setStyleSheet("font-size: 11px;") + + def _onTextChanged(self, text): + if not text.strip(): + self._clearStatus() + + success, reconstructed_str = self.checkValidity(self._variable_names) + + if success: + self._status_label.setText(f"→ {reconstructed_str}") + self._status_label.setStyleSheet("font-size: 11px; color: green;") + else: + self._status_label.setText(reconstructed_str) + self._status_label.setStyleSheet("font-size: 11px; color: red;") + + self.sigFormulaChanged.emit(text, success) + + def checkValidity(self, variable_names=None): + if variable_names is None: + variable_names = self._variable_names + formula_str = self._edit.text() + arrays = {name: 1 for name in variable_names} + success = False + reconstructed_str = "ERROR" + forb_ch = self.parent.forbiddenChannels + if forb_ch: + stepsWidgets = self.parent.combineChannelsWidget.stepsWidgets + channels = { + stepsWidget["selector"].currentText() + for stepsWidget in stepsWidgets.values() + } + if forb_ch.intersection(channels): + reconstructed_str = ( + "Channels that are forbidden are not allowed to be used!:\n" + f"{forb_ch}" + ) + return False, reconstructed_str + if formula_str == "": + reconstructed_str = "First channel is returned/applied" + return True, reconstructed_str + try: + symbols = {name: sp.Symbol(name) for name in arrays} + expr = sp.sympify(formula_str, locals=symbols) + missing = {str(s) for s in expr.free_symbols} - arrays.keys() + if missing: + reconstructed_str = f"Missing variables: {missing}" + return False, reconstructed_str + + if formula_str == "": + reconstructed_str = "" + return True, reconstructed_str + + # filter out expressions that have no variables + if not any(s.is_Symbol for s in expr.free_symbols): + reconstructed_str = "No variables used" + return False, reconstructed_str + + reconstructed_str = str(expr) + success = True + except Exception as e: + if "syntax" in str(e): + reconstructed_str = f"Syntax error" + else: + reconstructed_str = str(e) + success = False + return success, reconstructed_str + + +class InitFijiMacroDialog(QBaseDialog): + def __init__(self, parent=None): + self.cancel = True + + super().__init__(parent=parent) + + mainLayout = QVBoxLayout() + + infoLabel = QLabel( + html_utils.paragraph( + """ + Place all the raw microscopy files in a folder without any other + file
+ and provide the following information: + """ + ) + ) + mainLayout.addWidget(infoLabel) + + gridLayout = QGridLayout() + + row = 0 + label = QLabel("Files internal structure: ") + gridLayout.addWidget(label, row, 0) + self.filesStructureCombobox = QComboBox() + self.filesStructureCombobox.addItems( + [ + 'Positions (aka "series") embedded in the file', + 'Positions (aka "series") separated, one for each file', + 'Positions (aka "series") and channels separated, one for each file', + ] + ) + gridLayout.addWidget(self.filesStructureCombobox, row, 1) + self.filesStructureCombobox.currentTextChanged.connect( + self.fileStructureChanged + ) + infoButton = widgets.infoPushButton() + gridLayout.addWidget(infoButton, row, 2) + infoButton.clicked.connect(self.showInfoFileStructure) + + row += 1 + label = QLabel("Folder with raw microscopy files: ") + gridLayout.addWidget(label, row, 0) + self.folderPathLineEdit = widgets.ElidingLineEdit() + gridLayout.addWidget(self.folderPathLineEdit, row, 1) + browseButton = widgets.browseFileButton(openFolder=True) + gridLayout.addWidget(browseButton, row, 2) + browseButton.sigPathSelected.connect( + partial(self.updateFolderPath, lineEdit=self.folderPathLineEdit) + ) + self.folderPathLineEdit.textChanged.connect(self.srcFolderPathChanged) + + row += 1 + label = QLabel("Destination folder: ") + gridLayout.addWidget(label, row, 0) + self.dstfolderPathLineEdit = widgets.ElidingLineEdit() + gridLayout.addWidget(self.dstfolderPathLineEdit, row, 1) + browseButton = widgets.browseFileButton(openFolder=True) + gridLayout.addWidget(browseButton, row, 2) + browseButton.sigPathSelected.connect(self.dstfolderPathLineEdit.setText) + + row += 1 + label = QLabel("Channel(s) name: ") + gridLayout.addWidget(label, row, 0) + self.channelNamesLineEdit = widgets.alphaNumericLineEdit(additionalChars=" ,") + gridLayout.addWidget(self.channelNamesLineEdit, row, 1) + checkButton = widgets.TestPushButton("Check") + gridLayout.addWidget(checkButton, row, 3) + checkButton.clicked.connect(self.checkChannelNames) + checkButton.setDisabled(True) + self.checkButton = checkButton + infoButton = widgets.infoPushButton() + gridLayout.addWidget(infoButton, row, 2) + infoButton.clicked.connect(self.showInfoChannelName) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + gridLayout.setColumnStretch(0, 0) + gridLayout.setColumnStretch(1, 1) + gridLayout.setColumnStretch(2, 0) + gridLayout.setColumnStretch(3, 0) + + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def fileStructureChanged(self, text): + self.checkButton.setDisabled(not "channels separated" in text) + + def checkChannelNames(self, checked=False): + proceed = self.validate() + if not proceed: + return + + src_folderpath = self.folderPath() + channel_names = self.channelNames() + extension = os.listdir(src_folderpath)[0].split(".")[-1] + basenames = io.move_separate_channels_tiffs_to_pos_folders( + src_folderpath, channel_names, get_only_basenames=True, extension=extension + ) + pos_folders_texts = [] + for p, basename in enumerate(basenames): + pos_folders_texts.append(f"Position_{p + 1}: {basename}") + + pos_folders_html_list = html_utils.to_list(pos_folders_texts, ordered=True) + text = html_utils.paragraph( + "The following Position folders will be created based on the provided channel names:
" + f"{pos_folders_html_list}" + ) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Position folders", text) + + def srcFolderPathChanged(self, text): + if self.dstfolderPathLineEdit.text(): + return + + folderPath = self.folderPathLineEdit.text() + self.dstfolderPathLineEdit.setText(folderPath) + + def showInfoFileStructure(self): + txt = html_utils.paragraph(""" + Select whether the microscopy files contains multiple "series".

+ This typically depends on how you acquired the images at the + microscope, i.e., you generated multiple microscopy files + (e.g., snapshots), or you setup automatic acquisition of multiple + positions. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Files structure info", txt) + + def showInfoChannelName(self): + txt = html_utils.paragraph(""" + Enter the channels name. Separate multiple channels with a comma.

+ The channel names will be used to name the individual TIFF files + (one for each channel).

+ If multiple channels are embedded in the microscopy file, make sure that you write the channels in the right order.
+ If you are unsure, open the file in Fiji first + and check the order of channels.

+ If the channels are already separated, make sure to write the + full channel name as it appears in the file, including capitalization and spaces.
+ For example, if the files are named "pos1_ch1.tif", "pos1_ch2.tif", etc., the channels names should be "ch1, ch2".

+ After providing the channel names, you can check that they are correct by clicking on the "Check" button next to the channel names field.
+ The number of Positions that will be created will be displayed alongside the basename. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Files structure info", txt) + + def updateFolderPath(self, path, lineEdit=""): + for file in os.listdir(path): + if not is_alphanumeric_filename(file): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + f""" + The filename {file} contains invalid + characters.

+ Valid characters are letters, numbers, spaces, underscores + and dashes.

+ Please rename the file and try again.

+ Thank you for your patience! + """ + ) + msg.critical(self, "Invalid filename", txt, path_to_browse=path) + lineEdit.setText("") + return + + lineEdit.setText(path) + + def warnPathEmpty(self, path_name): + txt = html_utils.paragraph(f""" + {path_name} cannot be empty. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Empty folder path", txt) + + def warnSelectedPathDoesNotExist(self, path): + txt = html_utils.paragraph(""" + The selected path does not exist.

+ Selected path: + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Folder path does not exist", txt, commands=(path,)) + + def warnSelectedPathNotAFolder(self, path): + txt = html_utils.paragraph(""" + The selected path is not a folder.

+ Selected path: + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Selected path not a folder", txt, commands=(path,)) + + def warnMultipleExtensionsPresent(self, path, extensions): + txt = html_utils.paragraph(f""" + The selected path contains files with different extensions. +

+ Extensions present: {extensions}

+ Please, make sure that all the files in the folder have the same + extension before proceeding.

+ Selected path: + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Multiple file extensions detected", txt, commands=(path,)) + + def warnChannelNamesEmpty(self): + txt = html_utils.paragraph(""" + Channel(s) name cannot be empty. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Empty channel name", txt) + + def validate(self): + path = self.folderPath() + dst_path = self.dstfolderPathLineEdit.text() + paths = { + "Source folder": path, + "Destination folder": dst_path, + } + for _path_name, _path in paths.items(): + if not _path: + self.warnPathEmpty(_path_name) + return False + + if not os.path.exists(_path): + self.warnSelectedPathDoesNotExist(_path) + return False + + if not os.path.isdir(_path): + self.warnSelectedPathNotAFolder(_path) + return False + + files = utils.listdir(path) + extensions = set([os.path.splitext(file)[1] for file in files]) + if len(extensions) > 1: + self.warnMultipleExtensionsPresent(path, extensions) + return False + + if not self.channelNamesLineEdit.text(): + self.warnChannelNamesEmpty() + return False + + return True + + def folderPath(self): + return self.folderPathLineEdit.text() + + def channelNames(self): + channel_names = self.channelNamesLineEdit.text().split(",") + channel_names = [ch.strip() for ch in channel_names] + return channel_names + + def ok_cb(self): + proceed = self.validate() + if not proceed: + return + + self.selectedFolderPath = self.folderPath() + self.filesStructure = self.filesStructureCombobox.currentText() + is_multiple_files = self.filesStructure.find("separated") != -1 + is_separate_channels = "channels separated" in self.filesStructure + dst_folderpath = self.dstfolderPathLineEdit.text() + self.init_macro_args = ( + self.folderPath(), + is_multiple_files, + is_separate_channels, + dst_folderpath, + self.channelNames(), + ) + self.cancel = False + self.close() + + +class ImageJRoisToSegmManager(QBaseDialog): + def __init__( + self, + rois_filepath, + TZYX_shape, + addUseSamePropsForNextPosButton=False, + parent=None, + ): + import roifile + + self.cancel = True + super().__init__(parent) + + self.setWindowTitle("ROI Manager") + + mainLayout = QVBoxLayout() + + rois = roifile.roiread(rois_filepath) + self.rois = {roi.name: roi for roi in rois} + + roisNamesTreeWidget = widgets.TreeWidget() + roisNamesTreeWidget.setHeaderLabels(["ROI name", "Cell_ID"]) + roisNamesTreeWidget.header().setSectionResizeMode(QHeaderView.ResizeToContents) + # roisNamesTreeWidget.header().setStretchLastSection(False) + for r, roi in enumerate(rois): + item = widgets.TreeWidgetItem() + item.setText(0, roi.name) + item.setText(1, str(r + 1)) + roisNamesTreeWidget.addTopLevelItem(item) + roisNamesTreeWidget.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + roisNamesTreeWidget.selectAll() + mainLayout.addWidget(QLabel("Select ROIs to convert")) + mainLayout.addWidget(roisNamesTreeWidget) + self.roisNamesTreeWidget = roisNamesTreeWidget + mainLayout.addSpacing(10) + mainLayout.addWidget(widgets.QHLine()) + mainLayout.addSpacing(5) + + gridLayout = None + self.lowZspinbox = None + + SizeT, SizeZ, SizeY, SizeX = TZYX_shape + if SizeZ > 1: + gridLayout = QGridLayout() + self.lowZspinbox = widgets.SpinBox() + self.lowZspinbox.setMinimum(0) + self.lowZspinbox.setMaximum(SizeZ - 1) + + self.highZspinbox = widgets.SpinBox() + self.highZspinbox.setMinimum(0) + self.highZspinbox.setMaximum(SizeZ - 1) + self.highZspinbox.setValue(SizeZ - 1) + + gridLayout.addWidget(QLabel("Repeat 2D ROIs over z-range: "), 1, 0) + + gridLayout.addWidget(QLabel("Start z-slice"), 0, 1) + gridLayout.addWidget(self.lowZspinbox, 1, 1) + + gridLayout.addWidget(QLabel("Stop z-slice"), 0, 2) + gridLayout.addWidget(self.highZspinbox, 1, 2) + + if gridLayout is not None: + mainLayout.addLayout(gridLayout) + mainLayout.addSpacing(5) + mainLayout.addWidget(widgets.QHLine()) + mainLayout.addSpacing(10) + + self.rescaleRoisGroupbox = widgets.RescaleImageJroisGroupbox(TZYX_shape) + self.rescaleRoisGroupbox.setChecked(False) + mainLayout.addWidget(self.rescaleRoisGroupbox) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + self.useSamePropsForNextPos = False + if addUseSamePropsForNextPosButton: + useSamePropsForNextPosButton = widgets.reloadPushButton( + "Keep the same preferences for all next Positions" + ) + buttonsLayout.insertWidget(3, useSamePropsForNextPosButton) + useSamePropsForNextPosButton.clicked.connect( + self.useSamePropsForNextPosClicked + ) + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def useSamePropsForNextPosClicked(self): + self.useSamePropsForNextPos = True + self.ok_cb() + + def warnRoiSelectionEmpty(self): + txt = html_utils.paragraph(f""" + You did not select any ROI.

+ ROIs selection cannot be empty. Thank you for your patience! + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "ROIs selection empty", txt) + + def ok_cb(self): + selectedRois = self.roisNamesTreeWidget.selectedItems() + if not selectedRois: + self.useSamePropsForNextPos = False + self.warnRoiSelectionEmpty() + return + + self.IDsToRoisMapper = {} + for item in selectedRois: + roiName = item.text(0) + ID = int(item.text(1)) + self.IDsToRoisMapper[ID] = self.rois[roiName] + + numRois = self.roisNamesTreeWidget.topLevelItemCount() + self.areAllRoisSelected = len(self.IDsToRoisMapper) == numRois + + self.rescaleSizes = self.rescaleRoisGroupbox.inputOutputSizes() + self.repeatRoisZslicesRange = None + if self.lowZspinbox is not None: + self.repeatRoisZslicesRange = ( + self.lowZspinbox.value(), + self.highZspinbox.value() + 1, + ) + + self.cancel = False + self.close() + + +class ResizeUtilProps(QBaseDialog): + def __init__(self, input_path="", parent=None): + self.cancel = True + super().__init__(parent) + + self.setWindowTitle("Resize Data Properties") + + mainLayout = QVBoxLayout() + + paramsLayout = QGridLayout() + + self._input_path = input_path + + row = 0 + paramsLayout.addWidget(QLabel("Overwrite raw data: "), row, 0) + self.overwriteToggle = widgets.Toggle() + self.overwriteToggle.setChecked(True) + paramsLayout.addWidget( + self.overwriteToggle, row, 1, 1, 2, alignment=Qt.AlignCenter + ) + + row += 1 + paramsLayout.addWidget(QLabel("Folder path for resized images: "), row, 0) + self.folderPathOutControl = widgets.filePathControl( + browseFolder=True, + fileManagerTitle="Select folder where to save resized data", + elide=True, + startFolder=utils.getMostRecentPath(), + ) + self.folderPathOutControl.setDisabled(True) + paramsLayout.addWidget(self.folderPathOutControl, row, 1, 1, 2) + + row += 1 + paramsLayout.addWidget(QLabel("Text to append to files: "), row, 0) + self.textToAppendLineEdit = widgets.alphaNumericLineEdit() + self.textToAppendLineEdit.setAlignment(Qt.AlignCenter) + self.textToAppendLineEdit.setDisabled(True) + paramsLayout.addWidget(self.textToAppendLineEdit, row, 1, 1, 2) + + row += 1 + paramsLayout.addWidget(QLabel("Resize mode: "), row, 0) + self.downScaleRadioButton = QRadioButton("Downscale") + self.upScaleRadioButton = QRadioButton("Upscale") + self.downScaleRadioButton.setChecked(True) + paramsLayout.addWidget( + self.downScaleRadioButton, row, 1, alignment=Qt.AlignCenter + ) + paramsLayout.addWidget( + self.upScaleRadioButton, row, 2, alignment=Qt.AlignCenter + ) + + row += 1 + paramsLayout.addWidget(QLabel("Resize factor: "), row, 0) + self.factorSpinbox = widgets.FloatLineEdit(allowNegative=False) + self.factorSpinbox.setMinimum(1.0) + self.factorSpinbox.setValue(2.0) + paramsLayout.addWidget(self.factorSpinbox, row, 1, 1, 2) + + paramsLayout.setColumnStretch(0, 0) + paramsLayout.setVerticalSpacing(10) + + self.overwriteToggle.toggled.connect(self.overwriteToggled) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(paramsLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + mainLayout.addStretch(1) + + # self.textToAppendLineEdit.setText(self._getDefaultTextToAppend()) + + self.setLayout(mainLayout) + + def _getDefaultTextToAppend(self): + rescale_mode = "up" if self.upScaleRadioButton.isChecked() else "down" + factor = self.factorSpinbox.value() + text = f"{rescale_mode}scaled_factor_{factor}" + return text + + def overwriteToggled(self, checked): + self.folderPathOutControl.setDisabled(checked) + self.textToAppendLineEdit.setDisabled(checked) + if checked: + text = "" + else: + text = self._getDefaultTextToAppend() + self.textToAppendLineEdit.setText(text) + + def warnFolderPathEmpty(self): + txt = html_utils.paragraph(""" + To prevent overwriting raw data the Folder path for + resized images cannot be empty. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Empty folder path", txt) + + def warnTextToAppendEmpty(self): + txt = html_utils.paragraph(""" + To prevent overwriting raw data the text to append + cannot be empty. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Empty text to append", txt) + + def ok_cb(self): + self.expFolderpathOut = self.folderPathOutControl.path() + self.textToAppend = self.textToAppendLineEdit.text() + isAccidentalOverwrite = ( + not self.overwriteToggle.isChecked() + and self.expFolderpathOut == self._input_path + and not self.textToAppend + ) + if isAccidentalOverwrite: + self.warnTextToAppendEmpty() + return + + if self.textToAppend and not self.textToAppend.startswith("_"): + self.textToAppend = f"_{self.textToAppend}" + + if self.overwriteToggle.isChecked(): + self.expFolderpathOut = None + + factor = self.factorSpinbox.value() + self.resizeFactor = ( + factor if self.upScaleRadioButton.isChecked() else 1 / factor + ) + + self.cancel = False + self.close() + + +class FucciPreprocessDialog(FunctionParamsDialog): + def __init__( + self, + channel_names, + df_metadata=None, + parent=None, + ): + + from cellacdc.preprocess import fucci_filter + + params_argspecs = utils.get_function_argspec(fucci_filter) + + super().__init__( + params_argspecs, + function_name="FUCCI pre-processing", + df_metadata=df_metadata, + parent=parent, + ) + + channelNamesLayout = QGridLayout() + + row = 0 + label = QLabel("First channel name: ") + channelNamesLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) + self.firstChNameWidget = QComboBox() + self.firstChNameWidget.addItems(channel_names) + channelNamesLayout.addWidget(self.firstChNameWidget, row, 1) + + row += 1 + label = QLabel("Second channel name: ") + channelNamesLayout.addWidget(label, row, 0, alignment=Qt.AlignLeft) + self.secondChNameWidget = QComboBox() + self.secondChNameWidget.addItems(channel_names) + self.secondChNameWidget.setCurrentText(list(channel_names)[1]) + channelNamesLayout.addWidget(self.secondChNameWidget, row, 1) + + channelNamesLayout.setColumnStretch(0, 0) + channelNamesLayout.setColumnStretch(1, 1) + + self.mainLayout.insertLayout(0, channelNamesLayout) + self.mainLayout.insertWidget(1, widgets.QHLine()) + + def ok_cb(self): + self.firstChannelName = self.firstChNameWidget.currentText() + self.secondChannelName = self.secondChNameWidget.currentText() + super().ok_cb() + + +class PreProcessRecipeDialog(QBaseDialog): + sigApplyImage = Signal(object) + sigApplyZstack = Signal(object) + sigApplyAllFrames = Signal(object) + sigApplyAllPos = Signal(object) + sigPreviewToggled = Signal(bool) + sigValuesChanged = Signal(list) + sigSavePreprocData = Signal(object) + sigClose = Signal(object) + + def __init__( + self, + isTimelapse=False, + isZstack=False, + isMultiPos=False, + df_metadata=None, + addApplyButton=False, + parent=None, + hideOnClosing=False, + ): + super().__init__(parent=parent) + + self.setWindowTitle("Pre-processing recipe") + + self.cancel = True + self.hideOnClosing = hideOnClosing + + mainLayout = QVBoxLayout() + + keepInputDataTypeLayout = QHBoxLayout() + self.keepInputDataTypeToggle = widgets.Toggle() + self.keepInputDataTypeToggle.setChecked(True) + self.keepInputDataTypeToggle.toggled.connect(self.emitValuesChanged) + + keepInputDataTypeLayout.addStretch(1) + keepInputDataTypeLayout.addWidget(QLabel("Keep input data type: ")) + keepInputDataTypeLayout.addWidget(self.keepInputDataTypeToggle) + keepInputDataTypeInfoButton = widgets.infoPushButton() + keepInputDataTypeLayout.addWidget(keepInputDataTypeInfoButton) + keepInputDataTypeInfoButton.clicked.connect(self.showInfoKeepInputDataType) + self.keepInputDataTypeLayout = keepInputDataTypeLayout + + self.preProcessParamsWidget = PreProcessParamsWidget( + df_metadata=df_metadata, addApplyButton=addApplyButton, parent=self + ) + self.preProcessParamsWidget.groupbox.setCheckable(False) + + buttonsLayout = QGridLayout() # self.preProcessParamsWidget.buttonsLayout + self.buttonsLayout = buttonsLayout + self.previewCheckbox = QCheckBox("Preview") + buttonsLayout.addWidget(self.previewCheckbox, 0, 0) + + # Relocate buttons of PreProcessParamsWidget to this dialog + pPPWBL = self.preProcessParamsWidget.buttonsLayout + loadRecipeButtIdx = pPPWBL.indexOf(self.preProcessParamsWidget.loadRecipeButton) + self.loadRecipeButton = pPPWBL.takeAt(loadRecipeButtIdx).widget() + buttonsLayout.addWidget(self.loadRecipeButton, 0, 1) + + saveRecipeButtIdx = pPPWBL.indexOf(self.preProcessParamsWidget.saveRecipeButton) + self.saveRecipeButton = pPPWBL.takeAt(saveRecipeButtIdx).widget() + buttonsLayout.addWidget(self.saveRecipeButton, 1, 1) + + loadLastRecipeButtIdx = pPPWBL.indexOf( + self.preProcessParamsWidget.loadLastRecipeButton + ) + self.loadLastRecipeButton = pPPWBL.takeAt(loadLastRecipeButtIdx).widget() + buttonsLayout.addWidget(self.loadLastRecipeButton, 1, 0) + + self.loadLastRecipeButton.hide() + + # self.cancelButton = widgets.cancelPushButton('Cancel') + # buttonsLayout.insertWidget(2, self.cancelButton) + # buttonsLayout.insertSpacing(3, 20) + + self.allButtons = [ + self.previewCheckbox, + self.loadRecipeButton, + self.saveRecipeButton, + ] + col = 3 + row = 0 + self.applyCurrentFrameButton = widgets.okPushButton("Apply to displayed image") + buttonsLayout.addWidget(self.applyCurrentFrameButton, row, col) + self.applyCurrentFrameButton.clicked.connect( + partial(self.apply, signal=self.sigApplyImage) + ) + self.allButtons.append(self.applyCurrentFrameButton) + + infoLayout = QHBoxLayout() + buttonsHeight = self.applyCurrentFrameButton.sizeHint().height() + self.loadingCircle = widgets.LoadingCircleAnimation(size=buttonsHeight) + sp = self.loadingCircle.sizePolicy() + sp.setRetainSizeWhenHidden(True) + self.loadingCircle.setSizePolicy(sp) + self.loadingCircle.setVisible(False) + infoLayout.addWidget(self.loadingCircle) + + self.infoLabel = QLabel("(Feel free to use Cell-ACDC while waiting)") + sp = self.infoLabel.sizePolicy() + sp.setRetainSizeWhenHidden(True) + self.infoLabel.setSizePolicy(sp) + self.infoLabel.hide() + infoLayout.addWidget(self.infoLabel) + + buttonsLayout.addLayout( + infoLayout, row + 1, 0, 3, 2, alignment=Qt.AlignBottom | Qt.AlignLeft + ) + + if isZstack: + row += 1 + self.applyAllZslicesButton = widgets.threeDPushButton( + "Apply to all z-slices of current image" + ) + buttonsLayout.addWidget(self.applyAllZslicesButton, row, col) + self.applyAllZslicesButton.clicked.connect(self.applyAllZslices) + self.allButtons.append(self.applyAllZslicesButton) + if isTimelapse: + row += 1 + self.applyAllFramesButton = widgets.futurePushButton("Apply to all frames") + buttonsLayout.addWidget(self.applyAllFramesButton, row, col) + self.applyAllFramesButton.clicked.connect(self.applyAllFrames) + self.allButtons.append(self.applyAllFramesButton) + if isMultiPos: + row += 1 + self.applyAllPosButton = widgets.futurePushButton("Apply to all Positions") + buttonsLayout.addWidget(self.applyAllPosButton, row, col) + self.applyAllPosButton.clicked.connect( + partial(self.apply, signal=self.sigApplyAllPos) + ) + self.allButtons.append(self.applyAllPosButton) + + row += 1 + self.savePreprocButton = widgets.savePushButton("Save pre-processed data...") + buttonsLayout.addWidget(self.savePreprocButton, row, col) + + self.allButtons.append(self.savePreprocButton) + self.savePreprocButton.clicked.connect(self.emitSignalSavePreprocData) + + self.previewCheckbox.toggled.connect(self.emitSigPreviewToggled) + self.preProcessParamsWidget.sigValuesChanged.connect(self.emitValuesChanged) + + # self.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(keepInputDataTypeLayout) + mainLayout.addSpacing(20) + mainLayout.addWidget(self.preProcessParamsWidget) + mainLayout.addLayout(buttonsLayout) + self.mainLayout = mainLayout + + self.setLayout(mainLayout) + + def applyAllZslices(self, checked=False): + # Preview needs to be turned off because we are computing on every + # z-slice + self.previewCheckbox.setChecked(False) + self.apply(signal=self.sigApplyZstack) + + def applyAllFrames(self, checked=False): + # Preview needs to be turned off because we are computing on all frames + self.previewCheckbox.setChecked(False) + self.apply(signal=self.sigApplyAllFrames) + + def emitSigPreviewToggled(self): + self.sigPreviewToggled.emit(self.previewCheckbox.isChecked()) + + def showInfoKeepInputDataType(self): + txt = html_utils.paragraph(""" + If checked, the data type of the pre-processed data will be + the same as the input data type.

+ This is useful to avoid saving the pre-processed data as + floating-point numbers (e.g., 32-bit float) which might + increase the file size.

+ We recommend keeping this option checked. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "Keep input data type", txt) + + def emitSignalSavePreprocData(self): + self.sigSavePreprocData.emit(self) + + def emitValuesChanged(self): + recipe = self.recipe(warn=False) + if recipe is None: + return + + self.sigValuesChanged.emit(recipe) + + def setDisabled(self, disabled: bool): + self.preProcessParamsWidget.setDisabled(disabled) + self.loadingCircle.setVisible(disabled) + self.infoLabel.setVisible(disabled) + for button in self.allButtons: + try: + button.setDisabled(disabled) + except RuntimeError as e: + printl(traceback.format_exc()) + printl(f"Error: {e}") + printl(f"Button: {button}") + + def apply(self, checked=False, signal: Signal = None): + recipe = self.recipe() + if recipe is None: + return + + if signal is not None: + signal.emit(recipe) + + if self.hideOnClosing: + self.setDisabled(True) + self.infoLabel.setText( + f"{self.sender().text().replace('Apply', 'Applying')}...
" + "(Feel free to use Cell-ACDC while waiting)" + ) + else: + self.ok_cb() + + def appliedFinished(self): + self.setDisabled(False) + + def recipe(self, warn=True): + recipe = self.preProcessParamsWidget.recipe(warn=warn) + if recipe is None: + return + + for step in recipe: + step["keep_input_data_type"] = self.keepInputDataTypeToggle.isChecked() + return recipe + + def recipeConfigPars(self): + return self.preProcessParamsWidget.recipeConfigPars("acdc") + + def ok_cb(self): + if self.hideOnClosing: + self.hide() + return + + self.cancel = False + self.close() + + def close(self): + super().close() + self.sigClose.emit(self) + + +class PreProcessRecipeDialogUtil(PreProcessRecipeDialog): + def __init__( + self, + channel_names: Iterable[str], + df_metadata=None, + parent=None, + ): + self.cancel = True + + super().__init__( + isTimelapse=False, + isZstack=False, + isMultiPos=False, + addApplyButton=False, + df_metadata=df_metadata, + parent=parent, + hideOnClosing=False, + ) + + self.listSelector = widgets.listWidget( + isMultipleSelection=True, minimizeHeight=True + ) + self.listSelector.addItems(channel_names) + self.listSelector.setCurrentRow(0) + + self.mainLayout.insertWidget(0, self.listSelector) + self.mainLayout.insertWidget(0, QLabel("Select channel(s) to pre-process:")) + self.mainLayout.insertSpacing(2, 10) + self.mainLayout.insertWidget(2, widgets.QHLine()) + + self.savePreprocButton.hide() + self.previewCheckbox.hide() + self.applyCurrentFrameButton.setText("Ok") + + buttonsLayout = self.preProcessParamsWidget.buttonsLayout + + saveRecipeButtonIndex = buttonsLayout.indexOf( + self.preProcessParamsWidget.saveRecipeButton + ) + + if saveRecipeButtonIndex == -1: + return + + saveRecipeButtonItem = buttonsLayout.takeAt(saveRecipeButtonIndex) + + buttonsLayout.addItem(saveRecipeButtonItem, 0, 2) + + def warnChannelSelectionEmpty(self): + txt = html_utils.paragraph(""" + You did not select any channel.

+ Channel selection cannot be empty.

+ Thank you for your patience! + """) + + def ok_cb(self): + selectedChannelItems = self.listSelector.selectedItems() + if not selectedChannelItems: + self.warnChannelSelectionEmpty() + + recipe = self.recipe() + if recipe is None: + return + + self.selectedRecipe = recipe + self.selectedChannels = [item.text() for item in selectedChannelItems] + + self.cancel = False + self.close() + + +class CombineChannelsSetupDialog(PreProcessRecipeDialog): + sigApplyImage = Signal(dict, bool, str) + sigApplyZstack = Signal(dict, bool, str) + sigApplyAllFrames = Signal(dict, bool, str) + sigApplyAllPos = Signal(dict, bool, str) + sigValuesChanged = Signal() + sigSaveAsSegmCheckboxToggled = Signal(bool) + + # sigApplyAllZslices = Signal(dict, bool, str) + # sigApplyAllFramesZslices = Signal(dict, bool, str) + + def __init__( + self, + channel_names, + df_metadata=None, + parent=None, + hideOnClosing=False, + isTimelapse=False, + isZstack=False, + isMultiPos=False, + ): + + self.combineChannelsWidget = CombineChannelsWidget(channel_names, parent=self) + self.warnExistingRecipeFile = self.combineChannelsWidget.warnExistingRecipeFile + self.communicateSavingRecipeFinished = ( + self.combineChannelsWidget.communicateSavingRecipeFinished + ) + self.saveRecipeUI = self.combineChannelsWidget.saveRecipeUI + self.selectRecipeFilepath = self.combineChannelsWidget.selectRecipeFilepath + + super().__init__( + isTimelapse=isTimelapse, + isZstack=isZstack, + isMultiPos=isMultiPos, + df_metadata=df_metadata, + parent=parent, + hideOnClosing=hideOnClosing, + ) + + self.combineChannelsWidget.sigValuesChangedCombineChannels.connect( + self.emitValuesChangedSteps + ) + + self.segm_blinked = False + self.validFormula = True # allow empty formula + self.forbiddenChannels = set() # channels that cannot be combined + + self.mainLayout.setSpacing(4) + + self.mainLayout.insertWidget(2, self.combineChannelsWidget) + self.combineChannelsWidget.groupbox.setCheckable(False) + self.combineChannelsWidget.groupbox.setTitle( + "Combine and manipulate channels and/or segmentation files" + ) + + self.formulaEditWidget = FormulaEditWidget(parent=self) + self._updateFormulaVariableNames() + self.formulaEditWidget.sigFormulaChanged.connect(self.formulaChanged) + self.formulaEditWidget.setToolTip( + 'Enter a formula to combine the channels. For example "img1 + img2 * 0.5"' + ) + self.mainLayout.insertWidget(3, self.formulaEditWidget) + + buttonsLayoutSaveGroup = QGridLayout() + + row = 0 + col = 0 + loadRecipeButton = widgets.OpenFilePushButton("Load saved recipe") + self.loadRecipeButtonComb = loadRecipeButton + buttonsLayoutSaveGroup.addWidget(loadRecipeButton, row, col) + self.loadRecipeButtonComb.clicked.connect(self.selectAndLoadRecipe) + + col += 1 + saveRecipeButton = widgets.savePushButton("Save current recipe") + self.saveRecipeButtonComb = saveRecipeButton + buttonsLayoutSaveGroup.addWidget(saveRecipeButton, row, col) + saveRecipeButton.clicked.connect(self.saveRecipe) + saveRecipeButton.setToolTip( + "Save the current recipe to a file\n" + f"Location: {combine_channels_recipes_path}" + ) + + col += 1 + loadLastRecipeButton = widgets.reloadPushButton("Load last recipe") + self.loadLastRecipeButtonComb = loadLastRecipeButton + buttonsLayoutSaveGroup.addWidget(loadLastRecipeButton, row, col) + self.mainLayout.addLayout(buttonsLayoutSaveGroup) + loadLastRecipeButton.clicked.connect(self.loadLastRecipe) + self.setLoadLastRecipe() + + loadLastRecipeButton.setContextMenuPolicy( + Qt.ContextMenuPolicy.CustomContextMenu + ) + loadLastRecipeButton.customContextMenuRequested.connect( + self._showLoadRecipeContextMenu + ) + + self.cancel = True + + self.setWindowTitle("Combine and manipulate channels and/or segmentation files") + self.preProcessParamsWidget.hide() + self.mainLayout.removeWidget(self.preProcessParamsWidget) + + self.savePreprocButton.setText("Save combined data...") + + tooltip = ( + "Save as a segmentation file, for example " + "when combining a binary mask with a segmentation mask." + ) + label = QLabel("Save as segmentation:") + self.saveAsSegmlabel = label + label.setToolTip(tooltip) + self.saveAsSegmCheckbox = widgets.Toggle() + self.saveAsSegmCheckbox.setToolTip(tooltip) + self.saveAsSegmCheckbox.setChecked(False) + self.saveAsSegmCheckbox.setEnabled(False) + self.saveAsSegmCheckbox.toggled.connect(self.emitSaveAsSegmCheckboxToggled) + + self.keepInputDataTypeLayout.insertWidget(0, label) + self.keepInputDataTypeLayout.insertWidget(1, self.saveAsSegmCheckbox) + + def setLoadLastRecipe(self): + filepath = self._lastRecipePath() + if not os.path.exists(filepath): + self.loadLastRecipeButtonComb.setEnabled(False) + + def returLoadSecondLastRecipe(self): + filepath = self._secondLastRecipePath() + if not os.path.exists(filepath): + return False + return True + + def _showLoadRecipeContextMenu(self, pos): + menu = QMenu(self) + action = menu.addAction("Load recipe from before the last one") + action.triggered.connect(self.loadPreviousRecipe) + action.setEnabled(self.returLoadSecondLastRecipe()) + menu.exec(self.loadLastRecipeButtonComb.mapToGlobal(pos)) + + def loadPreviousRecipe(self): + filepath = self._secondLastRecipePath() + if not os.path.exists(filepath): + return + + self.loadRecipe(filepath) + + def loadLastRecipe(self): + filepath = self._lastRecipePath() + if not os.path.exists(filepath): + return + + self.loadRecipe(filepath) + + def saveLastRecipe(self): + os.makedirs(combine_channels_recipes_path, exist_ok=True) + filepath = self._lastRecipePath() + + same = False + if os.path.exists(filepath): + steps_curr = self._getSaveRecipyDict() + with open(filepath, "r") as f: + steps_prev = json.load(f) + same = self._recipesMatch(steps_curr, steps_prev) + + if same: + return + + if os.path.exists(filepath): + new_filename = self._secondLastRecipePath() + if os.path.exists(new_filename): + os.remove(new_filename) + os.rename(filepath, new_filename) + self.saveRecipe(filepath=filepath) + + def _recipesMatch(self, steps_curr, steps_prev): + # Normalize current dict to strings for comparison with JSON-loaded dict + def normalize(d): + return {str(k): str(v) for k, v in d.items()} + + for raw_key in steps_curr: + key = str(raw_key) + if key not in steps_prev: + return False + if key in ("formula", "keep_input_data_type", "save_as_segm"): + if str(steps_curr[raw_key]) != str(steps_prev[key]): + return False + else: + step_dict = normalize(steps_curr[raw_key]) + step_dict_prev = steps_prev[key] + for key2, val2 in step_dict.items(): + if key2 not in step_dict_prev: + return False + if val2 != str(step_dict_prev[key2]): + return False + return True + + def _lastRecipePath(self): + return os.path.join( + combine_channels_recipes_path, ".last_combine_channels_recipe.json" + ) + + def _secondLastRecipePath(self): + return os.path.join( + combine_channels_recipes_path, ".previous_combine_channels_recipe.json" + ) + + def _getSaveRecipyDict(self): + steps = self.combineChannelsWidget.steps() # already returns a copy + formula = self.formulaEditWidget.text() + steps["formula"] = formula + steps["keep_input_data_type"] = self.keepInputDataTypeToggle.isChecked() + steps["save_as_segm"] = self.saveAsSegmCheckbox.isChecked() + return steps + + def saveRecipe(self, dummy=None, filepath=None): + os.makedirs(combine_channels_recipes_path, exist_ok=True) + + filepath_provided = filepath is not None + if not filepath_provided: + folder_content = utils.listdir(combine_channels_recipes_path) + num_recipes = len(folder_content) + default_text = f"{num_recipes + 1}" + proceed, filepath = self.saveRecipeUI( + combine_channels_recipes_path, + ".json", + "Save recipe", + "combine_channels_recipe", + "Insert a filename for the recipe:", + default_text, + ) + + if not proceed: + return + + steps = self._getSaveRecipyDict() + + with open(filepath, "w") as f: + json.dump(steps, f, indent=2) + + if not filepath_provided: + self.communicateSavingRecipeFinished(filepath) + + def selectAndLoadRecipe(self): + filepath = self.selectRecipeFilepath( + combine_channels_recipes_path, "combine_channels_recipe", "JSON", "json" + ) + if filepath is None: + return + + self.loadRecipe(filepath) + + def loadRecipe(self, filepath): + with open(filepath, "r") as f: + recipe = json.load(f) + + recipe = dict(sorted(recipe.items())) + keys_used = set() + for key, value in recipe.items(): + if key == "formula": + formula = value + continue + if key == "keep_input_data_type": + self.keepInputDataTypeToggle.setChecked(value) + continue + if key == "save_as_segm": + self.saveAsSegmCheckbox.setChecked(value) + continue + + name = value["name"] + channel = value["channel"] + binarize = value["binarize"] + min_val = float(value["min_val"]) + max_val = float(value["max_val"]) + key = int(key) + stepWidgetsNum = len(self.combineChannelsWidget.stepsWidgets) + if key > stepWidgetsNum: + self.combineChannelsWidget.addStep() + + stepWidgets = self.combineChannelsWidget.stepsWidgets[key] + idx = stepWidgets["selector"].findText(channel) + if idx == -1: + stepWidgets["selector"].addItem(channel) + # stepWidgets['selector'].forbiddenItems.add(channel) + blinker = qutils.QControlBlink(stepWidgets["selector"], qparent=self) + blinker.start() + stepWidgets["selector"].blinker = blinker + self.forbiddenChannels.add(channel) + + stepWidgets["selector"].setCurrentText(channel) + stepWidgets["name_edit"].setText(name) + stepWidgets["binarize"].setCurrentText(binarize) + stepWidgets["minValueSpinbox"].setValue(min_val) + stepWidgets["maxValueSpinbox"].setValue(max_val) + + keys_used.add(key) + + # remove extra steps + keys_present = set(range(1, len(self.combineChannelsWidget.stepsWidgets) + 1)) + extra_keys = keys_present - keys_used + extra_keys = list(extra_keys) + extra_keys.sort(reverse=True) + for key in extra_keys: + self.combineChannelsWidget.removeStep(step_n=key) + # updates key dynamically so I have to rely that missing indx are always last steps + + # update formula + self.formulaEditWidget.setText(formula) + + for stepWidgets in self.combineChannelsWidget.stepsWidgets.values(): + combo = stepWidgets["selector"] + # set forbidden channels red in all steps + for i in range(combo.count()): + item = combo.itemText(i) + if item in self.forbiddenChannels: + combo.setItemData(i, QColor("red"), Qt.ForegroundRole) + + def _updateFormulaVariableNames(self): + names = [ + stepWidgets["name_edit"].text() + for stepWidgets in self.combineChannelsWidget.stepsWidgets.values() + ] + self.formulaEditWidget.setVariableNames(names) + + def formulaChanged(self, formula_str, is_valid): + self.setButtonsEnabled(is_valid) + self.validFormula = is_valid + if is_valid: + self.sigValuesChanged.emit() + + def setButtonsEnabled(self, enabled): + for i in range(self.buttonsLayout.count()): + item = self.buttonsLayout.itemAt(i) + widget = item.widget() + if widget is None: + continue + if isinstance(widget, QPushButton): + label = widget.text().lower().rstrip().lstrip() + if "apply" in label or "save" in label or "ok" in label: + if enabled: + try: + widget.setEnabled(True) + except: + pass + else: + try: + widget.setDisabled(True) + except: + pass + + def saveAsSegm(self): + return self.saveAsSegmCheckbox.isChecked() + + def emitSaveAsSegmCheckboxToggled(self): + if self.validFormula: + self.sigSaveAsSegmCheckboxToggled.emit(self.saveAsSegm()) + + def autoCheckSaveAsSegmCheckbox(self): + any_not_seg = False + for step in self.combineChannelsWidget.steps().values(): + channel = step["channel"] + if "segm" not in channel: + any_not_seg = True + break + + if any_not_seg: + self.saveAsSegmCheckbox.setChecked(False) + self.saveAsSegmCheckbox.setEnabled(False) + else: + if not self.segm_blinked: + self.saveAsSegmCheckbox.setEnabled(True) + self.blinker = qutils.QControlBlink( + self.saveAsSegmCheckbox, qparent=self + ) + self.blinker.start() + self.segm_blinked = True + + def apply(self, checked=False, signal: Signal = None): + steps = self.combineChannelsWidget.steps() + formula = self.formulaEditWidget.text() + keep_input_dtype = self.keepInputDataTypeToggle.isChecked() + if not steps or not self.validFormula: + return + + if signal is not None: + try: + signal.emit(steps, formula) + except TypeError as err: + signal.emit(steps, keep_input_dtype, formula) + + self.saveLastRecipe() + if self.hideOnClosing: + self.setDisabled(True) + self.infoLabel.setText( + f"{self.sender().text().replace('Apply', 'Applying')}...
" + "(Feel free to use Cell-ACDC while waiting)" + ) + else: + self.ok_cb(saveLastRecipe=False) + + # Not needed anymore since now we funnel all changes to the formulaEditWidget, which then verifies the formula and + # emits a signal via formulaChangeda + # def emitValuesChanged(self): + # if not self.validFormula: + # return + # self.sigValuesChanged.emit() + + def emitValuesChangedSteps(self): + self.autoCheckSaveAsSegmCheckbox() + self._updateFormulaVariableNames() + + def ok_cb(self, dummy=None, saveLastRecipe=True): + if not self.validFormula: + return + + if saveLastRecipe: + self.saveLastRecipe() + + self.keepInputDataType = self.keepInputDataTypeToggle.isChecked() + self.selectedSteps = self.combineChannelsWidget.steps() + self.formula = self.formulaEditWidget.text() + self.cancel = False + self.close() + + +class CombineChannelsSetupDialogUtil(CombineChannelsSetupDialog): + def __init__( + self, + channel_names, + df_metadata=None, + parent=None, + ): + + super().__init__(channel_names, parent=parent, df_metadata=df_metadata) + + # add int input for number of workers + + self.mainLayout.addSpacing(20) + + qutils.hide_and_delete_layout(self.buttonsLayout) + buttonsLayout = widgets.CancelOkButtonsLayout() + self.buttonsLayout = buttonsLayout + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addLayout(buttonsLayout) + + self.nThreadsSpinBox = QSpinBox() + self.nThreadsSpinBox.setMinimum(1) + self.nThreadsSpinBox.setValue(4) + self.nThreadsSpinBox.setToolTip("Number of threads to use for processing") + self.mainLayout.addWidget(QLabel("Number of threads:")) + self.mainLayout.addWidget(self.nThreadsSpinBox) + + +class CombineChannelsSetupDialogGUI(CombineChannelsSetupDialog): + def __init__( + self, + channel_names: Iterable[str], + df_metadata=None, + isTimelapse=False, + isZstack=False, + isMultiPos=False, + parent=None, + hideOnClosing=False, + ): + super().__init__( + channel_names, + df_metadata=df_metadata, + isTimelapse=isTimelapse, + isZstack=isZstack, + isMultiPos=isMultiPos, + parent=parent, + hideOnClosing=hideOnClosing, + ) + + # remove the preprocess buttons, we use the comb version of them + qutils.delete_widget(self.loadLastRecipeButton) + qutils.delete_widget(self.saveRecipeButton) + qutils.delete_widget(self.loadRecipeButton) + + # self.allButtons.remove(self.loadLastRecipeButton) + self.allButtons.remove(self.saveRecipeButton) + self.allButtons.remove(self.loadRecipeButton) + + self.previewCheckbox.setChecked(True) + self.saveAsSegmlabel.setText("Save and view as segmentation") + + def steps(self, return_keepInputDataType=False): + steps = self.combineChannelsWidget.steps() + formula = self.formulaEditWidget.text() + # if not return_keepInputDataType: + # return steps, formula + + keep_input_dtype = self.keepInputDataTypeToggle.isChecked() + return steps, keep_input_dtype, formula + + +class TestSegmModelInitalDialog(QBaseDialog): + def __init__(self, parent=None): + super().__init__(parent) + + self.cancel = True + + mainLayout = QVBoxLayout() + entriesLayout = widgets.FormLayout() + + row = 0 + self.startFrameNumberSpinbox = widgets.SpinBox() + self.startFrameNumberSpinbox.setMinimum(1) + + self.startFrameNumberFormWidget = widgets.formWidget( + self.startFrameNumberSpinbox, + labelTextLeft="Start frame number", + addActivateCheckbox=True, + ) + entriesLayout.addFormWidget(self.startFrameNumberFormWidget, row=row) + + row += 1 + self.stopFrameNumberSpinbox = widgets.SpinBox() + self.stopFrameNumberSpinbox.setMinimum(1) + + self.stopFrameNumberFormWidget = widgets.formWidget( + self.stopFrameNumberSpinbox, + labelTextLeft="Stop frame number", + addActivateCheckbox=True, + ) + entriesLayout.addFormWidget(self.stopFrameNumberFormWidget, row=row) + + row += 1 + self.startZsliceNumberSpinbox = widgets.SpinBox() + self.startZsliceNumberSpinbox.setMinimum(1) + + self.startZsliceNumberFormWidget = widgets.formWidget( + self.startZsliceNumberSpinbox, + labelTextLeft="Start z-slice number", + addActivateCheckbox=True, + ) + entriesLayout.addFormWidget(self.startZsliceNumberFormWidget, row=row) + + row += 1 + self.stopZsliceNumberSpinbox = widgets.SpinBox() + self.stopZsliceNumberSpinbox.setMinimum(1) + + self.stopZsliceNumberFormWidget = widgets.formWidget( + self.stopZsliceNumberSpinbox, + labelTextLeft="Stop z-slice number", + addActivateCheckbox=True, + ) + entriesLayout.addFormWidget(self.stopZsliceNumberFormWidget, row=row) + + row += 1 + + self.isTimelapseToggleFormWidget = widgets.formWidget( + widgets.Toggle(), + labelTextLeft="Is timelapse?", + stretchWidget=False, + valueGetterName="isChecked", + ) + entriesLayout.addFormWidget(self.isTimelapseToggleFormWidget, row=row) + + # self.stopFrameNumberSpinbox + # self.startZsliceNumberSpinbox + # self.stopZsliceNumberSpinbox + # self.isTimelapseToggle + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(entriesLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def ok_cb(self): + self.cancel = False + + self.start_frame_n = self.startFrameNumberFormWidget.value() + self.stop_frame_n = self.stopFrameNumberFormWidget.value() + self.start_z_slice_n = self.startZsliceNumberFormWidget.value() + self.stop_z_slice_n = self.stopZsliceNumberFormWidget.value() + self.is_timelapse = self.isTimelapseToggleFormWidget.value() + + self.close() + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + ArgWidget, +) +from .general import ( + imageViewer, +) +from .measurements import ( + SelectFeaturesRangeDialog, +) +from .metadata import ( + filenameDialog, +) + diff --git a/cellacdc/dialogs/tracking.py b/cellacdc/dialogs/tracking.py new file mode 100644 index 000000000..b408ed538 --- /dev/null +++ b/cellacdc/dialogs/tracking.py @@ -0,0 +1,2871 @@ +"""Cell-ACDC dialog windows: tracking.""" + +import os +import sys +import re +from typing import Literal, Callable, Dict, Iterable, List, Tuple +import datetime +import pathlib +from collections import defaultdict +import zipfile +from heapq import nlargest +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle, Circle, PathPatch, Path +import numpy as np +import scipy.interpolate + +try: + import tkinter as tk +except Exception as err: + pass + +import cv2 +import traceback +from itertools import combinations, permutations +from collections import namedtuple +from natsort import natsorted + +# from MyWidgets import Slider, Button, MyRadioButtons +from skimage.measure import label, regionprops +from functools import partial +import skimage.filters +import skimage.measure +import skimage.morphology +import skimage.exposure +import skimage.draw +import skimage.registration +import skimage.color +import skimage.segmentation +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import math +import time +import sympy as sp +import json +import html + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from qtpy import QtCore +from qtpy.QtGui import ( + QIcon, + QFontMetrics, + QKeySequence, + QFont, + QRegularExpressionValidator, + QCursor, + QKeyEvent, + QPixmap, + QFont, + QPalette, + QMouseEvent, + QColor, +) +from qtpy.QtCore import ( + Qt, + QSize, + QEvent, + Signal, + QEventLoop, + QTimer, + QRegularExpression, +) +from qtpy.QtWidgets import ( + QFileDialog, + QApplication, + QMainWindow, + QMenu, + QLabel, + QToolBar, + QScrollBar, + QWidget, + QVBoxLayout, + QLineEdit, + QPushButton, + QHBoxLayout, + QDialog, + QFormLayout, + QListWidget, + QAbstractItemView, + QButtonGroup, + QCheckBox, + QSizePolicy, + QComboBox, + QSlider, + QGridLayout, + QSpinBox, + QToolButton, + QTableView, + QTextBrowser, + QDoubleSpinBox, + QScrollArea, + QFrame, + QProgressBar, + QGroupBox, + QRadioButton, + QDockWidget, + QMessageBox, + QStyle, + QPlainTextEdit, + QSpacerItem, + QTreeWidget, + QTreeWidgetItem, + QTextEdit, + QSplashScreen, + QAction, + QListWidgetItem, + QActionGroup, + QHeaderView, + QStyledItemDelegate, +) +import qtpy.compat + +from .. import exception_handler +from .. import load, prompts, core, measurements, html_utils +from .. import is_mac, is_win, is_linux, settings_folderpath, config +from .. import preproc_recipes_path, segm_recipes_path, combine_channels_recipes_path +from .. import is_conda_env +from .. import printl +from .. import colors +from .. import issues_url +from .. import utils +from .. import qutils +from .. import _palettes +from .. import base_cca_dict +from .. import widgets +from .. import user_profile_path, promptable_models_path, models_path +from .. import features +from .. import _core +from .. import _types +from .. import plot +from .. import urls +from ..acdc_regex import float_regex, is_alphanumeric_filename, to_alphanumeric +from .. import _base_widgets +from .. import io +from .. import cca_functions +from .. import path + +POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) +TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() +LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() +BACKGROUND_RGBA = _palettes.get_disabled_colors()["Button"] + +font = QFont() +font.setPixelSize(12) +italicFont = QFont() +italicFont.setPixelSize(12) +italicFont.setItalic(True) + +from ._base import ( + QBaseDialog, +) +from .export import ( + pdDataFrameWidget, +) +from .general import ( + QLineEditDialog, +) + +class TrackSubCellObjectsDialog(QBaseDialog): + def __init__(self, basename="", parent=None): + self.cancel = True + super().__init__(parent=parent) + + self.setWindowTitle("Track sub-cellular objects parameters") + + mainLayout = QVBoxLayout() + entriesLayout = widgets.FormLayout() + + row = 0 + infoTxt = html_utils.paragraph(""" + Select behaviour with untracked objects:

+ NOTE: this utility always create new files. + Original segmentation masks
are not modified
. + """) + options = ( + "Delete sub-cellular objects that do not belong to any cell", + "Delete cells that do not have any sub-cellular object", + "Delete both cells and sub-cellular objects without an assignment", + "Only track the objects and keep all the non-tracked objects", + ) + combobox = widgets.QCenteredComboBox() + combobox.addItems(options) + self.optionsWidget = widgets.formWidget( + combobox, + addInfoButton=True, + labelTextLeft="Tracking mode: ", + infoTxt=infoTxt, + ) + entriesLayout.addFormWidget(self.optionsWidget, row=row) + + row += 1 + infoTxt = html_utils.paragraph(""" + Re-label sub-cellular objects before assigning them to the cell.

+ Activate this option if you have merged sub-cellular objects + that must be separated, or the segmentation is a boolean mask + (i.e., semantic segmentation). + """) + self.relabelSubObjLab = widgets.formWidget( + widgets.Toggle(), + addInfoButton=True, + stretchWidget=False, + labelTextLeft="Re-label sub-cellular objects before tracking: ", + infoTxt=infoTxt, + ) + entriesLayout.addFormWidget(self.relabelSubObjLab, row=row) + + row += 1 + IoAtext = html_utils.paragraph(""" + Enter a minimum percentage (0-1) of the sub-cellular object's area
+ that MUST overlap with the parent cell to be considered belonging to a cell: + """) + spinbox = widgets.CenteredDoubleSpinbox() + spinbox.setMaximum(1) + spinbox.setValue(0.5) + spinbox.setSingleStep(0.1) + self.IoAwidget = widgets.formWidget( + spinbox, + addInfoButton=True, + labelTextLeft="IoA threshold: ", + infoTxt=IoAtext, + ) + entriesLayout.addFormWidget(self.IoAwidget, row=row) + + row += 1 + infoTxt = html_utils.paragraph(""" + The third segmentation file is the result of subtracting the + sub-cellular objects from the parent objects

+ This is useful if, for example, you need to compute measurements + only from the cytoplasm (i.e., the sub-cellular object is the nucleus). + """) + self.createThirdSegmWidget = widgets.formWidget( + widgets.Toggle(), + addInfoButton=True, + stretchWidget=False, + labelTextLeft="Create third segmentation: ", + infoTxt=infoTxt, + ) + entriesLayout.addFormWidget(self.createThirdSegmWidget, row=row) + + row += 1 + infoTxt = html_utils.paragraph(""" + Text to append at the end of the third segmentation file.

+ The third segmentation file is the result of subtracting the + sub-cellular objects from the parent objects

+ This is useful if, for example, you need to compute measurements + only from the cytoplasm (i.e., the sub-cellular object is the nucleus). + """) + lineEdit = widgets.alphaNumericLineEdit() + lineEdit.setText("difference") + lineEdit.setAlignment(Qt.AlignCenter) + self.appendTextWidget = widgets.formWidget( + lineEdit, + addInfoButton=True, + labelTextLeft="Text to append: ", + infoTxt=infoTxt, + ) + entriesLayout.addFormWidget(self.appendTextWidget, row=row) + self.appendTextWidget.setDisabled(True) + + self.createThirdSegmWidget.widget.toggled.connect(self.createThirdSegmToggled) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + mainLayout.addLayout(entriesLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + self.setFont(font) + + def createThirdSegmToggled(self, checked): + self.appendTextWidget.setDisabled(not checked) + + def ok_cb(self): + self.cancel = False + if self.createThirdSegmWidget.widget.isChecked(): + if not self.appendTextWidget.widget.text(): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph( + "When creating the third segmentation file, " + "the name to append cannot be empty!" + ) + msg.critical(self, "Empty name", txt) + return + + self.trackSubCellObjParams = { + "how": self.optionsWidget.widget.currentText(), + "IoA": self.IoAwidget.widget.value(), + "createThirdSegm": self.createThirdSegmWidget.widget.isChecked(), + "relabelSubObjLab": self.relabelSubObjLab.widget.isChecked(), + "thirdSegmAppendedText": self.appendTextWidget.widget.text(), + } + self.close() + + +class CellACDCTrackerParamsWin(QDialog): + def __init__(self, parent=None): + self.cancel = True + super().__init__(parent) + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.setWindowTitle("Cell-ACDC tracker parameters") + + paramsLayout = QGridLayout() + paramsBox = QGroupBox() + + row = 0 + label = QLabel(html_utils.paragraph("Minimum overlap between objects")) + paramsLayout.addWidget(label, row, 0) + maxOverlapSpinbox = QDoubleSpinBox() + maxOverlapSpinbox.setAlignment(Qt.AlignCenter) + maxOverlapSpinbox.setMinimum(0) + maxOverlapSpinbox.setMaximum(1) + maxOverlapSpinbox.setSingleStep(0.1) + maxOverlapSpinbox.setValue(0.4) + self.maxOverlapSpinbox = maxOverlapSpinbox + paramsLayout.addWidget(maxOverlapSpinbox, row, 1) + infoButton = widgets.infoPushButton() + infoButton.clicked.connect(self.showInfo) + paramsLayout.addWidget(infoButton, row, 2) + paramsLayout.setColumnStretch(0, 0) + paramsLayout.setColumnStretch(1, 1) + paramsLayout.setColumnStretch(2, 0) + + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton(" Ok ") + cancelButton.clicked.connect(self.cancel_cb) + okButton.clicked.connect(self.ok_cb) + + buttonsLayout = QHBoxLayout() + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + layout = QVBoxLayout() + infoText = html_utils.paragraph("Cell-ACDC tracker parameters") + infoLabel = QLabel(infoText) + layout.addWidget(infoLabel, alignment=Qt.AlignCenter) + layout.addSpacing(10) + paramsBox.setLayout(paramsLayout) + layout.addWidget(paramsBox) + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + layout.addStretch(1) + self.setLayout(layout) + self.setFont(font) + + def showInfo(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "Cell-ACDC tracker computes the percentage of overlap between " + "all the objects
at frame n and all the " + "objects in previous frame n-1.

" + "All objects with overlap less than " + "Minimum overlap between objects
are considered " + "new objects.

" + "Set this value to 0 if you want to force tracking of ALL the " + "objects
in the previous frame (e.g., if cells move a lot " + "between frames)" + ) + msg.information(self, "Cell-ACDC tracker info", txt) + + def ok_cb(self, checked=False): + self.cancel = False + self.params = {"IoA_thresh": self.maxOverlapSpinbox.value()} + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + super().show() + self.resize(int(self.width() * 1.3), self.height()) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class BayesianTrackerParamsWin(QDialog): + def __init__(self, segmShape, parent=None, channels=None, currentChannelName=None): + self.cancel = True + super().__init__(parent) + + self.channels = channels + self.currentChannelName = currentChannelName + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + self.setWindowTitle("Bayesian tracker parameters") + + paramsLayout = QGridLayout() + paramsBox = QGroupBox() + + row = 0 + this_path = os.path.dirname(os.path.abspath(__file__)) + default_model_path = os.path.join( + this_path, "trackers", "BayesianTracker", "model", "cell_config.json" + ) + label = QLabel(html_utils.paragraph("Model path")) + paramsLayout.addWidget(label, row, 0) + modelPathLineEdit = QLineEdit() + start_dir = "" + if os.path.exists(default_model_path): + start_dir = os.path.dirname(default_model_path) + modelPathLineEdit.setText(default_model_path) + self.modelPathLineEdit = modelPathLineEdit + paramsLayout.addWidget(modelPathLineEdit, row, 1) + browseButton = widgets.browseFileButton( + title="Select Bayesian Tracker model file", + ext={"JSON Config": (".json",)}, + start_dir=start_dir, + ) + browseButton.sigPathSelected.connect(self.onPathSelected) + paramsLayout.addWidget(browseButton, row, 2, alignment=Qt.AlignLeft) + + if self.channels is not None: + row += 1 + label = QLabel(html_utils.paragraph("Intensity image channel: ")) + paramsLayout.addWidget(label, row, 0) + items = ["None", *self.channels] + self.channelCombobox = widgets.QCenteredComboBox() + self.channelCombobox.addItems(items) + paramsLayout.addWidget(self.channelCombobox, row, 1) + if self.currentChannelName is not None: + self.channelCombobox.setCurrentText(self.currentChannelName) + + row += 1 + label = QLabel(html_utils.paragraph("Features")) + paramsLayout.addWidget(label, row, 0) + selectFeaturesButton = widgets.setPushButton("Select features") + paramsLayout.addWidget(selectFeaturesButton, row, 1) + self.features = [] + selectFeaturesButton.clicked.connect(self.selectFeatures) + + row += 1 + label = QLabel(html_utils.paragraph("Verbose")) + paramsLayout.addWidget(label, row, 0) + verboseToggle = widgets.Toggle() + verboseToggle.setChecked(True) + self.verboseToggle = verboseToggle + paramsLayout.addWidget(verboseToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Run optimizer")) + paramsLayout.addWidget(label, row, 0) + optimizeToggle = widgets.Toggle() + optimizeToggle.setChecked(True) + self.optimizeToggle = optimizeToggle + paramsLayout.addWidget(optimizeToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Max search radius")) + paramsLayout.addWidget(label, row, 0) + maxSearchRadiusSpinbox = QSpinBox() + maxSearchRadiusSpinbox.setAlignment(Qt.AlignCenter) + maxSearchRadiusSpinbox.setMinimum(1) + maxSearchRadiusSpinbox.setMaximum(2147483647) + maxSearchRadiusSpinbox.setValue(50) + self.maxSearchRadiusSpinbox = maxSearchRadiusSpinbox + self.maxSearchRadiusSpinbox.setDisabled(True) + paramsLayout.addWidget(maxSearchRadiusSpinbox, row, 1) + + row += 1 + Z, Y, X = segmShape + label = QLabel(html_utils.paragraph("Tracking volume")) + paramsLayout.addWidget(label, row, 0) + volumeLineEdit = QLineEdit() + defaultVol = f" (0, {X}), (0, {Y}) " + if Z > 1: + defaultVol = f"{defaultVol}, (0, {Z}) " + volumeLineEdit.setText(defaultVol) + volumeLineEdit.setAlignment(Qt.AlignCenter) + self.volumeLineEdit = volumeLineEdit + paramsLayout.addWidget(volumeLineEdit, row, 1) + + row += 1 + label = QLabel(html_utils.paragraph("Interactive mode step size")) + paramsLayout.addWidget(label, row, 0) + stepSizeSpinbox = QSpinBox() + stepSizeSpinbox.setAlignment(Qt.AlignCenter) + stepSizeSpinbox.setMinimum(1) + stepSizeSpinbox.setMaximum(2147483647) + stepSizeSpinbox.setValue(100) + self.stepSizeSpinbox = stepSizeSpinbox + paramsLayout.addWidget(stepSizeSpinbox, row, 1) + + row += 1 + label = QLabel(html_utils.paragraph("Update method")) + paramsLayout.addWidget(label, row, 0) + updateMethodCombobox = QComboBox() + updateMethodCombobox.addItems(["EXACT", "APPROXIMATE"]) + self.updateMethodCombobox = updateMethodCombobox + self.updateMethodCombobox.currentTextChanged.connect(self.methodChanged) + paramsLayout.addWidget(updateMethodCombobox, row, 1) + + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton(" Ok ") + cancelButton.clicked.connect(self.cancel_cb) + okButton.clicked.connect(self.ok_cb) + + buttonsLayout = QHBoxLayout() + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + layout = QVBoxLayout() + infoText = html_utils.paragraph("Bayesian Tracker parameters") + infoLabel = QLabel(infoText) + layout.addWidget(infoLabel, alignment=Qt.AlignCenter) + layout.addSpacing(10) + paramsBox.setLayout(paramsLayout) + layout.addWidget(paramsBox) + + url = "https://btrack.readthedocs.io/en/latest/index.html" + moreInfoText = html_utils.paragraph( + "Find more info on the Bayesian Tracker's " + f'home page' + ) + moreInfoLabel = QLabel(moreInfoText) + moreInfoLabel.setOpenExternalLinks(True) + layout.addWidget(moreInfoLabel, alignment=Qt.AlignCenter) + + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + layout.addStretch(1) + self.setLayout(layout) + self.setFont(font) + + def selectFeatures(self): + features = measurements.get_btrack_features() + selectWin = widgets.QDialogListbox( + "Select features", + "Select features to use for tracking:\n", + features, + multiSelection=True, + parent=self, + includeSelectionHelp=True, + ) + for i in range(selectWin.listBox.count()): + item = selectWin.listBox.item(i) + if item.text() in self.features: + item.setSelected(True) + selectWin.exec_() + if selectWin.cancel: + return + self.features = selectWin.selectedItemsText + + def methodChanged(self, method): + if method == "APPROXIMATE": + self.maxSearchRadiusSpinbox.setDisabled(False) + else: + self.maxSearchRadiusSpinbox.setDisabled(True) + + def onPathSelected(self, path): + self.modelPathLineEdit.setText(path) + + def ok_cb(self, checked=False): + self.cancel = False + try: + m = re.findall(r"\((\d+), *(\d+)\)", self.volumeLineEdit.text()) + if len(m) < 2: + raise + self.volume = tuple([(int(start), int(end)) for start, end in m]) + if len(self.volume) == 2: + self.volume = (self.volume[0], self.volume[1], (-1e5, 1e5)) + except Exception as e: + self.warnNotAcceptedVolume() + return + + if not os.path.exists(self.modelPathLineEdit.text()): + self.warnNotVaidPath() + return + + self.intensityImageChannel = None + self.verbose = self.verboseToggle.isChecked() + self.max_search_radius = self.maxSearchRadiusSpinbox.value() + self.update_method = self.updateMethodCombobox.currentText() + self.model_path = os.path.normpath(self.modelPathLineEdit.text()) + self.params = { + "model_path": self.model_path, + "verbose": self.verbose, + "volume": self.volume, + "max_search_radius": self.max_search_radius, + "update_method": self.update_method, + "step_size": self.stepSizeSpinbox.value(), + "optimize": self.optimizeToggle.isChecked(), + "features": self.features, + } + if self.channels is not None: + if self.channelCombobox.currentText() != "None": + self.intensityImageChannel = self.channelCombobox.currentText() + self.close() + + def warnNotVaidPath(self): + url = "https://github.com/lowe-lab-ucl/segment-classify-track/tree/main/models" + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "The model configuration file path

" + f"{self.modelPathLineEdit.text()}

" + "does not exist.

" + "You can find some pre-configured models " + f'here.' + ) + msg.critical(self, "Invalid volume", txt) + + def warnNotAcceptedVolume(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + f"{self.volumeLineEdit.text()} is not a valid volume!

" + "Valid volume is for example (0, 2048), (0, 2048)
" + "for 2D segmentation or (0, 2048), (0, 2048), (0, 2048)
" + "for 3D segmentation." + ) + msg.critical(self, "Invalid volume", txt) + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + super().show() + self.resize(int(self.width() * 1.3), self.height()) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class DeltaTrackerParamsWin(QDialog): + def __init__(self, posData=None, parent=None): + self.cancel = True + super().__init__(parent) + + self.setWindowFlags(Qt.Dialog | Qt.WindowStaysOnTopHint) + self.setWindowTitle("Delta tracker parameters") + + paramsLayout = QGridLayout() + paramsBox = QGroupBox() + + row = 0 + this_path = os.path.dirname(os.path.abspath(__file__)) + default_model_path = this_path + + label = QLabel(html_utils.paragraph("Original Images path")) + paramsLayout.addWidget(label, row, 0) + modelPathLineEdit = QLineEdit() + start_dir = "" + if os.path.exists(default_model_path): + start_dir = os.path.dirname(default_model_path) + modelPathLineEdit.setText(default_model_path) + self.modelPathLineEdit = modelPathLineEdit + paramsLayout.addWidget(modelPathLineEdit, row, 1) + browseButton = widgets.browseFileButton( + title="Select Original Images", ext={"TIFF": (".tif",)}, start_dir=start_dir + ) + if posData is not None: + modelPathLineEdit.setText(posData.imgPath) + browseButton.sigPathSelected.connect(self.onPathSelected) + paramsLayout.addWidget(browseButton, row, 2, alignment=Qt.AlignLeft) + + row += 1 + label = QLabel(html_utils.paragraph("Model Type")) + paramsLayout.addWidget(label, row, 0) + updateMethodCombobox = QComboBox() + updateMethodCombobox.addItems(["2D", "mothermachine"]) + self.model_type = "2D" + self.updateMethodCombobox = updateMethodCombobox + self.updateMethodCombobox.currentTextChanged.connect(self.methodChanged) + paramsLayout.addWidget(updateMethodCombobox, row, 1) + + row += 1 + label = QLabel(html_utils.paragraph("Single Mother Machine Chamber?")) + paramsLayout.addWidget(label, row, 0) + chamberToggle = widgets.Toggle() + chamberToggle.setChecked(True) + self.chamberToggle = chamberToggle + paramsLayout.addWidget(chamberToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Verbose")) + paramsLayout.addWidget(label, row, 0) + verboseToggle = widgets.Toggle() + verboseToggle.setChecked(True) + self.verboseToggle = verboseToggle + paramsLayout.addWidget(verboseToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Legacy Save (.mat)")) + paramsLayout.addWidget(label, row, 0) + legacyToggle = widgets.Toggle() + legacyToggle.setChecked(False) + self.legacyToggle = legacyToggle + paramsLayout.addWidget(legacyToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Pickle (.pkl)")) + paramsLayout.addWidget(label, row, 0) + pickleToggle = widgets.Toggle() + pickleToggle.setChecked(False) + self.pickleToggle = pickleToggle + paramsLayout.addWidget(pickleToggle, row, 1, alignment=Qt.AlignCenter) + + row += 1 + label = QLabel(html_utils.paragraph("Movie (.mp4) *only for 2D images")) + paramsLayout.addWidget(label, row, 0) + movieToggle = widgets.Toggle() + movieToggle.setChecked(False) + self.movieToggle = movieToggle + paramsLayout.addWidget(movieToggle, row, 1, alignment=Qt.AlignCenter) + + cancelButton = widgets.cancelPushButton("Cancel") + okButton = widgets.okPushButton(" Ok ") + cancelButton.clicked.connect(self.cancel_cb) + okButton.clicked.connect(self.ok_cb) + + buttonsLayout = QHBoxLayout() + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + + layout = QVBoxLayout() + infoText = html_utils.paragraph("Delta Tracker parameters") + infoLabel = QLabel(infoText) + layout.addWidget(infoLabel, alignment=Qt.AlignCenter) + layout.addSpacing(10) + paramsBox.setLayout(paramsLayout) + layout.addWidget(paramsBox) + + url = "https://delta.readthedocs.io/en/latest/" + moreInfoText = html_utils.paragraph( + f'Find more info on Delta Tracker\'s home page' + ) + moreInfoLabel = QLabel(moreInfoText) + moreInfoLabel.setOpenExternalLinks(True) + layout.addWidget(moreInfoLabel, alignment=Qt.AlignCenter) + + layout.addSpacing(20) + layout.addLayout(buttonsLayout) + layout.addStretch(1) + self.setLayout(layout) + self.setFont(font) + + def methodChanged(self, method): + if method == "mothermachine": + self.model_type = "mothermachine" + + def onPathSelected(self, path): + self.modelPathLineEdit.setText(path) + + def ok_cb(self, checked=False): + self.cancel = False + + if not os.path.exists(self.modelPathLineEdit.text()): + self.warnNotVaidPath() + return + + self.verbose = self.verboseToggle.isChecked() + self.legacy = self.legacyToggle.isChecked() + self.pickle = self.pickleToggle.isChecked() + self.movie = self.movieToggle.isChecked() + self.chamber = self.chamberToggle.isChecked() + self.model_path = os.path.normpath(self.modelPathLineEdit.text()) + self.params = { + "original_images_path": self.model_path, + "verbose": self.verbose, + "legacy": self.legacy, + "pickle": self.pickle, + "movie": self.movie, + "model_type": self.model_type, + "single mothermachine chamber": self.chamber, + } + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + super().show() + self.resize(int(self.width() * 1.3), self.height()) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class GenerateMotherBudTotalTableSelectColumnsDialog(QBaseDialog): + def __init__(self, df: pd.DataFrame, parent=None): + super().__init__(parent) + + self.setWindowTitle("Select columns to combine into the output table") + + self.cancel = True + + self.columns = core.natsort_acdc_columns(df.columns) + self.operations = ( + "Sum mother and bud", + "Copy column from mother", + ) + + self.mainLayout = QVBoxLayout() + + instructionsText = html_utils.paragraph(""" + Select which columns and how you want to combine them + into the output table.
+ """) + self.mainLayout.addWidget(QLabel(instructionsText)) + + settingsLayout = QGridLayout() + + row = 0 + settingsLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) + + row += 1 + settingsLayout.addWidget( + QLabel("Copy all non-selected columns from mother cell"), row, 0 + ) + self.copyAllColsToggle = widgets.Toggle() + settingsLayout.addWidget(self.copyAllColsToggle, row, 1, alignment=Qt.AlignLeft) + + row += 1 + settingsLayout.addWidget(widgets.QHLine(), row, 0, 1, 2) + + self.mainLayout.addLayout(settingsLayout) + + scrollArea = widgets.ScrollArea() + scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + scrollWidget = QWidget() + scrollArea.setWidget(scrollWidget) + self.centralLayout = QGridLayout() + scrollWidget.setLayout(self.centralLayout) + + self.centralLayout.addWidget(QLabel("Grouping columns"), 0, 0) + self.centralLayout.addWidget(QLabel("Column"), 0, 1) + self.centralLayout.addWidget(QLabel("Operation"), 0, 2) + self.centralLayout.setRowStretch(0, 0) + + self.groupingColsListWidget = widgets.listWidget( + isMultipleSelection=True, + ) + self.groupingColsListWidget.addItems(self.columns) + self.centralLayout.addWidget(self.groupingColsListWidget, 1, 0, 2, 1) + + selector = widgets.ComboBox(self) + selector.addItems(self.columns) + operationCombobox = widgets.ComboBox(self) + operationCombobox.addItems(self.operations) + self.addSelectorButton = widgets.addPushButton() + + dummyButton = widgets.delPushButton() + dummyButton.setRetainSizeWhenHidden(True) + dummyButton.hide() + self.centralLayout.addWidget(dummyButton, 1, 4) + + self.centralLayout.addWidget(selector, 1, 1) + self.centralLayout.addWidget(operationCombobox, 1, 2) + self.centralLayout.addWidget(self.addSelectorButton, 1, 3) + + self.centralLayout.setRowStretch(1, 1) + self.centralLayout.setRowStretch(2, 1) + + self.selectors = {1: (selector, operationCombobox)} + + buttonsLayout = widgets.CancelOkButtonsLayout() + + saveSelectionButton = widgets.savePushButton("Save current selection") + buttonsLayout.insertWidget(3, saveSelectionButton) + + loadDefaultColsButton = widgets.reloadPushButton( + "Load default summable columns" + ) + buttonsLayout.insertWidget(4, loadDefaultColsButton) + + loadPreviousSelButton = widgets.OpenFilePushButton("Load previous selection") + buttonsLayout.insertWidget(5, loadPreviousSelButton) + + saveSelectionButton.clicked.connect(self.saveSelection) + loadDefaultColsButton.clicked.connect(self.loadDefaultCols) + loadPreviousSelButton.clicked.connect(self.loadPreviousSelection) + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addWidget(scrollArea) + self.mainLayout.addSpacing(20) + self.mainLayout.addLayout(buttonsLayout) + + self.addSelectorButton.clicked.connect(self.addSelector) + selector.currentTextChanged.connect(self.selectorTextChanged) + + self.setLayout(self.mainLayout) + self.setFont(font) + + def saveSelection(self): + saved_selections = io.get_saved_moth_bud_tot_selections() + existing_names = set(saved_selections.keys()) + win = filenameDialog( + basename="", + ext="", + hintText="Insert a name for the current selection:", + existingNames=existing_names, + allowEmpty=False, + defaultEntry="mother_bud_total_columns_selection", + ) + win.exec_() + if win.cancel: + return + + name = win.filename + saved_selections[name] = self.selectedOptions() + io.save_moth_bud_tot_selected_options(saved_selections) + + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(f""" + Current selection saved with name {name}. + """) + msg.information(self, "Selection saved", txt) + + def loadDefaultCols(self): + from . import single_pos_index_cols + + grouping_cols = [col for col in single_pos_index_cols if col in self.columns] + self.groupingColsListWidget.setSelectedItems(grouping_cols) + + column_operation_mapper = { + col: "Sum mother and bud" for col in cca_functions.default_summable_columns + } + column_operation_mapper = { + col: op + for col, op in column_operation_mapper.items() + if col in self.columns and op in self.operations + } + self.addSelectors( + len(column_operation_mapper), + callback_on_finished=partial( + self.setSelectorValues, column_operation_mapper + ), + ) + + def loadPreviousSelection(self): + saved_selections = io.get_saved_moth_bud_tot_selections() + if not saved_selections: + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(""" + There are no saved selections. + """) + msg.warning(self, "No saved selections", txt) + return + + existing_names = natsorted(saved_selections.keys(), key=str.casefold) + + selectNameWin = widgets.QDialogListbox( + "Choose selection to load", + "Choose selection to load:\n", + existing_names, + multiSelection=False, + parent=self, + ) + selectNameWin.exec_() + if selectNameWin.cancel: + return + + self.loadOptions(saved_selections[selectNameWin.selectedItemsText[0]]) + + def resetSelectors(self, callback_on_finished=None): + self.callback_on_finished = callback_on_finished + QTimer.singleShot(1, self._removeLastSelector) + + def _removeLastSelector(self): + if len(self.selectors) == 1: + if self.callback_on_finished is not None: + self.callback_on_finished() + return + + lastRow = max(self.selectors.keys()) + lastSelector, _ = self.selectors[lastRow] + self.removeSelector(sender=lastSelector.delButton) + QTimer.singleShot(1, self._removeLastSelector) + + def addSelectors(self, number, callback_on_finished=None): + self.callback_on_finished = callback_on_finished + QTimer.singleShot(1, partial(self._addSelectorRecursive, number)) + + def _addSelectorRecursive(self, number): + if len(self.selectors) == number: + if self.callback_on_finished is not None: + self.callback_on_finished() + return + + self.addSelector() + QTimer.singleShot(1, partial(self._addSelectorRecursive, number)) + + def loadOptions(self, options: dict): + if len(self.selectors) > 1: + self.resetSelectors(callback_on_finished=partial(self.loadOptions, options)) + return + + self.copyAllColsToggle.setChecked( + options.get("do_copy_all_nonselected_columns", False) + ) + self.groupingColsListWidget.setSelectedItems( + options.get("grouping_columns", []) + ) + column_operation_mapper = options.get("column_operation_mapper", {}) + column_operation_mapper = { + col: op + for col, op in column_operation_mapper.items() + if col in self.columns and op in self.operations + } + if len(column_operation_mapper) > 1: + self.addSelectors( + len(column_operation_mapper), + callback_on_finished=partial( + self.setSelectorValues, column_operation_mapper + ), + ) + return + + self.setSelectorValues(column_operation_mapper) + + def setSelectorValues(self, column_operation_mapper): + for i, (col, op) in enumerate(column_operation_mapper.items()): + selector, operationCombobox = self.selectors[i + 1] + selector.setCurrentText(col) + operationCombobox.setCurrentText(op) + + def resetSelectorsStyles(self): + for selector, _ in self.selectors.values(): + selector.setStyleSheet("") + + def selectorTextChanged(self, text): + self.resetSelectorsStyles() + selector = self.sender() + for other_selector, _ in self.selectors.values(): + if other_selector == selector: + continue + + if selector.currentText() != other_selector.currentText(): + continue + + self.setWarningStyleSelector(selector) + self.setWarningStyleSelector(other_selector) + + def addSelector(self): + row = len(self.selectors) + 1 + + selector = widgets.ComboBox(self) + selector.addItems(self.columns) + selector.setCurrentIndex(len(self.selectors)) + operationCombobox = widgets.ComboBox(self) + operationCombobox.addItems(self.operations) + delButton = widgets.delPushButton() + selector.delButton = delButton + delButton._row = row + + self.selectors[row] = (selector, operationCombobox) + + self.centralLayout.addWidget(selector, row, 1) + self.centralLayout.addWidget(operationCombobox, row, 2) + self.centralLayout.addWidget(delButton, row, 3) + + self.centralLayout.removeWidget(self.addSelectorButton) + self.centralLayout.addWidget(self.addSelectorButton, row, 4) + + delButton.clicked.connect(self.removeSelector) + + self.centralLayout.removeWidget(self.groupingColsListWidget) + rowSpan = self.centralLayout.rowCount() + self.centralLayout.addWidget(self.groupingColsListWidget, 1, 0, rowSpan, 1) + self.centralLayout.setRowStretch(rowSpan, 1) + + selector.currentTextChanged.connect(self.selectorTextChanged) + + def removeSelector(self, checked=False, sender=None): + if sender is None: + delButton = self.sender() + else: + delButton = sender + + selector, operationCombobox = self.selectors.pop(delButton._row) + + self.centralLayout.removeWidget(selector) + self.centralLayout.removeWidget(operationCombobox) + self.centralLayout.removeWidget(delButton) + + resorted_selectors = {} + for i, (row, (sel, op)) in enumerate(self.selectors.items()): + if i == 0: + resorted_selectors[i + 1] = (sel, op) + continue + + delButton = sel.delButton + delButton._row = i + 1 + self.centralLayout.removeWidget(sel) + self.centralLayout.removeWidget(op) + self.centralLayout.removeWidget(delButton) + self.centralLayout.addWidget(sel, i + 1, 1) + self.centralLayout.addWidget(op, i + 1, 2) + self.centralLayout.addWidget(delButton, i + 1, 3) + + resorted_selectors[i + 1] = (sel, op) + + last_row = i + 1 + col = 4 if last_row > 1 else 3 + self.centralLayout.removeWidget(self.addSelectorButton) + self.centralLayout.addWidget(self.addSelectorButton, i + 1, col) + + self.selectors = resorted_selectors + + def sizeHint(self): + width = super().sizeHint().width() + height = super().sizeHint().height() + groupingColsWidth = widgets.get_min_width_for_no_scrollbar( + self.groupingColsListWidget + ) + width += groupingColsWidth + return QSize(width, height) + + def checkDuplicatedSelectedColumns(self): + for selector, _ in self.selectors.values(): + selector.setStyleSheet("background-color: none") + for other_selector, _ in self.selectors.values(): + if other_selector == selector: + continue + + if other_selector.currentText() != selector.currentText(): + continue + + self.warnDuplicatedSelectedColumns(selector, other_selector) + return False + + return True + + def setWarningStyleSelector(self, selector): + popup = selector.view() + palette = popup.palette() + text_color = palette.color(palette.ColorRole.Text) + warningStyleSheet = f""" + QComboBox {{ + color: black; + background-color: orange; /* main area */ + }} + QComboBox QAbstractItemView {{ + background-color: {text_color.name()}; + }} + """ + selector.setStyleSheet(warningStyleSheet) + + def warnDuplicatedSelectedColumns(self, selector1, selector2): + self.setWarningStyleSelector(selector1) + self.setWarningStyleSelector(selector2) + + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(f""" + The following column has been selected more than once + (highlighted in orange).

+ {selector1.currentText()}

+ Please, select each column only once.

+ Thank you for your patience! + """) + msg.warning(self, "Duplicated selection", txt) + + def checkGroupingColumnsNotSelected(self): + if self.groupingColsListWidget.selectedItems(): + return True + + return self.warnGroupingColumnsNotSelected() + + def warnGroupingColumnsNotSelected(self): + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph(f""" + Are you sure you do not want to select any grouping column?

+ Grouping columns are those needed to identify each unique + Position folder. + """) + _, noButton, yesButton = msg.question( + self, + "No grouping columns selected?", + txt, + buttonsTexts=( + "Cancel", + "No, let me select grouping columns", + "Yes, I do not need grouping columns", + ), + ) + return msg.clickedButton == yesButton + + def selectedOptions(self): + selected_options = { + "grouping_columns": self.groupingColsListWidget.selectedItemsText(), + "column_operation_mapper": { + selector.currentText(): operationCombobox.currentText() + for selector, operationCombobox in self.selectors.values() + }, + "do_copy_all_nonselected_columns": self.copyAllColsToggle.isChecked(), + } + return selected_options + + def ok_cb(self): + proceed = self.checkDuplicatedSelectedColumns() + if not proceed: + return + + proceed = self.checkGroupingColumnsNotSelected() + if not proceed: + return + + self.selected_options = self.selectedOptions() + + self.cancel = False + self.close() + + +class ApplyTrackTableSelectColumnsDialog(QBaseDialog): + def __init__(self, df, parent=None): + super().__init__(parent) + + self.setWindowTitle("Select columns containing tracking info") + + self.cancel = True + self.mainLayout = QVBoxLayout() + + options = ( + '"Frame index", "Tracked IDs" and "Segmentation mask IDs"
', + '"Frame index", "Tracked IDs", "X coord. centroid", and "Y coord. centroid"', + ) + self.instructionsText = html_utils.paragraph( + f""" + Select which columns contain the tracking information.

+ You must choose one of the following combinations:
+ {html_utils.to_list(options)} + Optionally, you can provide the column name containing the parent ID.
+ This will allow you to load lineage information into Cell-ACDC. + """ + ) + self.mainLayout.addWidget(QLabel(self.instructionsText)) + + formLayout = QFormLayout() + + self.frameIndexCombobox = widgets.QCenteredComboBox() + self.frameIndexCombobox.addItems(df.columns) + self.frameIndexCheckbox = QCheckBox("1st frame is index 1") + frameIndexLayout = QHBoxLayout() + frameIndexLayout.addWidget(self.frameIndexCombobox) + frameIndexLayout.addWidget(self.frameIndexCheckbox) + frameIndexLayout.setStretch(0, 2) + frameIndexLayout.setStretch(1, 0) + formLayout.addRow("Frame index: ", frameIndexLayout) + + self.trackedIDsCombobox = widgets.QCenteredComboBox() + self.trackedIDsCombobox.addItems(df.columns) + formLayout.addRow("Tracked IDs: ", self.trackedIDsCombobox) + + items = df.columns.to_list() + items.insert(0, "None") + self.maskIDsCombobox = widgets.QCenteredComboBox() + self.maskIDsCombobox.addItems(items) + formLayout.addRow("Segmentation mask IDs: ", self.maskIDsCombobox) + + self.xCentroidCombobox = widgets.QCenteredComboBox() + self.xCentroidCombobox.addItems(items) + formLayout.addRow("X coord. centroid: ", self.xCentroidCombobox) + + self.yCentroidCombobox = widgets.QCenteredComboBox() + self.yCentroidCombobox.addItems(items) + formLayout.addRow("Y coord. centroid: ", self.yCentroidCombobox) + + self.parentIDcombobox = widgets.QCenteredComboBox() + self.parentIDcombobox.addItems(items) + formLayout.addRow("Parent ID (optional): ", self.parentIDcombobox) + + deleteUntrackedLayout = QHBoxLayout() + self.deleteUntrackedIDsToggle = widgets.Toggle() + deleteUntrackedLayout.addStretch(1) + deleteUntrackedLayout.addWidget(self.deleteUntrackedIDsToggle) + deleteUntrackedLayout.addStretch(1) + formLayout.addRow("Delete untracked IDs: ", deleteUntrackedLayout) + + buttonsLayout = widgets.CancelOkButtonsLayout() + + buttonsLayout.okButton.clicked.connect(self.ok_cb) + buttonsLayout.cancelButton.clicked.connect(self.close) + + self.mainLayout.addSpacing(30) + self.mainLayout.addLayout(formLayout) + self.mainLayout.addSpacing(20) + self.mainLayout.addLayout(buttonsLayout) + + self.setLayout(self.mainLayout) + self.setFont(font) + + def ok_cb(self): + self.cancel = False + self.frameIndexCol = self.frameIndexCombobox.currentText() + self.trackedIDsCol = self.trackedIDsCombobox.currentText() + self.maskIDsCol = self.maskIDsCombobox.currentText() + self.xCentroidCol = self.xCentroidCombobox.currentText() + self.yCentroidCol = self.yCentroidCombobox.currentText() + self.deleteUntrackedIDs = self.deleteUntrackedIDsToggle.isChecked() + if self.maskIDsCol == "None": + if self.xCentroidCol == "None" or self.yCentroidCol == "None": + self.warnInvalidSelection() + return + else: + self.xCentroidCol = "None" + self.yCentroidCol = "None" + self.parentIDcol = self.parentIDcombobox.currentText() + self.isFirstFrameOne = self.frameIndexCheckbox.isChecked() + self.close() + + def warnInvalidSelection(self): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.warning( + self, + "Invalid selection", + html_utils.paragraph( + f"Invalid selection
{self.instructionsText}" + ), + ) + + +class editCcaTableWidget(QDialog): + sigApplyChangesFutureFrames = Signal(object, int) + + def __init__( + self, + cca_df, + SizeT, + title="Edit cell cycle annotations", + parent=None, + current_frame_i=0, + ): + self.inputCca_df = cca_df + self.cancel = True + self.SizeT = SizeT + self.cca_df = None + self.current_frame_i = current_frame_i + + super().__init__(parent) + self.setWindowTitle(title) + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + # Layouts + mainLayout = QVBoxLayout() + headerLayout = QGridLayout() + tableLayout = QGridLayout() + buttonsLayout = QHBoxLayout() + self.scrollArea = QScrollArea() + self.viewBox = QWidget() + + # Header labels + col = 0 + row = 0 + IDsLabel = QLabel("Cell ID") + AC = Qt.AlignCenter + IDsLabel.setAlignment(AC) + headerLayout.addWidget(IDsLabel, 0, col, alignment=AC) + + col += 1 + ccsLabel = QLabel("Cell cycle stage") + ccsLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(ccsLabel, 0, col, alignment=AC) + + col += 1 + relIDLabel = QLabel("Relative ID") + relIDLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(relIDLabel, 0, col, alignment=AC) + + col += 1 + genNumLabel = QLabel("Generation number") + genNumLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(genNumLabel, 0, col, alignment=AC) + genNumColWidth = genNumLabel.sizeHint().width() + + col += 1 + relationshipLabel = QLabel("Relationship") + relationshipLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(relationshipLabel, 0, col, alignment=AC) + + col += 1 + emergFrameLabel = QLabel("Emerging frame num.") + emergFrameLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(emergFrameLabel, 0, col, alignment=AC) + + col += 1 + divitionFrameLabel = QLabel("Division frame num.") + divitionFrameLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(divitionFrameLabel, 0, col, alignment=AC) + + col += 1 + historyKnownLabel = QLabel("Is history known?") + historyKnownLabel.setAlignment(Qt.AlignCenter) + headerLayout.addWidget(historyKnownLabel, 0, col, alignment=AC) + + self.headerLayout = headerLayout + + tableLayout.setHorizontalSpacing(20) + self.tableLayout = tableLayout + + # Add buttons + cancelButton = widgets.cancelPushButton("Cancel") + moreInfoButton = widgets.helpPushButton("More info...") + moreInfoButton.setIcon(QIcon(":info.svg")) + applyToFutureFramesbutton = widgets.futurePushButton( + "Apply changes to future frames..." + ) + okButton = widgets.okPushButton("Ok") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(moreInfoButton) + buttonsLayout.addWidget(applyToFutureFramesbutton) + buttonsLayout.addWidget(okButton) + + # Scroll area properties + self.scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.scrollArea.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.scrollArea.setFrameStyle(QFrame.Shape.NoFrame) + self.scrollArea.setWidgetResizable(True) + + # Add layouts + self.viewBox.setLayout(tableLayout) + self.scrollArea.setWidget(self.viewBox) + mainLayout.addLayout(headerLayout) + mainLayout.addWidget(self.scrollArea) + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + # Populate table Layout + IDs = cca_df.index + self.IDs = IDs.to_list() + relIDsOptions = [str(ID) for ID in IDs] + relIDsOptions.insert(0, "-1") + self.IDlabels = [] + self.ccsComboBoxes = [] + self.genNumSpinBoxes = [] + self.relIDComboBoxes = [] + self.relationshipComboBoxes = [] + self.emergFrameSpinBoxes = [] + self.divisFrameSpinBoxes = [] + self.emergFrameSpinPrevValues = [] + self.divisFrameSpinPrevValues = [] + self.historyKnownCheckBoxes = [] + for row, ID in enumerate(IDs): + col = 0 + IDlabel = QLabel(f"{ID}") + IDlabel.setAlignment(Qt.AlignCenter) + tableLayout.addWidget(IDlabel, row + 1, col, alignment=AC) + self.IDlabels.append(IDlabel) + + col += 1 + ccsComboBox = QComboBox() + ccsComboBox.setFocusPolicy(Qt.StrongFocus) + ccsComboBox.installEventFilter(self) + ccsComboBox.addItems(["G1", "S/G2/M"]) + ccsValue = cca_df.at[ID, "cell_cycle_stage"] + if ccsValue == "S": + ccsValue = "S/G2/M" + + try: + ccsComboBox.setCurrentText(ccsValue) + except Exception as err: + printl(ccsValue) + printl(cca_df) + raise err + tableLayout.addWidget(ccsComboBox, row + 1, col, alignment=AC) + self.ccsComboBoxes.append(ccsComboBox) + ccsComboBox.activated.connect(self.clearComboboxFocus) + + col += 1 + relIDComboBox = QComboBox() + relIDComboBox.setFocusPolicy(Qt.StrongFocus) + relIDComboBox.installEventFilter(self) + relIDComboBox.addItems(relIDsOptions) + relIDComboBox.setCurrentText(str(cca_df.at[ID, "relative_ID"])) + tableLayout.addWidget(relIDComboBox, row + 1, col) + self.relIDComboBoxes.append(relIDComboBox) + relIDComboBox.currentIndexChanged.connect(self.setRelID) + relIDComboBox.activated.connect(self.clearComboboxFocus) + + col += 1 + genNumSpinBox = widgets.SpinBox() + genNumSpinBox.setFocusPolicy(Qt.StrongFocus) + genNumSpinBox.installEventFilter(self) + genNumSpinBox.setValue(2) + genNumSpinBox.setMaximum(2147483647) + genNumSpinBox.setAlignment(Qt.AlignCenter) + genNumSpinBox.setFixedWidth(int(genNumColWidth * 2 / 3)) + genNumSpinBox.setValue(int(cca_df.at[ID, "generation_num"])) + tableLayout.addWidget(genNumSpinBox, row + 1, col, alignment=AC) + self.genNumSpinBoxes.append(genNumSpinBox) + + col += 1 + relationshipComboBox = QComboBox() + relationshipComboBox.setFocusPolicy(Qt.StrongFocus) + relationshipComboBox.installEventFilter(self) + relationshipComboBox.addItems(["mother", "bud"]) + relationshipComboBox.setCurrentText(str(cca_df.at[ID, "relationship"])) + tableLayout.addWidget(relationshipComboBox, row + 1, col) + self.relationshipComboBoxes.append(relationshipComboBox) + relationshipComboBox.currentIndexChanged.connect( + self.relationshipChanged_cb + ) + relationshipComboBox.activated.connect(self.clearComboboxFocus) + + col += 1 + emergFrameSpinBox = widgets.SpinBox() + emergFrameSpinBox.setFocusPolicy(Qt.StrongFocus) + emergFrameSpinBox.installEventFilter(self) + emergFrameSpinBox.setMaximum(SizeT) + emergFrameSpinBox.setMinimum(-1) + emergFrameSpinBox.setValue(-1) + emergFrameSpinBox.setAlignment(Qt.AlignCenter) + emergFrameSpinBox.setFixedWidth(int(genNumColWidth * 2 / 3)) + emergFrame_i = cca_df.at[ID, "emerg_frame_i"] + val = emergFrame_i + 1 if emergFrame_i >= 0 else -1 + emergFrameSpinBox.setValue(val) + tableLayout.addWidget(emergFrameSpinBox, row + 1, col, alignment=AC) + self.emergFrameSpinBoxes.append(emergFrameSpinBox) + self.emergFrameSpinPrevValues.append(emergFrameSpinBox.value()) + emergFrameSpinBox.valueChanged.connect(self.skip0emergFrame) + + col += 1 + divisFrameSpinBox = widgets.SpinBox() + divisFrameSpinBox.setFocusPolicy(Qt.StrongFocus) + divisFrameSpinBox.installEventFilter(self) + divisFrameSpinBox.setMinimum(-1) + divisFrameSpinBox.setMaximum(SizeT) + divisFrameSpinBox.setValue(-1) + divisFrameSpinBox.setAlignment(Qt.AlignCenter) + divisFrameSpinBox.setFixedWidth(int(genNumColWidth * 2 / 3)) + divisFrame_i = int(cca_df.at[ID, "division_frame_i"]) + val = divisFrame_i + 1 if divisFrame_i >= 0 else -1 + divisFrameSpinBox.setValue(val) + tableLayout.addWidget(divisFrameSpinBox, row + 1, col, alignment=AC) + self.divisFrameSpinBoxes.append(divisFrameSpinBox) + self.divisFrameSpinPrevValues.append(divisFrameSpinBox.value()) + divisFrameSpinBox.valueChanged.connect(self.skip0divisFrame) + + col += 1 + HistoryCheckBox = QCheckBox() + HistoryCheckBox.setChecked(bool(cca_df.at[ID, "is_history_known"])) + tableLayout.addWidget(HistoryCheckBox, row + 1, col, alignment=AC) + self.historyKnownCheckBoxes.append(HistoryCheckBox) + + self.setLayout(mainLayout) + + # Connect to events + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + moreInfoButton.clicked.connect(self.moreInfo) + applyToFutureFramesbutton.clicked.connect(self.applyToFutureFrames) + + # self.setModal(True) + + def getChanges(self): + newCcaDf = self.getCca_df() + changes = {} + for row in newCcaDf.itertuples(): + ID = row.Index + for col in newCcaDf.columns: + inputValue = self.inputCca_df.at[ID, col] + newValue = getattr(row, col) + if newValue == inputValue: + continue + + if ID not in changes: + changes[ID] = {col: (inputValue, newValue)} + else: + changes[ID][col] = (inputValue, newValue) + return changes + + def applyToFutureFrames(self): + txt = "Enter up to which frame you want to apply the changes
" + win = NumericEntryDialog( + title="Stop frame", + instructions=txt, + parent=self, + minValue=1, + maxValue=self.SizeT, + currentValue=self.current_frame_i, + ) + win.exec_() + if win.cancel: + return + + stop_frame_i = win.value + changes = self.getChanges() + changes_format = utils.format_cca_manual_changes(changes) + detailsText = ( + f"Changes that will be applied from frame n. {self.current_frame_i + 1}" + f" to frame n. {stop_frame_i + 1}:\n\n{changes_format}" + ) + txt = html_utils.paragraph(""" +Use this feature with caution!

+Before propagating to future frames carefully inspect what changes +will be applied (see below).

+""") + msg = widgets.myMessageBox(wrapText=False) + msg.setDetailedText(detailsText, visible=True) + msg.warning(self, "Caution!", txt, buttonsTexts=("Yes, I am sure", "Cancel")) + if msg.cancel: + return + + self.sigApplyChangesFutureFrames.emit(changes, stop_frame_i) + + def moreInfo(self, checked=True): + desc = utils.get_cca_colname_desc() + msg = widgets.myMessageBox(parent=self) + msg.setWindowTitle("Cell cycle annotations info") + msg.setWidth(400) + msg.setIcon() + for col, txt in desc.items(): + msg.addText(html_utils.paragraph(f"{col}: {txt}")) + msg.addButton(" Ok ") + msg.exec_() + + def setRelID(self, itemIndex): + idx = self.relIDComboBoxes.index(self.sender()) + relID = self.sender().currentText() + IDofRelID = self.IDs[idx] + relIDidx = self.IDs.index(int(relID)) + relIDComboBox = self.relIDComboBoxes[relIDidx] + relIDComboBox.setCurrentText(str(IDofRelID)) + + def skip0emergFrame(self, value): + idx = self.emergFrameSpinBoxes.index(self.sender()) + prevVal = self.emergFrameSpinPrevValues[idx] + if value == 0 and value > prevVal: + self.sender().setValue(1) + self.emergFrameSpinPrevValues[idx] = 1 + elif value == 0 and value < prevVal: + self.sender().setValue(-1) + self.emergFrameSpinPrevValues[idx] = -1 + + def skip0divisFrame(self, value): + idx = self.divisFrameSpinBoxes.index(self.sender()) + prevVal = self.divisFrameSpinPrevValues[idx] + if value == 0 and value > prevVal: + self.sender().setValue(1) + self.divisFrameSpinPrevValues[idx] = 1 + elif value == 0 and value < prevVal: + self.sender().setValue(-1) + self.divisFrameSpinPrevValues[idx] = -1 + + def relationshipChanged_cb(self, itemIndex): + idx = self.relationshipComboBoxes.index(self.sender()) + ccs = self.sender().currentText() + if ccs == "bud": + self.ccsComboBoxes[idx].setCurrentText("S/G2/M") + self.genNumSpinBoxes[idx].setValue(0) + + def getCca_df(self): + ccsValues = [var.currentText() for var in self.ccsComboBoxes] + ccsValues = [val if val == "G1" else "S" for val in ccsValues] + genNumValues = [var.value() for var in self.genNumSpinBoxes] + relIDValues = [int(var.currentText()) for var in self.relIDComboBoxes] + relatValues = [var.currentText() for var in self.relationshipComboBoxes] + emergFrameValues = [ + var.value() - 1 if var.value() > 0 else -1 + for var in self.emergFrameSpinBoxes + ] + divisFrameValues = [ + var.value() - 1 if var.value() > 0 else -1 + for var in self.divisFrameSpinBoxes + ] + historyValues = [var.isChecked() for var in self.historyKnownCheckBoxes] + check_rel = [ID == relID for ID, relID in zip(self.IDs, relIDValues)] + + # Buds in S phase must have 0 as number of cycles + check_buds_S = [ + ccs == "S" and rel_ship == "bud" and not numc == 0 + for ccs, rel_ship, numc in zip(ccsValues, relatValues, genNumValues) + ] + + # Mother cells must have at least 1 as number of cycles if history known + check_mothers = [ + rel_ship == "mother" and not numc >= 1 if is_history_known else False + for rel_ship, numc, is_history_known in zip( + relatValues, genNumValues, historyValues + ) + ] + + # Buds cannot be in G1 + check_buds_G1 = [ + ccs == "G1" and rel_ship == "bud" + for ccs, rel_ship in zip(ccsValues, relatValues) + ] + + # The number of cells in S phase must be half mothers and half buds + num_moth_S = len( + [ + 0 + for ccs, rel_ship in zip(ccsValues, relatValues) + if ccs == "S" and rel_ship == "mother" + ] + ) + num_bud_S = len( + [ + 0 + for ccs, rel_ship in zip(ccsValues, relatValues) + if ccs == "S" and rel_ship == "bud" + ] + ) + + # Cells in S phase cannot have -1 as relative's ID + check_relID_S = [ + ccs == "S" and relID == -1 for ccs, relID in zip(ccsValues, relIDValues) + ] + + # Mother cells with unknown history at emergence is recommended to have + # generation number = 2 (easier downstream analysis) + check_unknown_mothers = [ + rel_ship == "mother" + and not is_history_known + and gen_num != 2 + and (emerg_frame_i == self.current_frame_i or self.current_frame_i == 0) + for rel_ship, is_history_known, gen_num, emerg_frame_i in zip( + relatValues, historyValues, genNumValues, emergFrameValues + ) + ] + + if any(check_rel): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + Some cells are mother or bud of itself!

+ Make sure that the relative ID is different from the Cell ID. + """) + msg.critical(self, "Some IDs are equal to relative ID", txt) + return None + elif any(check_unknown_mothers): + txt = html_utils.paragraph(""" + We recommend to set generation number to 2 for mother cells + with unknown history
+ that just appeared
(i.e., first cell cycle in the video).

+ While it is allowed to insert any number, knowing that these + cells start at generation number 2
+ makes downstream analysis easier.

+ What do you want to do? + """) + correctButtonText = " Fine, let me correct. " + keepButtonText = " Keep the generation number that I chose. " + buttonsTexts = (correctButtonText, keepButtonText) + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + msg.warning(self, "Recommendation", txt, buttonsTexts=buttonsTexts) + if msg.cancel or msg.clickedButton == correctButtonText: + return None + elif any(check_buds_S): + msg = widgets.myMessageBox(wrapText=False) + title = "Bud in S/G2/M not in 0 Generation number" + txt = html_utils.paragraph( + "Some buds " + "in S phase do not have 0 as Generation number!
" + 'Buds in S phase must have 0 as "Generation number"' + ) + msg.critical(self, title, txt) + return None + elif any(check_mothers): + msg = widgets.myMessageBox(wrapText=False) + title = "Mother not in >=1 Generation number" + txt = html_utils.paragraph( + 'Some mother cells do not have >=1 as "Generation number"!
' + 'Mothers MUST have >1 "Generation number"' + ) + msg.critical(self, title, txt) + return None + elif any(check_buds_G1): + msg = widgets.myMessageBox(wrapText=False) + title = "Buds in G1!" + txt = html_utils.paragraph( + "Some buds are in G1 phase!

Buds MUST be in S/G2/M phase" + ) + msg.critical(self, title, txt) + return None + elif num_moth_S != num_bud_S: + msg = widgets.myMessageBox(wrapText=False) + title = "Number of mothers-buds mismatch!" + txt = html_utils.paragraph( + f'There are {num_moth_S} mother cells in "S/G2/M" phase,' + f"but there are {num_bud_S} bud cells.

" + 'The number of mothers and buds in "S/G2/M" ' + "phase must be equal!" + ) + msg.critical(self, title, txt) + return None + elif any(check_relID_S): + msg = widgets.myMessageBox(wrapText=False) + title = "Relative's ID of cells in S/G2/M = -1" + txt = html_utils.paragraph( + 'Some cells are in "S/G2/M" phase but have -1 as Relative\'s ID!
' + 'Cells in "S/G2/M" phase must have an existing ' + "ID as Relative's ID!" + ) + msg.critical(self, title, txt) + return None + + corrected_on_frame_i = self.inputCca_df["corrected_on_frame_i"] + cca_df = pd.DataFrame( + { + "cell_cycle_stage": ccsValues, + "generation_num": genNumValues, + "relative_ID": relIDValues, + "relationship": relatValues, + "emerg_frame_i": emergFrameValues, + "division_frame_i": divisFrameValues, + "is_history_known": historyValues, + "corrected_on_frame_i": corrected_on_frame_i, + "will_divide": self.inputCca_df["will_divide"], + }, + index=self.IDs, + ) + cca_df.index.name = "Cell_ID" + + # Add missing columns + for column, default in base_cca_dict.items(): + if column in cca_df.columns: + continue + + value = self.inputCca_df.get(column, default=default) + cca_df[column] = value + + # Check that every pair of cells in S are relative of each other + proceed = self.check_ID_rel_ID_mismatches(cca_df) + if not proceed: + return None + + d = dict.fromkeys(cca_df.select_dtypes(np.int64).columns, np.int32) + cca_df = cca_df.astype(d) + return cca_df + + def check_ID_rel_ID_mismatches(self, cca_df): + ID_rel_ID_mismatches = [] + for row in cca_df.itertuples(): + if row.cell_cycle_stage == "G1": + continue + + ID = row.Index + relID = row.relative_ID + relID_of_relID = cca_df.at[relID, "relative_ID"] + + if relID_of_relID != ID: + ID_rel_ID_mismatches.append((ID, relID, relID_of_relID)) + + if not ID_rel_ID_mismatches: + return True + + items = [ + f"Cell ID {ID} has relative ID = {relID}, " + f"while cell ID {relID} has relative ID = {relID_of_relID}" + for ID, relID, relID_of_relID in ID_rel_ID_mismatches + ] + title = "`ID-relative_ID` mismatches" + txt = html_utils.paragraph( + f"`ID-relative_ID` mismatches:{html_utils.to_list(items)}" + ) + msg = widgets.myMessageBox(wrapText=False) + msg.critical(self, title, txt) + return False + + def ok_cb(self, checked): + cca_df = self.getCca_df() + if cca_df is None: + return + self.cca_df = cca_df + self.cancel = False + self.close() + + def cancel_cb(self, checked): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + ncols = self.tableLayout.columnCount() + maxLabelWidth = max( + [ + self.headerLayout.itemAt(j).widget().sizeHint().width() + for j in range(ncols) + ] + ) + minWidth = (maxLabelWidth + 5) * ncols + self.setMinimumWidth(minWidth) + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def eventFilter(self, object, event): + # Disable wheel scroll on widgets to allow scroll only on scrollarea + if event.type() == QEvent.Type.Wheel: + event.ignore() + return True + return False + + def clearComboboxFocus(self): + self.sender().clearFocus() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class FindIDDialog(QLineEditDialog): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.okButton.setIcon(QIcon(":magnGlass.svg")) + self.okButton.setText(" Find ") + + +class NumericEntryDialog(QBaseDialog): + def __init__( + self, + title="Entry a value", + currentValue=0, + instructions="Entry value", + parent=None, + maxValue=None, + minValue=None, + stretch=False, + ): + super().__init__(parent=parent) + self.setWindowTitle(title) + self.cancel = False + mainLayout = QVBoxLayout() + entryLayout = QHBoxLayout() + cancelOkLayout = widgets.CancelOkButtonsLayout() + cancelOkLayout.okButton.clicked.connect(self.ok_cb) + cancelOkLayout.cancelButton.clicked.connect(self.close) + + instructionsLabel = QLabel(html_utils.paragraph(instructions)) + mainLayout.addWidget(instructionsLabel) + + if type(currentValue) == int: + self.entryWidget = widgets.SpinBox() + self.entryWidget.setValue(currentValue) + self.valueGetter = "value" + if maxValue is not None: + self.entryWidget.setMaximum(maxValue) + if minValue is not None: + self.entryWidget.setMinimum(minValue) + + if stretch: + entryLayout.addWidget(self.entryWidget) + else: + entryLayout.addStretch(1) + entryLayout.addWidget(self.entryWidget) + entryLayout.addStretch(1) + + mainLayout.addLayout(entryLayout) + mainLayout.addSpacing(20) + mainLayout.addLayout(cancelOkLayout) + + self.setLayout(mainLayout) + + def ok_cb(self): + self.cancel = False + self.value = getattr(self.entryWidget, self.valueGetter)() + self.close() + + +class EditIDDialog(QDialog): + def __init__( + self, + clickedID, + IDs, + entryID=None, + doNotShowAgain=False, + parent=None, + nextUniqueID=1, + allIDs=None, + addPropagateCheckbox=False, + ): + self.assignNewID = False + self.IDs = IDs + self.clickedID = clickedID + self.cancel = True + self.how = None + self.mergeWithExistingID = True + self.doNotAskAgainExistingID = doNotShowAgain + self.allIDs = allIDs + if allIDs is None: + self.allIDs = set(self.IDs) + self.nextUniqueID = nextUniqueID + + super().__init__(parent) + self.setWindowTitle("Edit ID") + mainLayout = QVBoxLayout() + + VBoxLayout = QVBoxLayout() + msg = QLabel(f"Replace ID {clickedID} with:") + _font = QFont() + _font.setPixelSize(12) + msg.setFont(_font) + # padding: top, left, bottom, right + msg.setStyleSheet("padding:0px 0px 3px 0px;") + VBoxLayout.addWidget(msg, alignment=Qt.AlignCenter) + + entryWidget = QLineEdit() + entryWidget.setFont(_font) + entryWidget.setAlignment(Qt.AlignCenter) + self.entryWidget = entryWidget + VBoxLayout.addWidget(entryWidget) + if entryID is not None: + entryWidget.setText(str(entryID)) + entryWidget.selectAll() + + VBoxLayout.addWidget( + QLabel(f"Next unique ID = {nextUniqueID}"), alignment=Qt.AlignCenter + ) + + VBoxLayout.addWidget(widgets.QHLine()) + + self.warnExistingIDLabel = QLabel() + self.warnExistingIDLabel.setStyleSheet("color: red") + VBoxLayout.addWidget(self.warnExistingIDLabel, alignment=Qt.AlignCenter) + + note = QLabel( + "NOTE: To replace multiple IDs at once\n" + 'write "(old ID, new ID), (old ID, new ID)" etc.' + ) + note.setFont(_font) + note.setAlignment(Qt.AlignCenter) + # padding: top, left, bottom, right + note.setStyleSheet("padding:12px 0px 0px 0px;") + VBoxLayout.addWidget(note, alignment=Qt.AlignCenter) + mainLayout.addLayout(VBoxLayout) + + self.propagateCheckbox = None + if addPropagateCheckbox: + mainLayout.addSpacing(10) + self.propagateCheckbox = QCheckBox("Apply to future frames") + mainLayout.addWidget(self.propagateCheckbox) + + buttonsLayout = QHBoxLayout() + okButton = widgets.okPushButton("Ok") + cancelButton = widgets.cancelPushButton("Cancel") + applyNewIDButton = widgets.AssignNewIDButton("Assign new, unique ID") + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(applyNewIDButton) + buttonsLayout.addWidget(okButton) + + mainLayout.addSpacing(20) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + # Connect events + self.prevText = "" + entryWidget.textChanged[str].connect(self.onTextChanged) + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + applyNewIDButton.clicked.connect(self.assignNewIDclicked) + + # self.setModal(True) + + def onTextChanged(self, text): + self.warnExistingIDLabel.setText("") + try: + ID = int(text) + if ID in self.allIDs: + self.warnExistingIDLabel.setText(f"WARNING: ID {ID} was already used") + except Exception as err: + pass + + # Get inserted char + idx = self.entryWidget.cursorPosition() + if idx == 0: + return + + newChar = text[idx - 1] + + # Do nothing if user is deleting text + if idx == 0 or len(text) < len(self.prevText): + self.prevText = text + return + + # Do not allow chars except for "(", ")", "int", "," + m = re.search(r"\(|\)|\d|,", newChar) + if m is None: + self.prevText = text + text = text.replace(newChar, "") + self.entryWidget.setText(text) + return + + # Cast integers greater than uint32 machine limit + m_iter = re.finditer(r"\d+", self.entryWidget.text()) + for m in m_iter: + val = int(m.group()) + uint32_max = np.iinfo(np.uint32).max + if val > uint32_max: + text = self.entryWidget.text() + text = f"{text[: m.start()]}{uint32_max}{text[m.end() :]}" + self.entryWidget.setText(text) + + # Automatically close ( bracket + if newChar == "(": + text += ")" + self.entryWidget.setText(text) + self.prevText = text + + def _warnExistingID(self, existingID, newID): + warn_msg = html_utils.paragraph(f""" + ID {existingID} is already existing.

+ How do you want to proceed?
+ """) + msg = widgets.myMessageBox() + doNotAskAgainCheckbox = QCheckBox("Remember my choice and do not ask again") + swapButton = widgets.reloadPushButton(f"Swap {newID} with {existingID}") + mergeButton = widgets.mergePushButton(f"Merge {newID} with {existingID}") + msg.warning( + self, + "Existing ID", + warn_msg, + buttonsTexts=("Cancel", mergeButton, swapButton), + widgets=doNotAskAgainCheckbox, + ) + if msg.cancel: + return False + self.doNotAskAgainExistingID = doNotAskAgainCheckbox.isChecked() + self.mergeWithExistingID = msg.clickedButton == mergeButton + return True + + def assignNewIDclicked(self): + self.cancel = False + self.how = None + self.assignNewID = True + self.close() + + def ok_cb(self, event): + txt = self.entryWidget.text() + valid = False + + # Check validity of inserted text + try: + ID = int(txt) + how = [(self.clickedID, ID)] + if ID in self.IDs and not self.doNotAskAgainExistingID: + proceed = self._warnExistingID(self.clickedID, ID) + if not proceed: + return + valid = True + else: + valid = True + except ValueError: + pattern = r"\((\d+),\s*(\d+)\)" + fa = re.findall(pattern, txt) + if fa: + how = [(int(g[0]), int(g[1])) for g in fa] + valid = True + else: + valid = False + + if not valid: + err_msg = html_utils.paragraph( + "You entered invalid text. Valid text is either a single integer" + f" ID that will be used to replace ID {self.clickedID} " + "or a list of elements enclosed in parenthesis separated by a comma
" + "such as (5, 10), (8, 27) to replace ID 5 with ID 10 and ID 8 with ID 27" + ) + msg = widgets.myMessageBox() + msg.warning(self, "Invalid entry", err_msg) + return + + self.cancel = False + self.how = how + self.doPropagateFutureFrames = False + if self.propagateCheckbox is not None: + self.doPropagateFutureFrames = self.propagateCheckbox.isChecked() + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class manualSeparateGui(QMainWindow): + def __init__( + self, + lab, + ID, + img, + fontSize="12pt", + IDcolor=[255, 255, 0], + parent=None, + loop=None, + drawMode="threepoints_arc", + ): + super().__init__(parent) + self.loop = loop + self.cancel = True + self.drawMode = drawMode + self._parent = parent + self.lab = lab.copy() + self.lab[lab != ID] = 0 + self.ID = ID + self.img = skimage.exposure.equalize_adapthist(img / img.max()) + self.IDcolor = IDcolor + self.countClicks = 0 + self.prevLabs = [] + self.prevAllCutsCoords = [] + self.labelItemsIDs = [] + self.undoIdx = 0 + self.fontSize = fontSize + self.AllCutsCoords = [] + self.setWindowTitle("Split object") + # self.setGeometry(Left, Top, 850, 800) + + self.gui_createActions() + self.gui_createMenuBar() + self.gui_createToolBars() + + self.gui_createStatusBar() + + self.gui_createGraphics() + self.gui_connectImgActions() + + self.gui_createImgWidgets() + self.gui_connectActions() + + self.updateImg() + self.zoomToObj() + + mainContainer = QWidget() + self.setCentralWidget(mainContainer) + + mainLayout = QGridLayout() + mainLayout.addWidget(self.graphLayout, 0, 0, 1, 1) + mainLayout.addLayout(self.img_Widglayout, 1, 0) + + mainContainer.setLayout(mainLayout) + + self.setWindowModality(Qt.WindowModal) + + def centerWindow(self): + parent = self._parent + if parent is not None: + # Center the window on main window + mainWinGeometry = parent.geometry() + mainWinLeft = mainWinGeometry.left() + mainWinTop = mainWinGeometry.top() + mainWinWidth = mainWinGeometry.width() + mainWinHeight = mainWinGeometry.height() + mainWinCenterX = int(mainWinLeft + mainWinWidth / 2) + mainWinCenterY = int(mainWinTop + mainWinHeight / 2) + winGeometry = self.geometry() + winWidth = winGeometry.width() + winHeight = winGeometry.height() + winLeft = int(mainWinCenterX - winWidth / 2) + winRight = int(mainWinCenterY - winHeight / 2) + self.move(winLeft, winRight) + + def gui_createActions(self): + # File actions + self.exitAction = QAction("&Exit", self) + self.helpAction = QAction("Help", self) + self.undoAction = QAction(QIcon(":undo.svg"), "Undo (Ctrl+Z)", self) + self.undoAction.setEnabled(False) + self.undoAction.setShortcut("Ctrl+Z") + + self.okAction = QAction(QIcon(":applyCrop.svg"), "Happy with that", self) + self.cancelAction = QAction(QIcon(":cancel.svg"), "Cancel", self) + + self.drawModesActionGroup = QActionGroup(self) + + self.threePointsArcAction = QAction( + QIcon(":threepoints_arc.svg"), "Separate with three-points arc", self + ) + self.threePointsArcAction.setCheckable(True) + self.threePointsArcAction.drawMode = "threepoints_arc" + self.drawModesActionGroup.addAction(self.threePointsArcAction) + + self.freeHandAction = QAction( + QIcon(":freehand.svg"), "Separate with freehand line", self + ) + self.freeHandAction.setCheckable(True) + self.freeHandAction.drawMode = "freehand" + self.drawModesActionGroup.addAction(self.freeHandAction) + + if self.drawMode == "threepoints_arc": + self.threePointsArcAction.setChecked(True) + elif self.drawMode == "freehand": + self.freeHandAction.setChecked(True) + + self.swapIDsAction = QAction(QIcon(":reload.svg"), "Swap IDs", self) + self.swapIDsAction.setToolTip('Swap the two displayed IDs\n\nShortcut: "S"') + self.swapIDsAction.setShortcut("S") + + def state(self): + return { + "is_overlay_active": self.overlayButton.isChecked(), + "is_three_points_active": self.threePointsArcAction.isChecked(), + "is_free_hand_active": self.freeHandAction.isChecked(), + } + + def show(self, block=False): + super().show() + if not block: + return + self.loop = QEventLoop(self) + self.loop.exec_() + + def setState(self, state): + if state is None: + return + self.overlayButton.setChecked(state.get("is_overlay_active", False)) + self.threePointsArcAction.setChecked(state.get("is_three_points_active", True)) + self.freeHandAction.setChecked(state.get("is_free_hand_active", False)) + + def gui_storeDrawMode(self): + self.drawMode = self.sender().drawMode + + def gui_createMenuBar(self): + menuBar = self.menuBar() + # style = "QMenuBar::item:selected { background: white; }" + # menuBar.setStyleSheet(style) + # File menu + fileMenu = QMenu("&File", self) + menuBar.addMenu(fileMenu) + + menuBar.addAction(self.helpAction) + fileMenu.addAction(self.exitAction) + + def gui_createToolBars(self): + toolbarSize = 30 + + editToolBar = QToolBar("Edit", self) + editToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) + self.addToolBar(editToolBar) + + editToolBar.addAction(self.okAction) + editToolBar.addAction(self.cancelAction) + + editToolBar.addAction(self.undoAction) + + self.overlayButton = QToolButton(self) + self.overlayButton.setIcon(QIcon(":overlay.svg")) + self.overlayButton.setCheckable(True) + self.overlayButton.setToolTip("Overlay channel's image") + editToolBar.addWidget(self.overlayButton) + + editToolBar.addAction(self.threePointsArcAction) + editToolBar.addAction(self.freeHandAction) + + editToolBar.addAction(self.swapIDsAction) + + self.warnLabel = QLabel() + editToolBar.addWidget(self.warnLabel) + + def gui_connectActions(self): + self.exitAction.triggered.connect(self.close) + self.helpAction.triggered.connect(self.help) + self.okAction.triggered.connect(self.ok_cb) + self.cancelAction.triggered.connect(self.close) + self.undoAction.triggered.connect(self.undo) + self.overlayButton.toggled.connect(self.toggleOverlay) + self.imgGrad.sigLookupTableChanged.connect(self.histLUT_cb) + self.swapIDsAction.triggered.connect(self.swapIDs) + + def gui_createStatusBar(self): + self.statusbar = self.statusBar() + # Temporary message + self.statusbar.showMessage("Ready", 3000) + # Permanent widget + self.wcLabel = QLabel(f"") + self.statusbar.addPermanentWidget(self.wcLabel) + + def gui_createGraphics(self): + self.graphLayout = pg.GraphicsLayoutWidget() + + # Plot Item container for image + self.ax = pg.PlotItem() + self.ax.invertY(True) + self.ax.setAspectLocked(True) + self.ax.hideAxis("bottom") + self.ax.hideAxis("left") + self.graphLayout.addItem(self.ax, row=1, col=1) + + # Image Item + self.imgItem = pg.ImageItem(np.zeros((512, 512))) + self.ax.addItem(self.imgItem) + + # Image histogram + self.imgGrad = widgets.myHistogramLUTitem() + + # Curvature items + self.hoverLinSpace = np.linspace(0, 1, 1000) + self.hoverLinePen = pg.mkPen( + color=(200, 0, 0, 255 * 0.5), width=2, style=Qt.DashLine + ) + self.hoverCurvePen = pg.mkPen(color=(200, 0, 0, 255 * 0.5), width=3) + self.lineHoverPlotItem = pg.PlotDataItem(pen=self.hoverLinePen) + self.curvHoverPlotItem = pg.PlotDataItem(pen=self.hoverCurvePen) + self.curvAnchors = pg.ScatterPlotItem( + symbol="o", + size=9, + brush=pg.mkBrush((255, 0, 0, 50)), + pen=pg.mkPen((255, 0, 0), width=2), + hoverable=True, + hoverPen=pg.mkPen((255, 0, 0), width=3), + hoverBrush=pg.mkBrush((255, 0, 0)), + ) + self.ax.addItem(self.curvAnchors) + self.ax.addItem(self.curvHoverPlotItem) + self.ax.addItem(self.lineHoverPlotItem) + + self.freeHandItem = widgets.PlotCurveItem(pen=pg.mkPen(color="r", width=2)) + self.ax.addItem(self.freeHandItem) + + def gui_createImgWidgets(self): + self.img_Widglayout = QGridLayout() + self.img_Widglayout.setContentsMargins(50, 0, 50, 0) + + alphaScrollBar_label = QLabel("Overlay alpha ") + alphaScrollBar = QScrollBar(Qt.Horizontal) + alphaScrollBar.setFixedHeight(20) + alphaScrollBar.setMinimum(0) + alphaScrollBar.setMaximum(40) + alphaScrollBar.setValue(12) + alphaScrollBar.setToolTip( + "Control the alpha value of the overlay.\n" + "alpha=0 results in NO overlay,\n" + "alpha=1 results in only labels visible" + ) + alphaScrollBar.sliderMoved.connect(self.alphaScrollBarMoved) + self.alphaScrollBar = alphaScrollBar + self.alphaScrollBar_label = alphaScrollBar_label + self.img_Widglayout.addWidget( + alphaScrollBar_label, 0, 0, alignment=Qt.AlignCenter + ) + self.img_Widglayout.addWidget(alphaScrollBar, 0, 1, 1, 20) + self.alphaScrollBar.hide() + self.alphaScrollBar_label.hide() + + def gui_connectImgActions(self): + self.imgItem.hoverEvent = self.gui_hoverEventImg + self.imgItem.mousePressEvent = self.gui_mousePressEventImg + self.imgItem.mouseMoveEvent = self.gui_mouseDragEventImg + self.imgItem.mouseReleaseEvent = self.gui_mouseReleaseEventImg + + def gui_hoverEventImg(self, event): + # Update x, y, value label bottom right + try: + x, y = event.pos() + xdata, ydata = int(x), int(y) + _img = self.lab + Y, X = _img.shape + if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: + val = _img[ydata, xdata] + self.wcLabel.setText(f"(x={x:.2f}, y={y:.2f}, ID={val:.0f})") + else: + self.wcLabel.setText(f"") + except Exception as e: + self.wcLabel.setText(f"") + + if event.isExit(): + return + + self.drawHoverEvent(*event.pos()) + + def gui_mousePressEventImg(self, event): + right_click = event.button() == Qt.MouseButton.RightButton + left_click = event.button() == Qt.MouseButton.LeftButton + + dragImg = left_click + + if dragImg: + pg.ImageItem.mousePressEvent(self.imgItem, event) + + if not right_click: + return + + self.drawPressEvent(event) + + def gui_mouseDragEventImg(self, event): + pass + + def gui_mouseReleaseEventImg(self, event): + if self.countClicks == 0: + return + if self.freeHandAction.isChecked(): + self.countClicks = 0 + xx, yy = self.freeHandItem.getData() + self.setSplitCurveCoords(xx, yy) + self.splitObjectAlongCurve() + self.freeHandItem.setData([], []) + self.curvAnchors.setData([], []) + + def getSpline(self, xx, yy): + tck, u = scipy.interpolate.splprep([xx, yy], s=0, k=2) + xi, yi = scipy.interpolate.splev(self.hoverLinSpace, tck) + return xi, yi + + def drawPressEvent(self, event): + if self.freeHandAction.isChecked(): + self.countClicks = 1 + x, y = event.pos().x(), event.pos().y() + self.curvAnchors.addPoints([x], [y]) + elif self.threePointsArcAction.isChecked(): + self.threePointsArcPressEvent(event) + + def drawHoverEvent(self, x, y): + if self.freeHandAction.isChecked(): + self.freeHandHoverEvent(x, y) + elif self.threePointsArcAction.isChecked(): + self.threePointsArcHoverEvent(x, y) + + def freeHandHoverEvent(self, x, y): + if self.countClicks == 0: + return + self.freeHandItem.addPoint(int(x), int(y)) + _xx, _yy = self.freeHandItem.getData() + xx = [_xx[0], x] + yy = [_yy[0], y] + self.curvAnchors.setData(xx, yy) + + def threePointsArcHoverEvent(self, x, y): + if self.countClicks == 1: + self.lineHoverPlotItem.setData([self.x0, x], [self.y0, y]) + elif self.countClicks == 2: + xx = [self.x0, x, self.x1] + yy = [self.y0, y, self.y1] + xi, yi = self.getSpline(xx, yy) + self.curvHoverPlotItem.setData(xi, yi) + elif self.countClicks == 0: + self.curvHoverPlotItem.setData([], []) + self.lineHoverPlotItem.setData([], []) + self.curvAnchors.setData([], []) + + def threePointsArcPressEvent(self, event): + if self.countClicks == 0: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + self.x0, self.y0 = xdata, ydata + self.curvAnchors.addPoints([xdata], [ydata]) + self.countClicks = 1 + elif self.countClicks == 1: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + self.x1, self.y1 = xdata, ydata + self.curvAnchors.addPoints([xdata], [ydata]) + self.countClicks = 2 + elif self.countClicks == 2: + self.countClicks = 0 + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + xx = [self.x0, xdata, self.x1] + yy = [self.y0, ydata, self.y1] + xi, yi = self.getSpline(xx, yy) + yy, xx = np.round(yi).astype(int), np.round(xi).astype(int) + self.setSplitCurveCoords(xx, yy) + self.splitObjectAlongCurve() + + def setSplitCurveCoords(self, xx, yy): + self.storeUndoState() + xxCurve, yyCurve = [], [] + for i, (r0, c0) in enumerate(zip(yy, xx)): + if i == len(yy) - 1: + break + r1 = yy[i + 1] + c1 = xx[i + 1] + rr, cc, _ = skimage.draw.line_aa(r0, c0, r1, c1) + # rr, cc = skimage.draw.line(r0, c0, r1, c1) + nonzeroMask = self.lab[rr, cc] > 0 + xxCurve.extend(cc[nonzeroMask]) + yyCurve.extend(rr[nonzeroMask]) + self.AllCutsCoords.append((yyCurve, xxCurve)) + for rr, cc in self.AllCutsCoords: + self.lab[rr, cc] = 0 + self.lab = skimage.morphology.remove_small_objects(self.lab, 5) + + def histLUT_cb(self, LUTitem): + if self.overlayButton.isChecked(): + overlay = self.getOverlay() + self.imgItem.setImage(overlay) + + def swapIDs(self, checked=False): + if len(self.rp) == 1: + self.warnLabel.setText( + html_utils.paragraph( + "WARNING: Split the object before swapping IDs", font_color="red" + ) + ) + return + + self.warnLabel.setText("") + + obj1 = self.rp[0] + obj2 = self.rp[1] + + self.lab[obj1.slice][obj1.image] = obj2.label + self.lab[obj2.slice][obj2.image] = obj1.label + + self.updateImg() + + def updateImg(self): + self.updateLookuptable() + rp = skimage.measure.regionprops(self.lab) + self.rp = rp + + if self.overlayButton.isChecked(): + overlay = self.getOverlay() + self.imgItem.setImage(overlay) + else: + self.imgItem.setImage(self.lab) + + # Draw ID on centroid of each label + for labelItemID in self.labelItemsIDs: + self.ax.removeItem(labelItemID) + self.labelItemsIDs = [] + for obj in rp: + labelItemID = widgets.myLabelItem() + labelItemID.setText(f"{obj.label}", color="r", size=f"{self.fontSize}px") + y, x = obj.centroid + w, h = labelItemID.rect().right(), labelItemID.rect().bottom() + labelItemID.setPos(x - w / 2, y - h / 2) + self.labelItemsIDs.append(labelItemID) + self.ax.addItem(labelItemID) + + def zoomToObj(self): + # Zoom to object + lab_mask = (self.lab > 0).astype(np.uint8) + rp = skimage.measure.regionprops(lab_mask) + obj = rp[0] + min_row, min_col, max_row, max_col = obj.bbox + xRange = min_col - 10, max_col + 10 + yRange = max_row + 10, min_row - 10 + self.ax.setRange(xRange=xRange, yRange=yRange) + + def storeUndoState(self): + self.prevLabs.append(self.lab.copy()) + self.prevAllCutsCoords.append(self.AllCutsCoords.copy()) + self.undoIdx += 1 + self.undoAction.setEnabled(True) + + def undo(self): + self.undoIdx -= 1 + self.lab = self.prevLabs[self.undoIdx] + self.AllCutsCoords = self.prevAllCutsCoords[self.undoIdx] + self.updateImg() + if self.undoIdx == 0: + self.undoAction.setEnabled(False) + self.prevLabs = [] + self.prevAllCutsCoords = [] + + def splitObjectAlongCurve(self): + self.lab = skimage.measure.label(self.lab, connectivity=1) + + # Relabel largest object with original ID + rp = skimage.measure.regionprops(self.lab) + areas = [obj.area for obj in rp] + IDs = [obj.label for obj in rp] + maxAreaIdx = areas.index(max(areas)) + maxAreaID = IDs[maxAreaIdx] + if self.ID not in self.lab: + self.lab[self.lab == maxAreaID] = self.ID + else: + tempID = self.lab.max() + 1 + self.lab[self.lab == maxAreaID] = tempID + self.lab[self.lab == self.ID] = maxAreaID + self.lab[self.lab == tempID] = self.ID + + # Keep only the two largest objects + larger_areas = nlargest(2, areas) + larger_ids = [rp[areas.index(area)].label for area in larger_areas] + for obj in rp: + if obj.label not in larger_ids: + self.lab[tuple(obj.coords.T)] = 0 + + rp = skimage.measure.regionprops(self.lab) + + if self._parent is not None: + self._parent.setBrushID() + # Use parent window setBrushID function for all other IDs + for obj in rp: + if self._parent is None: + break + if obj.label == self.ID: + continue + posData = self._parent.data[self._parent.pos_i] + posData.brushID += 1 + self.lab[obj.slice][obj.image] = posData.brushID + + # Replace 0s on the cutting curve with IDs + self.cutLab = self.lab.copy() + for rr, cc in self.AllCutsCoords: + for y, x in zip(rr, cc): + top_row = self.cutLab[y + 1, x - 1 : x + 2] + bot_row = self.cutLab[y - 1, x - 1 : x + 1] + left_col = self.cutLab[y - 1, x - 1] + right_col = self.cutLab[y : y + 2, x + 1] + allNeigh = list(top_row) + allNeigh.extend(bot_row) + allNeigh.append(left_col) + allNeigh.extend(right_col) + newID = max(allNeigh) + self.lab[y, x] = newID + + self.rp = skimage.measure.regionprops(self.lab) + self.updateImg() + + def updateLookuptable(self): + # Lookup table + self.cmap = colors.getFromMatplotlib("viridis") + self.lut = self.cmap.getLookupTable(0, 1, self.lab.max() + 1) + self.lut[0] = [25, 25, 25] + self.lut[self.ID] = self.IDcolor + if self.overlayButton.isChecked(): + self.imgItem.setLookupTable(None) + else: + self.imgItem.setLookupTable(self.lut) + + def keyPressEvent(self, ev): + if ev.key() == Qt.Key_Escape: + self.countClicks = 0 + self.curvHoverPlotItem.setData([], []) + self.lineHoverPlotItem.setData([], []) + self.curvAnchors.setData([], []) + self.freeHandItem.setData([], []) + elif ev.key() == Qt.Key_Enter or ev.key() == Qt.Key_Return: + self.ok_cb(True) + + def getOverlay(self): + # Rescale intensity based on hist ticks values + min = self.imgGrad.gradient.listTicks()[0][1] + max = self.imgGrad.gradient.listTicks()[1][1] + img = skimage.exposure.rescale_intensity(self.img, in_range=(min, max)) + alpha = self.alphaScrollBar.value() / self.alphaScrollBar.maximum() + + # Convert img and lab to RGBs + rgb_shape = (self.lab.shape[0], self.lab.shape[1], 3) + labRGB = np.zeros(rgb_shape) + labRGB[self.lab > 0] = [1, 1, 1] + imgRGB = skimage.color.gray2rgb(img) + overlay = imgRGB * (1.0 - alpha) + labRGB * alpha + + # Color eaach label + for obj in self.rp: + rgb = self.lut[obj.label] / 255 + overlay[obj.slice][obj.image] *= rgb + + # Convert (0,1) to (0,255) + overlay = (np.clip(overlay, 0, 1) * 255).astype(np.uint8) + return overlay + + def alphaScrollBarMoved(self, alpha_int): + overlay = self.getOverlay() + self.imgItem.setImage(overlay) + + def toggleOverlay(self, checked): + if checked: + self.graphLayout.addItem(self.imgGrad, row=1, col=0) + self.alphaScrollBar.show() + self.alphaScrollBar_label.show() + else: + self.graphLayout.removeItem(self.imgGrad) + self.alphaScrollBar.hide() + self.alphaScrollBar_label.hide() + self.updateImg() + + def help(self): + msg = QMessageBox() + msg.information( + self, + "Help", + "Separate object along a curved line.\n\n" + "To draw a curved line you will need 3 right-clicks:\n\n" + "1. Right-click outside of the object --> a line appears.\n" + "2. Right-click to end the line and a curve going through the " + "mouse cursor will appear.\n" + "3. Once you are happy with the cutting curve right-click again " + "and the object will be separated along the curve.\n\n" + "Note that you can separate as many times as you want.\n\n" + "Once happy click on the green tick on top-right or " + 'cancel the process with the "X" button', + ) + + def ok_cb(self, checked): + self.cancel = False + self.close() + + def closeEvent(self, event): + if self.loop is not None: + self.loop.exit() + + +class ViewCcaTableWindow(pdDataFrameWidget): + sigUpdateCcaTable = Signal(object) + + def __init__(self, df, parent=None): + super().__init__(df, parent=parent) + + updateTableButton = widgets.reloadPushButton("Update table with visible IDs...") + buttonsLayout = QHBoxLayout() + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(updateTableButton) + + self._layout.insertLayout(0, buttonsLayout) + + updateTableButton.clicked.connect(self.emitUpdateCcaTable) + + def emitUpdateCcaTable(self): + self.sigUpdateCcaTable.emit(self) + +# Sibling imports (deferred to avoid import cycles) +from .metadata import ( + filenameDialog, +) + diff --git a/cellacdc/docs/source/citations/citations_per_year.py b/cellacdc/docs/source/citations/citations_per_year.py index be3d7b8e5..ee83488f7 100644 --- a/cellacdc/docs/source/citations/citations_per_year.py +++ b/cellacdc/docs/source/citations/citations_per_year.py @@ -9,33 +9,33 @@ cwd_path = os.path.dirname(os.path.abspath(__file__)) # Load data -csv_path = os.path.join(cwd_path, 'citations_per_year.csv') +csv_path = os.path.join(cwd_path, "citations_per_year.csv") df = pd.read_csv(csv_path) # Calculate maximum of yticks yticks_step = 5 -yticks_max = np.ceil(df['citations'].max() / yticks_step) * yticks_step +yticks_max = np.ceil(df["citations"].max() / yticks_step) * yticks_step yticks_max = int(yticks_max) # Plot -font = {'size': 13} -matplotlib.rc('font', **font) +font = {"size": 13} +matplotlib.rc("font", **font) fig, ax = plt.subplots(figsize=(6, 4)) -ax.bar(df['year'], df['citations']) -ax.set_xlabel('Year') -ax.set_ylabel('Citations') -ax.set_title('Citations per Year') -ax.set_xticks(df['year']) -ax.set_yticks(range(0, yticks_max+1, yticks_step)) +ax.bar(df["year"], df["citations"]) +ax.set_xlabel("Year") +ax.set_ylabel("Citations") +ax.set_title("Citations per Year") +ax.set_xticks(df["year"]) +ax.set_yticks(range(0, yticks_max + 1, yticks_step)) ax.grid( - True, - 'major', - axis='y', - zorder=0, - alpha=0.7, - color='gray', - linestyle='--', + True, + "major", + axis="y", + zorder=0, + alpha=0.7, + color="gray", + linestyle="--", ) ax.set_axisbelow(True) plt.tight_layout() @@ -44,6 +44,6 @@ # exit() # Save image -png_out_path = os.path.join(cwd_path, 'citations_per_year.png') +png_out_path = os.path.join(cwd_path, "citations_per_year.png") plt.savefig(png_out_path, dpi=300) -plt.close() \ No newline at end of file +plt.close() diff --git a/cellacdc/docs/source/conf.py b/cellacdc/docs/source/conf.py index 44d96ed98..671b28df4 100644 --- a/cellacdc/docs/source/conf.py +++ b/cellacdc/docs/source/conf.py @@ -5,9 +5,9 @@ # -- Project information -project = 'Cell-ACDC' +project = "Cell-ACDC" author = cellacdc.__author__ -copyright = f'{datetime.now():%Y}, {author}' +copyright = f"{datetime.now():%Y}, {author}" version = cellacdc.__version__ release = version @@ -16,47 +16,47 @@ # -- General configuration extensions = [ - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', - 'sphinxcontrib.email', - 'sphinx_tabs.tabs', - 'sphinx_carousel.carousel', - 'sphinx_copybutton' + "sphinx.ext.duration", + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinxcontrib.email", + "sphinx_tabs.tabs", + "sphinx_carousel.carousel", + "sphinx_copybutton", ] intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), + "python": ("https://docs.python.org/3/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), } -intersphinx_disabled_domains = ['std'] +intersphinx_disabled_domains = ["std"] -html_favicon = 'https://raw.githubusercontent.com/SchmollerLab/Cell_ACDC/main/cellacdc/resources/icon.ico' +html_favicon = "https://raw.githubusercontent.com/SchmollerLab/Cell_ACDC/main/cellacdc/resources/icon.ico" -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # -- Options for EPUB output -epub_show_urls = 'footnote' +epub_show_urls = "footnote" # -- My css -html_static_path = ['static'] +html_static_path = ["static"] html_css_files = [ - 'css/custom.css', + "css/custom.css", ] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path exclude_patterns = [ - '_build', - 'Thumbs.db', - '.DS_Store', - '_gui_packages.rst', - '_models_list.rst' -] \ No newline at end of file + "_build", + "Thumbs.db", + ".DS_Store", + "_gui_packages.rst", + "_models_list.rst", +] diff --git a/cellacdc/docs/source/models.rst b/cellacdc/docs/source/models.rst index 8396f3b97..1e31bb208 100644 --- a/cellacdc/docs/source/models.rst +++ b/cellacdc/docs/source/models.rst @@ -41,7 +41,7 @@ Adding a segmentation model Adding a segmentation model in a few steps: -1. Create a **new folder** with the models's name (e.g., YeastMate) inside the ``/cellacdc/models`` folder. +1. Create a **new folder** with the models's name (e.g., YeastMate) inside the ``/cellacdc/segmenters`` folder. .. tip:: If you **don't know where Cell-ACDC was installed**, open the main launcher and click on the ``Help --> About Cell-ACDC`` menu on the top menu bar. @@ -90,7 +90,7 @@ Adding a segmentation model in a few steps: The **model parameters** will be **automatically inferred from the class you created** in the ``acdcSegment.py`` file, and a widget with those parameters will pop-up. In this widget you can set the model parameters (or press Ok without changing anything if you want to go with default parameters). -Have a loot at the ``/cellacdc/models`` folder `here `__ for **examples**. You can for example see the ``__init__.py`` file `here `__ and the ``acdcSegment.py`` file `here `__ for YeaZ2. +Have a loot at the ``/cellacdc/segmenters`` folder `here `__ for **examples**. You can for example see the ``__init__.py`` file `here `__ and the ``acdcSegment.py`` file `here `__ for YeaZ2. Adding a tracker diff --git a/cellacdc/exporters.py b/cellacdc/exporters.py index 0cd190574..93c92ae5c 100644 --- a/cellacdc/exporters.py +++ b/cellacdc/exporters.py @@ -14,30 +14,31 @@ import pyqtgraph.exporters import pyqtgraph as pg -from . import transformation, printl, myutils +from . import transformation, printl, utils from . import is_mac, is_win from . import acdc_ffmpeg_path + class ImageExporter(pyqtgraph.exporters.ImageExporter): def __init__(self, item, background=(0, 0, 0, 0), dpi=100, save_pngs=True): super().__init__(item) self._save_pngs = save_pngs - + self._dpi = dpi - + # DPI using A4 width desired_width = 8.268 * dpi - self.params['width'] = desired_width - - self.parameters()['background'] = (0, 0, 0, 0) - + self.params["width"] = desired_width + + self.parameters()["background"] = (0, 0, 0, 0) + def super_export(self, filepath): super().export(filepath) - + def svg_to_image(self, svg_filepath, image_filepath): - width = self.params['width'] - height = self.params['height'] - + width = self.params["width"] + height = self.params["height"] + renderer = QSvgRenderer(svg_filepath) img = QImage(width, height, QImage.Format_ARGB32) img.fill(0) @@ -46,85 +47,83 @@ def svg_to_image(self, svg_filepath, image_filepath): renderer.render(p) p.end() - if image_filepath.endswith('.tiff'): - png_filepath = image_filepath.replace('.tiff', '.png') + if image_filepath.endswith(".tiff"): + png_filepath = image_filepath.replace(".tiff", ".png") img.save(png_filepath, quality=100) png_img = skimage.io.imread(png_filepath) skimage.io.imsave(image_filepath, png_img) - try: + try: os.remove(png_filepath) except Exception: pass else: img.save(image_filepath) - + def crop_from_mask(self, img_rgba): - if not hasattr(self.item, 'exportMaskImageItem'): + if not hasattr(self.item, "exportMaskImageItem"): return img_rgba - + crop_mask_rgba = self.item.exportMaskImageItem.image if crop_mask_rgba is None: return img_rgba - + alpha = crop_mask_rgba[..., 3] rows, cols = np.where(alpha == 0) top, bottom = rows.min(), rows.max() + 1 left, right = cols.min(), cols.max() + 1 - + pos = self.item.exportMaskImageItem.pos() x0, y0 = pos.x(), pos.y() - + view_range = self.item.viewRange() (x_min, x_max), (y_min, y_max) = view_range H, W = img_rgba.shape[:2] - + # x mapping left_px_f = (left - x_min) / (x_max - x_min) * W right_px_f = (right - x_min) / (x_max - x_min) * W - + # y mapping (PNG origin top-left) top_px_f = (y_max - top) / (y_max - y_min) * H bottom_px_f = (y_max - bottom) / (y_max - y_min) * H - - left_px = int(np.floor(left_px_f)) - right_px = int(np.ceil(right_px_f)) + + left_px = int(np.floor(left_px_f)) + right_px = int(np.ceil(right_px_f)) bottom_px = int(np.floor(bottom_px_f)) - top_px = int(np.ceil(top_px_f)) - + top_px = int(np.ceil(top_px_f)) + if left_px < 0: left_px = 0 - + if right_px > W: right_px = W - + if bottom_px < 0: bottom_px = 0 - + if top_px > H: top_px = H - + return img_rgba[bottom_px:top_px, left_px:right_px] - def export(self, filepath): + def export(self, filepath): no_ext_filepath, ext = os.path.splitext(filepath) - svg_filepath = f'{no_ext_filepath}.svg' - svg_exporter = SVGExporter(self.item) + svg_filepath = f"{no_ext_filepath}.svg" + svg_exporter = SVGExporter(self.item) svg_exporter.export(svg_filepath) self.svg_to_image(svg_filepath, filepath) - - try: + + try: os.remove(svg_filepath) except Exception as err: pass - + # Remove padding - img_rgba = skimage.io.imread(filepath) + img_rgba = skimage.io.imread(filepath) img_rgba = self.crop_from_mask(img_rgba) - - img_rgba = transformation.crop_outer_padding( - img_rgba, value=(0, 0, 0, 255) - ) + + img_rgba = transformation.crop_outer_padding(img_rgba, value=(0, 0, 0, 255)) img_rgba = transformation.crop_outer_padding( img_rgba, value=(255, 255, 255, 255) ) @@ -133,21 +132,23 @@ def export(self, filepath): skimage.io.imsave(filepath, img_rgba, check_contrast=False) img_bgr = cv2.cvtColor(img_rgba, cv2.COLOR_RGBA2BGR) - + return img_bgr + class SVGExporter(pyqtgraph.exporters.SVGExporter): def __init__(self, item): super().__init__(item) - self.parameters()['background'] = (0, 0, 0, 0) + self.parameters()["background"] = (0, 0, 0, 0) + class VideoExporter: def __init__(self, avi_filepath, fps): self.writer = None self._avi_filepath = avi_filepath self._fps = fps - self._fourcc = cv2.VideoWriter_fourcc(*'XVID') - + self._fourcc = cv2.VideoWriter_fourcc(*"XVID") + def add_frame(self, img_bgr): if self.writer is None: height, width = img_bgr.shape[:-1] @@ -155,81 +156,85 @@ def add_frame(self, img_bgr): self._avi_filepath, self._fourcc, self._fps, (width, height) ) self.writer.write(img_bgr) - + def release(self): self.writer.release() - + def avi_to_mp4(self): avi_to_mp4(self._avi_filepath) + def avi_to_mp4(in_filepath_avi, out_filepath_mp4=None): - ffmep_exec_path = myutils.download_ffmpeg() - + ffmep_exec_path = utils.download_ffmpeg() + if out_filepath_mp4 is None: - out_filepath_mp4 = in_filepath_avi.replace('.avi', '.mp4') - - ffmep_exec_path = ffmep_exec_path.replace('\\', '/') - out_filepath_mp4 = out_filepath_mp4.replace('\\', '/') - in_filepath_avi = in_filepath_avi.replace('\\', '/') - + out_filepath_mp4 = in_filepath_avi.replace(".avi", ".mp4") + + ffmep_exec_path = ffmep_exec_path.replace("\\", "/") + out_filepath_mp4 = out_filepath_mp4.replace("\\", "/") + in_filepath_avi = in_filepath_avi.replace("\\", "/") + args = [ - '-i', f'{in_filepath_avi}', '-c:v', 'libx264', - '-crf', '18', '-an', f'{out_filepath_mp4}' + "-i", + f"{in_filepath_avi}", + "-c:v", + "libx264", + "-crf", + "18", + "-an", + f"{out_filepath_mp4}", ] - + _run_ffmpeg(ffmep_exec_path, args) + def _run_ffmpeg(ffmep_exec_path, command_args): import subprocess, os - + command_args_no_quotes = [ - arg.replace('"', '').replace("'", '') for arg in command_args + arg.replace('"', "").replace("'", "") for arg in command_args ] - full_command = ' '.join(command_args_no_quotes) - full_command = f'{ffmep_exec_path} {full_command}' - - separator = '-'*100 + full_command = " ".join(command_args_no_quotes) + full_command = f"{ffmep_exec_path} {full_command}" + + separator = "-" * 100 print( - f'{separator}\n' - f'Converting to MP4 with the following command:\n\n' - f'`{full_command}`\n' - f'{separator}' + f"{separator}\n" + f"Converting to MP4 with the following command:\n\n" + f"`{full_command}`\n" + f"{separator}" ) if is_win: subprocess.check_call(full_command) return - - ffmpeg_exec_path = os.path.join(acdc_ffmpeg_path, 'ffmpeg') + + ffmpeg_exec_path = os.path.join(acdc_ffmpeg_path, "ffmpeg") if is_mac: - args_ffmpeg_executable = [f'chmod 755 {ffmpeg_exec_path}'] + args_ffmpeg_executable = [f"chmod 755 {ffmpeg_exec_path}"] subprocess.check_call(args_ffmpeg_executable, shell=True) - command_str = ' '.join(command_args) - command_no_quotes_str = ' '.join(command_args_no_quotes) - + command_str = " ".join(command_args) + command_no_quotes_str = " ".join(command_args_no_quotes) + commands_to_try = ( - ['ffmpeg', *command_args], - ['ffmpeg', *command_args_no_quotes], - f'ffmpeg {command_str}', - f'ffmpeg {command_no_quotes_str}', - [ffmpeg_exec_path, *command_args], + ["ffmpeg", *command_args], + ["ffmpeg", *command_args_no_quotes], + f"ffmpeg {command_str}", + f"ffmpeg {command_no_quotes_str}", + [ffmpeg_exec_path, *command_args], [ffmpeg_exec_path, *command_args_no_quotes], - f'{ffmpeg_exec_path} {command_str}', - f'{ffmpeg_exec_path} {command_no_quotes_str}', + f"{ffmpeg_exec_path} {command_str}", + f"{ffmpeg_exec_path} {command_no_quotes_str}", ) for command in commands_to_try: print( - f'{separator}\n' - f'Attempting conversion to MP4 with the following command:\n\n' - f'`{command}`\n' - f'{separator}' + f"{separator}\n" + f"Attempting conversion to MP4 with the following command:\n\n" + f"`{command}`\n" + f"{separator}" ) try: subprocess.check_call(command, shell=True) break except Exception as err: - print( - f'{separator}\n' - f'[ERROR]: {err}\n' - f'{separator}' - ) \ No newline at end of file + print(f"{separator}\n[ERROR]: {err}\n{separator}") diff --git a/cellacdc/features.py b/cellacdc/features.py index 852a3f16d..5b7e9acce 100644 --- a/cellacdc/features.py +++ b/cellacdc/features.py @@ -9,9 +9,10 @@ from . import measurements from . import printl + def add_rotational_volume_regionprops( - rp, PhysicalSizeY=1, PhysicalSizeX=1, logger_func=None - ): + rp, PhysicalSizeY=1, PhysicalSizeX=1, logger_func=None +): for obj in rp: vol_vox, vol_fl = _core._calc_rotational_vol( obj, PhysicalSizeY, PhysicalSizeX, logger=logger_func @@ -19,54 +20,63 @@ def add_rotational_volume_regionprops( obj.vol_vox, obj.vol_fl = vol_vox, vol_fl return rp + def filter_acdc_df_by_features_range(features_range, acdc_df): - queries = [] + queries = [] for feature_name, thresholds in features_range.items(): if feature_name not in acdc_df.columns: pass _min, _max = thresholds if _min is not None: - queries.append(f'({feature_name} > {_min})') + queries.append(f"({feature_name} > {_min})") if _max is not None: - queries.append(f'({feature_name} < {_max})') + queries.append(f"({feature_name} < {_max})") if not queries: return acdc_df - - query = ' & '.join(queries) + + query = " & ".join(queries) return acdc_df.query(query) + def _eval_equation_df(df, new_col_name, expression): try: df[new_col_name] = df.eval(expression) except Exception as error: traceback.print_exc() + def _add_combined_metrics_acdc_df(posData, df): - # Add channel specifc combined metrics (from equations and + # Add channel specifc combined metrics (from equations and # from user_path_equations sections) config = posData.combineMetricsConfig for chName in posData.loadedChNames: - posDataEquations = config['equations'] - userPathChEquations = config['user_path_equations'] + posDataEquations = config["equations"] + userPathChEquations = config["user_path_equations"] for newColName, equation in posDataEquations.items(): _eval_equation_df(df, newColName, equation) for newColName, equation in userPathChEquations.items(): _eval_equation_df(df, newColName, equation) + def get_acdc_df_features( - posData, grouped_features, lab, foregr_img, frame_i, filename, - channel, bkgrData, other_channels_foregr_imgs - ): + posData, + grouped_features, + lab, + foregr_img, + frame_i, + filename, + channel, + bkgrData, + other_channels_foregr_imgs, +): posData.fluo_bkgrData_dict[filename] = bkgrData - yx_pxl_to_um2 = posData.PhysicalSizeY*posData.PhysicalSizeX - vox_to_fl_3D = ( - posData.PhysicalSizeY*posData.PhysicalSizeX*posData.PhysicalSizeZ - ) - + yx_pxl_to_um2 = posData.PhysicalSizeY * posData.PhysicalSizeX + vox_to_fl_3D = posData.PhysicalSizeY * posData.PhysicalSizeX * posData.PhysicalSizeZ + rp = skimage.measure.regionprops(lab) isSegm3D = lab.ndim == 3 - + # Initialise DataFrame IDs = [obj.label for obj in rp] columns = [] @@ -78,110 +88,141 @@ def get_acdc_df_features( columns.extend(metrics_names) data = np.zeros((len(IDs), len(columns))) df = pd.DataFrame(columns=columns, index=IDs, data=data) - df.index.name = 'Cell_ID' + df.index.name = "Cell_ID" for category, metrics_names in grouped_features.items(): - if category == 'size': + if category == "size": df = measurements.add_size_metrics( df, rp, metrics_names, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D ) - elif category == 'standard': + elif category == "standard": metrics_func, _ = measurements.standard_metrics_func() custom_func_dict = measurements.get_custom_metrics_func() - + # Get metrics to save params = measurements.get_metrics_params( metrics_names, metrics_func, custom_func_dict ) - (bkgr_metrics_params, foregr_metrics_params, - concentration_metrics_params, custom_metrics_params) = params - + ( + bkgr_metrics_params, + foregr_metrics_params, + concentration_metrics_params, + custom_metrics_params, + ) = params + # Get background masks autoBkgr_masks = measurements.get_autoBkgr_mask( lab, isSegm3D, posData, frame_i ) - + autoBkgr_mask, autoBkgr_mask_proj = autoBkgr_masks - dataPrepBkgrROI_mask = measurements.get_bkgrROI_mask( - posData, isSegm3D - ) - + dataPrepBkgrROI_mask = measurements.get_bkgrROI_mask(posData, isSegm3D) + # Get the z-slice if we have z-stacks - z = posData.zSliceSegmentation( - filename, frame_i, errors='ignore' - ) - + z = posData.zSliceSegmentation(filename, frame_i, errors="ignore") + # Get the background data bkgr_data = measurements.get_bkgr_data( - foregr_img, posData, filename, frame_i, autoBkgr_mask, z, - autoBkgr_mask_proj, dataPrepBkgrROI_mask, isSegm3D, lab + foregr_img, + posData, + filename, + frame_i, + autoBkgr_mask, + z, + autoBkgr_mask_proj, + dataPrepBkgrROI_mask, + isSegm3D, + lab, ) - + # Compute background values df = measurements.add_bkgr_values( df, bkgr_data, bkgr_metrics_params[channel], metrics_func ) - + foregr_data = measurements.get_foregr_data(foregr_img, isSegm3D, z) - + # Iterate objects and compute foreground metrics df = measurements.add_foregr_standard_metrics( - df, rp, channel, foregr_data, - foregr_metrics_params[channel], - metrics_func, isSegm3D, - lab, foregr_img, - z_slice=z + df, + rp, + channel, + foregr_data, + foregr_metrics_params[channel], + metrics_func, + isSegm3D, + lab, + foregr_img, + z_slice=z, ) df = measurements.add_concentration_metrics( df, concentration_metrics_params ) - + df = measurements.add_custom_metrics( - df, rp, channel, foregr_data, - custom_metrics_params[channel], - isSegm3D, lab, foregr_img, + df, + rp, + channel, + foregr_data, + custom_metrics_params[channel], + isSegm3D, + lab, + foregr_img, other_channels_foregr_imgs, z_slice=z, ) - - elif category == 'regionprop': + + elif category == "regionprop": try: df, rp_errors = measurements.add_regionprops_metrics( df, lab, metrics_names, logger_func=print ) except Exception as error: traceback.print_exc() - + # Remove 0s columns df = df.loc[:, (df != -2).any(axis=0)] - + return df + def add_background_metrics_names( - grouped_features, channel, isSegm3D, isZstack, isManualBackgrPresent - ): + grouped_features, channel, isSegm3D, isZstack, isManualBackgrPresent +): _, bkgr_val_desc = measurements.standard_metrics_desc( - isZstack, channel, isSegm3D=isSegm3D, - isManualBackgrPresent=isManualBackgrPresent + isZstack, + channel, + isSegm3D=isSegm3D, + isManualBackgrPresent=isManualBackgrPresent, ) backgr_metrics_names = list(bkgr_val_desc.keys()) backgr_metrics_names = [ - name for name in backgr_metrics_names - if (name.find('bkgrVal_median')!=-1 or name.find('bkgrVal_mean')!=-1) + name + for name in backgr_metrics_names + if (name.find("bkgrVal_median") != -1 or name.find("bkgrVal_mean") != -1) ] - if 'standard' not in grouped_features: - grouped_features['standard'] = {channel: backgr_metrics_names} + if "standard" not in grouped_features: + grouped_features["standard"] = {channel: backgr_metrics_names} else: for backgr_metric_name in backgr_metrics_names: - if backgr_metric_name in grouped_features['standard'][channel]: + if backgr_metric_name in grouped_features["standard"][channel]: continue - grouped_features['standard'][channel].append(backgr_metric_name) + grouped_features["standard"][channel].append(backgr_metric_name) return grouped_features + def custom_post_process_segm( - posData, grouped_features, lab, img, frame_i, filename, channel, - features_range, other_channels_foregr_imgs=None, return_delIDs=False - ): + posData, + grouped_features, + lab, + img, + frame_i, + filename, + channel, + features_range, + other_channels_foregr_imgs=None, + return_delIDs=False, +): isSegm3D = lab.ndim == 3 isZstack = posData.SizeZ > 1 bkgrData = posData.bkgrData @@ -192,8 +233,15 @@ def custom_post_process_segm( grouped_features, channel, isSegm3D, isZstack, isManualBackgrPresent ) df = get_acdc_df_features( - posData, grouped_features, lab, img, frame_i, filename, channel, - bkgrData, other_channels_foregr_imgs + posData, + grouped_features, + lab, + img, + frame_i, + filename, + channel, + bkgrData, + other_channels_foregr_imgs, ) try: filtered_df = filter_acdc_df_by_features_range(features_range, df) @@ -208,4 +256,4 @@ def custom_post_process_segm( if return_delIDs: return filtered_lab, df.index.difference(filtered_df.index).to_list() else: - return filtered_lab \ No newline at end of file + return filtered_lab diff --git a/cellacdc/fiji_macros/__init__.py b/cellacdc/fiji_macros/__init__.py index 5862b93ea..a7eb0c501 100644 --- a/cellacdc/fiji_macros/__init__.py +++ b/cellacdc/fiji_macros/__init__.py @@ -3,64 +3,66 @@ from typing import Iterable from uuid import uuid4 -from cellacdc import myutils +from cellacdc import utils from .. import acdc_fiji_path + def init_macro( - files_folderpath: os.PathLike, - is_multiple_files: bool, - is_separate_channels: bool, - dst_folderpath: os.PathLike, - channels: Iterable[str] - ): - macros_folderpath = os.path.join(acdc_fiji_path, 'macros') + files_folderpath: os.PathLike, + is_multiple_files: bool, + is_separate_channels: bool, + dst_folderpath: os.PathLike, + channels: Iterable[str], +): + macros_folderpath = os.path.join(acdc_fiji_path, "macros") os.makedirs(macros_folderpath, exist_ok=True) - + macros_template_folderpath = os.path.dirname(os.path.abspath(__file__)) if is_separate_channels: - macro_template_filename = 'multiple_files_separate_channels.ijm' + macro_template_filename = "multiple_files_separate_channels.ijm" elif is_multiple_files: - macro_template_filename = 'multiple_files.ijm' + macro_template_filename = "multiple_files.ijm" else: - macro_template_filename = 'single_file.ijm' - + macro_template_filename = "single_file.ijm" + macro_template_filepath = os.path.join( macros_template_folderpath, macro_template_filename ) - with open(macro_template_filepath, 'r') as ijm: + with open(macro_template_filepath, "r") as ijm: macro_txt = ijm.read() - + channels = [f'"{ch.strip()}"' for ch in channels] - channels_macro = ', '.join(channels) + channels_macro = ", ".join(channels) macro_txt = macro_txt.replace( - 'channels = newArray(...)', - f'channels = newArray({channels_macro})' + "channels = newArray(...)", f"channels = newArray({channels_macro})" ) - - files_path = files_folderpath.replace('\\', '/') + + files_path = files_folderpath.replace("\\", "/") files_path = f'"{files_path}/"' - macro_txt = macro_txt.replace('id = ...', f'id = {files_path}') - - dst_folderpath = dst_folderpath.replace('\\', '/') + macro_txt = macro_txt.replace("id = ...", f"id = {files_path}") + + dst_folderpath = dst_folderpath.replace("\\", "/") macro_txt = macro_txt.replace( - 'dst_folderpath = ...', f'dst_folderpath = "{dst_folderpath}"' + "dst_folderpath = ...", f'dst_folderpath = "{dst_folderpath}"' ) - - date_time = datetime.datetime.now().strftime(r'%Y-%m-%d_%H-%M-%S') + + date_time = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S") id = uuid4() - macro_filename = f'{date_time}_{id}_{macro_template_filename}' + macro_filename = f"{date_time}_{id}_{macro_template_filename}" macro_filepath = os.path.join(macros_folderpath, macro_filename) - with open(macro_filepath, 'w') as ijm: + with open(macro_filepath, "w") as ijm: ijm.write(macro_txt) - + return macro_filepath + def command_run_macro(macro_filepath): - exec_path = myutils.get_fiji_exec_folderpath() - command = f'{exec_path} -macro {macro_filepath}' + exec_path = utils.get_fiji_exec_folderpath() + command = f"{exec_path} -macro {macro_filepath}" return command + def run_macro(macro_command): - success = myutils.run_fiji_command(command=macro_command) - return success \ No newline at end of file + success = utils.run_fiji_command(command=macro_command) + return success diff --git a/cellacdc/gui.py b/cellacdc/gui.py index 7247f462e..91cca2d4f 100755 --- a/cellacdc/gui.py +++ b/cellacdc/gui.py @@ -1,217 +1,90 @@ -import gc import sys import os -import shutil -import re -import traceback -import time -from copy import deepcopy -from datetime import datetime, timedelta -import inspect -import logging -import uuid -import json -from collections import defaultdict, Counter -import psutil -import zipfile -from functools import partial -from tqdm import tqdm -from natsort import natsorted -from typing import Literal, Iterable, Dict, Any, List, Union, Tuple, Set -import time -import cv2 -import math import numpy as np -import pandas as pd -import matplotlib -import scipy.optimize -import scipy.interpolate -import scipy.ndimage -import skimage -import skimage.io -import skimage.measure -import skimage.morphology -import skimage.draw -import skimage.exposure -import skimage.transform -import skimage.segmentation - -from functools import wraps -from skimage.color import gray2rgb, gray2rgba, label2rgb - -from qtpy.QtCore import ( - Qt, QPoint, QTextStream, QSize, QRect, QRectF, - QEventLoop, QTimer, QEvent, QObject, Signal, - QThread, QMutex, QWaitCondition, QSettings, PYQT6 -) -from qtpy.QtGui import ( - QIcon, QKeySequence, QCursor, QGuiApplication, QPixmap, QColor, - QFont, QKeyEvent, QMouseEvent -) -from qtpy.QtWidgets import ( - QAction, QLabel, QPushButton, QHBoxLayout, QSizePolicy, - QMainWindow, QMenu, QToolBar, QGroupBox, QGridLayout, - QScrollBar, QCheckBox, QToolButton, QSpinBox, QButtonGroup, QActionGroup, QFileDialog, QAbstractSlider, QMessageBox, QWidget, QGridLayout, - QDockWidget, QGraphicsProxyWidget, QVBoxLayout, QRadioButton, - QSpacerItem, QScrollArea, QFormLayout, QGraphicsSceneMouseEvent +from qtpy.QtCore import Qt, QTimer, Signal +from qtpy.QtWidgets import QMainWindow, QButtonGroup, QWidget + +from . import utils, autopilot, favourite_func_metrics_csv_path, settings_folderpath +from .utils import setupLogger +from .gui_decorators import get_data_exception_handler, resetViewRange + +custom_annot_path = os.path.join(settings_folderpath, "custom_annotations.json") +shortcut_filepath = os.path.join(settings_folderpath, "shortcuts.ini") +from .mixins import ( + WhitelistGui, + DataLoading, + CanvasRightImage, + CanvasHover, + LineageInteractions, + CustomAnnotations, + MagicPrompts, + ObjectSearch, + ObjectCleanup, + SegForLostIds, + Exporting, + CombineWorker, + CurvatureTools, + DrawClearRegion, + LabelTransformTools, + DeletedRois, + Saving, + MainToolbar, + QuickSettings, + MainMenu, + Measurements, ) -import pyqtgraph as pg -pg.setConfigOption('imageAxisOrder', 'row-major') - -from warnings import simplefilter -simplefilter(action="ignore", category=pd.errors.PerformanceWarning) - -# Custom modules -from . import exception_handler, disableWindow -from . import base_cca_dict, lineage_tree_cols, lineage_tree_cols_std_val -from . import graphLayoutBkgrColor, darkBkgrColor -from . import cca_df_colnames -from . import load, prompts, apps, workers, html_utils -from . import core, myutils, dataPrep, widgets -from . import _warnings, issues_url -from . import measurements, printl -from . import colors, annotate -from . import user_manual_url -from . import recentPaths_path, settings_folderpath, settings_csv_path -from . import favourite_func_metrics_csv_path -from . import qutils, autopilot, QtScoped -from . import _palettes -from . import transformation -from . import measure -from . import cca_functions -from . import data_structure_docs_url -from . import exporters -from . import preprocess -from . import io -from . import whitelist -from . import cli -from . import is_mac -from .trackers.CellACDC import CellACDC_tracker -from .cca_functions import _calc_rot_vol -from .myutils import exec_time, setupLogger, ArgSpec -from .help import welcome, about -from .trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import ( - normal_division_lineage_tree)#, reorg_sister_cells_for_export) -from . import debugutils - -from .plot import imshow -from . import gui_utils - -from . import gui_combine +np.seterr(invalid="ignore") -np.seterr(invalid='ignore') - -if os.name == 'nt': +if os.name == "nt": try: - # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) - except Exception as e: + except Exception: pass -GREEN_HEX = _palettes.green() - -custom_annot_path = os.path.join(settings_folderpath, 'custom_annotations.json') -shortcut_filepath = os.path.join(settings_folderpath, 'shortcuts.ini') - -_font = QFont() -_font.setPixelSize(11) - -font_13px = QFont() -font_13px.setPixelSize(13) - -SliderSingleStepAdd = QtScoped.SliderSingleStepAdd() -SliderSingleStepSub = QtScoped.SliderSingleStepSub() -SliderPageStepAdd = QtScoped.SliderPageStepAdd() -SliderPageStepSub = QtScoped.SliderPageStepSub() -SliderMove = QtScoped.SliderMove() - -def qt_debug_trace(): - from qtpy.QtCore import pyqtRemoveInputHook - pyqtRemoveInputHook() - import pdb; pdb.set_trace() - -def get_data_exception_handler(func): - @wraps(func) - def inner_function(self, *args, **kwargs): - try: - if func.__code__.co_argcount==1 and func.__defaults__ is None: - result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: - result = func(self, *args) - else: - result = func(self, *args, **kwargs) - except Exception as e: - try: - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - except AttributeError: - pass - result = None - posData = self.data[self.pos_i] - acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) - segm_filename = os.path.basename(posData.segm_npz_path) - traceback_str = traceback.format_exc() - self.logger.exception(traceback_str) - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.addShowInFileManagerButton(self.logs_path, txt='Show log file...') - msg.setDetailedText(traceback_str) - err_msg = html_utils.paragraph(f""" - Error in function {func.__name__}.

- One possbile explanation is that either the - {acdc_df_filename} file
- or the segmentation file {segm_filename}
- are being synchronized by a cloud service (e.g., Google Drive - or OneDrive) or they are corrupted/damaged.

- Try moving these files (one by one) outside of the - {os.path.dirname(posData.relPath)} folder -
and reloading the data.

- More details below or in the terminal/console.

- Note that the error details from this session are - also saved in the following file:

- {self.log_path}

- Please send the log file when reporting a bug, thanks! - Please restart Cell-ACDC, we apologise for any inconvenience.

- """) - - msg.critical(self, 'Critical error', err_msg) - self.is_error_state = True - raise e - return result - return inner_function - -def resetViewRange(func): - @wraps(func) - def inner_function(self, *args, **kwargs): - self.storeViewRange() - if func.__code__.co_argcount==1 and func.__defaults__ is None: - result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: - result = func(self, *args) - else: - result = func(self, *args, **kwargs) - QTimer.singleShot(200, self.resetRange) - return result - return inner_function - -class guiWin(QMainWindow, whitelist.WhitelistGUIElements, - gui_combine.CombineGuiElements, - gui_combine.CombineGUIWorker): +class guiWin( + QMainWindow, + WhitelistGui, + DataLoading, + CanvasRightImage, + CanvasHover, + LineageInteractions, + CustomAnnotations, + MagicPrompts, + ObjectSearch, + ObjectCleanup, + SegForLostIds, + Exporting, + CombineWorker, + CurvatureTools, + DrawClearRegion, + LabelTransformTools, + DeletedRois, + Saving, + MainToolbar, + QuickSettings, + MainMenu, + Measurements, +): """Main Window.""" sigClosed = Signal(object) sigExportFrame = Signal() def __init__( - self, app, parent=None, buttonToRestore=None, - mainWin=None, version=None, launcherSlot=None - ): + self, + app, + parent=None, + buttonToRestore=None, + mainWin=None, + version=None, + launcherSlot=None, + ): """Initializer.""" super().__init__(parent) @@ -219,20 +92,22 @@ def __init__( self._version = version from .trackers.YeaZ import tracking as tracking_yeaz + self.tracking_yeaz = tracking_yeaz from .config import parser_args - self.debug = parser_args['debug'] + + self.debug = parser_args["debug"] self.buttonToRestore = buttonToRestore self.launcherSlot = launcherSlot self.mainWin = mainWin self.app = app self.closeGUI = False - self._acdc_version = myutils.read_version() + self._acdc_version = utils.read_version() self.setAcceptDrops(True) - self._appName = 'Cell-ACDC' + self._appName = "Cell-ACDC" self.lineage_tree = None self.already_synced_lin_tree = set() @@ -240,48 +115,25 @@ def __init__( self.original_df_lin_tree = None self.original_df_lin_tree_i = None - def setTooltips(self): - tooltips = load.get_tooltips_from_docs() - - for key, tooltip in tooltips.items(): - setShortcut = getattr(self, key).shortcut().toString() - if 'Shortcut: ' in tooltip: - tooltip = tooltip.replace('Shortcut: ', '\nShortcut: ') - elif setShortcut != "": - tooltip = re.sub( - r'Shortcut: \"(.*)\"', - f"Shortcut: \"{setShortcut}\"", - tooltip - ) - else: - tooltip = re.sub( - r'Shortcut: \"(.*)\"', - f"Shortcut: \"No shortcut\"", - tooltip - ) - - getattr(self, key).setToolTip(tooltip) - getattr(self, key)._tooltip = tooltip - - def run(self, module='acdc_gui', logs_path=None): + def run(self, module="acdc_gui", logs_path=None): self.setWindowIcon() self.setWindowTitle() - + self.is_win = sys.platform.startswith("win") if self.is_win: - self.openFolderText = 'Show in Explorer...' + self.openFolderText = "Show in Explorer..." else: - self.openFolderText = 'Reveal in Finder...' + self.openFolderText = "Reveal in Finder..." self.is_error_state = False logger, logs_path, log_path, log_filename = setupLogger( module=module, logs_path=logs_path, caller=self._appName ) if self._version is not None: - logger.info(f'Initializing GUI v{self._version}') + logger.info(f"Initializing GUI v{self._version}") else: - logger.info(f'Initializing GUI...') - + logger.info(f"Initializing GUI...") + self.module = module self.logger = logger self.log_path = log_path @@ -307,7 +159,7 @@ def run(self, module='acdc_gui', logs_path=None): self.flag = True self.currentPropsID = 0 self.isSegm3D = False - self.newSegmEndName = '' + self.newSegmEndName = "" self.closeGUI = False self.warnKeyPressedMsg = None self.img1ChannelGradients = {} @@ -326,25 +178,22 @@ def run(self, module='acdc_gui', logs_path=None): self.whyNavigateDisabled = set() self.autoSaveTimer = QTimer() self.dirtyPointsLayerTableEndNames = set() - + self._setup_vars_combine() - if 'autoSaveIntevalValue' not in self.df_settings.index: + if "autoSaveIntevalValue" not in self.df_settings.index: autoSaveIntevalValue = 2 - autoSaveIntervalUnit = 'minutes' + autoSaveIntervalUnit = "minutes" else: autoSaveIntevalValue = float( - self.df_settings.at['autoSaveIntevalValue', 'value'] + self.df_settings.at["autoSaveIntevalValue", "value"] ) autoSaveIntervalUnit = str( - self.df_settings.at['autoSaveIntervalUnit', 'value'] + self.df_settings.at["autoSaveIntervalUnit", "value"] ) - - self.autoSaveIntevalValueUnit = ( - autoSaveIntevalValue, autoSaveIntervalUnit - ) + + self.autoSaveIntevalValueUnit = (autoSaveIntevalValue, autoSaveIntervalUnit) self.logger.info( - 'Autosave interval set to: ' - f'{autoSaveIntevalValue} {autoSaveIntervalUnit}' + f"Autosave interval set to: {autoSaveIntevalValue} {autoSaveIntervalUnit}" ) self.checkableButtons = [] @@ -352,7 +201,6 @@ def run(self, module='acdc_gui', logs_path=None): self.toolsActiveInProj3Dsegm = set() self.customAnnotDict = {} - # Keep a list of functions that are not functional in 3D, yet self.functionsNotTested3D = [] self.isSnapshot = False @@ -362,11 +210,9 @@ def run(self, module='acdc_gui', logs_path=None): self.countKeyPress = 0 self.countRightClicks = 0 self.xHoverImg, self.yHoverImg = None, None - - # Keep track on what frames the on first visit tools already ran + self.lastFrameRanOnFirstVisitTools = 0 - - # Buttons added to QButtonGroup will be mutually exclusive + self.checkableQButtonsGroup = QButtonGroup(self) self.checkableQButtonsGroup.setExclusive(False) @@ -389,7 +235,6 @@ def run(self, module='acdc_gui', logs_path=None): self.gui_connectActions() self.gui_createStatusBar() - # self.gui_createTerminalWidget() self.gui_createGraphicsPlots() self.gui_addGraphicsItems() @@ -413,32795 +258,5 @@ def run(self, module='acdc_gui', logs_path=None): self.initShortcuts() self.show() QTimer.singleShot(100, self.resizeRangeWelcomeText) - # self.installEventFilter(self) - - self.logger.info('GUI ready.') - - def initGlobalAttr(self): - self.setOverlayColors() - - self.initImgCmap() - - # Colormap - self.setLut() - - self.fluoDataChNameActions = [] - - self.splineHoverON = False - self.tempSegmentON = False - self.xyOnCtrlPressedFirstTime = None - self.typingEditID = False - self.prevAnnotOptions = None - self.ghostObject = None - self.autoContourHoverON = False - self.navigateScrollBarStartedMoving = True - self.zSliceScrollBarStartedMoving = True - self.labelRoiRunning = False - self.isRangeReset = True - self.lastManualSeparateState = None - self.editIDmergeIDs = True - self.doNotAskAgainExistingID = False - self.doubleRightClickTimeElapsed = False - self.isRealTimeTrackerInitialized = False - self.isWarningCcaIntegrity = False - self.isDoubleRightClick = False - self.isExportingVideo = False - self.pointsLayersNeverToggled = True - self.highlightedIDopts = None - self.timestampStartTimedelta = timedelta(seconds=0) - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - self._ZprojWidgersEnabledState = None - self.imgValueFormatter = 'd' - self.rawValueFormatter = 'd' - self.lastHoverID = -1 - self.annotOptionsToRestore = None - self.annotOptionsToRestoreRight = None - self.rescaleIntensChannelHowMapper = { - self.user_ch_name: 'Rescale each 2D image' - } - self.timestampDialog = None - self.scaleBarDialog = None - self.countObjsWindow = None - self.initLabelRoiModelDialog = None - - # Second channel used by cellpose - self.secondChannelName = None - - self.ax1_viewRange = None - self.measurementsWin = None - - self.model_kwargs = None - self.segmModelName = None - self.labelRoiModel = None - self.autoSegmDoNotAskAgain = False - self.labelRoiGarbageWorkers = [] - self.labelRoiActiveWorkers = [] - - self.clickedOnBud = False - self.postProcessSegmWin = None - - self.UserEnforced_DisabledTracking = False - self.UserEnforced_Tracking = False - - self.ax1BrushHoverID = 0 - - self.disabled_cca_warnings = set() - - self.last_pos_i = -1 - self.last_frame_i = -1 - - # Plots items - self.isMouseDragImg2 = False - self.isMouseDragImg1 = False - self.isMovingLabel = False - self.isRightClickDragImg1 = False - self.clickObjYc, self.clickObjXc = None, None - - self.cca_df_colnames = cca_df_colnames - self.cca_df_dtypes = [ - str, int, int, str, int, int, bool, bool, int - ] - self.cca_df_default_values = list(base_cca_dict.values()) - self.cca_df_int_cols = [ - col for col in cca_df_colnames if type(base_cca_dict[col]) == int - ] - self.lin_tree_df_bool_col = [ - col for col in cca_df_colnames - if isinstance(base_cca_dict[col], bool) - ] - - self.lin_tree_col_checks = [ - 'generation_num', - ] - - # self.lin_tree_df_colnames = set(base_cca_df.keys()) | set(lineage_tree_cols) - # # self.lin_tree_df_dtypes = [ #dk if i need this, for now ignored - # # str, int, int, str, int, int, bool, bool, int - # # ] - # self.lin_tree_df_default_values = list(base_cca_df.values()) + lineage_tree_cols_std_val - self.lin_tree_df_int_cols = [ - 'generation_num', - 'relative_ID', - 'emerg_frame_i', - 'division_frame_i', - 'corrected_on_frame_i' - ] - self.lin_tree_df_bool_col = [ - 'is_history_known', - ] - - self.lin_tree_col_checks = [ - 'generation_num', - ] - - self.lin_tree_df_colnames = self.lin_tree_df_int_cols + self.lin_tree_df_bool_col + self.lin_tree_col_checks - self.SegForLostIDsSettings = {} - - def setWindowIcon(self, icon=None): - if icon is None: - icon = QIcon(":icon.ico") - super().setWindowIcon(icon) - - def setWindowTitle(self, title=None): - if title is None: - title = f'Cell-ACDC v{self._acdc_version} - GUI' - super().setWindowTitle(title) - - def initProfileModels(self): - self.logger.info('Initiliazing profilers...') - - from ._profile.spline_to_obj import model - - self.splineToObjModel = model.Model() - - self.splineToObjModel.fit() - - def setDisabled(self, disabled:bool, keepDisabled:bool=None, force:bool=False): - if force: - if disabled: - super().setDisabled(disabled) - return - else: - self.keepDisabled = False - super().setDisabled(disabled) - return - - if keepDisabled is not None: - self.keepDisabled = keepDisabled - - if self.keepDisabled: - if disabled: - super().setDisabled(disabled) - return - else: - return - else: - super().setDisabled(disabled) - - def readRecentPaths(self, recent_paths_path=None): - # Step 0. Remove the old options from the menu - self.openRecentMenu.clear() - - # Step 1. Read recent Paths - if recent_paths_path is None: - recent_paths_path = recentPaths_path - - if os.path.exists(recent_paths_path): - df = pd.read_csv(recent_paths_path, index_col='index') - df['path'] = df['path'].str.replace('\\', '/') - df = df.drop_duplicates(subset=['path']) - df.to_csv(recent_paths_path) - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - recentPaths = df['path'].to_list() - else: - recentPaths = [] - - # Step 2. Dynamically create the actions - actions = [] - for path in recentPaths: - if not os.path.exists(path): - continue - action = QAction(path, self) - action.triggered.connect(partial(self.openRecentFile, path)) - actions.append(action) - - # Step 3. Add the actions to the menu - self.openRecentMenu.addActions(actions) - - def addPathToOpenRecentMenu(self, path): - for action in self.openRecentMenu.actions(): - if path == action.text(): - break - else: - action = QAction(path, self) - action.triggered.connect(partial(self.openRecentFile, path)) - - try: - firstAction = self.openRecentMenu.actions()[0] - self.openRecentMenu.insertAction(firstAction, action) - except Exception as e: - pass - - def loadLastSessionSettings(self): - self.settings_csv_path = settings_csv_path - if os.path.exists(settings_csv_path): - self.df_settings = pd.read_csv( - settings_csv_path, index_col='setting' - ) - if 'is_bw_inverted' not in self.df_settings.index: - self.df_settings.at['is_bw_inverted', 'value'] = 'No' - else: - self.df_settings.loc['is_bw_inverted'] = ( - self.df_settings.loc['is_bw_inverted'].astype(str) - ) - if 'fontSize' not in self.df_settings.index: - self.df_settings.at['fontSize', 'value'] = 12 - if 'overlayColor' not in self.df_settings.index: - self.df_settings.at['overlayColor', 'value'] = '255-255-0' - if 'how_normIntensities' not in self.df_settings.index: - raw = 'Do not normalize. Display raw image' - self.df_settings.at['how_normIntensities', 'value'] = raw - else: - idx = ['is_bw_inverted', 'fontSize', 'overlayColor', 'how_normIntensities'] - values = ['No', 12, '255-255-0', 'raw'] - self.df_settings = pd.DataFrame({ - 'setting': idx,'value': values} - ).set_index('setting') - - if 'isLabelsVisible' not in self.df_settings.index: - self.df_settings.at['isLabelsVisible', 'value'] = 'No' - - if 'isNextFrameVisible' not in self.df_settings.index: - self.df_settings.at['isNextFrameVisible', 'value'] = 'No' - - if 'isRightImageVisible' not in self.df_settings.index: - self.df_settings.at['isRightImageVisible', 'value'] = 'Yes' - - if 'manual_separate_draw_mode' not in self.df_settings.index: - col = 'manual_separate_draw_mode' - self.df_settings.at[col, 'value'] = 'threepoints_arc' - - if 'colorScheme' in self.df_settings.index: - col = 'colorScheme' - self._colorScheme = self.df_settings.at[col, 'value'] - else: - self._colorScheme = 'light' - - self.doNotShowAgainMissingCca = False - if 'doNotShowAgainMissingCca' not in self.df_settings.index: - self.df_settings.at['doNotShowAgainMissingCca', 'value'] = 'No' - else: - val = self.df_settings.at['doNotShowAgainMissingCca', 'value'] - self.doNotShowAgainMissingCca = val=='Yes' - - def dragEnterEvent(self, event): - file_path = event.mimeData().urls()[0].toLocalFile() - if os.path.isdir(file_path): - exp_path = file_path - basename = os.path.basename(file_path) - if basename.find('Position_')!=-1 or basename=='Images': - event.acceptProposedAction() - else: - event.ignore() - else: - event.acceptProposedAction() - - def dropEvent(self, event): - event.setDropAction(Qt.CopyAction) - file_path = event.mimeData().urls()[0].toLocalFile() - self.logger.info(f'Dragged and dropped path "{file_path}"') - basename = os.path.basename(file_path) - if os.path.isdir(file_path): - exp_path = file_path - self.openFolder(exp_path=exp_path) - else: - self.openFile(file_path=file_path) - - def changeEvent(self, event): - try: - self.delObjToolAction.setChecked(False) - except Exception as err: - return - - def leaveEvent(self, event): - if self.slideshowWin is not None: - posData = self.data[self.pos_i] - mainWinGeometry = self.geometry() - mainWinLeft = mainWinGeometry.left() - mainWinTop = mainWinGeometry.top() - mainWinWidth = mainWinGeometry.width() - mainWinHeight = mainWinGeometry.height() - mainWinRight = mainWinLeft+mainWinWidth - mainWinBottom = mainWinTop+mainWinHeight - - slideshowWinGeometry = self.slideshowWin.geometry() - slideshowWinLeft = slideshowWinGeometry.left() - slideshowWinTop = slideshowWinGeometry.top() - slideshowWinWidth = slideshowWinGeometry.width() - slideshowWinHeight = slideshowWinGeometry.height() - - # Determine if overlap - overlap = ( - (slideshowWinTop < mainWinBottom) and - (slideshowWinLeft < mainWinRight) - ) - - autoActivate = ( - self.isDataLoaded and not - overlap and not - posData.disableAutoActivateViewerWindow - ) - - if autoActivate: - self.slideshowWin.setFocus() - self.slideshowWin.activateWindow() - - def enterEvent(self, event): - event.accept() - if self.slideshowWin is not None: - posData = self.data[self.pos_i] - mainWinGeometry = self.geometry() - mainWinLeft = mainWinGeometry.left() - mainWinTop = mainWinGeometry.top() - mainWinWidth = mainWinGeometry.width() - mainWinHeight = mainWinGeometry.height() - mainWinRight = mainWinLeft+mainWinWidth - mainWinBottom = mainWinTop+mainWinHeight - - slideshowWinGeometry = self.slideshowWin.geometry() - slideshowWinLeft = slideshowWinGeometry.left() - slideshowWinTop = slideshowWinGeometry.top() - slideshowWinWidth = slideshowWinGeometry.width() - slideshowWinHeight = slideshowWinGeometry.height() - - # Determine if overlap - overlap = ( - (slideshowWinTop < mainWinBottom) and - (slideshowWinLeft < mainWinRight) - ) - - autoActivate = ( - self.isDataLoaded and not - overlap and not - posData.disableAutoActivateViewerWindow - ) - - if autoActivate: - # self.setFocus() - self.activateWindow() - - def isPanImageClick(self, mouseEvent, modifiers): - left_click = mouseEvent.button() == Qt.MouseButton.LeftButton - return modifiers == Qt.AltModifier and left_click - - def middleClickText(self): - if self.delObjAction is None and is_mac: - return 'Command + Left Click' - - if self.delObjAction is None: - return 'Middle Click' - - delObjKeySequence, delObjQtButton = self.delObjAction - - if delObjQtButton == Qt.MouseButton.LeftButton: - buttonName = 'Left click' - elif delObjQtButton == Qt.MouseButton.RightButton: - buttonName = 'Right click' - else: - buttonName = 'Middle click' - - if delObjKeySequence is None: - return buttonName - - return f'{delObjKeySequence.toString()} + {buttonName}' - - def isDefaultMiddleClick(self, mouseEvent, modifiers): - if is_mac: - middle_click = ( - mouseEvent.button() == Qt.MouseButton.LeftButton - and modifiers == Qt.ControlModifier - and not self.brushButton.isChecked() - ) - else: - middle_click = mouseEvent.button() == Qt.MouseButton.MiddleButton - return middle_click - - def isMiddleClick(self, mouseEvent, modifiers): - if self.delObjAction is None: - return self.isDefaultMiddleClick(mouseEvent, modifiers) - - delObjKeySequence, delObjQtButton = self.delObjAction - if delObjKeySequence is None: - # Setting only middle click on mac is allowed, however the - # delObjKeySequence is None and the tool button is never checked - isDelObjectActive = True - else: - isDelObjectActive = self.delObjToolAction.isChecked() - - mouseEventButton = self.changeRightClickToLeftOnMac(mouseEvent) - - middle_click = ( - mouseEventButton == delObjQtButton and isDelObjectActive - ) - - return middle_click - - def gui_createCursors(self): - pixmap = QPixmap(":wand_cursor.svg") - self.wandCursor = QCursor(pixmap, 16, 16) - - pixmap = QPixmap(":curv_cursor.svg") - self.curvCursor = QCursor(pixmap, 16, 16) - - pixmap = QPixmap(":addDelPolyLineRoi_cursor.svg") - self.polyLineRoiCursor = QCursor(pixmap, 16, 16) - - pixmap = QPixmap(":cross_cursor.svg") - self.addPointsCursor = QCursor(pixmap, 16, 16) - - def gui_createMenuBar(self): - menuBar = self.menuBar() - menuBar.setNativeMenuBar(False) - - # File menu - fileMenu = QMenu("&File", self) - self.fileMenu = fileMenu - menuBar.addMenu(fileMenu) - if self.debug: - fileMenu.addAction(self.createEmptyDataAction) - fileMenu.addAction(self.newAction) - fileMenu.addAction(self.newWindowAction) - fileMenu.addSeparator() - fileMenu.addAction(self.openFolderAction) - fileMenu.addAction(self.openFileAction) - # Open Recent submenu - self.openRecentMenu = fileMenu.addMenu("Open Recent") - fileMenu.addSeparator() - fileMenu.addAction(self.manageVersionsAction) - fileMenu.addAction(self.saveAction) - fileMenu.addAction(self.saveAsAction) - fileMenu.addAction(self.quickSaveAction) - fileMenu.addSeparator() - - self.exportMenu = fileMenu.addMenu('Export') - self.exportMenu.addAction(self.exportToVideoAction) - self.exportMenu.addAction(self.exportToImageAction) - fileMenu.addSeparator() - fileMenu.addAction(self.loadFluoAction) - fileMenu.addAction(self.loadPosAction) - # Separator - self.fileMenu.lastSeparator = fileMenu.addSeparator() - fileMenu.addAction(self.exitAction) - - # Edit menu - editMenu = menuBar.addMenu("&Edit") - editMenu.addSeparator() - - editMenu.addAction(self.editShortcutsAction) - editMenu.addAction(self.editTextIDsColorAction) - editMenu.addAction(self.editOverlayColorAction) - editMenu.addAction(self.manuallyEditCcaAction) - editMenu.addAction(self.enableSmartTrackAction) - editMenu.addAction(self.enableAutoZoomToCellsAction) - - # View menu - self.viewMenu = menuBar.addMenu("&View") - self.viewMenu.addSeparator() - self.viewMenu.addAction(self.viewCcaTableAction) - - # Image menu - ImageMenu = menuBar.addMenu("&Image") - ImageMenu.addSeparator() - ImageMenu.addAction(self.imgPropertiesAction) - self.defaultRescaleIntensLutMenu = ImageMenu.addMenu( - "Default method to rescale intensities (LUT)" - ) - - self.defaultRescaleIntensActionGroup = QActionGroup( - self.defaultRescaleIntensLutMenu - ) - howTexts = ( - 'Rescale each 2D image', - 'Rescale across z-stack', - 'Rescale across time frames', - 'Do no rescale, display raw image' - ) - try: - self.defaultRescaleIntensHow = ( - self.df_settings.at['default_rescale_intens_how', 'value'] - ) - except Exception as err: - self.defaultRescaleIntensHow = howTexts[0] - - for howText in howTexts: - action = QAction(howText, self.defaultRescaleIntensLutMenu) - action.setCheckable(True) - if howText == self.defaultRescaleIntensHow: - action.setChecked(True) - - self.defaultRescaleIntensActionGroup.addAction(action) - self.defaultRescaleIntensLutMenu.addAction(action) - - ImageMenu.addAction(self.addScaleBarAction) - ImageMenu.addAction(self.addTimestampAction) - - self.rescaleIntensMenu = ImageMenu.addMenu('Rescale intensities (LUT)') - - ImageMenu.addAction(self.preprocessAction) - ImageMenu.addAction(self.combineChannelsAction) - ImageMenu.addAction(self.saveLabColormapAction) - ImageMenu.addAction(self.shuffleCmapAction) - ImageMenu.addAction(self.greedyShuffleCmapAction) - ImageMenu.addAction(self.zoomToObjsAction) - ImageMenu.addAction(self.zoomOutAction) - - # Segment menu - SegmMenu = menuBar.addMenu("&Segment") - self.segmentMenu = SegmMenu - SegmMenu.addSeparator() - self.segmSingleFrameMenu = SegmMenu.addMenu('Segment displayed frame') - for action in self.segmActions: - self.segmSingleFrameMenu.addAction(action) - - self.segmSingleFrameMenu.addSeparator() - self.segmSingleFrameMenu.addAction(self.addCustomModelFrameAction) - - self.segmVideoMenu = SegmMenu.addMenu('Segment multiple frames') - for action in self.segmActionsVideo: - self.segmVideoMenu.addAction(action) - - self.segmVideoMenu.addSeparator() - self.segmVideoMenu.addAction(self.addCustomModelVideoAction) - - self.segmWithPromptableModelMenu = SegmMenu.addMenu( - 'Segment with promptable model' - ) - - self.segmWithPromptableModelMenu.addAction( - self.segmWithPromptableModelAction - ) - - self.segmWithPromptableModelMenu.addSeparator() - self.segmWithPromptableModelMenu.addAction( - self.addCustomPromptModelAction - ) - - SegmMenu.addAction(self.EditSegForLostIDsSetSettings) - SegmMenu.addAction(self.postProcessSegmAction) - SegmMenu.addAction(self.autoSegmAction) - SegmMenu.addAction(self.relabelSequentialAction) - SegmMenu.aboutToShow.connect(self.nonViewerEditMenuOpened) - - # Tracking menu - trackingMenu = menuBar.addMenu("&Tracking") - self.trackingMenu = trackingMenu - trackingMenu.addSeparator() - selectTrackAlgoMenu = trackingMenu.addMenu( - 'Select real-time tracking algorithm' - ) - for rtTrackerAction in self.trackingAlgosGroup.actions(): - selectTrackAlgoMenu.addAction(rtTrackerAction) - - trackingMenu.addAction(self.editRtTrackerParamsAction) - trackingMenu.addAction(self.repeatTrackingVideoAction) - - trackingMenu.addAction(self.repeatTrackingMenuAction) - trackingMenu.aboutToShow.connect(self.nonViewerEditMenuOpened) - - if self.mainWin is not None: - trackingMenu.addAction( - self.mainWin.applyTrackingFromTableAction - ) - trackingMenu.addAction( - self.mainWin.applyTrackingFromTrackMateXMLAction - ) - - # Measurements menu - measurementsMenu = menuBar.addMenu("&Measurements") - self.measurementsMenu = measurementsMenu - measurementsMenu.addSeparator() - measurementsMenu.addAction(self.setMeasurementsAction) - measurementsMenu.addAction(self.addCustomMetricAction) - measurementsMenu.addAction(self.addCombineMetricAction) - measurementsMenu.setDisabled(True) - - # Settings menu - self.settingsMenu = QMenu("Settings", self) - menuBar.addMenu(self.settingsMenu) - self.settingsMenu.addAction(self.invertBwAction) - self.settingsMenu.addAction(self.toggleColorSchemeAction) - self.settingsMenu.addSeparator() - self.settingsMenu.addAction(self.pxModeAction) - self.settingsMenu.addAction(self.highLowResAction) - self.settingsMenu.addAction(self.editShortcutsAction) - self.settingsMenu.addAction(self.showMirroredCursorAction) - self.settingsMenu.addSeparator() - self.settingsMenu.addAction(self.editAutoSaveIntervalAction) - self.settingsMenu.addSeparator() - - # Mode menu (actions added when self.modeComboBox is created) - self.modeMenu = menuBar.addMenu('Mode') - self.modeMenu.menuAction().setVisible(False) - - # Help menu - helpMenu = menuBar.addMenu("&Help") - helpMenu.addAction(self.openLogFileAction) - helpMenu.addAction(self.showLogFilesAction) - helpMenu.addAction(self.tipsAction) - helpMenu.addAction(self.UserManualAction) - helpMenu.addSeparator() - helpMenu.addAction(self.aboutAction) - self.helpMenu = helpMenu - - def gui_createToolBars(self): - # File toolbar - fileToolBar = self.addToolBar("File") - # fileToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) - fileToolBar.setMovable(False) - - self.segmNdimIndicatorAction = fileToolBar.addWidget( - self.segmNdimIndicator - ) - self.segmNdimIndicatorAction.setVisible(False) - fileToolBar.addAction(self.newAction) - fileToolBar.addAction(self.openFolderAction) - fileToolBar.addAction(self.openFileAction) - fileToolBar.addAction(self.manageVersionsAction) - fileToolBar.addAction(self.saveAction) - fileToolBar.addAction(self.showInExplorerAction) - # fileToolBar.addAction(self.reloadAction) - fileToolBar.addAction(self.undoAction) - fileToolBar.addAction(self.redoAction) - self.fileToolBar = fileToolBar - self.setEnabledFileToolbar(False) - - self.undoAction.setEnabled(False) - self.redoAction.setEnabled(False) - - # Navigation toolbar - navigateToolBar = widgets.ToolBar("Navigation", self) - navigateToolBar.setContextMenuPolicy(Qt.PreventContextMenu) - # navigateToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) - self.addToolBar(navigateToolBar) - navigateToolBar.addAction(self.findIdAction) - - navigateToolBar.addWidget(self.zoomRectButton) - - self.slideshowButton = QToolButton(self) - self.slideshowButton.setIcon(QIcon(":eye-plus.svg")) - self.slideshowButton.setCheckable(True) - self.slideshowButton.setShortcut('Ctrl+W') - navigateToolBar.addWidget(self.slideshowButton) - - navigateToolBar.addAction(self.autoPilotButton) - - # navigateToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) - navigateToolBar.addAction(self.skipToNewIdAction) - - self.preprocessImageAction = QAction('Preprocess image', self) - self.preprocessImageAction.setIcon(QIcon(":filter_image.svg")) - navigateToolBar.addAction(self.preprocessImageAction) - - self.overlayButton = widgets.rightClickToolButton(parent=self) - self.overlayButton.setIcon(QIcon(":overlay.svg")) - self.overlayButton.setCheckable(True) - - self.overlayButtonAction = navigateToolBar.addWidget(self.overlayButton) - # self.checkableButtons.append(self.overlayButton) - # self.checkableQButtonsGroup.addButton(self.overlayButton) - - self.countObjsButton = QToolButton(self) - self.countObjsButton.setIcon(QIcon(":count_objects.svg")) - self.countObjsButton.setCheckable(True) - self.countObjsButton.setShortcut('Ctrl+Shift+C') - self.countObjsButtonAction = navigateToolBar.addWidget( - self.countObjsButton - ) - - self.togglePointsLayerAction = QAction('Activate points layer', self) - self.togglePointsLayerAction.setCheckable(True) - self.togglePointsLayerAction.setIcon(QIcon(":pointsLayer.svg")) - navigateToolBar.addAction(self.togglePointsLayerAction) - - self.overlayLabelsButton = widgets.rightClickToolButton(parent=self) - self.overlayLabelsButton.setIcon(QIcon(":overlay_labels.svg")) - self.overlayLabelsButton.setCheckable(True) - # self.overlayLabelsButton.setVisible(False) - self.overlayLabelsButtonAction = navigateToolBar.addWidget( - self.overlayLabelsButton - ) - self.overlayLabelsButtonAction.setVisible(False) - - self.rulerButton = QToolButton(self) - self.rulerButton.setIcon(QIcon(":ruler.svg")) - self.rulerButton.setCheckable(True) - navigateToolBar.addWidget(self.rulerButton) - self.checkableButtons.append(self.rulerButton) - self.LeftClickButtons.append(self.rulerButton) - - # fluorescence image color widget - colorsToolBar = widgets.ToolBar("Colors", self) - - self.overlayColorButton = pg.ColorButton(self, color=(230,230,230)) - self.overlayColorButton.setDisabled(True) - colorsToolBar.addWidget(self.overlayColorButton) - - self.textIDsColorButton = pg.ColorButton(self) - colorsToolBar.addWidget(self.textIDsColorButton) - - self.addToolBar(colorsToolBar) - colorsToolBar.setVisible(False) - - self.navigateToolBar = navigateToolBar - - # cca toolbar - ccaToolBar = widgets.ToolBar("Cell cycle annotations", self) - self.addToolBar(ccaToolBar) - - # Assign mother to bud button - self.assignBudMothButton = QToolButton(self) - self.assignBudMothButton.setIcon(QIcon(":assign-motherbud.svg")) - self.assignBudMothButton.setCheckable(True) - self.assignBudMothButton.setShortcut('A') - self.assignBudMothButton.setVisible(False) - self.assignBudMothButton.action = ccaToolBar.addWidget( - self.assignBudMothButton - ) - self.checkableButtons.append(self.assignBudMothButton) - self.checkableQButtonsGroup.addButton(self.assignBudMothButton) - self.functionsNotTested3D.append(self.assignBudMothButton) - - - # Set is_history_known button - self.setIsHistoryKnownButton = QToolButton(self) - self.setIsHistoryKnownButton.setIcon(QIcon(":history.svg")) - self.setIsHistoryKnownButton.setCheckable(True) - self.setIsHistoryKnownButton.setShortcut('U') - self.setIsHistoryKnownButton.setVisible(False) - self.setIsHistoryKnownButton.action = ccaToolBar.addWidget( - self.setIsHistoryKnownButton - ) - self.checkableButtons.append(self.setIsHistoryKnownButton) - self.checkableQButtonsGroup.addButton(self.setIsHistoryKnownButton) - self.functionsNotTested3D.append(self.setIsHistoryKnownButton) - - ccaToolBar.addAction(self.assignBudMothAutoAction) - ccaToolBar.addAction(self.editCcaToolAction) - ccaToolBar.addAction(self.reInitCcaAction) - ccaToolBar.setVisible(False) - self.ccaToolBar = ccaToolBar - self.functionsNotTested3D.append(self.assignBudMothAutoAction) - self.functionsNotTested3D.append(self.reInitCcaAction) - self.functionsNotTested3D.append(self.editCcaToolAction) - - # Edit toolbar - editToolBar = widgets.ToolBar("Edit", self) - editToolBar.setContextMenuPolicy(Qt.PreventContextMenu) - - self.addToolBar(editToolBar) - - self.manulAnnotToolButtons = set() - - self.brushButton = QToolButton(self) - self.brushButton.setIcon(QIcon(":brush.svg")) - self.brushButton.setCheckable(True) - editToolBar.addWidget(self.brushButton) - self.checkableButtons.append(self.brushButton) - self.LeftClickButtons.append(self.brushButton) - self.brushButton.keyPressShortcut = Qt.Key_B - self.widgetsWithShortcut['Brush'] = self.brushButton - self.manulAnnotToolButtons.add(self.brushButton) - - self.eraserButton = QToolButton(self) - self.eraserButton.setIcon(QIcon(":eraser.svg")) - self.eraserButton.setCheckable(True) - editToolBar.addWidget(self.eraserButton) - self.eraserButton.keyPressShortcut = Qt.Key_X - self.widgetsWithShortcut['Eraser'] = self.eraserButton - self.checkableButtons.append(self.eraserButton) - self.LeftClickButtons.append(self.eraserButton) - self.manulAnnotToolButtons.add(self.eraserButton) - - self.curvToolButton = QToolButton(self) - self.curvToolButton.setIcon(QIcon(":curvature-tool.svg")) - self.curvToolButton.setCheckable(True) - self.curvToolButton.setShortcut('C') - self.curvToolButton.action = editToolBar.addWidget(self.curvToolButton) - self.LeftClickButtons.append(self.curvToolButton) - # self.functionsNotTested3D.append(self.curvToolButton) - self.widgetsWithShortcut['Curvature tool'] = self.curvToolButton - # self.checkableButtons.append(self.curvToolButton) - self.manulAnnotToolButtons.add(self.curvToolButton) - - self.wandToolButton = QToolButton(self) - self.wandToolButton.setIcon(QIcon(":magic_wand.svg")) - self.wandToolButton.setCheckable(True) - self.wandToolButton.setShortcut('Ctrl+D') - self.wandToolButton.action = editToolBar.addWidget(self.wandToolButton) - self.LeftClickButtons.append(self.wandToolButton) - self.checkableButtons.append(self.eraserButton) - self.widgetsWithShortcut['Magic wand'] = self.wandToolButton - - self.magicPromptsToolButton = QToolButton(self) - self.magicPromptsToolButton.setIcon(QIcon(":magic-prompts.svg")) - self.magicPromptsToolButton.setCheckable(True) - self.magicPromptsToolButton.setShortcut('W') - self.magicPromptsToolButton.action = editToolBar.addWidget( - self.magicPromptsToolButton - ) - self.widgetsWithShortcut['Magic prompts'] = self.magicPromptsToolButton - - self.drawClearRegionButton = QToolButton(self) - self.drawClearRegionButton.setCheckable(True) - self.drawClearRegionButton.setIcon(QIcon(":clear_freehand_region.svg")) - self.widgetsWithShortcut['Clear freehand region'] = ( - self.drawClearRegionButton - ) - self.toolsActiveInProj3Dsegm.add(self.drawClearRegionButton) - - self.checkableButtons.append(self.drawClearRegionButton) - self.LeftClickButtons.append(self.drawClearRegionButton) - - self.drawClearRegionAction = editToolBar.addWidget( - self.drawClearRegionButton - ) - - self.widgetsWithShortcut['Annotate mother/daughter pairing'] = ( - self.assignBudMothButton - ) - self.widgetsWithShortcut['Annotate unknown history'] = ( - self.setIsHistoryKnownButton - ) - - self.copyLostObjButton = QToolButton(self) - self.copyLostObjButton.setIcon(QIcon(":copyContour.svg")) - self.copyLostObjButton.setCheckable(True) - self.copyLostObjButton.setShortcut('V') - self.copyLostObjButton.action = editToolBar.addWidget( - self.copyLostObjButton - ) - self.checkableButtons.append(self.copyLostObjButton) - self.checkableQButtonsGroup.addButton(self.copyLostObjButton) - self.widgetsWithShortcut['Copy lost object contour'] = ( - self.copyLostObjButton - ) - self.functionsNotTested3D.append(self.copyLostObjButton) - - self.labelRoiButton = widgets.rightClickToolButton(parent=self) - self.labelRoiButton.setIcon(QIcon(":label_roi.svg")) - self.labelRoiButton.setCheckable(True) - self.labelRoiButton.setShortcut('L') - self.labelRoiButton.action = editToolBar.addWidget(self.labelRoiButton) - self.LeftClickButtons.append(self.labelRoiButton) - self.checkableButtons.append(self.labelRoiButton) - self.checkableQButtonsGroup.addButton(self.labelRoiButton) - self.widgetsWithShortcut['Label ROI'] = self.labelRoiButton - # self.functionsNotTested3D.append(self.labelRoiButton) - - self.manualAnnotPastButton = QToolButton(self) - self.manualAnnotPastButton.setIcon(QIcon(":lock_id_annotate_future.svg")) - self.manualAnnotPastButton.setCheckable(True) - self.manualAnnotPastButton.setShortcut('Y') - self.manualAnnotPastButton.action = editToolBar.addWidget( - self.manualAnnotPastButton - ) - self.checkableButtons.append(self.manualAnnotPastButton) - self.widgetsWithShortcut['Lock ID and annotate single object'] = ( - self.manualAnnotPastButton - ) - self.functionsNotTested3D.append(self.manualAnnotPastButton) - self.manulAnnotToolButtons.add(self.manualAnnotPastButton) - - self.segmentToolAction = QAction('Segment with last used model', self) - self.segmentToolAction.setIcon(QIcon(":segment.svg")) - self.segmentToolAction.setShortcut('R') - self.widgetsWithShortcut['Repeat segmentation'] = self.segmentToolAction - editToolBar.addAction(self.segmentToolAction) - - self.segForLostIDsButton = QToolButton(self) - self.segForLostIDsButton.setIcon(QIcon(":segForLostIDs.svg")) - self.segForLostIDsAction = editToolBar.addWidget( - self.segForLostIDsButton - ) - self.segForLostIDsButton.clicked.connect( - self.segForLostIDsButtonClicked - ) - - # self.SegForLostIDsButton.setShortcut('U') - # self.widgetsWithShortcut['Unknown lineage (lineage tree)'] = self.SegForLostIDsButton - - self.manualBackgroundButton = QToolButton(self) - self.manualBackgroundButton.setIcon(QIcon(":manual_background.svg")) - self.manualBackgroundButton.setCheckable(True) - self.manualBackgroundButton.setShortcut('G') - self.LeftClickButtons.append(self.manualBackgroundButton) - self.checkableButtons.append(self.manualBackgroundButton) - self.checkableQButtonsGroup.addButton(self.manualBackgroundButton) - self.widgetsWithShortcut['Manual background'] = self.manualBackgroundButton - - self.manualBackgroundAction = editToolBar.addWidget( - self.manualBackgroundButton - ) - - self.delObjsOutSegmMaskAction = QAction( - QIcon(":del_objs_out_segm.svg"), - 'Select a segmentation file and delete all objects on the background', - self - ) - self.delObjsOutSegmMaskAction.setShortcut('I') - self.widgetsWithShortcut['Delete all objects outside segm'] = ( - self.delObjsOutSegmMaskAction - ) - editToolBar.addAction(self.delObjsOutSegmMaskAction) - - self.hullContToolButton = QToolButton(self) - self.hullContToolButton.setIcon(QIcon(":hull.svg")) - self.hullContToolButton.setCheckable(True) - self.hullContToolButton.setShortcut('O') - self.hullContToolButton.action = editToolBar.addWidget(self.hullContToolButton) - self.checkableButtons.append(self.hullContToolButton) - self.checkableQButtonsGroup.addButton(self.hullContToolButton) - self.functionsNotTested3D.append(self.hullContToolButton) - self.widgetsWithShortcut['Hull contour'] = self.hullContToolButton - - self.fillHolesToolButton = QToolButton(self) - self.fillHolesToolButton.setIcon(QIcon(":fill_holes.svg")) - self.fillHolesToolButton.setCheckable(True) - self.fillHolesToolButton.setShortcut('F') - self.fillHolesToolButton.action = editToolBar.addWidget( - self.fillHolesToolButton - ) - self.checkableButtons.append(self.fillHolesToolButton) - self.checkableQButtonsGroup.addButton(self.fillHolesToolButton) - self.functionsNotTested3D.append(self.fillHolesToolButton) - self.widgetsWithShortcut['Fill holes'] = self.fillHolesToolButton - - self.moveLabelToolButton = QToolButton(self) - self.moveLabelToolButton.setIcon(QIcon(":moveLabel.svg")) - self.moveLabelToolButton.setCheckable(True) - self.moveLabelToolButton.setShortcut('P') - self.moveLabelToolButton.action = editToolBar.addWidget(self.moveLabelToolButton) - self.checkableButtons.append(self.moveLabelToolButton) - self.checkableQButtonsGroup.addButton(self.moveLabelToolButton) - self.widgetsWithShortcut['Move label'] = self.moveLabelToolButton - - self.expandLabelToolButton = QToolButton(self) - self.expandLabelToolButton.setIcon(QIcon(":expandLabel.svg")) - self.expandLabelToolButton.setCheckable(True) - self.expandLabelToolButton.setShortcut('E') - self.expandLabelToolButton.action = editToolBar.addWidget(self.expandLabelToolButton) - self.expandLabelToolButton.hide() - self.checkableButtons.append(self.expandLabelToolButton) - self.LeftClickButtons.append(self.expandLabelToolButton) - self.checkableQButtonsGroup.addButton(self.expandLabelToolButton) - self.widgetsWithShortcut['Expand/shrink label'] = self.expandLabelToolButton - - self.editIDbutton = QToolButton(self) - self.editIDbutton.setIcon(QIcon(":edit-id.svg")) - self.editIDbutton.setCheckable(True) - self.editIDbutton.setShortcut('N') - editToolBar.addWidget(self.editIDbutton) - self.checkableButtons.append(self.editIDbutton) - self.checkableQButtonsGroup.addButton(self.editIDbutton) - self.widgetsWithShortcut['Edit ID'] = self.editIDbutton - - self.separateBudButton = QToolButton(self) - self.separateBudButton.setIcon(QIcon(":separate-bud.svg")) - self.separateBudButton.setCheckable(True) - self.separateBudButton.setShortcut('S') - self.separateBudButton.action = editToolBar.addWidget(self.separateBudButton) - self.checkableButtons.append(self.separateBudButton) - self.checkableQButtonsGroup.addButton(self.separateBudButton) - # self.functionsNotTested3D.append(self.separateBudButton) - self.widgetsWithShortcut['Separate objects'] = self.separateBudButton - - self.mergeIDsButton = QToolButton(self) - self.mergeIDsButton.setIcon(QIcon(":merge-IDs.svg")) - self.mergeIDsButton.setCheckable(True) - self.mergeIDsButton.setShortcut('M') - self.mergeIDsButton.action = editToolBar.addWidget(self.mergeIDsButton) - self.checkableButtons.append(self.mergeIDsButton) - self.checkableQButtonsGroup.addButton(self.mergeIDsButton) - # self.functionsNotTested3D.append(self.mergeIDsButton) - self.widgetsWithShortcut['Merge objects'] = self.mergeIDsButton - - self.keepIDsButton = QToolButton(self) - self.keepIDsButton.setIcon(QIcon(":keep_objects.svg")) - self.keepIDsButton.setCheckable(True) - self.keepIDsButton.action = editToolBar.addWidget(self.keepIDsButton) - self.keepIDsButton.setShortcut('K') - self.checkableButtons.append(self.keepIDsButton) - self.checkableQButtonsGroup.addButton(self.keepIDsButton) - # self.functionsNotTested3D.append(self.keepIDsButton) - self.widgetsWithShortcut['Select objects to keep'] = self.keepIDsButton - - self.whitelistIDsButton = QToolButton(self) - self.whitelistIDsButton.setIcon(QIcon(":whitelist.svg")) - self.whitelistIDsButton.setCheckable(True) - self.whitelistIDsButton.action = editToolBar.addWidget( - self.whitelistIDsButton - ) - self.whitelistIDsButton.setShortcut('Ctrl+K') - self.checkableButtons.append(self.whitelistIDsButton) - self.checkableQButtonsGroup.addButton(self.whitelistIDsButton) - self.LeftClickButtons.append(self.whitelistIDsButton) - # self.functionsNotTested3D.append(self.whitelistIDsButton) - self.widgetsWithShortcut['Select objects to add to a tracking whitelist'] = ( - self.whitelistIDsButton - ) - - self.binCellButton = QToolButton(self) - self.binCellButton.setIcon(QIcon(":bin.svg")) - self.binCellButton.setCheckable(True) - # self.binCellButton.setShortcut('R') - self.binCellButton.action = editToolBar.addWidget(self.binCellButton) - self.checkableButtons.append(self.binCellButton) - self.checkableQButtonsGroup.addButton(self.binCellButton) - # self.functionsNotTested3D.append(self.binCellButton) - - self.manualTrackingButton = QToolButton(self) - self.manualTrackingButton.setIcon(QIcon(":manual_tracking.svg")) - self.manualTrackingButton.setCheckable(True) - self.manualTrackingButton.setShortcut('T') - self.checkableQButtonsGroup.addButton(self.manualTrackingButton) - self.checkableButtons.append(self.manualTrackingButton) - self.widgetsWithShortcut['Manual tracking'] = self.manualTrackingButton - - self.ripCellButton = QToolButton(self) - self.ripCellButton.setIcon(QIcon(":rip.svg")) - self.ripCellButton.setCheckable(True) - self.ripCellButton.setShortcut('D') - self.ripCellButton.action = editToolBar.addWidget(self.ripCellButton) - self.checkableButtons.append(self.ripCellButton) - self.checkableQButtonsGroup.addButton(self.ripCellButton) - self.functionsNotTested3D.append(self.ripCellButton) - self.widgetsWithShortcut['Annotate cell as dead'] = self.ripCellButton - - editToolBar.addAction(self.addDelRoiAction) - # editToolBar.addAction(self.addDelPolyLineRoiAction) - - self.addDelPolyLineRoiAction = editToolBar.addWidget( - self.addDelPolyLineRoiButton - ) - self.addDelPolyLineRoiAction.roiType = 'polyline' - - editToolBar.addAction(self.delBorderObjAction) - self.delBorderObjAction.button = editToolBar.widgetForAction( - self.delBorderObjAction - ) - editToolBar.addAction(self.delNewObjAction) - self.delNewObjAction.button = editToolBar.widgetForAction( - self.delNewObjAction - ) - - self.addDelRoiAction.toolbar = editToolBar - self.functionsNotTested3D.append(self.addDelRoiAction) - - self.addDelPolyLineRoiAction.toolbar = editToolBar - self.functionsNotTested3D.append(self.addDelPolyLineRoiAction) - - self.delBorderObjAction.toolbar = editToolBar - self.functionsNotTested3D.append(self.delBorderObjAction) - - self.delNewObjAction.toolbar = editToolBar - # self.functionsNotTested3D.append(self.delNewObjAction) so id this doesnt work in 3d i dont know anymore - - editToolBar.addAction(self.repeatTrackingAction) - - self.manualTrackingAction = editToolBar.addWidget( - self.manualTrackingButton - ) - - self.functionsNotTested3D.append(self.repeatTrackingAction) - self.functionsNotTested3D.append(self.manualTrackingAction) - - self.reinitLastSegmFrameAction = QAction(self) - self.reinitLastSegmFrameAction.setIcon(QIcon(":reinitLastSegm.svg")) - self.reinitLastSegmFrameAction.setVisible(False) - editToolBar.addAction(self.reinitLastSegmFrameAction) - editToolBar.setVisible(False) - self.reinitLastSegmFrameAction.toolbar = editToolBar - self.functionsNotTested3D.append(self.reinitLastSegmFrameAction) - - - self.editLin_TreeBar = widgets.ToolBar("Lin Tree Edit", self) - self.editLin_TreeBar.setContextMenuPolicy(Qt.PreventContextMenu) - - self.addToolBar(self.editLin_TreeBar) - self.editLin_TreeGroup = QButtonGroup() - self.editLin_TreeGroup.setExclusive(True) - - self.findNextMotherButton = QToolButton(self) - self.findNextMotherButton.setIcon(QIcon(":magnGlass.svg")) - self.findNextMotherButton.setCheckable(True) - self.editLin_TreeBar.addWidget(self.findNextMotherButton) - self.editLin_TreeGroup.addButton(self.findNextMotherButton) - self.findNextMotherButton.setShortcut('F') - self.widgetsWithShortcut['Find next potential mother (lineage tree)'] = self.findNextMotherButton - self.unknownLineageButton = QToolButton(self) - self.unknownLineageButton.setIcon(QIcon(":history.svg")) - self.unknownLineageButton.setCheckable(True) - self.editLin_TreeBar.addWidget(self.unknownLineageButton) - self.editLin_TreeGroup.addButton(self.unknownLineageButton) - self.unknownLineageButton.setShortcut('U') - self.widgetsWithShortcut['Unknown lineage (lineage tree)'] = self.unknownLineageButton - - self.noToolLinTreeButton = QToolButton(self) - self.noToolLinTreeButton.setIcon(QIcon(":arrow_cursor.svg")) - self.noToolLinTreeButton.setCheckable(True) - self.editLin_TreeBar.addWidget(self.noToolLinTreeButton) - self.editLin_TreeGroup.addButton(self.noToolLinTreeButton) - self.noToolLinTreeButton.setShortcut('N') - self.widgetsWithShortcut['No tool (lineage tree)'] = self.noToolLinTreeButton - - self.propagateLinTreeButton = QToolButton(self) - self.propagateLinTreeButton.setIcon(QIcon(":compute.svg")) - self.editLin_TreeBar.addWidget(self.propagateLinTreeButton) - self.propagateLinTreeButton.setShortcut('P') - self.widgetsWithShortcut['Propagate (lineage tree)'] = self.propagateLinTreeButton - self.propagateLinTreeButton.clicked.connect(self.propagateLinTreeAction) - - self.viewLinTreeInfoButton = QToolButton(self) - self.viewLinTreeInfoButton.setIcon(QIcon(":addCustomAnnotation.svg")) - self.editLin_TreeBar.addWidget(self.viewLinTreeInfoButton) - self.viewLinTreeInfoButton.setShortcut('S') - self.widgetsWithShortcut['View Changes (lineage tree)'] = self.viewLinTreeInfoButton - self.viewLinTreeInfoButton.clicked.connect(self.viewLinTreeInfoAction) - - - modes_available = [ - 'Segmentation and Tracking', - 'Cell cycle analysis', - 'Viewer', - 'Custom annotations', - 'Normal division: Lineage tree' - ] - self.modeItems = modes_available - - self.modeActionGroup = QActionGroup(self.modeMenu) - for mode in self.modeItems: - action = QAction(mode) - action.setCheckable(True) - self.modeActionGroup.addAction(action) - self.modeMenu.addAction(action) - if mode == 'Viewer': - action.setChecked(True) - - self.editToolBar = editToolBar - self.editToolBar.setVisible(False) - self.navigateToolBar.setVisible(False) - self.editLin_TreeBar.setVisible(False) - - self.gui_createAnnotateToolbar() - - def gui_createAnnotateToolbar(self): - # Edit toolbar - self.annotateToolbar = widgets.ToolBar("Custom annotations", self) - self.annotateToolbar.setContextMenuPolicy(Qt.PreventContextMenu) - self.addToolBar(Qt.LeftToolBarArea, self.annotateToolbar) - self.annotateToolbar.addAction(self.loadCustomAnnotationsAction) - self.annotateToolbar.addAction(self.addCustomAnnotationAction) - self.annotateToolbar.addAction(self.viewAllCustomAnnotAction) - self.annotateToolbar.setVisible(False) - - def gui_createLazyLoader(self): - if not self.lazyLoader is None: - return - - self.lazyLoaderThread = QThread() - self.lazyLoaderMutex = QMutex() - self.lazyLoaderWaitCond = QWaitCondition() - self.waitReadH5cond = QWaitCondition() - self.readH5mutex = QMutex() - self.lazyLoader = workers.LazyLoader( - self.lazyLoaderMutex, self.lazyLoaderWaitCond, - self.waitReadH5cond, self.readH5mutex - ) - self.lazyLoader.moveToThread(self.lazyLoaderThread) - self.lazyLoader.wait = True - - self.lazyLoader.signals.finished.connect(self.lazyLoaderThread.quit) - self.lazyLoader.signals.finished.connect(self.lazyLoader.deleteLater) - self.lazyLoaderThread.finished.connect(self.lazyLoaderThread.deleteLater) - - self.lazyLoader.signals.progress.connect(self.workerProgress) - self.lazyLoader.signals.sigLoadingNewChunk.connect(self.loadingNewChunk) - self.lazyLoader.sigLoadingFinished.connect(self.lazyLoaderFinished) - self.lazyLoader.signals.critical.connect(self.lazyLoaderCritical) - self.lazyLoader.signals.finished.connect(self.lazyLoaderWorkerClosed) - - self.lazyLoaderThread.started.connect(self.lazyLoader.run) - self.lazyLoaderThread.start() - - def gui_createStoreStateWorker(self): - self.storeStateWorker = None - return - self.storeStateThread = QThread() - self.autoSaveMutex = QMutex() - self.autoSaveWaitCond = QWaitCondition() - - self.storeStateWorker = workers.StoreGuiStateWorker( - self.autoSaveMutex, self.autoSaveWaitCond - ) - - self.storeStateWorker.moveToThread(self.storeStateThread) - self.storeStateWorker.finished.connect(self.storeStateThread.quit) - self.storeStateWorker.finished.connect(self.storeStateWorker.deleteLater) - self.storeStateThread.finished.connect(self.storeStateThread.deleteLater) - - self.storeStateWorker.sigDone.connect(self.storeStateWorkerDone) - self.storeStateWorker.progress.connect(self.workerProgress) - self.storeStateWorker.finished.connect(self.storeStateWorkerClosed) - - self.storeStateThread.started.connect(self.storeStateWorker.run) - self.storeStateThread.start() - - self.logger.info('Store state worker started.') - - def storeStateWorkerDone(self): - if self.storeStateWorker.callbackOnDone is not None: - self.storeStateWorker.callbackOnDone() - self.storeStateWorker.callbackOnDone = None - - def storeStateWorkerClosed(self): - self.logger.info('Store state worker started.') - - def gui_createAutoSaveWorker(self): - if not hasattr(self, 'data'): - return - - if not self.isDataLoaded: - return - - if self.autoSaveActiveWorkers: - garbage = self.autoSaveActiveWorkers[-1] - self.autoSaveGarbageWorkers.append(garbage) - worker = garbage[0] - worker._stop() - - posData = self.data[self.pos_i] - autoSaveThread = QThread() - self.autoSaveMutex = QMutex() - self.autoSaveWaitCond = QWaitCondition() - - savedSegmData = posData.segm_data.copy() - autoSaveWorker = workers.AutoSaveWorker( - self.autoSaveMutex, self.autoSaveWaitCond, savedSegmData - ) - autoSaveWorker.isAutoSaveON = self.autoSaveToggle.isChecked() - - autoSaveWorker.moveToThread(autoSaveThread) - autoSaveWorker.finished.connect(autoSaveThread.quit) - autoSaveWorker.finished.connect(autoSaveWorker.deleteLater) - autoSaveThread.finished.connect(autoSaveThread.deleteLater) - - autoSaveWorker.sigDone.connect(self.autoSaveWorkerDone) - autoSaveWorker.progress.connect(self.workerProgress) - autoSaveWorker.finished.connect(self.autoSaveWorkerClosed) - autoSaveWorker.sigAutoSaveCannotProceed.connect( - self.turnOffAutoSaveWorker - ) - - autoSaveThread.started.connect(autoSaveWorker.run) - autoSaveThread.start() - - self.autoSaveActiveWorkers.append((autoSaveWorker, autoSaveThread)) - - self.logger.info('Autosaving worker started.') - - def autoSaveWorkerStartTimer(self, worker, posData): - self.autoSaveWorkerTimer = QTimer() - self.autoSaveWorkerTimer.timeout.connect( - partial(self.autoSaveWorkerTimerCallback, worker, posData) - ) - self.autoSaveWorkerTimer.start(150) - - def autoSaveWorkerTimerCallback(self, worker, posData): - if not self.isSaving: - self.autoSaveWorkerTimer.stop() - worker._enqueue(posData) - - def autoSaveWorkerDone(self): - self.setStatusBarLabel(log=False) - - def ccaCheckerWorkerDone(self): - self.setStatusBarLabel(log=False) - - def preprocWorkerIsQueueEmpty(self, isEmpty: bool): - if isEmpty: - self.preprocessDialog.appliedFinished() - else: - self.preprocessDialog.setDisabled(True) - self.preprocessDialog.infoLabel.setText( - 'Computing preview...
' - '(Feel free to use Cell-ACDC while waiting)' - ) - - def preprocWorkerPreviewDone( - self, processed_data: np.ndarray, - key: Tuple[int, int, Union[int, str]] - ): - pos_i, frame_i, z_slice = key - posData = self.data[pos_i] - if not hasattr(posData, 'preproc_img_data'): - posData.preproc_img_data = preprocess.PreprocessedData( - image_data=np.zeros(posData.img_data.shape) - ) - - posData.preproc_img_data[frame_i][z_slice] = processed_data - self.img1.updateMinMaxValuesPreprocessedData( - self.data, pos_i, frame_i, z_slice - ) - - self.setImageImg1() - - def preprocWorkerDone( - self, - processed_data: np.ndarray, - how: str, - ): - self.setStatusBarLabel(log=False) - self.preprocessDialog.appliedFinished() - - posData = self.data[self.pos_i] - if not hasattr(posData, 'preproc_img_data'): - posData.preproc_img_data = preprocess.PreprocessedData() - - if how == 'current_image': - if posData.SizeZ > 1: - z_slice = self.z_slice_index() - posData.preproc_img_data[posData.frame_i][z_slice] = ( - processed_data - ) - else: - posData.preproc_img_data[posData.frame_i] = processed_data - z_slice = 0 - self.img1.updateMinMaxValuesPreprocessedData( - self.data, self.pos_i, posData.frame_i, z_slice - ) - elif how == 'z_stack': - for z_slice, processed_img in enumerate(processed_data): - posData.preproc_img_data[posData.frame_i][z_slice] = ( - processed_img - ) - self.img1.updateMinMaxValuesPreprocessedData( - self.data, self.pos_i, posData.frame_i, z_slice - ) - self.img1.updateMinMaxValuesPreprocessedProjections( - self.data, self.pos_i, posData.frame_i - ) - elif how == 'all_frames': - for frame_i, processed_frame in enumerate(processed_data): - if processed_frame.ndim == 2: - processed_frame = (processed_frame,) - - for z_slice, processed_img in enumerate(processed_frame): - posData.preproc_img_data[frame_i][z_slice] = ( - processed_img - ) - self.img1.updateMinMaxValuesPreprocessedData( - self.data, self.pos_i, frame_i, z_slice - ) - self.img1.updateMinMaxValuesPreprocessedProjections( - self.data, self.pos_i, frame_i - ) - elif how == 'all_pos': - for pos_i, processed_pos_data in enumerate(processed_data): - if processed_pos_data.ndim == 2: - processed_pos_data = (processed_pos_data,) - - posData = self.data[pos_i] - if not hasattr(posData, 'preproc_img_data'): - posData.preproc_img_data = preprocess.PreprocessedData() - for z_slice, processed_img in enumerate(processed_pos_data): - posData.preproc_img_data[0][z_slice] = ( - processed_img - ) - self.img1.updateMinMaxValuesPreprocessedData( - self.data, pos_i, 0, z_slice - ) - - if posData.SizeZ > 1: - self.img1.updateMinMaxValuesPreprocessedProjections( - self.data, pos_i, frame_i - ) - - if not self.viewPreprocDataToggle.isChecked(): - self.viewPreprocDataToggle.setChecked(True) - else: - self.setImageImg1() - - def goToFrameNumber(self, frame_n): - posData = self.data[self.pos_i] - posData.frame_i = frame_n - 1 - self.get_data() - self.updateAllImages() - self.updateScrollbars() - - def warnCcaIntegrity(self, txt, category): - self.logger.warning(f'{html_utils.to_plain_text(txt)}') - - if 'disable_all' in self.disabled_cca_warnings: - return - - if category in self.disabled_cca_warnings: - return - - if txt in self.disabled_cca_warnings: - return - - if self.isWarningCcaIntegrity: - # Some other warning is still open --> avoid opening another one - return - - self.isWarningCcaIntegrity = True - disabled_warning = _warnings.warn_cca_integrity( - txt, category, self, - go_to_frame_callback=self.goToFrameNumber - ) - if disabled_warning: - self.disabled_cca_warnings.add(disabled_warning) - - self.isWarningCcaIntegrity = False - - def fixWillDivide(self, warning_txt, IDs_will_divide_wrong): - self.logger.info(warning_txt) - self.logger.info('Fixing `will_divide` information...') - - global_cca_df = self.getConcatCcaDf() - global_cca_df = ( - global_cca_df.reset_index() - .set_index(['Cell_ID', 'generation_num']) - ) - global_cca_df.loc[IDs_will_divide_wrong, 'will_divide'] = 0 - global_cca_df = ( - global_cca_df.reset_index() - .set_index(['frame_i', 'Cell_ID']) - ) - self.storeFromConcatCcaDf(global_cca_df) - - def autoSaveWorkerClosed(self, worker): - if self.autoSaveActiveWorkers: - self.logger.info('Autosaving worker closed.') - try: - self.autoSaveActiveWorkers.remove(worker) - except Exception as e: - pass - - def ccaCheckerWorkerClosed(self, worker): - self.logger.info('Cell cycle annotations integrity checker stopped.') - self.ccaCheckerRunning = False - - def preprocWorkerClosed(self, worker): - self.logger.info('Pre-processing worker stopped.') - - def gui_createMainLayout(self): - mainLayout = QGridLayout() - row, col = 0, 1 # Leave column 1 for the overlay labels gradient editor - mainLayout.addLayout(self.leftSideDocksLayout, row, col, 2, 1) - - row = 0 - col = 2 - mainLayout.addWidget(self.graphLayout, row, col, 1, 2) - mainLayout.setRowStretch(row, 2) - - col = 4 # graphLayout spans two columns - mainLayout.addWidget(self.labelsGrad, row, col) - - col = 5 - mainLayout.addLayout(self.rightSideDocksLayout, row, col, 2, 1) - - col = 2 - row += 1 - self.resizeBottomLayoutLine = widgets.VerticalResizeHline() - mainLayout.addWidget(self.resizeBottomLayoutLine, row, col, 1, 2) - self.resizeBottomLayoutLine.dragged.connect( - self.resizeBottomLayoutLineDragged - ) - self.resizeBottomLayoutLine.clicked.connect( - self.resizeBottomLayoutLineClicked - ) - self.resizeBottomLayoutLine.released.connect( - self.resizeBottomLayoutLineReleased - ) - - # row += 1 - # mainLayout.addItem(QSpacerItem(5,5), row+1, col, 1, 2) - - # row, col = 1, 2 - # mainLayout.addLayout( - # self.bottomLayout, row, col, 1, 2, alignment=Qt.AlignLeft - # ) - - row += 1 - mainLayout.addWidget(self.bottomScrollArea, row, col, 1, 2) - mainLayout.setRowStretch(row, 0) - - # row, col = 2, 1 - # mainLayout.addWidget(self.terminal, row, col, 1, 4) - # self.terminal.hide() - - return mainLayout - - def gui_createRegionPropsDockWidget(self, side=Qt.LeftDockWidgetArea): - self.propsDockWidget = QDockWidget('Cell-ACDC objects', self) - self.guiTabControl = widgets.guiTabControl(self.propsDockWidget) - - # self.guiTabControl.setFont(_font) - - self.propsDockWidget.setWidget(self.guiTabControl) - self.propsDockWidget.setFeatures( - QDockWidget.DockWidgetFeature.DockWidgetFloatable - | QDockWidget.DockWidgetFeature.DockWidgetMovable - ) - self.propsDockWidget.setAllowedAreas( - Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea - ) - - self.addDockWidget(side, self.propsDockWidget) - self.propsDockWidget.hide() - - def gui_createControlsToolbar(self): - self.controlToolBars = [] - self.addToolBarBreak() - - # Edit toolbar - modeToolBar = widgets.ToolBar("Mode", self) - self.addToolBar(modeToolBar) - - self.modeComboBox = widgets.ComboBox() - self.modeComboBox.addItems(self.modeItems) - self.modeComboBoxLabel = QLabel(' Mode: ') - self.modeComboBoxLabel.setBuddy(self.modeComboBox) - modeToolBar.addWidget(self.modeComboBoxLabel) - modeToolBar.addWidget(self.modeComboBox) - modeToolBar.setVisible(False) - - self.modeToolBar = modeToolBar - - self.overlayToolbar = widgets.OverlayToolbar(parent=self) - self.addToolBar(Qt.TopToolBarArea, self.overlayToolbar) - self.overlayToolbar.setVisible(False) - self.overlayToolbar.sigSetTranspacency.connect( - self.setOverlayTransparency - ) - self.overlayToolbar.sigSetSingleChannel.connect( - self.setOverlaySingleChannel - ) - - self.autoPilotZoomToObjToolbar = widgets.ToolBar( - "Auto-zoom to objects", self - ) - self.autoPilotZoomToObjToolbar.setContextMenuPolicy(Qt.PreventContextMenu) - self.autoPilotZoomToObjToolbar.setMovable(False) - self.addToolBar(Qt.TopToolBarArea, self.autoPilotZoomToObjToolbar) - # self.autoPilotZoomToObjToolbar.setIconSize(QSize(16, 16)) - self.autoPilotZoomToObjToolbar.setVisible(False) - self.autoPilotZoomToObjToolbar.keepVisibleWhenActive = True - self.controlToolBars.append(self.autoPilotZoomToObjToolbar) - - # Highlighted ID or searched ID toolbar - self.highlightIDToolbar = widgets.HighlightedIDToolbar( - parent=self - ) - self.addToolBar(Qt.TopToolBarArea, self.highlightIDToolbar) - self.highlightIDToolbar.setVisible(False) - self.highlightIDToolbar.keepVisibleWhenActive = True - self.controlToolBars.append(self.highlightIDToolbar) - - self.highlightIDToolbar.sigIDChanged.connect( - self.setHighlighedIDfromToolbar - ) - - # Widgets toolbar - brushEraserToolBar = widgets.ToolBar("Widgets", self) - self.addToolBar(Qt.TopToolBarArea, brushEraserToolBar) - self.controlToolBars.append(brushEraserToolBar) - - self.editIDspinbox = widgets.SpinBox() - # self.editIDspinbox.setMaximum(2**32-1) - editIDLabel = QLabel(' ID: ') - self.editIDLabelAction = brushEraserToolBar.addWidget(editIDLabel) - self.editIDspinboxAction = brushEraserToolBar.addWidget( - self.editIDspinbox - ) - self.editIDLabelAction.setVisible(False) - self.editIDspinboxAction.setVisible(False) - self.editIDspinboxAction.setDisabled(True) - self.editIDLabelAction.setDisabled(True) - - brushEraserToolBar.addWidget(QLabel(' ')) - self.autoIDcheckbox = QCheckBox('Auto-ID') - self.autoIDcheckbox.setChecked(True) - self.autoIDcheckboxAction = brushEraserToolBar.addWidget(self.autoIDcheckbox) - self.autoIDcheckboxAction.setVisible(False) - - self.brushSizeSpinbox = widgets.SpinBox( - disableKeyPress=True, - allowNegative=False - ) - self.brushSizeSpinbox.setValue(4) - brushSizeLabel = QLabel(' Size: ') - brushSizeLabel.setBuddy(self.brushSizeSpinbox) - self.brushSizeLabelAction = brushEraserToolBar.addWidget(brushSizeLabel) - self.brushSizeAction = brushEraserToolBar.addWidget(self.brushSizeSpinbox) - self.brushSizeLabelAction.setVisible(False) - self.brushSizeAction.setVisible(False) - - brushEraserToolBar.addWidget(QLabel(' ')) - self.brushAutoFillCheckbox = QCheckBox('Auto-fill holes') - self.brushAutoFillAction = brushEraserToolBar.addWidget( - self.brushAutoFillCheckbox - ) - self.brushAutoFillAction.setVisible(False) - if 'brushAutoFill' in self.df_settings.index: - checked = self.df_settings.at['brushAutoFill', 'value'] == 'Yes' - self.brushAutoFillCheckbox.setChecked(checked) - - brushEraserToolBar.addWidget(QLabel(' ')) - self.brushAutoHideCheckbox = QCheckBox('Hide objects when hovering') - self.brushAutoHideAction = brushEraserToolBar.addWidget( - self.brushAutoHideCheckbox - ) - self.brushAutoHideCheckbox.setChecked(True) - self.brushAutoHideAction.setVisible(False) - if 'brushAutoHide' in self.df_settings.index: - checked = self.df_settings.at['brushAutoHide', 'value'] == 'Yes' - self.brushAutoHideCheckbox.setChecked(checked) - - brushEraserToolBar.setVisible(False) - self.brushEraserToolBar = brushEraserToolBar - - self.wandControlsToolbar = widgets.WandControlsToolbar( - parent=self - ) - - self.addToolBar(Qt.TopToolBarArea , self.wandControlsToolbar) - self.wandControlsToolbar.setVisible(False) - self.controlToolBars.append(self.wandControlsToolbar) - - separatorW = 5 - self.labelRoiToolbar = widgets.ToolBar("Magic labeller controls", self) - self.labelRoiToolbar.addWidget(QLabel('ROI n. of z-slices: ')) - self.labelRoiZdepthSpinbox = widgets.SpinBox(disableKeyPress=True) - self.labelRoiToolbar.addWidget(self.labelRoiZdepthSpinbox) - - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - self.labelRoiToolbar.addWidget(widgets.QVLine()) - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - - self.labelRoiReplaceExistingObjectsCheckbox = QCheckBox( - 'Remove objs. touched by new ones' - ) - self.labelRoiToolbar.addWidget(self.labelRoiReplaceExistingObjectsCheckbox) - self.labelRoiAutoClearBorderCheckbox = QCheckBox( - 'Clear ROI borders before adding new objs.' - ) - self.labelRoiAutoClearBorderCheckbox.setChecked(True) - self.labelRoiToolbar.addWidget(self.labelRoiAutoClearBorderCheckbox) - - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - self.labelRoiToolbar.addWidget(widgets.QVLine()) - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - - group = QButtonGroup() - group.setExclusive(True) - self.labelRoiIsRectRadioButton = QRadioButton('Rect. ROI') - self.labelRoiIsRectRadioButton.setChecked(True) - self.labelRoiIsFreeHandRadioButton = QRadioButton('Freehand ROI') - self.labelRoiIsCircularRadioButton = QRadioButton('Circular ROI') - group.addButton(self.labelRoiIsRectRadioButton) - group.addButton(self.labelRoiIsFreeHandRadioButton) - group.addButton(self.labelRoiIsCircularRadioButton) - self.labelRoiToolbar.addWidget(self.labelRoiIsRectRadioButton) - self.labelRoiToolbar.addWidget(self.labelRoiIsFreeHandRadioButton) - self.labelRoiToolbar.addWidget(self.labelRoiIsCircularRadioButton) - self.labelRoiToolbar.addWidget(QLabel(' | Radius (pixel): ')) - self.labelRoiCircularRadiusSpinbox = widgets.SpinBox(disableKeyPress=True) - self.labelRoiCircularRadiusSpinbox.setMinimum(1) - self.labelRoiCircularRadiusSpinbox.setValue(11) - self.labelRoiCircularRadiusSpinbox.setDisabled(True) - self.labelRoiToolbar.addWidget(self.labelRoiCircularRadiusSpinbox) - - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - self.labelRoiToolbar.addWidget(widgets.QVLine()) - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) - - startFrameLabel = QLabel('Start frame n. ') - startFrameLabel.setDisabled(True) - self.labelRoiToolbar.addWidget(startFrameLabel) - self.labelRoiStartFrameNoSpinbox = widgets.SpinBox(disableKeyPress=True) - self.labelRoiStartFrameNoSpinbox.label = startFrameLabel - self.labelRoiStartFrameNoSpinbox.setValue(1) - self.labelRoiStartFrameNoSpinbox.setMinimum(1) - self.labelRoiToolbar.addWidget(self.labelRoiStartFrameNoSpinbox) - self.labelRoiStartFrameNoSpinbox.setDisabled(True) - - self.labelRoiFromCurrentFrameAction = QAction(self) - self.labelRoiFromCurrentFrameAction.setText('Segment from current frame') - self.labelRoiFromCurrentFrameAction.setIcon(QIcon(":frames_current.svg")) - self.labelRoiToolbar.addAction(self.labelRoiFromCurrentFrameAction) - self.labelRoiFromCurrentFrameAction.setDisabled(True) - - self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=3)) - stopFrameLabel = QLabel(' Stop frame n. ') - stopFrameLabel.setDisabled(True) - self.labelRoiToolbar.addWidget(stopFrameLabel) - self.labelRoiStopFrameNoSpinbox = widgets.SpinBox(disableKeyPress=True) - self.labelRoiStopFrameNoSpinbox.label = stopFrameLabel - self.labelRoiStopFrameNoSpinbox.setValue(1) - self.labelRoiStopFrameNoSpinbox.setMinimum(1) - self.labelRoiToolbar.addWidget(self.labelRoiStopFrameNoSpinbox) - self.labelRoiStopFrameNoSpinbox.setDisabled(True) - - self.labelRoiToEndFramesAction = QAction(self) - self.labelRoiToEndFramesAction.setText('Segment all remaining frames') - self.labelRoiToEndFramesAction.setIcon(QIcon(":frames_end.svg")) - self.labelRoiToolbar.addAction(self.labelRoiToEndFramesAction) - self.labelRoiToEndFramesAction.setDisabled(True) - - self.labelRoiTrangeCheckbox = QCheckBox('Segment range of frames') - self.labelRoiToolbar.addWidget(self.labelRoiTrangeCheckbox) - - self.labelRoiViewCurrentModelAction = QAction(self) - self.labelRoiViewCurrentModelAction.setText( - 'View current model\'s parameters' - ) - self.labelRoiViewCurrentModelAction.setIcon(QIcon(":view.svg")) - self.labelRoiToolbar.addAction(self.labelRoiViewCurrentModelAction) - self.labelRoiViewCurrentModelAction.setDisabled(True) - - self.addToolBar(Qt.TopToolBarArea, self.labelRoiToolbar) - self.controlToolBars.append(self.labelRoiToolbar) - self.labelRoiToolbar.setVisible(False) - self.labelRoiTypesGroup = group - - self.loadLabelRoiLastParams() - - self.labelRoiTrangeCheckbox.toggled.connect( - self.labelRoiTrangeCheckboxToggled - ) - self.labelRoiReplaceExistingObjectsCheckbox.toggled.connect( - self.storeLabelRoiParams - ) - self.labelRoiIsCircularRadioButton.toggled.connect( - self.labelRoiIsCircularRadioButtonToggled - ) - self.labelRoiCircularRadiusSpinbox.valueChanged.connect( - self.updateLabelRoiCircularSize - ) - self.labelRoiCircularRadiusSpinbox.valueChanged.connect( - self.storeLabelRoiParams - ) - self.labelRoiZdepthSpinbox.valueChanged.connect( - self.storeLabelRoiParams - ) - self.labelRoiAutoClearBorderCheckbox.toggled.connect( - self.storeLabelRoiParams - ) - group.buttonToggled.connect(self.storeLabelRoiParams) - - self.labelRoiToEndFramesAction.triggered.connect( - self.labelRoiToEndFramesTriggered - ) - self.labelRoiFromCurrentFrameAction.triggered.connect( - self.labelRoiFromCurrentFrameTriggered - ) - self.labelRoiViewCurrentModelAction.triggered.connect( - self.labelRoiViewCurrentModel - ) - - self.keepIDsToolbar = widgets.ToolBar("Keep IDs controls", self) - self.keepIDsConfirmAction = QAction() - self.keepIDsConfirmAction.setIcon(QIcon(":greenTick.svg")) - self.keepIDsConfirmAction.setToolTip('Apply "keep IDs" selection') - self.keepIDsConfirmAction.setDisabled(True) - self.keepIDsToolbar.addAction(self.keepIDsConfirmAction) - self.keepIDsToolbar.addWidget(QLabel(' IDs to keep: ')) - instructionsText = ( - ' (Separate IDs by comma. Use a dash to denote a range of IDs)' - ) - instructionsLabel = QLabel(instructionsText) - self.keptIDsLineEdit = widgets.KeepIDsLineEdit( - instructionsLabel, parent=self - ) - self.keepIDsToolbar.addWidget(self.keptIDsLineEdit) - self.keepIDsToolbar.addWidget(instructionsLabel) - spacer = QWidget() - spacer.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) - self.keepIDsToolbar.addWidget(spacer) - self.addToolBar(Qt.TopToolBarArea, self.keepIDsToolbar) - self.keepIDsToolbar.setVisible(False) - self.controlToolBars.append(self.keepIDsToolbar) - - self.keptIDsLineEdit.sigEnterPressed.connect(self.applyKeepObjects) - self.keptIDsLineEdit.sigIDsChanged.connect(self.updateKeepIDs) - self.keepIDsConfirmAction.triggered.connect(self.applyKeepObjects) - - # closeToolbarAction = QAction( - # QIcon(":cancelButton.svg"), "Close toolbar...", self - # ) - # closeToolbarAction.triggered.connect(self.closeToolbars) - # self.autoPilotZoomToObjToolbar.addAction(closeToolbarAction) - - self.autoPilotZoomToObjToolbar.addWidget(widgets.QVLine()) - self.autoPilotZoomToObjToolbar.addWidget( - widgets.QHWidgetSpacer(width=separatorW) - ) - - spinBox = widgets.SpinBox() - spinBox.setMinimum(1) - spinBox.label = QLabel(' Zoom to ID: ') - spinBox.labelAction = self.autoPilotZoomToObjToolbar.addWidget(spinBox.label) - spinBox.action = self.autoPilotZoomToObjToolbar.addWidget(spinBox) - spinBox.editingFinished.connect(self.zoomToObj) - spinBox.sigUpClicked.connect(self.autoZoomNextObj) - spinBox.sigDownClicked.connect(self.autoZoomPrevObj) - self.autoPilotZoomToObjSpinBox = spinBox - toggle = widgets.Toggle() - self.autoPilotZoomToObjToggle = toggle - toggle.toggled.connect(self.autoPilotZoomToObjToggled) - toggle.label = QLabel(' Auto-pilot: ') - tooltip = ( - 'When auto-pilot is active, you can use Up/Down arrows to ' - 'automatically zoom to the next/previous object.\n\n' - 'Alternatively, you can type the ID of the object you want to ' - 'zoom to.' - ) - toggle.label.setToolTip(tooltip) - toggle.setToolTip(tooltip) - self.autoPilotZoomToObjToolbar.addWidget(toggle.label) - self.autoPilotZoomToObjToolbar.addWidget(toggle) - - self.pointsLayersToolbars = [] - - self.pointsLayersToolbar = widgets.PointsLayersToolbar(parent=self) - self.pointsLayersToolbar.setContextMenuPolicy(Qt.PreventContextMenu) - - self.pointsLayersToolbar.sigAddPointsLayer.connect( - self.addPointsLayerTriggered - ) - - self.addToolBar(Qt.TopToolBarArea, self.pointsLayersToolbar) - - self.pointsLayersToolbar.setVisible(False) - self.pointsLayersToolbar.keepVisibleWhenActive = True - self.controlToolBars.append(self.pointsLayersToolbar) - - self.pointsLayersToolbars.append( - self.pointsLayersToolbar - ) - - self.manualTrackingToolbar = widgets.ManualTrackingToolBar( - "Manual tracking controls", self - ) - self.manualTrackingToolbar.sigIDchanged.connect(self.initGhostObject) - self.manualTrackingToolbar.sigDisableGhost.connect(self.clearGhost) - self.manualTrackingToolbar.sigClearGhostContour.connect( - self.clearGhostContour - ) - self.manualTrackingToolbar.sigClearGhostMask.connect( - self.clearGhostMask - ) - self.manualTrackingToolbar.sigGhostOpacityChanged.connect( - self.updateGhostMaskOpacity - ) - - self.addToolBar(Qt.TopToolBarArea, self.manualTrackingToolbar) - self.manualTrackingToolbar.setVisible(False) - self.controlToolBars.append(self.manualTrackingToolbar) - - self.manualBackgroundToolbar = widgets.ManualBackgroundToolBar( - "Manual background controls", self - ) - self.manualBackgroundToolbar.sigIDchanged.connect( - self.initManualBackgroundObject - ) - self.addToolBar(Qt.TopToolBarArea, self.manualBackgroundToolbar) - self.manualBackgroundToolbar.setVisible(False) - self.controlToolBars.append(self.manualBackgroundToolbar) - - # Copy lost object contour toolbar - self.copyLostObjToolbar = widgets.CopyLostObjectToolbar( - "Copy lost object controls", self - ) - for name, action in self.copyLostObjToolbar.widgetsWithShortcut.items(): - self.widgetsWithShortcut[name] = action - - self.copyLostObjToolbar.sigCopyAllObjects.connect( - self.copyAllLostObjects - ) - - self.addToolBar(Qt.TopToolBarArea, self.copyLostObjToolbar) - self.copyLostObjToolbar.setVisible(False) - # self.controlToolBars.append(self.copyLostObjToolbar) - - # Copy lost object contour toolbar - self.drawClearRegionToolbar = widgets.DrawClearRegionToolbar( - "Draw freehand region and clear objects controls", self - ) - - self.addToolBar(Qt.TopToolBarArea, self.drawClearRegionToolbar) - self.drawClearRegionToolbar.setVisible(False) - self.controlToolBars.append(self.drawClearRegionToolbar) - - try: - addNewIDToggleState = self.df_settings.at['addNewIDsWhitelistToggle', 'value'] == 'Yes' - except KeyError: - addNewIDToggleState = True - - self.whitelistIDsToolbar = widgets.WhitelistIDsToolbar( - addNewIDToggleState, self - ) - for name, action in self.whitelistIDsToolbar.widgetsWithShortcut.items(): - self.widgetsWithShortcut[name] = action - - self.addToolBar(Qt.TopToolBarArea, self.whitelistIDsToolbar) - self.whitelistIDsToolbar.setVisible(False) - self.controlToolBars.append(self.whitelistIDsToolbar) - - self.magicPromptsToolbar = widgets.MagicPromptsToolbar(self) - for name, action in self.magicPromptsToolbar.widgetsWithShortcut.items(): - self.widgetsWithShortcut[name] = action - - self.magicPromptsToolbar.sigComputeOnZoom.connect( - self.magicPromptsComputeOnZoomTriggered - ) - self.magicPromptsToolbar.sigComputeOnImage.connect( - self.magicPromptsComputeOnImageTriggered - ) - self.magicPromptsToolbar.sigInitSelectedModel.connect( - self.magicPromptsInitModel - ) - self.magicPromptsToolbar.sigViewModelParams.connect( - self.viewSetMagicPromptModelParams - ) - self.magicPromptsToolbar.sigClearPoints.connect( - partial(self.magicPromptsClearPoints, only_zoom=False) - ) - self.magicPromptsToolbar.sigClearPointsOnZmom.connect( - partial(self.magicPromptsClearPoints, only_zoom=True) - ) - self.magicPromptsToolbar.sigInterpolateZslice.connect( - self.magicPromptsInterpolateZsliceToggled - ) - - self.addToolBar(Qt.TopToolBarArea, self.magicPromptsToolbar) - self.magicPromptsToolbar.setVisible(False) - self.magicPromptsToolbar.keepVisibleWhenActive = True - self.controlToolBars.append(self.magicPromptsToolbar) - - self.promptSegmentPointsLayerToolbar = ( - widgets.PromptableModelPointsLayerToolbar(parent=self) - ) - self.promptSegmentPointsLayerToolbar.setContextMenuPolicy( - Qt.PreventContextMenu - ) - - self.addToolBar(Qt.TopToolBarArea, self.promptSegmentPointsLayerToolbar) - self.promptSegmentPointsLayerToolbar.setVisible(False) - - self.pointsLayersToolbars.append( - self.promptSegmentPointsLayerToolbar - ) - - # Second level toolbar - secondLevelToolbar = widgets.ToolBar("Second level toolbar", self) - self.addToolBar(Qt.TopToolBarArea, secondLevelToolbar) - self.delObjToolAction = QAction(self) - self.delObjToolAction.setIcon(QIcon(":del_obj_click.svg")) - self.delObjToolAction.setCheckable(True) - self.delObjToolAction.setToolTip( - 'Customisable delete object action\n\n' - 'Go to the `Settings --> Customise keyboard shortcuts...` menu ' - 'on the top menubar\n' - 'to customise the action required to delete ' - 'an object with a click.\n\n' - 'When working with 3D segmentations, to delete only the z-slice mask, hold "Shift" while clicking.' - ) - secondLevelToolbar.addAction(self.delObjToolAction) - secondLevelToolbar.setMovable(False) - self.secondLevelToolbar = secondLevelToolbar - self.secondLevelToolbar.setVisible(False) - - def gui_populateToolSettingsMenu(self): - brushHoverModeActionGroup = QActionGroup(self) - brushHoverModeActionGroup.setExclusive(True) - self.brushHoverCenterModeAction = QAction() - self.brushHoverCenterModeAction.setCheckable(True) - self.brushHoverCenterModeAction.setText( - 'Use center of the brush/eraser cursor to determine hover ID' - ) - self.brushHoverCircleModeAction = QAction() - self.brushHoverCircleModeAction.setCheckable(True) - self.brushHoverCircleModeAction.setText( - 'Use the entire circle of the brush/eraser cursor to determine hover ID' - ) - brushHoverModeActionGroup.addAction(self.brushHoverCenterModeAction) - brushHoverModeActionGroup.addAction(self.brushHoverCircleModeAction) - brushHoverModeMenu = self.settingsMenu.addMenu( - 'Brush/eraser cursor hovering mode' - ) - brushHoverModeMenu.addAction(self.brushHoverCenterModeAction) - brushHoverModeMenu.addAction(self.brushHoverCircleModeAction) - - if 'useCenterBrushCursorHoverID' not in self.df_settings.index: - self.df_settings.at['useCenterBrushCursorHoverID', 'value'] = 'Yes' - - useCenterBrushCursorHoverID = self.df_settings.at[ - 'useCenterBrushCursorHoverID', 'value' - ] == 'Yes' - self.brushHoverCenterModeAction.setChecked(useCenterBrushCursorHoverID) - self.brushHoverCircleModeAction.setChecked(not useCenterBrushCursorHoverID) - - self.brushHoverCenterModeAction.toggled.connect( - self.useCenterBrushCursorHoverIDtoggled - ) - - self.settingsMenu.addSeparator() - - keepToolActiveNames = { - 'Segment range of frames': self.labelRoiTrangeCheckbox - } - for button in self.checkableQButtonsGroup.buttons(): - if button.toolTip() == "": - toolName = "MISSING" - continue - else: - toolName = re.findall(r'Name: (.*)', button.toolTip())[0] - keepToolActiveNames[toolName] = button - - keepToolActiveNames = dict(natsorted(keepToolActiveNames.items())) - - applyToNewFrameNames = { - 'Segmenting for lost IDs': self.segForLostIDsButton, - 'Delete bordering objects': self.delBorderObjAction.button, - 'Delete newly segmented objects': self.delNewObjAction.button, - } - - allToolsList = list(keepToolActiveNames.keys()) + list(applyToNewFrameNames.keys()) - allToolsList = natsorted(allToolsList) - - menus = {} - - for toolName in allToolsList: - menuItemText = f'{toolName} tool'.replace(' ', ' ') - menus[toolName] = self.settingsMenu.addMenu(menuItemText) - - self.keepToolActiveActions = dict() - self.applyToolNewFrameActions = dict() - self.applyToolNewFrameButtons = dict() - all_checked = True - - for toolName, button in keepToolActiveNames.items(): - menu = menus[toolName] - action = QAction(button) - action.setText('Keep tool active after using it') - action.setCheckable(True) - if toolName in self.df_settings.index: - action.setChecked(True) - else: - all_checked = False - action.toggled.connect(self.keepToolActiveActionToggled) - menu.addAction(action) - self.keepToolActiveActions[toolName] = action - - for toolName, button in applyToNewFrameNames.items(): - menu = menus[toolName] - action = QAction(button) - action.setText('Apply when visitng new frame') - action.setCheckable(True) - action.toggled.connect(self.applyToolNewFrameActionToggled) - menu.addAction(action) - self.applyToolNewFrameActions[toolName] = action - self.applyToolNewFrameButtons[toolName] = button - - for toolName in self.applyToolNewFrameActions.keys(): - settingString = toolName.strip() - settingString = toolName.replace(' ', '_') - settingString = f'{settingString}_applyNewFrame' - if settingString in self.df_settings.index: - val = self.df_settings.at[settingString, 'value'] - if val == 'applyNewFrame': - self.applyToolNewFrameActions[toolName].setChecked(True) - - self.settingsMenu.addSeparator() - - self.keepAllToolsActiveToggle = QAction() - self.keepAllToolsActiveToggle.setText( - 'Keep all tools active after using them' - ) - self.keepAllToolsActiveToggle.setCheckable(True) - self.keepAllToolsActiveToggle.setChecked(all_checked) - self.keepAllToolsActiveToggle.toggled.connect( - self.keepAllToolsActiveActionToggled - ) - self.settingsMenu.addAction(self.keepAllToolsActiveToggle) - self.settingsMenu.addSeparator() - - askHowFutureFramesMenu = self.settingsMenu.addMenu( - 'Ask how to propagate changes to future frames' - ) - self.askHowFutureFramesActions = {} - askHowFutureFramesActionsKeys = ( - 'Delete ID', - 'Exclude cell from analysis', - 'Annotate cell as dead', - 'Edit ID', - 'Keep ID' - ) - for key in askHowFutureFramesActionsKeys: - askHowFutureFramesAction = QAction() - askHowFutureFramesAction.setText(f'Ask for "{key}" action') - askHowFutureFramesAction.setCheckable(True) - askHowFutureFramesAction.setChecked(True) - askHowFutureFramesAction.setDisabled(True) - askHowFutureFramesMenu.addAction(askHowFutureFramesAction) - self.askHowFutureFramesActions[key] = askHowFutureFramesAction - - warningsMenu = self.settingsMenu.addMenu('Warnings and pop-ups') - self.warnLostCellsAction = QAction() - self.warnLostCellsAction.setText('Show pop-up warning for lost cells') - self.warnLostCellsAction.setCheckable(True) - self.warnLostCellsAction.setChecked(True) - warningsMenu.addAction(self.warnLostCellsAction) - - warnEditingWithAnnotTexts = { - 'Delete ID': 'Show warning when deleting ID that has annotations', - 'Separate IDs': 'Show warning when separating IDs that have annotations', - 'Edit ID': 'Show warning when editing ID that has annotations', - 'Annotate ID as dead': - 'Show warning when annotating dead ID that has annotations', - 'Delete ID with eraser': - 'Show warning when erasing ID that has annotations', - 'Add new ID with brush tool': - 'Show warning when adding new ID (brush) that has annotations', - 'Merge IDs': - 'Show warning when merging IDs that have annotations', - 'Add new ID with curvature tool': - 'Show warning when adding new ID (curv. tool) that has annotations', - 'Add new ID with magic-wand': - 'Show warning when adding new ID (magic-wand) that has annotations', - 'Delete IDs using ROI': - 'Show warning when using ROIs to delete IDs that have annotations', - } - self.warnEditingWithAnnotActions = {} - for key, desc in warnEditingWithAnnotTexts.items(): - action = QAction() - action.setText(desc) - action.setCheckable(True) - action.setChecked(True) - action.removeAnnot = False - self.warnEditingWithAnnotActions[key] = action - warningsMenu.addAction(action) - - - def gui_createStatusBar(self): - self.statusbar = self.statusBar() - # Permanent widget - self.wcLabel = QLabel('') - self.statusbar.addPermanentWidget(self.wcLabel) - - # self.toggleTerminalButton = widgets.ToggleTerminalButton() - # self.statusbar.addWidget(self.toggleTerminalButton) - # self.toggleTerminalButton.sigClicked.connect( - # self.gui_terminalButtonClicked - # ) - - self.statusBarLabel = QLabel('') - self.statusbar.addWidget(self.statusBarLabel) - - def gui_createTerminalWidget(self): - self.terminal = widgets.QLog(logger=self.logger) - self.terminal.connect() - self.terminalDock = QDockWidget('Log', self) - - self.terminalDock.setWidget(self.terminal) - self.terminalDock.setFeatures( - QDockWidget.DockWidgetFeature.DockWidgetFloatable | QDockWidget.DockWidgetFeature.DockWidgetMovable - ) - self.terminalDock.setAllowedAreas(Qt.BottomDockWidgetArea) - self.addDockWidget(Qt.BottomDockWidgetArea, self.terminalDock) - # self.terminalDock.widget().layout().setContentsMargins(10,0,10,0) - self.terminalDock.setVisible(False) - - @resetViewRange - def gui_terminalButtonClicked(self, terminalVisible): - self.terminalDock.setVisible(terminalVisible) - - def gui_createActions(self): - # File actions - self.segmNdimIndicator = widgets.ToolButtonTextIcon(text='') - self.segmNdimIndicator.setCheckable(True) - self.segmNdimIndicator.setChecked(True) - # self.segmNdimIndicator.setDisabled(True) - - if self.debug: - self.createEmptyDataAction = QAction(self) - self.createEmptyDataAction.setText("DEBUG: Create empty data") - - self.newWindowAction = QAction("New Window", self) - - self.newAction = QAction(self) - self.newAction.setText("&New Segmentation File...") - self.newAction.setIcon(QIcon(":file-new.svg")) - self.openFolderAction = QAction( - QIcon(":folder-open.svg"), "&Load Folder...", self - ) - self.openFileAction = QAction( - QIcon(":image.svg"),"&Open Image/Video File...", self - ) - self.manageVersionsAction = QAction( - QIcon(":manage_versions.svg"), "Load Older Versions...", self - ) - self.manageVersionsAction.setDisabled(True) - self.saveAction = QAction(QIcon(":file-save.svg"), "Save", self) - self.saveAsAction = QAction("Save as...", self) - self.exportToVideoAction = QAction("&Video...", self) - self.exportToImageAction = QAction("&Image...", self) - self.quickSaveAction = QAction("Save Only Segmentation Masks", self) - self.loadFluoAction = QAction("Load Fluorescence Images...", self) - self.loadPosAction = QAction("Load Different Position...", self) - # self.reloadAction = QAction( - # QIcon(":reload.svg"), "Reload segmentation file", self - # ) - self.nextAction = QAction('Next', self) - self.prevAction = QAction('Previous', self) - self.showInExplorerAction = QAction( - QIcon(":drawer.svg"), f"&{self.openFolderText}", self - ) - self.exitAction = QAction("&Exit", self) - self.undoAction = QAction(QIcon(":undo.svg"), "Undo", self) - self.redoAction = QAction(QIcon(":redo.svg"), "Redo", self) - # String-based key sequences - self.newWindowAction.setShortcut('Ctrl+Shift+N') - self.newAction.setShortcut('Ctrl+N') - self.openFolderAction.setShortcut('Ctrl+O') - self.loadPosAction.setShortcut('Shift+P') - self.saveAsAction.setShortcut('Ctrl+Shift+S') - self.exportToVideoAction.setShortcut('Ctrl+Shift+V') - self.exportToImageAction.setShortcut('Ctrl+Shift+I') - self.saveAction.setShortcut('Ctrl+Alt+S') - self.quickSaveAction.setShortcut('Ctrl+S') - self.undoAction.setShortcut('Ctrl+Z') - self.redoAction.setShortcut('Ctrl+Y') - self.nextAction.setShortcut(Qt.Key_Right) - self.prevAction.setShortcut(Qt.Key_Left) - self.addAction(self.nextAction) - self.addAction(self.prevAction) - # Help tips - newTip = "Create a new segmentation file" - self.newAction.setStatusTip(newTip) - self.newAction.setWhatsThis("Create a new empty segmentation file") - - self.autoPilotButton = QAction(self) - self.autoPilotButton.setIcon(QIcon(":auto-pilot.svg")) - self.autoPilotButton.setCheckable(True) - self.autoPilotButton.setShortcut('Ctrl+Shift+A') - - self.findIdAction = QAction(self) - self.findIdAction.setIcon(QIcon(":find.svg")) - self.findIdAction.setShortcut('Ctrl+F') - - self.zoomRectButton = QToolButton(self) - self.zoomRectButton.setIcon(QIcon(":zoom_rect.svg")) - self.zoomRectButton.setCheckable(True) - self.zoomRectButton.setShortcut('Shift+Z') - self.LeftClickButtons.append(self.zoomRectButton) - self.checkableButtons.append(self.zoomRectButton) - self.checkableQButtonsGroup.addButton(self.zoomRectButton) - self.widgetsWithShortcut['Zoom to rectangular area'] = ( - self.zoomRectButton - ) - - self.skipToNewIdAction = QAction(self) - self.skipToNewIdAction.setIcon(QIcon(":skip_forward_new_ID.svg")) - self.skipToNewIdAction.setShortcut( - widgets.KeySequenceFromText(Qt.Key_PageUp) - ) - - self.skipToNewIdAction.setDisabled(True) - - # Edit actions - models = myutils.get_list_of_models() - models = [*models, 'local_seg'] # Add local_seg for SegForLostIDsAction - self.segmActions = [] - self.modelNames = [] - self.acdcSegment_li = [] - self.models = [] - for model_name in models: - action = QAction(f"{model_name}...") - self.segmActions.append(action) - self.modelNames.append(model_name) - self.models.append(None) - self.acdcSegment_li.append(None) - action.setDisabled(True) - - self.addCustomModelFrameAction = QAction('Add custom model...', self) - self.addCustomModelVideoAction = QAction('Add custom model...', self) - - self.segmWithPromptableModelAction = QAction( - 'Select promptable model...', self - ) - self.addCustomPromptModelAction = QAction( - 'Add custom promptable model...', self - ) - - self.segmActionsVideo = [] - for model_name in models: - action = QAction(f"{model_name}...") - self.segmActionsVideo.append(action) - action.setDisabled(True) - - self.postProcessSegmAction = QAction( - "Segmentation post-processing...", self - ) - self.postProcessSegmAction.setDisabled(True) - self.postProcessSegmAction.setCheckable(True) - - self.EditSegForLostIDsSetSettings = QAction( - "Edit settings for Segmenting lost IDs...", self - ) - self.EditSegForLostIDsSetSettings.triggered.connect( - self.SegForLostIDsSetSettings - ) - - self.repeatTrackingAction = QAction( - QIcon(":repeat-tracking.svg"), "Repeat tracking", self - ) - self.repeatTrackingAction.setShortcut('Shift+T') - self.widgetsWithShortcut['Repeat Tracking'] = self.repeatTrackingAction - - - self.editRtTrackerParamsAction = QAction( - 'Edit real-time tracker parameters...', self - ) - - self.repeatTrackingMenuAction = QAction( - 'Track current frame with real-time tracker...', self - ) - self.repeatTrackingMenuAction.setDisabled(True) - self.repeatTrackingMenuAction.setShortcut('Shift+T') - - self.repeatTrackingVideoAction = QAction( - 'Select a tracker and track multiple frames...', self - ) - self.repeatTrackingVideoAction.setDisabled(True) - self.repeatTrackingVideoAction.setShortcut('Alt+Shift+T') - - self.trackingAlgosGroup = QActionGroup(self) - self.trackWithAcdcAction = QAction('Cell-ACDC', self) - self.trackWithAcdcAction.setCheckable(True) - self.trackingAlgosGroup.addAction(self.trackWithAcdcAction) - - self.trackWithYeazAction = QAction('YeaZ', self) - self.trackWithYeazAction.setCheckable(True) - self.trackingAlgosGroup.addAction(self.trackWithYeazAction) - - rt_trackers = myutils.get_list_of_real_time_trackers() - for rt_tracker in rt_trackers: - rtTrackerAction = QAction(rt_tracker, self) - rtTrackerAction.setCheckable(True) - self.trackingAlgosGroup.addAction(rtTrackerAction) - - self.trackWithAcdcAction.setChecked(True) - aliases = myutils.aliases_real_time_trackers() - - if 'tracking_algorithm' in self.df_settings.index: - trackingAlgo = self.df_settings.at['tracking_algorithm', 'value'] - if trackingAlgo in aliases: - trackingAlgo = aliases[trackingAlgo] - if trackingAlgo == 'Cell-ACDC': - self.trackWithAcdcAction.setChecked(True) - elif trackingAlgo == 'YeaZ': - self.trackWithYeazAction.setChecked(True) - else: - for rtTrackerAction in self.trackingAlgosGroup.actions(): - if rtTrackerAction.text() == trackingAlgo: - rtTrackerAction.setChecked(True) - break - - self.setMeasurementsAction = QAction('Set measurements...') - self.addCustomMetricAction = QAction('Add custom measurement...') - self.addCombineMetricAction = QAction('Add combined measurement...') - - # Standard key sequence - # self.copyAction.setShortcut(QKeySequence.StandardKey.Copy) - # self.pasteAction.setShortcut(QKeySequence.StandardKey.Paste) - # self.cutAction.setShortcut(QKeySequence.StandardKey.Cut) - # Help actions - self.tipsAction = QAction("Tips and tricks...", self) - self.UserManualAction = QAction("User Documentation...", self) - self.openLogFileAction = QAction("Open log file...", self) - self.showLogFilesAction = QAction("Show log files...", self) - self.aboutAction = QAction("About Cell-ACDC", self) - # self.aboutAction = QAction("&About...", self) - - # Assign mother to bud button - self.assignBudMothAutoAction = QAction(self) - self.assignBudMothAutoAction.setIcon(QIcon(":autoAssign.svg")) - self.assignBudMothAutoAction.setVisible(False) - - self.editCcaToolAction = QAction(self) - self.editCcaToolAction.setIcon(QIcon(":edit_cca.svg")) - # self.editCcaToolAction.setDisabled(True) - self.editCcaToolAction.setVisible(False) - - self.reInitCcaAction = QAction(self) - self.reInitCcaAction.setIcon(QIcon(":reinitCca.svg")) - self.reInitCcaAction.setVisible(False) - - self.toggleColorSchemeAction = QAction( - 'Switch to light theme' - ) - self.gui_updateSwitchColorSchemeActionText() - - self.pxModeAction = widgets.CheckableAction( - 'Fixed size text annotations' - ) - self.pxModeAction.setChecked(True) - pxModeTooltip = ( - 'When the text annotations are with fixed size they scale relative ' - 'to the object when zooming in/out (fixed size in pixels).\n' - 'This is typically faster to render, but it makes annotations ' - 'smaller/larger when zooming in/out, respectively.\n\n' - 'Try activating it to speed up the annotation of many objects ' - 'in high resolution mode.\n\n' - 'After activating it, you might need to increase the font size ' - 'from the menu on the top menubar `Edit --> Font size`.' - ) - self.pxModeAction.setToolTip(pxModeTooltip) - - self.highLowResAction = widgets.CheckableAction( - 'High resolution text annotations' - ) - highLowResTooltip = ( - 'Resolution of the text annotations. High resolution results ' - 'in slower update of the annotations.\n' - 'Not recommended with a number of segmented objects > 500.\n\n' - ) - self.highLowResAction.setToolTip(highLowResTooltip) - - self.editAutoSaveIntervalAction = QAction( - 'Change autosave interval (minutes or frames)...', self - ) - - self.editShortcutsAction = QAction( - 'Customize keyboard shortcuts...', self - ) - self.editShortcutsAction.setShortcut('Ctrl+K') - - self.showMirroredCursorAction = QAction( - 'Show mirrored cursor on images', self - ) - self.showMirroredCursorAction.setCheckable(True) - if 'showMirroredCursor' in self.df_settings.index: - checked = self.df_settings.at['showMirroredCursor', 'value'] == 'Yes' - self.showMirroredCursorAction.setChecked(checked) - else: - self.showMirroredCursorAction.setChecked(True) - self.showMirroredCursorAction.setShortcut('Ctrl+M') - - self.editTextIDsColorAction = QAction('Text annotation color...', self) - self.editTextIDsColorAction.setDisabled(True) - - self.editOverlayColorAction = QAction('Overlay color...', self) - self.editOverlayColorAction.setDisabled(True) - - self.manuallyEditCcaAction = QAction( - 'Edit cell cycle annotations...', self - ) - self.manuallyEditCcaAction.setShortcut('Ctrl+Shift+P') - self.manuallyEditCcaAction.setDisabled(True) - - self.viewCcaTableAction = QAction( - 'View cell cycle annotations...', self - ) - self.viewCcaTableAction.setDisabled(True) - self.viewCcaTableAction.setShortcut('Ctrl+P') - - - self.addScaleBarAction = QAction('Add scale bar', self) - self.addScaleBarAction.setCheckable(True) - - self.addTimestampAction = QAction('Add timestamp', self) - self.addTimestampAction.setCheckable(True) - - self.invertBwAction = QAction('Invert black/white', self) - self.invertBwAction.setCheckable(True) - checked = self.df_settings.at['is_bw_inverted', 'value'] == 'Yes' - self.invertBwAction.setChecked(checked) - - self.shuffleCmapAction = QAction('Randomly shuffle colormap', self) - self.shuffleCmapAction.setShortcut('Shift+S') - - self.greedyShuffleCmapAction = QAction( - 'Greedily shuffle colormap', self - ) - self.greedyShuffleCmapAction.setShortcut('Alt+Shift+S') - - self.saveLabColormapAction = QAction( - 'Save labels colormap...', self - ) - - self.normalizeRawAction = QAction( - 'Do not normalize. Display raw image', self) - self.normalizeToFloatAction = QAction( - 'Convert to floating point format with values [0, 1]', self) - # self.normalizeToUbyteAction = QAction( - # 'Rescale to 8-bit unsigned integer format with values [0, 255]', self) - self.normalizeRescale0to1Action = QAction( - 'Rescale to [0, 1]', self) - self.normalizeByMaxAction = QAction( - 'Normalize by max value', self) - self.normalizeRawAction.setCheckable(True) - self.normalizeToFloatAction.setCheckable(True) - # self.normalizeToUbyteAction.setCheckable(True) - self.normalizeRescale0to1Action.setCheckable(True) - self.normalizeByMaxAction.setCheckable(True) - self.normalizeQActionGroup = QActionGroup(self) - self.normalizeQActionGroup.addAction(self.normalizeRawAction) - self.normalizeQActionGroup.addAction(self.normalizeToFloatAction) - # self.normalizeQActionGroup.addAction(self.normalizeToUbyteAction) - self.normalizeQActionGroup.addAction(self.normalizeRescale0to1Action) - self.normalizeQActionGroup.addAction(self.normalizeByMaxAction) - - self.preprocessAction = QAction( - 'Pre-processing...', self - ) - self.preprocessAction.setShortcut('Alt+Shift+P') - - self.combineChannelsAction = QAction( - 'Combine and manipulate channels and/or segmentation files...', self - ) - self.combineChannelsAction.setShortcut('Alt+Shift+C') - - self.zoomToObjsAction = QAction( - 'Zoom to objects (Shortcut: H key)', self - ) - self.zoomOutAction = QAction( - 'Zoom out (Shortcut: double press H key)', self - ) - - self.relabelSequentialAction = QAction( - 'Relabel IDs sequentially...', self - ) - self.relabelSequentialAction.setShortcut('Ctrl+L') - self.relabelSequentialAction.setDisabled(True) - - self.setLastUserNormAction() - - self.autoSegmAction = QAction( - 'Enable automatic segmentation', self) - self.autoSegmAction.setCheckable(True) - self.autoSegmAction.setDisabled(True) - - self.enableSmartTrackAction = QAction( - 'Smart handling of enabling/disabling tracking', self) - self.enableSmartTrackAction.setCheckable(True) - self.enableSmartTrackAction.setChecked(True) - - self.enableAutoZoomToCellsAction = QAction( - 'Automatic zoom to all cells when pressing "Next/Previous"', self) - self.enableAutoZoomToCellsAction.setCheckable(True) - - self.imgPropertiesAction = QAction('Properties...', self) - self.imgPropertiesAction.setDisabled(True) - - self.addDelRoiAction = QAction(self) - self.addDelRoiAction.roiType = 'rect' - self.addDelRoiAction.setIcon(QIcon(":addDelRoi.svg")) - - self.addDelPolyLineRoiButton = QToolButton(self) - self.addDelPolyLineRoiButton.setCheckable(True) - self.addDelPolyLineRoiButton.setIcon(QIcon(":addDelPolyLineRoi.svg")) - - self.checkableButtons.append(self.addDelPolyLineRoiButton) - self.LeftClickButtons.append(self.addDelPolyLineRoiButton) - - self.delBorderObjAction = QAction(self) - self.delBorderObjAction.setIcon(QIcon(":delBorderObj.svg")) - - self.delNewObjAction = QAction(self) - self.delNewObjAction.setIcon(QIcon(":delNewObj.svg")) - - self.loadCustomAnnotationsAction = QAction(self) - self.loadCustomAnnotationsAction.setIcon(QIcon(":load_annotation.svg")) - self.loadCustomAnnotationsAction.setToolTip( - 'Load previously used custom annotations' - ) - - self.addCustomAnnotationAction = QAction(self) - self.addCustomAnnotationAction.setIcon(QIcon(":addCustomAnnotation.svg")) - self.addCustomAnnotationAction.setToolTip('Add custom annotation') - # self.functionsNotTested3D.append(self.addCustomAnnotationAction) - - self.viewAllCustomAnnotAction = QAction(self) - self.viewAllCustomAnnotAction.setCheckable(True) - self.viewAllCustomAnnotAction.setIcon(QIcon(":eye.svg")) - self.viewAllCustomAnnotAction.setToolTip('Show all custom annotations') - # self.functionsNotTested3D.append(self.viewAllCustomAnnotAction) - - # self.imgGradLabelsAlphaUpAction = QAction(self) - # self.imgGradLabelsAlphaUpAction.setVisible(False) - # self.imgGradLabelsAlphaUpAction.setShortcut('Ctrl+Up') - - def gui_updateSwitchColorSchemeActionText(self): - if self._colorScheme == 'dark': - txt = 'Switch to light theme' - else: - txt = 'Switch to dark theme' - self.toggleColorSchemeAction.setText(txt) - - def gui_connectActions(self): - # Connect File actions - if self.debug: - self.createEmptyDataAction.triggered.connect(self._createEmptyData) - self.segmNdimIndicator.clicked.connect(self.segmNdimIndicatorClicked) - self.newWindowAction.triggered.connect(self.openNewWindow) - self.newAction.triggered.connect(self.newFile) - self.openFolderAction.triggered.connect(self.openFolder) - self.openFileAction.triggered.connect(self.openFile) - self.manageVersionsAction.triggered.connect(self.manageVersions) - self.saveAction.triggered.connect(self.saveData) - self.saveAsAction.triggered.connect(self.saveAsData) - self.exportToVideoAction.triggered.connect(self.exportToVideoTriggered) - self.exportToImageAction.triggered.connect(self.exportToImageTriggered) - self.quickSaveAction.triggered.connect(self.quickSave) - self.viewPreprocDataToggle.toggled.connect( - self.viewPreprocDataToggled - ) - self.viewCombineChannelDataToggle.toggled.connect( - self.viewCombineChannelDataToggled - ) - self.autoSaveToggle.toggled.connect(self.autoSaveToggled) - self.autoSaveAnnotToggle.toggled.connect(self.autoSaveAnnotToggled) - self.autoSaveIntervalDialog.sigValueChanged.connect( - self.autoSaveIntervalValueChanged - ) - self.autoSaveIntervalEditButton.clicked.connect( - self.autoSaveIntervalEdit - ) - self.ccaIntegrCheckerToggle.toggled.connect( - self.ccaIntegrCheckerToggled - ) - self.annotLostObjsToggle.toggled.connect(self.annotLostObjsToggled) - self.highLowResAction.clicked.connect(self.highLowResToggled) - self.showInExplorerAction.triggered.connect(self.showInExplorer_cb) - self.exitAction.triggered.connect(self.close) - self.undoAction.triggered.connect(self.undo) - self.redoAction.triggered.connect(self.redo) - self.nextAction.triggered.connect(self.nextActionTriggered) - self.prevAction.triggered.connect(self.prevActionTriggered) - - self.invertBwAction.toggled.connect(self.invertBw) - self.toggleColorSchemeAction.triggered.connect(self.onToggleColorScheme) - self.pxModeAction.clicked.connect(self.pxModeActionToggled) - self.editShortcutsAction.triggered.connect(self.editShortcuts_cb) - self.editAutoSaveIntervalAction.triggered.connect( - self.autoSaveIntervalEditButton.click - ) - self.showMirroredCursorAction.toggled.connect( - self.showMirroredCursorToggled - ) - - # Connect Help actions - self.tipsAction.triggered.connect(self.showTipsAndTricks) - self.UserManualAction.triggered.connect(myutils.browse_docs) - self.openLogFileAction.triggered.connect(self.openLogFile) - self.showLogFilesAction.triggered.connect(self.showLogFiles) - self.aboutAction.triggered.connect(self.showAbout) - # Connect Open Recent to dynamically populate it - # self.openRecentMenu.aboutToShow.connect(self.populateOpenRecent) - self.checkableQButtonsGroup.buttonClicked.connect(self.uncheckQButton) - - self.showPropsDockButton.sigClicked.connect(self.showPropsDockWidget) - - self.loadCustomAnnotationsAction.triggered.connect( - self.loadCustomAnnotations - ) - self.addCustomAnnotationAction.triggered.connect( - self.addCustomAnnotation - ) - self.viewAllCustomAnnotAction.toggled.connect( - self.viewAllCustomAnnot - ) - self.addCustomModelVideoAction.triggered.connect( - self.showInstructionsCustomModel - ) - self.addCustomModelFrameAction.triggered.connect( - self.showInstructionsCustomModel - ) - self.addCustomModelFrameAction.callback = self.segmFrameCallback - self.addCustomModelVideoAction.callback = self.segmVideoCallback - - self.addCustomPromptModelAction.triggered.connect( - self.showInstructionsCustomPromptModel - ) - self.segmWithPromptableModelAction.triggered.connect( - self.segmWithPromptableModelActionTriggered - ) - - def zProjLockViewToggled(self, checked): - self.updateZproj(self.zProjComboBox.currentText()) - - def rescaleIntensExportToVideoDialog(self, how, channel, setImage=True): - if channel == self.user_ch_name: - lutItem = self.imgGrad - else: - lutItem = self.overlayLayersItems[channel][1] - - for action in lutItem.rescaleActionGroup.actions(): - if action.text() == how: - action.trigger() - # self.rescaleIntensitiesLut(setImage=setImage) - break - - def customLevelsLutChanged(self, levels, imageItem=None): - imageItem.setLevels(levels) - - def getPreComputedMinMaxZstack(self, channel: str): - if channel != self.user_ch_name: - return None - - posData = self.data[self.pos_i] - zstack_min, zstack_max = np.inf, 0 - for z in range(posData.SizeZ): - key = (self.pos_i, posData.frame_i, z) - levels = self.img1.minMaxValuesMapper.get(key) - if levels is None: - return - - img_min, img_max = levels - if img_min < zstack_min: - zstack_min = img_min - - if img_max > zstack_max: - zstack_max = img_max - - return (zstack_min, zstack_max) - - # @exec_time - def rescaleIntensitiesLut( - self, - action: QAction=None, - setImage: bool=True, - imageItem=None - ): - if not self.isDataLoaded: - self.logger.info( - 'WARNING: Data is not loaded. ' - 'Intensities will be rescaled later.' - ) - return - - posData = self.data[self.pos_i] - if imageItem is None: - imageItem = self.img1 - channel = self.user_ch_name - image_data = posData.img_data - else: - channel = imageItem.channelName - _, filename = self.getPathFromChName(channel, posData) - image_data = posData.fluo_data_dict[filename] - - triggeredByUser = True - if action is None: - triggeredByUser = False - action = imageItem.lutItem.rescaleActionGroup.checkedAction() - - how = action.text() - - self.df_settings.at[f'how_rescale_intensities_{channel}', 'value'] = how - self.df_settings.to_csv(self.settings_csv_path) - - if how == 'Rescale each 2D image': - if how == self.rescaleIntensChannelHowMapper[channel]: - # No need to update since we have autoscale - return - - imageItem.setEnableAutoLevels(True) - if setImage: - imageItem.setImage(imageItem.image) - return - - lutLevelsCh = posData.lutLevels[channel] - - if how == 'Rescale across z-stack': - imageItem.setEnableAutoLevels(False) - levels_key = (how, posData.frame_i) - levels = lutLevelsCh.get(levels_key) - if levels is None: - levels = self.getPreComputedMinMaxZstack(channel) - - if levels is None: - image_zstack = image_data[posData.frame_i] - levels = (image_zstack.min(), image_zstack.max()) - lutLevelsCh[levels_key] = levels - imageItem.setLevels(levels) - elif how == 'Rescale across time frames': - imageItem.setEnableAutoLevels(False) - levels_key = (how, None) - levels = lutLevelsCh.get(levels_key) - if levels is None: - levels = (image_data.min(), image_data.max()) - - lutLevelsCh[levels_key] = levels - imageItem.setLevels(levels) - elif how == 'Choose custom levels...': - autoLevelsEnabledBefore = imageItem.autoLevelsEnabled - imageItem.setEnableAutoLevels(False) - if triggeredByUser: - current_min, current_max = imageItem.getLevels() - dtype_max = np.iinfo(image_data.dtype).max - max_value = image_data.max() - min_value = image_data.min() - win = apps.SetCustomLevelsLut( - init_min_value=current_min, - init_max_value=current_max, - maximum_max_value=max_value, - minimum_min_value=min_value, - parent=self - ) - win.sigLevelsChanged.connect( - partial(self.customLevelsLutChanged, imageItem=imageItem) - ) - win.exec_() - if win.cancel: - imageItem.setEnableAutoLevels(autoLevelsEnabledBefore) - self.logger.info('Custom LUT levels setting cancelled.') - self.updateAllImages() - return - selectedLevels = win.selectedLevels - else: - selectedLevels = imageItem.getLevels() - imageItem.setLevels(selectedLevels) - elif how == 'Do no rescale, display raw image': - imageItem.setEnableAutoLevels(False) - levels_key = (how, None) - levels = lutLevelsCh.get(levels_key) - if levels is None: - dtype_max = np.iinfo(image_data.dtype).max - levels = (0, dtype_max) - lutLevelsCh[levels_key] = levels - imageItem.setLevels(levels) - - self.rescaleIntensChannelHowMapper[channel] = how - - if setImage: - imageItem.setImage(imageItem.image) - - def onToggleColorScheme(self): - if self.toggleColorSchemeAction.text().find('light') != -1: - self._colorScheme = 'light' - setDarkModeToggleChecked = False - else: - self._colorScheme = 'dark' - setDarkModeToggleChecked = True - self.gui_updateSwitchColorSchemeActionText() - _warnings.warnRestartCellACDCcolorModeToggled( - self._colorScheme, app_name=self._appName, parent=self - ) - load.rename_qrc_resources_file(self._colorScheme) - self.statusBarLabel.setText(html_utils.paragraph( - f'Restart {self._appName} for the change to take effect', - font_color='red' - )) - self.df_settings.at['colorScheme', 'value'] = self._colorScheme - self.df_settings.to_csv(settings_csv_path) - - def showMirroredCursorToggled(self, checked): - value = 'Yes' if checked else 'No' - self.df_settings.at['showMirroredCursor', 'value'] = value - self.df_settings.to_csv(settings_csv_path) - - if not checked: - self.clearCursors() - - def clearCursors(self): - self.ax1_cursor.setData([], []) - self.ax2_cursor.setData([], []) - self.setHoverToolSymbolData( - [], [], (self.ax2_BrushCircle, self.ax1_BrushCircle), - ) - eraserCursors = ( - self.ax1_EraserCircle, self.ax2_EraserCircle, - self.ax1_EraserX, self.ax2_EraserX - ) - self.setHoverToolSymbolData([], [], eraserCursors) - - def activeEraserCircleCursors(self, isHoverImg1): - if self.showMirroredCursorAction.isChecked(): - return self.ax1_EraserCircle, self.ax2_EraserCircle - - if isHoverImg1: - return self.ax1_EraserCircle, - else: - return self.ax2_EraserCircle, - - def activeEraserXCursors(self, isHoverImg1): - if self.showMirroredCursorAction.isChecked(): - return self.ax1_EraserX, self.ax2_EraserX - - if isHoverImg1: - return self.ax1_EraserX, - else: - return self.ax2_EraserX, - - def activeBrushCircleCursors(self, isHoverImg1): - if self.showMirroredCursorAction.isChecked(): - return self.ax1_BrushCircle, self.ax2_BrushCircle - - if isHoverImg1: - return self.ax1_BrushCircle, - else: - return self.ax2_BrushCircle, - - def gui_connectEditActions(self): - self.showInExplorerAction.setEnabled(True) - self.setEnabledFileToolbar(True) - self.loadFluoAction.setEnabled(True) - self.isEditActionsConnected = True - - self.preprocessImageAction.triggered.connect( - self.preprocessAction.trigger - ) - self.combineChannelsAction.triggered.connect( - self.combineChannelsActionTriggered - ) - - self.overlayButton.toggled.connect(self.overlay_cb) - self.countObjsButton.toggled.connect(self.countObjectsCb) - self.togglePointsLayerAction.toggled.connect(self.pointsLayerToggled) - self.overlayLabelsButton.toggled.connect(self.overlayLabels_cb) - self.overlayButton.sigRightClick.connect(self.showOverlayContextMenu) - self.labelRoiButton.sigRightClick.connect(self.showLabelRoiContextMenu) - self.overlayLabelsButton.sigRightClick.connect( - self.showOverlayLabelsContextMenu - ) - self.rulerButton.toggled.connect(self.ruler_cb) - self.loadFluoAction.triggered.connect(self.loadFluo_cb) - self.loadPosAction.triggered.connect(self.loadPosTriggered) - # self.reloadAction.triggered.connect(self.reload_cb) - self.findIdAction.triggered.connect(self.findID) - self.zoomRectButton.toggled.connect(self.zoomRectActionToggled) - self.autoPilotButton.toggled.connect(self.autoPilotToggled) - self.skipToNewIdAction.triggered.connect(self.skipForwardToNewID) - self.slideshowButton.toggled.connect(self.launchSlideshow) - - self.copyLostObjButton.toggled.connect(self.copyLostObjContour_cb) - self.manualAnnotPastButton.toggled.connect( - self.manualAnnotPast_cb - ) - - self.segmSingleFrameMenu.triggered.connect(self.segmFrameCallback) - self.segmVideoMenu.triggered.connect(self.segmVideoCallback) - - self.postProcessSegmAction.toggled.connect(self.postProcessSegm) - self.autoSegmAction.toggled.connect(self.autoSegm_cb) - self.realTimeTrackingToggle.clicked.connect(self.realTimeTrackingClicked) - self.repeatTrackingAction.triggered.connect(self.repeatTracking) - self.manualTrackingButton.toggled.connect(self.manualTracking_cb) - self.manualBackgroundButton.toggled.connect(self.manualBackground_cb) - self.repeatTrackingMenuAction.triggered.connect(self.repeatTracking) - self.repeatTrackingVideoAction.triggered.connect( - self.repeatTrackingVideo - ) - for rtTrackerAction in self.trackingAlgosGroup.actions(): - rtTrackerAction.toggled.connect(self.rtTrackerActionToggled) - self.editRtTrackerParamsAction.triggered.connect( - self.initRealTimeTracker - ) - self.delObjsOutSegmMaskAction.triggered.connect( - self.delObjsOutSegmMaskActionTriggered - ) - self.mergeIDsButton.toggled.connect(self.mergeObjs_cb) - self.brushButton.toggled.connect(self.Brush_cb) - self.eraserButton.toggled.connect(self.Eraser_cb) - self.curvToolButton.toggled.connect(self.curvTool_cb) - self.wandToolButton.toggled.connect(self.wand_cb) - self.labelRoiButton.toggled.connect(self.labelRoi_cb) - self.magicPromptsToolButton.toggled.connect(self.magicPrompts_cb) - self.drawClearRegionButton.toggled.connect(self.drawClearRegion_cb) - self.reInitCcaAction.triggered.connect(self.reInitCca) - self.moveLabelToolButton.toggled.connect(self.moveLabelButtonToggled) - self.editCcaToolAction.triggered.connect( - self.manualEditCcaToolbarActionTriggered - ) - self.assignBudMothAutoAction.triggered.connect( - self.autoAssignBud_YeastMate - ) - self.keepIDsButton.toggled.connect(self.keepIDs_cb) - - self.whitelistIDsButton.toggled.connect(self.whitelistIDs_cb) - - self.whitelistIDsToolbar.sigWhitelistChanged.connect( - self.whitelistIDsChanged - ) - - self.whitelistIDsToolbar.sigWhitelistAccepted.connect( - self.whitelistIDsAccepted - ) - - self.whitelistIDsToolbar.sigViewOGIDs.connect(self.whitelistViewOGIDs) - - self.whitelistIDsToolbar.sigAddNewIDs.connect(self.whitelistAddNewIDsToggled) - - self.whitelistIDsToolbar.sigLoadOGLabs.connect(self.whitelistLoadOGLabs_cb) - - self.whitelistIDsToolbar.sigTrackOGagainstPreviousFrame.connect( - self.whitelistTrackOGagainstPreviousFrame_cb - ) - - self.expandLabelToolButton.toggled.connect(self.expandLabelCallback) - - self.reinitLastSegmFrameAction.triggered.connect( - self.reInitLastSegmFrame - ) - - - self.defaultRescaleIntensActionGroup.triggered.connect( - self.defaultRescaleIntensLutActionToggled - ) - - # self.repeatAutoCcaAction.triggered.connect(self.repeatAutoCca) - self.manuallyEditCcaAction.triggered.connect(self.manualEditCca) - self.addScaleBarAction.toggled.connect(self.addScaleBar) - self.addTimestampAction.toggled.connect(self.addTimestamp) - self.saveLabColormapAction.triggered.connect(self.saveLabelsColormap) - - self.enableSmartTrackAction.toggled.connect(self.enableSmartTrack) - # Brush/Eraser size action - self.brushSizeSpinbox.valueChanged.connect(self.brushSize_cb) - self.autoIDcheckbox.toggled.connect(self.autoIDtoggled) - # Mode - self.modeActionGroup.triggered.connect(self.changeModeFromMenu) - self.modeComboBox.sigTextChanged.connect(self.changeMode) - self.modeComboBox.activated.connect(self.clearComboBoxFocus) - self.equalizeHistPushButton.toggled.connect(self.equalizeHist) - - self.editOverlayColorAction.triggered.connect(self.toggleOverlayColorButton) - self.editTextIDsColorAction.triggered.connect(self.toggleTextIDsColorButton) - self.overlayColorButton.sigColorChanging.connect(self.changeOverlayColor) - self.overlayColorButton.sigColorChanged.connect(self.saveOverlayColor) - self.textIDsColorButton.sigColorChanging.connect(self.updateTextAnnotColor) - self.textIDsColorButton.sigColorChanged.connect(self.saveTextIDsColors) - - self.setMeasurementsAction.triggered.connect(self.showSetMeasurements) - self.addCustomMetricAction.triggered.connect(self.addCustomMetric) - self.addCombineMetricAction.triggered.connect(self.addCombineMetric) - - self.labelsGrad.colorButton.sigColorChanging.connect(self.updateBkgrColor) - self.labelsGrad.colorButton.sigColorChanged.connect(self.saveBkgrColor) - self.labelsGrad.sigGradientChangeFinished.connect(self.updateLabelsCmap) - self.labelsGrad.sigGradientChanged.connect(self.ticksCmapMoved) - self.labelsGrad.textColorButton.sigColorChanging.connect( - self.updateTextLabelsColor - ) - self.labelsGrad.textColorButton.sigColorChanged.connect( - self.saveTextLabelsColor - ) - # self.addFontSizeActions( - # self.labelsGrad.fontSizeMenu, self.setFontSizeActionChecked - # ) - - self.labelsGrad.shuffleCmapAction.triggered.connect(self.shuffle_cmap) - self.labelsGrad.greedyShuffleCmapAction.triggered.connect( - self.greedyShuffleCmap - ) - self.labelsGrad.permanentGreedyCmapAction.toggled.connect( - self.permanentGreedyCmapToggled - ) - self.shuffleCmapAction.triggered.connect(self.shuffle_cmap) - self.greedyShuffleCmapAction.triggered.connect(self.greedyShuffleCmap) - self.labelsGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) - self.labelsGrad.sigShowLabelsImgToggled.connect(self.showLabelImageItem) - self.labelsGrad.sigShowRightImgToggled.connect(self.showRightImageItem) - self.labelsGrad.sigShowNextFrameToggled.connect(self.showNextFrameImageItem) - - self.labelsGrad.defaultSettingsAction.triggered.connect( - self.restoreDefaultSettings - ) - - # self.addFontSizeActions( - # self.imgGrad.fontSizeMenu, self.setFontSizeActionChecked - # ) - self.imgGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) - self.imgGrad.textColorButton.disconnect() - self.imgGrad.textColorButton.clicked.connect( - self.editTextIDsColorAction.trigger - ) - self.imgGrad.labelsAlphaSlider.valueChanged.connect( - self.updateLabelsAlpha - ) - self.imgGrad.defaultSettingsAction.triggered.connect( - self.restoreDefaultSettings - ) - - # Drawing mode - self.drawIDsContComboBox.currentIndexChanged.connect( - self.drawIDsContComboBox_cb - ) - self.drawIDsContComboBox.activated.connect(self.clearComboBoxFocus) - - self.annotateRightHowCombobox.currentIndexChanged.connect( - self.annotateRightHowCombobox_cb - ) - self.annotateRightHowCombobox.activated.connect(self.clearComboBoxFocus) - - self.showTreeInfoCheckbox.toggled.connect(self.setAnnotInfoMode) - - # Left - self.annotIDsCheckbox.clicked.connect(self.annotOptionClicked) - self.annotCcaInfoCheckbox.clicked.connect(self.annotOptionClicked) - self.annotContourCheckbox.clicked.connect(self.annotOptionClicked) - self.annotSegmMasksCheckbox.clicked.connect(self.annotOptionClicked) - self.drawMothBudLinesCheckbox.clicked.connect(self.annotOptionClicked) - self.drawNothingCheckbox.clicked.connect(self.annotOptionClicked) - self.annotNumZslicesCheckbox.clicked.connect(self.annotOptionClicked) - - # Right - self.annotIDsCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.annotCcaInfoCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.annotContourCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.annotSegmMasksCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.drawMothBudLinesCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.drawNothingCheckboxRight.clicked.connect( - self.annotOptionClickedRight) - self.annotNumZslicesCheckboxRight.clicked.connect( - self.annotOptionClickedRight - ) - - self.segmentToolAction.triggered.connect(self.segmentToolActionTriggered) - - self.addDelRoiAction.triggered.connect(self.addDelROI) - self.addDelPolyLineRoiButton.toggled.connect(self.addDelPolyLineRoi_cb) - self.delBorderObjAction.triggered.connect(self.delBorderObj) - self.delNewObjAction.triggered.connect(self.delNewObj) - - self.brushAutoFillCheckbox.toggled.connect(self.brushAutoFillToggled) - self.brushAutoHideCheckbox.toggled.connect(self.brushAutoHideToggled) - - self.imgGrad.sigAddScaleBar.connect(self.addScaleBarAction.setChecked) - self.imgGrad.sigAddTimestamp.connect(self.addTimestampAction.setChecked) - self.imgGrad.gradient.sigGradientChangeFinished.connect( - self.imgGradLUTfinished_cb - ) - - # self.normalizeQActionGroup.triggered.connect( - # self.normaliseIntensitiesActionTriggered - # ) - self.imgPropertiesAction.triggered.connect(self.editImgProperties) - - self.relabelSequentialAction.triggered.connect( - self.relabelSequentialCallback - ) - - self.zoomToObjsAction.triggered.connect(self.zoomToObjsActionCallback) - self.zoomOutAction.triggered.connect(self.zoomOut) - self.preprocessAction.triggered.connect(self.preprocessActionTriggered) - self.combineChannelsAction.triggered.connect(self.combineChannelsActionTriggered) - - self.viewCcaTableAction.triggered.connect(self.viewCcaTable) - - self.guiTabControl.propsQGBox.idSB.valueChanged.connect( - self.propsWidgetIDvalueChanged - ) - self.guiTabControl.highlightCheckbox.toggled.connect( - self.highlightIDonHoverCheckBoxToggled - ) - self.guiTabControl.highlightSearchedCheckbox.toggled.connect( - self.highlightSearchedIDcheckBoxToggled - ) - intensMeasurQGBox = self.guiTabControl.intensMeasurQGBox - intensMeasurQGBox.additionalMeasCombobox.currentTextChanged.connect( - self.updatePropsWidget - ) - intensMeasurQGBox.channelCombobox.currentTextChanged.connect( - self.updatePropsWidget - ) - - propsQGBox = self.guiTabControl.propsQGBox - propsQGBox.additionalPropsCombobox.currentTextChanged.connect( - self.updatePropsWidget - ) - - def gui_createShowPropsButton(self, side='left'): - self.leftSideDocksLayout = QVBoxLayout() - self.leftSideDocksLayout.setSpacing(0) - self.leftSideDocksLayout.setContentsMargins(0,0,0,0) - self.rightSideDocksLayout = QVBoxLayout() - self.rightSideDocksLayout.setSpacing(0) - self.rightSideDocksLayout.setContentsMargins(0,0,0,0) - self.showPropsDockButton = widgets.expandCollapseButton() - self.showPropsDockButton.setDisabled(True) - self.showPropsDockButton.setFocusPolicy(Qt.NoFocus) - self.showPropsDockButton.setToolTip('Show object properties') - if side == 'left': - self.leftSideDocksLayout.addWidget(self.showPropsDockButton) - else: - self.rightSideDocksLayout.addWidget(self.showPropsDockButton) - - def gui_createQuickSettingsWidgets(self): - self.quickSettingsLayout = QVBoxLayout() - self.quickSettingsGroupbox = widgets.GroupBox() - self.quickSettingsGroupbox.setTitle('Quick settings') - - layout = QFormLayout() - layout.setFieldGrowthPolicy( - QFormLayout.FieldGrowthPolicy.FieldsStayAtSizeHint - ) - layout.setFormAlignment(Qt.AlignRight | Qt.AlignVCenter) - - self.viewPreprocDataToggle = widgets.Toggle() - viewPreprocDataToggleTooltip = ( - 'View pre-processed data. See menu `Image --> Pre-processing...`\n' - 'on the top menubar.' - ) - self.viewPreprocDataToggle.setChecked(False) - self.viewPreprocDataToggle.setToolTip(viewPreprocDataToggleTooltip) - viewPreprocDataToggleLabel = QLabel('View pre-processed image') - viewPreprocDataToggleLabel.setToolTip(viewPreprocDataToggleTooltip) - layout.addRow(viewPreprocDataToggleLabel, self.viewPreprocDataToggle) - - self.viewCombineChannelDataToggle = widgets.Toggle() - viewCombineChannelDataToggleTooltip = ( - 'View combined channel. See menu `Image --> combing channels...`\n' - 'on the top menubar.' - ) - self.viewCombineChannelDataToggle.setChecked(False) - self.viewCombineChannelDataToggle.setToolTip( - viewCombineChannelDataToggleTooltip - ) - viewCombineChannelDataToggleLabel = QLabel('View combined channels') - viewCombineChannelDataToggleLabel.setToolTip( - viewCombineChannelDataToggleTooltip - ) - layout.addRow( - viewCombineChannelDataToggleLabel, - self.viewCombineChannelDataToggle - ) - - self.autoSaveToggle = widgets.Toggle() - autoSaveTooltip = ( - 'Automatically store a copy of the segmentation data ' - 'in the `.recovery` folder after every edit.' - ) - self.autoSaveToggle.setChecked(True) - self.autoSaveToggle.setToolTip(autoSaveTooltip) - autoSaveLabel = QLabel('Autosave segmentation') - autoSaveLabel.setToolTip(autoSaveTooltip) - layout.addRow(autoSaveLabel, self.autoSaveToggle) - - self.autoSaveAnnotToggle = widgets.Toggle() - autoSaveAnnotTooltip = ( - 'Automatically store a copy of the annotations (acdc_output CSV file) ' - 'in the `.recovery` folder after every edit.' - ) - self.autoSaveAnnotToggle.setChecked(True) - self.autoSaveAnnotToggle.setToolTip(autoSaveAnnotTooltip) - autoSaveAnnotLabel = QLabel('Autosave annotations') - autoSaveAnnotLabel.setToolTip(autoSaveAnnotTooltip) - layout.addRow(autoSaveAnnotLabel, self.autoSaveAnnotToggle) - - self.autoSaveIntervalEditButton = widgets.editPushButton( - flat=True, hoverable=True - ) - self.autoSaveIntervalLabel = QLabel('Autosave interval') - self.autoSaveIntervalSetTooltip() - layout.addRow( - self.autoSaveIntervalLabel, self.autoSaveIntervalEditButton - ) - - self.autoSaveIntervalDialog = apps.AutoSaveIntervalDialog(parent=self) - self.autoSaveIntervalDialog.setValues(*self.autoSaveIntevalValueUnit) - - self.ccaIntegrCheckerToggle = widgets.Toggle() - ccaIntegrCheckerToggleTooltip = ( - 'Toggle background cell cycle annotations integrity checker ON/OFF' - ) - self.ccaIntegrCheckerToggle.setChecked(False) - self.ccaIntegrCheckerToggle.setToolTip(ccaIntegrCheckerToggleTooltip) - label = QLabel('Cc annot. checker') - label.setToolTip(ccaIntegrCheckerToggleTooltip) - layout.addRow(label, self.ccaIntegrCheckerToggle) - if 'is_cca_integrity_checker_activated' in self.df_settings.index: - idx = 'is_cca_integrity_checker_activated' - val = int(self.df_settings.at[idx, 'value']) - self.ccaIntegrCheckerToggle.setChecked(not val) - - self.annotLostObjsToggle = widgets.Toggle() - annotLostObjsToggleTooltip = ( - 'Toggle annotation of lost objects mode ON/OFF' - ) - self.annotLostObjsToggle.setChecked(True) - self.annotLostObjsToggle.setToolTip(annotLostObjsToggleTooltip) - label = QLabel('Annot. lost objects') - label.setToolTip(annotLostObjsToggleTooltip) - layout.addRow(label, self.annotLostObjsToggle) - - self.realTimeTrackingToggle = widgets.Toggle() - self.realTimeTrackingToggle.setChecked(True) - self.realTimeTrackingToggle.setDisabled(True) - label = QLabel('Real-time tracking') - label.setDisabled(True) - self.realTimeTrackingToggle.label = label - layout.addRow(label, self.realTimeTrackingToggle) - - self.showAllContoursToggle = widgets.Toggle() - showAllContoursTooltip = ( - 'If active, all contours will be displayed, including inner contours' - '(e.g. holes and sub-objects)' - ) - self.showAllContoursToggle.setToolTip(showAllContoursTooltip) - showAllContourLabel = QLabel('Show all contours') - showAllContourLabel.setToolTip(showAllContoursTooltip) - layout.addRow(showAllContourLabel, self.showAllContoursToggle) - self.showAllContoursToggle.toggled.connect( - self.showAllContoursToggled - ) - - # Font size - self.fontSizeSpinBox = widgets.SpinBox() - self.fontSizeSpinBox.setMinimum(1) - self.fontSizeSpinBox.setMaximum(99) - layout.addRow('Font size', self.fontSizeSpinBox) - savedFontSize = str(self.df_settings.at['fontSize', 'value']) - if savedFontSize.find('pt') != -1: - savedFontSize = savedFontSize[:-2] - self.fontSize = int(savedFontSize) - if 'pxMode' not in self.df_settings.index: - # Users before introduction of pxMode had pxMode=False, but now - # the new default is True. This requires larger font size. - self.fontSize = 2*self.fontSize - self.df_settings.at['pxMode', 'value'] = 1 - self.df_settings.to_csv(settings_csv_path) - self.fontSizeSpinBox.setValue(self.fontSize) - self.fontSizeSpinBox.editingFinished.connect(self.changeFontSize) - self.fontSizeSpinBox.sigUpClicked.connect(self.changeFontSize) - self.fontSizeSpinBox.sigDownClicked.connect(self.changeFontSize) - - self.quickSettingsGroupbox.setLayout(layout) - self.quickSettingsLayout.addWidget(self.quickSettingsGroupbox) - self.quickSettingsLayout.addStretch(1) - - def showAllContoursToggled(self): - if not self.isDataLoaded: - return - - self.computeAllContours() - self.updateAllImages() - - def gui_createImg1Widgets(self): - # Toggle contours/ID combobox - self.drawIDsContComboBoxSegmItems = [ - 'Draw IDs and contours', - 'Draw IDs and overlay segm. masks', - 'Draw only cell cycle info', - 'Draw cell cycle info and contours', - 'Draw cell cycle info and overlay segm. masks', - 'Draw only mother-bud lines', - 'Draw only IDs', - 'Draw only contours', - 'Draw only overlay segm. masks', - 'Draw nothing' - ] - self.drawIDsContComboBox = widgets.ComboBox() - self.drawIDsContComboBox.setFont(_font) - self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) - self.drawIDsContComboBox.setVisible(False) - - self.annotIDsCheckbox = widgets.CheckBox( - 'IDs', keyPressCallback=self.resetFocus) - self.annotCcaInfoCheckbox = widgets.CheckBox( - 'Cell cycle info', keyPressCallback=self.resetFocus) - self.annotNumZslicesCheckbox = widgets.CheckBox( - 'No. z-slices/object', keyPressCallback=self.resetFocus) - - self.annotContourCheckbox = widgets.CheckBox( - 'Contours', keyPressCallback=self.resetFocus) - self.annotSegmMasksCheckbox = widgets.CheckBox( - 'Segm. masks', keyPressCallback=self.resetFocus) - - self.drawMothBudLinesCheckbox = widgets.CheckBox( - 'Only mother-daughter line', keyPressCallback=self.resetFocus - ) - - self.drawNothingCheckbox = widgets.CheckBox( - 'Do not annotate', keyPressCallback=self.resetFocus - ) - - self.annotOptionsWidget = QWidget() - annotOptionsLayout = QHBoxLayout() - - # Show tree info checkbox - self.showTreeInfoCheckbox = widgets.CheckBox( - 'Show tree info', keyPressCallback=self.resetFocus - ) - self.showTreeInfoCheckbox.setFont(_font) - sp = self.showTreeInfoCheckbox.sizePolicy() - sp.setRetainSizeWhenHidden(True) - self.showTreeInfoCheckbox.setSizePolicy(sp) - self.showTreeInfoCheckbox.hide() - - annotOptionsLayout.addWidget(self.showTreeInfoCheckbox) - annotOptionsLayout.addWidget(QLabel(' | ')) - annotOptionsLayout.addWidget(self.annotIDsCheckbox) - annotOptionsLayout.addWidget(self.annotCcaInfoCheckbox) - annotOptionsLayout.addWidget(self.drawMothBudLinesCheckbox) - annotOptionsLayout.addWidget(self.annotNumZslicesCheckbox) - annotOptionsLayout.addWidget(QLabel(' | ')) - annotOptionsLayout.addWidget(self.annotContourCheckbox) - annotOptionsLayout.addWidget(self.annotSegmMasksCheckbox) - annotOptionsLayout.addWidget(QLabel(' | ')) - annotOptionsLayout.addWidget(self.drawNothingCheckbox) - annotOptionsLayout.addWidget(self.drawIDsContComboBox) - self.annotOptionsLayout = annotOptionsLayout - - # Toggle highlight z+-1 objects combobox - self.highlightZneighObjCheckbox = widgets.CheckBox( - 'Highlight objects in neighbouring z-slices', - keyPressCallback=self.resetFocus - ) - self.highlightZneighObjCheckbox.setFont(_font) - self.highlightZneighObjCheckbox.hide() - - annotOptionsLayout.addWidget(self.highlightZneighObjCheckbox) - self.annotOptionsWidget.setLayout(annotOptionsLayout) - - # Annotations options right image - self.annotIDsCheckboxRight = widgets.CheckBox( - 'IDs', keyPressCallback=self.resetFocus) - self.annotCcaInfoCheckboxRight = widgets.CheckBox( - 'Cell cycle info', keyPressCallback=self.resetFocus) - self.annotNumZslicesCheckboxRight = widgets.CheckBox( - 'No. z-slices/object', keyPressCallback=self.resetFocus - ) - - self.annotContourCheckboxRight = widgets.CheckBox( - 'Contours', keyPressCallback=self.resetFocus) - self.annotSegmMasksCheckboxRight = widgets.CheckBox( - 'Segm. masks', keyPressCallback=self.resetFocus) - - self.drawMothBudLinesCheckboxRight = widgets.CheckBox( - 'Only mother-daughter line', keyPressCallback=self.resetFocus - ) - - self.drawNothingCheckboxRight = widgets.CheckBox( - 'Do not annotate', keyPressCallback=self.resetFocus) - - self.annotOptionsWidgetRight = QWidget() - annotOptionsLayoutRight = QHBoxLayout() - - annotOptionsLayoutRight.addWidget(QLabel(' ')) - annotOptionsLayoutRight.addWidget(QLabel(' | ')) - annotOptionsLayoutRight.addWidget(self.annotIDsCheckboxRight) - annotOptionsLayoutRight.addWidget(self.annotCcaInfoCheckboxRight) - annotOptionsLayoutRight.addWidget(self.drawMothBudLinesCheckboxRight) - annotOptionsLayoutRight.addWidget(self.annotNumZslicesCheckboxRight) - annotOptionsLayoutRight.addWidget(QLabel(' | ')) - annotOptionsLayoutRight.addWidget(self.annotContourCheckboxRight) - annotOptionsLayoutRight.addWidget(self.annotSegmMasksCheckboxRight) - annotOptionsLayoutRight.addWidget(QLabel(' | ')) - annotOptionsLayoutRight.addWidget(self.drawNothingCheckboxRight) - self.annotOptionsLayoutRight = annotOptionsLayoutRight - - self.annotOptionsWidgetRight.setLayout(annotOptionsLayoutRight) - - # Frames scrollbar - self.navigateScrollBar = widgets.navigateQScrollBar(Qt.Horizontal) - self.navigateScrollBar.setDisabled(True) - self.navigateScrollBar.setMinimum(1) - self.navigateScrollBar.setMaximum(1) - self.navigateScrollBar.setToolTip( - 'NOTE: The maximum frame number that can be visualized with this ' - 'scrollbar\n' - 'is the last visited frame with the selected mode\n' - '(see "Mode" selector on the top-right).\n\n' - 'If the scrollbar does not move it means that you never visited\n' - 'any frame with current mode.\n\n' - 'Note that the "Viewer" mode allows you to scroll ALL frames.' - ) - t_label = QLabel('frame n. ') - t_label.setFont(_font) - self.t_label = t_label - - # z-slice scrollbars - self.zSliceScrollBar = widgets.linkedQScrollbar(Qt.Horizontal) - - self.zProjComboBox = widgets.ComboBox() - self.zProjComboBox.setFont(_font) - self.zProjComboBox.addItems([ - 'single z-slice', - 'max z-projection', - 'mean z-projection', - 'median z-proj.' - ]) - self.zProjLockViewButton = widgets.LockPushButton() - self.zProjLockViewButton.setCheckable(True) - self.zProjLockViewButton.setToolTip( - 'If active, the selected z-slice view is applied to all frames' - ) - self.zProjLockViewButton.hide() - - self.switchPlaneCombobox = widgets.SwitchPlaneCombobox() - self.switchPlaneCombobox.setToolTip( - 'Switch viewed plane' - ) - - self.zSliceOverlay_SB = widgets.ScrollBar(Qt.Horizontal) - _z_label = QLabel('Overlay z-slice ') - _z_label.setFont(_font) - _z_label.setDisabled(True) - self.overlay_z_label = _z_label - - self.zProjOverlay_CB = widgets.ComboBox() - self.zProjOverlay_CB.setFont(_font) - self.zProjOverlay_CB.addItems([ - 'single z-slice', 'max z-projection', 'mean z-projection', - 'median z-proj.', 'same as above' - ]) - self.zProjOverlay_CB.setCurrentIndex(4) - self.zSliceOverlay_SB.setDisabled(True) - - self.img1BottomGroupbox = self.gui_getImg1BottomWidgets() - - def gui_getImg1BottomWidgets(self): - bottomLeftLayout = QGridLayout() - self.bottomLeftLayout = bottomLeftLayout - container = QGroupBox('Navigate and annotate left image') - - row = 0 - bottomLeftLayout.addWidget(self.annotOptionsWidget, row, 0, 1, 4) - # bottomLeftLayout.addWidget( - # self.drawIDsContComboBox, row, 1, 1, 2, - # alignment=Qt.AlignCenter - # ) - - # bottomLeftLayout.addWidget( - # self.showTreeInfoCheckbox, row, 0, 1, 1, - # alignment=Qt.AlignCenter - # ) - - row += 1 - navWidgetsLayout = QHBoxLayout() - self.navSpinBox = widgets.SpinBox(disableKeyPress=True) - self.navSpinBox.setMinimum(1) - self.navSpinBox.setMaximum(100) - self.navSizeLabel = QLabel('/ND') - navWidgetsLayout.addWidget(self.t_label) - navWidgetsLayout.addWidget(self.navSpinBox) - navWidgetsLayout.addWidget(self.navSizeLabel) - bottomLeftLayout.addLayout( - navWidgetsLayout, row, 0, alignment=Qt.AlignRight - ) - bottomLeftLayout.addWidget(self.navigateScrollBar, row, 1, 1, 2) - sp = self.navigateScrollBar.sizePolicy() - sp.setRetainSizeWhenHidden(True) - self.navigateScrollBar.setSizePolicy(sp) - self.navSpinBox.connectValueChanged(self.navigateSpinboxValueChanged) - self.navSpinBox.editingFinished.connect( - self.navigateSpinboxEditingFinished - ) - self.navSpinBox.sigUpClicked.connect( - self.navigateSpinboxEditingFinished - ) - self.navSpinBox.sigDownClicked.connect( - self.navigateSpinboxEditingFinished - ) - - self.lastTrackedFrameLabel = QLabel() - self.lastTrackedFrameLabel.setFont(_font) - bottomLeftLayout.addWidget(self.lastTrackedFrameLabel, row, 3) - - row += 1 - zSliceCheckboxLayout = QHBoxLayout() - self.zSliceCheckbox = QCheckBox('z-slice') - self.zSliceSpinbox = widgets.SpinBox(disableKeyPress=True) - self.zSliceSpinbox.setMinimum(1) - self.SizeZlabel = QLabel('/ND') - self.zSliceCheckbox.setToolTip( - 'Activate/deactivate control of the z-slices with keyboard arrows.\n\n' - 'SHORTCUT to toggle ON/OFF: "Z" key' - ) - zSliceCheckboxLayout.addWidget(self.zSliceCheckbox) - zSliceCheckboxLayout.addWidget(self.zSliceSpinbox) - zSliceCheckboxLayout.addWidget(self.SizeZlabel) - bottomLeftLayout.addLayout( - zSliceCheckboxLayout, row, 0, alignment=Qt.AlignRight - ) - bottomLeftLayout.addWidget(self.zSliceScrollBar, row, 1, 1, 2) - bottomLeftLayout.addWidget(self.zProjComboBox, row, 3) - bottomLeftLayout.addWidget(self.zProjLockViewButton, row, 4) - bottomLeftLayout.addWidget(self.switchPlaneCombobox, row, 5) - self.zSliceSpinbox.connectValueChanged(self.onZsliceSpinboxValueChange) - self.zSliceSpinbox.editingFinished.connect(self.zSliceScrollBarReleased) - - row += 1 - bottomLeftLayout.addWidget( - self.overlay_z_label, row, 0, alignment=Qt.AlignRight - ) - bottomLeftLayout.addWidget(self.zSliceOverlay_SB, row, 1, 1, 2) - - bottomLeftLayout.addWidget(self.zProjOverlay_CB, row, 3) - - row += 1 - self.alphaScrollbarRow = row - - bottomLeftLayout.setColumnStretch(0,0) - bottomLeftLayout.setColumnStretch(1,3) - bottomLeftLayout.setColumnStretch(2,0) - - container.setLayout(bottomLeftLayout) - return container - - def gui_createLabWidgets(self): - bottomRightLayout = QVBoxLayout() - self.rightBottomGroupbox = widgets.GroupBox( - 'Annotate right image independent of left image', - keyPressCallback=self.resetFocus - ) - self.rightBottomGroupbox.setCheckable(True) - self.rightBottomGroupbox.setChecked(False) - self.rightBottomGroupbox.hide() - - self.annotateRightHowCombobox = widgets.ComboBox() - self.annotateRightHowCombobox.setFont(_font) - self.annotateRightHowCombobox.addItems(self.drawIDsContComboBoxSegmItems) - self.annotateRightHowCombobox.setCurrentIndex( - self.drawIDsContComboBox.currentIndex() - ) - self.annotateRightHowCombobox.setVisible(False) - - self.annotOptionsLayoutRight.addWidget(self.annotateRightHowCombobox) - - self.rightImageFramesScrollbar = widgets.ScrollBarWithNumericControl( - labelText='Frame n. ' - ) - self.rightImageFramesScrollbar.setVisible(False) - - bottomRightLayout.addWidget(self.annotOptionsWidgetRight) - bottomRightLayout.addWidget(self.rightImageFramesScrollbar) - bottomRightLayout.addStretch(1) - - self.rightBottomGroupbox.setLayout(bottomRightLayout) - - self.rightBottomGroupbox.toggled.connect(self.rightImageControlsToggled) - - def rightImageControlsToggled(self, checked): - if self.isDataLoading: - return - if checked: - self.annotateRightHowCombobox.setCurrentText( - self.drawIDsContComboBox.currentText() - ) - self.updateAllImages() - - def setFocusGraphics(self): - self.graphLayout.setFocus() - - def setFocusMain(self): - # on macOS with Qt6 setFocus causes crashes. Disabled for now. - return - - def resetFocus(self): - self.setFocusGraphics() - self.setFocusMain() - - def gui_createBottomWidgetsToBottomLayout(self): - # self.bottomDockWidget = QDockWidget(self) - bottomScrollArea = widgets.ScrollArea(resizeVerticalOnShow=True) - bottomScrollArea.sigLeaveEvent.connect(self.setFocusMain) - bottomWidget = QWidget() - bottomScrollAreaLayout = QVBoxLayout() - self.bottomLayout = QHBoxLayout() - self.bottomLayout.addLayout(self.quickSettingsLayout) - self.bottomLayout.addStretch(1) - self.bottomLayout.addWidget(self.img1BottomGroupbox) - self.bottomLayout.addStretch(1) - self.bottomLayout.addWidget(self.rightBottomGroupbox) - self.bottomLayout.addStretch(1) - - bottomScrollAreaLayout.addLayout(self.bottomLayout) - bottomScrollAreaLayout.addStretch(1) - - bottomWidget.setLayout(bottomScrollAreaLayout) - bottomScrollArea.setWidgetResizable(True) - bottomScrollArea.setWidget(bottomWidget) - self.bottomScrollArea = bottomScrollArea - - if 'bottom_sliders_zoom_perc' in self.df_settings.index: - val = int(self.df_settings.at['bottom_sliders_zoom_perc', 'value']) - zoom_perc = val - else: - zoom_perc = 100 - self.bottomLayoutContextMenu = QMenu('Bottom layout', self) - zoomMenu = self.bottomLayoutContextMenu.addMenu('Zoom') - actions = [] - self.bottomLayoutContextMenu.zoomActionGroup = QActionGroup(zoomMenu) - for perc in np.arange(50, 151, 10): - action = QAction(f'{perc}%', zoomMenu) - action.setCheckable(True) - if perc == zoom_perc: - action.setChecked(True) - action.toggled.connect(self.zoomBottomLayoutActionTriggered) - actions.append(action) - self.bottomLayoutContextMenu.zoomActionGroup.addAction(action) - zoomMenu.addActions(actions) - resetAction = self.bottomLayoutContextMenu.addAction( - 'Reset default height' - ) - resetAction.triggered.connect(self.resizeGui) - retainSpaceAction = self.bottomLayoutContextMenu.addAction( - 'Retain space of hidden sliders' - ) - retainSpaceAction.setCheckable(True) - if 'retain_space_hidden_sliders' in self.df_settings.index: - retainSpaceChecked = ( - self.df_settings.at['retain_space_hidden_sliders', 'value'] - == 'Yes' - ) - else: - retainSpaceChecked = True - retainSpaceAction.setChecked(retainSpaceChecked) - retainSpaceAction.toggled.connect(self.retainSpaceSlidersToggled) - self.retainSpaceSlidersAction = retainSpaceAction - self.setBottomLayoutStretch() - - def gui_resetBottomLayoutHeight(self): - self.h = self.defaultWidgetHeightBottomLayout - self.checkBoxesHeight = 14 - self.fontPixelSize = 11 - self.resizeSlidersArea() - - def gui_createGraphicsPlots(self): - self.graphLayout = pg.GraphicsLayoutWidget() - if self.invertBwAction.isChecked(): - self.graphLayout.setBackground(graphLayoutBkgrColor) - self.titleColor = 'black' - else: - self.graphLayout.setBackground(darkBkgrColor) - self.titleColor = 'white' - - self.lutItemsLayout = self.graphLayout.addLayout(row=1, col=0) - # self.lutItemsLayout.setBorder('w') - - # Left plot - self.ax1 = widgets.MainPlotItem(showWelcomeText=True) - self.ax1.invertY(True) - self.ax1.setAspectLocked(True) - self.ax1.hideAxis('bottom') - self.ax1.hideAxis('left') - self.plotsCol = 1 - self.graphLayout.addItem(self.ax1, row=1, col=1) - - # Right plot - self.ax2 = widgets.MainPlotItem() - self.ax2.setAspectLocked(True) - self.ax2.invertY(True) - self.ax2.hideAxis('bottom') - self.ax2.hideAxis('left') - # self.currentFrameLabelItem = pg.LabelItem( - # color=self.titleColor, size='13px' - # ) - self.graphLayout.addItem(self.ax2, row=1, col=2) - - def gui_addGraphicsItems(self): - # Auto image adjustment button - proxy = QGraphicsProxyWidget() - equalizeHistPushButton = QPushButton("Enhance contrast") - widthHint = equalizeHistPushButton.sizeHint().width() - equalizeHistPushButton.setMaximumWidth(widthHint) - equalizeHistPushButton.setCheckable(True) - if not self.invertBwAction.isChecked(): - equalizeHistPushButton.setStyleSheet( - 'QPushButton {background-color: #282828; color: #F0F0F0;}' - ) - self.equalizeHistPushButton = equalizeHistPushButton - proxy.setWidget(equalizeHistPushButton) - self.graphLayout.addItem(proxy, row=0, col=0) - self.equalizeHistPushButton = equalizeHistPushButton - - # Left image histogram - self.imgGrad = widgets.myHistogramLUTitem(parent=self, name='image') - self.imgGrad.restoreState(self.df_settings) - self.lutItemsLayout.addItem(self.imgGrad, row=0, col=0) - for action in self.imgGrad.rescaleActionGroup.actions(): - if action.text() == self.defaultRescaleIntensHow: - action.setChecked(True) - self.rescaleIntensMenu.addAction(action) - - # Colormap gradient widget - self.labelsGrad = widgets.labelsGradientWidget(parent=self) - try: - stateFound = self.labelsGrad.restoreState(self.df_settings) - except Exception as e: - self.logger.exception(traceback.format_exc()) - print('======================================') - self.logger.info( - 'Failed to restore previously used colormap. ' - 'Using default colormap "viridis"' - ) - self.labelsGrad.item.loadPreset('viridis') - - # Add actions to imgGrad gradient item - self.imgGrad.gradient.menu.addAction( - self.labelsGrad.showLabelsImgAction - ) - self.imgGrad.gradient.menu.addAction( - self.labelsGrad.showRightImgAction - ) - self.imgGrad.gradient.menu.addAction( - self.labelsGrad.showNextFrameAction - ) - - self.imgGrad.gradient.menu.addSeparator() - - self.imgGrad.gradient.menu.addMenu(self.exportMenu) - - # Add actions to view menu - self.viewMenu.addAction(self.labelsGrad.showLabelsImgAction) - self.viewMenu.addAction(self.labelsGrad.showRightImgAction) - - # Right image histogram - self.imgGradRight = widgets.baseHistogramLUTitem( - name='image', parent=self, gradientPosition='left' - ) - self.imgGradRight.gradient.menu.addAction( - self.labelsGrad.showLabelsImgAction - ) - self.imgGradRight.gradient.menu.addAction( - self.labelsGrad.showRightImgAction - ) - self.imgGradRight.gradient.menu.addAction( - self.labelsGrad.showNextFrameAction - ) - - self.imgGrad.setChildLutItem(self.imgGradRight) - - # Title - self.titleLabel = pg.LabelItem( - justify='center', color=self.titleColor, size='14pt' - ) - self.graphLayout.addItem(self.titleLabel, row=0, col=1, colspan=2) - - def gui_createTextAnnotColors(self, r, g, b, custom=False): - if custom: - self.objLabelAnnotRgb = (int(r), int(g), int(b)) - self.SphaseAnnotRgb = (int(r*0.9), int(r*0.9), int(b*0.9)) - self.G1phaseAnnotRgba = (int(r*0.8), int(g*0.8), int(b*0.8), 220) - else: - self.objLabelAnnotRgb = (255, 255, 255) # white - self.SphaseAnnotRgb = (229, 229, 229) - self.G1phaseAnnotRgba = (204, 204, 204, 220) - self.dividedAnnotRgb = (245, 188, 1) # orange - - self.emptyBrush = pg.mkBrush((0,0,0,0)) - self.emptyPen = pg.mkPen((0,0,0,0)) - - def gui_setTextAnnotColors(self): - self.textAnnot[0].setColors( - self.objLabelAnnotRgb, self.dividedAnnotRgb, self.SphaseAnnotRgb, - self.G1phaseAnnotRgba, self.objLostAnnotRgb, self.objLostTrackedAnnotRgb - ) - - self.textAnnot[1].setColors( - self.objLabelAnnotRgb, self.dividedAnnotRgb, self.SphaseAnnotRgb, - self.G1phaseAnnotRgba, self.objLostAnnotRgb, self.objLostTrackedAnnotRgb - ) - - - def gui_createPlotItems(self): - if 'textIDsColor' in self.df_settings.index: - rgbString = self.df_settings.at['textIDsColor', 'value'] - r, g, b = colors.rgb_str_to_values(rgbString) - self.gui_createTextAnnotColors(r, g, b, custom=True) - self.textIDsColorButton.setColor((r, g, b)) - else: - self.gui_createTextAnnotColors(0,0,0, custom=False) - - if 'labels_text_color' in self.df_settings.index: - rgbString = self.df_settings.at['labels_text_color', 'value'] - r, g, b = colors.rgb_str_to_values(rgbString) - self.ax2_textColor = (r, g, b) - else: - self.ax2_textColor = (255, 0, 0) - - self.emptyLab = np.zeros((2,2), dtype=np.uint8) - - # Right image item linked to left - self.rightImageItem = widgets.ChildImageItem( - linkedScrollbar=self.rightImageFramesScrollbar - ) - self.imgGradRight.setImageItem(self.rightImageItem) - self.ax2.addItem(self.rightImageItem) - - # Left image - self.img1 = widgets.ParentImageItem( - linkedImageItem=self.rightImageItem, - activatingActions=( - self.labelsGrad.showRightImgAction, - self.labelsGrad.showNextFrameAction - ) - ) - self.imgGrad.setImageItem(self.img1) - self.img1.lutItem = self.imgGrad - self.imgGrad.sigRescaleIntes.connect(self.rescaleIntensitiesLut) - self.ax1.addBaseImageItem(self.img1) - - # RGBA image for true transparency mode - self.rgbaImg1 = pg.ImageItem() - - # self.rgbaImg1.setImage(self.emptyLab) - - # Right image - self.img2 = widgets.labImageItem() - self.ax2.addItem(self.img2) - - self.topLayerItems = [] - self.topLayerItemsRight = [] - - self.gui_createContourPens() - self.gui_createMothBudLinePens() - - self.eraserCirclePen = pg.mkPen(width=1.5, color='r') - - # Temporary line item connecting bud to new mother - self.BudMothTempLine = pg.PlotDataItem(pen=self.NewBudMoth_Pen) - self.topLayerItems.append(self.BudMothTempLine) - - # Temporary line item connecting objects to merge - self.mergeObjsTempLine = widgets.PlotCurveItem(pen=self.redDashLinePen) - self.topLayerItems.append(self.mergeObjsTempLine) - - # Overlay segm. masks item - self.labelsLayerImg1 = widgets.BaseLabelsImageItem() - self.ax1.addItem(self.labelsLayerImg1) - - self.labelsLayerRightImg = widgets.BaseLabelsImageItem() - self.ax2.addItem(self.labelsLayerRightImg) - - # Red/green border rect item - self.GreenLinePen = pg.mkPen(color='g', width=2) - self.RedLinePen = pg.mkPen(color='r', width=2) - self.ax1BorderLine = pg.PlotDataItem() - self.topLayerItems.append(self.ax1BorderLine) - self.ax2BorderLine = pg.PlotDataItem(pen=pg.mkPen(color='r', width=2)) - self.topLayerItems.append(self.ax2BorderLine) - - # Brush/Eraser/Wand.. layer item - self.tempLayerRightImage = pg.ImageItem() - self.tempLayerImg1 = widgets.ParentImageItem( - linkedImageItem=self.tempLayerRightImage, - activatingAction=(self.labelsGrad.showRightImgAction, ) - ) - self.topLayerItems.append(self.tempLayerImg1) - self.topLayerItemsRight.append(self.tempLayerRightImage) - - # Highlighted ID layer items - self.highLightIDLayerImg1 = pg.ImageItem() - self.topLayerItems.append(self.highLightIDLayerImg1) - - # Highlighted ID layer items - self.highLightIDLayerRightImage = pg.ImageItem() - self.topLayerItemsRight.append(self.highLightIDLayerRightImage) - - # Keep IDs temp layers - self.keepIDsTempLayerRight = pg.ImageItem() - self.keepIDsTempLayerLeft = widgets.ParentImageItem( - linkedImageItem=self.keepIDsTempLayerRight, - activatingAction=self.labelsGrad.showRightImgAction - ) - self.topLayerItems.append(self.keepIDsTempLayerLeft) - self.topLayerItemsRight.append(self.keepIDsTempLayerRight) - - # Searched ID contour - self.searchedIDitemRight = pg.ScatterPlotItem() - self.searchedIDitemRight.setData( - [], [], symbol='s', pxMode=False, size=1, - brush=pg.mkBrush(color=(255,0,0,150)), - pen=pg.mkPen(width=2, color='r'), tip=None - ) - self.searchedIDitemLeft = pg.ScatterPlotItem() - self.searchedIDitemLeft.setData( - [], [], symbol='s', pxMode=False, size=1, - brush=pg.mkBrush(color=(255,0,0,150)), - pen=pg.mkPen(width=2, color='r'), tip=None - ) - self.topLayerItems.append(self.searchedIDitemLeft) - self.topLayerItemsRight.append(self.searchedIDitemRight) - - - # Brush circle img1 - self.ax1_BrushCircle = pg.ScatterPlotItem() - self.ax1_BrushCircle.setData( - [], [], symbol='o', pxMode=False, - brush=pg.mkBrush((255,255,255,50)), - pen=pg.mkPen(width=2), tip=None - ) - self.topLayerItems.append(self.ax1_BrushCircle) - - # Eraser circle img1 - self.ax1_EraserCircle = pg.ScatterPlotItem() - self.ax1_EraserCircle.setData( - [], [], symbol='o', pxMode=False, - brush=None, pen=self.eraserCirclePen, tip=None - ) - self.topLayerItems.append(self.ax1_EraserCircle) - - self.ax1_EraserX = pg.ScatterPlotItem() - self.ax1_EraserX.setData( - [], [], symbol='x', pxMode=False, size=3, - brush=pg.mkBrush(color=(255,0,0,50)), - pen=pg.mkPen(width=1, color='r'), tip=None - ) - self.topLayerItems.append(self.ax1_EraserX) - - # Brush circle img1 - self.labelRoiCircItemLeft = widgets.LabelRoiCircularItem() - self.labelRoiCircItemLeft.cleared = False - self.labelRoiCircItemLeft.setData( - [], [], symbol='o', pxMode=False, - brush=pg.mkBrush(color=(255,0,0,0)), - pen=pg.mkPen(color='r', width=2), tip=None - ) - self.labelRoiCircItemRight = widgets.LabelRoiCircularItem() - self.labelRoiCircItemRight.cleared = False - self.labelRoiCircItemRight.setData( - [], [], symbol='o', pxMode=False, - brush=pg.mkBrush(color=(255,0,0,0)), - pen=pg.mkPen(color='r', width=2), tip=None - ) - self.topLayerItems.append(self.labelRoiCircItemLeft) - self.topLayerItemsRight.append(self.labelRoiCircItemRight) - - self.ax1_binnedIDs_ScatterPlot = widgets.BaseScatterPlotItem() - self.ax1_binnedIDs_ScatterPlot.setData( - [], [], symbol='t', pxMode=False, - brush=pg.mkBrush((255,0,0,50)), size=15, - pen=pg.mkPen(width=3, color='r'), tip=None - ) - self.topLayerItems.append(self.ax1_binnedIDs_ScatterPlot) - - self.ax1_ripIDs_ScatterPlot = widgets.BaseScatterPlotItem() - self.ax1_ripIDs_ScatterPlot.setData( - [], [], symbol='x', pxMode=False, - brush=pg.mkBrush((255,0,0,50)), size=15, - pen=pg.mkPen(width=2, color='r'), tip=None - ) - self.topLayerItems.append(self.ax1_ripIDs_ScatterPlot) - - # Ruler plotItem and scatterItem - rulerPen = pg.mkPen(color='r', style=Qt.DashLine, width=2) - self.ax1_rulerPlotItem = widgets.RulerPlotItem(pen=rulerPen) - self.ax1_rulerAnchorsItem = pg.ScatterPlotItem( - symbol='o', size=9, - brush=pg.mkBrush((255,0,0,50)), - pen=pg.mkPen((255,0,0), width=2), tip=None - ) - self.topLayerItems.append(self.ax1_rulerPlotItem) - self.topLayerItems.append(self.ax1_rulerPlotItem.labelItem) - self.topLayerItems.append(self.ax1_rulerAnchorsItem) - - # Start point of polyline roi - self.ax1_point_ScatterPlot = pg.ScatterPlotItem() - self.ax1_point_ScatterPlot.setData( - [], [], symbol='o', pxMode=False, size=3, - pen=pg.mkPen(width=2, color='r'), - brush=pg.mkBrush((255,0,0,50)), tip=None - ) - self.topLayerItems.append(self.ax1_point_ScatterPlot) - - # Experimental: scatter plot to add a point marker - self.startPointPolyLineItem = pg.ScatterPlotItem() - self.startPointPolyLineItem.setData( - [], [], symbol='o', size=9, - pen=pg.mkPen(width=2, color='r'), - brush=pg.mkBrush((255,0,0,50)), - hoverable=True, hoverBrush=pg.mkBrush((255,0,0,255)), tip=None - ) - self.topLayerItems.append(self.startPointPolyLineItem) - - # Eraser circle img2 - self.ax2_EraserCircle = pg.ScatterPlotItem() - self.ax2_EraserCircle.setData( - [], [], symbol='o', pxMode=False, brush=None, - pen=self.eraserCirclePen, tip=None - ) - self.ax2.addItem(self.ax2_EraserCircle) - self.ax2_EraserX = pg.ScatterPlotItem() - self.ax2_EraserX.setData( - [], [], symbol='x', pxMode=False, size=3, - brush=pg.mkBrush(color=(255,0,0,50)), - pen=pg.mkPen(width=1.5, color='r') - ) - self.ax2.addItem(self.ax2_EraserX) - - # Brush circle img2 - self.ax2_BrushCirclePen = pg.mkPen(width=2) - self.ax2_BrushCircleBrush = pg.mkBrush((255,255,255,50)) - self.ax2_BrushCircle = pg.ScatterPlotItem() - self.ax2_BrushCircle.setData( - [], [], symbol='o', pxMode=False, - brush=self.ax2_BrushCircleBrush, - pen=self.ax2_BrushCirclePen, tip=None - ) - self.ax2.addItem(self.ax2_BrushCircle) - - # Annotated metadata markers (ScatterPlotItem) - self.ax2_binnedIDs_ScatterPlot = widgets.BaseScatterPlotItem() - self.ax2_binnedIDs_ScatterPlot.setData( - [], [], symbol='t', pxMode=False, - brush=pg.mkBrush((255,0,0,50)), size=15, - pen=pg.mkPen(width=3, color='r'), tip=None - ) - self.ax2.addItem(self.ax2_binnedIDs_ScatterPlot) - - self.ax2_ripIDs_ScatterPlot = widgets.BaseScatterPlotItem() - self.ax2_ripIDs_ScatterPlot.setData( - [], [], symbol='x', pxMode=False, - brush=pg.mkBrush((255,0,0,50)), size=15, - pen=pg.mkPen(width=2, color='r'), tip=None - ) - self.ax2.addItem(self.ax2_ripIDs_ScatterPlot) - - self.freeRoiItem = widgets.PlotCurveItem( - pen=pg.mkPen(color='r', width=2) - ) - self.topLayerItems.append(self.freeRoiItem) - - self.warnPairingItem = widgets.PlotCurveItem( - pen=pg.mkPen(color='r', width=5, style=Qt.DashLine), - pxMode=False - ) - self.topLayerItems.append(self.warnPairingItem) - - self.exportMaskImageItem = pg.ImageItem() - - self.ghostContourItemLeft = widgets.GhostContourItem(self.ax1) - self.ghostContourItemRight = widgets.GhostContourItem(self.ax2) - - self.ghostMaskItemLeft = widgets.GhostMaskItem(self.ax1) - self.ghostMaskItemRight = widgets.GhostMaskItem(self.ax2) - - self.manualBackgroundObjItem = widgets.GhostContourItem( - self.ax1, penColor='r', textColor='r' - ) - self.manualBackgroundImageItem = pg.ImageItem() - - def gui_createZoomRectItem(self): - Y, X = self.currentLab2D.shape - # Label ROI rectangle - pen = pg.mkPen('r', width=3, style=Qt.DashLine) - self.zoomRectItem = widgets.ZoomROI( - (0,0), (0,0), - maxBounds=QRectF(QRect(0,0,X,Y)), - scaleSnap=True, - translateSnap=True, - pen=pen, hoverPen=pen - ) - - def gui_createLabelRoiItem(self): - Y, X = self.currentLab2D.shape - # Label ROI rectangle - pen = pg.mkPen('r', width=3) - self.labelRoiItem = widgets.ROI( - (0,0), (0,0), - maxBounds=QRectF(QRect(0,0,X,Y)), - scaleSnap=True, - translateSnap=True, - pen=pen, hoverPen=pen - ) - - posData = self.data[self.pos_i] - if self.labelRoiZdepthSpinbox.value() == 0: - self.labelRoiZdepthSpinbox.setValue(posData.SizeZ) - self.labelRoiZdepthSpinbox.setMaximum(posData.SizeZ+1) - - def gui_createOverlayColors(self): - fluoChannels = [ch for ch in self.ch_names if ch != self.user_ch_name] - self.logger.info( - f'Number of TIFF files detected: {len(fluoChannels)}' - ) - self.overlayColors = {} - for c, ch in enumerate(fluoChannels): - if f'{ch}_rgb' in self.df_settings.index: - rgb_text = self.df_settings.at[f'{ch}_rgb', 'value'] - rgb = tuple([int(val) for val in rgb_text.split('_')]) - self.overlayColors[ch] = rgb - else: - if c >= len(self.overlayRGBs) -1: - i = c/len(fluoChannels) - additional_color_num = c - len(self.overlayRGBs) + 1 - rgbs = [ - tuple([round(c*255) for c in self.overlayCmap(i)][:3]) - for _ in range(additional_color_num) - ] - self.overlayRGBs.extend(rgbs) - rgb = colors.FLUO_CHANNELS_COLORS.get(ch, self.overlayRGBs[c]) - self.overlayColors[ch] = rgb - - def gui_createOverlayItems(self): - self.imgGrad.setAxisLabel(self.user_ch_name) - self.baseLayerToolbutton = widgets.OverlayChannelToolButton( - self.user_ch_name, self.imgGrad - ) - self.baseLayerToolbutton.setChecked(True) - self.baseLayerToolbutton.clicked.connect( - self.overlayChannelToolbuttonClicked - ) - self.allOverlayToolbuttons = { - self.user_ch_name: self.baseLayerToolbutton - } - self.allOverlayToolbuttonsByIdx = { - 0: self.baseLayerToolbutton - } - self.baseLayerToolbutton.action = ( - self.overlayToolbar.addWidget(self.baseLayerToolbutton) - ) - self.overlayLayersItems = {} - self.overlayToolbarAreChannelsChecked = {} - fluoChannels = [ch for ch in self.ch_names if ch != self.user_ch_name] - for c, ch in enumerate(fluoChannels): - overlayItems = self.getOverlayItems(ch, c+1) - self.overlayLayersItems[ch] = overlayItems - imageItem, lutItem = overlayItems[:2] - self.ax1.addItem(imageItem) - self.lutItemsLayout.addItem(lutItem, row=0, col=c+1) - toolbutton = overlayItems[3] - self.allOverlayToolbuttons[ch] = toolbutton - self.allOverlayToolbuttonsByIdx[c+1] = toolbutton - - self.overlayToolbuttonsSep = self.overlayToolbar.addSeparator() - self.plotsCol = len(self.ch_names) - - self.ax1.addImageItem(self.rgbaImg1) - - def gui_getLostObjScatterItem(self): - self.objLostAnnotRgb = (245, 184, 0) - brush = pg.mkBrush((*self.objLostAnnotRgb, 150)) - pen = pg.mkPen(self.objLostAnnotRgb, width=1) - lostObjScatterItem = pg.ScatterPlotItem( - size=self.contLineWeight+1, pen=pen, - brush=brush, pxMode=False, symbol='s' - ) - return lostObjScatterItem - - def gui_getTrackedLostObjScatterItem(self): - self.objLostTrackedAnnotRgb = (0, 255, 0) - brush = pg.mkBrush((*self.objLostTrackedAnnotRgb, 150)) - pen = pg.mkPen(self.objLostTrackedAnnotRgb, width=1) - lostObjScatterItem = pg.ScatterPlotItem( - size=self.contLineWeight+1, pen=pen, - brush=brush, pxMode=False, symbol='s' - ) - return lostObjScatterItem - - def _gui_createGraphicsItems(self): - for _posData in self.data: - _posData.allData_li = [None]*_posData.SizeT - - posData = self.data[self.pos_i] - - allIDs, posData = core.count_objects(posData, self.logger.info) - - self.highLowResAction.setChecked(True) - numItems = len(allIDs) - if numItems > 1500: - cancel, switchToLowRes = _warnings.warnTooManyItems( - self, numItems, self.progressWin - ) - if cancel: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.loadingDataAborted() - return - if switchToLowRes: - self.highLowResAction.setChecked(False) - else: - # Many items requires pxMode active to be fast enough - self.pxModeAction.setChecked(True) - - self.logger.info(f'Creating graphical items...') - - self.ax1_contoursImageItem = pg.ImageItem() - - self.ax1_lostObjImageItem = pg.ImageItem() - self.ax2_lostObjImageItem = pg.ImageItem() - - self.ax1_lostTrackedObjImageItem = pg.ImageItem() - self.ax2_lostTrackedObjImageItem = pg.ImageItem() - - self.ax1_oldMothBudLinesItem = pg.ScatterPlotItem( - symbol='s', pxMode=False, brush=self.oldMothBudLineBrush, - size=self.mothBudLineWeight, pen=None - ) - self.ax1_newMothBudLinesItem = pg.ScatterPlotItem( - symbol='s', pxMode=False, brush=self.newMothBudLineBrush, - size=self.mothBudLineWeight, pen=None - ) - self.ax1_lostObjScatterItem = self.gui_getLostObjScatterItem() - self.yellowContourScatterItem = self.gui_getLostObjScatterItem() - - self.ax1_lostTrackedScatterItem = self.gui_getTrackedLostObjScatterItem() - self.greenContourScatterItem = self.gui_getTrackedLostObjScatterItem() - - brush = pg.mkBrush((0,255,0,200)) - pen = pg.mkPen('g', width=1) - self.ccaFailedScatterItem = pg.ScatterPlotItem( - size=self.contLineWeight+1, pen=pen, - brush=brush, pxMode=False, symbol='s' - ) - - self.ax2_contoursImageItem = pg.ImageItem() - self.ax2_oldMothBudLinesItem = pg.ScatterPlotItem( - symbol='s', pxMode=False, brush=self.oldMothBudLineBrush, - size=self.mothBudLineWeight, pen=None - ) - self.ax2_newMothBudLinesItem = pg.ScatterPlotItem( - symbol='s', pxMode=False, brush=self.newMothBudLineBrush, - size=self.mothBudLineWeight, pen=None - ) - self.ax2_lostObjScatterItem = self.gui_getLostObjScatterItem() - self.ax2_lostTrackedScatterItem = self.gui_getTrackedLostObjScatterItem() - - self.gui_createTextAnnotItems(allIDs) # here - self.gui_setTextAnnotColors()# here - - self.setDisabledAnnotOptions(False) - - self.progressWin.mainPbar.setMaximum(0) - self.gui_addOverlayLayerItems() - self.gui_addTopLayerItems() - - self.gui_addCreatedAxesItems() - self.gui_add_ax_cursors() - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.loadingDataCompleted() - - def gui_createTextAnnotItems(self, allIDs): - self.textAnnot = {} - isHighResolution = self.highLowResAction.isChecked() - pxMode = self.pxModeAction.isChecked() - for ax in range(2): - ax_textAnnot = annotate.TextAnnotations() - ax_textAnnot.initFonts(self.fontSize) - ax_textAnnot.createItems( - isHighResolution, allIDs, pxMode=pxMode - ) - self.textAnnot[ax] = ax_textAnnot - - def gui_addOverlayLayerItems(self): - for items in self.overlayLabelsItems.values(): - imageItem, contoursItem, gradItem = items - self.ax1.addItem(imageItem) - self.ax1.addItem(contoursItem) - - def gui_addTopLayerItems(self): - for item in self.topLayerItems: - self.ax1.addItem(item) - - for item in self.topLayerItemsRight: - self.ax2.addItem(item) - - # self.ax2.addItem(self.currentFrameLabelItem) - - def gui_createMothBudLinePens(self): - if 'mothBudLineSize' in self.df_settings.index: - val = self.df_settings.at['mothBudLineSize', 'value'] - self.mothBudLineWeight = int(val) - else: - self.mothBudLineWeight = 2 - - self.newMothBudlineColor = (255, 0, 0) - if 'mothBudLineColor' in self.df_settings.index: - val = self.df_settings.at['mothBudLineColor', 'value'] - rgba = colors.rgba_str_to_values(val) - self.mothBudLineColor = rgba[0:3] - else: - self.mothBudLineColor = (255,165,0) - - try: - self.imgGrad.mothBudLineColorButton.sigColorChanging.disconnect() - self.imgGrad.mothBudLineColorButton.sigColorChanged.disconnect() - except Exception as e: - pass - try: - for act in self.imgGrad.mothBudLineWightActionGroup.actions(): - act.toggled.disconnect() - except Exception as e: - pass - for act in self.imgGrad.mothBudLineWightActionGroup.actions(): - if act.lineWeight == self.mothBudLineWeight: - act.setChecked(True) - else: - act.setChecked(False) - self.imgGrad.mothBudLineColorButton.setColor(self.mothBudLineColor[:3]) - - self.imgGrad.mothBudLineColorButton.sigColorChanging.connect( - self.updateMothBudLineColour - ) - self.imgGrad.mothBudLineColorButton.sigColorChanged.connect( - self.saveMothBudLineColour - ) - for act in self.imgGrad.mothBudLineWightActionGroup.actions(): - act.toggled.connect(self.mothBudLineWeightToggled) - - # MOther-bud lines brushes - self.NewBudMoth_Pen = pg.mkPen( - color=self.newMothBudlineColor, width=self.mothBudLineWeight+1, - style=Qt.DashLine - ) - self.OldBudMoth_Pen = pg.mkPen( - color=self.mothBudLineColor, width=self.mothBudLineWeight, - style=Qt.DashLine - ) - - self.redDashLinePen = pg.mkPen( - color='r', width=2, style=Qt.DashLine - ) - - self.oldMothBudLineBrush = pg.mkBrush(self.mothBudLineColor) - self.newMothBudLineBrush = pg.mkBrush(self.newMothBudlineColor) - - def gui_createContourPens(self): - if 'contLineWeight' in self.df_settings.index: - val = self.df_settings.at['contLineWeight', 'value'] - self.contLineWeight = int(val) - else: - self.contLineWeight = 1 - if 'contLineColor' in self.df_settings.index: - val = self.df_settings.at['contLineColor', 'value'] - rgba = colors.rgba_str_to_values(val) - self.contLineColor = rgba - self.newIDlineColor = [min(255, v+50) for v in self.contLineColor] - else: - self.contLineColor = (255, 0, 0, 200) - self.newIDlineColor = (255, 0, 0, 255) - - try: - self.imgGrad.contoursColorButton.sigColorChanging.disconnect() - self.imgGrad.contoursColorButton.sigColorChanged.disconnect() - except Exception as e: - pass - try: - for act in self.imgGrad.contLineWightActionGroup.actions(): - act.toggled.disconnect() - except Exception as e: - pass - for act in self.imgGrad.contLineWightActionGroup.actions(): - if act.lineWeight == self.contLineWeight: - act.setChecked(True) - self.imgGrad.contoursColorButton.setColor(self.contLineColor[:3]) - - self.imgGrad.contoursColorButton.sigColorChanging.connect( - self.updateContColour - ) - self.imgGrad.contoursColorButton.sigColorChanged.connect( - self.saveContColour - ) - for act in self.imgGrad.contLineWightActionGroup.actions(): - act.toggled.connect(self.contLineWeightToggled) - - # Contours pens - self.oldIDs_cpen = pg.mkPen( - color=self.contLineColor, width=self.contLineWeight - ) - self.newIDs_cpen = pg.mkPen( - color=self.newIDlineColor, width=self.contLineWeight+1 - ) - self.tempNewIDs_cpen = pg.mkPen( - color='g', width=self.contLineWeight+1 - ) - - def gui_createGraphicsItems(self): - # Create enough PlotDataItems and LabelItems to draw contours and IDs. - self.progressWin = apps.QDialogWorkerProgress( - title='Creating axes items', parent=self, - pbarDesc='Creating axes items (see progress in the terminal)...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - QTimer.singleShot(50, self._gui_createGraphicsItems) - - def gui_connectGraphicsEvents(self): - self.img1.hoverEvent = self.gui_hoverEventImg1 - self.img2.hoverEvent = self.gui_hoverEventImg2 - self.img1.mousePressEvent = self.gui_mousePressEventImg1 - self.img1.mouseMoveEvent = self.gui_mouseDragEventImg1 - self.img1.mouseReleaseEvent = self.gui_mouseReleaseEventImg1 - self.img2.mousePressEvent = self.gui_mousePressEventImg2 - self.img2.mouseMoveEvent = self.gui_mouseDragEventImg2 - self.img2.mouseReleaseEvent = self.gui_mouseReleaseEventImg2 - self.rightImageItem.mousePressEvent = self.gui_mousePressRightImage - self.rightImageItem.mouseMoveEvent = self.gui_mouseDragRightImage - self.rightImageItem.mouseReleaseEvent = self.gui_mouseReleaseRightImage - self.rightImageItem.hoverEvent = self.gui_hoverEventRightImage - # self.imgGrad.gradient.showMenu = self.gui_gradientContextMenuEvent - self.imgGradRight.gradient.showMenu = self.gui_rightImageShowContextMenu - # self.imgGrad.vb.contextMenuEvent = self.gui_gradientContextMenuEvent - self.ax1.sigRangeChanged.connect(self.viewRangeChanged) - - def gui_initImg1BottomWidgets(self): - self.zSliceScrollBar.hide() - self.zProjComboBox.hide() - self.zProjLockViewButton.hide() - self.zSliceOverlay_SB.hide() - self.zProjOverlay_CB.hide() - self.overlay_z_label.hide() - self.zSliceCheckbox.hide() - self.zSliceSpinbox.hide() - self.SizeZlabel.hide() - - @exception_handler - def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): - modifiers = QGuiApplication.keyboardModifiers() - alt = modifiers == Qt.AltModifier - shift = modifiers == Qt.ShiftModifier - shift_regardless = bool(modifiers & Qt.ShiftModifier) - isMod = alt - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - left_click = event.button() == Qt.MouseButton.LeftButton and not alt - middle_click = self.isMiddleClick(event, modifiers) - right_click = event.button() == Qt.MouseButton.RightButton and not alt - isPanImageClick = self.isPanImageClick(event, modifiers) - eraserON = self.eraserButton.isChecked() - brushON = self.brushButton.isChecked() - separateON = self.separateBudButton.isChecked() - self.typingEditID = False - - # Drag image if neither brush or eraser are On pressed - dragImg = ( - left_click and not eraserON and not - brushON and not middle_click - ) - if isPanImageClick: - dragImg = True - - # Enable dragging of the image window like pyqtgraph original code - if dragImg: - pg.ImageItem.mousePressEvent(self.img2, event) - event.ignore() - return - - if mode == 'Viewer' and middle_click: - self.startBlinkingModeCB() - event.ignore() - return - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - Y, X = self.get_2Dlab(posData.lab).shape - if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - else: - return - - # Check if right click on ROI - isClickOnDelRoi = self.gui_clickedDelRoi(event, left_click, right_click) - if isClickOnDelRoi: - return - - # show gradient widget menu if none of the right-click actions are ON - # and event is not coming from image 1 - is_right_click_action_ON = any([ - b.isChecked() for b in self.checkableQButtonsGroup.buttons() - ]) - is_right_click_custom_ON = any([ - b.isChecked() for b in self.customAnnotDict.keys() - ]) - is_event_from_img1 = False - if hasattr(event, 'isImg1Sender'): - is_event_from_img1 = event.isImg1Sender - - is_only_right_click = ( - right_click and not is_right_click_action_ON and not middle_click - ) - - showLabelsGradMenu = ( - is_only_right_click and not is_event_from_img1 - ) - - if showLabelsGradMenu: - self.labelsGrad.showMenu(event) - event.ignore() - return - - editInViewerMode = ( - (is_right_click_action_ON or is_right_click_custom_ON) - and (right_click or middle_click) and mode=='Viewer' - ) - - if editInViewerMode: - self.startBlinkingModeCB() - event.ignore() - return - - # Left-click is used for brush, eraser, separate bud, curvature tool - # and magic labeller - # Brush and eraser are mutually exclusive but we want to keep the eraser - # or brush ON and disable them temporarily to allow left-click with - # separate ON - canDelete = mode == 'Segmentation and Tracking' or self.isSnapshot - - # Delete ID (set to 0) - if middle_click and canDelete: - t0 = time.perf_counter() - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - delID = self.get_2Dlab(posData.lab)[ydata, xdata] - if delID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - delID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.
' - 'Enter here ID(s) that you want to delete

' - 'You can enter multiple IDs separated by comma', - parent=self, - allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - allowList=True, - isInteger=True - ) - delID_prompt.exec_() - if delID_prompt.cancel: - return - delIDs = delID_prompt.EntryID - else: - delIDs = [delID] - - # Ask to propagate change to all future visited frames - key = 'Delete ID' - askAction = self.askHowFutureFramesActions[key] - doNotShow = not askAction.isChecked() - (UndoFutFrames, applyFutFrames, endFrame_i, - doNotShowAgain) = self.propagateChange( - delIDs, key, doNotShow, - posData.UndoFutFrames_DelID, posData.applyFutFrames_DelID - ) - - if UndoFutFrames is None: - return - - # Store undo state before modifying stuff - self.storeUndoRedoStates(UndoFutFrames) - posData.doNotShowAgain_DelID = doNotShowAgain - posData.UndoFutFrames_DelID = UndoFutFrames - posData.applyFutFrames_DelID = applyFutFrames - includeUnvisited = posData.includeUnvisitedInfo['Delete ID'] - - delID_mask = self.deleteIDmiddleClick( - delIDs, applyFutFrames, includeUnvisited, shift=shift_regardless - ) - if delID_mask.ndim == 3: - delID_mask = delID_mask[self.z_lab()] - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Delete ID') - else: - self.warnEditingWithCca_df('Delete ID', update_images=False) - - self.setImageImg2() - delROIsIDs = self.setAllTextAnnotations() - self.setAllContoursImages(delROIsIDs=delROIsIDs, compute=False) - - how = self.drawIDsContComboBox.currentText() - if how.find('overlay segm. masks') != -1: - self.labelsLayerImg1.image[delID_mask] = 0 - self.labelsLayerImg1.setImage(self.labelsLayerImg1.image) - - how_ax2 = self.getAnnotateHowRightImage() - if how_ax2.find('overlay segm. masks') != -1: - self.labelsLayerRightImg.image[delID_mask] = 0 - self.labelsLayerRightImg.setImage(self.labelsLayerRightImg.image) - - self.highlightLostNew() - - # Separate bud or objects with same ID - elif right_click and separateON: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x) - sepID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here ID that you want to split', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - sepID_prompt.exec_() - if sepID_prompt.cancel: - return - else: - ID = sepID_prompt.EntryID - y, x = posData.rp[posData.IDs_idxs[ID]].centroid[-2:] - xdata, ydata = int(x), int(y) - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - max_ID = max(posData.IDs, default=1) - - if self.isSegm3D and not shift: - z = self.zSliceScrollBar.sliderPosition() - posData.lab, splittedIDs = measure.separate_with_label( - posData.lab, posData.rp, [ID], max_ID, - click_coords_list=[(z, ydata, xdata)] - ) - success = True - # self.set_2Dlab(lab2D) - elif not shift: - result = core.split_along_convexity_defects( - ID, self.get_2Dlab(posData.lab), max_ID - ) - lab2D, success, splittedIDs = result - self.set_2Dlab(lab2D) - else: - success = False - - # If automatic bud separation was not successfull call manual one - if not success: - posData.disableAutoActivateViewerWindow = True - img = self.getDisplayedImg1() - col = 'manual_separate_draw_mode' - drawMode = self.df_settings.at[col, 'value'] - manualSep = apps.manualSeparateGui( - self.get_2Dlab(posData.lab), ID, img, - fontSize=self.fontSize, - IDcolor=self.lut[ID], - parent=self, - drawMode=drawMode - ) - manualSep.setState(self.lastManualSeparateState) - manualSep.show() - manualSep.centerWindow() - manualSep.show(block=True) - if manualSep.cancel: - posData.disableAutoActivateViewerWindow = False - if not self.separateBudButton.findChild(QAction).isChecked(): - self.separateBudButton.setChecked(False) - return - self.lastManualSeparateState = manualSep.state() - lab2D = self.get_2Dlab(posData.lab) - lab2D[manualSep.lab!=0] = manualSep.lab[manualSep.lab!=0] - self.set_2Dlab(lab2D) - splittedIDs = [obj.label for obj in manualSep.rp] - posData.disableAutoActivateViewerWindow = False - self.storeManualSeparateDrawMode(manualSep.drawMode) - - # Update data (rp, etc) - self.update_rp() - - # Repeat tracking - self.trackSubsetIDs(splittedIDs) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Separate IDs') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Separate IDs') - - self.store_data() - - if not self.separateBudButton.findChild(QAction).isChecked(): - self.separateBudButton.setChecked(False) - - # Fill holes - elif right_click and self.fillHolesToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - clickedBkgrID = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here the ID that you want to ' - 'fill the holes of', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - clickedBkgrID.exec_() - if clickedBkgrID.cancel: - return - else: - ID = clickedBkgrID.EntryID - - if ID in posData.lab: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] - objMask = self.getObjImage(obj.image, obj.bbox) - localFill = scipy.ndimage.binary_fill_holes(objMask) - posData.lab[self.getObjSlice(obj.slice)][localFill] = ID - - self.update_rp() - self.updateAllImages() - - if not self.fillHolesToolButton.findChild(QAction).isChecked(): - self.fillHolesToolButton.setChecked(False) - - # Hull contour - elif right_click and self.hullContToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - mergeID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here the ID that you want to ' - 'replace with Hull contour', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - mergeID_prompt.exec_() - if mergeID_prompt.cancel: - return - else: - ID = mergeID_prompt.EntryID - - if ID in posData.lab: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] - objMask = self.getObjImage(obj.image, obj.bbox) - localHull = skimage.morphology.convex_hull_image(objMask) - posData.lab[self.getObjSlice(obj.slice)][localHull] = ID - - self.update_rp() - self.updateAllImages() - - if not self.hullContToolButton.findChild(QAction).isChecked(): - self.hullContToolButton.setChecked(False) - - # Move label - elif right_click and self.moveLabelToolButton.isChecked(): - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - x, y = event.pos().x(), event.pos().y() - self.startMovingLabel(x, y) - - # Fill holes - elif right_click and self.fillHolesToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - clickedBkgrID = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here the ID that you want to ' - 'fill the holes of', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - clickedBkgrID.exec_() - if clickedBkgrID.cancel: - return - else: - ID = clickedBkgrID.EntryID - - # Merge IDs - elif right_click and self.mergeIDsButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - mergeID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here first ID that you want to merge', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - mergeID_prompt.exec_() - if mergeID_prompt.cancel: - self.mergeObjsTempLine.setData([], []) - return - else: - ID = mergeID_prompt.EntryID - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - self.firstID = ID - - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - yc, xc = self.getObjCentroid(obj.centroid) - self.clickObjYc, self.clickObjXc = int(yc), int(xc) - - # Edit ID - elif right_click and self.editIDbutton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - editID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here ID that you want to replace with a new one', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - editID_prompt.show(block=True) - - if editID_prompt.cancel: - return - else: - ID = editID_prompt.EntryID - - obj_idx = posData.IDs_idxs[ID] - y, x = posData.rp[obj_idx].centroid[-2:] - xdata, ydata = int(x), int(y) - - posData.disableAutoActivateViewerWindow = True - currentIDs = posData.IDs.copy() - self.setAllIDs(onlyVisited=True) - addPropagateCheckbox = ( - not self.isSnapshot - and posData.frame_i == self.navigateScrollBar.maximum() - 1 - and posData.frame_i < posData.SizeT - 1 - ) - editID = apps.EditIDDialog( - ID, posData.IDs, - doNotShowAgain=self.doNotAskAgainExistingID, - parent=self, - entryID=self.getNearestLostObjID(y, x), - nextUniqueID=self.setBrushID(return_val=True), - allIDs=posData.allIDs, - addPropagateCheckbox=addPropagateCheckbox - ) - editID.show(block=True) - if editID.cancel: - posData.disableAutoActivateViewerWindow = False - if not self.editIDbutton.findChild(QAction).isChecked(): - self.editIDbutton.setChecked(False) - return - - if editID.assignNewID: - self.assignNewIDfromClickedID(ID, event) - return - - if not self.doNotAskAgainExistingID: - self.editIDmergeIDs = editID.mergeWithExistingID - self.doNotAskAgainExistingID = editID.doNotAskAgainExistingID - - self.applyEditID( - ID, currentIDs, editID.how, x, y, - shift=shift, - doPropagateUnvisited=editID.doPropagateFutureFrames - ) - - elif (right_click or left_click) and self.keepIDsButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - keepID_win = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to keep', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - keepID_win.exec_() - if keepID_win.cancel: - return - else: - ID = keepID_win.EntryID - - if ID in self.keptObjectsIDs: - self.keptObjectsIDs.remove(ID) - self.clearHighlightedText() - else: - self.keptObjectsIDs.append(ID) - self.highlightLabelID(ID) - - self.updateTempLayerKeepIDs() - - # Annotate cell as removed from the analysis - elif right_click and self.binCellButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - binID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to remove from the analysis', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - binID_prompt.exec_() - if binID_prompt.cancel: - return - else: - ID = binID_prompt.EntryID - - # Ask to propagate change to all future visited frames - key = 'Exclude cell from analysis' - askAction = self.askHowFutureFramesActions[key] - doNotShow = not askAction.isChecked() - (UndoFutFrames, applyFutFrames, endFrame_i, - doNotShowAgain) = self.propagateChange( - ID, key, doNotShow, - posData.UndoFutFrames_BinID, - posData.applyFutFrames_BinID - ) - - if UndoFutFrames is None: - # User cancelled the process - return - - posData.doNotShowAgain_BinID = doNotShowAgain - posData.UndoFutFrames_BinID = UndoFutFrames - posData.applyFutFrames_BinID = applyFutFrames - - self.current_frame_i = posData.frame_i - - # Apply Exclude cell from analysis to future frames if requested - if applyFutFrames: - # Store current data before going to future frames - self.store_data() - for i in range(posData.frame_i+1, endFrame_i+1): - posData.frame_i = i - self.get_data() - if ID in posData.binnedIDs: - posData.binnedIDs.remove(ID) - else: - posData.binnedIDs.add(ID) - self.update_rp_metadata(draw=False) - self.store_data(autosave=i==endFrame_i) - - self.app.restoreOverrideCursor() - - # Back to current frame - if applyFutFrames: - posData.frame_i = self.current_frame_i - self.get_data() - - # Store undo state before modifying stuff - self.storeUndoRedoStates(UndoFutFrames) - - if ID in posData.binnedIDs: - posData.binnedIDs.remove(ID) - else: - posData.binnedIDs.add(ID) - - self.annotate_rip_and_bin_IDs(updateLabel=True) - - # Gray out ore restore binned ID - self.updateLookuptable() - - if not self.binCellButton.findChild(QAction).isChecked(): - self.binCellButton.setChecked(False) - - # Annotate cell as dead - elif right_click and self.ripCellButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - ripID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to annotate as dead', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - ripID_prompt.exec_() - if ripID_prompt.cancel: - return - else: - ID = ripID_prompt.EntryID - - # Ask to propagate change to all future visited frames - key = 'Annotate cell as dead' - askAction = self.askHowFutureFramesActions[key] - doNotShow = not askAction.isChecked() - (UndoFutFrames, applyFutFrames, endFrame_i, - doNotShowAgain) = self.propagateChange( - ID, key, doNotShow, - posData.UndoFutFrames_RipID, - posData.applyFutFrames_RipID - ) - - if UndoFutFrames is None: - return - - posData.doNotShowAgain_RipID = doNotShowAgain - posData.UndoFutFrames_RipID = UndoFutFrames - posData.applyFutFrames_RipID = applyFutFrames - - self.current_frame_i = posData.frame_i - - # Apply Edit ID to future frames if requested - if applyFutFrames: - # Store current data before going to future frames - self.store_data() - for i in range(posData.frame_i+1, endFrame_i+1): - posData.frame_i = i - self.get_data() - if ID in posData.ripIDs: - posData.ripIDs.remove(ID) - else: - posData.ripIDs.add(ID) - self.update_rp_metadata(draw=False) - self.store_data(autosave=i==endFrame_i) - self.app.restoreOverrideCursor() - - # Back to current frame - if applyFutFrames: - posData.frame_i = self.current_frame_i - self.get_data() - - # Store undo state before modifying stuff - self.storeUndoRedoStates(UndoFutFrames) - - if ID in posData.ripIDs: - posData.ripIDs.remove(ID) - else: - posData.ripIDs.add(ID) - - self.annotate_rip_and_bin_IDs(updateLabel=True) - - # Gray out dead ID - self.updateLookuptable() - self.store_data() - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Annotate ID as dead') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Annotate ID as dead') - - if not self.ripCellButton.findChild(QAction).isChecked(): - self.ripCellButton.setChecked(False) - - def resetExpandLabel(self): - self.expandingID = -1 - - def expandLabelCallback(self, checked): - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.sender()) - self.connectLeftClickButtons() - self.expandFootprintSize = 1 - else: - self.clearHighlightedID() - alpha = self.imgGrad.labelsAlphaSlider.value() - self.labelsLayerImg1.setOpacity(alpha) - self.labelsLayerRightImg.setOpacity(alpha) - self.hoverLabelID = 0 - self.expandingID = 0 - self.updateAllImages() - - def expandLabel(self, dilation=True): - posData = self.data[self.pos_i] - if self.hoverLabelID == 0: - self.isExpandingLabel = False - return - - # Re-initialize label to expand when we hover on a different ID - # or we change direction - reinitExpandingLab = ( - self.expandingID != self.hoverLabelID - or dilation != self.isDilation - ) - - ID = self.hoverLabelID - - obj = posData.rp[posData.IDs.index(ID)] - - if reinitExpandingLab: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - # hoverLabelID different from previously expanded ID --> reinit - self.isExpandingLabel = True - self.expandingID = ID - self.expandingLab = np.zeros_like(self.currentLab2D) - self.expandingLab[obj.coords[:,-2], obj.coords[:,-1]] = ID - self.expandFootprintSize = 1 - - prevCoords = (obj.coords[:,-2], obj.coords[:,-1]) - self.currentLab2D[obj.coords[:,-2], obj.coords[:,-1]] = 0 - lab_2D = self.get_2Dlab(posData.lab) - lab_2D[obj.coords[:,-2], obj.coords[:,-1]] = 0 - - footprint = skimage.morphology.disk(self.expandFootprintSize) - if dilation: - expandedLab = skimage.morphology.dilation( - self.expandingLab, footprint - ) - self.isDilation = True - else: - expandedLab = skimage.morphology.erosion( - self.expandingLab, footprint - ) - self.isDilation = False - - # Prevent expanding into neighbouring labels - expandedLab[self.currentLab2D>0] = 0 - - # Get coords of the dilated/eroded object - expandedObj = skimage.measure.regionprops(expandedLab)[0] - expandedObjCoords = (expandedObj.coords[:,-2], expandedObj.coords[:,-1]) - - # Add the dilated/erored object - self.currentLab2D[expandedObjCoords] = self.expandingID - lab_2D[expandedObjCoords] = self.expandingID - - self.set_2Dlab(lab_2D) - self.currentLab2D = lab_2D - - self.update_rp() - - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.img2.setImage(img=self.currentLab2D, autoLevels=False) - - self.setTempImgExpandLabel(prevCoords, expandedObjCoords) - - def startMovingLabel(self, xPos, yPos): - posData = self.data[self.pos_i] - xdata, ydata = int(xPos), int(yPos) - lab_2D = self.get_2Dlab(posData.lab) - ID = lab_2D[ydata, xdata] - if ID == 0: - self.isMovingLabel = False - return - - posData = self.data[self.pos_i] - self.isMovingLabel = True - - self.searchedIDitemRight.setData([], []) - self.searchedIDitemLeft.setData([], []) - self.movingID = ID - self.prevMovePos = (xdata, ydata) - movingObj = posData.rp[posData.IDs.index(ID)] - self.movingObjCoords = movingObj.coords.copy() - yy, xx = movingObj.coords[:,-2], movingObj.coords[:,-1] - self.currentLab2D[yy, xx] = 0 - - def moveLabel(self, xPos, yPos): - posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) - Y, X = lab_2D.shape - xdata, ydata = int(xPos), int(yPos) - if xdata<0 or ydata<0 or xdata>=X or ydata>=Y: - return - - self.clearObjContour(ID=self.movingID, ax=0) - - xStart, yStart = self.prevMovePos - deltaX = xdata-xStart - deltaY = ydata-yStart - - yy, xx = self.movingObjCoords[:,-2], self.movingObjCoords[:,-1] - - if self.isSegm3D: - zz = self.movingObjCoords[:,0] - posData.lab[zz, yy, xx] = 0 - else: - posData.lab[yy, xx] = 0 - - self.movingObjCoords[:,-2] = self.movingObjCoords[:,-2]+deltaY - self.movingObjCoords[:,-1] = self.movingObjCoords[:,-1]+deltaX - - yy, xx = self.movingObjCoords[:,-2], self.movingObjCoords[:,-1] - - yy[yy<0] = 0 - xx[xx<0] = 0 - yy[yy>=Y] = Y-1 - xx[xx>=X] = X-1 - - if self.isSegm3D: - zz = self.movingObjCoords[:,0] - posData.lab[zz, yy, xx] = self.movingID - else: - posData.lab[yy, xx] = self.movingID - - self.currentLab2D = self.get_2Dlab(posData.lab) - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.img2.setImage(self.currentLab2D, autoLevels=False) - - self.setTempImg1MoveLabel() - - self.prevMovePos = (xdata, ydata) - - @exception_handler - def gui_mouseDragEventImg1(self, event): - x, y = event.pos().x(), event.pos().y() - - if hasattr(self, 'scaleBar'): - if self.scaleBarDialog is not None: - self.scaleBarDialog.locCombobox.setCurrentText('Custom') - if self.scaleBar.isHighlighted() and self.scaleBar.clicked: - self.scaleBar.setLocationProperty('custom') - self.scaleBar.move(x, y) - return - - if hasattr(self, 'timestamp'): - if self.timestampDialog is not None: - self.timestampDialog.locCombobox.setCurrentText('Custom') - if self.timestamp.isHighlighted() and self.timestamp.clicked: - self.timestamp.setLocationProperty('custom') - self.timestamp.move(x, y) - return - - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - return - - posData = self.data[self.pos_i] - Y, X = self.get_2Dlab(posData.lab).shape - xdata, ydata = int(x), int(y) - if not myutils.is_in_bounds(xdata, ydata, X, Y): - return - - if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): - self.drawAutoContour(y, x) - - # Brush dragging mouse --> keep brushing - elif self.isMouseDragImg1 and self.brushButton.isChecked(): - lab_2D = self.get_2Dlab(posData.lab) - - # t1 = time.perf_counter() - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) - - # t2 = time.perf_counter() - - diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) - - # Build brush mask - mask = np.zeros(lab_2D.shape, bool) - mask[diskSlice][diskMask] = True - mask[rrPoly, ccPoly] = True - - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - - # t3 = time.perf_counter() - if not self.isPowerBrush() and not ctrl: - mask[lab_2D!=0] = False - self.setHoverToolSymbolColor( - xdata, ydata, self.ax2_BrushCirclePen, - (self.ax2_BrushCircle, self.ax1_BrushCircle), - self.brushButton, brush=self.ax2_BrushCircleBrush - ) - - # t4 = time.perf_counter() - - # Apply brush mask - self.applyBrushMask(mask, posData.brushID) - - self.setImageImg2(updateLookuptable=False) - - # t5 = time.perf_counter() - - lab2D = self.get_2Dlab(posData.lab) - brushMask = np.logical_and( - lab2D[diskSlice] == posData.brushID, diskMask - ) - self.setTempImg1Brush( - False, brushMask, posData.brushID, - toLocalSlice=diskSlice - ) - - # t6 = time.perf_counter() - - # printl( - # 'Brush exec times =\n' - # f' * {(t1-t0)*1000 = :.4f} ms\n' - # f' * {(t2-t1)*1000 = :.4f} ms\n' - # f' * {(t3-t2)*1000 = :.4f} ms\n' - # f' * {(t4-t3)*1000 = :.4f} ms\n' - # f' * {(t5-t4)*1000 = :.4f} ms\n' - # f' * {(t6-t5)*1000 = :.4f} ms\n' - # f' * {(t6-t0)*1000 = :.4f} ms' - # ) - - # Eraser dragging mouse --> keep erasing - elif self.isMouseDragImg1 and self.eraserButton.isChecked(): - posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) - rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - - diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) - - # Build eraser mask - mask = np.zeros(lab_2D.shape, bool) - mask[ymin:ymax, xmin:xmax][diskMask] = True - mask[rrPoly, ccPoly] = True - - if self.eraseOnlyOneID: - mask[lab_2D!=self.erasedID] = False - self.setHoverToolSymbolColor( - xdata, ydata, self.eraserCirclePen, - (self.ax2_EraserCircle, self.ax1_EraserCircle), - self.eraserButton, hoverRGB=self.img2.lut[self.erasedID], - ID=self.erasedID - ) - - self.erasedIDs.update(lab_2D[mask]) - self.applyEraserMask(mask) - - self.setImageImg2() - - for erasedID in self.erasedIDs: - if erasedID == 0: - continue - self.erasedLab[lab_2D==erasedID] = erasedID - self.erasedLab[mask] = 0 - - eraserMask = mask[diskSlice] - self.setTempImg1Eraser(eraserMask, toLocalSlice=diskSlice) - self.setTempImg1Eraser(eraserMask, toLocalSlice=diskSlice, ax=1) - - # Move label dragging mouse --> keep moving - elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - self.moveLabel(x, y) - - # Wand dragging mouse --> keep doing the magic - elif self.isMouseDragImg1 and self.wandToolButton.isChecked(): - tol = self.getMagicWandFloodTolerance() - if self.isSegm3D: - z_slice = self.zSliceScrollBar.sliderPosition() - seed = (z_slice, ydata, xdata) - else: - seed = (ydata, xdata) - - flood_mask = skimage.segmentation.flood( - self.flood_img, seed, tolerance=tol - ) - drawUnderMask = np.logical_or( - posData.lab==0, posData.lab==posData.brushID - ) - flood_mask = np.logical_and(flood_mask, drawUnderMask) - - self.flood_mask[flood_mask] = True - - if self.wandControlsToolbar.autoFillHolesCheckbox.isChecked(): - self.flood_mask = core.binary_fill_holes(self.flood_mask) - - if self.wandControlsToolbar.useConvexHullCheckbox.isChecked(): - self.flood_mask = core.convex_hull_mask(self.flood_mask) - - self.setTempBrushMaskFromWand(self.flood_mask) - - # Label ROI dragging mouse --> draw ROI - elif self.isMouseDragImg1 and self.labelRoiButton.isChecked(): - if self.labelRoiIsRectRadioButton.isChecked(): - x0, y0 = self.labelRoiItem.pos() - w, h = (xdata-x0), (ydata-y0) - self.labelRoiItem.setSize((w, h)) - elif self.labelRoiIsFreeHandRadioButton.isChecked(): - self.freeRoiItem.addPoint(xdata, ydata) - - # Draw freehand clear region --> draw region - elif self.isMouseDragImg1 and self.drawClearRegionButton.isChecked(): - self.freeRoiItem.addPoint(xdata, ydata) - - # Label ROI dragging mouse --> draw ROI - elif self.isMouseDragImg1 and self.zoomRectButton.isChecked(): - x0, y0 = self.zoomRectItem.pos() - w, h = (xdata-x0), (ydata-y0) - self.zoomRectItem.setSize((w, h)) - - # @exec_time - def fillHolesID(self, ID, sender='brush'): - posData = self.data[self.pos_i] - if sender == 'brush': - if not self.brushAutoFillCheckbox.isChecked(): - return False - - lab2D = self.get_2Dlab(posData.lab) - mask = lab2D == ID - filledMask = scipy.ndimage.binary_fill_holes(mask) - lab2D[filledMask] = ID - - self.set_2Dlab(lab2D) - return True - return False - - def highlightIDonHoverCheckBoxToggled(self, checked): - doHighlight = ( - self.guiTabControl.highlightCheckbox.isChecked() - or self.guiTabControl.highlightSearchedCheckbox.isChecked() - ) - if not doHighlight: - self.highlightedID = 0 - self.initLookupTableLab() - else: - self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() - self.highlightSearchedID(self.highlightedID, force=True) - self.updatePropsWidget(self.highlightedID) - self.updateAllImages() - - def highlightSearchedIDcheckBoxToggled(self, checked): - self.highlightIDonHoverCheckBoxToggled(checked) - if checked: - posData = self.data[self.pos_i] - self.highlightedID = self.getHighlightedID() - if self.highlightedID == 0: - return - objIdx = posData.IDs_idxs[self.highlightedID] - obj_idx = posData.IDs_idxs.get(self.highlightedID) - if obj_idx is None: - return - obj = posData.rp[objIdx] - self.goToZsliceSearchedID(obj) - - def setHighlightID(self, doHighlight): - if not doHighlight: - self.highlightedID = 0 - self.initLookupTableLab() - else: - self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() - self.highlightSearchedID(self.highlightedID, force=True) - self.updatePropsWidget(self.highlightedID) - self.updateAllImages() - - def propsWidgetIDvalueChanged(self, ID): - posData = self.data[self.pos_i] - if ID == 0: - self.updatePropsWidget(int(ID)) - return - - propsQGBox = self.guiTabControl.propsQGBox - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: - s = f'Object ID {int(ID):d} does not exist' - propsQGBox.notExistingIDLabel.setText(s) - return - - obj = posData.rp[obj_idx] - self.goToZsliceSearchedID(obj) - self.updatePropsWidget(int(ID)) - - def updatePropsWidget(self, ID, fromHover=False): - if isinstance(ID, str): - # Function called by currentTextChanged of channelCombobox or - # additionalMeasCombobox. We set self.currentPropsID = 0 to force update - ID = self.guiTabControl.propsQGBox.idSB.value() - self.currentPropsID = -1 - - ID = int(ID) - - update = ( - self.propsDockWidget.isVisible() - and ID != 0 and ID!=self.currentPropsID - ) - if not update: - return - - posData = self.data[self.pos_i] - if not hasattr(posData, 'rp'): - return - - if posData.rp is None: - self.update_rp() - - if not posData.IDs: - # empty segmentation mask - return - - if fromHover and not self.guiTabControl.highlightCheckbox.isChecked(): - # Do not highlight on hover - return - - propsQGBox = self.guiTabControl.propsQGBox - - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: - s = f'Object ID {int(ID):d} does not exist' - propsQGBox.notExistingIDLabel.setText(s) - return - - propsQGBox.notExistingIDLabel.setText('') - self.currentPropsID = ID - propsQGBox.idSB.setValue(ID) - - doHighlight = ( - self.guiTabControl.highlightCheckbox.isChecked() - or self.guiTabControl.highlightSearchedCheckbox.isChecked() - ) - if doHighlight: - self.highlightSearchedID(ID) - - obj = posData.rp[obj_idx] - - if self.isSegm3D: - if self.zProjComboBox.currentText() == 'single z-slice': - local_z = self.z_lab() - obj.bbox[0] - area_pxl = np.count_nonzero(obj.image[local_z]) - else: - area_pxl = np.count_nonzero(obj.image.max(axis=0)) - else: - area_pxl = obj.area - - propsQGBox.cellAreaPxlSB.setValue(area_pxl) - - pixelSizeQGBox = self.guiTabControl.pixelSizeQGBox - PhysicalSizeX = pixelSizeQGBox.pixelWidthWidget.value() - PhysicalSizeY = pixelSizeQGBox.pixelHeightWidget.value() - PhysicalSizeZ = pixelSizeQGBox.voxelDepthWidget.value() - - yx_pxl_to_um2 = PhysicalSizeY*PhysicalSizeX - - area_um2 = area_pxl*yx_pxl_to_um2 - - propsQGBox.cellAreaUm2DSB.setValue(area_um2) - - if self.isSegm3D: - PhysicalSizeZ = posData.PhysicalSizeZ - vol_vox_3D = obj.area - vol_fl_3D = vol_vox_3D*PhysicalSizeZ*PhysicalSizeY*PhysicalSizeX - propsQGBox.cellVolVox3D_SB.setValue(vol_vox_3D) - propsQGBox.cellVolFl3D_DSB.setValue(vol_fl_3D) - - vol_vox, vol_fl = _calc_rot_vol( - obj, PhysicalSizeY, PhysicalSizeX - ) - propsQGBox.cellVolVoxSB.setValue(int(vol_vox)) - propsQGBox.cellVolFlDSB.setValue(vol_fl) - - - minor_axis_length = max(1, obj.minor_axis_length) - elongation = obj.major_axis_length/minor_axis_length - propsQGBox.elongationDSB.setValue(elongation) - - solidity = obj.solidity - propsQGBox.solidityDSB.setValue(solidity) - - additionalPropName = propsQGBox.additionalPropsCombobox.currentText() - additionalPropValue = getattr(obj, additionalPropName) - propsQGBox.additionalPropsCombobox.indicator.setValue(additionalPropValue) - - intensMeasurQGBox = self.guiTabControl.intensMeasurQGBox - selectedChannel = intensMeasurQGBox.channelCombobox.currentText() - - try: - _, filename = self.getPathFromChName(selectedChannel, posData) - image = posData.ol_data_dict[filename][posData.frame_i] - except Exception as e: - image = posData.img_data[posData.frame_i] - - if posData.SizeZ > 1 and not self.isSegm3D: - z = self.zSliceScrollBar.sliderPosition() - objData = image[z][obj.slice][obj.image] - img = self.img1.image - else: - objData = image[obj.slice][obj.image] - img = image - - intensMeasurQGBox.minimumDSB.setValue(np.min(objData)) - intensMeasurQGBox.maximumDSB.setValue(np.max(objData)) - intensMeasurQGBox.meanDSB.setValue(np.mean(objData)) - intensMeasurQGBox.medianDSB.setValue(np.median(objData)) - - funcDesc = intensMeasurQGBox.additionalMeasCombobox.currentText() - func = intensMeasurQGBox.additionalMeasCombobox.functions[funcDesc] - if funcDesc == 'Concentration': - bkgrVal = np.median(img[posData.lab == 0]) - amount = func(objData, bkgrVal, obj.area) - value = amount/vol_vox - elif funcDesc == 'Amount': - bkgrVal = np.median(img[posData.lab == 0]) - amount = func(objData, bkgrVal, obj.area) - value = amount - else: - value = func(objData) - - intensMeasurQGBox.additionalMeasCombobox.indicator.setValue(value) - - def gui_hoverEventRightImage(self, event): - try: - posData = self.data[self.pos_i] - except AttributeError: - return - - if event.isExit(): - self.resetCursors() - - self.gui_hoverEventImg1(event, isHoverImg1=False) - setMirroredCursor = ( - self.app.overrideCursor() is None and not event.isExit() - and self.showMirroredCursorAction.isChecked() - ) - if setMirroredCursor: - x, y = event.pos() - self.ax1_cursor.setData([x], [y]) - - def onCtrlPressedFirstTime(self): - x, y = self.xHoverImg, self.yHoverImg - if x is None: - self.xyOnCtrlPressedFirstTime = None - return - - xdata, ydata = int(x), int(y) - Y, X = self.currentLab2D.shape - - if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): - self.xyOnCtrlPressedFirstTime = None - return - - ID = self.currentLab2D[ydata, xdata] - if ID == 0: - self.xyOnCtrlPressedFirstTime = None - return - - self.xyOnCtrlPressedFirstTime = (xdata, ydata) - - def onCtrlReleased(self): - self.xyOnCtrlPressedFirstTime = None - - def gui_hoverEventImg1(self, event, isHoverImg1=True): - try: - posData = self.data[self.pos_i] - except AttributeError: - return - - # Update x, y, value label bottom right - if not event.isExit(): - self.xHoverImg, self.yHoverImg = event.pos() - else: - self.xHoverImg, self.yHoverImg = None, None - - if event.isExit(): - self.resetCursor() - - if not event.isExit() and self.slideshowWin is not None: - self.slideshowWin.setMirroredCursorPos(*event.pos()) - - # Alt key was released --> restore cursor - modifiers = QGuiApplication.keyboardModifiers() - cursorsInfo = self.gui_setCursor(modifiers, event) - self.highlightHoverLostObj(modifiers, event) - - drawRulerLine = ( - (self.rulerButton.isChecked() - or self.addDelPolyLineRoiButton.isChecked()) - and self.tempSegmentON and not event.isExit() - ) - if drawRulerLine: - self.drawTempRulerLine(event) - - if not event.isExit(): - x, y = event.pos() - xdata, ydata = int(x), int(y) - _img = self.img1.image - Y, X = _img.shape[:2] - if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - ID = self.currentLab2D[ydata, xdata] - self.updatePropsWidget(ID, fromHover=True) - activeToolButton = self.getActiveToolButton() - hoverText = self.hoverValuesFormatted( - xdata, ydata, activeToolButton, isHoverImg1 - ) - self.checkHighlightScaleBar(x, y, activeToolButton) - self.checkHighlightTimestamp(x, y, activeToolButton) - self.wcLabel.setText(hoverText) - else: - self.clickedOnBud = False - self.BudMothTempLine.setData([], []) - self.wcLabel.setText('') - - if cursorsInfo['setKeepObjCursor']: - x, y = event.pos() - self.highlightHoverIDsKeptObj(x, y) - - if cursorsInfo['setManualTrackingCursor']: - x, y = event.pos() - # self.highlightHoverID(x, y) - self.drawManualTrackingGhost(x, y) - - if cursorsInfo['setManualBackgroundCursor']: - x, y = event.pos() - # self.highlightHoverID(x, y) - self.drawManualBackgroundObj(x, y) - - if ( - not cursorsInfo['setManualTrackingCursor'] - and not cursorsInfo['setManualBackgroundCursor'] - ): - self.clearGhost() - - setMoveLabelCursor = cursorsInfo['setMoveLabelCursor'] - setExpandLabelCursor = cursorsInfo['setExpandLabelCursor'] - if setMoveLabelCursor or setExpandLabelCursor: - x, y = event.pos() - self.updateHoverLabelCursor(x, y) - - # Draw eraser circle - if cursorsInfo['setEraserCursor']: - x, y = event.pos() - self.updateEraserCursor(x, y, isHoverImg1=isHoverImg1) - self.hideItemsHoverBrush(xy=(x, y)) - elif self.eraserButton.isChecked() and not event.isExit(): - if self.xyOnCtrlPressedFirstTime is not None: - self.updateEraserCursor( - x, y, xyLocked=self.xyOnCtrlPressedFirstTime, - isHoverImg1=isHoverImg1 - ) - self.hideItemsHoverBrush(xy=(x, y)) - else: - eraserCursors = ( - self.ax1_EraserCircle, self.ax2_EraserCircle, - self.ax1_EraserX, self.ax2_EraserX - ) - self.setHoverToolSymbolData([], [], eraserCursors) - - # Draw Brush circle - if cursorsInfo['setBrushCursor']: - x, y = event.pos() - self.updateBrushCursor(x, y, isHoverImg1=isHoverImg1) - self.hideItemsHoverBrush(xy=(x, y)) - elif cursorsInfo['setAddPointCursor']: - x, y = event.pos() - self.setHoverCircleAddPoint(x, y) - else: - self.setHoverToolSymbolData( - [], [], (self.ax2_BrushCircle, self.ax1_BrushCircle), - ) - - # Draw label ROi circular cursor - setLabelRoiCircCursor = cursorsInfo['setLabelRoiCircCursor'] - if setLabelRoiCircCursor: - x, y = event.pos() - else: - x, y = None, None - self.updateLabelRoiCircularCursor(x, y, setLabelRoiCircCursor) - - drawMothBudLine = ( - self.assignBudMothButton.isChecked() and self.clickedOnBud - and not event.isExit() - ) - if drawMothBudLine: - self.drawTempMothBudLine(event, posData) - - drawMergeObjsLine = ( - self.mergeIDsButton.isChecked() and not event.isExit() - ) - if drawMergeObjsLine: - self.drawTempMergeObjsLine(event, posData, modifiers) - - # Temporarily draw spline curve - # see https://stackoverflow.com/questions/33962717/interpolating-a-closed-curve-using-scipy - drawSpline = ( - self.curvToolButton.isChecked() and self.splineHoverON - and not event.isExit() - ) - if drawSpline: - self.hoverEventDrawSpline(event) - - setMirroredCursor = ( - self.app.overrideCursor() is None and not event.isExit() - and isHoverImg1 and self.showMirroredCursorAction.isChecked() - ) - if setMirroredCursor: - x, y = event.pos() - self.ax2_cursor.setData([x], [y]) - else: - self.ax2_cursor.setData([], []) - - return cursorsInfo - - def drawTempMothBudLine(self, event, posData): - x, y = event.pos() - y2, x2 = y, x - xdata, ydata = int(x), int(y) - y1, x1 = self.yClickBud, self.xClickBud - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - self.BudMothTempLine.setData([x1, x2], [y1, y2]) - else: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) - self.BudMothTempLine.setData([x1, x2], [y1, y2]) - - def drawTempMergeObjsLine(self, event, posData, modifiers): - if self.clickObjYc is None: - return - modifier = modifiers == Qt.ShiftModifier - x, y = event.pos() - y2, x2 = y, x - xdata, ydata = int(x), int(y) - y1, x1 = self.clickObjYc, self.clickObjXc - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID != 0: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) - - if modifier and ID > 0: - self.mergeObjsTempLine.addPoint(x2, y2) - elif not modifier: - self.mergeObjsTempLine.setData([x1, x2], [y1, y2]) - - def gui_add_ax_cursors(self): - try: - self.ax1.removeItem(self.ax1_cursor) - self.ax2.removeItem(self.ax2_cursor) - except Exception as e: - pass - - self.ax2_cursor = pg.ScatterPlotItem( - symbol='+', pxMode=True, pen=pg.mkPen('k', width=1), - brush=pg.mkBrush('w'), size=16, tip=None - ) - self.ax2.addItem(self.ax2_cursor) - - self.ax1_cursor = pg.ScatterPlotItem( - symbol='+', pxMode=True, pen=pg.mkPen('k', width=1), - brush=pg.mkBrush('w'), size=16, tip=None - ) - self.ax1.addItem(self.ax1_cursor) - - def gui_setCursor(self, modifiers, event): - noModifier = modifiers == Qt.NoModifier - shift = modifiers == Qt.ShiftModifier - ctrl = modifiers == Qt.ControlModifier - alt = modifiers == Qt.AltModifier - - # Alt key was released --> restore cursor - if self.app.overrideCursor() == Qt.SizeAllCursor and noModifier: - self.app.restoreOverrideCursor() - - setBrushCursor = ( - self.brushButton.isChecked() and not event.isExit() - and (noModifier or shift or ctrl) - ) - setEraserCursor = ( - self.eraserButton.isChecked() and not event.isExit() - and noModifier - ) - setAddDelPolyLineCursor = ( - self.addDelPolyLineRoiButton.isChecked() and not event.isExit() - and noModifier - ) - setLabelRoiCircCursor = ( - self.labelRoiButton.isChecked() and not event.isExit() - and (noModifier or shift or ctrl) - and self.labelRoiIsCircularRadioButton.isChecked() - ) - setWandCursor = ( - self.wandToolButton.isChecked() and not event.isExit() - and noModifier - ) - setLabelRoiCursor = ( - self.labelRoiButton.isChecked() and not event.isExit() - and noModifier - ) - setMoveLabelCursor = ( - self.moveLabelToolButton.isChecked() and not event.isExit() - and noModifier - ) - setExpandLabelCursor = ( - self.expandLabelToolButton.isChecked() and not event.isExit() - and noModifier - ) - setCurvCursor = ( - self.curvToolButton.isChecked() and not event.isExit() - and noModifier - ) - setKeepObjCursor = ( - self.keepIDsButton.isChecked() and not event.isExit() - and noModifier - ) - setCustomAnnotCursor = ( - self.customAnnotButton is not None and not event.isExit() - and noModifier - ) - setManualTrackingCursor = ( - self.manualTrackingButton.isChecked() - and not event.isExit() - and noModifier - ) - setManualBackgroundCursor = ( - self.manualBackgroundButton.isChecked() - and not event.isExit() - and noModifier - ) - setZoomRectCursor = ( - self.zoomRectButton.isChecked() and not event.isExit() - and noModifier - ) - setEditIDCursor = ( - self.editIDbutton.isChecked() and not event.isExit() - ) - magicPromptsON = self.magicPromptsToolButton.isChecked() - pointsLayerON = self.togglePointsLayerAction.isChecked() - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - setAddPointCursor = ( - (pointsLayerON or magicPromptsON) - and addPointsByClickingButton is not None - and not event.isExit() - and noModifier - ) - overrideCursor = self.app.overrideCursor() - setPanImageCursor = alt and not event.isExit() - if setPanImageCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.SizeAllCursor) - elif setBrushCursor or setEraserCursor or setLabelRoiCircCursor: - self.app.setOverrideCursor(Qt.CrossCursor) - elif setWandCursor and overrideCursor is None: - self.app.setOverrideCursor(self.wandCursor) - elif setLabelRoiCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.CrossCursor) - elif setCurvCursor and overrideCursor is None: - self.app.setOverrideCursor(self.curvCursor) - elif setCustomAnnotCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.PointingHandCursor) - elif setAddDelPolyLineCursor: - self.app.setOverrideCursor(self.polyLineRoiCursor) - elif setCustomAnnotCursor: - x, y = event.pos() - self.highlightHoverID(x, y) - elif setKeepObjCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.PointingHandCursor) - elif setManualTrackingCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.PointingHandCursor) - elif setManualBackgroundCursor and overrideCursor is None: - self.app.setOverrideCursor(Qt.PointingHandCursor) - elif setAddPointCursor: - self.app.setOverrideCursor(self.addPointsCursor) - elif setZoomRectCursor: - self.app.setOverrideCursor(Qt.CrossCursor) - elif setEditIDCursor and overrideCursor is None: - if shift: - self.app.setOverrideCursor(Qt.CrossCursor) - else: - self.app.restoreOverrideCursor() - - return { - 'setBrushCursor': setBrushCursor, - 'setEraserCursor': setEraserCursor, - 'setAddDelPolyLineCursor': setAddDelPolyLineCursor, - 'setLabelRoiCircCursor': setLabelRoiCircCursor, - 'setWandCursor': setWandCursor, - 'setLabelRoiCursor': setLabelRoiCursor, - 'setMoveLabelCursor': setMoveLabelCursor, - 'setExpandLabelCursor': setExpandLabelCursor, - 'setCurvCursor': setCurvCursor, - 'setKeepObjCursor': setKeepObjCursor, - 'setCustomAnnotCursor': setCustomAnnotCursor, - 'setManualTrackingCursor': setManualTrackingCursor, - 'setManualBackgroundCursor': setManualBackgroundCursor, - 'setAddPointCursor': setAddPointCursor, - 'setZoomRectCursor': setZoomRectCursor, - 'setEditIDCursor': setEditIDCursor - } - - def warnAddingPointWithExistingId(self, point_id, table_endname=''): - posData = self.data[self.pos_i] - if not point_id in posData.IDs_idxs: - return True - - msg = widgets.myMessageBox(wrapText=False) - txt = (f""" - Cell ID {point_id} already exists!

- Are you sure you want to add this point? - """) - if table_endname: - txt = (f""" - The loaded table {table_endname} has point id - {point_id}. -

However, {txt} - """) - txt = html_utils.paragraph(txt) - _, _, yesButton = msg.warning( - self, f'Cell ID {point_id} already exist', txt, - buttonsTexts=( - 'Cancel', 'No, do not add', f'Yes, add point id {point_id}' - ) - ) - return msg.clickedButton == yesButton - - def gui_hoverEventImg2(self, event): - try: - posData = self.data[self.pos_i] - except AttributeError: - return - - if not event.isExit(): - self.xHoverImg, self.yHoverImg = event.pos() - else: - self.xHoverImg, self.yHoverImg = None, None - - # Cursor left image --> restore cursor - if event.isExit() and self.app.overrideCursor() is not None: - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - # Alt key was released --> restore cursor - modifiers = QGuiApplication.keyboardModifiers() - noModifier = modifiers == Qt.NoModifier - shift = modifiers == Qt.ShiftModifier - ctrl = modifiers == Qt.ControlModifier - if self.app.overrideCursor() == Qt.SizeAllCursor and noModifier: - self.app.restoreOverrideCursor() - - setBrushCursor = ( - self.brushButton.isChecked() and not event.isExit() - and (noModifier or shift or ctrl) - ) - setEraserCursor = ( - self.eraserButton.isChecked() and not event.isExit() - and noModifier - ) - setLabelRoiCircCursor = ( - self.labelRoiButton.isChecked() and not event.isExit() - and (noModifier or shift or ctrl) - and self.labelRoiIsCircularRadioButton.isChecked() - ) - if setBrushCursor or setEraserCursor or setLabelRoiCircCursor: - self.app.setOverrideCursor(Qt.CrossCursor) - - setMoveLabelCursor = ( - self.moveLabelToolButton.isChecked() and not event.isExit() - and noModifier - ) - - setExpandLabelCursor = ( - self.expandLabelToolButton.isChecked() and not event.isExit() - and noModifier - ) - - # Cursor is moving on image while Alt key is pressed --> pan cursor - alt = QGuiApplication.keyboardModifiers() == Qt.AltModifier - setPanImageCursor = alt and not event.isExit() - if setPanImageCursor and self.app.overrideCursor() is None: - self.app.setOverrideCursor(Qt.SizeAllCursor) - - setKeepObjCursor = ( - self.keepIDsButton.isChecked() and not event.isExit() - and noModifier - ) - if setKeepObjCursor and self.app.overrideCursor() is None: - self.app.setOverrideCursor(Qt.PointingHandCursor) - - # Update x, y, value label bottom right - if not event.isExit(): - x, y = event.pos() - xdata, ydata = int(x), int(y) - _img = self.currentLab2D - Y, X = _img.shape - # hoverText = self.hoverValuesFormatted(xdata, ydata) - # self.wcLabel.setText(hoverText) - else: - if self.eraserButton.isChecked() or self.brushButton.isChecked(): - self.gui_mouseReleaseEventImg2(event) - self.wcLabel.setText(f'') - - if setMoveLabelCursor or setExpandLabelCursor: - x, y = event.pos() - self.updateHoverLabelCursor(x, y) - - if setKeepObjCursor: - x, y = event.pos() - self.highlightHoverIDsKeptObj(x, y) - - # Draw eraser circle - if setEraserCursor: - x, y = event.pos() - self.updateEraserCursor(x, y, isHoverImg1=False) - else: - self.setHoverToolSymbolData( - [], [], (self.ax1_EraserCircle, self.ax2_EraserCircle, - self.ax1_EraserX, self.ax2_EraserX) - ) - - # Draw Brush circle - if setBrushCursor: - x, y = event.pos() - self.updateBrushCursor(x, y, isHoverImg1=False) - else: - self.setHoverToolSymbolData( - [], [], (self.ax2_BrushCircle, self.ax1_BrushCircle), - ) - - # Draw label ROi circular cursor - if setLabelRoiCircCursor: - x, y = event.pos() - else: - x, y = None, None - self.updateLabelRoiCircularCursor(x, y, setLabelRoiCircCursor) - - def gui_imgGradShowContextMenu(self, x, y): - if hasattr(self, 'scaleBar'): - if self.scaleBar.isHighlighted(): - self.scaleBar.showContextMenu(x, y) - return - - if hasattr(self, 'timestamp'): - if self.timestamp.isHighlighted(): - self.timestamp.showContextMenu(x, y) - return - - self.imgGrad.gradient.menu.popup(QPoint(int(x), int(y))) - - def gui_rightImageShowContextMenu(self, event): - try: - # Convert QPointF to QPoint - self.imgGradRight.gradient.menu.popup(event.screenPos().toPoint()) - except AttributeError: - self.imgGradRight.gradient.menu.popup(event.screenPos()) - - @exception_handler - def gui_mouseDragEventImg2(self, event): - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - return - - Y, X = self.get_2Dlab(posData.lab).shape - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - if not myutils.is_in_bounds(xdata, ydata, X, Y): - return - - # Eraser dragging mouse --> keep erasing - if self.isMouseDragImg2 and self.eraserButton.isChecked(): - posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) - Y, X = lab_2D.shape - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - brushSize = self.brushSizeSpinbox.value() - rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - - # Build eraser mask - mask = np.zeros(lab_2D.shape, bool) - mask[ymin:ymax, xmin:xmax][diskMask] = True - mask[rrPoly, ccPoly] = True - - if self.eraseOnlyOneID: - mask[lab_2D!=self.erasedID] = False - self.setHoverToolSymbolColor( - xdata, ydata, self.eraserCirclePen, - (self.ax2_EraserCircle, self.ax1_EraserCircle), - self.eraserButton, hoverRGB=self.img2.lut[self.erasedID], - ID=self.erasedID - ) - - self.erasedIDs.update(lab_2D[mask]) - - self.applyEraserMask(mask) - self.setImageImg2(updateLookuptable=False) - - # Brush paint dragging mouse --> keep painting - if self.isMouseDragImg2 and self.brushButton.isChecked(): - posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) - Y, X = lab_2D.shape - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) - - # Build brush mask - mask = np.zeros(lab_2D.shape, bool) - mask[ymin:ymax, xmin:xmax][diskMask] = True - mask[rrPoly, ccPoly] = True - - # If user double-pressed 'b' then draw over the labels - color = self.brushButton.palette().button().color().name() - if color != self.doublePressKeyButtonColor: - mask[lab_2D!=0] = False - self.setHoverToolSymbolColor( - xdata, ydata, self.ax2_BrushCirclePen, - (self.ax2_BrushCircle, self.ax1_BrushCircle), - self.eraserButton, brush=self.ax2_BrushCircleBrush - ) - - # Apply brush mask - self.applyBrushMask(mask, self.ax2BrushID) - - self.setImageImg2() - - # Move label dragging mouse --> keep moving - elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - self.moveLabel(x, y) - - @exception_handler - def gui_mouseReleaseEventImg2(self, event): - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - return - - Y, X = self.get_2Dlab(posData.lab).shape - try: - x, y = event.pos().x(), event.pos().y() - except Exception as e: - return - - xdata, ydata = int(x), int(y) - if not myutils.is_in_bounds(xdata, ydata, X, Y): - self.isMouseDragImg2 = False - self.updateAllImages() - return - - # Move label mouse released, update move - if self.isMovingLabel and self.moveLabelToolButton.isChecked(): - self.isMovingLabel = False - - # Update data (rp, etc) - self.update_rp() - - # Repeat tracking - self.tracking(enforce=True, assign_unique_new_IDs=False) - - self.updateAllImages() - - if not self.moveLabelToolButton.findChild(QAction).isChecked(): - self.moveLabelToolButton.setChecked(False) - - # Merge IDs - elif self.mergeIDsButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - lab2D = self.get_2Dlab(posData.lab) - ID = lab2D[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - lab2D, y, x - ) - mergeID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to merge with ID ' - f'{self.firstID}', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - mergeID_prompt.exec_() - if mergeID_prompt.cancel: - return - else: - ID = mergeID_prompt.EntryID - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) - self.mergeObjsTempLine.addPoint(x2, y2) - - xx, yy = self.mergeObjsTempLine.getData() - IDs_to_merge = lab2D[yy.astype(int), xx.astype(int)] - for ID in IDs_to_merge: - if ID == 0: - continue - posData.lab[posData.lab==ID] = self.firstID - - self.mergeObjsTempLine.setData([], []) - self.clickObjYc, self.clickObjXc = None, None - - # Update data (rp, etc) - self.update_rp() - - ask_back_prop = True - - if posData.frame_i == 0: - ask_back_prop = False - prev_IDs = [] - else: - prev_IDs = posData.allData_li[posData.frame_i-1]['IDs'] - - if all(ID not in prev_IDs for ID in IDs_to_merge): - ask_back_prop = False - - if not self.isFrameCcaAnnotated() and ask_back_prop: - proceed = self.askPropagateChangePast(f'Merge IDs {IDs_to_merge}') - if proceed: - self.propagateMergeObjsPast(IDs_to_merge) - self.whitelistPropagateIDs(only_future_frames=False, update_lab=True) # in the update_rp() call, this should also be done - - # Repeat tracking - self.tracking( - enforce=True, assign_unique_new_IDs=False, - separateByLabel=False - ) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Merge IDs') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Merge IDs') - - if not self.mergeIDsButton.findChild(QAction).isChecked(): - self.mergeIDsButton.setChecked(False) - self.store_data() - - @exception_handler - def gui_mouseReleaseEventImg1(self, event): - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - alt = modifiers == Qt.AltModifier - right_click = event.button() == Qt.MouseButton.RightButton and not alt - - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - return - - Y, X = self.get_2Dlab(posData.lab).shape - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - if not myutils.is_in_bounds(xdata, ydata, X, Y): - self.isMouseDragImg2 = False - self.updateAllImages() - return - - if hasattr(self, 'scaleBar'): - if self.scaleBar.isHighlighted() and self.scaleBar.clicked: - self.scaleBar.clicked = False - return - - if hasattr(self, 'timestamp'): - if self.timestamp.isHighlighted() and self.timestamp.clicked: - self.timestamp.clicked = False - return - - sendRightClickImg2 = ( - (mode=='Segmentation and Tracking' or self.isSnapshot) - and right_click - ) - if sendRightClickImg2: - # Allow right-click actions on both images - self.gui_mouseReleaseEventImg2(event) - - # Right-click curvature tool mouse release - if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): - self.isRightClickDragImg1 = False - try: - self.curvToolSplineToObj(isRightClick=True) - self.update_rp() - if self.autoIDcheckbox.isChecked(): - self.trackManuallyAddedObject(posData.brushID, True) - if self.isSnapshot: - self.fixCcaDfAfterEdit('Add new ID with curvature tool') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Add new ID with curvature tool') - self.clearCurvItems() - self.curvTool_cb(True) - except ValueError: - self.clearCurvItems() - self.curvTool_cb(True) - pass - - # Eraser mouse release --> update IDs and contours - elif self.isMouseDragImg1 and self.eraserButton.isChecked(): - self.isMouseDragImg1 = False - - self.clearTempBrushImage() - - # Update data (rp, etc) - self.update_rp() - - doUpdateImages = self.checkWarnDeletedIDwithEraser() - - if doUpdateImages: - self.updateAllImages() - - # Brush button mouse release - elif self.isMouseDragImg1 and self.brushButton.isChecked(): - self.isMouseDragImg1 = False - - self.clearTempBrushImage() - - self.brushReleased() - - # Wand tool release, add new object - elif self.isMouseDragImg1 and self.wandToolButton.isChecked(): - self.isMouseDragImg1 = False - - self.clearTempBrushImage() - - posData = self.data[self.pos_i] - posData.lab[self.flood_mask] = posData.brushID - - # Update data (rp, etc) - self.update_rp() - - # Repeat tracking - self.trackManuallyAddedObject(posData.brushID, self.isNewID) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Add new ID with magic-wand') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Add new ID with magic-wand') - - # Label ROI mouse release --> label the ROI with labelRoiWorker - elif self.isMouseDragImg1 and self.labelRoiButton.isChecked(): - self.labelRoiRunning = True - self.app.setOverrideCursor(Qt.WaitCursor) - self.isMouseDragImg1 = False - - if self.labelRoiIsFreeHandRadioButton.isChecked(): - self.freeRoiItem.closeCurve() - - proceed = self.labelRoiCheckStartStopFrame() - if not proceed: - self.labelRoiCancelled() - return - - roiImg, self.labelRoiSlice = self.getLabelRoiImage() - - if roiImg.size == 0: - self.labelRoiCancelled() - return - - if self.labelRoiModel is None: - cancel = self.initLabelRoiModel() - if cancel: - self.labelRoiCancelled() - return - - # Restore state of button because it was maybe unchecked by - # using other tools that are allowed --> see "elif" case in - # labelRoi_cb - self.labelRoiButton.blockSignals(True) - self.labelRoiButton.setChecked(True) - self.labelRoiToolbar.setVisible(True) - self.labelRoiButton.blockSignals(False) - - roiSecondChannel = None - if self.secondChannelName is not None: - secondChannelData = self.getSecondChannelData() - roiSecondChannel = secondChannelData[self.labelRoiSlice] - - isTimelapse = self.labelRoiTrangeCheckbox.isChecked() - if isTimelapse: - start_n = self.labelRoiStartFrameNoSpinbox.value() - stop_n = self.labelRoiStopFrameNoSpinbox.value() - self.progressWin = apps.QDialogWorkerProgress( - title='ROI segmentation', parent=self, - pbarDesc=f'Segmenting frames n. {start_n} to {stop_n}...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(stop_n-start_n) - - - self.app.restoreOverrideCursor() - labelRoiWorker = self.labelRoiActiveWorkers[-1] - labelRoiWorker.start( - roiImg, posData, - roiSecondChannel=roiSecondChannel, - isTimelapse=isTimelapse - ) - self.app.setOverrideCursor(Qt.WaitCursor) - self.logger.info( - f'Magic labeller started on image ROI = {self.labelRoiSlice}...' - ) - self.titleLabel.setText('Magic labeller is doing its magic...') - self.setDisabled(True) - - # Move label mouse released, update move - elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): - self.isMovingLabel = False - - # Update data (rp, etc) - self.update_rp() - - # Repeat tracking - self.tracking(enforce=True, assign_unique_new_IDs=False) - - if not self.moveLabelToolButton.findChild(QAction).isChecked(): - self.moveLabelToolButton.setChecked(False) - else: - self.updateAllImages() - - # Assign mother to bud - elif self.assignBudMothButton.isChecked() and self.clickedOnBud: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud]: - return - - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - mothID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to annotate as mother cell', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - mothID_prompt.exec_() - if mothID_prompt.cancel: - return - else: - ID = mothID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) - - if self.isSnapshot: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - relationship = posData.cca_df.at[ID, 'relationship'] - ccs = posData.cca_df.at[ID, 'cell_cycle_stage'] - is_history_known = posData.cca_df.at[ID, 'is_history_known'] - # We allow assiging a cell in G1 as mother only on first frame - # OR if the history is unknown - if relationship == 'bud' and posData.frame_i > 0 and is_history_known: - self.assignBudMothButton.setChecked(False) - txt = html_utils.paragraph( - f'You clicked on ID {ID} which is a BUD.

' - 'To assign a bud start by clicking on the bud ' - 'and release on a cell in G1' - ) - msg = widgets.myMessageBox() - msg.critical( - self, 'Released on a bud', txt - ) - self.assignBudMothButton.setChecked(True) - return - - elif posData.frame_i == 0: - # Check that clicked bud actually is smaller that mother - # otherwise warn the user that he might have clicked first - # on a mother - budID = self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud] - new_mothID = self.get_2Dlab(posData.lab)[ydata, xdata] - bud_obj_idx = posData.IDs.index(budID) - new_moth_obj_idx = posData.IDs.index(new_mothID) - rp_budID = posData.rp[bud_obj_idx] - rp_new_mothID = posData.rp[new_moth_obj_idx] - if rp_budID.area >= rp_new_mothID.area: - self.assignBudMothButton.setChecked(False) - msg = widgets.myMessageBox() - txt = ( - f'You clicked FIRST on ID {budID} and then on {new_mothID}.
' - f'For me this means that you want ID {budID} to be the ' - f'BUD of ID {new_mothID}.
' - f'However ID {budID} is bigger than {new_mothID} ' - f'so maybe you should have clicked FIRST on {new_mothID}?

' - 'What do you want me to do?' - ) - txt = html_utils.paragraph(txt) - swapButton, keepButton = msg.warning( - self, 'Which one is bud?', txt, - buttonsTexts=( - f'Assign ID {new_mothID} as the bud of ID {budID}', - f'Keep ID {budID} as the bud of ID {new_mothID}' - ) - ) - if msg.clickedButton == swapButton: - (xdata, ydata, - self.xClickBud, self.yClickBud) = ( - self.xClickBud, self.yClickBud, - xdata, ydata - ) - self.assignBudMothButton.setChecked(True) - - elif is_history_known and not self.clickedOnHistoryKnown: - self.assignBudMothButton.setChecked(False) - budID = self.get_2Dlab(posData.lab)[ydata, xdata] - # Allow assigning an unknown cell ONLY to another unknown cell - txt = ( - f'You started by clicking on ID {budID} which has ' - 'UNKNOWN history, but you then clicked/released on ' - f'ID {ID} which has KNOWN history.\n\n' - 'Only two cells with UNKNOWN history can be assigned as ' - 'relative of each other.' - ) - msg = QMessageBox() - msg.critical( - self, 'Released on a cell with KNOWN history', txt, msg.Ok - ) - self.assignBudMothButton.setChecked(True) - return - - self.clickedOnHistoryKnown = is_history_known - self.xClickMoth, self.yClickMoth = xdata, ydata - - if ccs != 'G1' and posData.frame_i > 0: - self.assignBudMothButton.setChecked(False) - self.onMotherNotInG1(ID) - self.assignBudMothButton.setChecked(True) - else: - self.annotateBudToDifferentMother() - - if not self.assignBudMothButton.findChild(QAction).isChecked(): - self.assignBudMothButton.setChecked(False) - - self.clickedOnBud = False - self.BudMothTempLine.setData([], []) - - # Draw clear region mouse release - elif self.isMouseDragImg1 and self.drawClearRegionButton.isChecked(): - self.isMouseDragImg1 = False - self.freeRoiItem.closeCurve() - self.clearObjsFreehandRegion() - - # Zoom rect mouse release - elif self.isMouseDragImg1 and self.zoomRectButton.isChecked(): - self.isMouseDragImg1 = False - self.zoomRectDone() - - def gui_clickedDelRoi(self, event, left_click, right_click): - posData = self.data[self.pos_i] - x, y = event.pos().x(), event.pos().y() - - # Check if right click on ROI - delROIs = ( - posData.allData_li[posData.frame_i]['delROIs_info']['rois'].copy() - ) - for r, roi in enumerate(delROIs): - ROImask = self.getDelRoiMask(roi) - if self.isSegm3D: - clickedOnROI = ROImask[self.z_lab(), int(y), int(x)] - else: - clickedOnROI = ROImask[int(y), int(x)] - raiseContextMenuRoi = right_click and clickedOnROI - dragRoi = left_click and clickedOnROI - if raiseContextMenuRoi: - self.roi_to_del = roi - self.roiContextMenu = QMenu(self) - separator = QAction(self) - separator.setSeparator(True) - self.roiContextMenu.addAction(separator) - action = QAction('Remove ROI') - action.triggered.connect(self.removeDelROI) - self.roiContextMenu.addAction(action) - try: - # Convert QPointF to QPoint - self.roiContextMenu.exec_(event.screenPos().toPoint()) - except AttributeError: - self.roiContextMenu.exec_(event.screenPos()) - return True - elif dragRoi: - event.ignore() - return True - return False - - def gui_getHoveredSegmentsPolyLineRoi(self): - posData = self.data[self.pos_i] - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - segments = [] - for roi in delROIs_info['rois']: - if not isinstance(roi, pg.PolyLineROI): - continue - for seg in roi.segments: - if seg.currentPen == seg.hoverPen: - seg.roi = roi - segments.append(seg) - return segments - - def gui_getHoveredHandlesPolyLineRoi(self): - posData = self.data[self.pos_i] - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - handles = [] - for roi in delROIs_info['rois']: - if not isinstance(roi, pg.PolyLineROI): - continue - for handle in roi.getHandles(): - if handle.currentPen == handle.hoverPen: - handle.roi = roi - handles.append(handle) - return handles - - @exception_handler - def gui_mousePressRightImage(self, event): - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - alt = modifiers == Qt.AltModifier - isMod = alt - right_click = event.button() == Qt.MouseButton.RightButton and not isMod - is_right_click_action_ON = any([ - b.isChecked() for b in self.checkableQButtonsGroup.buttons() - ]) - self.typingEditID = False - showLabelsGradMenu = right_click and not is_right_click_action_ON - if showLabelsGradMenu: - self.gui_rightImageShowContextMenu(event) - event.ignore() - else: - self.gui_mousePressEventImg1(event) - - @exception_handler - def gui_mouseDragRightImage(self, event): - self.gui_mouseDragEventImg1(event) - - @exception_handler - def gui_mouseReleaseRightImage(self, event): - self.gui_mouseReleaseEventImg1(event) - - def drawTempRulerLine(self, event): - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - x, y = event.pos() - x1, y1 = int(x), int(y) - xxRA, yyRA = self.ax1_rulerAnchorsItem.getData() - x0, y0 = xxRA[0], yyRA[0] - if ctrl: - x1, y1 = transformation.snap_xy_to_closest_angle( - x0, y0, x1, y1 - ) - self.ax1_rulerPlotItem.setData([x0, x1], [y0, y1]) - - @exception_handler - def gui_mousePressEventImg1(self, event: QMouseEvent): - self.typingEditID = False - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - alt = modifiers == Qt.AltModifier - isMod = alt - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - isCcaMode = mode == 'Cell cycle analysis' - isCustomAnnotMode = mode == 'Custom annotations' - left_click = event.button() == Qt.MouseButton.LeftButton and not isMod - middle_click = self.isMiddleClick(event, modifiers) - right_click = event.button() == Qt.MouseButton.RightButton - isPanImageClick = self.isPanImageClick(event, modifiers) - brushON = self.brushButton.isChecked() - curvToolON = self.curvToolButton.isChecked() - histON = self.setIsHistoryKnownButton.isChecked() - eraserON = self.eraserButton.isChecked() - rulerON = self.rulerButton.isChecked() - wandON = self.wandToolButton.isChecked() and not isPanImageClick - polyLineRoiON = self.addDelPolyLineRoiButton.isChecked() - labelRoiON = self.labelRoiButton.isChecked() - keepObjON = self.keepIDsButton.isChecked() - whitelistIDsON = self.whitelistIDsButton.isChecked() - separateON = self.separateBudButton.isChecked() - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - manualBackgroundON = self.manualBackgroundButton.isChecked() - magicPromptsON = self.magicPromptsToolButton.isChecked() - pointsLayerON = self.togglePointsLayerAction.isChecked() - copyContourON = ( - self.copyLostObjButton.isChecked() - and self.ax1_lostObjScatterItem.hoverLostID>0 - ) - findNextMotherButtonON = self.findNextMotherButton.isChecked() - unknownLineageButtonON = self.unknownLineageButton.isChecked() - drawClearRegionON = self.drawClearRegionButton.isChecked() - zoomRectON = self.zoomRectButton.isChecked() - - # Check if right-click on segment of polyline roi to add segment - segments = self.gui_getHoveredSegmentsPolyLineRoi() - if len(segments) == 1 and right_click: - seg = segments[0] - seg.roi.segmentClicked(seg, event) - return - - # Check if right-click on handle of polyline roi to remove it - handles = self.gui_getHoveredHandlesPolyLineRoi() - if len(handles) == 1 and right_click: - handle = handles[0] - handle.roi.removeHandle(handle) - return - - # Check if click on ROI - isClickOnDelRoi = self.gui_clickedDelRoi(event, left_click, right_click) - if isClickOnDelRoi: - return - - dragImgLeft = ( - left_click and not brushON and not histON - and not curvToolON and not eraserON and not rulerON - and not wandON and not polyLineRoiON and not labelRoiON - and not middle_click and not keepObjON and not separateON - and not manualBackgroundON and not drawClearRegionON - and addPointsByClickingButton is None and not whitelistIDsON - and not zoomRectON - ) - if isPanImageClick: - dragImgLeft = True - - is_right_click_custom_ON = any([ - b.isChecked() for b in self.customAnnotDict.keys() - ]) - - canAnnotateDivision = ( - not self.assignBudMothButton.isChecked() - and not self.setIsHistoryKnownButton.isChecked() - and not self.curvToolButton.isChecked() - and not is_right_click_custom_ON - and not labelRoiON - and not separateON - ) - - # In timelapse mode division can be annotated if isCcaMode and right-click - # while in snapshot mode with Ctrl+right-click - isAnnotateDivision = ( - (right_click and isCcaMode and canAnnotateDivision) - or (right_click and ctrl and self.isSnapshot) - ) - - isCustomAnnot = ( - (right_click or dragImgLeft) - and (isCustomAnnotMode or self.isSnapshot) - and self.customAnnotButton is not None - ) - - is_right_click_action_ON = any([ - b.isChecked() for b in self.checkableQButtonsGroup.buttons() - ]) - - isOnlyRightClick = ( - right_click and canAnnotateDivision and not isAnnotateDivision - and not isMod and not is_right_click_action_ON - and not is_right_click_custom_ON and not copyContourON - and not findNextMotherButtonON and not unknownLineageButtonON - and not middle_click - ) - - if isOnlyRightClick: - # Start timer or check if it is a double-right-click - if self.countRightClicks == 0: - self.isDoubleRightClick = False - self.countRightClicks = 1 - self.doubleRightClickTimeElapsed = False - screenPos = event.screenPos() - self._img1_click_xy = (screenPos.x(), screenPos.y()) - QTimer.singleShot(400, self.doubleRightClickTimerCallBack) - return - elif ( - self.countRightClicks == 1 - and not self.doubleRightClickTimeElapsed - ): - self.isDoubleRightClick = True - self.countRightClicks = 0 - self.editIDbutton.setChecked(True) - - # Left click actions - canCurv = ( - curvToolON and not self.assignBudMothButton.isChecked() - and not brushON and not dragImgLeft and not eraserON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not magicPromptsON and not zoomRectON - ) - canBrush = ( - brushON and not curvToolON and not rulerON - and not dragImgLeft and not eraserON and not wandON - and not labelRoiON and not manualBackgroundON - and addPointsByClickingButton is None and not drawClearRegionON - and not magicPromptsON and not zoomRectON - ) - canErase = ( - eraserON and not curvToolON and not rulerON - and not dragImgLeft and not brushON and not wandON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not magicPromptsON and not zoomRectON - ) - canRuler = ( - rulerON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not wandON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not magicPromptsON and not zoomRectON - ) - canWand = ( - wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not magicPromptsON and not zoomRectON - ) - canPolyLine = ( - polyLineRoiON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not labelRoiON and not manualBackgroundON - and addPointsByClickingButton is None - and not drawClearRegionON and not magicPromptsON - and not zoomRectON - ) - canLabelRoi = ( - labelRoiON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not keepObjON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not whitelistIDsON and not magicPromptsON - and not zoomRectON - ) - canKeep = ( - keepObjON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not whitelistIDsON and not magicPromptsON - and not zoomRectON - ) - canWhitelistIDs = ( - whitelistIDsON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not keepObjON and not magicPromptsON - and not zoomRectON - ) - canAddPoint = ( - (pointsLayerON or magicPromptsON) - and addPointsByClickingButton is not None and not wandON - and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON and not keepObjON - and not manualBackgroundON and not drawClearRegionON - and not zoomRectON - ) - canAddManualBackgroundObj = ( - manualBackgroundON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not keepObjON and not drawClearRegionON - and not magicPromptsON and not whitelistIDsON - and not zoomRectON - ) - canDrawClearRegion = ( - drawClearRegionON and not wandON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not labelRoiON and not manualBackgroundON - and addPointsByClickingButton is None - and not polyLineRoiON and not magicPromptsON - and not whitelistIDsON and not zoomRectON - ) - canZoomRect = ( - zoomRectON and not curvToolON and not brushON - and not dragImgLeft and not brushON and not rulerON - and not polyLineRoiON and not labelRoiON - and addPointsByClickingButton is None - and not manualBackgroundON and not drawClearRegionON - and not wandON and not whitelistIDsON and not magicPromptsON - ) - - # Enable dragging of the image window or the scalebar - if dragImgLeft and not isCustomAnnot: - x, y = event.pos().x(), event.pos().y() - if hasattr(self, 'scaleBar'): - if self.scaleBar.isHighlighted(): - self.scaleBar.mousePressed(x, y) - return - if hasattr(self, 'timestamp'): - if self.timestamp.isHighlighted(): - self.timestamp.mousePressed(x, y) - return - pg.ImageItem.mousePressEvent(self.img1, event) - event.ignore() - return - - isAllowedActionViewer = (canAddPoint or canRuler) - - if mode == 'Viewer' and not isAllowedActionViewer: - self.startBlinkingModeCB() - event.ignore() - return - - # Allow right-click or middle-click actions on both images - eventOnImg2 = ( - ( - right_click or (middle_click and not canAddPoint) - # or (left_click and separateON) - ) - and (mode=='Segmentation and Tracking' or self.isSnapshot) - and not isAnnotateDivision and not manualBackgroundON - ) - if eventOnImg2: - event.isImg1Sender = True - self.gui_mousePressEventImg2(event) - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - Y, X = self.get_2Dlab(posData.lab).shape - if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - else: - return - - # Paint new IDs with brush and left click on the left image - if left_click and canBrush: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - lab_2D = self.get_2Dlab(posData.lab) - Y, X = lab_2D.shape - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False, storeOnlyZoom=True) - - ID = self.getHoverID(xdata, ydata) - - if ID > 0: - posData.brushID = ID - self.isNewID = False - else: - # Update brush ID. Take care of disappearing cells to remember - # to not use their IDs anymore in the future - self.isNewID = True - self.setBrushID() - self.updateLookuptable(lenNewLut=posData.brushID+1) - - self.brushColor = self.lut[posData.brushID]/255 - - self.yPressAx2, self.xPressAx2 = y, x - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) - - self.isMouseDragImg1 = True - - # Draw new objects - localLab = lab_2D[diskSlice] - mask = diskMask.copy() - if not self.isPowerBrush() and not ctrl: - mask[localLab!=0] = False - - self.applyBrushMask(mask, posData.brushID, toLocalSlice=diskSlice) - - self.setImageImg2(updateLookuptable=False) - - how = self.drawIDsContComboBox.currentText() - lab2D = self.get_2Dlab(posData.lab) - self.globalBrushMask = np.zeros(lab2D.shape, dtype=bool) - brushMask = localLab == posData.brushID - brushMask = np.logical_and(brushMask, diskMask) - self.setTempImg1Brush( - True, brushMask, posData.brushID, toLocalSlice=diskSlice - ) - - self.lastHoverID = -1 - - elif left_click and canErase: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - lab_2D = self.get_2Dlab(posData.lab) - Y, X = lab_2D.shape - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False, storeOnlyZoom=True) - - self.yPressAx2, self.xPressAx2 = y, x - # Keep a list of erased IDs got erased - self.erasedIDs = set() - - if self.xyOnCtrlPressedFirstTime is not None: - self.erasedID = self.getHoverID(*self.xyOnCtrlPressedFirstTime) - else: - self.erasedID = self.getHoverID(xdata, ydata) - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - - # Build eraser mask - mask = np.zeros(lab_2D.shape, bool) - mask[ymin:ymax, xmin:xmax][diskMask] = True - - - # If user double-pressed 'b' then erase over ALL labels - color = self.eraserButton.palette().button().color().name() - eraseOnlyOneID = ( - color != self.doublePressKeyButtonColor - and self.erasedID != 0 - ) - - self.eraseOnlyOneID = eraseOnlyOneID - - if eraseOnlyOneID: - mask[lab_2D!=self.erasedID] = False - - self.setTempImg1Eraser(mask, init=True) - self.applyEraserMask(mask) - - self.erasedIDs.update(lab_2D[mask]) - - for erasedID in self.erasedIDs: - if erasedID == 0: - continue - self.erasedLab[lab_2D==erasedID] = erasedID - - self.isMouseDragImg1 = True - - elif canAddPoint: - action = addPointsByClickingButton.action - self.storeUndoAddPoint(action) - x, y = event.pos().x(), event.pos().y() - hoveredPoints = action.scatterItem.pointsAt(event.pos()) - if len(hoveredPoints) > 0: - removed_ids = self.removeClickedPoints(action, hoveredPoints) - if not magicPromptsON: - removed_id = min(removed_ids) - addPointsByClickingButton.pointIdSpinbox.setValue(removed_id) - addPointsByClickingButton.pointIdSpinbox.removedId = ( - removed_id - ) - else: - self.restorePrevPointIdRightClick(addPointsByClickingButton) - self.drawPointsLayers(computePointsLayers=False) - else: - point_id = self.getAddedPointId( - magicPromptsON, addPointsByClickingButton, - right_click, left_click, middle_click - ) - if point_id is None: - return - - self.addClickedPoint(action, x, y, point_id) - self.drawPointsLayers(computePointsLayers=False) - - point_id = self.getClickedPointNewId( - action, point_id, - addPointsByClickingButton.pointIdSpinbox, - isMagicPrompts=magicPromptsON - ) - addPointsByClickingButton.pointIdSpinbox.setValue( - point_id, setLinkedWidget=False - ) - - elif left_click and canDrawClearRegion: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - self.freeRoiItem.addPoint(xdata, ydata) - - self.isMouseDragImg1 = True - - elif left_click and canRuler or canPolyLine: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - closePolyLine = ( - len(self.startPointPolyLineItem.pointsAt(event.pos())) > 0 - ) - if not self.tempSegmentON or canPolyLine: - # Keep adding anchor points for polyline - self.ax1_rulerAnchorsItem.setData([xdata], [ydata]) - self.tempSegmentON = True - else: - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - self.tempSegmentON = False - xxRA, yyRA = self.ax1_rulerAnchorsItem.getData() - x0, y0 = xxRA[0], yyRA[0] - if ctrl: - x1, y1 = transformation.snap_xy_to_closest_angle( - x0, y0, xdata, ydata - ) - else: - x1, y1 = xdata, ydata - lengthText = self.getRulerLengthText() - self.ax1_rulerPlotItem.setData( - [x0, x1], [y0, y1], lengthText=lengthText - ) - self.ax1_rulerAnchorsItem.setData([x0, x1], [y0, y1]) - - xxPolyLine = self.startPointPolyLineItem.getData()[0] - if canPolyLine and len(xxPolyLine) == 0: - # Create and add roi item - self.createDelPolyLineRoi() - # Add start point of polyline roi - self.startPointPolyLineItem.setData([xdata], [ydata]) - self.polyLineRoi.points.append((xdata, ydata)) - elif canPolyLine: - # Add points to polyline roi and eventually close it - if not closePolyLine: - self.polyLineRoi.points.append((xdata, ydata)) - self.addPointsPolyLineRoi(closed=closePolyLine) - if closePolyLine: - # Close polyline ROI - if len(self.polyLineRoi.getLocalHandlePositions()) == 2: - self.polyLineRoi = self.replacePolyLineRoiWithLineRoi( - self.polyLineRoi - ) - self.tempSegmentON = False - self.ax1_rulerAnchorsItem.setData([], []) - self.ax1_rulerPlotItem.setData([], []) - self.startPointPolyLineItem.setData([], []) - self.addRoiToDelRoiInfo(self.polyLineRoi) - # Call roi moving on closing ROI - self.delROImoving(self.polyLineRoi) - self.delROImovingFinished(self.polyLineRoi) - - elif left_click and canKeep: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - keepID_win = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to keep', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - keepID_win.exec_() - if keepID_win.cancel: - return - else: - ID = keepID_win.EntryID - - if ID in self.keptObjectsIDs: - self.keptObjectsIDs.remove(ID) - self.clearHighlightedText() - else: - self.keptObjectsIDs.append(ID) - self.highlightLabelID(ID) - - self.updateTempLayerKeepIDs() - - elif left_click and canWhitelistIDs: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - keepID_win = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to select', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - keepID_win.exec_() - if keepID_win.cancel: - return - else: - ID = keepID_win.EntryID - - posData = self.data[self.pos_i] - - if not posData.whitelist: - wl_init = False - if not hasattr(self, 'tempWhitelistIDs'): - self.tempWhitelistIDs = set() # not updated, only use in this context - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = self.tempWhitelistIDs - else: - wl_init = True - current_whitelist = posData.whitelist.get(posData.frame_i) - - if ID in current_whitelist: - current_whitelist.remove(ID) - self.removeHighlightLabelID(IDs=[ID]) - else: - current_whitelist.add(ID) - self.highlightLabelID(ID) - - self.whitelistIDsToolbar.whitelistLineEdit.setText( - current_whitelist - ) - - if wl_init: - posData.whitelist[posData.frame_i] = current_whitelist - else: - self.tempWhitelistIDs = current_whitelist - - self.whitelistUpdateTempLayer() - - elif right_click and copyContourON: - hoverLostID = self.ax1_lostObjScatterItem.hoverLostID - self.copyLostObjectMask(hoverLostID) - self.update_rp() - self.updateAllImages() - self.store_data() - - elif right_click and canCurv: - # Draw manually assisted auto contour - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - Y, X = self.get_2Dlab(posData.lab).shape - - self.autoCont_x0 = xdata - self.autoCont_y0 = ydata - self.xxA_autoCont, self.yyA_autoCont = [], [] - self.curvAnchors.addPoints([x], [y]) - img = self.getDisplayedImg1() - self.autoContObjMask = np.zeros(img.shape, np.uint8) - self.isRightClickDragImg1 = True - - elif left_click and canCurv: - # Draw manual spline - x, y = event.pos().x(), event.pos().y() - Y, X = self.get_2Dlab(posData.lab).shape - - # Check if user clicked on starting anchor again --> close spline - closeSpline = False - clickedAnchors = self.curvAnchors.pointsAt(event.pos()) - xxA, yyA = self.curvAnchors.getData() - if len(xxA)>0: - if len(xxA) == 1: - self.splineHoverON = True - x0, y0 = xxA[0], yyA[0] - if len(clickedAnchors)>0: - xA_clicked, yA_clicked = clickedAnchors[0].pos() - if x0==xA_clicked and y0==yA_clicked: - x = x0 - y = y0 - closeSpline = True - - # Add anchors - self.curvAnchors.addPoints([x], [y]) - try: - xx, yy = self.curvHoverPlotItem.getData() - self.curvPlotItem.setData(xx, yy) - except Exception as e: - # traceback.print_exc() - pass - - if closeSpline: - self.splineHoverON = False - self.curvToolSplineToObj() - self.update_rp() - if self.autoIDcheckbox.isChecked(): - self.trackManuallyAddedObject(posData.brushID, True) - if self.isSnapshot: - self.fixCcaDfAfterEdit('Add new ID with curvature tool') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Add new ID with curvature tool') - self.clearCurvItems() - self.curvTool_cb(True) - - elif left_click and canWand: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - Y, X = self.get_2Dlab(posData.lab).shape - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - self.isNewID = False - posData.brushID = self.get_2Dlab(posData.lab)[ydata, xdata] - if posData.brushID == 0: - self.setBrushID() - self.updateLookuptable( - lenNewLut=posData.brushID+1 - ) - self.isNewID = True - self.brushColor = self.img2.lut[posData.brushID]/255 - - # NOTE: flood is on mousedrag or release - tol = self.getMagicWandFloodTolerance() - self.initFloodMaskImage() - if self.isSegm3D: - z_slice = self.zSliceScrollBar.sliderPosition() - seed = (z_slice, ydata, xdata) - else: - seed = (ydata, xdata) - - flood_mask = skimage.segmentation.flood( - self.flood_img, seed, tolerance=tol - ) - - drawUnderMask = np.logical_or( - posData.lab==0, posData.lab==posData.brushID - ) - self.flood_mask = np.logical_and(flood_mask, drawUnderMask) - - if self.wandControlsToolbar.autoFillHolesCheckbox.isChecked(): - self.flood_mask = core.binary_fill_holes(self.flood_mask) - - if self.wandControlsToolbar.useConvexHullCheckbox.isChecked(): - self.flood_mask = core.convex_hull_mask(self.flood_mask) - - self.setTempBrushMaskFromWand(self.flood_mask, init=True) - self.isMouseDragImg1 = True - - elif right_click and self.manualTrackingButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - manualTrackID = self.manualTrackingToolbar.spinboxID.value() - clickedID = self.getClickedID( - xdata, ydata, text=f'that you want to assign to {manualTrackID}' - ) - if clickedID is None: - return - - if clickedID == manualTrackID: - self.manualTrackingToolbar.showWarning( - f'The clicked object already has ID = {manualTrackID}' - ) - return - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - posData = self.data[self.pos_i] - currentIDs = posData.IDs.copy() - if manualTrackID in currentIDs: - tempID = max(currentIDs) + 1 - posData.lab[posData.lab == clickedID] = tempID - posData.lab[posData.lab == manualTrackID] = clickedID - posData.lab[posData.lab == tempID] = manualTrackID - self.manualTrackingToolbar.showWarning( - f'The ID {manualTrackID} already exists --> ' - f'ID {manualTrackID} has been swapped with {clickedID}' - ) - else: - posData.lab[posData.lab == clickedID] = manualTrackID - self.manualTrackingToolbar.showInfo( - f'ID {clickedID} changed to {manualTrackID}.' - ) - - self.update_rp() - self.updateAllImages() - - elif right_click and manualBackgroundON: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - - delID = posData.manualBackgroundLab[ydata, xdata] - if delID == 0: - return - - self.clearManualBackgroundObject(delID) - textItem = self.manualBackgroundTextItems.pop(delID) - self.ax1.removeItem(textItem) - self.setManualBackgroundImage() - - elif left_click and canAddManualBackgroundObj: - x, y = event.pos().x(), event.pos().y() - - self.addManualBackgroundObject(x, y) - self.setManualBackgroundImage() - self.setManualBackgrounNextID() - - # Label ROI mouse press - elif (left_click or right_click) and canLabelRoi: - if right_click: - # Force model initialization on mouse release - self.labelRoiModel = None - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - - if self.labelRoiIsRectRadioButton.isChecked(): - self.labelRoiItem.setPos((xdata, ydata)) - elif self.labelRoiIsFreeHandRadioButton.isChecked(): - self.freeRoiItem.addPoint(xdata, ydata) - - self.isMouseDragImg1 = True - - # Annotate cell cycle division - elif isAnnotateDivision: - if posData.cca_df is None: - return - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - divID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to annotate as divided', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - divID_prompt.exec_() - if divID_prompt.cancel: - return - else: - ID = divID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) - - if not self.isSnapshot: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - # Annotate or undo division - self.manualCellCycleAnnotation(ID) - else: - self.undoBudMothAssignment(ID) - - # Assign bud to mother (mouse down on bud) - elif right_click and self.assignBudMothButton.isChecked(): - if self.clickedOnBud: - # NOTE: self.clickedOnBud is set to False when assigning a mother - # is successfull in mouse release event - # We still have to click on a mother - return - - if posData.cca_df is None: - return - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - budID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID of a bud you want to correct mother assignment', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - budID_prompt.exec_() - if budID_prompt.cancel: - return - else: - ID = budID_prompt.EntryID - - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) - - relationship = posData.cca_df.at[ID, 'relationship'] - is_history_known = posData.cca_df.at[ID, 'is_history_known'] - self.clickedOnHistoryKnown = is_history_known - # We allow assiging a cell in G1 as bud only on first frame - # OR if the history is unknown - if relationship != 'bud' and posData.frame_i > 0 and is_history_known: - txt = (f'You clicked on ID {ID} which is NOT a bud.\n' - 'To assign a bud to a cell start by clicking on a bud ' - 'and release on a cell in G1') - msg = QMessageBox() - msg.critical( - self, 'Not a bud', txt, msg.Ok - ) - return - - self.clickedOnBud = True - self.xClickBud, self.yClickBud = xdata, ydata - - # Annotate (or undo) that cell has unknown history - elif right_click and self.setIsHistoryKnownButton.isChecked(): - if posData.cca_df is None: - return - - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - unknownID_prompt = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to annotate as ' - '"history UNKNOWN/KNOWN"', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - unknownID_prompt.exec_() - if unknownID_prompt.cancel: - return - else: - ID = unknownID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) - - self.annotateIsHistoryKnown(ID) - if not self.setIsHistoryKnownButton.findChild(QAction).isChecked(): - self.setIsHistoryKnownButton.setChecked(False) - - elif isCustomAnnot: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - clickedBkgrDialog = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter ID that you want to annotate as divided', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - clickedBkgrDialog.exec_() - if clickedBkgrDialog.cancel: - return - else: - ID = clickedBkgrDialog.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) - - button = self.doCustomAnnotation(ID) - if button is None: - return - - keepActive = self.customAnnotDict[button]['state']['keepActive'] - if not keepActive: - button.setChecked(False) - - elif right_click and findNextMotherButtonON: - if posData.frame_i == 0: - return - - self.find_mother_action(posData, event, ydata, xdata) - - elif right_click and unknownLineageButtonON: - if posData.frame_i == 0: - return - - self.annotate_unknown_lineage_action(posData, event, ydata, xdata) - - elif (left_click or right_click) and canZoomRect: - if left_click: - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - - self.zoomRectItem.setPos((xdata, ydata)) - - self.isMouseDragImg1 = True - else: - try: - xRange, yRange = self.zoomRectItem.getLastRange() - self.ax1.setRange( - xRange=xRange, - yRange=yRange, - padding=0 - ) - except Exception as err: - QTimer.singleShot(100, self.autoRange) - - def repeat_click_and_backup(self, posData, event, ydata, xdata): - """ - This function is part of the lin_tree edit functionality. - It handles the back up of the original self.lineage_tree.lineage_list - df and the repeated clicking on the same ID to cycle through pssible mothers. - - Parameters - ---------- - posData : cellacdc.load.loadData - The position data. - event : QtGui.QMouseEvent - The event object. - ydata : int - The y-coordinate data. - xdata : int - The x-coordinate data. - - Returns - ------- - tuple - A tuple containing the point(tuple: (x, y) coords) and ID of clicked cell. - """ - if self.original_df_lin_tree is None: - self.original_df_lin_tree = posData.allData_li[posData.frame_i]['acdc_df'].copy() - self.original_df_lin_tree_i = posData.frame_i - elif self.original_df_lin_tree_i != posData.frame_i: - self.logger.info( - '[WARNING]: !!! Original lineage tree df changed, resetting original_df_lin_tree !!!' - ) - self.original_df_lin_tree = posData.allData_li[posData.frame_i]['acdc_df'].copy() - self.original_df_lin_tree_i = posData.frame_i - - if not self.right_click_ID: - self.right_click_i = 0 - self.right_click_ID = 0 - - x, y = event.pos().x(), event.pos().y() - point = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - - if ID == 0: - return None, None - - if self.right_click_ID != ID: - self.right_click_i = 0 - self.right_click_ID = ID - self.original_mother_skipped = False - elif event.modifiers() & Qt.ShiftModifier: - self.right_click_i -= 1 - else: - self.right_click_i += 1 - - return point, ID - - def getDistanceListMissingIDs(self, point, ID): - posData = self.data[self.pos_i] - frame_i = posData.frame_i - if self.getDistanceListMissingIDsCachedFrame != frame_i: - self.distanceListMissingIDs = dict() - self.getDistanceListMissingIDsCachedFrame = frame_i - # self.store_data(autosave=False) - # self.get_data() - - if ID not in self.distanceListMissingIDs.keys(): - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - relevant_rp = [ - obj for obj in prev_rp if obj.label not in posData.IDs - ] - len_relevant_rp = len(relevant_rp) - if len_relevant_rp == 0: - self.logger.info('No missing IDs found in previous frame.') - return [] - elif len_relevant_rp == 1: - self.distanceListMissingIDs[ID] = [relevant_rp[0].label] - return [relevant_rp[0].label] - else: - sorted_missing_IDs = myutils.sort_IDs_dist(relevant_rp, point=point) - self.distanceListMissingIDs[ID] = sorted_missing_IDs - return sorted_missing_IDs - else: - return self.distanceListMissingIDs[ID] - - def find_mother_action(self, posData, event, ydata, xdata): - """ - This function is part of the lin_tree edit functionality. - Associated with the right-click action of the 'findNextMotherButton' button. - Handles the right click action, which cycles through possible mothers of the clicked cell. - Changes the parent ID of the clicked cell to the next possible mother in self.lineage_tree.lineage_list. - - Parameters - ---------- - posData : cellacdc.load.loadData - The position data object. - event : QtGui.QMouseEvent - The event object. - ydata : int - The y-coordinate data. - xdata : int - The x-coordinate data. - """ - point, ID = self.repeat_click_and_backup(posData, event, ydata, xdata) - - if point is None: - return - posData = self.data[self.pos_i] - acdc_df_frame = posData.allData_li[posData.frame_i]['acdc_df'] - filtered_IDs = self.getDistanceListMissingIDs(point, ID) - if len(filtered_IDs) == 0: - self.logger.info('No mother candidates found.') - return - - i = self.right_click_i % len(filtered_IDs) - i = abs(i) # Ensure i is non-negative - new_mother = filtered_IDs[i] - - if acdc_df_frame.loc[ID]['parent_ID_tree'] == new_mother and self.original_mother_skipped == False: # if a mother is already present, skip it - self.right_click_i += 1 - self.original_mother_skipped = True - - i = self.right_click_i % len(filtered_IDs) - i = abs(i) # Ensure i is non-negative - new_mother = filtered_IDs[i] - - acdc_df_frame.at[ID, 'parent_ID_tree'] = new_mother # update mother in the df, no need to propagate or stuff lile this - # dont need to update alldata_li as acdc_df_frame is just a view - self.drawAllLineageTreeLines() - - def annotate_unknown_lineage_action(self, posData, event, ydata, xdata): - """ - This function is part of the lin_tree edit functionality. - Associated with the right-click action of the 'unknownLineageButton' button. - Annotates an unknown lineage by setting its parent ID to -1 in the lineage tree (self.lineage_tree.lineage_list) - - Parameters - ---------- - posData : cellacdc.load.loadData - The position data. - event : QtGui.QMouseEvent - The event that triggered the annotation. - ydata : int - The y-coordinate data. - xdata : int - The x-coordinate data. - """ - point, ID = self.repeat_click_and_backup(posData, event, ydata, xdata) - - if point is None: - return - posData = self.data[self.pos_i] - acdc_df_frame = posData.allData_li[posData.frame_i]['acdc_df'] - acdc_df_frame.at[ID, 'parent_ID_tree'] = -1 - self.drawAllLineageTreeLines() - - def gui_addCreatedAxesItems(self): - self.ax1.addItem(self.ax1_contoursImageItem) - self.ax1.addItem(self.ax1_lostObjImageItem) - self.ax1.addItem(self.ax1_lostTrackedObjImageItem) - self.ax1.addItem(self.ax1_oldMothBudLinesItem) - self.ax1.addItem(self.ax1_newMothBudLinesItem) - self.ax1.addItem(self.ax1_lostObjScatterItem) - self.ax1.addItem(self.ax1_lostTrackedScatterItem) - self.ax1.addItem(self.ccaFailedScatterItem) - self.ax1.addItem(self.yellowContourScatterItem) - - self.ax2.addItem(self.ax2_contoursImageItem) - self.ax2.addItem(self.ax2_lostObjImageItem) - self.ax2.addItem(self.ax2_lostTrackedObjImageItem) - self.ax2.addItem(self.ax2_oldMothBudLinesItem) - self.ax2.addItem(self.ax2_newMothBudLinesItem) - self.ax2.addItem(self.ax2_lostObjScatterItem) - - self.textAnnot[0].addToPlotItem(self.ax1) - self.textAnnot[1].addToPlotItem(self.ax2) - - self.ax1.addItem(self.exportMaskImageItem) - self.ax1.exportMaskImageItem = self.exportMaskImageItem - - def SegForLostIDsSetSettings(self): - - try: - prev_model = str(self.df_settings.at['SegForLostIDsModel', 'value']) - except KeyError: - prev_model = None - win = apps.QDialogSelectModel(parent=self, customFirst=prev_model) - win.exec_() - if win.cancel: - self.logger.info('Seg for lost IDs cancelled.') - return - base_model_name = win.selectedModel - - if base_model_name: - self.df_settings.at['SegForLostIDsModel', 'value'] = base_model_name - self.df_settings.to_csv(self.settings_csv_path) - - model_name = 'local_seg' - - idx = self.modelNames.index(model_name) - acdcSegment = self.acdcSegment_li[idx] - - try: - if acdcSegment is None or base_model_name != self.local_seg_base_model_name: - self.logger.info(f'Importing {base_model_name}...') - acdcSegment = myutils.import_segment_module(base_model_name) - self.acdcSegment_li[idx] = acdcSegment - self.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: - self.logger.error(f'Error importing {base_model_name}: {e}') - return - - extra_params = ['overlap_threshold', - 'padding', - 'size_perc_diff', - 'distance_filler_growth', - 'max_iterations', - 'allow_only_tracked_cells'] - - extra_types = [float, float, float, float, int, bool] - - extra_defaults = [0.5, 0.8, 0.3, 1., 2, False] - - extra_desc = ['Overlap threshold with other already segemented cells over which newly segmented cells are discarded', - 'Padding of the box used for new segmentation around the segmentation from the previous frame', - 'Relative size difference acceptable compared to previous frames', - """Cells which are already segmented are filled with random noise sampled from background - to ensure that they don't get segmented again. - This parameter controls the additional padding around the already segmented cells.""", - """The algorithm will try and segment the maximum amount - of cells in the image by running the model several - times and filling new found cells with background noise. - How many of these iterations should be run?""", - "If no new cell IDs should be permitted (based on real time tracking)"] - - extra_ArgSpec = [] - for i, param in enumerate(extra_params): - param = ArgSpec(name=param, - default=extra_defaults[i], - type=extra_types[i], - desc=extra_desc[i], - docstring='') - - extra_ArgSpec.append(param) - - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - segment_params = [arg for arg in segment_params if arg[0] != 'diameter'] - - extraParamsTitle = 'Settings for local segmentation' - win = self.initSegmModelParams( - base_model_name, acdcSegment, init_params, segment_params, - extraParams=extra_ArgSpec, extraParamsTitle=extraParamsTitle, - initLastParams=True, ini_filename='segmentation_for_lostIDs.ini', - ) - - if win is None: - self.logger.info('Segmentation for lost IDs cancelled.') - return - - init_kwargs_new = {} - args_new = {} - for key, val in win.init_kwargs.items(): - if key in extra_params: - args_new[key] = val - else: - init_kwargs_new[key] = val - - for key, val in win.extra_kwargs.items(): - if key in extra_params: - args_new[key] = val - - self.SegForLostIDsSettings = { - 'win': win, - 'init_kwargs_new': init_kwargs_new, - 'args_new': args_new, - 'base_model_name': base_model_name, - } - - def segForLostIDsButtonClicked(self): - - self.setFrameNavigationDisabled(disable=True, why='Segmentation for lost IDs') - posData = self.data[self.pos_i] - if posData.frame_i == 0: - self.logger.info('Segmentation for lost IDs not available on first frame.') - self.setFrameNavigationDisabled(disable=False, why='Segmentation for lost IDs') - return - self.storeUndoRedoStates(False) - self.progressWin = apps.QDialogWorkerProgress( - title='Segmenting for lost IDs', parent=self, - pbarDesc=f'Segmenting for lost IDs...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - self.startSegForLostIDsWorker() - - def onSegForLostInit(self): - self.logger.info('Settings for segmentation for lost IDs not set.') - self.SegForLostIDsSetSettings() - self.SegForLostIDsWaitCond.wakeAll() - - def SegForLostIDsWorkerAskInstallModel(self, model_name): - myutils.check_install_package(model_name) - self.SegForLostIDsWaitCond.wakeAll() - - def startSegForLostIDsWorker(self): - self.SegForLostIDsMutex = QMutex() - self.SegForLostIDsWaitCond = QWaitCondition() - self._thread = QThread() - - # Initialize the worker with mutex and wait condition - self.SegForLostIDsWorker = workers.SegForLostIDsWorker( - self, self.SegForLostIDsMutex, self.SegForLostIDsWaitCond - ) - - # Connect the worker's signal to the main thread's slot - self.SegForLostIDsWorker.sigAskInit.connect(self.onSegForLostInit) - self.SegForLostIDsWorker.sigAskInstallModel.connect( - self.SegForLostIDsWorkerAskInstallModel - ) - self.SegForLostIDsWorker.sigshowImageDebug.connect( - self.showImageDebug - ) - - self.SegForLostIDsWorker.sigSegForLostIDsWorkerAskInstallGPU.connect( - self.SegForLostIDsWorkerAskInstallGPU - ) - - self.SegForLostIDsWorker.sigStoreData.connect(self.onSigStoreDataSegForLostIDsWorker) - self.SegForLostIDsWorker.sigUpdateRP.connect(self.onSigUpdateRPSegForLostIDsWorker) - # self.SegForLostIDsWorker.sigGetData.connect(self.onSigGetDataSegForLostIDsWorker) - # self.SegForLostIDsWorker.sigGet2Dlab.connect(self.onSigGet2DlabSegForLostIDsWorker) - # self.SegForLostIDsWorker.sigGetTrackedLostIDs.connect(self.onSigGetTrackedSegForLostIDsWorker) - # self.SegForLostIDsWorker.sigGetBrushID.connect(self.onSigGetBrushIDSegForLostIDsWorker) - self.SegForLostIDsWorker.sigTrackManuallyAddedObject.connect(self.onSigTrackManuallyAddedObjectSegForLostIDsWorker) - - # Move the worker to the thread - self.SegForLostIDsWorker.moveToThread(self._thread) - - # Manage thread lifecycle - self.SegForLostIDsWorker.signals.finished.connect(self._thread.quit) - self.SegForLostIDsWorker.signals.finished.connect(self.SegForLostIDsWorker.deleteLater) - self._thread.finished.connect(self._thread.deleteLater) - - # Connect other worker signals to the appropriate slots - self.SegForLostIDsWorker.signals.finished.connect(self.SegForLostIDsWorkerFinished) - self.SegForLostIDsWorker.signals.progress.connect(self.workerProgress) - self.SegForLostIDsWorker.signals.initProgressBar.connect(self.workerInitProgressbar) - self.SegForLostIDsWorker.signals.progressBar.connect(self.workerUpdateProgressbar) - self.SegForLostIDsWorker.signals.critical.connect(self.workerCritical) - - # Start the thread and worker - self._thread.started.connect(self.SegForLostIDsWorker.run) - self._thread.start() - - def SegForLostIDsWorkerAskInstallGPU(self, model_name, use_gpu): - result = myutils.check_gpu_available(model_name, use_gpu, qparent=self) - self.SegForLostIDsWorker.gpu_go = result - dont_force_cpu = myutils.check_gpu_available( - model_name, use_gpu, do_not_warn=True) - self.SegForLostIDsWorker.dont_force_cpu = dont_force_cpu - self.SegForLostIDsWaitCond.wakeAll() - - def onSigStoreDataSegForLostIDsWorker(self, autosave): - self.onSigStoreData( - self.SegForLostIDsWaitCond, autosave=autosave) - - def onSigUpdateRPSegForLostIDsWorker(self, wl_update, wl_track_og_curr): - self.onSigUpdateRP(self.SegForLostIDsWaitCond, - wl_update=wl_update, - wl_track_og_curr=wl_track_og_curr) - - # def onSigGetDataSegForLostIDsWorker(self): - # self.onSigGetData( - # self.SegForLostIDsWaitCond) - - # def onSigGet2DlabSegForLostIDsWorker(self): - # posData = self.data[self.pos_i] - # lab = self.get_2Dlab(posData.lab) - # self.SegForLostIDsWorker.lab = lab - # self.SegForLostIDsWaitCond.wakeAll() - - # def onSigGetTrackedSegForLostIDsWorker(self): - # self.SegForLostIDsWorker.trackedLostIDs = self.getTrackedLostIDs() - # self.SegForLostIDsWaitCond.wakeAll() - - # def onSigGetBrushIDSegForLostIDsWorker(self): - # self.SegForLostIDsWorker.brushID = self.setBrushID(useCurrentLab=True, return_val=True) - # self.SegForLostIDsWaitCond.wakeAll() - - def onSigTrackManuallyAddedObjectSegForLostIDsWorker(self, added_IDs, isNewID, wl_update, wl_track_og_curr): - self.trackManuallyAddedObject(added_IDs, isNewID, wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) - self.SegForLostIDsWaitCond.wakeAll() - - - def onSigStoreData( - self, waitcond, pos_i=None, enforce=True, debug=False, - mainThread=True, autosave=True, store_cca_df_copy=False - ): - self.store_data(pos_i=pos_i, enforce=enforce, debug=debug, mainThread=mainThread, - autosave=autosave, store_cca_df_copy=store_cca_df_copy) - waitcond.wakeAll() - - def onSigUpdateRP(self, waitcond, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False): - self.update_rp(draw=draw, debug=debug, update_IDs=update_IDs, - wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) - waitcond.wakeAll() - - def onSigGetData(self, waitcond, debug=False): - self.get_data(debug=debug) - waitcond.wakeAll() - - def SegForLostIDsWorkerFinished(self): - self.updateAllImages() - self.update_rp() - self.store_data(autosave=True) - self.setFrameNavigationDisabled(disable=False, why='Segmentation for lost IDs') - - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - def showImageDebug(self, img): - imshow(img) - - def gui_raiseBottomLayoutContextMenu(self, event): - try: - # Convert QPointF to QPoint - self.bottomLayoutContextMenu.popup(event.globalPos().toPoint()) - except AttributeError: - self.bottomLayoutContextMenu.popup(event.globalPos()) - - def areContoursRequested(self, ax): - if ax == 0 and self.annotContourCheckbox.isChecked(): - return True - - if ax == 1: - if not self.labelsGrad.showRightImgAction.isChecked(): - return False - - isRightDifferentAnnot = self.rightBottomGroupbox.isChecked() - areContRequestedRight = self.annotContourCheckboxRight.isChecked() - - if isRightDifferentAnnot and areContRequestedRight: - return True - - areContRequestedLeft = self.annotContourCheckbox.isChecked() - if not isRightDifferentAnnot and areContRequestedLeft: - return True - return False - - def areMothBudLinesRequested(self, ax): - if ax == 0: - if self.annotCcaInfoCheckbox.isChecked(): - return True - if self.drawMothBudLinesCheckbox.isChecked(): - return True - else: - if not self.labelsGrad.showRightImgAction.isChecked(): - return False - - isRightDifferentAnnot = self.rightBottomGroupbox.isChecked() - areLinesRequestedRight = ( - self.annotCcaInfoCheckboxRight.isChecked() - or self.drawMothBudLinesCheckboxRight.isChecked() - ) - - if isRightDifferentAnnot and areLinesRequestedRight: - return True - - areLinesRequestedLeft = ( - self.drawMothBudLinesCheckbox.isChecked() - or self.annotCcaInfoCheckbox.isChecked() - ) - if not isRightDifferentAnnot and areLinesRequestedLeft: - return True - return False - - def getMothBudLineScatterItem(self, ax, new): - if ax == 0: - if new: - return self.ax1_newMothBudLinesItem - else: - return self.ax1_oldMothBudLinesItem - else: - if new: - return self.ax2_newMothBudLinesItem - else: - return self.ax2_oldMothBudLinesItem - - def labelRoiIsCircularRadioButtonToggled(self, checked): - if checked: - self.labelRoiCircularRadiusSpinbox.setDisabled(False) - else: - self.labelRoiCircularRadiusSpinbox.setDisabled(True) - - def pxModeActionToggled(self, checked): - self.df_settings.at['pxMode', 'value'] = int(checked) - self.df_settings.to_csv(self.settings_csv_path) - - if not self.isDataLoaded: - return - - if self.highLowResAction.isChecked(): - for ax in range(2): - self.textAnnot[ax].setPxMode(checked) - - self.updateAllImages() - - def relabelSequentialCallback(self): - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer' or mode == 'Cell cycle analysis': - self.startBlinkingModeCB() - return - - posData = self.data[self.pos_i] - selectedPos = (posData.pos_foldername, ) - if len(self.data) > 1: - selectedPos = self.askSelectPos(action='to process') - if selectedPos is None: - self.logger.info('Re-labelling process stopped.') - return - - self.store_data() - # acdc_df_concat = self.getConcatAcdcDf() - # load.store_unsaved_acdc_df( - # posData, acdc_df_concat, - # log_func=self.logger.info - # ) - # if posData.SizeT > 1: - self.progressWin = apps.QDialogWorkerProgress( - title='Re-labelling sequential', parent=self, - pbarDesc='Relabelling sequential...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - self.startRelabellingWorker(selectedPos) - - # elif posData: - # self.storeUndoRedoStates(False) - # posData.lab, oldIDs, newIDs = core.relabel_sequential(posData.lab) - # # Update annotations based on relabelling - # self.update_cca_df_relabelling(posData, oldIDs, newIDs) - # self.updateAnnotatedIDs(oldIDs, newIDs, logger=self.logger.info) - # self.store_data() - # self.update_rp() - # li = list(zip(oldIDs, newIDs)) - # s = '\n'.join([str(pair).replace(',', ' -->') for pair in li]) - # s = f'IDs relabelled as follows:\n{s}' - # self.logger.info(s) - # self.updateAllImages() - - def updateAnnotatedIDs(self, oldIDs, newIDs, logger=print): - logger('Updating annotated IDs...') - posData = self.data[self.pos_i] - - mapper = dict(zip(oldIDs, newIDs)) - posData.ripIDs = set([mapper[ripID] for ripID in posData.ripIDs]) - posData.binnedIDs = set([mapper[binID] for binID in posData.binnedIDs]) - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - - customAnnotButtons = list(self.customAnnotDict.keys()) - for button in customAnnotButtons: - customAnnotValues = self.customAnnotDict[button] - annotatedIDs = customAnnotValues['annotatedIDs'][self.pos_i] - mappedAnnotIDs = {} - for frame_i, annotIDs_i in annotatedIDs.items(): - mappedIDs = [mapper[ID] for ID in annotIDs_i] - mappedAnnotIDs[frame_i] = mappedIDs - customAnnotValues['annotatedIDs'][self.pos_i] = mappedAnnotIDs - - def rtTrackerActionToggled(self, checked): - if not checked: - return - - aliases = myutils.aliases_real_time_trackers(reverse=True) - if self.sender().text() in aliases: - trackingAlgo = aliases[self.sender().text()] - else: - trackingAlgo = self.sender().text() - self.df_settings.at['tracking_algorithm', 'value'] = trackingAlgo - self.df_settings.to_csv(self.settings_csv_path) - - if self.sender().text() == 'YeaZ': - msg = widgets.myMessageBox(wrapText=False) - info_txt = html_utils.paragraph(f""" - Note that YeaZ tracking algorithm tends to be sliglhtly more accurate - overall, but it is less capable of detecting segmentation - errors.

- If you need to correct as many segmentation errors as possible - we recommend using Cell-ACDC tracking algorithm. - """) - msg.information(self, 'Info about YeaZ', info_txt) - - self.isRealTimeTrackerInitialized = False - self.initRealTimeTracker() - - def autoPilotToggled(self, checked): - self.autoPilotZoomToObjToolbar.setVisible(checked) - if checked: - self.autoPilotZoomToObjToggle.setChecked(False) - self.autoPilotZoomToObjToggle.toggle() - - def zoomRectActionToggled(self, checked): - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.sender()) - self.connectLeftClickButtons() - self.ax1.addItem(self.zoomRectItem) - else: - self.zoomRectItem.setPos((0,0)) - self.zoomRectItem.setSize((0,0)) - self.ax1.removeItem(self.zoomRectItem) - - def zoomRectDone(self): - xRange, yRange = self.ax1.viewRange() - self.zoomRectItem.storeLastRange(xRange, yRange) - - ymin, xmin, ymax, xmax = self.zoomRectItem.bbox() - - self.zoomRectItem.setPos((0,0)) - self.zoomRectItem.setSize((0,0)) - - self.ax1.setRange( - xRange=(xmin, xmax), - yRange=(ymin, ymax), - padding=0 - ) - - def zoomRectCancelled(self): - self.isMouseDragImg1 = False - self.zoomRectItem.setPos((0,0)) - self.zoomRectItem.setSize((0,0)) - - def findID(self, checked=False, ID=None): - posData = self.data[self.pos_i] - if ID is None: - searchIDdialog = apps.FindIDDialog( - title='Search object by ID', - msg='Enter object ID to find and highlight', - parent=self, - isInteger=True - ) - searchIDdialog.exec_() - if searchIDdialog.cancel: - return - - searchedID = searchIDdialog.EntryID - else: - searchedID = ID - - if searchedID in posData.IDs: - self.goToObjectID(searchedID) - return - - if posData.SizeT == 1: - self.warnIDnotFound(searchedID) - return - - if searchedID in posData.lost_IDs: - self.goToLostObjectID(searchedID) - return - - tracked_lost_IDs = self.getTrackedLostIDs() - if searchedID in tracked_lost_IDs: - self.goToAcceptedLostObjectID(searchedID) - return - - self.logger.info(f'Searching ID {searchedID} in other frames...') - - frame_i_found = self.startSearchIDworker(searchedID) - if frame_i_found is None: - self.warnIDnotFound(searchedID) - return - - self.logger.info( - f'Object ID {searchedID} found at frame n. {frame_i_found+1}.' - ) - proceed = self.askGoToFrameFoundID(searchedID, frame_i_found) - if not proceed: - return - - posData.frame_i = frame_i_found - self.get_data() - self.updateAllImages() - self.updateScrollbars() - - self.goToObjectID(searchedID) - - @disableWindow - def startSearchIDworker(self, searchedID): - posData = self.data[self.pos_i] - - desc = 'Searching ID in all frames...' - - self.progressWin = apps.QDialogWorkerProgress( - title=desc, parent=self.mainWin, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(posData.SizeT) - self.progressWin.show(self.app) - - self.searchIDthread = QThread() - self.searchIDworker = workers.SimpleWorker( - posData, self.searchIDworkerCallback, - func_args=(searchedID, ) - ) - self.searchIDworker.frame_i_found = None - self.searchIDworker.moveToThread(self.searchIDthread) - - self.searchIDworker.signals.finished.connect( - self.searchIDthread.quit - ) - self.searchIDworker.signals.finished.connect( - self.searchIDworker.deleteLater - ) - self.searchIDthread.finished.connect(self.searchIDthread.deleteLater) - - self.searchIDworker.signals.critical.connect( - self.searchIDworkerCritical - ) - self.searchIDworker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.searchIDworker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.searchIDworker.signals.progress.connect( - self.workerProgress - ) - self.searchIDworker.signals.finished.connect( - self.searchIDworkerFinished - ) - - self.searchIDthread.started.connect(self.searchIDworker.run) - self.searchIDthread.start() - - self.searchIDworkerLoop = QEventLoop() - self.searchIDworkerLoop.exec_() - - return self.searchIDworker.frame_i_found - - def searchIDworkerCritical(self, error): - self.searchIDworkerLoop.exit() - self.workerCritical(error) - - def searchIDworkerFinished(self): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.searchIDworkerLoop.exit() - - def searchIDworkerCallback(self, posData, searchedID): - self.searchIDworker.signals.initProgressBar.emit(0) - self.setAllIDs() - self.searchIDworker.signals.initProgressBar.emit(posData.SizeT) - frame_i_found = None - for frame_i in range(len(posData.segm_data)): - if frame_i >= len(posData.allData_li): - break - lab = posData.allData_li[frame_i]['labels'] - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - IDs = set([obj.label for obj in rp]) - else: - IDs = posData.allData_li[frame_i]['IDs'] - - if searchedID in IDs: - frame_i_found = frame_i - break - - self.searchIDworker.signals.progressBar.emit(1) - - self.searchIDworker.frame_i_found = frame_i_found - - def warnIDnotFound(self, searchedID): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(f""" - Object ID {searchedID} was not found.

- """) - msg.warning(self, f'ID {searchedID} not found', txt) - - def goToObjectID(self, ID): - posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[ID] - obj = posData.rp[objIdx] - self.goToZsliceSearchedID(obj) - - self.highlightSearchedID(ID) - propsQGBox = self.guiTabControl.propsQGBox - propsQGBox.idSB.setValue(ID) - - def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)): - posData = self.data[self.pos_i] - frame_i = posData.frame_i - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[lostID]] - self.goToZsliceSearchedID(obj) - - imageItem = self.getLostObjImageItem(0) - thickness = 1 - if not hasattr(self, 'lostObjContoursImage'): - self.initLostObjContoursImage() - else: - self.lostObjContoursImage[:] = 0 - - contours = [] - obj_contours = self.getObjContours(obj, all_external=True) - contours.extend(obj_contours) - - self.addLostObjsToLostObjImage(obj, lostID) - self.drawLostObjContoursImage( - imageItem, contours, thickness=2, color=color - ) - - def goToAcceptedLostObjectID(self, acceptedLostID): - posData = self.data[self.pos_i] - frame_i = posData.frame_i - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[acceptedLostID]] - self.goToZsliceSearchedID(obj) - - self.updateLostTrackedContoursImage(tracked_lost_IDs=[acceptedLostID]) - - def askGoToFrameFoundID(self, searchedID, frame_i_found): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(f""" - Object ID {searchedID} was found at frame n. {frame_i_found+1}.

- Do you want to go to frame n. {frame_i_found+1}. - """) - noButton, yesButton = msg.information( - self, f'ID {searchedID} found at frame n. {frame_i_found+1}', txt, - buttonsTexts=( - 'No, stay on current frame', - f'Yes, go to frame n. {frame_i_found+1}' - ) - ) - return msg.clickedButton == yesButton - - def skipForwardToNewID(self): - self.progressWin = apps.QDialogWorkerProgress( - title='Searching the next frame with a new object', parent=self, - pbarDesc=f'Searching the next frame with a new object...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - self.startFindNextNewIdWorker() - - def startFindNextNewIdWorker(self): - posData = self.data[self.pos_i] - self._thread = QThread() - self.findNextNewIdWorker = workers.FindNextNewIdWorker(posData, self) - self.findNextNewIdWorker.moveToThread(self._thread) - - self.findNextNewIdWorker.signals.finished.connect(self._thread.quit) - self.findNextNewIdWorker.signals.finished.connect( - self.findNextNewIdWorker.deleteLater - ) - self._thread.finished.connect(self._thread.deleteLater) - - self.findNextNewIdWorker.signals.finished.connect( - self.findNextNewIdWorkerFinished - ) - self.findNextNewIdWorker.signals.progress.connect(self.workerProgress) - self.findNextNewIdWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.findNextNewIdWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.findNextNewIdWorker.signals.critical.connect( - self.workerCritical - ) - - self._thread.started.connect(self.findNextNewIdWorker.run) - self._thread.start() - - def findNextNewIdWorkerFinished(self, next_frame_i): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.navSpinBox.setValue(next_frame_i+1) - self.framesScrollBarReleased() - - def workerProgress(self, text, loggerLevel='INFO'): # used in cca and lin tree - if self.progressWin is not None: - self.progressWin.logConsole.append(text) - self.logger.log(getattr(logging, loggerLevel), text) - - def workerFinished(self): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.logger.info('Worker process ended.') - self.updateAllImages() - self.titleLabel.setText('Done', color='w') - - def savePreprocWorkerFinished(self): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.setStatusBarLabel() - self.logger.info('Pre-processed data saved!') - self.titleLabel.setText('Pre-processed data saved!', color='w') - - def delObjsOutSegmMaskWorkerFinished(self, result): - posData = self.data[self.pos_i] - worker, cleared_segm_data, delIDs = result - if posData.SizeT == 1: - cleared_segm_data = cleared_segm_data[np.newaxis] - - self.update_cca_df_deletedIDs(posData, delIDs) - - current_frame_i = posData.frame_i - for frame_i, cleared_lab in enumerate(cleared_segm_data): - # Store change - posData.allData_li[frame_i]['labels'] = cleared_lab - # Get the rest of the stored metadata based on the new lab - posData.frame_i = frame_i - self.get_data() - self.store_data(autosave=False) - - # Back to current frame - posData.frame_i = current_frame_i - self.get_data() - - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.logger.info('Deleting objects outside of ROIs finished.') - self.titleLabel.setText( - 'Deleting objects outside of ROIs finished.', color='w' - ) - self.updateAllImages() - - def loadingNewChunk(self, chunk_range): - coord0_chunk, coord1_chunk = chunk_range - desc = ( - f'Loading new window, range = ({coord0_chunk}, {coord1_chunk})...' - ) - self.progressWin = apps.QDialogWorkerProgress( - title='Loading data...', parent=self, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.show(self.app) - - def lazyLoaderFinished(self): - self.logger.info('Load chunk data worker done.') - if self.lazyLoader.updateImgOnFinished: - self.updateAllImages() - - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - @exception_handler - def trackingWorkerFinished(self): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.logger.info('Worker process ended.') - askDisableRealTimeTracking = ( - self.trackingWorker.trackingOnNeverVisitedFrames - and self.realTimeTrackingToggle.isChecked() - ) - if askDisableRealTimeTracking: - msg = widgets.myMessageBox() - title = 'Disable real-time tracking?' - txt = ( - 'You perfomed tracking on frames that you have ' - 'never visited.

' - 'Cell-ACDC default behaviour is to track them again when you ' - 'will visit them.

' - 'However, you can overwrite this behaviour and explicitly ' - 'disable tracking for all of the frames you already tracked.

' - 'NOTE: you can reactivate real-time tracking by clicking on the ' - '"Reset last segmented frame" button on the top toolbar.

' - 'What do you want me to do?' - ) - _, disableTrackingButton = msg.information( - self, title, html_utils.paragraph(txt), - buttonsTexts=( - 'Keep real-time tracking active (recommended)', - 'Disable real-time tracking' - ) - ) - if msg.clickedButton == disableTrackingButton: - self.logger.info('Disabling real time tracking...') - self.realTimeTrackingToggle.setChecked(False) - # posData = self.data[self.pos_i] - # current_frame_i = posData.frame_i - # for frame_i in range(self.start_n-1, self.stop_n): - # posData.frame_i = frame_i - # self.get_data() - # self.store_data(autosave=frame_i==self.stop_n-1) - # posData.last_tracked_i = frame_i - # self.setNavigateScrollBarMaximum() - - # # Back to current frame - # posData.frame_i = current_frame_i - # self.get_data() - posData = self.data[self.pos_i] - self.updateAllImages() - self.titleLabel.setText('Done', color='w') - - def workerInitProgressbar(self, totalIter): - self.progressWin.mainPbar.setValue(0) - if totalIter == 1: - totalIter = 0 - self.progressWin.mainPbar.setMaximum(totalIter) - - def workerUpdateProgressbar(self, step): - self.progressWin.mainPbar.update(step) - - def workerInitInnerPbar(self, totalIter): - self.progressWin.innerPbar.setValue(0) - if totalIter == 1: - totalIter = 0 - self.progressWin.innerPbar.setMaximum(totalIter) - - def workerUpdateInnerPbar(self, step): - self.progressWin.innerPbar.update(step) - - def startTrackingWorker(self, posData, video_to_track): - self.thread = QThread() - self.trackingWorker = workers.trackingWorker( - posData, self, video_to_track - ) - self.trackingWorker.moveToThread(self.thread) - self.trackingWorker.finished.connect(self.thread.quit) - self.trackingWorker.finished.connect(self.trackingWorker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - - # Custom signals - self.trackingWorker.signals.progress = self.trackingWorker.progress - self.trackingWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.trackingWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.trackingWorker.signals.sigInitInnerPbar.connect( - self.workerInitInnerPbar - ) - self.trackingWorker.progress.connect(self.workerProgress) - self.trackingWorker.critical.connect(self.workerCritical) - self.trackingWorker.finished.connect(self.trackingWorkerFinished) - - self.trackingWorker.debug.connect(self.workerDebug) - - self.thread.started.connect(self.trackingWorker.run) - self.thread.start() - - def startRelabellingWorker(self, posFoldernames): - self.thread = QThread() - self.worker = workers.relabelSequentialWorker(self, posFoldernames) - self.worker.moveToThread(self.thread) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - - self.worker.progress.connect(self.workerProgress) - self.worker.critical.connect(self.workerCritical) - self.worker.finished.connect(self.workerFinished) - self.worker.finished.connect(self.relabelWorkerFinished) - - self.worker.debug.connect(self.workerDebug) - - self.thread.started.connect(self.worker.run) - self.thread.start() - - def startPostProcessSegmWorker( - self, postProcessKwargs, customPostProcessGroupedFeatures, - customPostProcessFeatures - ): - self.thread = QThread() - self.postProcessWorker = workers.PostProcessSegmWorker( - postProcessKwargs, customPostProcessGroupedFeatures, - customPostProcessFeatures, self - ) - - self.postProcessWorker.moveToThread(self.thread) - self.postProcessWorker.signals.finished.connect(self.thread.quit) - self.postProcessWorker.signals.finished.connect( - self.postProcessWorker.deleteLater - ) - self.thread.finished.connect(self.thread.deleteLater) - - self.postProcessWorker.signals.finished.connect( - self.postProcessSegmWorkerFinished - ) - self.postProcessWorker.signals.progress.connect(self.workerProgress) - self.postProcessWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.postProcessWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.postProcessWorker.signals.critical.connect( - self.workerCritical - ) - - self.thread.started.connect(self.postProcessWorker.run) - self.thread.start() - - def relabelWorkerFinished(self): - self.updateAllImages() - - def workerDebug(self, item): - tracked_video, worker = item - from cellacdc.plot import imshow - imshow(tracked_video) - worker.waitCond.wakeAll() - - def keepToolActiveActionToggled(self, checked, toolName=None): - if toolName is None: - parentToolButton = self.sender().parent() - toolName = re.findall(r'Name: (.*)', parentToolButton.toolTip())[0] - - if checked: - self.df_settings.at[toolName, 'value'] = 'keepActive' - else: - self.df_settings = self.df_settings.drop( - index=toolName, errors='ignore' - ) - self.df_settings.to_csv(self.settings_csv_path) - - def applyToolNewFrameActionToggled(self, checked, toolName=None): - if toolName is None: - parentToolButton = self.sender().parent() - toolName = re.findall(r'Name: (.*)', parentToolButton.toolTip())[0] - toolName = toolName.strip() - button = self.applyToolNewFrameButtons[toolName] - toolName = toolName.replace(' ', '_') - settingName = f'{toolName}_applyNewFrame' - if checked: - self.df_settings.at[settingName, 'value'] = 'applyNewFrame' - button.setStyleSheet(f'background-color: {GREEN_HEX}') - else: - self.df_settings = self.df_settings.drop( - index=settingName, errors='ignore' - ) - button.setStyleSheet('background-color: none') - self.df_settings.to_csv(self.settings_csv_path) - - def keepAllToolsActiveActionToggled(self, checked): - for action in self.keepToolActiveActions.values(): - action.setChecked(checked) - - data_loaded = True - if not hasattr(self, 'data'): - data_loaded = False - try: - self.labelRoiTrangeCheckbox.disconnect() - except TypeError: - pass - self.labelRoiTrangeCheckbox.setChecked(checked) # why this is not wrapped in a QAction? - - if data_loaded: - self.labelRoiTrangeCheckbox.toggled.connect( - self.labelRoiTrangeCheckboxToggled - ) - - def determineSlideshowWinPos(self): - screens = self.app.screens() - self.numScreens = len(screens) - winScreen = self.screen() - - # Center main window and determine location of slideshow window - # depending on number of screens available - if self.numScreens > 1: - for screen in screens: - if screen != winScreen: - winScreen = screen - break - - winScreenGeom = winScreen.geometry() - winScreenCenter = winScreenGeom.center() - winScreenCenterX = winScreenCenter.x() - winScreenCenterY = winScreenCenter.y() - winScreenLeft = winScreenGeom.left() - winScreenTop = winScreenGeom.top() - self.slideshowWinLeft = winScreenCenterX - int(850/2) - self.slideshowWinTop = winScreenCenterY - int(800/2) - - def nonViewerEditMenuOpened(self): - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - self.startBlinkingModeCB() - - def getDistantGray(self, desiredGray, bkgrGray): - isDesiredSimilarToBkgr = ( - abs(desiredGray-bkgrGray) < 0.3 - ) - if isDesiredSimilarToBkgr: - return 1-desiredGray - else: - return desiredGray - - def RGBtoGray(self, R, G, B): - # see https://stackoverflow.com/questions/17615963/standard-rgb-to-grayscale-conversion - C_linear = (0.2126*R + 0.7152*G + 0.0722*B)/255 - if C_linear <= 0.0031309: - gray = 12.92*C_linear - else: - gray = 1.055*(C_linear)**(1/2.4) - 0.055 - return gray - - def ruler_cb(self, checked): - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.sender()) - self.connectLeftClickButtons() - else: - self.tempSegmentON = False - self.ax1_rulerPlotItem.setData([], []) - self.ax1_rulerAnchorsItem.setData([], []) - - def editImgProperties(self, checked=True): - posData = self.data[self.pos_i] - posData.askInputMetadata( - len(self.data), - ask_SizeT=True, - ask_TimeIncrement=True, - ask_PhysicalSizes=True, - save=True, singlePos=True, - askSegm3D=False - ) - if hasattr(self, 'timestamp'): - self.timestamp.setSecondsPerFrame(posData.TimeIncrement) - self.updateTimestampFrame() - - if hasattr(self, 'scaleBar'): - self.scaleBar.updatePhysicalLength(posData.PhysicalSizeX) - - def setHoverToolSymbolData(self, xx, yy, ScatterItems, size=None): - if not xx: - self.ax1_lostObjScatterItem.setVisible(True) - self.ax2_lostObjScatterItem.setVisible(True) - - self.ax1_lostTrackedScatterItem.setVisible(True) - self.ax2_lostTrackedScatterItem.setVisible(True) - - for item in ScatterItems: - if size is None: - item.setData(xx, yy) - else: - item.setData(xx, yy, size=size) - - def updateLabelRoiCircularSize(self, value): - self.labelRoiCircItemLeft.setSize(value) - self.labelRoiCircItemRight.setSize(value) - - def updateLabelRoiCircularCursor(self, x, y, checked): - if not self.labelRoiButton.isChecked(): - return - if not self.labelRoiIsCircularRadioButton.isChecked(): - return - if self.labelRoiRunning: - return - - size = self.labelRoiCircularRadiusSpinbox.value() - if not checked: - xx, yy = [], [] - else: - xx, yy = [x], [y] - - if not xx and len(self.labelRoiCircItemLeft.getData()[0]) == 0: - return - - self.labelRoiCircItemLeft.setData(xx, yy, size=size) - self.labelRoiCircItemRight.setData(xx, yy, size=size) - - def getLabelRoiImage(self): - posData = self.data[self.pos_i] - - if self.labelRoiTrangeCheckbox.isChecked(): - start_frame_i = self.labelRoiStartFrameNoSpinbox.value()-1 - stop_frame_n = self.labelRoiStopFrameNoSpinbox.value() - tRangeLen = stop_frame_n-start_frame_i - else: - tRangeLen = 1 - - if tRangeLen > 1: - tRange = (start_frame_i, stop_frame_n) - else: - tRange = None - - if self.isSegm3D: - if tRangeLen > 1: - imgData = posData.img_data - else: - # Filtered data not existing - imgData = posData.img_data[posData.frame_i] - - roi_zdepth = self.labelRoiZdepthSpinbox.value() - if roi_zdepth == posData.SizeZ: - z0 = 0 - z1 = posData.SizeZ - elif roi_zdepth == 1: - z0 = self.zSliceScrollBar.sliderPosition() - z1 = z0 + 1 - else: - if roi_zdepth%2 != 0: - roi_zdepth +=1 - half_zdepth = int(roi_zdepth/2) - zc = self.zSliceScrollBar.sliderPosition() + 1 - z0 = zc-half_zdepth - z0 = z0 if z0>=0 else 0 - z1 = zc+half_zdepth - z1 = z1 if z1 1: - imgData = posData.img_data - else: - imgData = self.img1.image - - roiImg = imgData[labelRoiSlice] - if self.labelRoiIsFreeHandRadioButton.isChecked(): - mask = self.freeRoiItem.mask() - elif self.labelRoiIsCircularRadioButton.isChecked(): - mask = self.labelRoiCircItemLeft.mask() - else: - mask = None - - if mask is not None: - # Copy roiImg otherwise we are replacing minimum inside original image - roiImg = roiImg.copy() - # Fill outside of freehand roi with minimum of the ROI image - if tRangeLen > 1: - for i in range(tRangeLen): - ith_roiImg = roiImg[i] - if self.isSegm3D: - roiImg[i, :, ~mask] = ith_roiImg.min() - else: - roiImg[i, ~mask] = ith_roiImg.min() - else: - if self.isSegm3D: - roiImg[:, ~mask] = roiImg.min() - else: - roiImg[~mask] = roiImg.min() - - return roiImg, labelRoiSlice - - def getClickedID(self, xdata, ydata, text=''): - posData = self.data[self.pos_i] - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - msg = ( - 'You clicked on the background.\n' - f'Enter here the ID {text}' - ) - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), xdata, ydata - ) - clickedBkgrID = apps.QLineEditDialog( - title='Clicked on background', - msg=msg, parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - clickedBkgrID.exec_() - if clickedBkgrID.cancel: - return - else: - ID = clickedBkgrID.EntryID - return ID - - # @exec_time - def applyEditID( - self, clickedID, currentIDs, oldIDnewIDMapper, clicked_x, clicked_y, shift=False, doPropagateUnvisited=False - ): - posData = self.data[self.pos_i] - - # Ask to propagate change to all future visited frames - key = 'Edit ID' - askAction = self.askHowFutureFramesActions[key] - doNotShow = not askAction.isChecked() - (UndoFutFrames, applyFutFrames, endFrame_i, - doNotShowAgain) = self.propagateChange( - clickedID, key, doNotShow, - posData.UndoFutFrames_EditID, posData.applyFutFrames_EditID, - applyTrackingB=True - ) - - if UndoFutFrames is None: - return - - if shift and self.isSegm3D: - lab = self.get_2Dlab(posData.lab) - else: - lab = posData.lab - - # Store undo state before modifying stuff - self.storeUndoRedoStates(UndoFutFrames) - maxID = max(posData.IDs, default=0) - for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in currentIDs and not self.editIDmergeIDs: - tempID = maxID + 1 - lab[lab == old_ID] = maxID + 1 - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - - old_ID_idx = currentIDs.index(old_ID) - new_ID_idx = currentIDs.index(new_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - objo = posData.rp[old_ID_idx] - yo, xo = self.getObjCentroid(objo.centroid) - objn = posData.rp[new_ID_idx] - yn, xn = self.getObjCentroid(objn.centroid) - if not math.isnan(yo) and not math.isnan(yn): - yn, xn = int(yn), int(xn) - posData.editID_info.append((yn, xn, new_ID)) - yo, xo = int(clicked_y), int(clicked_x) - posData.editID_info.append((yo, xo, old_ID)) - else: - lab[lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID - old_ID_idx = posData.IDs.index(old_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - obj = posData.rp[old_ID_idx] - y, x = self.getObjCentroid(obj.centroid) - if not math.isnan(y) and not math.isnan(y): - y, x = int(y), int(x) - posData.editID_info.append((y, x, new_ID)) - - self.updateAssignedObjsAcdcTrackerSecondStep(new_ID) - - if shift and self.isSegm3D: - self.set_2Dlab(lab) - - # Update rps - self.update_rp() - - # Since we manually changed an ID we don't want to repeat tracking - self.setAllTextAnnotations() - self.highlightLostNew() - # self.checkIDsMultiContour() - - # Update colors for the edited IDs - self.updateLookuptable() - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Edit ID') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Edit ID', update_images=False) - - if not self.editIDbutton.findChild(QAction).isChecked(): - self.editIDbutton.setChecked(False) - - posData.disableAutoActivateViewerWindow = True - - # Perform desired action on future frames - posData.doNotShowAgain_EditID = doNotShowAgain - posData.UndoFutFrames_EditID = UndoFutFrames - posData.applyFutFrames_EditID = applyFutFrames - includeUnvisited = ( - posData.includeUnvisitedInfo['Edit ID'] - or doPropagateUnvisited - ) - - if not applyFutFrames and not doPropagateUnvisited: - return - - self.changeIDfutureFrames( - endFrame_i, oldIDnewIDMapper, includeUnvisited, - shift=shift - ) - - def getLastHoveredID(self): - if self.xHoverImg is None: - return 0 - - xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) - ID = self.currentLab2D[ydata, xdata] - return ID - - def getHoverID(self, xdata, ydata, byPassShiftCheck=False): - if not hasattr(self, 'diskMask'): - return 0 - - modifiers = QGuiApplication.keyboardModifiers() - ctrl = modifiers == Qt.ControlModifier - if byPassShiftCheck: - shift = False - else: - shift = modifiers == Qt.ShiftModifier - - if self.isPowerBrush() and not ctrl: - return 0 - - if not self.autoIDcheckbox.isChecked(): - return self.editIDspinbox.value() - - ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) - posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) - ID = lab_2D[ydata, xdata] - self.isHoverZneighID = False - if self.isSegm3D: - z = self.z_lab() - SizeZ = posData.lab.shape[0] - doNotLinkThroughZ = self.brushButton.isChecked() and shift - if doNotLinkThroughZ: - if self.brushHoverCenterModeAction.isChecked() or ID>0: - hoverID = ID - else: - masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask] - hoverID = np.bincount(masked_lab).argmax() - else: - if z > 0: - ID_z_under = posData.lab[z-1, ydata, xdata] - if self.brushHoverCenterModeAction.isChecked() or ID_z_under>0: - hoverIDa = ID_z_under - else: - lab = posData.lab - masked_lab_a = lab[z-1, ymin:ymax, xmin:xmax][diskMask] - hoverIDa = np.bincount(masked_lab_a).argmax() - else: - hoverIDa = 0 - - if self.brushHoverCenterModeAction.isChecked() or ID>0: - hoverIDb = lab_2D[ydata, xdata] - else: - masked_lab_b = lab_2D[ymin:ymax, xmin:xmax][diskMask] - hoverIDb = np.bincount(masked_lab_b).argmax() - - if z < SizeZ-1: - ID_z_above = posData.lab[z+1, ydata, xdata] - if self.brushHoverCenterModeAction.isChecked() or ID_z_above>0: - hoverIDc = ID_z_above - else: - lab = posData.lab - masked_lab_c = lab[z+1, ymin:ymax, xmin:xmax][diskMask] - hoverIDc = np.bincount(masked_lab_c).argmax() - else: - hoverIDc = 0 - - if hoverIDa > 0: - hoverID = hoverIDa - self.isHoverZneighID = True - elif hoverIDb > 0: - hoverID = hoverIDb - elif hoverIDc > 0: - hoverID = hoverIDc - self.isHoverZneighID = True - else: - hoverID = 0 - else: - if self.brushButton.isChecked() and shift: - # Force new ID with brush and Shift - hoverID = 0 - elif self.brushHoverCenterModeAction.isChecked() or ID>0: - hoverID = ID - else: - masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask] - hoverID = np.bincount(masked_lab).argmax() - - self.editIDspinbox.setValue(hoverID) - - return hoverID - - def setHoverToolSymbolColor( - self, xdata, ydata, pen, ScatterItems, button, - brush=None, hoverRGB=None, ID=None, byPassShiftCheck=False - ): - modifiers = QGuiApplication.keyboardModifiers() - if byPassShiftCheck: - shift = False - else: - shift = modifiers == Qt.ShiftModifier - - posData = self.data[self.pos_i] - Y, X = self.get_2Dlab(posData.lab).shape - if not myutils.is_in_bounds(xdata, ydata, X, Y): - return - - self.isHoverZneighID = False - if ID is None: - hoverID = self.getHoverID( - xdata, ydata, byPassShiftCheck=byPassShiftCheck - ) - else: - hoverID = ID - - if hoverID == 0: - for item in ScatterItems: - item.setPen(pen) - item.setBrush(brush) - else: - try: - rgb = self.lut[hoverID] - rgb = rgb if hoverRGB is None else hoverRGB - rgbPen = np.clip(rgb*1.1, 0, 255) - for item in ScatterItems: - item.setPen(*rgbPen, width=2) - item.setBrush(*rgb, 100) - except IndexError: - pass - - checkChangeID = ( - self.isHoverZneighID and not shift - and self.lastHoverID != hoverID - ) - if checkChangeID: - # We are hovering an ID in z+1 or z-1 - self.restoreBrushID = hoverID - # self.changeBrushID() - - self.lastHoverID = hoverID - - def isPowerBrush(self): - color = self.brushButton.palette().button().color().name() - return color == self.doublePressKeyButtonColor - - def isPowerEraser(self): - color = self.eraserButton.palette().button().color().name() - return color == self.doublePressKeyButtonColor - - def isPowerButton(self, button): - color = button.palette().button().color().name() - return color == self.doublePressKeyButtonColor - - def getCheckNormAction(self): - normalize = False - how = '' - for action in self.normalizeQActionGroup.actions(): - if action.isChecked(): - how = action.text() - normalize = True - break - return action, normalize, how - - def normalizeIntensities(self, img): - action, normalize, how = self.getCheckNormAction() - if not normalize: - return img - - if how == 'Do not normalize. Display raw image': - img = img - elif how == 'Convert to floating point format with values [0, 1]': - img = myutils.img_to_float(img) - # elif how == 'Rescale to 8-bit unsigned integer format with values [0, 255]': - # img = skimage.img_as_float(img) - # img = (img*255).astype(np.uint8) - # return img - elif how == 'Rescale to [0, 1]': - img = skimage.img_as_float(img) - img = skimage.exposure.rescale_intensity(img) - elif how == 'Normalize by max value': - img = img/np.max(img) - return img - - def removeAlldelROIsCurrentFrame(self): - posData = self.data[self.pos_i] - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - rois = delROIs_info['rois'].copy() - for roi in rois: - self.ax2.removeDelRoiItem(roi) - - for item in self.ax2.items: - if isinstance(item, pg.ROI): - self.ax2.removeDelRoiItem(item) - - for item in self.ax1.items: - if isinstance(item, pg.ROI) and item != self.labelRoiItem: - self.ax1.removeDelRoiItem(item) - - def removeDelROI(self, event): - posData = self.data[self.pos_i] - - for ax in (self.ax1, self.ax2): - try: - self.ax1.removeDelRoiItem(self.roi_to_del) - except Exception as err: - pass - - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - idx = delROIs_info['rois'].index(self.roi_to_del) - delROIs_info['rois'].pop(idx) - delROIs_info['delMasks'].pop(idx) - delROIs_info['delIDsROI'].pop(idx) - delROIs_info['state'].pop(idx) - - self.removeDelROIFromFutureFrames(self.roi_to_del) - self.updateAllImages() - - def removeDelROIFromFutureFrames(self, roi_to_del): - posData = self.data[self.pos_i] - - # Restore deleted IDs from already visited future frames - current_frame_i = posData.frame_i - for i in range(posData.frame_i+1, posData.SizeT): - if posData.allData_li[i]['labels'] is None: - break - - delROIs_info = posData.allData_li[i]['delROIs_info'] - try: - idx = delROIs_info['rois'].index(roi_to_del) - except IndexError: - continue - - posData.frame_i = i - idx = delROIs_info['rois'].index(roi_to_del) - if delROIs_info['delIDsROI'][idx]: - posData.lab = posData.allData_li[i]['labels'] - self.restoreAnnotDelROI(roi_to_del, enforce=True, draw=False) - posData.allData_li[i]['labels'] = posData.lab - self.get_data() - self.store_data(autosave=False) - delROIs_info['rois'].pop(idx) - delROIs_info['delMasks'].pop(idx) - delROIs_info['delIDsROI'].pop(idx) - delROIs_info['state'].pop(idx) - - if isinstance(self.roi_to_del, pg.PolyLineROI): - # PolyLine ROIs are only on ax1 - self.ax1.removeItem(self.roi_to_del) - elif not self.labelsGrad.showLabelsImgAction.isChecked(): - # Rect ROI is on ax1 because ax2 is hidden - self.ax1.removeItem(self.roi_to_del) - else: - # Rect ROI is on ax2 because ax2 is visible - self.ax2.removeItem(self.roi_to_del) - - # Back to current frame - posData.frame_i = current_frame_i - posData.lab = posData.allData_li[posData.frame_i]['labels'] - self.get_data() - self.store_data() - - def updateDelROIinFutureFrames(self, roi: pg.ROI): - posData = self.data[self.pos_i] - restore_current_frame = False - - roiState = roi.getState() - # Restore deleted IDs from already visited future frames - current_frame_i = posData.frame_i - delROIs_info = posData.allData_li[current_frame_i]['delROIs_info'] - try: - idx = delROIs_info['rois'].index(roi) - delROIs_info['state'][idx] = roiState - except Exception as err: - pass - - self.store_data() - - for i in range(posData.frame_i+1, posData.SizeT): - delROIs_info = posData.allData_li[i]['delROIs_info'] - try: - idx = delROIs_info['rois'].index(roi) - except Exception as err: - continue - delROIs_info['state'][idx] = roiState - if posData.allData_li[i]['labels'] is None: - continue - - posData.frame_i = i - posData.lab = posData.allData_li[i]['labels'] - self.restoreAnnotDelROI(roi, enforce=False, draw=False) - posData.allData_li[i]['labels'] = posData.lab - self.get_data() - self.store_data(autosave=False) - restore_current_frame = True - - if not restore_current_frame: - return - - # Back to current frame - posData.frame_i = current_frame_i - posData.lab = posData.allData_li[posData.frame_i]['labels'] - self.get_data() - self.store_data() - - # @exec_time - def getPolygonBrush(self, yxc2, Y, X): - # see https://en.wikipedia.org/wiki/Tangent_lines_to_circles - y1, x1 = self.yPressAx2, self.xPressAx2 - y2, x2 = yxc2 - R = self.brushSizeSpinbox.value() - r = R - - arcsin_den = np.sqrt((x2-x1)**2+(y2-y1)**2) - arctan_den = (x2-x1) - if arcsin_den!=0 and arctan_den!=0: - beta = np.arcsin((R-r)/arcsin_den) - gamma = -np.arctan((y2-y1)/arctan_den) - alpha = gamma-beta - x3 = x1 + r*np.sin(alpha) - y3 = y1 + r*np.cos(alpha) - x4 = x2 + R*np.sin(alpha) - y4 = y2 + R*np.cos(alpha) - - alpha = gamma+beta - x5 = x1 - r*np.sin(alpha) - y5 = y1 - r*np.cos(alpha) - x6 = x2 - R*np.sin(alpha) - y6 = y2 - R*np.cos(alpha) - - rr_poly, cc_poly = skimage.draw.polygon([y3, y4, y6, y5], - [x3, x4, x6, x5], - shape=(Y, X)) - else: - rr_poly, cc_poly = [], [] - - self.yPressAx2, self.xPressAx2 = y2, x2 - return rr_poly, cc_poly - - def get_dir_coords(self, alfa_dir, yd, xd, shape, connectivity=1): - h, w = shape - y_above = yd+1 if yd+1 < h else yd - y_below = yd-1 if yd > 0 else yd - x_right = xd+1 if xd+1 < w else xd - x_left = xd-1 if xd > 0 else xd - if alfa_dir == 0: - yy = [y_below, y_below, yd, y_above, y_above] - xx = [xd, x_right, x_right, x_right, xd] - elif alfa_dir == 45: - yy = [y_below, y_below, y_below, yd, y_above] - xx = [x_left, xd, x_right, x_right, x_right] - elif alfa_dir == 90: - yy = [yd, y_below, y_below, y_below, yd] - xx = [x_left, x_left, xd, x_right, x_right] - elif alfa_dir == 135: - yy = [y_above, yd, y_below, y_below, y_below] - xx = [x_left, x_left, x_left, xd, x_right] - elif alfa_dir == -180 or alfa_dir == 180: - yy = [y_above, y_above, yd, y_below, y_below] - xx = [xd, x_left, x_left, x_left, xd] - elif alfa_dir == -135: - yy = [y_below, yd, y_above, y_above, y_above] - xx = [x_left, x_left, x_left, xd, x_right] - elif alfa_dir == -90: - yy = [yd, y_above, y_above, y_above, yd] - xx = [x_left, x_left, xd, x_right, x_right] - else: - yy = [y_above, y_above, y_above, yd, y_below] - xx = [x_left, xd, x_right, x_right, x_right] - if connectivity == 1: - return yy[1:4], xx[1:4] - else: - return yy, xx - - def drawAutoContour(self, y2, x2): - y1, x1 = self.autoCont_y0, self.autoCont_x0 - Dy = abs(y2-y1) - Dx = abs(x2-x1) - edge = self.getDisplayedImg1() - if Dy != 0 or Dx != 0: - # NOTE: numIter takes care of any lag in mouseMoveEvent - numIter = int(round(max((Dy, Dx)))) - alfa = np.arctan2(y1-y2, x2-x1) - base = np.pi/4 - alfa_dir = round((base * round(alfa/base))*180/np.pi) - for _ in range(numIter): - y1, x1 = self.autoCont_y0, self.autoCont_x0 - yy, xx = self.get_dir_coords(alfa_dir, y1, x1, edge.shape) - a_dir = edge[yy, xx] - min_int = np.max(a_dir) - min_i = list(a_dir).index(min_int) - y, x = yy[min_i], xx[min_i] - try: - xx, yy = self.curvHoverPlotItem.getData() - except TypeError: - xx, yy = [], [] - - if xx is None or yy is None or len(xx) == 0 or len(yy) == 0: - xx, yy = [], [] - elif x == xx[-1] and y == yy[-1]: - # Do not append point equal to last point - return - - xx = np.r_[xx, x] - yy = np.r_[yy, y] - try: - self.curvHoverPlotItem.setData(xx, yy) - self.curvPlotItem.setData(xx, yy) - except TypeError: - pass - self.autoCont_y0, self.autoCont_x0 = y, x - # self.smoothAutoContWithSpline() - - def smoothAutoContWithSpline(self, n=3): - try: - xx, yy = self.curvHoverPlotItem.getData() - if xx is None or yy is None: - return - # Downsample by taking every nth coord - xxA, yyA = xx[::n], yy[::n] - rr, cc = skimage.draw.polygon(yyA, xxA) - self.autoContObjMask[rr, cc] = 1 - rp = skimage.measure.regionprops(self.autoContObjMask) - if not rp: - return - obj = rp[0] - cont = self.getObjContours(obj) - xxC, yyC = cont[:,0], cont[:,1] - xxA, yyA = xxC[::n], yyC[::n] - self.xxA_autoCont, self.yyA_autoCont = xxA, yyA - xxS, yyS = self.getSpline(xxA, yyA, per=True, appendFirst=True) - if len(xxS)>0: - self.curvPlotItem.setData(xxS, yyS) - except (TypeError, ValueError): - pass - - def updateIsHistoryKnown(): - """ - This function is called every time the user saves and it is used - for updating the status of cells where we don't know the history - - There are three possibilities: - - 1. The cell with unknown history is a BUD - --> we don't know when that bud emerged --> 'emerg_frame_i' = -1 - 2. The cell with unknown history is a MOTHER cell - --> we don't know emerging frame --> 'emerg_frame_i' = -1 - AND generation number --> we start from 'generation_num' = 2 - 3. The cell with unknown history is a CELL in G1 - --> we don't know emerging frame --> 'emerg_frame_i' = -1 - AND generation number --> we start from 'generation_num' = 2 - AND relative's ID in the previous cell cycle --> 'relative_ID' = -1 - """ - pass - - def getStatusKnownHistoryBud(self, ID): - posData = self.data[self.pos_i] - cca_df_ID = None - for i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - is_cell_existing = is_bud_existing = ID in cca_df_i.index - if not is_cell_existing: - bud_cca_dict = base_cca_dict.copy() - bud_cca_dict['cell_cycle_stage'] = 'S' - bud_cca_dict['generation_num'] = 0 - bud_cca_dict['relationship'] = 'bud' - bud_cca_dict['emerg_frame_i'] = i+1 - bud_cca_dict['is_history_known'] = True - cca_df_ID = pd.Series(bud_cca_dict) - return cca_df_ID - - def setHistoryKnowledge(self, ID, cca_df): - posData = self.data[self.pos_i] - is_history_known = cca_df.at[ID, 'is_history_known'] - if is_history_known: - cca_df.at[ID, 'is_history_known'] = False - cca_df.at[ID, 'cell_cycle_stage'] = 'G1' - cca_df.at[ID, 'generation_num'] += 2 - cca_df.at[ID, 'emerg_frame_i'] = -1 - cca_df.at[ID, 'relative_ID'] = -1 - cca_df.at[ID, 'relationship'] = 'mother' - else: - cca_df.loc[ID] = posData.ccaStatus_whenEmerged[ID] - - def annotateIsHistoryKnown(self, ID): - """ - This function is used for annotating that a cell has unknown or known - history. Cells with unknown history are for example the cells already - present in the first frame or cells that appear in the frame from - outside of the field of view. - - With this function we simply set 'is_history_known' to False. - When the users saves instead we update the entire staus of the cell - with unknown history with the function "updateIsHistoryKnown()" - """ - posData = self.data[self.pos_i] - is_history_known = posData.cca_df.at[ID, 'is_history_known'] - relID = posData.cca_df.at[ID, 'relative_ID'] - if relID in posData.cca_df.index: - relID_cca = self.getStatus_RelID_BeforeEmergence(ID, relID) - - if is_history_known: - # Save status of ID when emerged to allow undoing - statusID_whenEmerged = self.getStatusKnownHistoryBud(ID) - if statusID_whenEmerged is None: - return - posData.ccaStatus_whenEmerged[ID] = statusID_whenEmerged - - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - if ID not in posData.ccaStatus_whenEmerged: - self.warnSettingHistoryKnownCellsFirstFrame(ID) - return - - self.setHistoryKnowledge(ID, posData.cca_df) - - if relID in posData.cca_df.index: - # If the cell with unknown history has a relative ID assigned to it - # we set the cca of it to the status it had BEFORE the assignment - posData.cca_df.loc[relID] = relID_cca - - # Update cell cycle info LabelItems - obj_idx = posData.IDs.index(ID) - rp_ID = posData.rp[obj_idx] - - if relID in posData.IDs: - relObj_idx = posData.IDs.index(relID) - rp_relID = posData.rp[relObj_idx] - - self.setAllTextAnnotations() - self.drawAllMothBudLines() - - self.store_cca_df() - - if self.ccaTableWin is not None: - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - # Correct future frames - for i in range(posData.frame_i+1, posData.SizeT): - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(i, cca_df_i, undoId) - IDs = cca_df_i.index - if ID not in IDs: - # For some reason ID disappeared from this frame - continue - else: - self.setHistoryKnowledge(ID, cca_df_i) - if relID in IDs: - cca_df_i.loc[relID] = relID_cca - self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) - - - # Correct past frames - for i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(i, cca_df_i, undoId) - IDs = cca_df_i.index - if ID not in IDs: - # we reached frame where ID was not existing yet - break - else: - relID = cca_df_i.at[ID, 'relative_ID'] - self.setHistoryKnowledge(ID, cca_df_i) - if relID in IDs: - cca_df_i.loc[relID] = relID_cca - self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) - - self.enqAutosave() - - def annotateWillDivide(self, ID, relID, frame_i=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - # Store in the past frames that division has been annotated - for past_frame_i in range(frame_i-1, -1, -1): - past_cca_df = self.get_cca_df(frame_i=past_frame_i, return_df=True) - if past_cca_df is None: - return - - if ID not in past_cca_df.index: - # ID is a bud and is not emerged yet here - return - - if frame_i-1 == past_frame_i: - # Get generation number at first iteration - gen_num = past_cca_df.at[ID, 'generation_num'] - - if past_cca_df.at[ID, 'generation_num'] != gen_num: - # ID is a mother and the cell cycle is finished here - return - - past_cca_df.at[ID, 'will_divide'] = 1 - past_cca_df.at[relID, 'will_divide'] = 1 - - self.store_cca_df( - cca_df=past_cca_df, frame_i=past_frame_i, autosave=False - ) - - def annotateDivisionFutureFramesSwapMothers( - self, cca_df_at_future_division, mothIDofDisappearedBud, frame_i - ): - """This method is called as part of `guiWin.swapMothers`. - - It annotates cell division and propagates that to future frames to the - mother cell that stops having the correct bud because division between - wrong bud and other wrong mother was annotated in the future. - - Parameters - ---------- - cca_df_at_future_division : pd.DataFrame - _description_ - mothIDofDisappearedBud : int - Mother ID of the disappeared bud - frame_i : int - Frame since when the mother ID stops having the correct bud because - the correct bud was assigned as divided from the wrong mother - """ - posData = self.data[self.pos_i] - - relativeIDofMothID = cca_df_at_future_division.at[ - mothIDofDisappearedBud, 'relative_ID' - ] - if relativeIDofMothID not in cca_df_at_future_division.index: - # Also wrong bud ID disappeared - return - - relativeIDofMothIDrelationship = cca_df_at_future_division.at[ - relativeIDofMothID, 'relationship' - ] - if relativeIDofMothIDrelationship != 'bud': - # The wrong bud ID is a cell in G1 from future cycle --> - # the actual wrong bud ID disappeared too. - return - - wrongBudID = relativeIDofMothID - - self.annotateDivision( - cca_df_at_future_division, mothIDofDisappearedBud, wrongBudID, - frame_i=frame_i - ) - cca_df_at_future_division.at[ - mothIDofDisappearedBud, 'corrected_on_frame_i'] = frame_i - self.store_cca_df( - frame_i=frame_i, cca_df=cca_df_at_future_division, autosave=False - ) - - ccaStatusToRestore = cca_df_at_future_division.loc[mothIDofDisappearedBud] - for future_i in range(frame_i+1, posData.SizeT): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - ccs = cca_df_i.at[mothIDofDisappearedBud, 'cell_cycle_stage'] - if ccs == 'G1': - # Mother cell in G1 again, stop correcting - break - - cca_df_i.loc[mothIDofDisappearedBud] = ccaStatusToRestore - cca_df_i.at[mothIDofDisappearedBud, 'corrected_on_frame_i'] = frame_i - - self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) - - def annotateDivision(self, cca_df, ID, relID, frame_i=None): - # Correct as follows: - # For frame_i > 0 --> assign to G1 and +1 on generation number - # For frame == 0 --> reinitialize to unknown cells - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - self.annotateWillDivide(ID, relID) - - store = False - cca_df.at[ID, 'cell_cycle_stage'] = 'G1' - cca_df.at[relID, 'cell_cycle_stage'] = 'G1' - - if frame_i > 0: - gen_num_clickedID = cca_df.at[ID, 'generation_num'] - cca_df.at[ID, 'generation_num'] += 1 - cca_df.at[ID, 'division_frame_i'] = frame_i - gen_num_relID = cca_df.at[relID, 'generation_num'] - cca_df.at[relID, 'generation_num'] = gen_num_relID+1 - cca_df.at[relID, 'division_frame_i'] = frame_i - if gen_num_clickedID < gen_num_relID: - cca_df.at[ID, 'relationship'] = 'mother' - else: - cca_df.at[relID, 'relationship'] = 'mother' - else: - cca_df.at[ID, 'generation_num'] = 2 - cca_df.at[relID, 'generation_num'] = 2 - - cca_df.at[ID, 'division_frame_i'] = -1 - cca_df.at[relID, 'division_frame_i'] = -1 - - cca_df.at[ID, 'relationship'] = 'mother' - cca_df.at[relID, 'relationship'] = 'mother' - - store = True - return store - - def undoDivisionAnnotation(self, cca_df, ID, relID): - # Correct as follows: - # If G1 then correct to S and -1 on generation number - store = False - cca_df.at[ID, 'cell_cycle_stage'] = 'S' - gen_num_clickedID = cca_df.at[ID, 'generation_num'] - cca_df.at[ID, 'generation_num'] -= 1 - cca_df.at[ID, 'division_frame_i'] = -1 - cca_df.at[relID, 'cell_cycle_stage'] = 'S' - gen_num_relID = cca_df.at[relID, 'generation_num'] - cca_df.at[relID, 'generation_num'] -= 1 - cca_df.at[relID, 'division_frame_i'] = -1 - if gen_num_clickedID < gen_num_relID: - cca_df.at[ID, 'relationship'] = 'bud' - else: - cca_df.at[relID, 'relationship'] = 'bud' - cca_df.at[ID, 'will_divide'] = 0 - cca_df.at[relID, 'will_divide'] = 0 - store = True - return store - - def undoBudMothAssignment(self, ID): - posData = self.data[self.pos_i] - relID = posData.cca_df.at[ID, 'relative_ID'] - ccs = posData.cca_df.at[ID, 'cell_cycle_stage'] - if ccs == 'G1': - return - posData.cca_df.at[ID, 'relative_ID'] = -1 - posData.cca_df.at[ID, 'generation_num'] = 2 - posData.cca_df.at[ID, 'cell_cycle_stage'] = 'G1' - posData.cca_df.at[ID, 'relationship'] = 'mother' - if relID in posData.cca_df.index: - posData.cca_df.at[relID, 'relative_ID'] = -1 - posData.cca_df.at[relID, 'generation_num'] = 2 - posData.cca_df.at[relID, 'cell_cycle_stage'] = 'G1' - posData.cca_df.at[relID, 'relationship'] = 'mother' - - obj_idx = posData.IDs.index(ID) - relObj_idx = posData.IDs.index(relID) - rp_ID = posData.rp[obj_idx] - rp_relID = posData.rp[relObj_idx] - - self.store_cca_df() - - # Update cell cycle info LabelItems - self.setAllTextAnnotations() - - if self.ccaTableWin is not None: - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - @exception_handler - def manualCellCycleAnnotation(self, ID): - """ - This function is used for both annotating division or undoing the - annotation. It can be called on any frame. - - If we annotate division (right click on a cell in S) then it will - check if there are future frames to correct. - Frames to correct are those frames where both the mother and the bud - are annotated as S phase cells. - In this case we assign all those frames to G1, relationship to mother, - and +1 generation number - - If we undo the annotation (right click on a cell in G1) then it will - correct both past and future annotated frames (if present). - Frames to correct are those frames where both the mother and the bud - are annotated as G1 phase cells. - In this case we assign all those frames to G1, relationship back to - bud, and -1 generation number - """ - posData = self.data[self.pos_i] - - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - # Correct current frame - clicked_ccs = posData.cca_df.at[ID, 'cell_cycle_stage'] - relID = posData.cca_df.at[ID, 'relative_ID'] - - if relID not in posData.IDs: - return - - if clicked_ccs == 'G1' and posData.frame_i == 0: - # We do not allow undoing division annotation on first frame - return - - if clicked_ccs == 'G1': - issue_frame_i = self.checkDivisionCanBeUndone(ID, relID) - if issue_frame_i is not None: - _warnings.warnDivisionAnnotationCannotBeUndone( - ID, relID, issue_frame_i, qparent=self - ) - return - - if clicked_ccs == 'S': - self.annotateDivision(posData.cca_df, ID, relID) - self.store_cca_df() - else: - self.undoDivisionAnnotation(posData.cca_df, ID, relID) - self.store_cca_df() - - # Update cell cycle info LabelItems - self.ax1_newMothBudLinesItem.setData([], []) - self.ax1_oldMothBudLinesItem.setData([], []) - self.ax2_newMothBudLinesItem.setData([], []) - self.ax2_oldMothBudLinesItem.setData([], []) - self.drawAllMothBudLines() - self.setAllTextAnnotations() - - if self.ccaTableWin is not None: - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - # Correct future frames - for future_i in range(posData.frame_i+1, posData.SizeT): - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(future_i, cca_df_i, undoId) - IDs = cca_df_i.index - if ID not in IDs: - # For some reason ID disappeared from this frame - continue - - ccs = cca_df_i.at[ID, 'cell_cycle_stage'] - relID = cca_df_i.at[ID, 'relative_ID'] - if clicked_ccs == 'S': - if ccs == 'G1': - # Cell is in G1 in the future again so stop annotating - break - self.annotateDivision(cca_df_i, ID, relID) - self.store_cca_df( - frame_i=future_i, cca_df=cca_df_i, autosave=False - ) - elif ccs == 'S': - # Cell is in S in the future again so stop undoing (break) - # also leave a 1 frame duration G1 to avoid a continuous - # S phase - self.annotateDivision(cca_df_i, ID, relID) - self.store_cca_df( - frame_i=future_i, cca_df=cca_df_i, autosave=False - ) - break - else: - self.undoDivisionAnnotation(cca_df_i, ID, relID) - self.store_cca_df( - frame_i=future_i, cca_df=cca_df_i, autosave=False - ) - - # Correct past frames - for past_i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - if ID not in cca_df_i.index or relID not in cca_df_i.index: - # Bud did not exist at frame_i = i - break - - self.storeUndoRedoCca(past_i, cca_df_i, undoId) - ccs = cca_df_i.at[ID, 'cell_cycle_stage'] - relID = cca_df_i.at[ID, 'relative_ID'] - if ccs == 'S': - # We correct only those frames in which the ID was in 'G1' - break - else: - store = self.undoDivisionAnnotation(cca_df_i, ID, relID) - self.store_cca_df( - frame_i=past_i, cca_df=cca_df_i, autosave=False - ) - - self.enqAutosave() - - def warnMotherNotEligible(self, new_mothID, budID, i, why): - if why == 'not_G1_in_the_future': - err_msg = html_utils.paragraph(f""" - The requested cell in G1 (ID={new_mothID}) - at future frame {i+1} has a bud assigned to it, - therefore it cannot be assigned as the mother - of bud ID {budID}.

- You can assign a cell as the mother of bud ID {budID} - only if this cell is in G1 for the - entire life of the bud.

- One possible solution is to click on "cancel", go to - frame {i+1} and assign the bud of cell {new_mothID} - to another cell.\n' - A second solution is to assign bud ID {budID} to cell - {new_mothID} anyway by clicking "Apply".

- However to ensure correctness of - future assignments Cell-ACDC will delete any cell cycle - information from frame {i+1} to the end. Therefore, you - will have to visit those frames again.

- The deletion of cell cycle information - CANNOT BE UNDONE! - Saved data is not changed of course.

- Apply assignment or cancel process? - """) - applyButton = widgets.okPushButton(isDefault=False) - applyButton.setText('Apply and remove future annotations') - msg = widgets.myMessageBox() - _, applyButton = msg.warning( - self, 'Cell not eligible', err_msg, - buttonsTexts=('Cancel', applyButton) - ) - cancel = msg.cancel - apply = msg.clickedButton == applyButton - elif why == 'not_G1_in_the_past': - err_msg = html_utils.paragraph(f""" - The requested cell in G1 - (ID={new_mothID}) at past frame {i+1} - has a bud assigned to it, therefore it cannot be - assigned as mother of bud ID {budID}.
- You can assign a cell as the mother of bud ID {budID} - only if this cell is in G1 for the entire life of the bud.
- One possible solution is to first go to frame {i+1} and - assign the bud of cell {new_mothID} to another cell. - """) - msg = widgets.myMessageBox() - msg.warning( - self, 'Cell not eligible', err_msg - ) - cancel = msg.cancel - apply = False - elif why == 'single_frame_G1_duration': - err_msg = html_utils.paragraph(f""" - Assigning bud ID {budID} to cell ID {new_mothID} would result - in no G1 phase at all between previous cell cycle and - current cell cycle (see frame n. {i+1}).

- - The solution is to annotate division on cell ID {new_mothID} - on any frame before the frame number {i+1}, and then - proceed to correcting the bud assignment.

- - This will gurantee a G1 duration for the cell {new_mothID} - of at least 1 frame.

- Thank you for your patience! - """) - msg = widgets.myMessageBox() - msg.warning( - self, 'Cell not eligible', err_msg - ) - cancel = msg.cancel - apply = False - return cancel, apply - - def warnSettingHistoryKnownCellsFirstFrame(self, ID): - txt = html_utils.paragraph(f""" - Cell ID {ID} is a cell that is present since the first - frame.

- These cells already have history UNKNOWN assigned and the - history status cannot be changed. - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning( - self, 'First frame cells', txt - ) - - def checkMothEligibility(self, budID, new_mothID): - """ - Check that the new mother is in G1 for the entire life of the bud - and that the G1 duration is > than 1 frame - """ - last_cca_frame_i = self.navigateScrollBar.maximum()-1 - posData = self.data[self.pos_i] - eligible = True - - # Check future frames - G1_duration_future = 0 - for future_i in range(posData.frame_i, posData.SizeT): - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - - if cca_df_i is None: - # ith frame was not visited yet - break - - if budID not in cca_df_i.index: - # Bud disappeared - break - - is_still_bud = cca_df_i.at[budID, 'relationship'] == 'bud' - if not is_still_bud: - break - - ccs = cca_df_i.at[new_mothID, 'cell_cycle_stage'] - if ccs != 'G1': - cancel, apply = self.warnMotherNotEligible( - new_mothID, budID, future_i, 'not_G1_in_the_future' - ) - if apply: - self.resetCcaFuture(future_i) - break - isG1singleFrame = G1_duration_future == 1 - isFutureFrameNotLastAnnot = future_i != last_cca_frame_i - if cancel or (isG1singleFrame and isFutureFrameNotLastAnnot): - eligible = False - return eligible - - G1_duration_future += 1 - - # Check past frames - for past_i in range(posData.frame_i-1, -1, -1): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - - is_bud_existing = budID in cca_df_i.index - is_moth_existing = new_mothID in cca_df_i.index - - if not is_moth_existing: - # Mother not existing because it appeared from outside FOV - break - - ccs = cca_df_i.at[new_mothID, 'cell_cycle_stage'] - if ccs != 'G1' and is_bud_existing: - # Requested mother not in G1 in the past - # during the life of the bud (is_bud_existing = True) - self.warnMotherNotEligible( - new_mothID, budID, past_i, 'not_G1_in_the_past' - ) - eligible = False - return eligible - - if not is_bud_existing: - # Bud stop existing --> check that mother is still in G1 - if ccs != 'G1': - eligible = False - self.warnMotherNotEligible( - new_mothID, budID, past_i, 'single_frame_G1_duration' - ) - break - - return eligible - - def checkMothersExcludedOrDead(self): - try: - posData = self.data[self.pos_i] - buds_df = posData.cca_df[ - (posData.cca_df.relationship == 'bud') - & (posData.cca_df.emerg_frame_i == posData.frame_i) - ] - acdc_df_i = posData.allData_li[posData.frame_i]['acdc_df'] - moth_df = acdc_df_i.loc[buds_df.relative_ID.to_list()] - excluded_df = moth_df[ - (moth_df.is_cell_dead > 0) | (moth_df.is_cell_excluded > 0) - ] - excludedMothIDs = excluded_df.index.to_list() - if not excludedMothIDs: - self.stopBlinkingPairItem() - return True - budIDsOfExcludedMoth = excluded_df.relative_ID.to_list() - proceed = self.warnDeadOrExcludedMothers( - budIDsOfExcludedMoth, excludedMothIDs - ) - return proceed - except Exception as e: - self.logger.info(traceback.format_exc()) - print('-'*100) - self.logger.warning( - 'Checking if mother cell is excluded or dead failed.' - ) - print('^'*100) - return False - - def checkDivisionCanBeUndone(self, ID, relID): - """Check that division annotation can be undone (see Notes section) - - Parameters - ---------- - ID : int - Cell ID of the clicked cell in G1 - relID : _type_ - Relative ID of the cell that was clicked - - Notes - ----- - Division annotation can be undone only if `relID` is also in G1 for the - entire duration of the correction - """ - posData = self.data[self.pos_i] - - ccs_relID = posData.cca_df.at[relID, 'cell_cycle_stage'] - if ccs_relID == 'S': - return posData.frame_i - - # Check future frames - for future_i in range(posData.frame_i+1, posData.SizeT): - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - ccs_relID = cca_df_i.at[relID, 'cell_cycle_stage'] - if ccs_relID == 'S': - return future_i - - # Check past frames - for past_i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - if ID not in cca_df_i.index or relID not in cca_df_i.index: - # Bud did not exist at frame_i = i - break - - ccs = cca_df_i.at[ID, 'cell_cycle_stage'] - if ccs == 'S': - break - - ccs_relID = cca_df_i.at[relID, 'cell_cycle_stage'] - if ccs_relID == 'S': - return future_i - - - def stopBlinkingPairItem(self): - self.ax1_newMothBudLinesItem.setOpacity(1.0) - self.ax1_oldMothBudLinesItem.setOpacity(1.0) - - self.warnPairingItem.setData([], []) - try: - self.blinkPairingItemTimer.stop() - except Exception as e: - pass - - def warnDeadOrExcludedMothers(self, budIDs, mothIDs): - self.startBlinkingPairingItem(budIDs, mothIDs) - msg = widgets.myMessageBox(wrapText=False) - pairings = [ - f'Mother ID {mID} --> bud ID {bID}' - for mID, bID in zip(mothIDs, budIDs) - ] - txt = html_utils.paragraph(f""" - The mother cell in the following mother-bud pairings - (blinking line on the image) is
- excluded from the analysis or dead: - {html_utils.to_list(pairings)} - """) - msg.warning( - self, 'Mother cell is excluded or dead', txt, - buttonsTexts=('Cancel', 'Ok') - ) - return not msg.cancel - - def startBlinkingPairingItem(self, budIDs, mothIDs): - self.ax1_newMothBudLinesItem.setOpacity(0.2) - self.ax1_oldMothBudLinesItem.setOpacity(0.2) - - posData = self.data[self.pos_i] - acdc_df_i = posData.allData_li[posData.frame_i]['acdc_df'] - - # Blink one pairing at the time (the first found) - xc_b = acdc_df_i.loc[budIDs[0], 'x_centroid'] - yc_b = acdc_df_i.loc[budIDs[0], 'y_centroid'] - - xc_m = acdc_df_i.loc[mothIDs[0], 'x_centroid'] - yc_m = acdc_df_i.loc[mothIDs[0], 'y_centroid'] - - self.warnPairingItem.setData([xc_b, xc_m], [yc_b, yc_m]) - - self.blinkPairingItemTimer = QTimer() - self.blinkPairingItemTimer.flag = True - self.blinkPairingItemTimer.timeout.connect(self.blinkPairingItem) - self.blinkPairingItemTimer.start(300) - - def blinkPairingItem(self): - if self.blinkPairingItemTimer.flag: - opacity = 0.3 - self.blinkPairingItemTimer.flag = False - else: - opacity = 1.0 - self.blinkPairingItemTimer.flag = True - self.warnPairingItem.setOpacity(opacity) - - def getStatus_RelID_BeforeEmergence(self, budID, curr_mothID): - posData = self.data[self.pos_i] - # Get status of the current mother before it had budID assigned to it - cca_status_before_bud_emerg = None - for i in range(posData.frame_i-1, -1, -1): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - - is_bud_existing = budID in cca_df_i.index - if not is_bud_existing: - # Bud was not emerged yet - if curr_mothID in cca_df_i.index: - cca_status_before_bud_emerg = cca_df_i.loc[curr_mothID] - return cca_status_before_bud_emerg - else: - # The bud emerged together with the mother because - # they appeared together from outside of the fov - # and they were trated as new IDs bud in S0 - bud_cca_dict = base_cca_dict.copy() - bud_cca_dict['cell_cycle_stage'] = 'S' - bud_cca_dict['generation_num'] = 0 - bud_cca_dict['relationship'] = 'bud' - bud_cca_dict['emerg_frame_i'] = i+1 - bud_cca_dict['is_history_known'] = True - cca_status_before_bud_emerg = pd.Series(bud_cca_dict) - return cca_status_before_bud_emerg - - # Mother did not have a status before bud emergence because it was - # already paired with bud at first frame --> reinit to default - cca_status_before_bud_emerg = ( - core.getBaseCca_df([curr_mothID]).loc[curr_mothID] - ) - return cca_status_before_bud_emerg - - - def annotateBudToDifferentMother(self): - """ - This function is used for correcting automatic mother-bud assignment. - - It can be called at any frame of the bud life. - - There are three cells involved: bud, current mother, new mother. - - Eligibility: - - User clicked first on a bud (checked at click time) - - User released mouse button on a cell in G1 (checked at release time) - - The new mother MUST be in G1 for all the frames of the bud life - --> if not warn - - The new mother MUST have appeared in current frame OR be already - in G1 in previous frame, otherwise there would be no G1 cycle - - Result: - - The bud only changes relative ID to the new mother - - The new mother changes relative ID and stage to 'S' - - The old mother changes its entire status to the status it had - before being assigned to the clicked bud - """ - posData = self.data[self.pos_i] - lab2D = self.get_2Dlab(posData.lab) - budID = lab2D[self.yClickBud, self.xClickBud] - new_mothID = lab2D[self.yClickMoth, self.xClickMoth] - - if budID == new_mothID: - return - - if not self.isSnapshot: - eligible = self.checkMothEligibility(budID, new_mothID) - if not eligible: - return - - budEligible = self.checkChangeMotherBudEligible( - budID, posData.frame_i - ) - if not budEligible: - return - - # Allow partial initialization of cca_df with mouse - if posData.frame_i == 0: - newMothCcs = posData.cca_df.at[new_mothID, 'cell_cycle_stage'] - if not newMothCcs == 'G1': - err_msg = ( - 'You are assigning the bud to a cell that is not in G1!' - ) - msg = QMessageBox() - msg.critical( - self, 'New mother not in G1!', err_msg, msg.Ok - ) - return - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(0, posData.cca_df, undoId) - currentRelID = posData.cca_df.at[budID, 'relative_ID'] - if currentRelID in posData.cca_df.index: - posData.cca_df.at[currentRelID, 'relative_ID'] = -1 - posData.cca_df.at[currentRelID, 'generation_num'] = 2 - posData.cca_df.at[currentRelID, 'cell_cycle_stage'] = 'G1' - posData.cca_df.at[budID, 'relationship'] = 'bud' - posData.cca_df.at[budID, 'generation_num'] = 0 - posData.cca_df.at[budID, 'relative_ID'] = new_mothID - posData.cca_df.at[budID, 'cell_cycle_stage'] = 'S' - posData.cca_df.at[new_mothID, 'relative_ID'] = budID - posData.cca_df.at[new_mothID, 'generation_num'] = 2 - posData.cca_df.at[new_mothID, 'cell_cycle_stage'] = 'S' - self.updateAllImages() - self.store_cca_df() - return - - curr_mothID = posData.cca_df.at[budID, 'relative_ID'] - if curr_mothID in posData.cca_df.index: - curr_moth_cca = self.getStatus_RelID_BeforeEmergence( - budID, curr_mothID - ) - - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - # Correct current frames and update LabelItems - posData.cca_df.at[budID, 'relative_ID'] = new_mothID - posData.cca_df.at[budID, 'generation_num'] = 0 - posData.cca_df.at[budID, 'relative_ID'] = new_mothID - posData.cca_df.at[budID, 'relationship'] = 'bud' - posData.cca_df.at[budID, 'corrected_on_frame_i'] = posData.frame_i - posData.cca_df.at[budID, 'cell_cycle_stage'] = 'S' - - posData.cca_df.at[new_mothID, 'relative_ID'] = budID - posData.cca_df.at[new_mothID, 'cell_cycle_stage'] = 'S' - posData.cca_df.at[new_mothID, 'relationship'] = 'mother' - - - if curr_mothID in posData.cca_df.index: - # Cells with UNKNOWN history has relative's ID = -1 - # which is not an existing cell - posData.cca_df.loc[curr_mothID] = curr_moth_cca - - self.updateAllImages() - - # self.checkMultiBudMoth(draw=True) - self.store_cca_df() - proceed = self.checkMothersExcludedOrDead() - if not proceed: - # User clicked on cancel in the message box - self.UndoCca() - return - - if self.ccaTableWin is not None: - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - # Correct future frames - for i in range(posData.frame_i+1, posData.SizeT): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - IDs = cca_df_i.index - if budID not in IDs or new_mothID not in IDs: - # For some reason ID disappeared from this frame - continue - - self.storeUndoRedoCca(i, cca_df_i, undoId) - bud_relationship = cca_df_i.at[budID, 'relationship'] - bud_ccs = cca_df_i.at[budID, 'cell_cycle_stage'] - - if bud_relationship == 'mother' and bud_ccs == 'S': - # The bud at the ith frame budded itself --> stop - break - - cca_df_i.at[budID, 'relative_ID'] = new_mothID - cca_df_i.at[budID, 'generation_num'] = 0 - cca_df_i.at[budID, 'relative_ID'] = new_mothID - cca_df_i.at[budID, 'relationship'] = 'bud' - cca_df_i.at[budID, 'cell_cycle_stage'] = 'S' - - newMoth_bud_ccs = cca_df_i.at[new_mothID, 'cell_cycle_stage'] - if newMoth_bud_ccs == 'G1': - # Assign bud to new mother only if the new mother is in G1 - # This can happen if the bud already has a G1 annotated - cca_df_i.at[new_mothID, 'relative_ID'] = budID - cca_df_i.at[new_mothID, 'cell_cycle_stage'] = 'S' - cca_df_i.at[new_mothID, 'relationship'] = 'mother' - - if curr_mothID in cca_df_i.index: - # Cells with UNKNOWN history has relative's ID = -1 - # which is not an existing cell - cca_df_i.loc[curr_mothID] = curr_moth_cca - - self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) - - # Correct past frames - for i in range(posData.frame_i-1, -1, -1): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - - is_bud_existing = budID in cca_df_i.index - if not is_bud_existing: - # Bud was not emerged yet - break - - self.storeUndoRedoCca(i, cca_df_i, undoId) - cca_df_i.at[budID, 'relative_ID'] = new_mothID - cca_df_i.at[budID, 'generation_num'] = 0 - cca_df_i.at[budID, 'relative_ID'] = new_mothID - cca_df_i.at[budID, 'relationship'] = 'bud' - cca_df_i.at[budID, 'cell_cycle_stage'] = 'S' - - cca_df_i.at[new_mothID, 'relative_ID'] = budID - cca_df_i.at[new_mothID, 'cell_cycle_stage'] = 'S' - cca_df_i.at[new_mothID, 'relationship'] = 'mother' - - if curr_mothID in cca_df_i.index: - # Cells with UNKNOWN history has relative's ID = -1 - # which is not an existing cell - cca_df_i.loc[curr_mothID] = curr_moth_cca - - self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) - - self.enqAutosave() - - def onMotherNotInG1(self, mothID): - txt = html_utils.paragraph( - f'You clicked on ID={mothID} which is NOT in G1

' - 'Do you want to proceed with swapping the mother cells?

' - 'NOTE: To assign a bud start by clicking on the bud ' - 'and release on a cell in G1' - ) - msg = widgets.myMessageBox() - swapMothersButton = widgets.reloadPushButton('Swap mother cells') - _, swapMothersButton = msg.warning( - self, 'Released on a cell NOT in G1', txt, - buttonsTexts=('Cancel', swapMothersButton) - ) - if msg.cancel: - return - - pairings = self.checkSwapMothersEligibility() - if pairings is None: - self.logger.info('Swapping mothers is not possible.') - return - - self.swapMothers(*pairings) - - def _checkBudFutureNoDivision(self, budID, start_frame_i): - posData = self.data[self.pos_i] - - future_i = start_frame_i - for future_i in range(start_frame_i, posData.SizeT): - if future_i == 0: - continue - - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - return - - if budID not in cca_df_i.index: - # Bud disappears in the future --> fine - return - - ccs = cca_df_i.at[budID, 'cell_cycle_stage'] - if ccs == 'G1': - return future_i, cca_df_i.at[budID, 'relative_ID'] - - def warnBudAnnotatedDividedInFuture( - self, budID, motherID, future_division_frame_i, - action='swap mother cells' - ): - posData = self.data[self.pos_i] - - txt = html_utils.paragraph(f""" - Bud ID {budID} is annotated as divided from mother ID {motherID} - at frame n. {future_division_frame_i+1},
- therefore it is not possible to {action}.

- We recommend reinitializing cell cycle annotations on any - frame
between frames number {posData.frame_i+1} and - {future_division_frame_i} before attempting to {action}.

- Thank you for your patience! - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, f'{action} not possible'.title(), txt) - return - - def _checkMothInG1beforeBudEmergence( - self, motherID, budID, wrongBudID, start_frame_i - ): - """Check that mother is in G1 on the frame before bud emergence - - Parameters - ---------- - motherID : int - ID of mother cell - budID : int - ID of bud - start_frame_i : int - Frame index from which to start checking in the past - """ - for past_i in range(start_frame_i, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - if budID not in cca_df_i.index: - if cca_df_i.at[motherID, 'cell_cycle_stage'] == 'G1': - return - - budID_prev_cycle = cca_df_i.at[motherID, 'relative_ID'] - if budID_prev_cycle != wrongBudID: - return past_i + 1 - - break - - def warnMotherNotAtLeastOneFrameG1(self, budID, motherID, frame_no_G1): - posData = self.data[self.pos_i] - - txt = html_utils.paragraph(f""" - Assigning bud ID {budID} to cell ID {motherID} cannot be - done because cell ID {motherID} is not in G1 at frame n. - {frame_no_G1}.

- This would result in no G1 phase between previous cell cycle of - cell ID {motherID} and current one. - This is unfortunately not allowed.

- One possible solution is to annotate division on cell ID - {motherID} on any frame before frame n. {frame_no_G1}.

- Thank you for your patience! - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning(self, 'Swap mothers not possible', txt) - return - - def checkChangeMotherBudEligible(self, budID, frame_i): - result = self._checkBudFutureNoDivision(budID, frame_i) - if result is None: - return True - - self.warnBudAnnotatedDividedInFuture( - budID, *result, action='change mother cell' - ) - return False - - def checkSwapMothersEligibility(self): - posData = self.data[self.pos_i] - - lab2D = self.get_2Dlab(posData.lab) - budID = lab2D[self.yClickBud, self.xClickBud] - otherMothID = lab2D[self.yClickMoth, self.xClickMoth] - mothID = posData.cca_df.at[budID, 'relative_ID'] - otherBudID = posData.cca_df.at[otherMothID, 'relative_ID'] - - for _budID in (budID, otherBudID): - result = self._checkBudFutureNoDivision( - _budID, posData.frame_i - ) - if result is None: - continue - - self.warnBudAnnotatedDividedInFuture(_budID, *result) - return - - correct_pairings = { - otherBudID: mothID, budID: otherMothID - } - wrong_pairings = { - mothID: budID, otherMothID: otherBudID - } - for correctBudID, correctMothID in correct_pairings.items(): - wrongBudID = wrong_pairings[correctMothID] - frame_no_G1 = self._checkMothInG1beforeBudEmergence( - correctMothID, correctBudID, wrongBudID, posData.frame_i - ) - if frame_no_G1 is None: - continue - - self.warnMotherNotAtLeastOneFrameG1( - correctBudID, correctMothID, frame_no_G1 - ) - return - - return budID, otherBudID, otherMothID, mothID - - @exception_handler - def swapMothers(self, budID, otherBudID, otherMothID, mothID): - posData = self.data[self.pos_i] - - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - self.logger.info( - f'Swapping assignments (requested at frame n. {posData.frame_i+1}):\n' - f' * Bud ID {budID} --> mother ID {otherMothID}\n' - f' * Bud ID {otherBudID} --> mother ID {mothID}' - ) - - correct_pairings = { - otherBudID: mothID, - budID: otherMothID - } - - for correct_budID, correct_mothID in correct_pairings.items(): - posData.cca_df.at[correct_budID, 'relative_ID'] = correct_mothID - posData.cca_df.at[correct_mothID, 'relative_ID'] = correct_budID - posData.cca_df.at[correct_budID, 'corrected_on_frame_i'] = posData.frame_i - posData.cca_df.at[correct_mothID, 'corrected_on_frame_i'] = posData.frame_i - self.store_cca_df() - - # Correct past frames - corrected_budIDs_past = set() - for past_i in range(posData.frame_i-1, -1, -1): - if len(corrected_budIDs_past) == 2: - break - - for correct_budID, correct_mothID in correct_pairings.items(): - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - - if correct_budID in corrected_budIDs_past: - continue - - if correct_budID not in cca_df_i.index: - # Bud does not exist anymore in the past - corrected_budIDs_past.add(correct_budID) - - if len(corrected_budIDs_past) < 2: - self.restoreMotherToBeforeWrongBudWasAssignedToIt( - correct_mothID, cca_df_i, past_i - ) - continue - - cca_df_i.at[correct_budID, 'relative_ID'] = correct_mothID - cca_df_i.at[correct_mothID, 'relative_ID'] = correct_budID - cca_df_i.at[correct_budID, 'corrected_on_frame_i'] = posData.frame_i - cca_df_i.at[correct_mothID, 'corrected_on_frame_i'] = posData.frame_i - - # Set mother cell cycle stage to S in case it is not - if cca_df_i.at[correct_mothID, 'cell_cycle_stage'] == 'G1': - cca_df_i.at[correct_mothID, 'cell_cycle_stage'] = 'S' - # cca_df_i.at[correct_mothID, 'generation_num'] -= 1 - - self.store_cca_df( - frame_i=past_i, cca_df=cca_df_i, autosave=False - ) - - # Correct future frames - corrected_budIDs_future = set() - for future_i in range(posData.frame_i+1, posData.SizeT): - if len(corrected_budIDs_future) == 2: - break - - # Get cca_df for ith frame from allData_li - cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - for correct_budID, correct_mothID in correct_pairings.items(): - if correct_budID in corrected_budIDs_future: - # Bud already corrected in the future - continue - - if correct_budID not in cca_df_i.index: - # Bud disappeared in the future - corrected_budIDs_future.add(correct_budID) - continue - - ccs_bud = cca_df_i.at[correct_budID, 'cell_cycle_stage'] - if ccs_bud == 'G1': - # Bud divided in the future, annotate division between - # correct mother and wrong bud and then stop correcting - if correct_budID not in corrected_budIDs_future: - corrected_budIDs_future.add(correct_budID) - - if len(corrected_budIDs_future) < 2: - self.annotateDivisionFutureFramesSwapMothers( - cca_df_i, correct_mothID, future_i - ) - continue - - cca_df_i.at[correct_budID, 'relative_ID'] = correct_mothID - cca_df_i.at[correct_mothID, 'relative_ID'] = correct_budID - cca_df_i.at[correct_budID, 'corrected_on_frame_i'] = posData.frame_i - cca_df_i.at[correct_mothID, 'corrected_on_frame_i'] = posData.frame_i - - # Set mother cell cycle stage to S in case it is not - if cca_df_i.at[correct_mothID, 'cell_cycle_stage'] == 'G1': - cca_df_i.at[correct_mothID, 'cell_cycle_stage'] = 'S' - # cca_df_i.at[correct_mothID, 'generation_num'] -= 1 - - self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) - - self.updateAllImages() - - def restoreMotherToBeforeWrongBudWasAssignedToIt( - self, mothIDofDisappearedBud, - cca_df_at_correct_bud_ID_disappearance, - frame_i - ): - """This method is called as part of `guiWin.swapMothers`. - - Parameters - ---------- - mothIDofDisappearedBud : int - Mother ID of the disappeared bud - cca_df_at_correct_bud_ID_disappearance : pd.DataFrame - Cell cycle annotations DataFrame when the correct bud ID stopped - existing (before emergence) - frame_i : int - Frame index when the correct bud ID stopped existing - (before emergence) - - Note - ---- - It restores the mother cell cycle annotations to the status it had - before the wrong bud was assigned to it. - - We need to do it only if the swapMothers past frames loop is still - iterating to correct the other bud. - - We also need to do this only if the wrong bud ID is actually a bud. - - When we swap mothers in the past frames it can be that the correct bud - ID stops existing (before emergence). In this case the correct mother - still has the wrong bud assigned to ID so we need to restore the status - it had before the wrong bud was assigned to it. - - To determine the status we go back until the wrong bud disappear. That - is the frame before the wrong bud was assigned to the mother we want to - correct. This is the status we want to restore. - - When we go back in time it could be that the wrong bud never disappears - becuase it is already emerged at frame 0. In this case the status we - want to restore at is the default G1 status at frame 0. - """ - relativeIDofMothID = cca_df_at_correct_bud_ID_disappearance.at[ - mothIDofDisappearedBud, 'relative_ID' - ] - if relativeIDofMothID not in cca_df_at_correct_bud_ID_disappearance.index: - # Also wrong bud ID disappeared - return - - relativeIDofMothIDrelationship = cca_df_at_correct_bud_ID_disappearance.at[ - relativeIDofMothID, 'relationship' - ] - if relativeIDofMothIDrelationship != 'bud': - # The wrong bud ID is a cell in G1 from previous cycle --> - # the actual wrong bud ID disappeared too. - return - - wrongBudID = relativeIDofMothID - - mothCcaBeforeWrongBudID = base_cca_dict - # Search in the past for status of mother before wrong bud emerged - for past_i in range(frame_i, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - if wrongBudID not in cca_df_i.index: - mothCcaBeforeWrongBudID = cca_df_i.loc[mothIDofDisappearedBud] - break - - # Restore in past frames the correct mother status - for past_i in range(frame_i, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) - if wrongBudID in cca_df_i.index: - cca_df_i.loc[mothIDofDisappearedBud] = mothCcaBeforeWrongBudID - cca_df_i.at[mothIDofDisappearedBud, 'corrected_on_frame_i'] = frame_i - self.store_cca_df( - frame_i=past_i, cca_df=cca_df_i, autosave=False - ) - else: - break - - def getClosedSplineCoords(self): - xxS, yyS = self.curvPlotItem.getData() - bbox_area = (xxS.max()-xxS.min())*(yyS.max()-yyS.min()) - if bbox_area < 26_000: - # Using 1000 is fast enough according to profiling - return xxS, yyS - - optimalSpaceSize = self.splineToObjModel.predict( - bbox_area, max_exec_time=150 - ) - if optimalSpaceSize >= 1000: - # Using 1000 is fast enough according to model - return xxS, yyS - - if optimalSpaceSize < 100: - # Do not allow a rough spline - optimalSpaceSize = 100 - - # Get spline with optimal space size so that exec time - # or skimage.draw.polygon is less than 150 ms - xx, yy = self.curvAnchors.getData() - resolutionSpace = np.linspace(0, 1, int(optimalSpaceSize)) - xxS, yyS = self.getSpline( - xx, yy, resolutionSpace=resolutionSpace, per=True - ) - return xxS, yyS - - - def getSpline(self, xx, yy, resolutionSpace=None, per=False, appendFirst=False): - # Remove duplicates - valid = np.where(np.abs(np.diff(xx)) + np.abs(np.diff(yy)) > 0) - xx = np.r_[xx[valid], xx[-1]] - yy = np.r_[yy[valid], yy[-1]] - if appendFirst: - xx = np.r_[xx, xx[0]] - yy = np.r_[yy, yy[0]] - per = True - - # Interpolate splice - if resolutionSpace is None: - resolutionSpace = self.hoverLinSpace - k = 2 if len(xx) == 3 else 3 - - try: - tck, u = scipy.interpolate.splprep( - [xx, yy], s=0, k=k, per=per - ) - xi, yi = scipy.interpolate.splev(resolutionSpace, tck) - return xi, yi - except (ValueError, TypeError): - # Catch errors where we know why splprep fails - return [], [] - - def uncheckQButton(self, button): - # Manual exclusive where we allow to uncheck all buttons - for b in self.checkableQButtonsGroup.buttons(): - if b != button: - b.setChecked(False) - - def delBorderObj(self, checked): - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - posData = self.data[self.pos_i] - posData.lab = skimage.segmentation.clear_border( - posData.lab, buffer_size=1 - ) - oldIDs = posData.IDs.copy() - self.update_rp() - removedIDs = [ID for ID in oldIDs if ID not in posData.IDs] - if posData.cca_df is not None: - posData.cca_df = posData.cca_df.drop(index=removedIDs) - self.store_data() - self.updateAllImages() - - def delNewObj(self, checked): - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - posData = self.data[self.pos_i] - frame_i = posData.frame_i - - if frame_i == 0: - return - - prev_IDs = posData.allData_li[frame_i-1]['IDs'] - curr_IDs = posData.IDs - new_IDs = list(set(curr_IDs) - set(prev_IDs)) - - lab = posData.lab - del_mask = np.isin(lab, new_IDs) - lab[del_mask] = 0 - posData.lab = lab - - self.update_rp() - - if posData.cca_df is not None: - posData.cca_df = posData.cca_df.drop(index=new_IDs) - self.store_data() - self.updateAllImages() - - def brushAutoFillToggled(self, checked): - val = 'Yes' if checked else 'No' - self.df_settings.at['brushAutoFill', 'value'] = val - self.df_settings.to_csv(self.settings_csv_path) - - def brushAutoHideToggled(self, checked): - val = 'Yes' if checked else 'No' - self.df_settings.at['brushAutoHide', 'value'] = val - self.df_settings.to_csv(self.settings_csv_path) - - def brushReleased(self): - posData = self.data[self.pos_i] - self.fillHolesID(posData.brushID, sender='brush') - - # Update data (rp, etc) - self.update_rp(update_IDs=self.isNewID,) - - # Repeat tracking - if self.autoIDcheckbox.isChecked(): - self.trackManuallyAddedObject(posData.brushID, self.isNewID) - else: - self.update_rp(update_IDs=posData.brushID not in posData.IDs_idxs) - - # Update images - if self.isNewID: - editTxt = 'Add new ID with brush tool' - if self.isSnapshot: - self.fixCcaDfAfterEdit(editTxt) - self.updateAllImages() - else: - self.warnEditingWithCca_df(editTxt) - else: - self.updateAllImages() - - self.isNewID = False - - def addDelROI(self, event): - roi, key = self.createDelROI() - self.addRoiToDelRoiInfo(roi) - if not self.labelsGrad.showLabelsImgAction.isChecked(): - self.ax1.addDelRoiItem(roi, key) - else: - self.ax2.addDelRoiItem(roi, key) - self.applyDelROIimg1(roi, init=True) - self.applyDelROIimg1(roi, init=True, ax=1) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Delete IDs using ROI') - self.updateAllImages() - else: - self.warnEditingWithCca_df( - 'Delete IDs using ROI', get_cancelled=True - ) - - def replacePolyLineRoiWithLineRoi(self, roi): - x0, y0 = roi.pos().x(), roi.pos().y() - (_, point1), (_, point2) = roi.getLocalHandlePositions() - xr1, yr1 = point1.x(), point1.y() - xr2, yr2 = point2.x(), point2.y() - x1, y1 = xr1+x0, yr1+y0 - x2, y2 = xr2+x0, yr2+x0 - lineRoi = pg.LineROI((x1, y1), (x2, y2), width=0.5) - lineRoi.handleSize = 7 - self.ax1.removeItem(self.polyLineRoi) - self.ax1.addItem(lineRoi) - lineRoi.removeHandle(2) - # Connect closed ROI - lineRoi.sigRegionChanged.connect(self.delROImoving) - lineRoi.sigRegionChangeFinished.connect(self.delROImovingFinished) - return lineRoi - - def addRoiToDelRoiInfo(self, roi: pg.ROI): - posData = self.data[self.pos_i] - for i in range(posData.frame_i, posData.SizeT): - delROIs_info = posData.allData_li[i]['delROIs_info'] - delROIs_info['rois'].append(roi) - delROIs_info['state'].append(roi.getState()) - delROIs_info['delMasks'].append(np.zeros_like(self.currentLab2D)) - delROIs_info['delIDsROI'].append(set()) - - def addDelPolyLineRoi_cb(self, checked): - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.addDelPolyLineRoiButton) - self.connectLeftClickButtons() - if self.isSnapshot: - self.fixCcaDfAfterEdit('Delete IDs using ROI') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Delete IDs using ROI') - else: - self.tempSegmentON = False - self.ax1_rulerPlotItem.setData([], []) - self.ax1_rulerAnchorsItem.setData([], []) - self.startPointPolyLineItem.setData([], []) - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - def createDelPolyLineRoi(self): - Y, X = self.currentLab2D.shape - self.polyLineRoi = pg.PolyLineROI( - [], rotatable=False, - removable=True, - pen=pg.mkPen(color='r') - ) - self.polyLineRoi.handleSize = 7 - self.polyLineRoi.points = [] - key = uuid.uuid4() - self.ax1.addDelRoiItem(self.polyLineRoi, key) - - def addPointsPolyLineRoi(self, closed=False): - self.polyLineRoi.setPoints(self.polyLineRoi.points, closed=closed) - if not closed: - return - - # Connect closed ROI - self.polyLineRoi.sigRegionChanged.connect(self.delROImoving) - self.polyLineRoi.sigRegionChangeFinished.connect(self.delROImovingFinished) - - def getViewRange(self): - Y, X = self.img1.image.shape[:2] - xRange, yRange = self.ax1.viewRange() - xmin = 0 if xRange[0] < 0 else xRange[0] - ymin = 0 if yRange[0] < 0 else yRange[0] - - xmax = X if xRange[1] >= X else xRange[1] - ymax = Y if yRange[1] >= Y else yRange[1] - return int(ymin), int(ymax), int(xmin), int(xmax) - - def createDelROI(self, xl=None, yb=None, w=32, h=32, anchors=None): - posData = self.data[self.pos_i] - if xl is None: - xRange, yRange = self.ax1.viewRange() - xl = 0 if xRange[0] < 0 else xRange[0] - yb = 0 if yRange[0] < 0 else yRange[0] - Y, X = self.currentLab2D.shape - if anchors is None: - roi = widgets.DelROI( - [xl, yb], [w, h], - rotatable=False, - removable=True, - pen=pg.mkPen(color='r'), - maxBounds=QRectF(QRect(0,0,X,Y)) - ) - ## handles scaling horizontally around center - roi.addScaleHandle([1, 0.5], [0, 0.5]) - roi.addScaleHandle([0, 0.5], [1, 0.5]) - - ## handles scaling vertically from opposite edge - roi.addScaleHandle([0.5, 0], [0.5, 1]) - roi.addScaleHandle([0.5, 1], [0.5, 0]) - - ## handles scaling both vertically and horizontally - roi.addScaleHandle([1, 1], [0, 0]) - roi.addScaleHandle([0, 0], [1, 1]) - roi.addScaleHandle([0, 1], [1, 0]) - roi.addScaleHandle([1, 0], [0, 1]) - - roi.handleSize = 7 - roi.sigRegionChanged.connect(self.delROImoving) - roi.sigRegionChanged.connect(self.delROIstartedMoving) - roi.sigRegionChangeFinished.connect(self.delROImovingFinished) - - key = uuid.uuid4() - - return roi, key - - def delROIstartedMoving(self, roi): - self.clearLostObjContoursItems() - - def clearLostObjContoursItems(self): - self.ax1_lostObjScatterItem.setData([], []) - self.ax2_lostObjScatterItem.setData([], []) - - self.ax1_lostTrackedScatterItem.setData([], []) - self.ax2_lostTrackedScatterItem.setData([], []) - - self.ax2_lostObjImageItem.clear() - self.ax2_lostTrackedObjImageItem.clear() - - self.ax1_lostObjImageItem.clear() - self.ax1_lostTrackedObjImageItem.clear() - - def delROImoving(self, roi): - roi.setPen(color=(255,255,0)) - # First bring back IDs if the ROI moved away - self.restoreAnnotDelROI(roi) - self.setImageImg2() - self.applyDelROIimg1(roi) - self.applyDelROIimg1(roi, ax=1) - - def delROImovingFinished(self, roi: pg.ROI): - roi.setPen(color='r') - self.update_rp() - self.updateAllImages() - QTimer.singleShot( - 300, partial(self.updateDelROIinFutureFrames, roi) - ) - - def restoreAnnotDelROI(self, roi, enforce=True, draw=True): - posData = self.data[self.pos_i] - ROImask = self.getDelRoiMask(roi) - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - try: - idx = delROIs_info['rois'].index(roi) - except Exception as err: - return - - delMask = delROIs_info['delMasks'][idx] - delIDs = delROIs_info['delIDsROI'][idx] - overlapROIdelIDs = np.unique(delMask[ROImask]) - lab2D = self.get_2Dlab(posData.lab) - restoredIDs = set() - for ID in delIDs: - if ID in overlapROIdelIDs and not enforce: - continue - - restoredIDs.add(ID) - - delMaskID = delMask==ID - self.currentLab2D[delMaskID] = ID - lab2D[delMaskID] = ID - - if draw: - self.restoreDelROIimg1(delMaskID, ID, ax=0) - self.restoreDelROIimg1(delMaskID, ID, ax=1) - - delMask[delMaskID] = 0 - - delROIs_info['delIDsROI'][idx] = delIDs - restoredIDs - self.set_2Dlab(lab2D) - self.update_rp() - - def restoreDelROIimg1(self, delMaskID, delID, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - if how.find('nothing') != -1: - return - - if how.find('contours') != -1: - rp_delmask = skimage.measure.regionprops(delMaskID.astype(np.uint8)) - if len(rp_delmask) > 0: - obj = rp_delmask[0] - self.addObjContourToContoursImage(obj=obj, ax=ax) - elif how.find('overlay segm. masks') != -1: - if ax == 0: - self.labelsLayerImg1.setImage( - self.currentLab2D, autoLevels=False - ) - else: - self.labelsLayerRightImg.setImage( - self.currentLab2D, autoLevels=False - ) - - def getDelRoisIDs(self): - posData = self.data[self.pos_i] - if posData.frame_i > 0: - prev_lab = posData.allData_li[posData.frame_i-1]['labels'] - allDelIDs = set() - for roi in posData.allData_li[posData.frame_i]['delROIs_info']['rois']: - if ( - not self.ax1.isDelRoiItemPresent(roi) - and not self.ax2.isDelRoiItemPresent(roi) - ): - continue - - ROImask = self.getDelRoiMask(roi) - delIDs = posData.lab[ROImask] - allDelIDs.update(delIDs) - if posData.frame_i > 0: - delIDsPrevFrame = prev_lab[ROImask] - allDelIDs.update(delIDsPrevFrame) - return allDelIDs - - def getStoredDelRoiIDs(self, frame_i=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - allDelIDs = set() - delROIs_info = posData.allData_li[frame_i]['delROIs_info'] - delIDs_rois = delROIs_info['delIDsROI'] - for delIDs in delIDs_rois: - allDelIDs.update(delIDs) - return allDelIDs - - # @exec_time - def getDelROIlab(self, input_lab_2D=None): - posData = self.data[self.pos_i] - if self.delRoiLab is None: - self.initDelRoiLab() - - out_lab = self.delRoiLab - if input_lab_2D is None: - out_lab[:] = self.get_2Dlab(posData.lab, force_z=False) - else: - out_lab[:] = input_lab_2D - - allDelIDs = set() - # Iterate rois and delete IDs - for roi in posData.allData_li[posData.frame_i]['delROIs_info']['rois']: - if ( - not self.ax1.isDelRoiItemPresent(roi) - and not self.ax2.isDelRoiItemPresent(roi) - ): - continue - ROImask = self.getDelRoiMask(roi) - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - idx = delROIs_info['rois'].index(roi) - delObjROImask = delROIs_info['delMasks'][idx] - delIDsROI = delROIs_info['delIDsROI'][idx] - delROIlabRp = skimage.measure.regionprops(out_lab) - for delObj in delROIlabRp: - isDelObj = np.any(ROImask[delObj.slice][delObj.image]) - if not isDelObj: - continue - - delObjROImask[delObj.slice][delObj.image] = delObj.label - out_lab[delObj.slice][delObj.image] = 0 - - delIDsROI.add(delObj.label) - allDelIDs.add(delObj.label) - - # Keep a mask of deleted IDs to bring them back when roi moves - delROIs_info['delMasks'][idx] = delObjROImask - delROIs_info['delIDsROI'][idx] = delIDsROI - - # printl( - # f't1-t0: {(t1-t0)*1000:.3f} ms,', - # f't2-t1: {(t2-t1)*1000:.3f} ms,', - # f't3-t2: {(t3-t2)*1000:.3f} ms,', - # # f't4-t3: {(t4-t3)*1000:.3f} ms,', - # # f't5-t4: {(t5-t4)*1000:.3f} ms,', - # # f't6-t5: {(t6-t5)*1000:.3f} ms', - # sep='\n' - # ) - - return allDelIDs, out_lab - - def getDelRoiMask(self, roi, posData=None, z_slice=None): - if posData is None: - posData = self.data[self.pos_i] - if z_slice is None: - z_slice = self.z_lab() - ROImask = np.zeros(posData.lab.shape, bool) - if isinstance(roi, pg.PolyLineROI): - r, c = [], [] - x0, y0 = roi.pos().x(), roi.pos().y() - for _, point in roi.getLocalHandlePositions(): - xr, yr = point.x(), point.y() - r.append(int(yr+y0)) - c.append(int(xr+x0)) - if not r or not c: - return ROImask - - if len(r) == 2: - rr, cc, val = skimage.draw.line_aa(r[0], c[0], r[1], c[1]) - else: - rr, cc = skimage.draw.polygon(r, c, shape=self.currentLab2D.shape) - - Y, X = self.currentLab2D.shape - rr = rr[(rr>=0) & (rr=0) & (cc{descr} {channel}
: value={value:{ff}}' - ) - return txt - - def _addOverlayHoverValuesFormatted(self, txt, xdata, ydata): - posData = self.data[self.pos_i] - if posData.ol_data is None: - return txt - - for filename in posData.ol_data: - chName = myutils.get_chname_from_basename( - filename, posData.basename, remove_ext=False - ) - if chName not in self.checkedOverlayChannels: - continue - - raw_overlay_img = self.getRawImage(filename=filename) - raw_overlay_value = raw_overlay_img[ydata, xdata] - # raw_overlay_max_value = raw_overlay_img.max() - - raw_txt = self._channelHoverValues('Raw', chName, raw_overlay_value) - - txt = f'{txt} | {raw_txt}' - return txt - - def getActiveToolButton(self): - for button in self.LeftClickButtons: - if button.isChecked(): - return button - - def getConcatAcdcDf(self): - acdc_dfs = [] - keys = [] - posData = self.data[self.pos_i] - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - break - - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - break - - acdc_dfs.append(acdc_df) - keys.append(frame_i) - - if not acdc_dfs: - return - - return pd.concat(acdc_dfs, keys=keys, names=['frame_i']) - - - def checkHighlightTimestamp(self, x, y, activeToolButton): - if not hasattr(self, 'timestamp'): - return - - if not self.addTimestampAction.isChecked(): - return - - if activeToolButton is not None: - return - - if hasattr(self, 'scaleBar'): - if self.scaleBar.isHighlighted(): - return - - ymin, xmin, ymax, xmax = self.timestamp.bbox() - if x < xmin: - self.timestamp.setHighlighted(False) - return - - if x > xmax: - self.timestamp.setHighlighted(False) - return - - if y < ymin: - self.timestamp.setHighlighted(False) - return - - if y > ymax: - self.timestamp.setHighlighted(False) - return - - self.timestamp.setHighlighted(True) - - def checkHighlightScaleBar(self, x, y, activeToolButton): - if not hasattr(self, 'scaleBar'): - return - - if not self.addScaleBarAction.isChecked(): - return - - if activeToolButton is not None: - return - - ymin, xmin, ymax, xmax = self.scaleBar.bbox() - if x < xmin: - self.scaleBar.setHighlighted(False) - return - - if x > xmax: - self.scaleBar.setHighlighted(False) - return - - if y < ymin: - self.scaleBar.setHighlighted(False) - return - - if y > ymax: - self.scaleBar.setHighlighted(False) - return - - self.scaleBar.setHighlighted(True) - - def getMouseDataCoordsRightImage(self): - text = self.wcLabel.text() - if not text: - return - - ax_idx = int(re.findall(r'\(ax(\d)\)', text)[0]) - if ax_idx == 0: - return - - coords = re.findall(r'x=(\d+), y=(\d+) \|', text)[0] - - return tuple([int(val) for val in coords]) - - def updateValuesStatusBar(self): - (xl, xr), (yt, yb) = self.ax1ViewRange(integers=True) - W = round(xr - xl) - H = round(yb - yt) - txt = self.wcLabel.text() - pattern = ( - r'W=.*?, H=.*? \| ' - r'x_left=.*?, y_top=.*? \| ' - r'x_right=.*?, y_bottom=.*? \| ' - ) - replacing = ( - f'W={W:d}, H={H:d} | ' - f'x_left={xl:d}, y_top={yt:d} | ' - f'x_right={xr:d}, y_bottom={yb:d} | ' - ) - txt = re.sub(pattern, replacing, txt) - self.wcLabel.setText(txt) - - def hoverValuesFormatted(self, xdata, ydata, activeToolButton, is_ax0): - (xl, xr), (yt, yb) = self.ax1ViewRange(integers=True) - W = round(xr - xl) - H = round(yb - yt) - ax_idx = 0 if is_ax0 else 1 - txt = ( - f'x={xdata:d}, y={ydata:d} | ' - f'W={W:d}, H={H:d} | ' - f'x_left={xl:d}, y_top={yt:d} | ' - f'x_right={xr:d}, y_bottom={yb:d} | ' - f'(ax{ax_idx})' - ) - if activeToolButton == self.rulerButton: - txt = self._addRulerMeasurementText(txt) - return txt - elif activeToolButton is not None: - return txt - - posData = self.data[self.pos_i] - - raw_img = self.getRawImage() - raw_value = raw_img[ydata, xdata] - # raw_max_value = raw_img.max() - - ch = self.user_ch_name - raw_txt = self._channelHoverValues('Raw', ch, raw_value) - - txt = f'{txt} | {raw_txt}' - - txt = self._addOverlayHoverValuesFormatted(txt, xdata, ydata) - - ID = self.currentLab2D[ydata, xdata] - maxID = max(posData.IDs, default=0) - - num_obj = len(posData.IDs) - lab_txt = ( - f'Objects: ID={ID}, max ID={maxID}, ' - f'num. of objects={num_obj}' - ) - txt = f'{txt} | {lab_txt}' - - txt = self._addRulerMeasurementText(txt) - return txt - - def getRulerLengthText(self): - text = self.wcLabel.text() - lengthText = re.findall(r'length = (.*)\)', text)[0] - lengthText = lengthText.replace('pxl', 'pixels') - return f'{lengthText})' - - def _addRulerMeasurementText(self, txt): - posData = self.data[self.pos_i] - xx, yy = self.ax1_rulerPlotItem.getData() - if xx is None: - return txt - - lenPxl = math.sqrt((xx[0]-xx[1])**2 + (yy[0]-yy[1])**2) - depthAxes = self.switchPlaneCombobox.depthAxes() - if depthAxes != 'z': - pxlToUm = posData.PhysicalSizeZ - else: - pxlToUm = posData.PhysicalSizeX - - length_txt = ( - f'length = {int(lenPxl)} pxl ({lenPxl*pxlToUm:.2f} μm)' - ) - txt = f'{txt} | Measurement: {length_txt}' - return txt - - def updateImageValueFormatter(self): - if self.img1.image is not None: - dtype = self.img1.image.dtype - n_digits = len(str(int(self.img1.image.max()))) - self.imgValueFormatter = myutils.get_number_fstring_formatter( - dtype, precision=abs(n_digits-5) - ) - - rawImgData = self.data[self.pos_i].img_data - dtype = rawImgData.dtype - n_digits = len(str(int(rawImgData.max()))) - self.rawValueFormatter = myutils.get_number_fstring_formatter( - dtype, precision=abs(n_digits-5) - ) - - def normaliseIntensitiesActionTriggered(self, action): - how = action.text() - self.df_settings.at['how_normIntensities', 'value'] = how - self.df_settings.to_csv(self.settings_csv_path) - self.updateAllImages() - self.updateImageValueFormatter() - - def setLastUserNormAction(self): - how = self.df_settings.at['how_normIntensities', 'value'] - for action in self.normalizeQActionGroup.actions(): - if action.text() == how: - action.setChecked(True) - break - - def saveLabelsColormap(self): - self.labelsGrad.saveColormap() - - def addFontSizeActions(self, menu, slot): - fontActionGroup = QActionGroup(self) - fontActionGroup.setExclusive(True) - for fontSize in range(4,27): - action = QAction(self) - action.setText(str(fontSize)) - action.setCheckable(True) - if fontSize == self.fontSize: - action.setChecked(True) - fontActionGroup.addAction(action) - menu.addAction(action) - action.triggered.connect(slot) - return fontActionGroup - - @exception_handler - def changeFontSize(self): - fontSize = self.fontSizeSpinBox.value() - if fontSize == self.fontSize: - return - - self.fontSize = fontSize - - self.df_settings.at['fontSize', 'value'] = self.fontSize - self.df_settings.to_csv(self.settings_csv_path) - - self.setAllIDs() - posData = self.data[self.pos_i] - for ax in range(2): - self.textAnnot[ax].changeFontSize(self.fontSize) - if self.highLowResAction.isChecked(): - self.setAllTextAnnotations() - else: - self.updateAllImages() - - def enableZstackWidgets(self, enabled): - if enabled: - myutils.setRetainSizePolicy(self.zSliceScrollBar) - myutils.setRetainSizePolicy(self.zProjComboBox) - myutils.setRetainSizePolicy(self.zSliceOverlay_SB) - myutils.setRetainSizePolicy(self.zProjOverlay_CB) - myutils.setRetainSizePolicy(self.overlay_z_label) - self.zSliceScrollBar.setDisabled(False) - self.zProjComboBox.show() - if self.data[self.pos_i].SizeT > 1: - self.zProjLockViewButton.show() - self.zSliceScrollBar.show() - self.zSliceCheckbox.show() - self.zSliceSpinbox.show() - self.switchPlaneCombobox.show() - self.switchPlaneCombobox.setDisabled(False) - self.SizeZlabel.show() - else: - myutils.setRetainSizePolicy(self.zSliceScrollBar, retain=False) - myutils.setRetainSizePolicy(self.zProjComboBox, retain=False) - myutils.setRetainSizePolicy(self.zSliceOverlay_SB, retain=False) - myutils.setRetainSizePolicy(self.zProjOverlay_CB, retain=False) - myutils.setRetainSizePolicy(self.overlay_z_label, retain=False) - self.zSliceScrollBar.setDisabled(True) - self.zProjComboBox.hide() - self.zProjComboBox.hide() - self.zSliceScrollBar.hide() - self.zSliceCheckbox.hide() - self.zSliceSpinbox.hide() - self.SizeZlabel.hide() - self.switchPlaneCombobox.hide() - self.switchPlaneCombobox.setDisabled(True) - - self.imgGrad.rescaleAcrossZstackAction.setDisabled(not enabled) - for ch, overlayItems in self.overlayLayersItems.items(): - lutItem = overlayItems[1] - lutItem.rescaleAcrossZstackAction.setDisabled(not enabled) - - def reInitCca(self): - if not self.isSnapshot: - txt = html_utils.paragraph( - 'If you decide to continue ALL cell cycle annotations from ' - 'this frame to the end will be erased from current session ' - '(saved data is not touched of course).

' - 'To annotate future frames again you will have to revisit them.

' - 'Do you want to continue?' - ) - msg = widgets.myMessageBox() - msg.warning( - self, 'Re-initialize annnotations?', txt, - buttonsTexts=('Cancel', 'Yes') - ) - posData = self.data[self.pos_i] - if msg.cancel: - return - - # Reset all future frames - self.resetCcaFuture(posData.frame_i+1) - if posData.frame_i == 0: - # Reset everything since we are on first frame - posData.cca_df = self.getBaseCca_df() - self.store_data() - self.updateAllImages() - self.navigateScrollBar.setMaximum(posData.frame_i+1) - self.navSpinBox.setMaximum(posData.frame_i+1) - else: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - posData = self.data[self.pos_i] - posData.cca_df = self.getBaseCca_df() - self.store_data() - self.updateAllImages() - - - def repeatAutoCca(self): - # Do not allow automatic bud assignment if there are future - # frames that already contain anotations - posData = self.data[self.pos_i] - next_df = posData.allData_li[posData.frame_i+1]['acdc_df'] - if next_df is not None: - if 'cell_cycle_stage' in next_df.columns: - msg = QMessageBox() - warn_cca = msg.critical( - self, 'Future visited frames detected!', - 'Automatic bud assignment CANNOT be performed becasue ' - 'there are future frames that already contain cell cycle ' - 'annotations. The behaviour in this case cannot be predicted.\n\n' - 'We suggest assigning the bud manually OR use the ' - '"Re-initialize cell cycle annotations" button which properly ' - 're-initialize future frames.', - msg.Ok - ) - return - - correctedAssignIDs = ( - posData.cca_df[posData.cca_df['corrected_on_frame_i']>=0].index - ) - NeverCorrectedAssignIDs = [ - ID for ID in posData.new_IDs if ID not in correctedAssignIDs - ] - - # Store cca_df temporarily if attempt_auto_cca fails - posData.cca_df_beforeRepeat = posData.cca_df.copy() - - if not all(NeverCorrectedAssignIDs): - notEnoughG1Cells, proceed = self.attempt_auto_cca() - if notEnoughG1Cells or not proceed: - posData.cca_df = posData.cca_df_beforeRepeat - else: - self.updateAllImages() - return - - msg = QMessageBox() - msg.setIcon(msg.Question) - msg.setText( - 'Do you want to automatically assign buds to mother cells for ' - 'ALL the new cells in this frame (excluding cells with unknown history) ' - 'OR only the cells where you never clicked on?' - ) - msg.setDetailedText( - f'New cells that you never touched:\n\n{NeverCorrectedAssignIDs}') - enforceAllButton = QPushButton('ALL new cells') - b = QPushButton('Only cells that I never corrected assignment') - msg.addButton(b, msg.YesRole) - msg.addButton(enforceAllButton, msg.NoRole) - msg.exec_() - if msg.clickedButton() == enforceAllButton: - notEnoughG1Cells, proceed = self.attempt_auto_cca(enforceAll=True) - else: - notEnoughG1Cells, proceed = self.attempt_auto_cca() - if notEnoughG1Cells or not proceed: - posData.cca_df = posData.cca_df_beforeRepeat - else: - self.updateAllImages() - - def manualEditCcaToolbarActionTriggered(self): - self.manualEditCca() - - def askGet2Dor3Dimage(self): - txt = html_utils.paragraph(""" - Do you want to test the denoising on the visualized 2D image or - on the entire 3D z-stack? - """) - msg = widgets.myMessageBox(wrapText=False) - _, use3Dbutton, use2Dbutton = msg.question( - self, '3D denoising?', txt, - buttonsTexts=('Cancel', 'Denoise 3D z-stack', 'Denoise 2D image') - ) - if msg.cancel: - return - - if msg.clickedButton == use3Dbutton: - posData = self.data[self.pos_i] - zslice = self.zSliceScrollBar.sliderPosition() - return posData.img_data[posData.frame_i, zslice] - else: - return self.getDisplayedImg1() - - def manualEditCca(self, checked=True): - posData = self.data[self.pos_i] - editCcaWidget = apps.editCcaTableWidget( - posData.cca_df, posData.SizeT, current_frame_i=posData.frame_i, - parent=self - ) - editCcaWidget.sigApplyChangesFutureFrames.connect( - self.applyManualCcaChangesFutureFrames - ) - editCcaWidget.exec_() - if editCcaWidget.cancel: - return - posData.cca_df = editCcaWidget.cca_df - self.store_cca_df() - # self.checkMultiBudMoth() - self.updateAllImages() - - @exception_handler - def applyManualCcaChangesFutureFrames(self, changes, stop_frame_i): - self.store_data(autosave=False) - posData = self.data[self.pos_i] - undoId = uuid.uuid4() - for i in range(posData.frame_i, stop_frame_i): - cca_df_i = self.get_cca_df(frame_i=i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(i, cca_df_i, undoId) - - for ID, changes_ID in changes.items(): - if ID not in cca_df_i.index: - continue - for col, (oldValue, newValue) in changes_ID.items(): - cca_df_i.at[ID, col] = newValue - self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) - self.get_data() - self.updateAllImages() - - def annotateRightHowCombobox_cb(self, idx): - how = self.annotateRightHowCombobox.currentText() - saveSettings = True - if hasattr(self.annotateRightHowCombobox, 'saveSettings'): - saveSettings = self.annotateRightHowCombobox.saveSettings - - if saveSettings: - self.df_settings.at['how_draw_right_annotations', 'value'] = how - self.df_settings.to_csv(self.settings_csv_path) - - mode = self.modeComboBox.currentText() - isCcaAnnot = ( - self.annotCcaInfoCheckboxRight.isChecked() and - mode != 'Normal division: Lineage tree' - ) - isIDAnnot = (self.annotIDsCheckboxRight.isChecked() or ( - self.annotCcaInfoCheckboxRight.isChecked() and - mode == 'Normal division: Lineage tree' - )) - self.textAnnot[1].setCcaAnnot( - isCcaAnnot - ) - - self.textAnnot[1].setLabelAnnot( - isIDAnnot - ) - if not self.isDataLoading: - self.updateAllImages() - - def drawIDsContComboBox_cb(self, idx): - how = self.drawIDsContComboBox.currentText() - saveSettings = True - if hasattr(self.drawIDsContComboBox, 'saveSettings'): - saveSettings = self.drawIDsContComboBox.saveSettings - - if saveSettings: - self.df_settings.at['how_draw_annotations', 'value'] = how - self.df_settings.to_csv(self.settings_csv_path) - - mode = self.modeComboBox.currentText() - isCcaAnnot = ( - self.annotCcaInfoCheckbox.isChecked() and - mode != 'Normal division: Lineage tree' - ) - isIDAnnot = (self.annotIDsCheckbox.isChecked() or ( - self.annotCcaInfoCheckbox.isChecked() and - mode == 'Normal division: Lineage tree' - )) - self.textAnnot[0].setCcaAnnot( - isCcaAnnot - ) - - self.textAnnot[0].setLabelAnnot( - isIDAnnot - ) - - if not self.isDataLoading: - self.updateAllImages() - - if self.eraserButton.isChecked(): - self.setTempImg1Eraser(None, init=True) - - def mousePressColorButton(self, event): - posData = self.data[self.pos_i] - items = list(self.checkedOverlayChannels) - if len(items)>1: - selectFluo = widgets.QDialogListbox( - 'Select image', - 'Select which fluorescence image you want to update the color of\n', - items, multiSelection=False, parent=self - ) - selectFluo.exec_() - keys = selectFluo.selectedItemsText - if selectFluo.cancel or not keys: - return - else: - self.overlayColorButton.channel = keys[0] - else: - self.overlayColorButton.channel = items[0] - self.overlayColorButton.selectColor() - - def setEnabledCcaToolbar(self, enabled=False): - self.manuallyEditCcaAction.setDisabled(False) - self.viewCcaTableAction.setDisabled(False) - self.ccaToolBar.setVisible(enabled) - for action in self.ccaToolBar.actions(): - button = self.ccaToolBar.widgetForAction(action) - action.setVisible(enabled) - button.setEnabled(enabled) - - # def setEnabledCcaToolbar(self, enabled=False): - # self.manuallyEditCcaAction.setDisabled(False) - # self.viewCcaTableAction.setDisabled(False) - # self.ccaToolBar.setVisible(enabled) - # for action in self.ccaToolBar.actions(): - # button = self.ccaToolBar.widgetForAction(action) - # action.setVisible(enabled) - # button.setEnabled(enabled) - - def setEnabledEditToolbarButton(self, enabled=False): - for action in self.segmActions: - action.setEnabled(enabled) - - for action in self.segmActionsVideo: - action.setEnabled(enabled) - - self.relabelSequentialAction.setEnabled(enabled) - self.repeatTrackingMenuAction.setEnabled(enabled) - self.repeatTrackingVideoAction.setEnabled(enabled) - self.postProcessSegmAction.setEnabled(enabled) - self.autoSegmAction.setEnabled(enabled) - self.editToolBar.setVisible(enabled) - mode = self.modeComboBox.currentText() - ccaON = mode == 'Cell cycle analysis' - for action in self.editToolBar.actions(): - button = self.editToolBar.widgetForAction(action) - # Keep binCellButton active in cca mode - if button==self.binCellButton and not enabled and ccaON: - action.setVisible(True) - button.setEnabled(True) - else: - action.setVisible(enabled) - button.setEnabled(enabled) - if not enabled: - self.setUncheckedAllButtons() - - def setEnabledFileToolbar(self, enabled): - for action in self.fileToolBar.actions(): - button = self.fileToolBar.widgetForAction(action) - if action == self.openFolderAction or action == self.newAction: - continue - if action == self.manageVersionsAction: - continue - if action == self.openFileAction: - continue - action.setEnabled(enabled) - button.setEnabled(enabled) - - def reconnectUndoRedo(self): - try: - self.undoAction.triggered.disconnect() - self.redoAction.triggered.disconnect() - except Exception as e: - pass - mode = self.modeComboBox.currentText() - if mode == 'Segmentation and Tracking' or mode == 'Snapshot': - self.undoAction.triggered.connect(self.undo) - self.redoAction.triggered.connect(self.redo) - elif mode == 'Cell cycle analysis': - self.undoAction.triggered.connect(self.UndoCca) - elif mode == 'Custom annotations': - self.undoAction.triggered.connect(self.undoCustomAnnotation) - else: - self.undoAction.setDisabled(True) - self.redoAction.setDisabled(True) - - def enableSizeSpinbox(self, enabled): - self.brushSizeLabelAction.setVisible(enabled) - self.brushSizeAction.setVisible(enabled) - self.brushAutoFillAction.setVisible(enabled) - self.brushAutoHideAction.setVisible(enabled) - self.brushEraserToolBar.setVisible(enabled) - self.disableNonFunctionalButtons() - - def reload_cb(self): - posData = self.data[self.pos_i] - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - labData = np.load(posData.segm_npz_path) - # Keep compatibility with .npy and .npz files - try: - lab = labData['arr_0'][posData.frame_i] - except Exception as e: - lab = labData[posData.frame_i] - posData.segm_data[posData.frame_i] = lab.copy() - self.get_data() - self.tracking() - self.updateAllImages() - - def clearComboBoxFocus(self, mode): - # Remove focus from modeComboBox to avoid the key_up changes its value - self.sender().clearFocus() - try: - self.timer.stop() - self.modeComboBox.setStyleSheet('background-color: none') - except Exception as e: - pass - - def updateModeMenuAction(self): - self.modeActionGroup.triggered.disconnect() - for action in self.modeActionGroup.actions(): - if action.text() != self.modeComboBox.currentText(): - continue - action.setChecked(True) - break - self.modeActionGroup.triggered.connect(self.changeModeFromMenu) - - def changeModeFromMenu(self, action): - self.modeComboBox.setCurrentText(action.text()) - - def restorePrevAnnotOptions(self): - if self.prevAnnotOptions is None: - return - self.restoreAnnotOptions_ax1(options=self.prevAnnotOptions) - self.setDrawAnnotComboboxText() - self.prevAnnotOptions = None - - def uncheckAllButtonsFromButtonGroup(self, buttonGroup): - for button in buttonGroup.buttons(): - if not button.isCheckable(): - continue - - if not button.isChecked(): - continue - - button.setChecked(False) - - @disableWindow - def changeMode(self, text): - self.reconnectUndoRedo() - self.updateModeMenuAction() - self.clearCustomAnnot() - posData = self.data[self.pos_i] - mode = text - prevMode = self.modeComboBox.previousText() - self.annotateToolbar.setVisible(False) - if prevMode != 'Viewer': - self.store_data(autosave=True) - - self.copyLostObjButton.setChecked(False) - self.stopCcaIntegrityCheckerWorker() - self.setAutoSaveSegmentationEnabled(False) - self.setAutoSaveAnnotationsEnabled(False) - if prevMode == 'Normal division: Lineage tree': - self.askLineageTreeChanges() - self.lineage_tree = None - self.editLin_TreeBar.setVisible(False) - self.uncheckAllButtonsFromButtonGroup(self.editLin_TreeGroup) - - elif prevMode == 'Cell cycle analysis': - self.setEnabledCcaToolbar(enabled=False) - - if mode == 'Segmentation and Tracking': - self.setAutoSaveSegmentationEnabled(True) - self.setSwitchViewedPlaneDisabled(True) - self.trackingMenu.setDisabled(False) - self.modeToolBar.setVisible(True) - self.lastTrackedFrameLabel.setText('') - self.initSegmTrackMode() - self.setEnabledEditToolbarButton(enabled=True) - self.addExistingDelROIs() - self.isFirstTimeOnNextFrame() - self.setEnabledCcaToolbar(enabled=False) - self.clearComputedContours() - self.realTimeTrackingToggle.setDisabled(False) - self.realTimeTrackingToggle.label.setDisabled(False) - if posData.cca_df is not None: - self.store_cca_df() - self.restorePrevAnnotOptions() - self.whitelistViewOGIDs(False) - elif mode == 'Cell cycle analysis': - self.setAutoSaveAnnotationsEnabled(True) - self.setSwitchViewedPlaneDisabled(True) - self.startCcaIntegrityCheckerWorker() - proceed = self.initCca() - if proceed: - self.applyDelROIs() - self.modeToolBar.setVisible(True) - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - self.computeAllContours() - # RAWR!!!!! - # self.computeAllObjToObjCostPairs() - if proceed: - self.setEnabledEditToolbarButton(enabled=False) - if self.isSnapshot: - self.editToolBar.setVisible(True) - self.setEnabledCcaToolbar(enabled=True) - self.removeAlldelROIsCurrentFrame() - self.setAnnotOptionsCcaMode() - self.clearGhost() - elif mode == 'Viewer': - self.autoSaveTimer.stop() - self.setSwitchViewedPlaneDisabled(False) - self.modeToolBar.setVisible(True) - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - self.setEnabledEditToolbarButton(enabled=False) - self.setEnabledCcaToolbar(enabled=False) - self.removeAlldelROIsCurrentFrame() - self.setStatusBarLabel() - self.navigateScrollBar.setMaximum(posData.SizeT) - self.navSpinBox.setMaximum(posData.SizeT) - self.clearGhost() - self.computeAllContours() - elif mode == 'Custom annotations': - self.setAutoSaveAnnotationsEnabled(True) - self.setSwitchViewedPlaneDisabled(True) - self.modeToolBar.setVisible(True) - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - self.setEnabledEditToolbarButton(enabled=False) - self.setEnabledCcaToolbar(enabled=False) - self.removeAlldelROIsCurrentFrame() - self.annotateToolbar.setVisible(True) - self.clearGhost() - self.doCustomAnnotation(0) - self.computeAllContours() - elif mode == 'Snapshot': - self.setAutoSaveAnnotationsEnabled(True) - self.setSwitchViewedPlaneDisabled(False) - self.reconnectUndoRedo() - self.setEnabledSnapshotMode() - self.doCustomAnnotation(0) - self.clearComputedContours() - elif mode == 'Normal division: Lineage tree': # Mode activation for lineage tree - # self.startLinTreeIntegrityCheckerWorker() # need to replace (postponed) - proceed = self.initLinTree() - self.setEnabledCcaToolbar(enabled=False) - self.setNavigateScrollBarMaximum() - if proceed: - self.applyDelROIs() - self.modeToolBar.setVisible(True) - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - if proceed: - self.setAutoSaveAnnotationsEnabled(True) - self.setEnabledEditToolbarButton(enabled=False) - if self.isSnapshot: - self.editToolBar.setVisible(True) - self.removeAlldelROIsCurrentFrame() - self.setAnnotOptionsLin_treeMode() - self.clearGhost() - self.editLin_TreeBar.setVisible(True) - - self.disableNonFunctionalButtons() - - def disableEditingViewPlaneNotXY(self): - posData = self.data[self.pos_i] - self.manuallyEditCcaAction.setDisabled(True) - for action in self.segmActions: - action.setDisabled(True) - if posData.SizeT == 1: - self.segmVideoMenu.setDisabled(True) - self.postProcessSegmAction.setDisabled(True) - self.autoSegmAction.setDisabled(True) - self.ccaToolBar.setVisible(False) - self.editToolBar.setVisible(False) - for action in self.ccaToolBar.actions(): - button = self.editToolBar.widgetForAction(action) - if button is not None: - button.setDisabled(True) - action.setVisible(False) - for action in self.editToolBar.actions(): - button = self.editToolBar.widgetForAction(action) - action.setVisible(False) - if button is not None: - button.setDisabled(True) - - def setEnabledSnapshotMode(self): - posData = self.data[self.pos_i] - self.manuallyEditCcaAction.setDisabled(False) - self.viewCcaTableAction.setDisabled(False) - for action in self.segmActions: - action.setDisabled(False) - - self.segmVideoMenu.setDisabled(True) - self.trackingMenu.setDisabled(True) - self.modeToolBar.setVisible(False) - - self.relabelSequentialAction.setDisabled(False) - self.postProcessSegmAction.setDisabled(False) - self.autoSegmAction.setDisabled(False) - self.ccaToolBar.setVisible(True) - self.editToolBar.setVisible(True) - self.reinitLastSegmFrameAction.setVisible(False) - for action in self.ccaToolBar.actions(): - button = self.ccaToolBar.widgetForAction(action) - if button == self.assignBudMothButton: - button.setDisabled(False) - action.setVisible(True) - elif action == self.reInitCcaAction: - action.setVisible(True) - elif action == self.assignBudMothAutoAction and posData.SizeT==1: - action.setVisible(True) - for action in self.editToolBar.actions(): - button = self.editToolBar.widgetForAction(action) - action.setVisible(True) - button.setEnabled(True) - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - self.repeatTrackingAction.setVisible(False) - self.manualTrackingAction.setVisible(False) - button = self.editToolBar.widgetForAction(self.repeatTrackingAction) - button.setDisabled(True) - button = self.editToolBar.widgetForAction(self.manualTrackingAction) - button.setDisabled(True) - self.disableNonFunctionalButtons() - self.reinitLastSegmFrameAction.setVisible(False) - - def launchSlideshow(self): - posData = self.data[self.pos_i] - self.determineSlideshowWinPos() - if self.slideshowButton.isChecked(): - self.slideshowWin = apps.imageViewer( - parent=self, - button_toUncheck=self.slideshowButton, - linkWindow=posData.SizeT > 1, - enableOverlay=True, - enableMirroredCursor=True - ) - self.slideshowWin.img.minMaxValuesMapper = ( - self.img1.minMaxValuesMapper - ) - self.slideshowWin.img.setCurrentPosIndex(self.pos_i) - h = self.drawIDsContComboBox.size().height() - self.slideshowWin.framesScrollBar.setFixedHeight(h) - self.slideshowWin.overlayButton.setChecked( - self.overlayButton.isChecked() - ) - self.slideshowWin.sigHoveringImage.connect( - self.setMirroredCursorFromSecondWindow - ) - if posData.SizeZ > 1: - z_slice = self.zSliceScrollBar.sliderPosition() - self.slideshowWin.img.setCurrentZsliceIndex(z_slice) - self.slideshowWin.zSliceScrollBar.setSliderPosition(z_slice) - self.slideshowWin.z_label.setText( - f'z-slice {z_slice+1:02}/{posData.SizeZ}' - ) - self.slideshowWin.update_img() - self.slideshowWin.show( - left=self.slideshowWinLeft, top=self.slideshowWinTop - ) - else: - self.slideshowWin.close() - self.slideshowWin = None - - def setMirroredCursorFromSecondWindow(self, x, y): - if x is None: - xx, yy = [], [] - else: - xx, yy = [x], [y] - self.ax1_cursor.setData(xx, yy) - if not self.isTwoImageLayout: - return - self.ax2_cursor.setData(xx, yy) - - def goToZsliceSearchedID(self, obj): - if not self.isSegm3D: - return - - current_z = self.z_lab() - nearest_nonzero_z = core.nearest_nonzero_z_idx_from_z_centroid( - obj, current_z=current_z - ) - if nearest_nonzero_z == current_z: - self.drawPointsLayers(computePointsLayers=True) - return - - self.zSliceScrollBar.setSliderPosition(nearest_nonzero_z) - self.update_z_slice(nearest_nonzero_z) - - def disconnectLeftClickButtons(self): - for button in self.LeftClickButtons: - try: - button.toggled.disconnect() - except Exception as e: - # Not all the LeftClickButtons have toggled connected - pass - - def uncheckLeftClickButtons(self, sender): - for button in self.LeftClickButtons: - if button != sender: - button.setChecked(False) - - if button != self.labelRoiButton: - # self.labelRoiButton is disconnected so we manually call uncheck - self.labelRoi_cb(False) - self.secondLevelToolbar.setVisible(True) - for toolbar in self.controlToolBars: - try: - toolbar.keepVisibleWhenActive - if toolbar.isVisible(): - self.secondLevelToolbar.setVisible(False) - continue - except: - pass - toolbar.setVisible(False) - - self.enableSizeSpinbox(False) - if sender is not None: - self.keepIDsButton.setChecked(False) - - def connectLeftClickButtonsPointsLayersToolbar(self): - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'layerTypeIdx'): - continue - if action.layerTypeIdx != 4: - continue - action.button.toggled.connect( - self.addPointsByClickingButtonToggled - ) - - def connectLeftClickButtons(self): - self.brushButton.toggled.connect(self.Brush_cb) - self.curvToolButton.toggled.connect(self.curvTool_cb) - self.rulerButton.toggled.connect(self.ruler_cb) - self.eraserButton.toggled.connect(self.Eraser_cb) - self.wandToolButton.toggled.connect(self.wand_cb) - self.labelRoiButton.toggled.connect(self.labelRoi_cb) - self.magicPromptsToolButton.toggled.connect(self.magicPrompts_cb) - self.drawClearRegionButton.toggled.connect(self.drawClearRegion_cb) - self.expandLabelToolButton.toggled.connect(self.expandLabelCallback) - self.addDelPolyLineRoiButton.toggled.connect(self.addDelPolyLineRoi_cb) - self.manualBackgroundButton.toggled.connect(self.manualBackground_cb) - self.whitelistIDsButton.toggled.connect(self.whitelistIDs_cb) - self.zoomRectButton.toggled.connect(self.zoomRectActionToggled) - self.connectLeftClickButtonsPointsLayersToolbar() - - def brushSize_cb(self, value): - self.ax2_EraserCircle.setSize(value*2) - self.ax1_BrushCircle.setSize(value*2) - self.ax2_BrushCircle.setSize(value*2) - self.ax1_EraserCircle.setSize(value*2) - self.ax2_EraserX.setSize(value) - self.ax1_EraserX.setSize(value) - self.setDiskMask() - - def autoIDtoggled(self, checked): - self.editIDspinboxAction.setDisabled(checked) - self.editIDLabelAction.setDisabled(checked) - if not checked and self.editIDspinbox.value() == 0: - newID = self.setBrushID(return_val=True) - self.editIDspinbox.setValue(newID) - - def wand_cb(self, checked): - posData = self.data[self.pos_i] - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.wandToolButton) - self.connectLeftClickButtons() - self.wandControlsToolbar.setVisible(True) - # self.secondLevelToolbar.setVisible(False) - else: - self.resetCursors() - # self.secondLevelToolbar.setVisible(True) - self.wandControlsToolbar.setVisible(False) - - def magicPrompts_cb(self, checked): - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.magicPromptsToolButton) - self.connectLeftClickButtons() - self.magicPromptsToolbar.setVisible(True) - self.promptSegmentPointsLayerToolbar.setVisible(True) - if not self.promptSegmentPointsLayerToolbar.isPointsLayerInit: - self.addPointsLayerTriggered( - toolbar=self.promptSegmentPointsLayerToolbar - ) - else: - self.resetCursors() - self.promptSegmentPointsLayerToolbar.setVisible(False) - self.magicPromptsToolbar.setVisible(False) - - def copyLostObjContour_cb(self, checked): - self.copyLostObjToolbar.setVisible(checked) - - self.ax1_lostObjScatterItem.hoverLostID = 0 - if not checked: - return - - self.lostObjImage = np.zeros_like(self.currentLab2D) - self.updateLostContoursImage(0) - - def manualAnnotPast_cb(self, checked): - posData = self.data[self.pos_i] - if checked: - for _ in range(3): - self.onEscape( - buttonsToNotUncheck=[self.manualAnnotPastButton], - doAutoRange=False - ) - - self.brushButton.setChecked(True) - self.store_data() - self.manualAnnotState = { - 'editID': self.editIDspinbox.value(), - 'isAutoID': self.autoIDcheckbox.isChecked(), - 'doWarnLostObj': self.warnLostCellsAction.isChecked(), - } - self.autoIDcheckbox.setChecked(False) - self.warnLostCellsAction.setChecked(False) - hoverID = self.getLastHoveredID() - if hoverID == 0: - win = apps.QLineEditDialog( - title='Not hovering any ID', - msg='You are not hovering on any ID.\n' - 'Enter the ID that you want to lock.', - parent=self, - isInteger=True, - defaultTxt=self.setBrushID(return_val=True) - ) - win.exec_() - if win.cancel: - self.manualAnnotPastButton.setChecked(False) - return - hoverID = win.EntryID - self.logger.info( - 'Setting manual annotation for ID = ' - f'{hoverID}, at frame n. {posData.frame_i+1}' - ) - self.editIDspinbox.setValue(hoverID) - try: - obj_idx = posData.IDs_idxs[hoverID] - obj = posData.rp[obj_idx] - radius = 0.9 * obj.minor_axis_length / 2 # math.sqrt(obj.area/math.pi)*0.9 - self.brushSizeSpinbox.setValue(round(radius)) - except Exception as err: - pass - - self.manualAnnotState['frame_i_to_restore'] = posData.frame_i - self.manualAnnotState['last_tracked_i'] = ( - self.navigateScrollBar.maximum()-1 - ) - self.ax1.sigRangeChanged.connect(self.highlightManualAnnotMode) - self.ax1.setHighlighted(True, color='green') - else: - self.setStatusBarLabel() - self.autoIDcheckbox.setChecked(self.manualAnnotState['isAutoID']) - self.editIDspinbox.setValue(self.manualAnnotState['editID']) - self.warnLostCellsAction.setChecked( - self.manualAnnotState['doWarnLostObj'] - ) - frame_to_restore = self.manualAnnotState.get('frame_i_to_restore') - if frame_to_restore is None: - return - - self.store_data() - self.store_manual_annot_data() - - last_tracked_i_to_restore = self.manualAnnotState['last_tracked_i'] - self.manualAnnotRestoreLastTrackedFrame(last_tracked_i_to_restore) - - self.logger.info( - f'Restoring view to frame n. {posData.frame_i+1}...' - ) - posData.frame_i = frame_to_restore - self.get_data() - self.updateAllImages() - self.updateScrollbars() - self.ax1.sigRangeChanged.disconnect() - self.ax1.setHighlighted(False) - QTimer.singleShot(150, self.autoRange) - - self.setManualAnnotModeEnabledTools(checked) - - def copyLostObjectMask(self, ID: int): - posData = self.data[self.pos_i] - mask = self.lostObjImage == ID - lab2D = self.get_2Dlab(posData.lab) - lab2D[mask] = ID - self.lostObjImage[mask] = 0 - self.set_2Dlab(lab2D) - - def highlightManualAnnotMode(self, viewBox, viewRange): - self.ax1.setHighlighted(True) - - def updateHighlightedAxis(self): - if not self.manualAnnotPastButton.isChecked(): - return - - frame_to_restore = self.manualAnnotState.get('frame_i_to_restore') - posData = self.data[self.pos_i] - if posData.frame_i == frame_to_restore: - color = 'green' - elif posData.frame_i < frame_to_restore: - color = 'gold' - else: - color = 'red' - - self.ax1.setHighlightingRectItemsColor(color) - - def updateLostNewCurrentIDs(self): - posData = self.data[self.pos_i] - - prev_IDs = self.getPrevFrameIDs() - tracked_lost_IDs = self.getTrackedLostIDs() - curr_IDs = posData.IDs - curr_delRoiIDs = self.getStoredDelRoiIDs() - prev_delRoiIDs = self.getStoredDelRoiIDs(frame_i=posData.frame_i-1) - lost_IDs = [ - ID for ID in prev_IDs if ID not in curr_IDs - and ID not in prev_delRoiIDs and ID not in tracked_lost_IDs - ] - new_IDs = [ - ID for ID in curr_IDs if ID not in prev_IDs - and ID not in curr_delRoiIDs - ] - IDs_with_holes = [] - posData.lost_IDs = lost_IDs - posData.new_IDs = new_IDs - posData.old_IDs = prev_IDs - posData.IDs = curr_IDs - - out = ( - lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs, curr_delRoiIDs - ) - return out - - def _copyAllLostObjects_navigateToFrame(self, frame_i): - posData = self.data[self.pos_i] - self.store_data(mainThread=False, autosave=False) - - posData.frame_i = frame_i - self.get_data() - self.tracking(wl_update=False) - self.currentLab2D = self.get_2Dlab(posData.lab) - self.update_rp() - self.updateLostNewCurrentIDs() - self.store_data(mainThread=False, autosave=False) - - self.lostObjContoursImage[:] = 0 - self.lostObjImage[:] = 0 - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] # need to change this when merging with opt. - for lostID in posData.lost_IDs: - obj = prev_rp[prev_IDs_idxs[lostID]] - self.addLostObjsToLostObjImage(obj, lostID, force=True) - - def _copyAllLostObjects_returnToFrame(self, frame_i): - posData = self.data[self.pos_i] - self.store_data(autosave=False, mainThread=False) - posData.frame_i = frame_i - self.get_data() - - def _copyAllLostObjects_refreshRp(self): - self.update_rp(draw=False, wl_update=False) # need to change this when merging with opt. - - @disableWindow - def copyAllLostObjects(self, for_future_frame_n, max_overlap_perc): - if not self.copyLostObjButton.isChecked(): - return - - posData = self.data[self.pos_i] - - desc = 'Copying all lost objects...' - - self.progressWin = apps.QDialogWorkerProgress( - title=desc, parent=self.mainWin, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(for_future_frame_n+1) - self.progressWin.show(self.app) - - self.copyAllLostObjectsThread = QThread() - - self.copyAllLostObjectsWorker = workers.CopyAllLostObjectsWorker( - self, posData, for_future_frame_n, max_overlap_perc - ) - self.copyAllLostObjectsWorker.moveToThread(self.copyAllLostObjectsThread) - - self.copyAllLostObjectsWorker.navigateToFrame.connect( - self._copyAllLostObjects_navigateToFrame, - Qt.BlockingQueuedConnection - ) - self.copyAllLostObjectsWorker.returnToFrame.connect( - self._copyAllLostObjects_returnToFrame, - Qt.BlockingQueuedConnection - ) - self.copyAllLostObjectsWorker.copyLostObjectMask.connect( - self.copyLostObjectMask, - Qt.BlockingQueuedConnection - ) - self.copyAllLostObjectsWorker.refreshRp.connect( - self._copyAllLostObjects_refreshRp, - Qt.BlockingQueuedConnection - ) - self.copyAllLostObjectsWorker.progressBar.connect( - self.workerUpdateProgressbar - ) - self.copyAllLostObjectsWorker.critical.connect( - self.copyAllLostObjectsWorkerCritical - ) - self.copyAllLostObjectsWorker.finished.connect( - self.copyAllLostObjectsThread.quit - ) - self.copyAllLostObjectsWorker.finished.connect( - self.copyAllLostObjectsWorker.deleteLater - ) - self.copyAllLostObjectsThread.finished.connect( - self.copyAllLostObjectsThread.deleteLater - ) - self.copyAllLostObjectsWorker.finished.connect( - self.copyAllLostObjectsWorkerFinished - ) - - self.copyAllLostObjectsThread.started.connect( - self.copyAllLostObjectsWorker.run - ) - self.copyAllLostObjectsThread.start() - - self.copyAllLostObjectsWorkerLoop = QEventLoop() - self.copyAllLostObjectsWorkerLoop.exec_() - - def copyAllLostObjectsWorkerCritical(self, error): - self.copyAllLostObjectsWorkerLoop.exit() - self.workerCritical(error) - - def copyAllLostObjectsWorkerFinished(self, output): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - if output.get('doReinitLastSegmFrame', False): - self.reInitLastSegmFrame( - from_frame_i=output.get('last_visited_frame_i'), - updateImages=False, - force=True - ) - - if output.get('overlap_warning', False): - self.blinker = qutils.QControlBlink( - self.copyLostObjToolbar.maxOverlapNumberControl, - qparent=self.mainWin - ) - self.blinker.start() - - self.copyAllLostObjectsWorkerLoop.exit() - self.update_rp() - self.updateAllImages() - self.store_data() - - def labelRoiTrangeCheckboxToggled(self, checked): - disabled = not checked - self.labelRoiStartFrameNoSpinbox.setDisabled(disabled) - self.labelRoiStopFrameNoSpinbox.setDisabled(disabled) - self.labelRoiStartFrameNoSpinbox.label.setDisabled(disabled) - self.labelRoiStopFrameNoSpinbox.label.setDisabled(disabled) - self.labelRoiToEndFramesAction.setDisabled(disabled) - self.labelRoiFromCurrentFrameAction.setDisabled(disabled) - - if disabled: - return - - posData = self.data[self.pos_i] - - self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i+1) - self.labelRoiStopFrameNoSpinbox.setValue(posData.SizeT) - - def drawClearRegion_cb(self, checked): - posData = self.data[self.pos_i] - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.drawClearRegionButton) - self.connectLeftClickButtons() - - self.drawClearRegionToolbar.setVisible(checked) - - if not self.isSegm3D: - self.drawClearRegionToolbar.setZslicesControlEnabled(False) - return - - if not checked: - return - - self.drawClearRegionToolbar.setZslicesControlEnabled( - True, SizeZ=posData.SizeZ - ) - - def labelRoi_cb(self, checked): - posData = self.data[self.pos_i] - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.labelRoiButton) - self.connectLeftClickButtons() - - self.labelRoiStartFrameNoSpinbox.setMaximum(posData.SizeT) - self.labelRoiStopFrameNoSpinbox.setMaximum(posData.SizeT) - - if self.labelRoiActiveWorkers: - lastActiveWorker = self.labelRoiActiveWorkers[-1] - self.labelRoiGarbageWorkers.append(lastActiveWorker) - lastActiveWorker.finished.emit() - self.logger.info('Collected garbage w5orker (magic labeller).') - - self.labelRoiToolbar.setVisible(True) - if self.isSegm3D: - self.labelRoiZdepthSpinbox.setDisabled(False) - else: - self.labelRoiZdepthSpinbox.setDisabled(True) - - # Start thread and pause it - self.labelRoiThread = QThread() - self.labelRoiMutex = QMutex() - self.labelRoiWaitCond = QWaitCondition() - - labelRoiWorker = workers.LabelRoiWorker(self) - - labelRoiWorker.moveToThread(self.labelRoiThread) - labelRoiWorker.finished.connect(self.labelRoiThread.quit) - labelRoiWorker.finished.connect(labelRoiWorker.deleteLater) - self.labelRoiThread.finished.connect( - self.labelRoiThread.deleteLater - ) - - labelRoiWorker.finished.connect(self.labelRoiWorkerFinished) - labelRoiWorker.sigLabellingDone.connect(self.labelRoiDone) - labelRoiWorker.sigProgressBar.connect(self.workerUpdateProgressbar) - - labelRoiWorker.progress.connect(self.workerProgress) - labelRoiWorker.critical.connect(self.workerCritical) - - self.labelRoiActiveWorkers.append(labelRoiWorker) - - self.labelRoiThread.started.connect(labelRoiWorker.run) - self.labelRoiThread.start() - - # Add the rectROI to ax1 - self.ax1.addItem(self.labelRoiItem) - elif self.initLabelRoiModelDialog is not None: - # User is using other tools while the dialog is still open - # --> we allow this because it's useful to be able to use - # the ruler or check things --> do nothing - pass - else: - self.labelRoiToolbar.setVisible(False) - - for worker in self.labelRoiActiveWorkers: - worker._stop() - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - self.labelRoiItem.setPos((0,0)) - self.labelRoiItem.setSize((0,0)) - self.freeRoiItem.clear() - self.ax1.removeItem(self.labelRoiItem) - self.updateLabelRoiCircularCursor(None, None, False) - - def clearObjsFreehandRegion(self): - self.logger.info('Clearing objects inside freehand region...') - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False, storeImage=False, storeOnlyZoom=True) - - posData = self.data[self.pos_i] - zRange = None - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - z_slice = self.z_lab() - zRange = self.drawClearRegionToolbar.zRange( - z_slice, posData.SizeZ - ) - else: - zRange = (0, posData.SizeZ) - - regionSlice = self.freeRoiItem.slice(zRange=zRange) - mask = self.freeRoiItem.mask() - - regionLab = posData.lab[(...,) + regionSlice].copy() - - clearBorders = ( - self.drawClearRegionToolbar - .clearOnlyEnclosedObjsRadioButton.isChecked() - ) - if clearBorders: - if regionLab.ndim == 2: - regionLab = transformation.clear_objects_not_in_mask( - regionLab, mask - ) - regionRp = skimage.measure.regionprops(regionLab) - for obj in regionRp: - if np.all(mask[obj.slice][obj.image]): - continue - - regionLab[obj.slice][obj.image] = 0 - else: - for z, regionLab_z in enumerate(regionLab): - regionLab[z] = transformation.clear_objects_not_in_mask( - regionLab_z, mask - ) - else: - regionLab[..., ~mask] = 0 - - regionRp = skimage.measure.regionprops(regionLab) - clearIDs = [obj.label for obj in regionRp] - - if not clearIDs: - if clearBorders: - self.logger.warning( - 'None of the objects in the freehand region are ' - 'fully enclosed' - ) - else: - self.logger.warning( - 'None of the objects are touching the freehand region' - ) - return - - self.deleteIDmiddleClick(clearIDs, False, False) - self.update_cca_df_deletedIDs(posData, clearIDs) - - self.freeRoiItem.clear() - - self.updateAllImages() - - def labelRoiWorkerFinished(self): - self.logger.info('Magic labeller closed.') - worker = self.labelRoiActiveWorkers.pop(-1) - - def indexRoiLab(self, roiLab, roiLabSlice, lab, brushID): - # Delete only objects touching borders in X and Y not in Z - if self.labelRoiAutoClearBorderCheckbox.isChecked(): - mask = np.zeros(roiLab.shape, dtype=bool) - mask[..., 1:-1, 1:-1] = True - roiLab = skimage.segmentation.clear_border(roiLab, mask=mask) - - roiLabMask = roiLab>0 - roiLab[roiLabMask] += (brushID-1) - if self.labelRoiReplaceExistingObjectsCheckbox.isChecked(): - IDs_touched_by_new_objects = np.unique(lab[roiLabSlice][roiLabMask]) - for ID in IDs_touched_by_new_objects: - lab[lab==ID] = 0 - - lab[roiLabSlice][roiLabMask] = roiLab[roiLabMask] - return lab - - @exception_handler - def labelRoiDone(self, roiSegmData, isTimeLapse): - self.setDisabled(False) - - posData = self.data[self.pos_i] - self.setBrushID() - - if isTimeLapse: - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.mainPbar.setValue(0) - current_frame_i = posData.frame_i - start_frame_i = self.labelRoiStartFrameNoSpinbox.value() - 1 - for i, roiLab in enumerate(roiSegmData): - frame_i = start_frame_i + i - lab = posData.allData_li[frame_i]['labels'] - store = True - if lab is None: - if frame_i >= len(posData.segm_data): - lab = np.zeros_like(posData.segm_data[0]) - posData.segm_data = np.append( - posData.segm_data, lab[np.newaxis], axis=0 - ) - else: - lab = posData.segm_data[frame_i] - store = False - roiLabSlice = self.labelRoiSlice[1:] - lab = self.indexRoiLab( - roiLab, roiLabSlice, lab, posData.brushID - ) - if store: - posData.frame_i = frame_i - posData.allData_li[frame_i]['labels'] = lab.copy() - self.get_data() - self.store_data(autosave=False) - - # Back to current frame - posData.frame_i = current_frame_i - self.get_data() - else: - roiLab = roiSegmData - posData.lab = self.indexRoiLab( - roiLab, self.labelRoiSlice, posData.lab, posData.brushID - ) - - self.update_rp() - - # Repeat tracking - if self.autoIDcheckbox.isChecked(): - self.tracking(enforce=True, assign_unique_new_IDs=False) - - self.store_data() - self.updateAllImages() - - self.labelRoiItem.setPos((0,0)) - self.labelRoiItem.setSize((0,0)) - self.freeRoiItem.clear() - self.logger.info('Magic labeller done!') - self.app.restoreOverrideCursor() - - self.labelRoiRunning = False - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - uncheckLabelRoiTRange = ( - self.labelRoiTrangeCheckbox.isChecked() - and not self.labelRoiTrangeCheckbox.findChild(QAction).isChecked() - ) - if uncheckLabelRoiTRange: - self.labelRoiTrangeCheckbox.setChecked(False) - - def restoreHoverObjBrush(self): - posData = self.data[self.pos_i] - if self.ax1BrushHoverID in posData.IDs: - obj_idx = posData.IDs_idxs[self.ax1BrushHoverID] - obj = posData.rp[obj_idx] - if not self.isObjVisible(obj.bbox): - return - - self.addObjContourToContoursImage(obj=obj, ax=0) - self.addObjContourToContoursImage(obj=obj, ax=1) - - def hideItemsHoverBrush(self, xy=None, ID=None, force=False): - if xy is not None: - x, y = xy - if x is None: - return - - xdata, ydata = int(x), int(y) - Y, X = self.currentLab2D.shape - - if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): - return - - if not self.brushAutoHideCheckbox.isChecked() and not force: - return - - posData = self.data[self.pos_i] - size = self.brushSizeSpinbox.value()*2 - - if xy is not None: - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - - if self.ax1_lostObjScatterItem.isVisible(): - self.ax1_lostObjScatterItem.setVisible(False) - - if self.ax1_lostTrackedScatterItem.isVisible(): - self.ax1_lostTrackedScatterItem.setVisible(False) - - if self.ax2_lostObjScatterItem.isVisible(): - self.ax2_lostObjScatterItem.setVisible(False) - - if self.ax2_lostTrackedScatterItem.isVisible(): - self.ax2_lostTrackedScatterItem.setVisible(False) - - # Restore ID previously hovered - if ID != self.ax1BrushHoverID and not self.isMouseDragImg1: - try: - self.restoreHoverObjBrush() - except Exception as e: - self.ax1BrushHoverID = 0 - return - - # Hide items hover ID - if ID != 0: - self.clearObjContour(ID=ID, ax=0) - self.clearObjContour(ID=ID, ax=1) - self.ax1BrushHoverID = ID - else: - self.ax1BrushHoverID = 0 - - def updateBrushCursor(self, x, y, isHoverImg1=True): - if x is None: - return - - xdata, ydata = int(x), int(y) - _img = self.currentLab2D - Y, X = _img.shape - - if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): - return - - size = self.brushSizeSpinbox.value()*2 - self.setHoverToolSymbolData( - [x], [y], self.activeBrushCircleCursors(isHoverImg1), - size=size - ) - self.setHoverToolSymbolColor( - xdata, ydata, self.ax2_BrushCirclePen, - self.activeBrushCircleCursors(isHoverImg1), - self.brushButton, brush=self.ax2_BrushCircleBrush - ) - - def moveLabelButtonToggled(self, checked): - if not checked: - self.hoverLabelID = 0 - self.highlightedID = 0 - self.highLightIDLayerImg1.clear() - self.highLightIDLayerRightImage.clear() - self.setHighlightID(False) - - def setAllIDs(self, onlyVisited=False): - for posData in self.data: - posData.allIDs = set() - for frame_i in range(len(posData.segm_data)): - if frame_i >= len(posData.allData_li): - break - lab = posData.allData_li[frame_i]['labels'] - if lab is None and onlyVisited: - break - - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - else: - rp = posData.allData_li[frame_i]['regionprops'] - posData.allIDs.update([obj.label for obj in rp]) - - def countObjectsTimelapse(self): - if self.countObjsWindow is None: - activeCategories = { - 'In current frame', - 'In all visited frames', - 'In entire video', - 'Unique objects in all visited frames', - 'Unique objects in entire video' - } - else: - activeCategories = self.countObjsWindow.activeCategories() - - posData = self.data[self.pos_i] - allCategoryCountMapper = posData.countObjectsInSegmTimelapse( - activeCategories - ) - if self.countObjsWindow is None: - return allCategoryCountMapper - - categoryCountMapper = {} - for category in activeCategories: - categoryCountMapper[category] = allCategoryCountMapper[category] - - return categoryCountMapper - - - def countObjectsSnapshots(self): - posData = self.data[self.pos_i] - if self.countObjsWindow is None: - activeCategories = { - 'In current position', - 'In all visited positions (current session)', - 'In all visited positions (previous sessions)', - 'In all loaded positions', - } - if self.isSegm3D: - activeCategories.add('In current z-slice') - else: - activeCategories = self.countObjsWindow.activeCategories() - - numObjectsCurrentPos = len(posData.IDs) - numObjectsAllPos = 0 - numObjectsVisitedPosPrevious = 0 - numObjectsVisitedPosCurrent = 0 - numObjectsCurrentZslice = None - if 'In current z-slice' in activeCategories: - numObjectsCurrentZslice = len( - skimage.measure.regionprops(self.currentLab2D) - ) - - for pos_i, _posData in enumerate(self.data): - IDs = _posData.allData_li[0]['IDs'] - if os.path.exists(_posData.acdc_output_csv_path): - numObjectsVisitedPosPrevious += len(IDs) - if IDs: - numObjs = len(IDs) - numObjectsAllPos += len(IDs) - else: - lab = _posData.segm_data[0] - rp = skimage.measure.regionprops(lab) - numObjs = len(rp) - numObjectsAllPos += numObjs - - if _posData.visited: - numObjectsVisitedPosCurrent += numObjs - - allCategoryCountMapper = { - 'In current position': numObjectsCurrentPos, - 'In all visited positions (current session)': - numObjectsVisitedPosCurrent, - 'In all visited positions (previous sessions)': - numObjectsVisitedPosPrevious, - 'In all loaded positions': numObjectsAllPos, - } - if numObjectsCurrentZslice is not None: - allCategoryCountMapper['In current z-slice'] = ( - numObjectsCurrentZslice - ) - - if self.countObjsWindow is None: - return allCategoryCountMapper - - categoryCountMapper = {} - for category in activeCategories: - categoryCountMapper[category] = allCategoryCountMapper[category] - - return categoryCountMapper - - def countObjects(self): - self.logger.info('Counting objects...') - - posData = self.data[self.pos_i] - if posData.SizeT > 1: - return self.countObjectsTimelapse() - - return self.countObjectsSnapshots() - - - def updateObjectCounts(self): - if self.countObjsWindow is None: - return - - if not self.countObjsWindow.isVisible(): - return - - if not self.countObjsWindow.livePreviewCheckbox.isChecked(): - return - - categoryCountMapper = self.countObjects() - self.countObjsWindow.updateCounts(categoryCountMapper) - - def keepIDs_cb(self, checked): - if checked: - self.highlightedLab = np.zeros_like(self.currentLab2D) - if self.annotCcaInfoCheckbox.isChecked(): - self.annotCcaInfoCheckbox.setChecked(False) - self.annotIDsCheckbox.setChecked(True) - self.setDrawAnnotComboboxText() - self.uncheckLeftClickButtons(None) - self.initKeepObjLabelsLayers() - self.setAllIDs() - else: - # restore items to non-grayed out - self.clearTempBrushImage() - alpha = self.imgGrad.labelsAlphaSlider.value() - self.labelsLayerImg1.setOpacity(alpha) - self.labelsLayerRightImg.setOpacity(alpha) - self.ax1_contoursImageItem.setOpacity(1.0) - self.ax2_contoursImageItem.setOpacity(1.0) - self.ax1_lostObjImageItem.setOpacity(1.0) - self.ax2_lostObjImageItem.setOpacity(1.0) - self.ax1_lostTrackedObjImageItem.setOpacity(1.0) - self.ax2_lostTrackedObjImageItem.setOpacity(1.0) - - self.keepIDsToolbar.setVisible(checked) - self.highlightedIDopts = None - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - self.updateAllImages() - - # QTimer.singleShot(300, self.autoRange) - - def get_curr_lab(self, curr_lab: np.ndarray|None = None, frame_i: int|None = None): - """Get the current labels for the position data. Hirarchically checks: - 1. If `curr_lab` is provided, use it. - 2. If `posData.lab` is not None, use it. - 3. If `posData.allData_li[frame_i]['labels']` exists, use it. - 4. If `posData.segm_data[frame_i]` exists, use it. - - If frame_i is None, uses the current frame index from `posData`. - - Parameters - ---------- - curr_lab : np.ndarray, optional - Current labels for the position data if it should be checked - if its not None first, by default None - frame_i : int, optional - Frame index to use for retrieving labels, by default None - - Returns - ------- - np.ndarray - Current labels for the position data - """ - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - if curr_lab is None and frame_i == posData.frame_i: - curr_lab = posData.lab - - if curr_lab is None: - try: - curr_lab = posData.allData_li[frame_i]['labels'].copy() - except: - pass - - if curr_lab is None: - try: - curr_lab = posData.segm_data[frame_i].copy() - except: - pass - - return curr_lab - - def setFrameNavigationDisabled(self, disable: bool, why: str): - """Disables the frame navigation buttons and scrollbar. - This is used when the user is not allowed to navigate through frames - Call again to unlock it again. Also sets tooltips to inform the user - - Parameters - ---------- - disable : bool - if the navigation should be disabled - why : str - the reason for disabeling the navigation. - """ - - if disable: - self.whyNavigateDisabled.add(why) - else: - try: - self.whyNavigateDisabled.remove(why) - except KeyError: - pass - - if len(self.whyNavigateDisabled) == 0: - disable = False - else: - disable = True - - # Apply the disable/enable state - self.prevAction.setDisabled(disable) - self.nextAction.setDisabled(disable) - self.navigateScrollBar.setDisabled(disable) - - # Set appropriate tooltip - if not disable: - self.navigateScrollBar.setToolTip( - 'NOTE: The maximum frame number that can be visualized with this ' - 'scrollbar\n' - 'is the last visited frame with the selected mode\n' - '(see "Mode" selector on the top-right).\n\n' - 'If the scrollbar does not move it means that you never visited\n' - 'any frame with current mode.\n\n' - 'Note that the "Viewer" mode allows you to scroll ALL frames.' - ) - return - - txt = f'Frame navigation disabled: {self.whyNavigateDisabled}' - self.logger.info(txt) - self.navigateScrollBar.setToolTip(txt) - - def delObjsOutSegmMaskActionTriggered(self): - posData = self.data[self.pos_i] - segm_files = load.get_segm_files(posData.images_path) - existingSegmEndnames = load.get_endnames( - posData.basename, segm_files - ) - selectSegmWin = widgets.QDialogListbox( - 'Select segmentation file', - 'Select segmentation file to use as ROI:\n', - existingSegmEndnames, multiSelection=False, parent=self - ) - selectSegmWin.exec_() - if selectSegmWin.cancel: - self.logger.info('Delete objects process cancelled.') - return - - selectedSegmEndname = selectSegmWin.selectedItemsText[0] - - self.startDelObjsOutSegmMaskWorker(selectedSegmEndname) - - def startDelObjsOutSegmMaskWorker(self, selectedSegmEndname): - self.store_data(autosave=False) - posData = self.data[self.pos_i] - segm_data = np.squeeze(self.getStoredSegmData()) - - self.progressWin = apps.QDialogWorkerProgress( - title='Deleting objects outside of ROIs', parent=self, - pbarDesc='Deleting objects outside of ROIs...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - self.thread = QThread() - self.worker = workers.DelObjectsOutsideSegmROIWorker( - selectedSegmEndname, segm_data, posData.images_path - ) - self.worker.moveToThread(self.thread) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - - self.worker.progress.connect(self.workerProgress) - self.worker.critical.connect(self.workerCritical) - self.worker.finished.connect(self.delObjsOutSegmMaskWorkerFinished) - - self.worker.debug.connect(self.workerDebug) - - self.thread.started.connect(self.worker.run) - self.thread.start() - - def storeViewRange(self): - if not hasattr(self, 'isRangeReset'): - return - - if not self.isRangeReset: - return - self.ax1_viewRange = self.ax1.viewRange() - self.isRangeReset = False - - def mergeObjs_cb(self, checked): - if not checked: - self.mergeObjsTempLine.setData([], []) - - def Brush_cb(self, checked): - if checked: - self.typingEditID = False - self.setDiskMask() - self.setHoverToolSymbolData( - [], [], (self.ax1_EraserCircle, self.ax2_EraserCircle, - self.ax1_EraserX, self.ax2_EraserX) - ) - self.updateBrushCursor(self.xHoverImg, self.yHoverImg) - self.setBrushID() - - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.sender()) - c = self.defaultToolBarButtonColor - self.eraserButton.setStyleSheet(f'background-color: {c}') - self.connectLeftClickButtons() - self.setFocusGraphics() - else: - self.ax1_lostObjScatterItem.setVisible(True) - self.ax2_lostObjScatterItem.setVisible(True) - self.ax1_lostTrackedScatterItem.setVisible(True) - self.ax2_lostTrackedScatterItem.setVisible(True) - - self.setHoverToolSymbolData( - [], [], (self.ax2_BrushCircle, self.ax1_BrushCircle), - ) - self.resetCursors() - - self.showEditIDwidgets(checked) - self.enableSizeSpinbox(checked) - - def showEditIDwidgets(self, visible): - self.editIDLabelAction.setVisible(visible) - self.editIDspinboxAction.setVisible(visible) - self.autoIDcheckboxAction.setVisible(visible) - showToolbar = ( - visible - or self.brushSizeAction.isVisible() - or self.brushAutoFillAction.isVisible() - or self.brushAutoHideAction.isVisible() - ) - self.brushEraserToolBar.setVisible(showToolbar) - - def resetCursors(self): - self.ax1_cursor.setData([], []) - self.ax2_cursor.setData([], []) - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - def setDiskMask(self): - brushSize = self.brushSizeSpinbox.value() - # diam = brushSize*2 - # center = (brushSize, brushSize) - # diskShape = (diam+1, diam+1) - # diskMask = np.zeros(diskShape, bool) - # rr, cc = skimage.draw.disk(center, brushSize+1, shape=diskShape) - # diskMask[rr, cc] = True - self.diskMask = skimage.morphology.disk(brushSize, dtype=bool) - - def getDiskMask(self, xdata, ydata): - Y, X = self.currentLab2D.shape[-2:] - - brushSize = self.brushSizeSpinbox.value() - yBottom, xLeft = ydata-brushSize, xdata-brushSize - yTop, xRight = ydata+brushSize+1, xdata+brushSize+1 - - if xLeft<0: - if yBottom<0: - # Disk mask out of bounds top-left - diskMask = self.diskMask.copy() - diskMask = diskMask[-yBottom:, -xLeft:] - yBottom = 0 - elif yTop>Y: - # Disk mask out of bounds bottom-left - diskMask = self.diskMask.copy() - diskMask = diskMask[0:Y-yBottom, -xLeft:] - yTop = Y - else: - # Disk mask out of bounds on the left - diskMask = self.diskMask.copy() - diskMask = diskMask[:, -xLeft:] - xLeft = 0 - - elif xRight>X: - if yBottom<0: - # Disk mask out of bounds top-right - diskMask = self.diskMask.copy() - diskMask = diskMask[-yBottom:, 0:X-xLeft] - yBottom = 0 - elif yTop>Y: - # Disk mask out of bounds bottom-right - diskMask = self.diskMask.copy() - diskMask = diskMask[0:Y-yBottom, 0:X-xLeft] - yTop = Y - else: - # Disk mask out of bounds on the right - diskMask = self.diskMask.copy() - diskMask = diskMask[:, 0:X-xLeft] - xRight = X - - elif yBottom<0: - # Disk mask out of bounds on top - diskMask = self.diskMask.copy() - diskMask = diskMask[-yBottom:] - yBottom = 0 - - elif yTop>Y: - # Disk mask out of bounds on bottom - diskMask = self.diskMask.copy() - diskMask = diskMask[0:Y-yBottom] - yTop = Y - - else: - # Disk mask fully inside the image - diskMask = self.diskMask - - return yBottom, xLeft, yTop, xRight, diskMask - - def setBrushID(self, useCurrentLab=True, return_val=False): - # Make sure that the brushed ID is always a new one based on - # already visited frames - posData = self.data[self.pos_i] - wl_init = posData.whitelist and posData.whitelist.whitelistIDs - if useCurrentLab: - IDs_tot = set(posData.IDs) - if wl_init: - try: - IDs_tot.update(posData.whitelist.originalLabsIDs[posData.frame_i]) - except: - pass - try: - if posData.whitelist.whitelistIDs[posData.frame_i]: - IDs_tot.update(posData.whitelist.whitelistIDs[posData.frame_i]) - except: - pass - newID = max(IDs_tot, default=0) - else: - newID = 0 - for frame_i, storedData in enumerate(posData.allData_li): - if frame_i == posData.frame_i: - continue - lab = storedData['labels'] - if lab is not None: - rp = storedData['regionprops'] - IDs_tot = {obj.label for obj in rp} - if wl_init: - if self.whitelistCheckOriginalLabels(warning=False, frame_i=frame_i): - IDs_tot.update(posData.whitelist.originalLabsIDs[frame_i]) - if posData.whitelist.whitelistIDs[frame_i]: - IDs_tot.update(posData.whitelist.whitelistIDs[frame_i]) - _max = max(IDs_tot, default=0) - if _max > newID: - newID = _max - else: - break - - for y, x, manual_ID in posData.editID_info: - if manual_ID > newID: - newID = manual_ID - posData.brushID = newID+1 - if return_val: - return posData.brushID - - @disableWindow - def equalizeHist(self, checked=True): - self.img1.useEqualized = checked - - if not checked: - self.updateAllImages() - return - - self.logger.info('Equalizing image histogram...') - for pos_i, _posData in enumerate(self.data): - n_dim_img = _posData.img_data.ndim - _posData.equalized_img_data = preprocess.PreprocessedData() - for frame_i, img_frame in enumerate(_posData.img_data): - if n_dim_img == 4: - for z, img_z in enumerate(img_frame): - eq_img = skimage.exposure.equalize_adapthist(img_z) - _posData.equalized_img_data[frame_i][z] = eq_img - self.img1.updateMinMaxValuesEqualizedData( - self.data, pos_i, frame_i, z - ) - self.img1.updateMinMaxValuesEqualizedDataProjections( - self.data, pos_i, frame_i - ) - else: - eq_img = skimage.exposure.equalize_adapthist(img_frame) - _posData.equalized_img_data[frame_i] = eq_img - self.img1.updateMinMaxValuesEqualizedData( - self.data, pos_i, frame_i, None - ) - - self.updateAllImages() - - def curvTool_cb(self, checked): - posData = self.data[self.pos_i] - if checked: - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.curvToolButton) - self.connectLeftClickButtons() - self.hoverLinSpace = np.linspace(0, 1, 1000) - self.curvPlotItem = pg.PlotDataItem(pen=self.newIDs_cpen) - self.curvHoverPlotItem = pg.PlotDataItem(pen=self.oldIDs_cpen) - self.curvAnchors = pg.ScatterPlotItem( - symbol='o', size=9, - brush=pg.mkBrush((255,0,0,50)), - pen=pg.mkPen((255,0,0), width=2), - hoverable=True, hoverPen=pg.mkPen((255,0,0), width=3), - hoverBrush=pg.mkBrush((255,0,0)), tip=None - ) - self.ax1.addItem(self.curvAnchors) - self.ax1.addItem(self.curvPlotItem) - self.ax1.addItem(self.curvHoverPlotItem) - self.splineHoverON = True - posData.curvPlotItems.append(self.curvPlotItem) - posData.curvAnchorsItems.append(self.curvAnchors) - posData.curvHoverItems.append(self.curvHoverPlotItem) - else: - self.splineHoverON = False - self.isRightClickDragImg1 = False - self.clearCurvItems() - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - self.showEditIDwidgets(checked) - - def updateHoverLabelCursor(self, x, y): - if x is None: - self.hoverLabelID = 0 - return - - xdata, ydata = int(x), int(y) - Y, X = self.currentLab2D.shape - if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): - return - - ID = self.currentLab2D[ydata, xdata] - self.hoverLabelID = ID - - if ID == 0: - if self.highlightedID != 0: - self.updateAllImages() - self.highlightedID = 0 - return - - if self.app.overrideCursor() != Qt.SizeAllCursor: - self.app.setOverrideCursor(Qt.SizeAllCursor) - - if not self.isMovingLabel: - self.highlightSearchedID(ID) - - def updateEraserCursor(self, x, y, xyLocked=None, isHoverImg1=True): - if x is None: - return - - xdata, ydata = int(x), int(y) - _img = self.currentLab2D - Y, X = _img.shape - - if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): - return - - size = self.brushSizeSpinbox.value()*2 - self.setHoverToolSymbolData( - [x], [y], self.activeEraserCircleCursors(isHoverImg1), - size=size - ) - self.setHoverToolSymbolData( - [x], [y], self.activeEraserXCursors(isHoverImg1), - size=int(size/2) - ) - - isMouseDrag = ( - self.isMouseDragImg1 or self.isMouseDragImg2 - ) - if isMouseDrag: - return - - if xyLocked is not None: - xdata, ydata = xyLocked - - self.setHoverToolSymbolColor( - xdata, ydata, self.eraserCirclePen, - self.activeEraserCircleCursors(isHoverImg1), - self.eraserButton, hoverRGB=None - ) - - def Eraser_cb(self, checked): - if checked: - self.setDiskMask() - self.setHoverToolSymbolData( - [], [], (self.ax2_BrushCircle, self.ax1_BrushCircle), - ) - self.updateEraserCursor(self.xHoverImg, self.yHoverImg) - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.sender()) - c = self.defaultToolBarButtonColor - self.brushButton.setStyleSheet(f'background-color: {c}') - self.connectLeftClickButtons() - else: - self.setHoverToolSymbolData( - [], [], (self.ax1_EraserCircle, self.ax2_EraserCircle, - self.ax1_EraserX, self.ax2_EraserX) - ) - self.resetCursors() - self.updateAllImages() - - self.showEditIDwidgets(checked) - self.enableSizeSpinbox(checked) - - def storeCurrentAnnotOptions_ax1(self, return_value=False): - if self.annotOptionsToRestore is not None: - return - - checkboxes = [ - 'annotIDsCheckbox', - 'annotCcaInfoCheckbox', - 'annotContourCheckbox', - 'annotSegmMasksCheckbox', - 'drawMothBudLinesCheckbox', - 'annotNumZslicesCheckbox', - 'drawNothingCheckbox', - ] - annotOptions = {} - for checkboxName in checkboxes: - checkbox = getattr(self, checkboxName) - annotOptions[checkboxName] = checkbox.isChecked() - if return_value: - return annotOptions - self.annotOptionsToRestore = annotOptions - - def storeCurrentAnnotOptions_ax2(self): - if self.annotOptionsToRestoreRight is not None: - return - - checkboxes = [ - 'annotIDsCheckboxRight', - 'annotCcaInfoCheckboxRight', - 'annotContourCheckboxRight', - 'annotSegmMasksCheckboxRight', - 'drawMothBudLinesCheckboxRight', - 'annotNumZslicesCheckboxRight', - 'drawNothingCheckboxRight', - ] - self.annotOptionsToRestoreRight = {} - for checkboxName in checkboxes: - checkbox = getattr(self, checkboxName) - self.annotOptionsToRestoreRight[checkboxName] = checkbox.isChecked() - - def restoreAnnotOptions_ax1(self, options=None): - if options is None and not hasattr(self, 'annotOptionsToRestore'): - return - - if options is None: - options = self.annotOptionsToRestore - - if options is None: - return - - for option, state in options.items(): - checkbox = getattr(self, option) - checkbox.setChecked(state) - - self.setDrawAnnotComboboxText() - self.annotOptionsToRestore = None - - def restoreAnnotOptions_ax2(self): - if not hasattr(self, 'annotOptionsToRestoreRight'): - return - - if self.annotOptionsToRestoreRight is None: - return - - for option, state in self.annotOptionsToRestoreRight.items(): - checkbox = getattr(self, option) - checkbox.setChecked(state) - - self.setDrawAnnotComboboxTextRight() - self.annotOptionsToRestoreRight = None - - def setDrawNothingAnnotations(self): - self.storeCurrentAnnotOptions_ax1() - self.storeCurrentAnnotOptions_ax2() - self.drawNothingCheckbox.setChecked(True) - self.annotOptionClicked( - sender=self.drawNothingCheckbox, saveSettings=False) - self.drawNothingCheckboxRight.setChecked(True) - self.annotOptionClickedRight( - sender=self.drawNothingCheckboxRight, saveSettings=False - ) - - def restoreAnnotationsOptions(self): - self.restoreAnnotOptions_ax1() - self.restoreAnnotOptions_ax2() - - def onDoubleSpaceBar(self): - how = self.drawIDsContComboBox.currentText() - if how.find('nothing') == -1: - self.storeCurrentAnnotOptions_ax1() - self.drawNothingCheckbox.setChecked(True) - self.annotOptionClicked( - sender=self.drawNothingCheckbox, saveSettings=False - ) - else: - self.restoreAnnotOptions_ax1() - - how = self.annotateRightHowCombobox.currentText() - if how.find('nothing') == -1: - self.storeCurrentAnnotOptions_ax2() - self.drawNothingCheckboxRight.setChecked(True) - self.annotOptionClickedRight( - sender=self.drawNothingCheckboxRight, saveSettings=False - ) - else: - self.restoreAnnotOptions_ax2() - - - def resizeBottomLayoutLineClicked(self, event): - pass - - def resizeBottomLayoutLineDragged(self, event): - if not self.img1BottomGroupbox.isVisible(): - return - newBottomLayoutHeight = self.bottomScrollArea.minimumHeight() - event.y() - self.bottomScrollArea.setFixedHeight(newBottomLayoutHeight) - - def resizeBottomLayoutLineReleased(self): - QTimer.singleShot(100, self.autoRange) - - def mousePressEvent(self, event) -> None: - if event.button() == Qt.MouseButton.RightButton: - pos = self.resizeBottomLayoutLine.mapFromGlobal(event.globalPos()) - if pos.y()>=0: - self.gui_raiseBottomLayoutContextMenu(event) - return super().mousePressEvent(event) - - def zoomBottomLayoutActionTriggered(self, checked): - if not checked: - return - perc = int(re.findall(r'(\d+)%', self.sender().text())[0]) - if perc != 100: - fontSizeFactor = perc/100 - heightFactor = perc/100 - self.resizeSlidersArea( - fontSizeFactor=fontSizeFactor, heightFactor=heightFactor - ) - else: - self.gui_resetBottomLayoutHeight() - self.df_settings.at['bottom_sliders_zoom_perc', 'value'] = perc - self.df_settings.to_csv(self.settings_csv_path) - QTimer.singleShot(150, self.resizeGui) - - def defaultRescaleIntensLutActionToggled(self, action): - how = action.text() - for rescaleIntensAction in self.imgGrad.rescaleActionGroup.actions(): - if how == rescaleIntensAction.text(): - rescaleIntensAction.setChecked(True) - rescaleIntensAction.trigger() - break - - for channel, items in self.overlayLayersItems.items(): - lutItem = items[1] - for rescaleIntensAction in lutItem.rescaleActionGroup.actions(): - if how == rescaleIntensAction.text(): - rescaleIntensAction.setChecked(True) - rescaleIntensAction.trigger() - break - - self.df_settings.at['default_rescale_intens_how', 'value'] = how - self.df_settings.to_csv(self.settings_csv_path) - - def retainSpaceSlidersToggled(self, checked): - if checked: - self.df_settings.at['retain_space_hidden_sliders', 'value'] = 'Yes' - else: - self.df_settings.at['retain_space_hidden_sliders', 'value'] = 'No' - self.df_settings.to_csv(self.settings_csv_path) - if not self.zSliceScrollBar.isEnabled(): - retainSpaceZ = False - else: - retainSpaceZ = checked - myutils.setRetainSizePolicy(self.zSliceScrollBar, retain=retainSpaceZ) - myutils.setRetainSizePolicy(self.zProjComboBox, retain=retainSpaceZ) - myutils.setRetainSizePolicy(self.zSliceOverlay_SB, retain=retainSpaceZ) - myutils.setRetainSizePolicy(self.zProjOverlay_CB, retain=retainSpaceZ) - myutils.setRetainSizePolicy(self.overlay_z_label, retain=retainSpaceZ) - - QTimer.singleShot(200, self.resizeGui) - - def resizeLeaveSpaceTerminalBelow(self): - self.setWindowState(Qt.WindowMaximized) - QTimer.singleShot(200, self._resizeLeaveSpaceTerminalBelow) - - def _resizeLeaveSpaceTerminalBelow(self): - geometry = self.geometry() - left = geometry.left() - top = geometry.top() - width = geometry.width() - height = geometry.height() - self.setGeometry(left, top+10, width, height-200) - - def checkSetDelObjActionActive(self, event): - if self.delObjAction is None and self.is_win: - return - - if self.delObjAction is None: - # On mac we check for Key_Control - if event.key() == Qt.Key_Control: - self.delObjToolAction.setChecked(True) - return - - delObjKeySequence, delObjQtButton = self.delObjAction - keySequenceText = widgets.QKeyEventToString(event).rstrip('+') - - if delObjKeySequence is None: - # self.delObjToolAction.setChecked(True) - return - - delObjKeySequenceText = widgets.macShortcutToWindows( - delObjKeySequence.toString() - ) - keySequenceText = widgets.macShortcutToWindows(keySequenceText) - - # printl( - # delObjKeySequence.toString(), - # keySequenceText, - # delObjKeySequenceText - # ) - - if keySequenceText == delObjKeySequenceText: - self.delObjToolAction.setChecked(True) - - def changeRightClickToLeftOnMac(self, mouseEvent): - button = mouseEvent.button() - if not is_mac: - return button - - delObjKeySequence, delObjQtButton = self.delObjAction - if delObjKeySequence is None: - return button - - if not delObjKeySequence.toString() == 'Control': - return button - - if button != Qt.MouseButton.RightButton: - return button - - if delObjQtButton == Qt.MouseButton.LeftButton: - # On mac, pressing "Control" and clicking with left button changes - # it to a right click button --> here, left click is required for - # delete object --> force return of left click - return Qt.MouseButton.LeftButton - - return button - - - def checkTriggerKeyPressShortcuts(self, event: QKeyEvent): - isBrushKey = event.key() == self.brushButton.keyPressShortcut - isEraserKey = event.key() == self.eraserButton.keyPressShortcut - if isBrushKey or isEraserKey: - return isBrushKey, isEraserKey - - modifierText = widgets.modifierKeyToText(event.modifiers()) - for widget in self.widgetsWithShortcut.values(): - if not hasattr(widget, 'keyPressShortcut'): - continue - - if event.key() == widget.keyPressShortcut: - if widget.isCheckable(): - widget.setChecked(True) - else: - widget.trigger() - continue - - shortcutText = widget.keyPressShortcut.toString() - try: - mod, key = shortcutText.split('+') - if modifierText == mod and event.key() == QKeySequence(key): - widget.trigger() - - except Exception as e: - pass - - return isBrushKey, isEraserKey - - def _temp_debug(self, id=None): - posData = self.data[self.pos_i] - imshow(posData.lab, annotate_labels_idxs=[0]) - - def checkOverlayToolbuttonClicked(self, event): - success = False - try: - n = int(event.text()) - toolbutton = self.allOverlayToolbuttonsByIdx.get(n, None) - toolbutton.click() - success = True - except Exception as e: - # printl(traceback.format_exc()) - success = False - return success - - def keyPressCheckSetSpinboxValue(self, event, spinbox): - """Check if the key pressed is a digit and set the spinbox value - accordingly.""" - try: - n = int(event.text()) - if self.typingEditID: - value = int(f'{spinbox.value()}{n}') - else: - value = n - self.typingEditID = True - spinbox.setValue(value) - - try: - spinbox.timer.stop() - except Exception as err: - pass - - spinbox.timer = QTimer(spinbox) - spinbox.timer.timeout.connect( - self.editingSpinboxValueTimerCallback - ) - spinbox.timer.start(2000) - spinbox.timer.setSingleShot(True) - success = True - except Exception as e: - # printl(traceback.format_exc()) - success = False - return success - - def editingSpinboxValueTimerCallback(self): - self.typingEditID = False - - @exception_handler - def keyPressEvent(self, ev): - ctrl = ev.modifiers() == Qt.ControlModifier - if ctrl and ev.key() == Qt.Key_D: - self.resizeLeaveSpaceTerminalBelow() - return - - if ev.key() == Qt.Key_Q and self.debug: - try: - from . import _q_debug - _q_debug.q_debug(self) - except Exception as err: - printl(traceback.format_exc()) - printl('[ERROR]: Error with "_qdebug" module. See Traceback above.') - pass - - if not self.isDataLoaded: - self.logger.warning( - 'Data not loaded yet. Key pressing events are not connected.' - ) - return - - if ev.key() == Qt.Key_Control: - if not ctrl: - self.wasCtrlPressedFirstTime = True - self.onCtrlPressedFirstTime() - - if ev.key() == Qt.Key_PageDown: - self.onKeyPageDown() - - if ev.key() == Qt.Key_PageUp: - self.onKeyPageUp() - - if ev.key() == Qt.Key_Home: - self.onKeyHome() - - if ev.key() == Qt.Key_End: - self.onKeyEnd() - - modifiers = ev.modifiers() - isAltModifier = modifiers == Qt.AltModifier - isCtrlModifier = modifiers == Qt.ControlModifier - isShiftModifier = modifiers == Qt.ShiftModifier - - self.checkSetDelObjActionActive(ev) - - self.isZmodifier = ( - ev.key()== Qt.Key_Z and not isAltModifier - and not isCtrlModifier and not isShiftModifier - ) - if isShiftModifier: - if self.brushButton.isChecked(): - # Force default brush symbol with shift down - self.setHoverToolSymbolColor( - 1, 1, self.ax2_BrushCirclePen, - (self.ax2_BrushCircle, self.ax1_BrushCircle), - self.brushButton, brush=self.ax2_BrushCircleBrush, - ID=0 - ) - if self.isSegm3D: - self.changeBrushID() - - isAnyModifier = isAltModifier or isCtrlModifier or isShiftModifier - if not isAnyModifier and self.overlayButton.isChecked(): - isButtonClicked = self.checkOverlayToolbuttonClicked(ev) - if isButtonClicked: - return - - isBrushActive = ( - self.brushButton.isChecked() or self.eraserButton.isChecked() - ) - isManualTrackingActive = self.manualTrackingButton.isChecked() - isManualBackgroundActive = self.manualBackgroundButton.isChecked() - isTypingIDFunctionChecked = False - if self.brushButton.isChecked() and not self.autoIDcheckbox.isChecked(): - success = self.keyPressCheckSetSpinboxValue(ev, self.editIDspinbox) - isTypingIDFunctionChecked = True - - if isManualTrackingActive: - isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( - ev, self.manualTrackingToolbar.spinboxID - ) - - elif isManualBackgroundActive: - isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( - ev, self.manualBackgroundToolbar.spinboxID - ) - - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - if ( - addPointsByClickingButton is not None - and addPointsByClickingButton.toolbar.isVisible() - ): - isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( - ev, addPointsByClickingButton.rightClickIDSpinbox - ) - - isBrushKey, isEraserKey = self.checkTriggerKeyPressShortcuts(ev) - isExpandLabelActive = self.expandLabelToolButton.isChecked() - isWandActive = self.wandToolButton.isChecked() - isLabelRoiCircActive = ( - self.labelRoiButton.isChecked() - and self.labelRoiIsCircularRadioButton.isChecked() - ) - how = self.drawIDsContComboBox.currentText() - isOverlaySegm = how.find('overlay segm. masks') != -1 - if ev.key()==Qt.Key_Up and not isCtrlModifier: - self.keyUpCallback( - isBrushActive, isWandActive, isExpandLabelActive, - isLabelRoiCircActive - ) - elif ev.key()==Qt.Key_Down and not isCtrlModifier: - self.keyDownCallback( - isBrushActive, isWandActive, isExpandLabelActive, - isLabelRoiCircActive - ) - elif ev.key() == Qt.Key_Enter or ev.key() == Qt.Key_Return: - if isTypingIDFunctionChecked: - self.typingEditID = False - elif self.keepIDsButton.isChecked(): - self.keepIDsConfirmAction.trigger() - elif ev.key() == Qt.Key_Escape: - self.onEscape(isTypingIDFunctionChecked=isTypingIDFunctionChecked) - elif isAltModifier: - isCursorSizeAll = self.app.overrideCursor() == Qt.SizeAllCursor - # Alt is pressed while cursor is on images --> set SizeAllCursor - if self.xHoverImg is not None and not isCursorSizeAll: - self.app.setOverrideCursor(Qt.SizeAllCursor) - elif isCtrlModifier and isOverlaySegm: - if ev.key() == Qt.Key_Up: - val = self.imgGrad.labelsAlphaSlider.value() - delta = 5/self.imgGrad.labelsAlphaSlider.maximum() - val = val+delta - self.imgGrad.labelsAlphaSlider.setValue(val, emitSignal=True) - elif ev.key() == Qt.Key_Down: - val = self.imgGrad.labelsAlphaSlider.value() - delta = 5/self.imgGrad.labelsAlphaSlider.maximum() - val = val-delta - self.imgGrad.labelsAlphaSlider.setValue(val, emitSignal=True) - elif ev.key() == self.zoomOutKeyValue: - self.zoomToCells(enforce=True) - if self.countKeyPress == 0: - self.isKeyDoublePress = False - self.countKeyPress = 1 - self.doubleKeyTimeElapsed = False - self.Button = None - QTimer.singleShot(400, self.doubleKeyTimerCallBack) - elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: - self.ax1.autoRange() - self.isKeyDoublePress = True - self.countKeyPress = 0 - elif ev.key() == Qt.Key_Space: - if self.countKeyPress == 0: - # Single press --> wait that it's not double press - self.isKeyDoublePress = False - self.countKeyPress = 1 - self.doubleKeyTimeElapsed = False - QTimer.singleShot(300, self.doubleKeySpacebarTimerCallback) - elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: - self.isKeyDoublePress = True - # Double press --> toggle draw nothing - self.onDoubleSpaceBar() - self.countKeyPress = 0 - elif isBrushKey or isEraserKey: - if isBrushKey: - self.Button = self.brushButton - else: - self.Button = self.eraserButton - - if not self.Button.isVisible(): - return - - if self.countKeyPress == 0: - # If first time clicking B activate brush and start timer - # to catch double press of B - if not self.Button.isChecked(): - self.uncheck = False - self.Button.setChecked(True) - else: - self.uncheck = True - self.countKeyPress = 1 - self.isKeyDoublePress = False - self.doubleKeyTimeElapsed = False - - QTimer.singleShot(400, self.doubleKeyTimerCallBack) - elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: - self.isKeyDoublePress = True - color = self.Button.palette().button().color().name() - if color == self.doublePressKeyButtonColor: - c = self.defaultToolBarButtonColor - else: - c = self.doublePressKeyButtonColor - self.Button.setStyleSheet(f'background-color: {c}') - self.countKeyPress = 0 - if self.xHoverImg is not None: - xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) - if isBrushKey: - self.setHoverToolSymbolColor( - xdata, ydata, self.ax2_BrushCirclePen, - (self.ax2_BrushCircle, self.ax1_BrushCircle), - self.brushButton, brush=self.ax2_BrushCircleBrush - ) - elif isEraserKey: - self.setHoverToolSymbolColor( - xdata, ydata, self.eraserCirclePen, - (self.ax2_EraserCircle, self.ax1_EraserCircle), - self.eraserButton - ) - - def doubleRightClickTimerCallBack(self): - if self.isDoubleRightClick: - self.doubleRightClickTimeElapsed = False - return - self.doubleRightClickTimeElapsed = True - self.countRightClicks = 0 - - # Time to double right click on img1 expired --> single right-click - self.gui_imgGradShowContextMenu(*self._img1_click_xy) - - def doubleKeyTimerCallBack(self): - if self.isKeyDoublePress: - self.doubleKeyTimeElapsed = False - return - self.doubleKeyTimeElapsed = True - self.countKeyPress = 0 - if self.Button is None: - return - - isBrushChecked = self.Button.isChecked() - if isBrushChecked and self.uncheck: - self.Button.setChecked(False) - c = self.defaultToolBarButtonColor - self.Button.setStyleSheet(f'background-color: {c}') - - def doubleKeySpacebarTimerCallback(self): - if self.isKeyDoublePress: - self.doubleKeyTimeElapsed = False - return - self.doubleKeyTimeElapsed = True - self.countKeyPress = 0 - - # # Spacebar single press --> toggle next visualization - # currentIndex = self.drawIDsContComboBox.currentIndex() - # nItems = self.drawIDsContComboBox.count() - # nextIndex = currentIndex+1 - # if nextIndex < nItems: - # self.drawIDsContComboBox.setCurrentIndex(nextIndex) - # else: - # self.drawIDsContComboBox.setCurrentIndex(0) - - def updateBrushCursorOnShiftRelease(self): - xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) - self.setHoverToolSymbolColor( - xdata, ydata, self.ax2_BrushCirclePen, - (self.ax2_BrushCircle, self.ax1_BrushCircle), - self.brushButton, brush=self.ax2_BrushCircleBrush, - byPassShiftCheck=True - ) - if self.isSegm3D: - self.changeBrushID() - - def onShiftReleased(self): - if self.brushButton.isChecked() and self.xHoverImg is not None: - self.updateBrushCursorOnShiftRelease() - - def keyReleaseEvent(self, ev): - if self.app.overrideCursor() == Qt.SizeAllCursor: - self.app.restoreOverrideCursor() - if ev.key() == Qt.Key_Control: - self.onCtrlReleased() - elif ev.key() == Qt.Key_Shift: - self.onShiftReleased() - - canRepeat = ( - ev.key() == Qt.Key_Left - or ev.key() == Qt.Key_Right - or ev.key() == Qt.Key_Up - or ev.key() == Qt.Key_Down - or ev.key() == Qt.Key_Control - or ev.key() == Qt.Key_Backspace - or self.delObjToolAction.isChecked() - ) - - if canRepeat and ev.isAutoRepeat(): - return - - self.delObjToolAction.setChecked(False) - - if ev.isAutoRepeat() and not ev.key() == Qt.Key_Z: - if self.warnKeyPressedMsg is not None: - return - self.warnKeyPressedMsg = widgets.myMessageBox( - showCentered=False, wrapText=False - ) - txt = html_utils.paragraph(f""" - Please, do not keep the key "{ev.text().upper()}" - pressed.

- It confuses me :)

- Thanks! - """) - self.warnKeyPressedMsg.warning(self, 'Release the key, please', txt) - self.warnKeyPressedMsg = None - elif ev.isAutoRepeat() and ev.key() == Qt.Key_Z and self.isZmodifier: - self.zKeptDown = True - elif ev.key() == Qt.Key_Z and self.isZmodifier: - posData = self.data[self.pos_i] - self.isZmodifier = False - if not self.zKeptDown and posData.SizeZ > 1: - self.zSliceCheckbox.setChecked(not self.zSliceCheckbox.isChecked()) - self.zKeptDown = False - - def setUncheckedAllButtons(self, buttonsToNotUncheck=None): - self.clickedOnBud = False - if buttonsToNotUncheck is None: - buttonsToNotUncheck = set() - - try: - self.BudMothTempLine.setData([], []) - except Exception as e: - pass - for button in self.checkableButtons: - if button in buttonsToNotUncheck: - continue - button.setChecked(False) - - if self.countObjsButton not in buttonsToNotUncheck: - self.countObjsButton.setChecked(False) - self.splineHoverON = False - self.tempSegmentON = False - self.isRightClickDragImg1 = False - self.clearCurvItems(removeItems=False) - - def setUncheckedAllCustomAnnotButtons(self): - for button in self.customAnnotDict.keys(): - button.setChecked(False) - - def askPropagateChangePast(self, change_txt): - txt = html_utils.paragraph(f""" - Do you want to propagate the change "{change_txt}" to the past frames? - """) - msg = widgets.myMessageBox(wrapText=False) - yesButton, _ = msg.question( - self, 'Propagate change to past frames', txt, - buttonsTexts=('Yes', 'No') - ) - return msg.clickedButton == yesButton - - def propagateMergeObjsPast(self, IDs_to_merge): - self.store_data(autosave=False) - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - for past_frame_i in range(posData.frame_i-1, -1, -1): - posData.frame_i = past_frame_i - self.get_data() - - IDs = posData.allData_li[past_frame_i]['IDs'] - stop_loop = False - for ID in IDs_to_merge: - if ID not in IDs: - stop_loop = True - break - - if ID == 0: - continue - posData.lab[posData.lab==ID] = self.firstID - self.update_rp() - - self.store_data(autosave=False) - - if stop_loop: - break - - posData.frame_i = current_frame_i - self.get_data() - - def propagateChange( - self, modID, modTxt, doNotShow, UndoFutFrames, - applyFutFrames, applyTrackingB=False, force=False - ): - """ - This function determines whether there are already visited future frames - that contains "modID". If so, it triggers a pop-up asking the user - what to do (propagate change to future frames o not) - """ - posData = self.data[self.pos_i] - # Do not check the future for the last frame - if posData.frame_i+1 == posData.SizeT: - # No future frames to propagate the change to - return False, False, None, doNotShow - - includeUnvisited = posData.includeUnvisitedInfo.get(modTxt, False) - areFutureIDs_affected = [] - # Get number of future frames already visited and check if future - # frames has an ID affected by the change - last_tracked_i_found = False - segmSizeT = len(posData.segm_data) - for i in range(posData.frame_i+1, segmSizeT): - if posData.allData_li[i]['labels'] is None: - if not last_tracked_i_found: - # We set last tracked frame at -1 first None found - last_tracked_i = i - 1 - last_tracked_i_found = True - if not includeUnvisited: - # Stop at last visited frame since includeUnvisited = False - break - else: - lab = posData.segm_data[i] - else: - lab = posData.allData_li[i]['labels'] - - if modID in lab: - areFutureIDs_affected.append(True) - - if not last_tracked_i_found: - # All frames have been visited in segm&track mode - last_tracked_i = posData.SizeT - 1 - - if last_tracked_i == posData.frame_i and not includeUnvisited: - # No future frames to propagate the change to - return False, False, None, doNotShow - - if not areFutureIDs_affected and not force: - # There are future frames but they are not affected by the change - return UndoFutFrames, False, None, doNotShow - - # Ask what to do unless the user has previously checked doNotShowAgain - if doNotShow: - endFrame_i = last_tracked_i - if applyFutFrames and not UndoFutFrames and modTxt == 'Edit ID': - self.whitelistSyncIDsOG(frame_is=range(posData.frame_i, endFrame_i+1)) - return UndoFutFrames, applyFutFrames, endFrame_i, doNotShow - else: - addApplyAllButton = ( - modTxt == 'Delete ID' or modTxt == 'Edit ID' - or modTxt == 'Assign new ID' - ) - ffa = apps.FutureFramesAction_QDialog( - posData.frame_i+1, last_tracked_i, modTxt, - applyTrackingB=applyTrackingB, parent=self, - addApplyAllButton=addApplyAllButton - ) - ffa.exec_() - decision = ffa.decision - - if decision is None: - return None, None, None, doNotShow - - endFrame_i = ffa.endFrame_i - doNotShowAgain = ffa.doNotShowCheckbox.isChecked() - askAction = self.askHowFutureFramesActions[modTxt] - askAction.setChecked( not doNotShowAgain) - askAction.setDisabled(False) - - self.onlyTracking = False - if decision == 'apply_and_reinit': - UndoFutFrames = True - applyFutFrames = False - elif decision == 'apply_and_NOTreinit': - UndoFutFrames = False - applyFutFrames = False - elif decision == 'apply_to_all_visited': - UndoFutFrames = False - applyFutFrames = True - elif decision == 'only_tracking': - UndoFutFrames = False - applyFutFrames = True - self.onlyTracking = True - elif decision == 'apply_to_all': - UndoFutFrames = False - applyFutFrames = True - posData.includeUnvisitedInfo[modTxt] = True - - if applyFutFrames and not UndoFutFrames and modTxt == 'Edit ID': - self.whitelistSyncIDsOG(frame_is=range(posData.frame_i, endFrame_i+1)) - return UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain - - def addCcaState(self, frame_i, cca_df, undoId): - posData = self.data[self.pos_i] - posData.UndoRedoCcaStates[frame_i].insert( - 0, {'id': undoId, 'cca_df': cca_df.copy()} - ) - - def addCurrentState(self, storeImage=False, storeOnlyZoom=False): - posData = self.data[self.pos_i] - if posData.cca_df is not None: - cca_df = posData.cca_df.copy() - else: - cca_df = None - - if storeImage: - image = self.img1.image.copy() - else: - image = None - - if storeOnlyZoom: - labels, crop_slice = transformation.crop_2D( - self.currentLab2D, self.ax1.viewRange(), tolerance=10, - return_copy=False - ) - if self.isSegm3D: - z = self.z_lab(checkIfProj=True) - if z is None: - z_slice = slice(0, len(posData.lab)) - crop_slice = (z_slice, *crop_slice) - labels = posData.lab[crop_slice].copy() - else: - z_slice = z - crop_slice = (z_slice, *crop_slice) - labels = labels.copy() - else: - labels = labels.copy() - else: - labels = posData.lab.copy() - crop_slice = None - - state = { - 'image': image, - 'labels': labels, - 'editID_info': posData.editID_info.copy(), - 'binnedIDs': posData.binnedIDs.copy(), - 'keptObejctsIDs': self.keptObjectsIDs.copy(), - 'ripIDs': posData.ripIDs.copy(), - 'cca_df': cca_df, - 'crop_slice': crop_slice - } - posData.UndoRedoStates[posData.frame_i].insert(0, state) - - # posData.storedLab = np.array(posData.lab, order='K', copy=True) - # self.storeStateWorker.callbackOnDone = callbackOnDone - # self.storeStateWorker.enqueue(posData, self.img1.image) - - def getCurrentState(self): - posData = self.data[self.pos_i] - i = posData.frame_i - c = self.UndoCount - state = posData.UndoRedoStates[i][c] - if state['image'] is None: - image_left = None - else: - image_left = state['image'].copy() - - crop_slice = state['crop_slice'] - if crop_slice is None: - posData.lab = state['labels'].copy() - elif self.isSegm3D: - z_slice, slice_y, slice_x = crop_slice - posData.lab[..., z_slice, slice_y, slice_x] = state['labels'].copy() - else: - slice_y, slice_x = crop_slice - posData.lab[..., slice_y, slice_x] = state['labels'].copy() - - posData.editID_info = state['editID_info'].copy() - posData.binnedIDs = state['binnedIDs'].copy() - posData.ripIDs = state['ripIDs'].copy() - self.keptObjectsIDs = state['keptObejctsIDs'].copy() - cca_df = state['cca_df'] - if cca_df is not None: - posData.cca_df = state['cca_df'].copy() - else: - posData.cca_df = None - return image_left - - def storeLabelRoiParams(self, value=None, checked=True): - checkedRoiType = self.labelRoiTypesGroup.checkedButton().text() - circRoiRadius = self.labelRoiCircularRadiusSpinbox.value() - roiZdepth = self.labelRoiZdepthSpinbox.value() - autoClearBorder = self.labelRoiAutoClearBorderCheckbox.isChecked() - clearBorder = 'Yes' if autoClearBorder else 'No' - self.df_settings.at['labelRoi_checkedRoiType', 'value'] = checkedRoiType - self.df_settings.at['labelRoi_circRoiRadius', 'value'] = circRoiRadius - self.df_settings.at['labelRoi_roiZdepth', 'value'] = roiZdepth - self.df_settings.at['labelRoi_autoClearBorder', 'value'] = clearBorder - self.df_settings.at['labelRoi_replaceExistingObjects', 'value'] = ( - 'Yes' if self.labelRoiReplaceExistingObjectsCheckbox.isChecked() - else 'No' - ) - self.df_settings.to_csv(self.settings_csv_path) - - def loadLabelRoiLastParams(self): - idx = 'labelRoi_checkedRoiType' - if idx in self.df_settings.index: - checkedRoiType = self.df_settings.at[idx, 'value'] - for button in self.labelRoiTypesGroup.buttons(): - if button.text() == checkedRoiType: - button.setChecked(True) - break - - idx = 'labelRoi_circRoiRadius' - if idx in self.df_settings.index: - circRoiRadius = self.df_settings.at[idx, 'value'] - self.labelRoiCircularRadiusSpinbox.setValue(int(circRoiRadius)) - - idx = 'labelRoi_roiZdepth' - if idx in self.df_settings.index: - roiZdepth = self.df_settings.at[idx, 'value'] - self.labelRoiZdepthSpinbox.setValue(int(roiZdepth)) - - idx = 'labelRoi_autoClearBorder' - if idx in self.df_settings.index: - clearBorder = self.df_settings.at[idx, 'value'] - checked = clearBorder == 'Yes' - self.labelRoiAutoClearBorderCheckbox.setChecked(checked) - - idx = 'labelRoi_replaceExistingObjects' - if idx in self.df_settings.index: - val = self.df_settings.at[idx, 'value'] - checked = val == 'Yes' - self.labelRoiReplaceExistingObjectsCheckbox.setChecked(checked) - - if self.labelRoiIsCircularRadioButton.isChecked(): - self.labelRoiCircularRadiusSpinbox.setDisabled(False) - - # @exec_time - def storeUndoRedoStates( - self, UndoFutFrames, storeImage=False, storeOnlyZoom=False - ): - posData = self.data[self.pos_i] - if UndoFutFrames: - # Since we modified current frame all future frames that were already - # visited are not valid anymore. Undo changes there - self.reInitLastSegmFrame(updateImages=False) - - # Keep only 5 Undo/Redo states - if len(posData.UndoRedoStates[posData.frame_i]) > 5: - posData.UndoRedoStates[posData.frame_i].pop(-1) - - # Restart count from the most recent state (index 0) - # NOTE: index 0 is most recent state before doing last change - self.UndoCount = 0 - self.undoAction.setEnabled(True) - self.addCurrentState( - storeImage=storeImage, storeOnlyZoom=storeOnlyZoom - ) - - def storeUndoRedoCca(self, frame_i, cca_df, undoId): - if self.isSnapshot: - # For snapshot mode we don't store anything because we have only - # segmentation undo action active - return - """ - Store current cca_df along with a unique id to know which cca_df needs - to be restored - """ - - posData = self.data[self.pos_i] - - # Restart count from the most recent state (index 0) - # NOTE: index 0 is most recent state before doing last change - self.UndoCcaCount = 0 - self.undoAction.setEnabled(True) - - self.addCcaState(frame_i, cca_df, undoId) - - # Keep only 10 Undo/Redo states - if len(posData.UndoRedoCcaStates[frame_i]) > 10: - posData.UndoRedoCcaStates[frame_i].pop(-1) - - def undoCustomAnnotation(self): - pass - - def UndoCca(self): - posData = self.data[self.pos_i] - # Undo current ccaState - storeState = False - if self.UndoCount == 0: - undoId = uuid.uuid4() - self.addCcaState(posData.frame_i, posData.cca_df, undoId) - storeState = True - - - # Get previously stored state - self.UndoCount += 1 - currentCcaStates = posData.UndoRedoCcaStates[posData.frame_i] - prevCcaState = currentCcaStates[self.UndoCount] - posData.cca_df = prevCcaState['cca_df'] - self.store_cca_df() - self.updateAllImages() - - # Check if we have undone all states - if len(currentCcaStates) > self.UndoCount: - # There are no states left to undo for current frame_i - self.undoAction.setEnabled(False) - - # Undo all past and future frames that has a last status inserted - # when modyfing current frame - prevStateId = prevCcaState['id'] - for frame_i in range(0, posData.SizeT): - if storeState: - cca_df_i = self.get_cca_df(frame_i=frame_i, return_df=True) - if cca_df_i is None: - break - # Store current state to enable redoing it - self.addCcaState(frame_i, cca_df_i, undoId) - - CcaStates_i = posData.UndoRedoCcaStates[frame_i] - if len(CcaStates_i) <= self.UndoCount: - # There are no states to undo for frame_i - continue - - CcaState_i = CcaStates_i[self.UndoCount] - id_i = CcaState_i['id'] - if id_i != prevStateId: - # The id of the state in frame_i is different from current frame - continue - - cca_df_i = CcaState_i['cca_df'] - self.store_cca_df(frame_i=frame_i, cca_df=cca_df_i, autosave=False) - - self.resetWillDivideInfo() - self.enqAutosave() - - def undo(self): - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - if addPointsByClickingButton is not None: - done = self.undoAddPoint(addPointsByClickingButton.action) - if done: - return - - if self.UndoCount == 0: - # Store current state to enable redoing it - self.addCurrentState() - - posData = self.data[self.pos_i] - # Get previously stored state - if self.UndoCount < len(posData.UndoRedoStates[posData.frame_i])-1: - self.UndoCount += 1 - # Since we have undone then it is possible to redo - self.redoAction.setEnabled(True) - - # Restore state - image_left = self.getCurrentState() - self.update_rp() - self.updateAllImages(image=image_left) - self.store_data() - - if not self.UndoCount < len(posData.UndoRedoStates[posData.frame_i])-1: - # We have undone all available states - self.undoAction.setEnabled(False) - - if self.whitelistIDsButton.isChecked(): - self.whitelistHighlightIDs() - - def redo(self): - posData = self.data[self.pos_i] - # Get previously stored state - if self.UndoCount > 0: - self.UndoCount -= 1 - # Since we have redone then it is possible to undo - self.undoAction.setEnabled(True) - - # Restore state - image_left = self.getCurrentState() - self.update_rp() - self.updateAllImages(image=image_left) - self.store_data() - - if not self.UndoCount > 0: - # We have redone all available states - self.redoAction.setEnabled(False) - - if self.whitelistIDsButton.isChecked(): - self.whitelistHighlightIDs() - - def realTimeTrackingClicked(self, checked): - # Event called ONLY if the user click on Disable tracking - # NOT called if setChecked is called. This allows to keep track - # of the user choice. This way user con enforce tracking - # NOTE: I know two booleans doing the same thing is overkill - # but the code is more readable when we actually need them - - posData = self.data[self.pos_i] - isRealTimeTrackingDisabled = not checked - - # Turn off smart tracking - self.enableSmartTrackAction.toggled.disconnect() - self.enableSmartTrackAction.setChecked(False) - if isRealTimeTrackingDisabled: - self.UserEnforced_DisabledTracking = True - self.UserEnforced_Tracking = False - else: - txt = html_utils.paragraph(""" - - Do you want to keep tracking always active including on already - visited frames?

- Note: To re-activate automatic handling of tracking go to
- Edit --> Smart handling of enabling/disabling tracking. - - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - yesButton, noButton = msg.question( - self, 'Keep tracking always active?', txt, - buttonsTexts=('Yes', 'No') - ) - if msg.clickedButton == yesButton: - self.repeatTracking() - self.UserEnforced_DisabledTracking = False - self.UserEnforced_Tracking = True - else: - self.enableSmartTrackAction.setChecked(True) - - @exception_handler - def repeatTrackingVideo(self, checked=False): - posData = self.data[self.pos_i] - win = widgets.selectTrackerGUI( - posData.SizeT, currentFrameNo=posData.frame_i+1 - ) - win.exec_() - if win.cancel: - self.logger.info('Tracking aborted.') - return - - trackerName = win.selectedItemsText[0] - start_n = win.startFrame - stop_n = win.stopFrame - video_to_track = posData.segm_data - for frame_i in range(start_n-1, stop_n): - data_dict = posData.allData_li[frame_i] - lab = data_dict['labels'] - if lab is None: - break - - video_to_track[frame_i] = lab - video_to_track = video_to_track[start_n-1:stop_n] - - self.logger.info(f'Importing {trackerName} tracker...') - self.tracker, self.track_params, init_params = myutils.init_tracker( - posData, trackerName, qparent=self, return_init_params=True - ) - if self.track_params is None: - self.logger.info('Tracking aborted.') - return - - warningText = myutils.validate_tracker_input( - self.tracker, video_to_track - ) - if warningText is not None: - self.logger.info(warningText) - self.warnTrackerInputNotValid(trackerName, warningText) - return - - if 'image_channel_name' in self.track_params: - # Remove the channel name since it was already loaded in init_tracker - del self.track_params['image_channel_name'] - - track_params_log = { - key: value for key, value in self.track_params.items() - if key != 'image' - } - self.logger.info( - 'Tracking parameters:\n\n' - f'Initialization parameters: {init_params}\n' - f'Track parameters: {track_params_log}' - ) - - last_cca_i = self.get_last_cca_frame_i() - if start_n-2 <= last_cca_i and start_n>1: - proceed = self.warnRepeatTrackingVideoWithAnnotations( - last_cca_i, start_n - ) - if not proceed: - self.logger.info('Tracking aborted.') - return - - self.logger.info(f'Removing annotations from frame n. {start_n}.') - self.resetCcaFuture(start_n-1) - - self.start_n = start_n - self.stop_n = stop_n - - info_txt = f'Tracking from frame n. {start_n} to {stop_n}...' - self.logger.info(info_txt) - - self.progressWin = apps.QDialogWorkerProgress( - title='Tracking', parent=self, pbarDesc=info_txt - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(stop_n-start_n) - self.startTrackingWorker(posData, video_to_track) - - def warnTrackerInputNotValid(self, trackerName, warningText): - msg = widgets.myMessageBox(wrapText=False) - txt = warningText.replace('\n', '
') - txt = html_utils.paragraph( - f'{txt}

' - 'Tracking process will be cancelled. Thank you for your patience!' - ) - msg.warning(self, 'Invalid input for tracker', txt) - - def repeatTracking(self): - posData = self.data[self.pos_i] - prev_lab = self.get_2Dlab(posData.lab).copy() - self.tracking(enforce=True, DoManualEdit=False) - if posData.editID_info: - editedIDsInfo = { - posData.lab[y,x]:newID - for y, x, newID in posData.editID_info - if posData.lab[y,x] != newID - } - editedIDsInfoItems = [ - f'ID {oldID} --> {newID}' - for oldID, newID in editedIDsInfo.items() - ] - editIDul = html_utils.to_list(editedIDsInfoItems) - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - You requested to repeat tracking but there are manually - edited IDs (see edited IDs in the details section below) -

- Do you want to keep these edits or ignore them? - """) - keepManualEditButton = widgets.okPushButton( - 'Keep manually edited IDs' - ) - ignoreButton = widgets.noPushButton( - 'Ignore manually edited IDs' - ) - msg.question( - self, 'Repeat tracking mode', txt, - buttonsTexts=(keepManualEditButton, ignoreButton), - detailsText=editIDul - ) - if msg.cancel: - return - if msg.clickedButton == keepManualEditButton: - allIDs = [obj.label for obj in posData.rp] - lab2D = self.get_2Dlab(posData.lab) - self.manuallyEditTracking(lab2D, allIDs) - self.update_rp() - self.setAllTextAnnotations() - self.highlightLostNew() - # self.checkIDsMultiContour() - else: - posData.editID_info = [] - if np.any(posData.lab != prev_lab): - if self.isSnapshot: - self.fixCcaDfAfterEdit('Repeat tracking') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Repeat tracking') - else: - self.updateAllImages() - - def updateGhostMaskOpacity(self, alpha_percentage=None): - if alpha_percentage is None: - alpha_percentage = ( - self.manualTrackingToolbar.ghostMaskOpacitySpinbox.value() - ) - alpha = alpha_percentage/100 - self.ghostMaskItemLeft.setOpacity(alpha) - self.ghostMaskItemRight.setOpacity(alpha) - - def addManualTrackingItems(self): - self.ghostContourItemLeft.addToPlotItem() - self.ghostContourItemRight.addToPlotItem() - - self.ghostMaskItemLeft.addToPlotItem() - self.ghostMaskItemRight.addToPlotItem() - - Y, X = self.img1.image.shape[:2] - self.ghostMaskItemLeft.initImage((Y, X)) - self.ghostMaskItemRight.initImage((Y, X)) - - self.updateGhostMaskOpacity() - - def removeManualTrackingItems(self): - self.ghostContourItemLeft.removeFromPlotItem() - self.ghostContourItemRight.removeFromPlotItem() - - self.ghostMaskItemLeft.removeFromPlotItem() - self.ghostMaskItemRight.removeFromPlotItem() - - def addManualBackgroundItems(self): - self.manualBackgroundObjItem.addToPlotItem() - self.ax1.addItem(self.manualBackgroundImageItem) - - def removeManualBackgroundItems(self): - self.manualBackgroundObjItem.removeFromPlotItem() - self.ax1.removeItem(self.manualBackgroundImageItem) - - def resetManualBackgroundSpinboxID(self): - if not self.manualBackgroundButton.isChecked(): - self.manualBackgroundObj = None - return - - posData = self.data[self.pos_i] - minID = min(posData.IDs, default=0) - self.manualBackgroundToolbar.spinboxID.setValue(minID) - - def initManualBackgroundObject(self, ID=None): - if not self.manualBackgroundButton.isChecked(): - self.manualBackgroundObj = None - return - - if ID is None: - ID = self.manualBackgroundToolbar.spinboxID.value() - - posData = self.data[self.pos_i] - if ID not in posData.IDs: - self.manualBackgroundObj = None - self.manualBackgroundToolbar.showWarning( - f'The ID {ID} does not exist' - ) - self.manualBackgroundObjItem.clear() - return - - ID_idx = posData.IDs_idxs[ID] - self.manualBackgroundObj = posData.rp[ID_idx] - - self.manualBackgroundToolbar.clearInfoText() - self.manualBackgroundObj.contour = self.getObjContours( - self.manualBackgroundObj, local=True - ) - xx_contour = self.manualBackgroundObj.contour[:,0] - yy_contour = self.manualBackgroundObj.contour[:,1] - self.manualBackgroundObj.xx_contour = xx_contour - self.manualBackgroundObj.yy_contour = yy_contour - - def initGhostObject(self, ID=None): - mode = self.modeComboBox.currentText() - if mode != 'Segmentation and Tracking': - self.ghostObject = None - return - - if not self.manualTrackingButton.isChecked(): - self.ghostObject = None - return - - if not self.manualTrackingToolbar.showGhostCheckbox.isChecked(): - self.ghostObject = None - return - - if ID is None: - ID = self.manualTrackingToolbar.spinboxID.value() - - posData = self.data[self.pos_i] - if posData.frame_i == 0: - self.ghostObject = None - return - - prevFrameRp = posData.allData_li[posData.frame_i-1]['regionprops'] - if prevFrameRp is None: - self.ghostObject = None - return - - for obj in prevFrameRp: - if obj.label != ID: - continue - self.ghostObject = obj - break - else: - self.ghostObject = None - self.manualTrackingToolbar.showWarning( - f'The ID {ID} does not exist in previous frame ' - '--> starting a new track.' - ) - return - - self.manualTrackingToolbar.clearInfoText() - - self.ghostObject.contour = self.getObjContours( - self.ghostObject, local=True - ) - self.ghostObject.xx_contour = self.ghostObject.contour[:,0] - self.ghostObject.yy_contour = self.ghostObject.contour[:,1] - - self.ghostMaskItemLeft.initLookupTable(self.lut[ID]) - self.ghostMaskItemRight.initLookupTable(self.lut[ID]) - - def clearGhost(self): - self.clearGhostContour() - self.clearGhostMask() - - def clearManualBackgroundAnnotations(self): - try: - for textItem in self.manualBackgroundTextItems.values(): - textItem.setText('') - except Exception as error: - pass - - def clearGhostContour(self): - self.ghostContourItemLeft.clear() - self.ghostContourItemRight.clear() - self.manualBackgroundObjItem.clear() - - def clearGhostMask(self): - self.ghostMaskItemLeft.clear() - self.ghostMaskItemRight.clear() - - @disableWindow - def _importInitMagicPromptModel( - self, model_name, posData, win, acdcPromptSegment, toolbar - ): - self.logger.info(f'Initializing promptable model {model_name}...') - init_kwargs = win.init_kwargs - model = myutils.init_prompt_segm_model( - acdcPromptSegment, posData, win.init_kwargs - ) - toolbar.model = model - toolbar.model_segment_kwargs = win.model_kwargs - toolbar.model_name = model_name - toolbar.viewModelParamsAction.setDisabled(False) - - self.magicPromptsToolbar.setInitializedModel( - init_kwargs, toolbar.model_segment_kwargs - ) - - self.logger.info( - f'Promptable model {model_name} successfully initialised!' - ) - - @exception_handler - def magicPromptsInitModel( - self, - model_name, - acdcPromptSegment, - init_argspecs, - segment_argspecs, - help_url, - toolbar, - ): - posData = self.data[self.pos_i] - - out = prompts.init_prompt_model_params( - posData, model_name, init_argspecs, segment_argspecs, - help_url=help_url, qparent=self, init_last_params=True - ) - win = out.get('win') - if win.cancel: - self.logger.info( - f'Initialization of {model_name} promptable model cancelled.' - ) - return - - self._importInitMagicPromptModel( - model_name, posData, win, acdcPromptSegment, toolbar - ) - - def viewSetMagicPromptModelParams( - self, - model_name, - acdcPromptSegment, - init_argspecs, - segment_argspecs, - help_url, - init_kwargs, - segment_kwargs, - toolbar - ): - posData = self.data[self.pos_i] - - init_argspecs = myutils.setDefaultValueArgSpecsFromKwargs( - init_argspecs, init_kwargs - ) - segment_argspecs = myutils.setDefaultValueArgSpecsFromKwargs( - segment_argspecs, segment_kwargs - ) - - out = prompts.init_prompt_model_params( - posData, model_name, init_argspecs, segment_argspecs, - help_url=help_url, qparent=self, init_last_params=False - ) - win = out.get('win') - if win.cancel: - return - - if win.model_kwargs != segment_kwargs or win.init_kwargs != init_kwargs: - self._importInitMagicPromptModel( - model_name, posData, win, acdcPromptSegment, toolbar - ) - - def getMagicPromptsInputs(self, toolbar): - if not self.promptSegmentPointsLayerToolbar.isPointsLayerInit: - _warnings.warnPromptSegmentPointsLayerNotInit(qparent=self) - return - - if not self.magicPromptsToolbar.viewModelParamsAction.isEnabled(): - _warnings.warnPromptSegmentModelNotInit(qparent=self) - return - - posData = self.data[self.pos_i] - image = self.getDisplayedZstack() - df_points = self.promptSegmentPointsLayerToolbar.pointsLayerDf( - posData, isSegm3D=self.isSegm3D - ) - - self.logger.info( - f'Starting {toolbar.model_name} promptable segmentation with the ' - f'following prompts:\n\n{df_points}' - ) - - return image, df_points - - @disableWindow - def magicPromptsComputeOnZoomTriggered(self, toolbar): - inputs = self.getMagicPromptsInputs(toolbar) - if inputs is None: - self.logger.info( - '"Computing promptable segmentation on zoom" process cancelled.' - ) - return - - posData = self.data[self.pos_i] - image, df_points = inputs - - ((xmin, xmax), (ymin, ymax)) = self.ax1.viewRange() - Y, X = image.shape[-2:] - - xmin = int(max(0, xmin)) - xmax = int(min(X, xmax)) - ymin = int(max(0, ymin)) - ymax = int(min(Y, ymax)) - - self.logger.info( - f'Zoom range: xmin={xmin}, xmax={xmax}, ymin={ymin}, ymax={ymax}' - ) - - zoom_slice = (slice(ymin, ymax), slice(xmin, xmax)) - - image = image[..., ymin:ymax, xmin:xmax] - image_origin = (0, ymin, xmin) - - df_points = df_points[df_points['y'] >= ymin] - df_points = df_points[df_points['x'] >= xmin] - df_points = df_points[df_points['y'] < ymax] - df_points = df_points[df_points['x'] < xmax] - - df_points['y'] -= ymin - df_points['x'] -= xmin - - df_points = df_points[ df_points['frame_i'] == posData.frame_i] - - self.logger.info( - f'Image origin = {image_origin}\n' - f'Image shape = {image.shape}' - ) - - self.startMagicPromptsWorkerAndWait( - image, df_points, toolbar.model, toolbar.model_segment_kwargs, - image_origin=image_origin, zoom_slice=zoom_slice - ) - - def magicPromptsInterpolateZsliceToggled(self, checked): - # See 'self.promptSegmentPointsLayerToolbar.addPointsZslicesInterpolation' - self.promptSegmentPointsLayerToolbar.doAddPointsZslicesInterpolation = ( - checked - ) - - def magicPromptsClearPoints(self, toolbar, only_zoom=False): - posData = self.data[self.pos_i] - scatterItem = self.promptSegmentPointsLayerToolbar.scatterItem() - action = scatterItem.action - - pointsDataPos = action.pointsData.get(self.pos_i) - if pointsDataPos is None: - return - - framePointsData = action.pointsData[self.pos_i].pop( - posData.frame_i, None - ) - if framePointsData is None: - return - - if not only_zoom: - scatterItem.clear() - return - - ((xmin, xmax), (ymin, ymax)) = self.ax1.viewRange() - Y, X = posData.img_data.shape[-2:] - - xmin = int(max(0, xmin)) - xmax = int(min(X, xmax)) - ymin = int(max(0, ymin)) - ymax = int(min(Y, ymax)) - - if 'x' in framePointsData: - newFramePointsData = {'x': [], 'y': [], 'id': []} - xx = framePointsData['x'] - yy = framePointsData['y'] - ids = framePointsData['id'] - for x, y, point_id in zip(xx, yy, ids): - if x < xmin or x >= xmax or y < ymin or y >= ymax: - newFramePointsData['x'].append(x) - newFramePointsData['y'].append(y) - newFramePointsData['id'].append(point_id) - else: - newFramePointsData = {} - for z, zSliceFramePointsData in framePointsData.items(): - newFramePointsData[z] = {'x': [], 'y': [], 'id': []} - xx = zSliceFramePointsData['x'] - yy = zSliceFramePointsData['y'] - ids = zSliceFramePointsData['id'] - for x, y, point_id in zip(xx, yy, ids): - if x < xmin or x >= xmax or y < ymin or y >= ymax: - newFramePointsData[z]['x'].append(x) - newFramePointsData[z]['y'].append(y) - newFramePointsData[z]['id'].append(point_id) - - action.pointsData[self.pos_i][posData.frame_i] = newFramePointsData - self.drawPointsLayers() - - @disableWindow - def magicPromptsComputeOnImageTriggered(self, toolbar): - inputs = self.getMagicPromptsInputs(toolbar) - if inputs is None: - self.logger.info( - '"Computing promptable segmentation on entire image" ' - 'process cancelled.' - ) - return - - image, df_points = inputs - - self.startMagicPromptsWorkerAndWait( - image, df_points, toolbar.model, toolbar.model_segment_kwargs - ) - - def startMagicPromptsWorkerAndWait( - self, image, df_points, model, model_segment_kwargs, - image_origin=(0, 0, 0), zoom_slice=None - ): - desc = ( - 'Running promptable segmentation model...' - ) - self.logger.info(desc) - posData = self.data[self.pos_i] - - self.progressWin = apps.QDialogWorkerProgress( - title=desc, parent=self, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.show(self.app) - - self.magicPromptsThread = QThread() - self.magicPromptsWorker = workers.MagicPromptsWorker( - posData, image, df_points, model, model_segment_kwargs, - image_origin=image_origin, - global_image=posData.img_data[posData.frame_i] - ) - - self.magicPromptsWorker.moveToThread( - self.magicPromptsThread - ) - - self.magicPromptsWorker.signals.finished.connect( - self.magicPromptsThread.quit - ) - self.magicPromptsWorker.signals.finished.connect( - self.magicPromptsWorker.deleteLater - ) - self.magicPromptsThread.finished.connect( - self.magicPromptsThread.deleteLater - ) - - self.magicPromptsWorker.signals.critical.connect( - self.magicPromptsWorkerCritical - ) - self.magicPromptsWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.magicPromptsWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.magicPromptsWorker.signals.progress.connect( - self.workerProgress - ) - self.magicPromptsWorker.signals.finished.connect( - partial(self.magicPromptsWorkerFinished, zoom_slice=zoom_slice) - ) - - self.magicPromptsThread.started.connect( - self.magicPromptsWorker.run - ) - self.magicPromptsThread.start() - - self.magicPromptsWorkerLoop = QEventLoop() - self.magicPromptsWorkerLoop.exec_() - - def magicPromptsWorkerCritical(self, error): - self.magicPromptsWorkerLoop.exit() - self.workerCritical(error) - - def magicPromptsWorkerFinished(self, output, zoom_slice=None): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.magicPromptsWorkerLoop.exit() - - lab_new, lab_union, lab_interesection = output - - posData = self.data[self.pos_i] - - is_zoom = True - if zoom_slice is None: - zoom_slice = (slice(None), slice(None)) - is_zoom = False - - img = ( - posData.img_data[posData.frame_i][..., zoom_slice[0], zoom_slice[1]] - ) - images = [img, img, img, img] - labels_overlays = [ - posData.lab[..., zoom_slice[0], zoom_slice[1]], - lab_new[..., zoom_slice[0], zoom_slice[1]], - lab_union[..., zoom_slice[0], zoom_slice[1]], - lab_interesection[..., zoom_slice[0], zoom_slice[1]], - ] - labels_overlays_lut = self.getLabelsImageLut() - labels_overlays_luts = [ - labels_overlays_lut, - labels_overlays_lut, - labels_overlays_lut, - labels_overlays_lut, - ] - axis_titles = [ - 'Original masks', - 'New masks', - 'Union of original and new masks', - 'Intersection of original and new masks' - ] - - from cellacdc.plot import imshow - promptSegmResultsWindow = imshow( - *images, - labels_overlays=labels_overlays, - labels_overlays_luts=labels_overlays_luts, - axis_titles=axis_titles, - window_title='Promptable segmentation results', - figure_title='Ctrl+Click to select the result to use', - annotate_labels_idxs=[0, 1, 2, 3], - selectable_images=True, - max_ncols=2, - lut='gray', - infer_rgb=False - ) - if promptSegmResultsWindow.selected_idx is None: - self.logger.info( - 'Selection of the promptable model segmentation ' - 'result cancelled.' - ) - return - - if promptSegmResultsWindow.selected_idx == 0: - self.logger.info( - 'No selection of a promptable model segmentation ' - 'result was made' - ) - return - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - results = (None, lab_new, lab_union, lab_interesection) - selected_idx = promptSegmResultsWindow.selected_idx - zoom_out_lab = results[selected_idx][..., zoom_slice[0], zoom_slice[1]] - zoom_out_lab_mask = zoom_out_lab > 0 - - lab = posData.allData_li[posData.frame_i]['labels'] - lab[..., zoom_slice[0], zoom_slice[1]][zoom_out_lab_mask] = ( - zoom_out_lab[zoom_out_lab_mask] - ) - - posData.allData_li[posData.frame_i]['labels'] = lab - self.get_data() - self.store_data(autosave=False) - self.updateAllImages() - - def manualTracking_cb(self, checked): - self.manualTrackingToolbar.setVisible(checked) - if checked: - self.realTimeTrackingToggle.previousStatus = ( - self.realTimeTrackingToggle.isChecked() - ) - self.realTimeTrackingToggle.setChecked(False) - self.UserEnforced_DisabledTracking_previousStatus = ( - self.UserEnforced_DisabledTracking - ) - self.UserEnforced_Tracking_previousStatus = ( - self.UserEnforced_Tracking - ) - - self.UserEnforced_DisabledTracking = True - self.UserEnforced_Tracking = False - self.initGhostObject() - self.addManualTrackingItems() - else: - self.realTimeTrackingToggle.setChecked( - self.realTimeTrackingToggle.previousStatus - ) - self.UserEnforced_DisabledTracking = ( - self.UserEnforced_DisabledTracking_previousStatus - ) - self.UserEnforced_Tracking = ( - self.UserEnforced_Tracking_previousStatus - ) - self.removeManualTrackingItems() - self.clearGhost() - - def manualBackground_cb(self, checked): - if checked: - posData = self.data[self.pos_i] - minID = min(posData.IDs, default=0) - if minID == self.manualBackgroundToolbar.spinboxID.value(): - self.initManualBackgroundObject() - else: - self.manualBackgroundToolbar.spinboxID.setValue(minID) - # self.initManualBackgroundObject() - # self.initManualBackgroundImage() - self.addManualBackgroundItems() - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.manualBackgroundButton) - self.connectLeftClickButtons() - self.updateAllImages() - else: - self.removeManualTrackingItems() - self.clearGhost() - self.clearManualBackgroundAnnotations() - self.manualBackgroundToolbar.setVisible(checked) - - def autoSegm_cb(self, checked): - if checked: - self.askSegmParam = True - # Ask which model - models = myutils.get_list_of_models() - win = widgets.QDialogListbox( - 'Select model', - 'Select model to use for segmentation: ', - models, - multiSelection=False, - parent=self - ) - win.exec_() - if win.cancel: - return - model_name = win.selectedItemsText[0] - self.segmModelName = model_name - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - self.updateAllImages() - self.computeSegm() - self.askSegmParam = False - else: - self.segmModelName = None - - def postProcessSegm(self, checked): - if self.isSegm3D: - SizeZ = max([posData.SizeZ for posData in self.data]) - else: - SizeZ = None - if checked: - posData = self.data[self.pos_i] - self.postProcessSegmWin = apps.PostProcessSegmDialog( - posData, mainWin=self - ) - self.postProcessSegmWin.sigClosed.connect( - self.postProcessSegmWinClosed - ) - self.postProcessSegmWin.sigValueChanged.connect( - self.postProcessSegmValueChanged - ) - self.postProcessSegmWin.sigEditingFinished.connect( - self.postProcessSegmEditingFinished - ) - self.postProcessSegmWin.sigApplyToAllFutureFrames.connect( - self.postProcessSegmApplyToAllFutureFrames - ) - self.postProcessSegmWin.show() - self.postProcessSegmWin.valueChanged(None) - else: - self.postProcessSegmWin.close() - self.postProcessSegmWin = None - - def postProcessSegmApplyToAllFutureFrames( - self, postProcessKwargs, - customPostProcessGroupedFeatures, - customPostProcessFeatures - ): - proceed = self.warnEditingWithCca_df( - 'post-processing segmentation', update_images=False - ) - if not proceed: - self.logger.info('Post-processing segmentation cancelled.') - return - - self.progressWin = apps.QDialogWorkerProgress( - title='Post-processing segmentation', parent=self, - pbarDesc=f'Post-processing segmentation masks...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - self.startPostProcessSegmWorker( - postProcessKwargs, customPostProcessGroupedFeatures, - customPostProcessFeatures - ) - - def postProcessSegmEditingFinished(self): - self.update_rp() - self.store_data() - self.updateAllImages() - - def postProcessSegmWorkerFinished(self): - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.get_data() - self.updateAllImages() - self.titleLabel.setText('Post-processing segmentation done!', color='w') - self.logger.info('Post-processing segmentation done!') - - def postProcessSegmWinClosed(self): - self.postProcessSegmWin = None - self.postProcessSegmAction.toggled.disconnect() - self.postProcessSegmAction.setChecked(False) - self.postProcessSegmAction.toggled.connect(self.postProcessSegm) - - def postProcessSegmValueChanged(self, lab, delObjs: dict): - for delObj in delObjs.values(): - self.clearObjContour(obj=delObj, ax=0) - self.clearObjContour(obj=delObj, ax=1) - - posData = self.data[self.pos_i] - - labelsToSkip = {} - for ID in posData.IDs: - if ID in delObjs: - labelsToSkip[ID] = True - continue - - restoreObj = self.postProcessSegmWin.origObjs[ID] - self.addObjContourToContoursImage(obj=restoreObj, ax=0) - self.addObjContourToContoursImage(obj=restoreObj, ax=1) - - # self.setAllTextAnnotations(labelsToSkip=labelsToSkip) - - posData.lab = lab - self.setImageImg2() - if self.annotSegmMasksCheckbox.isChecked(): - self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) - if self.annotSegmMasksCheckboxRight.isChecked(): - self.labelsLayerRightImg.setImage(self.currentLab2D, autoLevels=False) - - def readSavedCustomAnnot(self): - tempAnnot = {} - if os.path.exists(custom_annot_path): - self.logger.info('Loading saved custom annotations...') - tempAnnot = load.read_json( - custom_annot_path, logger_func=self.logger.info - ) - - posData = self.data[self.pos_i] - self.savedCustomAnnot = tempAnnot - for pos_i, posData in enumerate(self.data): - self.savedCustomAnnot = { - **self.savedCustomAnnot, **posData.customAnnot - } - - def addCustomAnnotButtonAllLoadedPos(self): - allPosCustomAnnot = {} - for pos_i, posData in enumerate(self.data): - self.addCustomAnnotationSavedPos(pos_i=pos_i) - allPosCustomAnnot = {**allPosCustomAnnot, **posData.customAnnot} - for posData in self.data: - posData.customAnnot = allPosCustomAnnot - - def addCustomAnnotationSavedPos(self, pos_i=None): - if pos_i is None: - pos_i = self.pos_i - - posData = self.data[pos_i] - for name, annotState in posData.customAnnot.items(): - # Check if button is already present and update only annotated IDs - buttons = [b for b in self.customAnnotDict.keys() if b.name==name] - if buttons: - toolButton = buttons[0] - allAnnotedIDs = self.customAnnotDict[toolButton]['annotatedIDs'] - allAnnotedIDs[pos_i] = posData.customAnnotIDs.get(name, {}) - continue - - try: - symbol = re.findall(r"\'(.+)\'", annotState['symbol'])[0] - except Exception as e: - self.logger.info(traceback.format_exc()) - symbol = 'o' - - symbolColor = QColor(*annotState['symbolColor']) - shortcut = annotState['shortcut'] - if shortcut is not None: - keySequence = widgets.macShortcutToWindows(shortcut) - keySequence = widgets.KeySequenceFromText(keySequence) - else: - keySequence = None - toolTip = myutils.getCustomAnnotTooltip(annotState) - keepActive = annotState.get('keepActive', True) - isHideChecked = annotState.get('isHideChecked', True) - - toolButton, action = self.addCustomAnnotationButton( - symbol, symbolColor, keySequence, toolTip, name, - keepActive, isHideChecked - ) - allPosAnnotIDs = [ - pos.customAnnotIDs.get(name, defaultdict(list)) - for pos in self.data - ] - self.customAnnotDict[toolButton] = { - 'action': action, - 'state': annotState, - 'annotatedIDs': allPosAnnotIDs - } - - self.addCustomAnnnotScatterPlot(symbolColor, symbol, toolButton) - - def addCustomAnnotationButton( - self, symbol, symbolColor, keySequence, toolTip, annotName, - keepActive, isHideChecked - ): - toolButton = widgets.customAnnotToolButton( - symbol, symbolColor, parent=self, keepToolActive=keepActive, - isHideChecked=isHideChecked - ) - toolButton.setCheckable(True) - self.checkableQButtonsGroup.addButton(toolButton) - if keySequence is not None: - toolButton.setShortcut(keySequence) - toolButton.setToolTip(toolTip) - toolButton.name = annotName - toolButton.toggled.connect(self.customAnnotButtonToggled) - toolButton.sigRemoveAction.connect(self.removeCustomAnnotButton) - toolButton.sigKeepActiveAction.connect(self.customAnnotKeepActive) - toolButton.sigHideAction.connect(self.customAnnotHide) - toolButton.sigModifyAction.connect(self.customAnnotModify) - action = self.annotateToolbar.addWidget(toolButton) - return toolButton, action - - def addCustomAnnnotScatterPlot( - self, symbolColor, symbol, toolButton - ): - # Add scatter plot item - symbolColorBrush = [0, 0, 0, 50] - symbolColorBrush[:3] = symbolColor.getRgb()[:3] - scatterPlotItem = widgets.CustomAnnotationScatterPlotItem() - scatterPlotItem.setData( - [], [], symbol=symbol, pxMode=False, - brush=pg.mkBrush(symbolColorBrush), size=15, - pen=pg.mkPen(width=3, color=symbolColor), - hoverable=True, hoverBrush=pg.mkBrush(symbolColor), - tip=None - ) - scatterPlotItem.sigHovered.connect(self.customAnnotHovered) - scatterPlotItem.button = toolButton - self.customAnnotDict[toolButton]['scatterPlotItem'] = scatterPlotItem - self.ax1.addItem(scatterPlotItem) - - def addCustomAnnotationItems( - self, symbol, symbolColor, keySequence, toolTip, name, - keepActive, isHideChecked, state - ): - toolButton, action = self.addCustomAnnotationButton( - symbol, symbolColor, keySequence, toolTip, name, - keepActive, isHideChecked - ) - - self.customAnnotDict[toolButton] = { - 'action': action, - 'state': state, - 'annotatedIDs': [defaultdict(list) for _ in range(len(self.data))] - } - - # Save custom annotation to cellacdc/temp/custom_annotations.json - state_to_save = state.copy() - state_to_save['symbolColor'] = tuple(symbolColor.getRgb()) - self.savedCustomAnnot[name] = state_to_save - for posData in self.data: - posData.customAnnot[name] = state_to_save - - # Add scatter plot item - self.addCustomAnnnotScatterPlot(symbolColor, symbol, toolButton) - - customAnnotButton = self.customAnnotDict[toolButton] - allPosAnnotatedIDs = customAnnotButton['annotatedIDs'] - # Add 0s column to acdc_df - for pos_i, posData in enumerate(self.data): - for frame_i, data_dict in enumerate(posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - if name not in acdc_df.columns: - acdc_df[name] = 0 - else: - acdc_df[name] = acdc_df[name].astype(int) - acdc_df_annot = acdc_df[acdc_df[name] == 1].reset_index() - annot_IDs = acdc_df_annot['Cell_ID'].to_list() - allPosAnnotatedIDs[pos_i][frame_i].extend(annot_IDs) - - if posData.acdc_df is not None: - if name not in posData.acdc_df.columns: - posData.acdc_df[name] = 0 - else: - posData.acdc_df[name] = posData.acdc_df[name].astype(int) - acdc_df_annot = ( - posData.acdc_df[posData.acdc_df[name] == 1] - .reset_index() - ) - annot_IDs = acdc_df_annot['Cell_ID'].to_list() - allPosAnnotatedIDs[pos_i][frame_i].extend(annot_IDs) - - def customAnnotHovered(self, scatterPlotItem, points, event): - # Show tool tip when hovering an annotation with annotation name and ID - vb = scatterPlotItem.getViewBox() - if vb is None: - return - if len(points) > 0: - posData = self.data[self.pos_i] - point = points[0] - x, y = point.pos().x(), point.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - vb.setToolTip( - f'Annotation name: {scatterPlotItem.button.name}\n' - f'ID = {ID}' - ) - else: - vb.setToolTip('') - - def loadCustomAnnotations(self): - items = list(self.savedCustomAnnot.keys()) - if len(items) == 0: - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - There are no custom annotations saved.

- Click on "Add custom annotation" button to start adding new - annotations. - """) - msg.warning(self, 'No annotations saved', txt) - return - - self.selectAnnotWin = widgets.QDialogListbox( - 'Load previously used custom annotation(s)', - 'Select annotations to load:', items, - additionalButtons=('Delete selected annnotations', ), - parent=self, multiSelection=True - ) - for button in self.selectAnnotWin._additionalButtons: - button.disconnect() - button.clicked.connect(self.deleteSavedAnnotation) - self.selectAnnotWin.exec_() - if self.selectAnnotWin.cancel: - return - - for selectedAnnotName in self.selectAnnotWin.selectedItemsText: - selectedAnnot = self.savedCustomAnnot[selectedAnnotName] - - symbol = selectedAnnot['symbol'] - symbol = re.findall(r"\'(.+)\'", symbol)[0] - symbolColor = selectedAnnot['symbolColor'] - symbolColor = pg.mkColor(symbolColor) - keySequence = widgets.KeySequenceFromText(selectedAnnot['shortcut']) - Type = selectedAnnot['type'] - toolTip = ( - f'Name: {selectedAnnotName}\n\n' - f'Type: {Type}\n\n' - f'Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n' - f'Description: {selectedAnnot["description"]}\n\n' - f'Shortcut: "{keySequence}"' - ) - keepActive = selectedAnnot['keepActive'] - isHideChecked = selectedAnnot['isHideChecked'] - state = { - 'type': Type, - 'name': selectedAnnotName, - 'symbol': selectedAnnot['symbol'], - 'shortcut': selectedAnnot['shortcut'], - 'description': selectedAnnot["description"], - 'keepActive': keepActive, - 'isHideChecked': isHideChecked, - 'symbolColor': symbolColor - } - self.addCustomAnnotationItems( - symbol, symbolColor, keySequence, toolTip, selectedAnnotName, - keepActive, isHideChecked, state - ) - for pos_i, posData in enumerate(self.data): - posData.customAnnot[selectedAnnotName] = selectedAnnot - - self.saveCustomAnnot() - - def deleteSavedAnnotation(self): - for item in self.selectAnnotWin.listBox.selectedItems(): - name = item.text() - self.savedCustomAnnot.pop(name) - self.deleteSelectedAnnot(self.selectAnnotWin.listBox.selectedItems()) - items = list(self.savedCustomAnnot.keys()) - self.selectAnnotWin.listBox.clear() - self.selectAnnotWin.listBox.addItems(items) - - def addCustomAnnotation(self): - self.readSavedCustomAnnot() - - self.addAnnotWin = apps.customAnnotationDialog( - self.savedCustomAnnot, parent=self - ) - self.addAnnotWin.sigDeleteSelecAnnot.connect(self.deleteSelectedAnnot) - self.addAnnotWin.exec_() - if self.addAnnotWin.cancel: - self.logger.info('Custom annotation process cancelled.') - return - - symbol = self.addAnnotWin.symbol - symbolColor = self.addAnnotWin.state['symbolColor'] - keySequence = self.addAnnotWin.shortcutWidget.widget.keySequence - toolTip = self.addAnnotWin.toolTip - name = self.addAnnotWin.state['name'] - keepActive = self.addAnnotWin.state.get('keepActive', True) - isHideChecked = self.addAnnotWin.state.get('isHideChecked', True) - - proceed = self.checkNameExists(name) - if not proceed: - self.logger.info('Custom annotation process cancelled.') - return - - self.addCustomAnnotationItems( - symbol, symbolColor, keySequence, toolTip, name, - keepActive, isHideChecked, self.addAnnotWin.state - ) - self.saveCustomAnnot() - self.doCustomAnnotation(0) - - def askCustomAnnotationNameExists(self, name): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph(f""" - The annotationa called {name} already exists in the - acdc_output CSV file.

- If you continue, this column will be used to initialize - pre-annotated objects.

- Do you want to continue? - """ - ) - noButton, yesButton = msg.question( - self, 'Custom annotation name already exists', txt, - buttonsTexts=('No, stop process', 'Yes, use existing column') - ) - return msg.clickedButton == yesButton - - - def checkNameExists(self, name): - posData = self.data[self.pos_i] - for frame_i, data_dict in enumerate(posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - if name in acdc_df.columns: - return self.askCustomAnnotationNameExists(name) - - if posData.acdc_df is not None and name in posData.acdc_df.columns: - return self.askCustomAnnotationNameExists(name) - - return True - - - def viewAllCustomAnnot(self, checked): - if not checked: - # Clear all annotations before showing only checked - for button in self.customAnnotDict.keys(): - self.clearScatterPlotCustomAnnotButton(button) - self.doCustomAnnotation(0) - - def clearScatterPlotCustomAnnotButton(self, button): - scatterPlotItem = self.customAnnotDict[button]['scatterPlotItem'] - scatterPlotItem.setData([], []) - - def saveCustomAnnot(self, only_temp=False): - if not hasattr(self, 'savedCustomAnnot'): - return - - if not self.savedCustomAnnot: - return - - # Save to cell acdc temp path - with open(custom_annot_path, mode='w') as file: - json.dump(self.savedCustomAnnot, file, indent=2) - - if only_temp: - return - - self.logger.info('Saving custom annotations parameters...') - # Save to pos path - for _posData in self.data: - _posData.saveCustomAnnotationParams() - - def customAnnotKeepActive(self, button): - self.customAnnotDict[button]['state']['keepActive'] = button.keepToolActive - - def customAnnotHide(self, button): - self.customAnnotDict[button]['state']['isHideChecked'] = button.isHideChecked - clearAnnot = ( - not button.isChecked() and button.isHideChecked - and not self.viewAllCustomAnnotAction.isChecked() - ) - if clearAnnot: - # User checked hide annot with the button not active --> clear - self.clearScatterPlotCustomAnnotButton(button) - elif not button.isChecked(): - # User uncheked hide annot with the button not active --> show - self.doCustomAnnotation(0) - - def deleteSelectedAnnot(self, itemsToDelete): - self.saveCustomAnnot(only_temp=True) - - def customAnnotModify(self, button): - state = self.customAnnotDict[button]['state'] - self.addAnnotWin = apps.customAnnotationDialog( - self.savedCustomAnnot, state=state - ) - self.addAnnotWin.sigDeleteSelecAnnot.connect(self.deleteSelectedAnnot) - self.addAnnotWin.exec_() - if self.addAnnotWin.cancel: - return - - # Rename column if existing - posData = self.data[self.pos_i] - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - if acdc_df is not None: - old_name = self.customAnnotDict[button]['state']['name'] - new_name = self.addAnnotWin.state['name'] - acdc_df = acdc_df.rename(columns={old_name: new_name}) - posData.allData_li[posData.frame_i]['acdc_df'] = acdc_df - - self.customAnnotDict[button]['state'] = self.addAnnotWin.state - - name = self.addAnnotWin.state['name'] - state_to_save = self.addAnnotWin.state.copy() - symbolColor = self.addAnnotWin.state['symbolColor'] - state_to_save['symbolColor'] = tuple(symbolColor.getRgb()) - self.savedCustomAnnot[name] = self.addAnnotWin.state - self.saveCustomAnnot() - - symbol = self.addAnnotWin.symbol - symbolColor = self.customAnnotDict[button]['state']['symbolColor'] - button.setColor(symbolColor) - button.update() - symbolColorBrush = [0, 0, 0, 50] - symbolColorBrush[:3] = symbolColor.getRgb()[:3] - scatterPlotItem = self.customAnnotDict[button]['scatterPlotItem'] - xx, yy = scatterPlotItem.getData() - if xx is None: - xx, yy = [], [] - scatterPlotItem.setData( - xx, yy, symbol=symbol, pxMode=False, - brush=pg.mkBrush(symbolColorBrush), size=15, - pen=pg.mkPen(width=3, color=symbolColor) - ) - - def doCustomAnnotation(self, ID): - mode = self.modeComboBox.currentText() - if not self.isSnapshot and mode != 'Custom annotations': - # Do not show annotations if timelapse and mode not annotations - return - - if self.switchPlaneCombobox.depthAxes() != 'z': - return - - # NOTE: pass 0 for ID to not add - posData = self.data[self.pos_i] - if self.viewAllCustomAnnotAction.isChecked(): - # User requested to show all annotations --> iterate all buttons - # Unless it actively clicked to annotate --> avoid annotating object - # with all the annotations present - buttons = list(self.customAnnotDict.keys()) - else: - # Annotate if the button is active or isHideChecked is False - buttons = [ - b for b in self.customAnnotDict.keys() - if (b.isChecked() or not b.isHideChecked) - ] - if not buttons: - return - - for button in buttons: - annotatedIDs = ( - self.customAnnotDict[button]['annotatedIDs'][self.pos_i] - ) - annotIDs_frame_i = annotatedIDs.get(posData.frame_i, []) - state = self.customAnnotDict[button]['state'] - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - - if button.isChecked() and ID > 0: - # Annotate only if existing ID and the button is checked - if ID in annotIDs_frame_i: - annotIDs_frame_i.remove(ID) - acdc_df.at[ID, state['name']] = 0 - elif ID != 0: - annotIDs_frame_i.append(ID) - - annotPerButton = self.customAnnotDict[button] - allAnnotedIDs = annotPerButton['annotatedIDs'] - posAnnotedIDs = allAnnotedIDs[self.pos_i] - posAnnotedIDs[posData.frame_i] = annotIDs_frame_i - - if acdc_df is None: - self.store_data(autosave=False) - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - - xx, yy = [], [] - for annotID in annotIDs_frame_i: - if annotID not in posData.IDs_idxs: - continue - - obj_idx = posData.IDs_idxs[annotID] - obj = posData.rp[obj_idx] - acdc_df.at[annotID, state['name']] = 1 - if not self.isObjVisible(obj.bbox): - continue - y, x = self.getObjCentroid(obj.centroid) - xx.append(x) - yy.append(y) - - scatterPlotItem = self.customAnnotDict[button]['scatterPlotItem'] - scatterPlotItem.setData(xx, yy) - - posData.allData_li[posData.frame_i]['acdc_df'] = acdc_df - - # if self.highlightedID != 0: - # self.highlightedID = 0 - # self.setHighlightID(False) - - if buttons: - return buttons[0] - - def removeCustomAnnotButton(self, button, askHow=True, save=True): - if askHow: - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - Do you want to remove also the column with annotations or - only the annotation button?
- """) - _, removeOnlyButton, removeColButton = msg.question( - self, 'Remove only button?', txt, - buttonsTexts=( - 'Cancel', 'Remove only button', - ' Remove also column with annotations ' - ) - ) - if msg.cancel: - return - removeOnlyButton = msg.clickedButton == removeOnlyButton - else: - removeOnlyButton = True - - name = self.customAnnotDict[button]['state']['name'] - # remove annotation from position - for posData in self.data: - try: - posData.customAnnot.pop(name) - posData.saveCustomAnnotationParams() - except KeyError as e: - # Current pos doesn't have any annotation button. Continue - continue - - if posData.acdc_df is None: - continue - - if removeOnlyButton: - continue - - posData.acdc_df = posData.acdc_df.drop( - columns=name, errors='ignore' - ) - for frame_i, data_dict in enumerate(posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - acdc_df = acdc_df.drop(columns=name, errors='ignore') - posData.allData_li[frame_i]['acdc_df'] = acdc_df - - self.clearScatterPlotCustomAnnotButton(button) - - action = self.customAnnotDict[button]['action'] - self.annotateToolbar.removeAction(action) - self.checkableQButtonsGroup.removeButton(button) - self.customAnnotDict.pop(button) - # self.savedCustomAnnot.pop(name) - - self.saveCustomAnnot(only_temp=True) - - def customAnnotButtonToggled(self, checked): - if checked: - self.customAnnotButton = self.sender() - # Uncheck the other buttons - for button in self.customAnnotDict.keys(): - if button == self.sender(): - continue - - button.toggled.disconnect() - self.clearScatterPlotCustomAnnotButton(button) - button.setChecked(False) - button.toggled.connect(self.customAnnotButtonToggled) - self.doCustomAnnotation(0) - else: - self.customAnnotButton = None - button = self.sender() - clearAnnotation = ( - button.isHideChecked - or not self.viewAllCustomAnnotAction.isChecked() - ) - if clearAnnotation: - self.clearScatterPlotCustomAnnotButton(button) - self.setHighlightID(False) - self.resetCursor() - - def resetCursor(self): - if self.app.overrideCursor() is not None: - while self.app.overrideCursor() is not None: - self.app.restoreOverrideCursor() - - def segmFrameCallback(self, action): - if action == self.addCustomModelFrameAction: - return - - idx = self.segmActions.index(action) - model_name = self.modelNames[idx] - self.repeatSegm(model_name=model_name, askSegmParams=True) - - def segmVideoCallback(self, action): - if action == self.addCustomModelVideoAction: - return - - posData = self.data[self.pos_i] - win = apps.startStopFramesDialog( - posData.SizeT, currentFrameNum=posData.frame_i+1 - ) - win.exec_() - if win.cancel: - self.logger.info('Segmentation on multiple frames aborted.') - return - - idx = self.segmActionsVideo.index(action) - model_name = self.modelNames[idx] - self.repeatSegmVideo(model_name, win.startFrame, win.stopFrame) - - def segmentToolActionTriggered(self): - if self.segmModelName is None: - win = apps.QDialogSelectModel(parent=self) - win.exec_() - if win.cancel: - self.logger.info('Repeat segmentation cancelled.') - return - model_name = win.selectedModel - self.repeatSegm( - model_name=model_name, askSegmParams=True - ) - else: - self.repeatSegm(model_name=self.segmModelName) - - def initSegmModelParams( - self, model_name, acdcSegment, init_params, segment_params, - is_label_roi=False, initLastParams=False, - extraParams=None, extraParamsTitle=None,ini_filename=None - - ): - posData = self.data[self.pos_i] - try: - url = acdcSegment.url_help() - except AttributeError: - url = None - - text_if_cancelled = 'Segmentation process cancelled.' - out = prompts.init_segm_model_params( - posData, model_name, init_params, segment_params, - help_url=url, qparent=self, init_last_params=initLastParams, - check_sam_embeddings=not is_label_roi, is_gui_caller=True, - extraParams=extraParams,extraParamsTitle=extraParamsTitle, - ini_filename=ini_filename, - ) - if out.get('load_sam_embeddings', False): - self.logger.info('Loading Segment Anything image embeddings...') - for _posData in self.data: - _posData.loadSamEmbeddings(logger_func=None) - text_if_cancelled = 'SAM embeddings loaded.' - - win = out.get('win') - if win is None: - self.logger.info(text_if_cancelled) - self.titleLabel.setText(text_if_cancelled) - return - - if win.cancel: - self.logger.info(text_if_cancelled) - self.titleLabel.setText(text_if_cancelled) - return - - if model_name != 'thresholding': - self.model_kwargs = win.model_kwargs - - return win - - @exception_handler - def repeatSegm( - self, model_name='', askSegmParams=False, is_label_roi=False - ): - if model_name == 'thresholding': - # thresholding model is stored as 'Automatic thresholding' - # at line of code `models.append('Automatic thresholding')` - model_name = 'Automatic thresholding' - - idx = self.modelNames.index(model_name) - # Ask segm parameters if not already set - # and not called by segmSingleFrameMenu (askSegmParams=False) - if not askSegmParams: - askSegmParams = self.model_kwargs is None - - self.downloadWin = apps.downloadModel(model_name, parent=self) - self.downloadWin.download() - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - if model_name == 'Automatic thresholding': - # Automatic thresholding is the name of the models as stored - # in self.modelNames, but the actual model is called thresholding - # (see cellacdc/models/thresholding) - model_name = 'thresholding' - - posData = self.data[self.pos_i] - # Check if model needs to be imported - acdcSegment = self.acdcSegment_li[idx] - if acdcSegment is None: - self.logger.info(f'Importing {model_name}...') - acdcSegment = myutils.import_segment_module(model_name) - self.acdcSegment_li[idx] = acdcSegment - - # Ask parameters if the user clicked on the action - # Otherwise this function is called by "computeSegm" function and - # we use loaded parameters - if askSegmParams: - if self.app.overrideCursor() == Qt.WaitCursor: - self.app.restoreOverrideCursor() - self.segmModelName = model_name - # Read all models parameters - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - # Prompt user to enter the model parameters - try: - url = acdcSegment.url_help() - except AttributeError: - url = None - - self.preproc_recipe = None - initLastParams = True - if model_name == 'thresholding': - win = apps.QDialogAutomaticThresholding( - parent=self, isSegm3D=self.isSegm3D - ) - win.exec_() - if win.cancel: - return - self.model_kwargs = win.segment_kwargs - thresh_method = self.model_kwargs['threshold_method'] - gauss_sigma = self.model_kwargs['gauss_sigma'] - segment_params = myutils.insertModelArgSpec( - segment_params, 'threshold_method', thresh_method - ) - segment_params = myutils.insertModelArgSpec( - segment_params, 'gauss_sigma', gauss_sigma - ) - initLastParams = False - - win = self.initSegmModelParams( - model_name, acdcSegment, init_params, segment_params, - is_label_roi=is_label_roi, - initLastParams=initLastParams - ) - if win is None: - return - - self.standardPostProcessKwargs = win.standardPostProcessKwargs - self.customPostProcessFeatures = win.customPostProcessFeatures - self.customPostProcessGroupedFeatures = ( - win.customPostProcessGroupedFeatures - ) - self.applyPostProcessing = win.applyPostProcessing - self.secondChannelName = win.secondChannelName - self.preproc_recipe = win.preproc_recipe - - myutils.log_segm_params( - model_name, win.init_kwargs, win.model_kwargs, - logger_func=self.logger.info, - preproc_recipe=win.preproc_recipe, - apply_post_process=self.applyPostProcessing, - standard_postprocess_kwargs=self.standardPostProcessKwargs, - custom_postprocess_features=self.customPostProcessFeatures - ) - - use_gpu = win.init_kwargs.get('gpu', False) - proceed = myutils.check_gpu_available(model_name, use_gpu, qparent=self) - if not proceed: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - - model = myutils.init_segm_model( - acdcSegment, posData, win.init_kwargs - ) - if model is None: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - try: - model.setupLogger(self.logger) - except Exception as e: - pass - self.models[idx] = model - model.model_name = model_name - else: - model = self.models[idx] - - if is_label_roi: - return model - - self.titleLabel.setText( - f'Segmenting with {model_name}... ' - '(check progress in terminal/console)', color=self.titleColor - ) - - post_process_params = { - 'applied_postprocessing': self.applyPostProcessing - } - post_process_params = { - **post_process_params, - **self.standardPostProcessKwargs, - **self.customPostProcessFeatures - } - if askSegmParams: - posData.saveSegmHyperparams( - model_name, win.init_kwargs, win.model_kwargs, - post_process_params=post_process_params, - preproc_recipe=self.preproc_recipe - ) - - if self.askRepeatSegment3D: - self.segment3D = False - if self.isSegm3D and self.askRepeatSegment3D: - msg = widgets.myMessageBox(showCentered=False) - msg.addDoNotShowAgainCheckbox(text='Do not ask again') - txt = html_utils.paragraph( - 'Do you want to segment the entire z-stack or only the ' - 'current z-slice?' - ) - _, segment3DButton, _ = msg.question( - self, '3D segmentation?', txt, - buttonsTexts=( - 'Cancel', 'Segment 3D z-stack', 'Segment 2D z-slice' - ) - ) - if msg.cancel: - self.titleLabel.setText('Segmentation process aborted.') - self.logger.info('Segmentation process aborted.') - return - self.segment3D = msg.clickedButton == segment3DButton - if msg.doNotShowAgainCheckbox.isChecked(): - self.askRepeatSegment3D = False - - if self.askZrangeSegm3D: - self.z_range = None - if self.isSegm3D and self.segment3D and self.askZrangeSegm3D: - idx = (posData.filename, posData.frame_i) - try: - orignal_z = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - except ValueError as e: - orignal_z = posData.segmInfo_df.loc[idx, 'z_slice_used_gui'].iloc[0] - selectZtool = apps.QCropZtool( - posData.SizeZ, parent=self, cropButtonText='Ok', - addDoNotShowAgain=True, title='Select z-slice range to segment' - ) - selectZtool.sigZvalueChanged.connect(self.selectZtoolZvalueChanged) - selectZtool.sigCrop.connect(selectZtool.close) - selectZtool.exec_() - self.update_z_slice(orignal_z) - if selectZtool.cancel: - self.titleLabel.setText('Segmentation process aborted.') - self.logger.info('Segmentation process aborted.') - return - startZ = selectZtool.lowerZscrollbar.value() - stopZ = selectZtool.upperZscrollbar.value() - self.z_range = (startZ, stopZ) - if selectZtool.doNotShowAgainCheckbox.isChecked(): - self.askZrangeSegm3D = False - - secondChannelData = None - if self.secondChannelName is not None: - secondChannelData = self.getSecondChannelData() - - self.titleLabel.setText( - f'{model_name} is thinking... ' - '(check progress in terminal/console)', color=self.titleColor - ) - - self.model = model - - self.segmWorkerMutex = QMutex() - self.segmWorkerWaitCond = QWaitCondition() - self.thread = QThread() - self.worker = workers.segmWorker( - self, - secondChannelData=secondChannelData, - mutex=self.segmWorkerMutex, - waitCond=self.segmWorkerWaitCond - ) - self.worker.z_range = self.z_range - self.worker.moveToThread(self.thread) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - if self.debug: - self.worker.debug.connect(self.debugSegmWorker) - self.thread.finished.connect(self.thread.deleteLater) - - # Custom signals - self.worker.critical.connect(self.workerCritical) - self.worker.finished.connect(self.segmWorkerFinished) - - self.thread.started.connect(self.worker.run) - self.thread.start() - - def debugSegmWorker(self, to_debug): - img, _lab, lab = to_debug - printl(img.shape, _lab.shape, lab.shape) - imshow(img, _lab, lab) - self.segmWorkerWaitCond.wakeAll() - - def selectZtoolZvalueChanged(self, whichZ, z): - self.update_z_slice(z) - - @exception_handler - def repeatSegmVideo(self, model_name, startFrameNum, stopFrameNum): - if model_name == 'thresholding': - # thresholding model is stored as 'Automatic thresholding' - # at line of code `models.append('Automatic thresholding')` - model_name = 'Automatic thresholding' - - idx = self.modelNames.index(model_name) - - self.downloadWin = apps.downloadModel(model_name, parent=self) - self.downloadWin.download() - - if model_name == 'Automatic thresholding': - # Automatic thresholding is the name of the models as stored - # in self.modelNames, but the actual model is called thresholding - # (see cellacdc/models/thresholding) - model_name = 'thresholding' - - posData = self.data[self.pos_i] - # Check if model needs to be imported - acdcSegment = self.acdcSegment_li[idx] - if acdcSegment is None: - self.logger.info(f'Importing {model_name}...') - acdcSegment = myutils.import_segment_module(model_name) - self.acdcSegment_li[idx] = acdcSegment - - # Read all models parameters - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - # Prompt user to enter the model parameters - try: - url = acdcSegment.url_help() - except AttributeError: - url = None - - if model_name == 'thresholding': - autoThreshWin = apps.QDialogAutomaticThresholding( - parent=self, isSegm3D=self.isSegm3D - ) - autoThreshWin.exec_() - if autoThreshWin.cancel: - return - - win = self.initSegmModelParams( - model_name, acdcSegment, init_params, segment_params - ) - if win is None: - return - - self.standardPostProcessKwargs = win.standardPostProcessKwargs - self.customPostProcessFeatures = win.customPostProcessFeatures - self.customPostProcessGroupedFeatures = ( - win.customPostProcessGroupedFeatures - ) - self.applyPostProcessing = win.applyPostProcessing - self.preproc_recipe = win.preproc_recipe - - myutils.log_segm_params( - model_name, win.init_kwargs, win.model_kwargs, - logger_func=self.logger.info, - preproc_recipe=win.preproc_recipe, - apply_post_process=self.applyPostProcessing, - standard_postprocess_kwargs=self.standardPostProcessKwargs, - custom_postprocess_features=self.customPostProcessFeatures - ) - - secondChannelData = None - if win.secondChannelName is not None: - secondChannelData = self.getSecondChannelData() - - use_gpu = win.init_kwargs.get('gpu', False) - proceed = myutils.check_gpu_available(model_name, use_gpu, qparent=self) - if not proceed: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - - model = myutils.init_segm_model(acdcSegment, posData, win.init_kwargs) - if model is None: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - try: - model.setupLogger(self.logger) - except Exception as e: - pass - - self.extendSegmDataIfNeeded(stopFrameNum) - self.reInitLastSegmFrame( - from_frame_i=startFrameNum-1, updateImages=False - ) - - self.titleLabel.setText( - f'{model_name} is thinking... ' - '(check progress in terminal/console)', color=self.titleColor - ) - - self.progressWin = apps.QDialogWorkerProgress( - title='Segmenting video', parent=self, - pbarDesc=f'Segmenting from frame n. {startFrameNum} to {stopFrameNum}...' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(stopFrameNum-startFrameNum) - - self.thread = QThread() - self.worker = workers.segmVideoWorker( - posData, win, model, startFrameNum, stopFrameNum - ) - self.worker.secondChannelData = secondChannelData - self.worker.moveToThread(self.thread) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - - # Custom signals - self.worker.critical.connect(self.workerCritical) - self.worker.finished.connect(self.segmVideoWorkerFinished) - self.worker.progressBar.connect(self.workerUpdateProgressbar) - self.worker.progress.connect(self.workerProgress) - - self.thread.started.connect(self.worker.run) - self.thread.start() - - def segmVideoWorkerFinished(self, exec_time): - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.activateAnnotations() - - self.get_data() - self.tracking(enforce=True) - self.updateAllImages() - - txt = f'Done. Segmentation computed in {exec_time:.3f} s' - self.logger.info('-----------------') - self.logger.info(txt) - self.logger.info('=================') - self.titleLabel.setText(txt, color='g') - - @exception_handler - def lazyLoaderCritical(self, error): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.lazyLoader.pause() - raise error - - def ccaIntegrityWorkerCritical(self, error): - try: - raise error - except Exception as err: - self.logger.exception(traceback.format_exc()) - - href = f'GitHub page' - txt = html_utils.paragraph(f""" - Unfortunately the experimental feature - check cell cycle annotations integrity raised a - critical error.

- Cell-ACDC will now disable this feature to allow you to keep - using the software.

- However, we kindly ask you to report the issue on our - {href}, thank you very much!

- Please, include the log file when reporting the issue.

- Log file location: - """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning( - self, 'Experimental feature error', txt, - commands=(self.log_path,), - path_to_browse=self.logs_path - ) - self.disableCcaIntegrityChecker() - - @exception_handler - def workerCritical(self, out: Tuple[QObject, Exception]): - self.setDisabled(False) - try: - worker, error = out - except TypeError as err: - error = out - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.logger.info(error) - try: - worker.thread().quit() - worker.deleteLater() - worker.thread().deleteLater() - except Exception as err: - # Worker already closed - pass - raise error - - def workerLog(self, text): - self.logger.info(text) - - def saveDataWorkerCritical(self, error): - self.logger.warning( - 'Saving process stopped because of critical error.' - ) - self.saveWin.aborted = True - self.worker.finished.emit() - self.workerCritical(error) - - def lazyLoaderWorkerClosed(self): - if self.lazyLoader.salute: - self.logger.info('Cell-ACDC GUI closed.') - self.sigClosed.emit(self) - - self.lazyLoader = None - - def segmWorkerFinished(self, lab, exec_time): - posData = self.data[self.pos_i] - - if posData.segmInfo_df is not None and posData.SizeZ>1: - idx = (posData.filename, posData.frame_i) - posData.segmInfo_df.at[idx, 'resegmented_in_gui'] = True - - if lab.ndim == 2 and self.isSegm3D: - self.set_2Dlab(lab) - else: - posData.lab = lab.copy() - - self.activateAnnotations() - - self.update_rp(wl_update=False) - self.tracking(enforce=True, against_next=posData.frame_i==0) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Repeat segmentation') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Repeat segmentation') - - txt = f'Done. Segmentation computed in {exec_time:.3f} s' - self.logger.info('-----------------') - self.logger.info(txt) - self.logger.info('=================') - self.titleLabel.setText(txt, color='g') - self.checkIfAutoSegm() - - QTimer.singleShot(200, self.resizeGui) - def activateAnnotations(self): - if self.annotContourCheckbox.isChecked(): - return - if self.annotSegmMasksCheckbox.isChecked(): - return - - self.annotSegmMasksCheckbox.setChecked(True) - self.setDrawAnnotComboboxText() - - # @exec_time - def getDisplayedImg1(self): - return self.img1.image - - def getDisplayedZstack(self): - posData = self.data[self.pos_i] - return posData.img_data[posData.frame_i] - - def autoAssignBud_YeastMate(self): - if not self.is_win: - txt = ( - 'YeastMate is available only on Windows OS.' - 'We are working on expading support also on macOS and Linux.\n\n' - 'Thank you for your patience!' - ) - msg = QMessageBox() - msg.critical( - self, 'Supported only on Windows', txt, msg.Ok - ) - return - - - model_name = 'YeastMate' - idx = self.modelNames.index(model_name) - - self.titleLabel.setText( - f'{model_name} is thinking... ' - '(check progress in terminal/console)', color=self.titleColor - ) - - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - posData = self.data[self.pos_i] - # Check if model needs to be imported - acdcSegment = self.acdcSegment_li[idx] - if acdcSegment is None: - acdcSegment = myutils.import_segment_module(model_name) - self.acdcSegment_li[idx] = acdcSegment - - # Read all models parameters - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - # Prompt user to enter the model parameters - try: - url = acdcSegment.url_help() - except AttributeError: - url = None - - _SizeZ = None - if self.isSegm3D: - _SizeZ = posData.SizeZ - win = apps.QDialogModelParams( - init_params, - segment_params, - model_name, - url=url, - posData=posData, - df_metadata=posData.metadata_df - ) - win.exec_() - if win.cancel: - self.titleLabel.setText('Segmentation aborted.') - return - - use_gpu = win.init_kwargs.get('gpu', False) - proceed = myutils.check_gpu_available(model_name, use_gpu, qparent=self) - if not proceed: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - - self.model_kwargs = win.model_kwargs - model = myutils.init_segm_model(acdcSegment, posData, win.init_kwargs) - if model is None: - self.logger.info('Segmentation process cancelled.') - self.titleLabel.setText('Segmentation process cancelled.') - return - try: - model.setupLogger(self.logger) - except Exception as e: - pass - - self.models[idx] = model - - img = self.getDisplayedImg1() - - posData.cca_df = model.predictCcaState(img, posData.lab) - self.store_data() - self.updateAllImages() - - self.titleLabel.setText('Budding event prediction done.', color='g') - - def isNavigateActionOnNextFrame(self): - posData = self.data[self.pos_i] - if posData.SizeT == 1: - return False - - ax1_coords = self.getMouseDataCoordsRightImage() - if ax1_coords is None: - return False - - if not self.labelsGrad.showNextFrameAction.isEnabled(): - return False - - if not self.labelsGrad.showNextFrameAction.isChecked(): - return - - # Mouse is on right image and next frame action is checked - return True - - def rightImageFramesScrollbarValueChanged(self, value): - img = self.nextFrameImage(current_frame_i=value-2) - self.img1.linkedImageItem.frame_i = value - self.img1.linkedImageItem.setImage(img) - - def nextActionTriggered(self): - if self.isNavigateActionOnNextFrame(): - self.rightImageFramesScrollbar.setValue( - self.rightImageFramesScrollbar.value()+1 - ) - return - - stepAddAction = QAbstractSlider.SliderAction.SliderSingleStepAdd - if self.zKeptDown or self.zSliceCheckbox.isChecked(): - self.zSliceScrollBar.triggerAction(stepAddAction) - else: - self.navigateScrollBar.triggerAction(stepAddAction) - - def prevActionTriggered(self): - if self.isNavigateActionOnNextFrame(): - self.rightImageFramesScrollbar.setValue( - self.rightImageFramesScrollbar.value()-1 - ) - return - - stepSubAction = QAbstractSlider.SliderAction.SliderSingleStepSub - if self.zKeptDown or self.zSliceCheckbox.isChecked(): - self.zSliceScrollBar.triggerAction(stepSubAction) - else: - self.navigateScrollBar.triggerAction(stepSubAction) - - def resetNavigateScrollbar(self): - try: - self.navigateScrollBar.blockSignals(True) - self.navigateScrollBar.actionTriggered.disconnect() - self.navigateScrollBar.sliderReleased.disconnect() - self.navigateScrollBar.sliderMoved.disconnect() - # self.navigateScrollBar.valueChanged.disconnect() - self.navigateScrollBar.setSliderPosition(self.navSpinBox.value()) - except Exception as e: - if "disconnect()" not in str(e): - printl(e) - pass - - self.navigateScrollBar.blockSignals(False) - self.navigateScrollBar.actionTriggered.connect( - self.framesScrollBarActionTriggered - ) - self.navigateScrollBar.sliderReleased.connect(self.framesScrollBarReleased) - self.navigateScrollBar.sliderMoved.connect(self.framesScrollBarMoved) - - @exception_handler - def next_cb(self): - if self.isSnapshot: - self.next_pos() - else: - self.next_frame() - if self.curvToolButton.isChecked(): - self.curvTool_cb(True) - - self.updatePropsWidget('') - - @exception_handler - def prev_cb(self): - if self.isSnapshot: - self.prev_pos() - else: - self.prev_frame() - if self.curvToolButton.isChecked(): - self.curvTool_cb(True) - - self.updatePropsWidget('') - - def zoomOut(self): - self.ax1.autoRange() - - def preprocessActionTriggered(self): - self.preprocessDialog.show() - self.preprocessDialog.raise_() - self.preprocessDialog.activateWindow() - self.preprocessDialog.emitSigPreviewToggled() - - def zoomToObjsActionCallback(self): - self.zoomToCells(enforce=True) - - def zoomToCells(self, enforce=False): - if not self.enableAutoZoomToCellsAction.isChecked() and not enforce: - return - - posData = self.data[self.pos_i] - lab_mask = (self.currentLab2D>0).astype(np.uint8) - rp = skimage.measure.regionprops(lab_mask) - if not rp: - Y, X = lab_mask.shape - xRange = -0.5, X+0.5 - yRange = -0.5, Y+0.5 - else: - obj = rp[0] - min_row, min_col, max_row, max_col = self.getObjBbox(obj.bbox) - xRange = min_col-10, max_col+10 - yRange = max_row+10, min_row-10 - - self.ax1.setRange(xRange=xRange, yRange=yRange) - - def viewCcaTable(self): - posData = self.data[self.pos_i] - zoomIDs = self.getZoomIDs() - - df = posData.allData_li[posData.frame_i]['acdc_df'] - current_cca_df = posData.cca_df - if zoomIDs is not None: - df = df.loc[zoomIDs] - current_cca_df = current_cca_df.loc[zoomIDs] - - for column in current_cca_df.columns: - header = ( - '================================================\n' - f'CURRENT vs STORED `{column}` column' - f'for frame number {posData.frame_i+1}:\n' - ) - df_compare = current_cca_df[[column]].copy() - df_compare[f'STORED_{column}'] = df[column] - text = f'{header}{df_compare}' - self.logger.info(text) - - if 'cell_cycle_stage' in df.columns: - cca_df = df[self.cca_df_colnames] - cca_df = cca_df.merge( - current_cca_df, how='outer', left_index=True, right_index=True, - suffixes=('_STORED', '_CURRENT') - ) - cca_df = cca_df.reindex(sorted(cca_df.columns), axis=1) - num_cols = len(cca_df.columns) - for j in range(0,num_cols,2): - df_j_x = cca_df.iloc[:,j] - df_j_y = cca_df.iloc[:,j+1] - if any(df_j_x!=df_j_y): - self.logger.info('------------------------') - self.logger.info('DIFFERENCES:') - diff_df = cca_df.iloc[:,j:j+2] - diff_mask = diff_df.iloc[:,0]!=diff_df.iloc[:,1] - self.logger.info(diff_df[diff_mask]) - else: - cca_df = None - self.logger.info(cca_df) - self.logger.info('========================') - if current_cca_df is None: - return - if current_cca_df.empty: - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Cell cycle annotations\' table is empty.
' - ) - msg.warning(self, 'Table empty', txt) - return - - df = posData.add_tree_cols_to_cca_df( - current_cca_df, frame_i=posData.frame_i - ) - if self.ccaTableWin is None: - self.ccaTableWin = apps.ViewCcaTableWindow(df, parent=self) - self.ccaTableWin.show() - self.ccaTableWin.setGeometryWindow() - self.ccaTableWin.sigUpdateCcaTable.connect( - self.onSigUpdateCcaTableWindow - ) - else: - self.ccaTableWin.setFocus() - self.ccaTableWin.activateWindow() - self.ccaTableWin.updateTable(current_cca_df) - - def updateScrollbars(self): - self.updateItemsMousePos() - self.updateFramePosLabel() - posData = self.data[self.pos_i] - navPos = self.pos_i+1 if self.isSnapshot else posData.frame_i+1 - self.navigateScrollBar.setSliderPosition(navPos) - if posData.SizeZ > 1: - self.updateZsliceScrollbar(posData.frame_i) - idx = (posData.filename, posData.frame_i) - self.zSliceScrollBar.setMaximum(posData.SizeZ-1) - self.zSliceSpinbox.setMaximum(posData.SizeZ) - self.SizeZlabel.setText(f'/{posData.SizeZ}') - - def updateItemsMousePos(self): - if self.brushButton.isChecked(): - self.updateBrushCursor(self.xHoverImg, self.yHoverImg) - - if self.eraserButton.isChecked(): - self.updateEraserCursor(self.xHoverImg, self.yHoverImg) - - @exception_handler - def postProcessing(self): - if self.postProcessSegmWin is None: - return - - self.postProcessSegmWin.setPosData() - posData = self.data[self.pos_i] - lab, delIDs = self.postProcessSegmWin.apply() - if posData.allData_li[posData.frame_i]['labels'] is None: - posData.lab = lab.copy() - self.update_rp() - else: - posData.allData_li[posData.frame_i]['labels'] = lab - self.get_data() - - def preprocessDialogRecipeChanged(self, recipe):# why does this need the recepie as an arg - recipe = self.preprocessDialog.recipe() - if recipe is None: - self.logger.warning('Pre-processing recipe not initialized yet.') - return - - self.updatePreprocessPreview(recipe=recipe) - - def debugShowImg(self, img): - imshow(img) - - def preprocessDialogSavePreprocessedData(self, dialog): - posData = self.data[self.pos_i] - - try: - posData.preprocessedDataArray() - except TypeError as e: - if 'Not all frames have been processed.' in str(e): - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Not all frames have been processed.
' - 'Please process all frames before saving.' - ) - msg.warning(self, 'Process all data before saving', txt) - return - - - helpText = ( - """ - The preprocessed image file will be saved with a different - file name.

- Insert a name to append to the end of the new file name. The rest of - the name will be the same as the original file. - """ - ) - - - win = apps.filenameDialog( - basename=f'{posData.basename}{self.user_ch_name}', - ext=".tif", - hintText='Insert a name for the preprocessed image file:', - defaultEntry='preprocessed', - helpText=helpText, - allowEmpty=False, - parent=dialog - ) - win.exec_() - if win.cancel: - return - - appendedText = win.entryText - - self.progressWin = apps.QDialogWorkerProgress( - title='Saving pre-processed image(s)', - parent=self, - pbarDesc='Saving pre-processed image(s)' - ) - self.progressWin.show(self.app) - self.progressWin.mainPbar.setMaximum(0) - - self.statusBarLabel.setText('Saving pre-processed data...') - - self.savePreprocWorker = workers.SaveProcessedDataWorker( - self.data, appendedText, ext=".tif" - ) - - self.savePreprocThread = QThread() - self.savePreprocWorker.moveToThread(self.savePreprocThread) - self.savePreprocWorker.signals.finished.connect( - self.savePreprocThread.quit - ) - self.savePreprocWorker.signals.finished.connect( - self.savePreprocWorker.deleteLater - ) - self.savePreprocThread.finished.connect( - self.savePreprocThread.deleteLater - ) - - self.savePreprocWorker.signals.critical.connect( - self.workerCritical - ) - self.savePreprocWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.savePreprocWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.savePreprocWorker.signals.progress.connect( - self.workerProgress - ) - self.savePreprocWorker.signals.finished.connect( - self.savePreprocWorkerFinished - ) - - self.savePreprocThread.started.connect( - self.savePreprocWorker.run - ) - self.savePreprocThread.start() - - - def preprocessEnqueueCurrentImage(self, recipe): - posData = self.data[self.pos_i] - func = core.preprocess_image_from_recipe - image_data = self.getImage(raw=True) - if posData.SizeZ > 1: - z_slice = self.z_slice_index() - else: - z_slice = 0 - - recipe = core.validate_multidimensional_recipe(recipe) - - key = (self.pos_i, posData.frame_i, z_slice) - self.preprocWorker.enqueue( - func, - image_data, - recipe, - key - ) - - def getChData(self, requ_ch=None, pos_i=None): - if not pos_i: - pos_i = self.pos_i - - posData = self.data[pos_i] - - if not requ_ch: - requ_ch = set(self.ch_names) - else: - requ_ch = set(requ_ch) - - posData.setLoadedChannelNames() - - loaded_channels = set(posData.loadedChNames) - missing_channels = requ_ch - loaded_channels - - self.loadFluo_cb(fluo_channels=missing_channels) - - def updatePreprocessPreview(self, *args, **kwargs): - force = kwargs.get('force', False) - - if not self.preprocessDialog.isVisible() and not force: - return - - if not self.preprocessDialog.previewCheckbox.isChecked() and not force: - return - - if kwargs.get('recipe') is None: - recipe = self.preprocessDialog.recipe() - else: - recipe = kwargs.get('recipe') - - if recipe is None: - self.logger.warning('Pre-processing recipe not initialized yet.') - return - - txt = 'Pre-processing current image...' - self.logger.info(txt) - self.statusBarLabel.setText(txt) - - self.preprocessEnqueueCurrentImage(recipe) - - def next_pos(self): - self.store_data(debug=True, autosave=False) - prev_pos_i = self.pos_i - if self.pos_i < self.num_pos-1: - self.pos_i += 1 - self.updateSegmDataAutoSaveWorker() - else: - self.logger.info('You reached last position.') - self.pos_i = 0 - self.updatePos() - - def resetManualBackgroundItems(self): - self.initManualBackgroundImage() - self.resetManualBackgroundSpinboxID() - self.drawManualTrackingGhost(self.xHoverImg, self.yHoverImg) - self.drawManualBackgroundObj(self.xHoverImg, self.yHoverImg) - - def clearUndoQueue(self): - posData = self.data[self.pos_i] - self.UndoCount = 0 - self.redoAction.setEnabled(False) - self.undoAction.setEnabled(False) - posData.UndoRedoStates = [[] for _ in range(posData.SizeT)] - posData.UndoRedoCcaStates = [[] for _ in range(posData.SizeT)] - if hasattr(self, 'undoAddPointQueueMapper'): - self.undoAddPointQueueMapper = defaultdict(list) - - def updatePos(self): - self.clearUndoQueue() - self.setStatusBarLabel() - self.checkManageVersions() - self.removeAlldelROIsCurrentFrame() - self.resetManualBackgroundItems() - proceed_cca, never_visited = self.get_data(debug=False) - self.pointsLayerLoadedDfsToData() - self.flushDirtyPointsLayersAutosave() - self.initContoursImage() - self.initDelRoiLab() - self.initTextAnnot() - self.postProcessing() - self.updateScrollbars() - self.updatePreprocessPreview() - self.updateCombineChannelsPreview() - self.updateAllImages() - self.computeSegm() - self.zoomOut() - self.restartZoomAutoPilot() - self.initManualBackgroundObject() - self.updateObjectCounts() - self.updateItemsMousePos() - - def prev_pos(self): - self.store_data(debug=False, autosave=False) - prev_pos_i = self.pos_i - if self.pos_i > 0: - self.pos_i -= 1 - self.updateSegmDataAutoSaveWorker() - else: - self.logger.info('You reached first position.') - self.pos_i = self.num_pos-1 - self.updatePos() - - def updateViewerWindow(self): - if self.slideshowWin is None: - return - - if self.slideshowWin.linkWindow is None: - return - - if not self.slideshowWin.linkWindowCheckbox.isChecked(): - return - - posData = self.data[self.pos_i] - self.slideshowWin.frame_i = posData.frame_i - self.slideshowWin.update_img() - - def warnLostObjects(self, do_warn=True): - if not do_warn: - return True - - if not self.warnLostCellsAction.isChecked(): - return True - - mode = str(self.modeComboBox.currentText()) - if not mode == 'Segmentation and Tracking': - return True - - posData = self.data[self.pos_i] - if not posData.lost_IDs: - return True - - frame_i = posData.frame_i - try: - accepted_lost_IDs = posData.accepted_lost_IDs.get(frame_i, []) - already_accepted_lost = ( - Counter(accepted_lost_IDs) == Counter(posData.lost_IDs) - ) - except AttributeError as err: - already_accepted_lost = False - - if already_accepted_lost: - return True - - self.nextAction.setDisabled(True) - self.prevAction.setDisabled(True) - self.navigateScrollBar.setDisabled(True) - - msg = widgets.myMessageBox() - warn_msg = html_utils.paragraph( - 'Current frame (compared to previous frame) ' - 'has lost the following cells:

' - f'{posData.lost_IDs}

' - 'Are you sure you want to continue?
' - ) - checkBox = QCheckBox('Do not show again') - noButton, yesButton = msg.warning( - self, 'Lost cells!', warn_msg, - buttonsTexts=('No', 'Yes'), - widgets=checkBox - ) - doNotWarnLostCells = not checkBox.isChecked() - self.warnLostCellsAction.setChecked(doNotWarnLostCells) - if msg.clickedButton == noButton: - self.nextAction.setDisabled(False) - self.prevAction.setDisabled(False) - self.navigateScrollBar.setDisabled(False) - return False - - self.nextAction.setDisabled(False) - self.prevAction.setDisabled(False) - self.navigateScrollBar.setDisabled(False) - if not hasattr(posData, 'accepted_lost_IDs'): - posData.accepted_lost_IDs = {} - if frame_i not in posData.accepted_lost_IDs: - posData.accepted_lost_IDs[frame_i] = [] - - posData.accepted_lost_IDs[frame_i].extend(posData.lost_IDs) - # This section is adding the lost cells to tracked_lost_centroids... TBH I dont know why this wasnt done in the first place - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - accepted_lost_centroids = { - tuple(int(val) for val in prev_rp[prev_IDs_idxs[ID]].centroid) - for ID in posData.lost_IDs - } - try: - posData.tracked_lost_centroids[frame_i] = ( - posData.tracked_lost_centroids[frame_i] | (accepted_lost_centroids) - ) - except KeyError: - posData.tracked_lost_centroids[frame_i] = accepted_lost_centroids - return True - - def askInitCcaFirstFrame(self): - mode = str(self.modeComboBox.currentText()) - if mode != 'Cell cycle analysis': - return True - - posData = self.data[self.pos_i] - if posData.frame_i != 0: - return True - - editCcaWidget = apps.editCcaTableWidget( - posData.cca_df, posData.SizeT, parent=self, - title='Initialize cell cycle annotations' - ) - editCcaWidget.sigApplyChangesFutureFrames.connect( - self.applyManualCcaChangesFutureFrames - ) - editCcaWidget.exec_() - if editCcaWidget.cancel: - self.resetNavigateFramesScrollbar() - return False - - if posData.cca_df is not None: - is_cca_same_as_stored = ( - (posData.cca_df == editCcaWidget.cca_df).all(axis=None) - ) - if not is_cca_same_as_stored: - reinit_cca = self.warnEditingWithCca_df( - 'Re-initialize cell cyle annotations first frame', - return_answer=True - ) - if reinit_cca: - self.resetCcaFuture(0) - - posData.cca_df = editCcaWidget.cca_df - self.store_cca_df() - - return True - - def askInitLinTreeFirstFrame(self): - mode = str(self.modeComboBox.currentText()) - if mode != 'Normal division: Lineage tree': - return True - - posData = self.data[self.pos_i] - if posData.frame_i != 0: - return True - - if self.lineage_tree is None: - self.initLinTree() - - return True - - def checkIfFutureFrameManualAnnotPastFrames(self): - if not self.manualAnnotPastButton.isChecked(): - return True - - posData = self.data[self.pos_i] - frame_to_restore = self.manualAnnotState.get('frame_i_to_restore') - if posData.frame_i <= frame_to_restore: - return True - - warn_txt = ( - 'WARNING: Cannot navigate to future frames while in ' - 'manual annotation mode.' - ) - self.logger.info(warn_txt) - self.statusBarLabel.setText(f'

{warn_txt}

') - - return False - - # @exec_time - def next_frame(self, warn=True): - proceed = self.checkIfFutureFrameManualAnnotPastFrames() - if not proceed: - return - - proceed = self.askInitCcaFirstFrame() - if not proceed: - return - - proceed = self.askInitLinTreeFirstFrame() - if not proceed: - return - - mode = str(self.modeComboBox.currentText()) - posData = self.data[self.pos_i] - - if posData.frame_i >= posData.SizeT-1: - # Store data for current frame - if mode != 'Viewer': - self.store_data(debug=False) - msg = 'You reached the last segmented frame!' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - return - - proceed = self.warnLostObjects() - if not proceed: - self.resetNavigateScrollbar() - return - - # Store data for current frame - if mode != 'Viewer': - self.store_data(debug=False) - - self.askLineageTreeChanges() - posData.frame_i += 1 - self.removeAlldelROIsCurrentFrame() - proceed_cca, never_visited = self.get_data() - if not proceed_cca: - posData.frame_i -= 1 - self.get_data() - self.logger.info( - 'No data for current frame. ' - ) - return - - if mode == 'Segmentation and Tracking' or self.isSnapshot: - self.addExistingDelROIs() - - self.updatePreprocessPreview() - self.updateCombineChannelsPreview() - self.postProcessing() - self.tracking(storeUndo=True, wl_update=False) - notEnoughG1Cells, proceed = self.attempt_auto_cca() - if notEnoughG1Cells or not proceed: - posData.frame_i -= 1 - self.get_data() - self.setAllTextAnnotations() - self.logger.info( - 'Not enough G1 cells to compute cell cycle annotations.' - ) - return - - self.store_zslices_rp() - self.resetExpandLabel() - self.updateAllImages() - self.updateHighlightedAxis() - self.updateViewerWindow() - self.updateLastVisitedFrame(last_visited_frame_i=posData.frame_i-1) - self.setNavigateScrollBarMaximum() - self.updateScrollbars() - self.computeSegm() - self.initGhostObject() - self.whitelistPropagateIDs() - self.zoomToCells() - self.updateItemsMousePos() - self.updateObjectCounts() - - self.apply_tools_on_new_frame() - - def apply_tools_on_new_frame(self): - mode = str(self.modeComboBox.currentText()) - if mode != 'Segmentation and Tracking': - return - posData = self.data[self.pos_i] - if not (posData.last_tracked_i <= posData.frame_i) or posData.frame_i == self.lastFrameRanOnFirstVisitTools: - return - - self.lastFrameRanOnFirstVisitTools = posData.frame_i - for name, checkbox in self.applyToolNewFrameActions.items(): - if not checkbox.isChecked(): - continue - - tool_button = self.applyToolNewFrameButtons[name] - try: - if hasattr(tool_button, 'click'): - tool_button.click() - elif hasattr(tool_button, 'trigger'): - tool_button.trigger() - else: - printl( - f"Warning: {name} has no click or trigger method" - ) - except Exception as e: - self.logger.info(f"Error applying tool {name}: {e}") - - @disableWindow - def get_difference_table(self, return_css_separated=False, return_differece=False): - - if self.original_df_lin_tree is None: - return - - posData = self.data[self.pos_i] - - new_df = posData.allData_li[posData.frame_i]['acdc_df'] - original_df = self.original_df_lin_tree.copy() - - if original_df.equals(new_df): - return - - compare_columns = ['parent_ID_tree'] - - new_df = new_df[original_df.columns] - new_df = myutils.checked_reset_index_Cell_ID(new_df) - new_df = new_df[compare_columns] - new_df = new_df.sort_index() - original_df = myutils.checked_reset_index_Cell_ID(original_df) - original_df = original_df[compare_columns] - original_df = original_df.sort_index() - - differences = original_df.compare(new_df) - if differences.empty: - return - - differences = myutils.checked_reset_index_Cell_ID(differences) - - differences = differences['parent_ID_tree'] - differences = differences.reset_index() - - txt = """ - - - - - """ - - for diff in differences.itertuples(): - ID = str(int(diff.Cell_ID)) - old_parent = str(int(diff.self)) - new_parent = str(int(diff.other)) - - txt += f''' - - - - ''' - txt += '
IDold parent -->new parent
{ID}{old_parent}{new_parent}
' - - css = r''' - - ''' - if return_css_separated and not return_differece: - return css, txt - elif return_css_separated and return_differece: - return css, txt, differences - elif not return_css_separated and return_differece: - return txt, differences - else: - txt = css + html_utils.paragraph(txt) - return txt - - def viewLinTreeInfoAction(self): - mode = str(self.modeComboBox.currentText()) - if mode != 'Normal division: Lineage tree': - self.logger.info('This action is only available in the "Normal division: Lineage tree" mode.') - return - - if not self.lineage_tree: - self.logger.info('No lineage tree found.') - return - - posData = self.data[self.pos_i] - - if self.original_df_lin_tree_i != posData.frame_i: - # could be that this is not entirley true and self.curr_original_df_i just didnt get set right though! - txt_changes = '
No changes were made in this frame.

' - - else: - result = self.get_difference_table(return_css_separated=True) - - if result is None: - txt_changes = 'No changes were made in this frame.' - else: - css, txt_changes = result - - txt_changes = 'Changes made in this frame:' + txt_changes + '

' - - cells_with_parent, orphan_cells, lost_cells = self.lineage_tree.export_lin_tree_info(posData.frame_i) - - if orphan_cells == []: - txt_orphan_cells = 'No orphan Cells!' - else: - txt_orphan_cells = ', '.join([str(cell) for cell in orphan_cells]) - txt_orphan = f'Orphan cells:
{txt_orphan_cells}

' - - lost_cells = list(lost_cells) - if lost_cells == []: - txt_lost_cells = 'No lost Cells!' - else: - txt_lost_cells = ', '.join([str(cell) for cell in lost_cells]) - txt_lost = f'Lost cells:
{txt_lost_cells}

' - - if cells_with_parent == []: - table_cells_with_parent = '
No cells with parents!' - else: - table_cells_with_parent = """ - - - - """ - - for cell, parent in cells_with_parent: - table_cells_with_parent += f''' - - - ''' - table_cells_with_parent += '
Parent IDID
{parent}{cell}
' - - txt_cells_with_parents = f'Cells with parents:{table_cells_with_parent}

' - - css = r''' - - ''' - - txt = css + html_utils.paragraph(txt_changes + txt_orphan + txt_lost + txt_cells_with_parents) - - msg = widgets.myMessageBox() - msg.information(self, - 'lineage tree information', - txt - ) - - @disableWindow - def askLineageTreeChanges(self): - """ - Asks the user for changes in the lineage tree. - - This method is called when the user selects the 'Normal division: Lineage tree' mode. - It compared the backed up df (self.original_df from repeat_click_and_backup) with the current df (self.lineage_tree.export_df(posData.frame_i)) and propts the user to keep, propagate or discard the changes. - - """ - mode = str(self.modeComboBox.currentText()) - if mode != 'Normal division: Lineage tree': - return - - if not self.lineage_tree: - return - - posData = self.data[self.pos_i] - - if self.original_df_lin_tree_i is not None and self.original_df_lin_tree_i != posData.frame_i: - printl("!This should not happen!") - self.store_data(autosave=False) - og_frame = posData.frame_i - posData.frame_i = self.original_df_lin_tree_i - self.get_data() - self.logger.info('Lineage tree changes were not propagated, going back to original frame.') - self.askLineageTreeChanges() - self.store_data(autosave=False) - posData.frame_i = og_frame - self.get_data() - return - - result = self.get_difference_table(return_css_separated=True, return_differece=True) - if result is None: - self.original_df_lin_tree = None - self.original_df_lin_tree_i = None - return - - css, txt, differences = result - changed_IDs = differences['Cell_ID'].unique() - - if posData.frame_i == max(self.lineage_tree.frames_for_dfs): - # here we can just propagate the cahnged. This is super fast, since there is no recursion, no children and fast finding of parents - self.lineage_tree.propagate(posData.frame_i, relevant_cells=changed_IDs) - self.original_df_lin_tree = None - self.original_df_lin_tree_i = None - return - - txt = txt + 'Do you want to keep, propgagte or discard the changes?' - txt = css + html_utils.paragraph('Changes made in this frame
' + txt) - - msg = widgets.myMessageBox() - - propagate_btn, discard_btn, _ = msg.question(self, - 'Changes in lineage tree', - txt, - buttonsTexts=('Propagate', 'Discard', 'Cancel'),) - - if msg.clickedButton == propagate_btn: - self.lineage_tree.propagate(posData.frame_i, relevant_cells=changed_IDs) - self.original_df_lin_tree = None - self.original_df_lin_tree_i = None - self.logger.info('Lineage tree propagated.') - - elif msg.clickedButton == discard_btn: - posData.allData_li[posData.frame_i]['acdc_df'] = self.original_df_lin_tree.copy() - self.original_df_lin_tree = None - self.original_df_lin_tree_i = None - self.logger.info('Lineage tree changes discarded.') - - - elif msg.cancel: - # Go back to current frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph(''' - Changes were kept but not propagated! - Please make sure to come back and propagate them, - otherwise your table might be inconsistent! - There is a button for this next to the edit buttons. - Please also do not visit new frames! - - ''') - msg.warning(self, 'Changes kept but not propagated!', txt) - self.original_df_lin_tree = None - self.original_df_lin_tree_i = None - self.logger.info('Lineage tree changes discarded.') - - def manualAnnotRestoreLastTrackedFrame(self, last_tracked_i_to_restore): - if self.navigateScrollBar.maximum()-1 <= last_tracked_i_to_restore: - return - - posData = self.data[self.pos_i] - for frame_i in range(last_tracked_i_to_restore+1, posData.SizeT): - data_frame_i = myutils.get_empty_stored_data_dict() - - data_frame_i['manually_edited_lab'] = ( - posData.allData_li[frame_i]['manually_edited_lab'] - ) - - posData.allData_li[frame_i] = data_frame_i - - self.navigateScrollBar.setMaximum(last_tracked_i_to_restore+1) - self.navSpinBox.setMaximum(last_tracked_i_to_restore+1) - - def setNavigateScrollBarMaximum(self): - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Segmentation and Tracking': - if posData.last_tracked_i is not None: - if posData.frame_i > posData.last_tracked_i: - self.navigateScrollBar.setMaximum(posData.frame_i+1) - self.navSpinBox.setMaximum(posData.frame_i+1) - else: - self.navigateScrollBar.setMaximum(posData.last_tracked_i+1) - self.navSpinBox.setMaximum(posData.last_tracked_i+1) - else: - self.navigateScrollBar.setMaximum(posData.frame_i+1) - self.navSpinBox.setMaximum(posData.frame_i+1) - - self.updateLastCheckedFrameWidgets(self.navSpinBox.maximum()-1) - elif mode == 'Cell cycle analysis': - if posData.frame_i > self.last_cca_frame_i: - self.navigateScrollBar.setMaximum(posData.frame_i+1) - self.navSpinBox.setMaximum(posData.frame_i+1) - else: - self.navigateScrollBar.setMaximum(self.last_cca_frame_i+1) - self.navSpinBox.setMaximum(self.last_cca_frame_i+1) - self.lastTrackedFrameLabel.setText( - f'Last cc annot. frame n. = {self.navSpinBox.maximum()}' - ) - elif mode == 'Normal division: Lineage tree': - if self.lineage_tree is None: - self.navigateScrollBar.setMaximum(posData.frame_i+1) - self.navSpinBox.setMaximum(posData.frame_i+1) - else: - if self.lineage_tree.frames_for_dfs: - i = max(self.lineage_tree.frames_for_dfs) - else: - i = 0 - self.navigateScrollBar.setMaximum(i+1) - self.navSpinBox.setMaximum(i+1) - - # @exec_time - def prev_frame(self): - posData = self.data[self.pos_i] - if posData.frame_i <= 0: - msg = 'You reached the first frame!' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - return - - # Store data for current frame - mode = str(self.modeComboBox.currentText()) - if mode != 'Viewer': - self.store_data(debug=False) - - self.removeAlldelROIsCurrentFrame() - self.askLineageTreeChanges() - posData.frame_i -= 1 - _, never_visited = self.get_data() - - if mode == 'Segmentation and Tracking' or self.isSnapshot: - self.addExistingDelROIs() - - self.resetExpandLabel() - self.updatePreprocessPreview() - self.updateCombineChannelsPreview() - self.postProcessing() - self.tracking() - self.whitelistPropagateIDs(update_lab=True) - self.updateAllImages() - self.updateScrollbars() - self.updateHighlightedAxis() - self.zoomToCells() - self.initGhostObject() - self.updateViewerWindow() - self.updateItemsMousePos() - self.updateObjectCounts() - - def loadSelectedData(self, user_ch_file_paths, user_ch_name): - data = [] - numPos = len(user_ch_file_paths) - self.user_ch_file_paths = user_ch_file_paths - - self.logger.info(f'Reading {user_ch_name} channel metadata...') - # Get information from first loaded position - posData = load.loadData(user_ch_file_paths[0], user_ch_name, log_func=self.logger.info) - posData.getBasenameAndChNames(qparent=self) - posData.buildPaths() - - if posData.ext != '.h5': - self.lazyLoader.salute = False - self.lazyLoader.exit = True - self.lazyLoaderWaitCond.wakeAll() - self.waitReadH5cond.wakeAll() - - # Get end name of every existing segmentation file - existingSegmEndNames = set() - for filePath in user_ch_file_paths: - _posData = load.loadData(filePath, user_ch_name, log_func=self.logger.info) - _posData.getBasenameAndChNames(qparent=self) - segm_files = load.get_segm_files(_posData.images_path) - _existingEndnames = load.get_endnames( - _posData.basename, segm_files - ) - existingSegmEndNames.update(_existingEndnames) - - selectedSegmEndName = '' - self.newSegmEndName = '' - if self.isNewFile or not existingSegmEndNames: - self.isNewFile = True - # Remove the 'segm_' part to allow filenameDialog to check if - # a new file is existing (since we only ask for the part after - # 'segm_') - existingEndNames = [ - n.replace('segm', '', 1).replace('_', '', 1) - for n in existingSegmEndNames - ] - if posData.basename.endswith('_'): - basename = f'{posData.basename}segm' - else: - basename = f'{posData.basename}_segm' - win = apps.filenameDialog( - basename=basename, - hintText='Insert a filename for the segmentation file:', - existingNames=existingEndNames - ) - win.exec_() - if win.cancel: - self.loadingDataAborted() - return - self.newSegmEndName = win.entryText - else: - if len(existingSegmEndNames) > 0: - win = apps.SelectSegmFileDialog( - existingSegmEndNames, self.exp_path, parent=self, - addNewFileButton=True, basename=posData.basename - ) - win.exec_() - if win.cancel: - self.loadingDataAborted() - return - if win.newSegmEndName is None: - selectedSegmEndName = win.selectedItemText - self.AutoPilotProfile.storeSelectedSegmFile( - selectedSegmEndName - ) - else: - self.newSegmEndName = win.newSegmEndName - self.isNewFile = True - elif len(existingSegmEndNames) == 1: - selectedSegmEndName = list(existingSegmEndNames)[0] - - posData.loadImgData() - - required_ram = posData.getBytesImageData() - if required_ram >= 5e8: - # Disable autosave for data > 500MB - self.autoSaveToggle.setChecked(False) - - proceed = self.checkMemoryRequirements(required_ram) - if not proceed: - self.loadingDataAborted() - return - - posData.loadOtherFiles( - load_segm_data=True, - load_metadata=True, - create_new_segm=self.isNewFile, - new_endname=self.newSegmEndName, - end_filename_segm=selectedSegmEndName, - ) - self.selectedSegmEndName = selectedSegmEndName - self.labelBoolSegm = posData.labelBoolSegm - posData.labelSegmData() - - print('') - self.logger.info( - f'Segmentation filename: {posData.segm_npz_path}' - ) - - proceed = posData.askInputMetadata( - self.num_pos, - ask_SizeT=self.num_pos==1, - ask_TimeIncrement=True, - ask_PhysicalSizes=True, - singlePos=False, - save=True, - warnMultiPos=True - ) - if not proceed: - self.loadingDataAborted() - return - - self.AutoPilotProfile.storeOkAskInputMetadata() - - if posData.isSegm3D is None: - self.isSegm3D = False - else: - self.isSegm3D = posData.isSegm3D - self.SizeT = posData.SizeT - self.SizeZ = posData.SizeZ - self.TimeIncrement = posData.TimeIncrement - self.PhysicalSizeZ = posData.PhysicalSizeZ - self.PhysicalSizeY = posData.PhysicalSizeY - self.PhysicalSizeX = posData.PhysicalSizeX - self.loadSizeS = posData.loadSizeS - self.loadSizeT = posData.loadSizeT - self.loadSizeZ = posData.loadSizeZ - - self.overlayLabelsItems = {} - self.drawModeOverlayLabelsChannels = {} - - self.existingSegmEndNames = existingSegmEndNames - self.createOverlayLabelsContextMenu(existingSegmEndNames) - self.overlayLabelsButtonAction.setVisible(True) - self.createOverlayLabelsItems(existingSegmEndNames) - self.disableNonFunctionalButtons() - - self.isH5chunk = ( - posData.ext == '.h5' - and (self.loadSizeT != self.SizeT - or self.loadSizeZ != self.SizeZ) - ) - - required_ram = posData.checkH5memoryFootprint()*self.loadSizeS - if required_ram > 0: - proceed = self.checkMemoryRequirements(required_ram) - if not proceed: - self.loadingDataAborted() - return - - if posData.SizeT == 1: - self.isSnapshot = True - else: - self.isSnapshot = False - - self.progressWin = apps.QDialogWorkerProgress( - title='Loading data...', parent=self, - pbarDesc=f'Loading "{user_ch_file_paths[0]}"...' - ) - self.progressWin.show(self.app) - - func = partial( - self.startLoadDataWorker, user_ch_file_paths, user_ch_name, - posData - ) - - - QTimer.singleShot(150, func) - - def setManualAnnotModeEnabledTools(self, enabled): - for action in self.editToolBar.actions(): - toolButton = self.editToolBar.widgetForAction(action) - if toolButton in self.manulAnnotToolButtons: - continue - - toolButton.setDisabled(enabled) - action.setDisabled(enabled) - - def disableNonFunctionalButtons(self): - if not self.isSegm3D: - return - - for item in self.functionsNotTested3D: - if hasattr(item, 'action'): - toolButton = item - action = toolButton.action - toolButton.setDisabled(True) - elif hasattr(item, 'toolbar'): - toolbar = item.toolbar - action = item - toolButton = toolbar.widgetForAction(action) - toolButton.setDisabled(True) - else: - action = item - action.setDisabled(True) - - @exception_handler - def startLoadDataWorker( - self, user_ch_file_paths, user_ch_name, firstPosData - ): - self.funcDescription = 'loading data' - - self.guiTabControl.propsQGBox.idSB.setValue(0) - - self.thread = QThread() - self.loadDataMutex = QMutex() - self.loadDataWaitCond = QWaitCondition() - - self.loadDataWorker = workers.loadDataWorker( - self, user_ch_file_paths, user_ch_name, firstPosData - ) - - self.loadDataWorker.moveToThread(self.thread) - self.loadDataWorker.signals.finished.connect(self.thread.quit) - self.loadDataWorker.signals.finished.connect( - self.loadDataWorker.deleteLater - ) - self.thread.finished.connect(self.thread.deleteLater) - - self.loadDataWorker.signals.finished.connect( - self.loadDataWorkerFinished - ) - self.loadDataWorker.signals.progress.connect(self.workerProgress) - self.loadDataWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.loadDataWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.loadDataWorker.signals.critical.connect( - self.workerCritical - ) - self.loadDataWorker.signals.dataIntegrityCritical.connect( - self.loadDataWorkerDataIntegrityCritical - ) - self.loadDataWorker.signals.dataIntegrityWarning.connect( - self.loadDataWorkerDataIntegrityWarning - ) - self.loadDataWorker.signals.sigPermissionError.connect( - self.workerPermissionError - ) - self.loadDataWorker.signals.sigWarnMismatchSegmDataShape.connect( - self.askMismatchSegmDataShape - ) - self.loadDataWorker.signals.sigRecovery.connect( - self.askRecoverNotSavedData - ) - - self.thread.started.connect(self.loadDataWorker.run) - self.thread.start() - - def askRecoverNotSavedData(self, posData): - last_modified_time_unsaved = 'NEVER' - if os.path.exists(posData.segm_npz_temp_path): - recovered_file_path = posData.segm_npz_temp_path - if os.path.exists(posData.segm_npz_path): - last_modified_time_unsaved = ( - datetime.fromtimestamp( - os.path.getmtime(posData.segm_npz_path) - ).strftime("%a %d. %b. %y - %H:%M:%S") - ) - else: - posData.setTempPaths() - if os.path.exists(posData.unsaved_acdc_df_autosave_path): - zip_path = posData.unsaved_acdc_df_autosave_path - with zipfile.ZipFile(zip_path, mode='r') as zip: - csv_names = natsorted(set(zip.namelist())) - iso_key = csv_names[-1][:-4] - most_recent_unsaved_acdc_df_datetime = datetime.strptime( - iso_key, load.ISO_TIMESTAMP_FORMAT - ) - last_modified_time_unsaved = ( - most_recent_unsaved_acdc_df_datetime - ).strftime("%a %d. %b. %y - %H:%M:%S") - - if os.path.exists(posData.acdc_output_csv_path): - acdc_df_mtime = os.path.getmtime(posData.acdc_output_csv_path) - timestamp = datetime.fromtimestamp(acdc_df_mtime) - last_modified_time_saved = timestamp.strftime( - "%a %d. %b. %y - %H:%M:%S" - ) - else: - last_modified_time_saved = 'Null' - - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(""" - Cell-ACDC detected unsaved data.

- Do you want to load and recover the unsaved data or - load the data that was last saved by the user? - """) - details = (f""" - The unsaved data was created on {last_modified_time_unsaved}\n\n - The user saved the data last time on {last_modified_time_saved} - """) - msg.setDetailedText(details) - loadUnsavedButton = widgets.reloadPushButton('Recover unsaved data') - loadSavedButton = widgets.savePushButton('Load saved data') - infoButton = widgets.infoPushButton('More info...') - loadSafeNpzButton = '' - if posData.isSafeNpzOverwritePresent(): - loadSafeNpzButton = widgets.reloadPushButton( - 'Load .safe.npz file from crash' - ) - buttons = ( - loadSavedButton, loadUnsavedButton, loadSafeNpzButton, - infoButton - ) - else: - buttons = (loadSavedButton, loadUnsavedButton, infoButton) - msg.question( - self.progressWin, 'Recover unsaved data?', txt, - buttonsTexts=('Cancel', *buttons), - showDialog=False - ) - infoButton.disconnect() - infoButton.clicked.connect(partial(self.showInfoAutosave, posData)) - msg.exec_() - if msg.cancel: - self.loadDataWorker.abort = True - elif msg.clickedButton == loadUnsavedButton: - self.loadDataWorker.loadUnsaved = True - elif msg.clickedButton == loadSafeNpzButton: - self.loadDataWorker.loadSafeOverwriteNpz = True - - self.loadDataWorker.waitCond.wakeAll() - # self.AutoPilotProfile.storeLoadSavedData() - - def showInfoAutosave(self, posData): - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = (f""" - Cell-ACDC either detected unsaved data in a previous session and it - stored it because the Autosave
- function was active, or it crashed during saving.

- You can toggle Autosave ON and OFF from the menu on the top menubar - File --> Autosave. - """) - txt = (f""" - {txt}

- If Cell-ACDC crashed during saving, the segmentation file ending - with .new.npz
- is present and you might be able to recover the data from there. - """) - - txt = (f""" - {txt}

- You can find additional recovered data in the following folder: - """) - txt = html_utils.paragraph(txt) - msg.information( - self, 'Autosave info', txt, - path_to_browse=posData.recoveryFolderPath, - commands=(posData.recoveryFolderPath,) - ) - - def askMismatchSegmDataShape(self, posData): - msg = widgets.myMessageBox(wrapText=False) - title = 'Segm. data shape mismatch' - f = '3D' if self.isSegm3D else '2D' - f = f'{f} over time' if posData.SizeT > 1 else f - r = '2D' if self.isSegm3D else '3D' - r = f'{r} over time' if posData.SizeT > 1 else r - text = html_utils.paragraph(f""" - The segmentation masks of the first Position that you loaded is - {f},
- while {posData.pos_foldername} is {r}.

- The loaded segmentation masks must be either all 3D - or all 2D.

- Do you want to skip loading this position or cancel the process? - """) - _, skipPosButton = msg.warning( - self, title, text, buttonsTexts=('Cancel', 'Skip this Position') - ) - if skipPosButton == msg.clickedButton: - self.loadDataWorker.skipPos = True - self.loadDataWorker.waitCond.wakeAll() - - def workerPermissionError(self, txt, waitCond): - msg = widgets.myMessageBox(parent=self) - msg.setIcon(iconName='SP_MessageBoxCritical') - msg.setWindowTitle('Permission denied') - msg.addText(txt) - msg.addButton(' Ok ') - msg.exec_() - waitCond.wakeAll() - - def loadDataWorkerDataIntegrityCritical(self): - errTitle = 'All loaded positions contains frames over time!' - self.titleLabel.setText(errTitle, color='r') - - msg = widgets.myMessageBox(parent=self) - - err_msg = html_utils.paragraph(f""" - {errTitle}.

- To load data that contains frames over time you have to select - only ONE position. - """) - msg.setIcon(iconName='SP_MessageBoxCritical') - msg.setWindowTitle('Loaded multiple positions with frames!') - msg.addText(err_msg) - msg.addButton('Ok') - msg.show(block=True) - - @exception_handler - def loadDataWorkerFinished(self, data): - self.funcDescription = 'loading data worker finished' - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - if data is None or data=='abort': - self.loadingDataAborted() - return - - if data[0].onlyEditMetadata: - self.loadingDataAborted() - return - - self.pos_i = 0 - self.data = data - self.gui_createGraphicsItems() - return True - - def checkManageVersions(self): - posData = self.data[self.pos_i] - posData.setTempPaths(createFolder=False) - loaded_acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) - - if os.path.exists(posData.recoveryFolderpath()): - self.manageVersionsAction.setDisabled(False) - self.manageVersionsAction.setToolTip( - f'Load an older version of the `{loaded_acdc_df_filename}` file ' - '(table with annotations and measurements).' - ) - else: - self.manageVersionsAction.setDisabled(True) - - def preprocessPreviewToggled(self, checked): - self.viewPreprocDataToggle.setChecked(checked) - self.updatePreprocessPreview() - - - - def preprocessCurrentImage(self, recipe: List[Dict[str, Any]], *args): - txt = 'Pre-processing current image...' - self.logger.info(txt) - self.statusBarLabel.setText(txt) - - func = core.preprocess_image_from_recipe - recipe = core.validate_multidimensional_recipe(recipe) - - image_data = self.getImage(raw=True) - self.preprocWorker.setupJob( - func, - image_data, - recipe, - 'current_image' - ) - - self.preprocWorker.wakeUp() - - def preprocessZStack(self, recipe: List[Dict[str, Any]], *args): - txt = 'Pre-processing z-stack...' - self.statusBarLabel.setText(txt) - self.logger.info(txt) - - posData = self.data[self.pos_i] - func = core.preprocess_zstack_from_recipe - recipe = core.validate_multidimensional_recipe( - recipe, apply_to_all_frames=False - ) - image_data = posData.img_data[posData.frame_i] - self.preprocWorker.setupJob( - func, - image_data, - recipe, - 'z_stack' - ) - - self.preprocWorker.wakeUp() - - def preprocessAllFrames(self, recipe: List[Dict[str, Any]]): - txt = 'Pre-processing all frames...' - self.logger.info(txt) - self.statusBarLabel.setText(txt) - - posData = self.data[self.pos_i] - func = core.preprocess_video_from_recipe - image_data = posData.img_data - self.preprocWorker.setupJob( - func, - image_data, - recipe, - 'all_frames' - ) - self.preprocWorker.wakeUp() - - def preprocessAllPos(self, recipe: List[Dict[str, Any]]): - txt = 'Pre-processing all Positions...' - self.logger.info(txt) - self.statusBarLabel.setText(txt) - - func = core.preprocess_multi_pos_from_recipe - recipe = core.validate_multidimensional_recipe( - recipe, apply_to_all_frames=False - ) - image_data = [posData.img_data[0] for posData in self.data] - self.preprocWorker.setupJob( - func, - image_data, - recipe, - 'all_pos' - ) - - self.preprocWorker.wakeUp() - - def setupPreprocessing(self): - posData = self.data[self.pos_i] - if self.preprocessDialog is not None: - self.preprocessDialog.close() - - self.preprocessDialog = apps.PreProcessRecipeDialog( - isTimelapse=posData.SizeT>1, - isZstack=posData.SizeZ>1, - isMultiPos=len(self.data)>1, - df_metadata=posData.metadata_df, - hideOnClosing=True, - addApplyButton=True, - parent=self - ) - self.doPreviewPreprocImage = False - self.preprocessDialog.sigApplyImage.connect( - self.preprocessCurrentImage - ) - self.preprocessDialog.sigApplyZstack.connect( - self.preprocessZStack - ) - self.preprocessDialog.sigApplyAllFrames.connect( - self.preprocessAllFrames - ) - self.preprocessDialog.sigApplyAllPos.connect( - self.preprocessAllPos - ) - self.preprocessDialog.sigPreviewToggled.connect( - self.preprocessPreviewToggled - ) - self.preprocessDialog.sigValuesChanged.connect( - self.preprocessDialogRecipeChanged - ) - self.preprocessDialog.sigSavePreprocData.connect( - self.preprocessDialogSavePreprocessedData - ) - - if self.preprocWorker is not None: - return - - self.preprocThread = QThread() - self.preprocMutex = QMutex() - self.preprocWaitCond = QWaitCondition() - - self.preprocWorker = workers.CustomPreprocessWorkerGUI( - self.preprocMutex, self.preprocWaitCond - ) - - self.preprocWorker.moveToThread(self.preprocThread) - self.preprocWorker.signals.finished.connect(self.preprocThread.quit) - self.preprocWorker.signals.finished.connect( - self.preprocWorker.deleteLater - ) - self.preprocThread.finished.connect(self.preprocThread.deleteLater) - - self.preprocWorker.sigDone.connect(self.preprocWorkerDone) - self.preprocWorker.sigIsQueueEmpty.connect( - self.preprocWorkerIsQueueEmpty - ) - self.preprocWorker.sigPreviewDone.connect(self.preprocWorkerPreviewDone) - self.preprocWorker.signals.progress.connect(self.workerProgress) - self.preprocWorker.signals.critical.connect(self.workerCritical) - self.preprocWorker.signals.finished.connect(self.preprocWorkerClosed) - - self.preprocThread.started.connect(self.preprocWorker.run) - self.preprocThread.start() - - self.logger.info('Pre-processing worker started.') - - def preprocWorkerCritical(self, error): - self.preprocessDialog.appliedFinished() - self.workerCritical(error) - - @exception_handler - def loadingDataCompleted(self): - self.isDataLoading = True - posData = self.data[self.pos_i] - - files_format = '\n'.join([ - f' - {file}' for file in posData.images_folder_files - ]) - sep = '-'*100 - self.logger.info( - f'{sep}\nFiles present in the first Position folder loaded:\n\n' - f'{files_format}\n{sep}' - ) - self.logger.info(f'Basename of the first Position: {posData.basename}') - self.secondLevelToolbar.setVisible(True) - self.updateImageValueFormatter() - self.checkManageVersions() - self.initManualBackgroundImage() - self.initPixelSizePropsDockWidget() - - self.setWindowTitle( - f'Cell-ACDC v{self._acdc_version} - GUI - "{posData.exp_path}"' - ) - - self.setupPreprocessing() - self.setupCombiningChannels() - - if self.isSegm3D: - self.segmNdimIndicator.setText('3D') - else: - self.segmNdimIndicator.setText('2D') - - self.segmNdimIndicatorAction.setVisible(True) - - self.guiTabControl.addChannels([posData.user_ch_name]) - self.showPropsDockButton.setDisabled(False) - - self.bottomScrollArea.show() - self.gui_createStoreStateWorker() - self.init_segmInfo_df() - self.connectScrollbars() - self.initPosAttr() - - self.logger.info('Pre-computing min and max values of the images...') - self.img1.preComputedMinMaxValues(self.data) - self.img2.minMaxValuesMapper = self.img1.minMaxValuesMapper - - self.initMetrics() - self.initFluoData() - self.createChannelNamesActions() - self.addActionsLutItemContextMenu(self.imgGrad) - - # Scrollbar for opacity of img1 (when overlaying) - self.img1.alphaScrollbar = self.addAlphaScrollbar( - self.user_ch_name, self.img1 - ) - - self.navigateScrollBar.setSliderPosition(posData.frame_i+1) - - # Connect events at the end of loading data process - self.gui_connectGraphicsEvents() - if not self.isEditActionsConnected: - self.gui_connectEditActions() - self.normalizeToFloatAction.setChecked(True) - - self.navSpinBox.connectValueChanged(self.navigateSpinboxValueChanged) - - self.setFramesSnapshotMode() - if self.isSnapshot: - self.navSizeLabel.setText(f'/{len(self.data)}') - else: - self.navSizeLabel.setText(f'/{posData.SizeT}') - - self.enableZstackWidgets(posData.SizeZ > 1) - # self.showHighlightZneighCheckbox() - - self.exportToVideoAction.setDisabled( - posData.SizeZ == 1 and posData.SizeT == 1 - ) - - self.img1BottomGroupbox.show() - - isLabVisible = self.df_settings.at['isLabelsVisible', 'value'] == 'Yes' - isRightImgVisible = ( - self.df_settings.at['isRightImageVisible', 'value'] == 'Yes' - ) - isNextFrameVisible = ( - self.df_settings.at['isNextFrameVisible', 'value'] == 'Yes' - ) - isNextFrameActive = ( - isNextFrameVisible and self.labelsGrad.showNextFrameAction.isEnabled() - ) - self.updateScrollbars() - self.openFolderAction.setEnabled(True) - self.editTextIDsColorAction.setDisabled(False) - self.imgPropertiesAction.setEnabled(True) - self.navigateToolBar.setVisible(True) - self.labelsGrad.showLabelsImgAction.setChecked(isLabVisible) - self.labelsGrad.showRightImgAction.setChecked(isRightImgVisible) - self.labelsGrad.showNextFrameAction.setChecked(isNextFrameActive) - if isRightImgVisible or isNextFrameActive: - self.rightBottomGroupbox.setChecked(True) - - isTwoImagesLayout = ( - isRightImgVisible or isLabVisible or isNextFrameActive - ) - self.setTwoImagesLayout(isTwoImagesLayout) - - self.setBottomLayoutStretch() - - if isNextFrameActive: - self.rightBottomGroupbox.show() - self.rightBottomGroupbox.setChecked(True) - self.drawNothingCheckboxRight.click() - - self.readSavedCustomAnnot() - self.addCustomAnnotButtonAllLoadedPos() - self.setStatusBarLabel() - - self.initLookupTableLab() - if self.invertBwAction.isChecked() and not self.invertBwAlreadyCalledOnce: - self.invertBw(True) - self.restoreSavedSettings() - - self.initContoursImage() - self.initTextAnnot() - self.initDelRoiLab() - - self.update_rp() - self.updateAllImages() - if posData.SizeT > 1: - self.rightImageFramesScrollbar.setValueNoSignal(posData.frame_i+2) - self.setMetricsFunc() - - self.gui_createLabelRoiItem() - self.gui_createZoomRectItem() - - self.titleLabel.setText( - 'Data successfully loaded.', - color=self.titleColor - ) - - self.disableNonFunctionalButtons() - self.setVisible3DsegmWidgets() - - if len(self.data) == 1 and posData.SizeZ > 1 and posData.SizeT == 1: - self.zSliceCheckbox.setChecked(True) - else: - self.zSliceCheckbox.setChecked(False) - - self.labelRoiCircItemLeft.setImageShape(self.currentLab2D.shape) - self.labelRoiCircItemRight.setImageShape(self.currentLab2D.shape) - - self.retainSpaceSlidersToggled(self.retainSpaceSlidersAction.isChecked()) - - self.stopAutomaticLoadingPos() - self.viewAllCustomAnnotAction.setChecked(True) - - self.updateImageValueFormatter() - - posData.loadWhitelist() - - self.setFocusGraphics() - self.setFocusMain() - - # Overwrite axes viewbox context menu - self.ax1.vb.menu = self.imgGrad.gradient.menu - self.ax2.vb.menu = self.labelsGrad.menu - - QTimer.singleShot(200, self.resizeGui) - - self.isDataLoaded = True - self.isDataLoading = False - - self.initImgGradRescaleIntensitiesHowPreference() - - self.rescaleIntensitiesLut(setImage=False) - - self.gui_createAutoSaveWorker() - - def initImgGradRescaleIntensitiesHowPreference(self): - posData = self.data[self.pos_i] - channelName = posData.user_ch_name - if f'how_rescale_intensities_{channelName}' not in self.df_settings.index: - return - - how = self.df_settings.at[ - f'how_rescale_intensities_{channelName}', 'value' - ] - self.imgGrad.setRescaleIntensitiesHow(how) - - def removeAxLimits(self): - self.ax1.vb.state['limits']['xLimits'] = [-1E307, +1E307] - self.ax1.vb.state['limits']['yLimits'] = [-1E307, +1E307] - - def resizeGui(self): - self.ax1.vb.state['limits']['xRange'] = [None, None] - self.ax1.vb.state['limits']['yRange'] = [None, None] - self.autoRange() - if self.ax1.getViewBox().state['limits']['xRange'][0] is not None: - self.bottomScrollArea._resizeVertical() - return - (xmin, xmax), (ymin, ymax) = self.ax1.viewRange() - maxYRange = int((ymax-ymin)*1.5) - maxXRange = int((xmax-xmin)*1.5) - self.ax1.setLimits( - maxYRange=maxYRange, - maxXRange=maxXRange - ) - self.bottomScrollArea._resizeVertical() - QTimer.singleShot(200, self.autoRange) - - def setVisible3DsegmWidgets(self): - self.annotNumZslicesCheckbox.setVisible(self.isSegm3D) - self.annotNumZslicesCheckboxRight.setVisible(self.isSegm3D) - if not self.isSegm3D: - self.annotNumZslicesCheckbox.setChecked(False) - self.annotNumZslicesCheckboxRight.setChecked(False) - - def showHighlightZneighCheckbox(self): - if self.isSegm3D: - layout = self.bottomLeftLayout - # layout.addWidget(self.annotOptionsWidget, 0, 1, 1, 2) - # # layout.removeWidget(self.drawIDsContComboBox) - # # layout.addWidget(self.drawIDsContComboBox, 0, 1, 1, 1, - # # alignment=Qt.AlignCenter - # # ) - # layout.addWidget(self.highlightZneighObjCheckbox, 0, 2, 1, 2, - # alignment=Qt.AlignRight - # ) - self.highlightZneighObjCheckbox.show() - self.highlightZneighObjCheckbox.setChecked(True) - self.highlightZneighObjCheckbox.toggled.connect( - self.highlightZneighLabels_cb - ) - - def restoreSavedSettings(self): - if 'how_draw_annotations' in self.df_settings.index: - how = self.df_settings.at['how_draw_annotations', 'value'] - self.drawIDsContComboBox.setCurrentText(how) - else: - self.drawIDsContComboBox.setCurrentText('Draw IDs and contours') - - if 'how_draw_right_annotations' in self.df_settings.index: - how = self.df_settings.at['how_draw_right_annotations', 'value'] - self.annotateRightHowCombobox.setCurrentText(how) - else: - self.annotateRightHowCombobox.setCurrentText( - 'Draw IDs and overlay segm. masks' - ) - - if 'addNewIDsWhitelistToggle' in self.df_settings.index: - self.addNewIDsWhitelistToggle = ( - self.df_settings.at['addNewIDsWhitelistToggle', 'value'] - ) == 'Yes' - else: - self.addNewIDsWhitelistToggle = True - - self.drawAnnotCombobox_to_options() - self.drawIDsContComboBox_cb(0) - self.annotateRightHowCombobox_cb(0) - - def uncheckAnnotOptions(self, left=True, right=True): - # Left - if left: - self.annotIDsCheckbox.setChecked(False) - self.annotCcaInfoCheckbox.setChecked(False) - self.annotContourCheckbox.setChecked(False) - self.annotSegmMasksCheckbox.setChecked(False) - self.drawMothBudLinesCheckbox.setChecked(False) - self.drawNothingCheckbox.setChecked(False) - - # Right - if right: - self.annotIDsCheckboxRight.setChecked(False) - self.annotCcaInfoCheckboxRight.setChecked(False) - self.annotContourCheckboxRight.setChecked(False) - self.annotSegmMasksCheckboxRight.setChecked(False) - self.drawMothBudLinesCheckboxRight.setChecked(False) - self.drawNothingCheckboxRight.setChecked(False) - - def setDisabledAnnotOptions(self, disabled): - # Left - self.annotIDsCheckbox.setDisabled(disabled) - self.annotCcaInfoCheckbox.setDisabled(disabled) - self.annotContourCheckbox.setDisabled(disabled) - # self.annotSegmMasksCheckbox.setDisabled(disabled) - self.drawMothBudLinesCheckbox.setDisabled(disabled) - # self.drawNothingCheckbox.setDisabled(disabled) - - # Right - self.annotIDsCheckboxRight.setDisabled(disabled) - self.annotCcaInfoCheckboxRight.setDisabled(disabled) - self.annotContourCheckboxRight.setDisabled(disabled) - # self.annotSegmMasksCheckboxRight.setDisabled(disabled) - self.drawMothBudLinesCheckboxRight.setDisabled(disabled) - # self.drawNothingCheckboxRight.setDisabled(disabled) - - def drawAnnotCombobox_to_options(self): - self.uncheckAnnotOptions() - - # Left - how = self.drawIDsContComboBox.currentText() - if how.find('IDs') != -1: - self.annotIDsCheckbox.setChecked(True) - if how.find('cell cycle info') != -1: - self.annotCcaInfoCheckbox.setChecked(True) - if how.find('contours') != -1: - self.annotContourCheckbox.setChecked(True) - if how.find('segm. masks') != -1: - self.annotSegmMasksCheckbox.setChecked(True) - if how.find('mother-bud lines') != -1: - self.drawMothBudLinesCheckbox.setChecked(True) - if how.find('nothing') != -1: - self.drawNothingCheckbox.setChecked(True) - - # Right - how = self.annotateRightHowCombobox.currentText() - if how.find('IDs') != -1: - self.annotIDsCheckboxRight.setChecked(True) - if how.find('cell cycle info') != -1: - self.annotCcaInfoCheckboxRight.setChecked(True) - if how.find('contours') != -1: - self.annotContourCheckboxRight.setChecked(True) - if how.find('segm. masks') != -1: - self.annotSegmMasksCheckboxRight.setChecked(True) - if how.find('mother-bud lines') != -1: - self.drawMothBudLinesCheckboxRight.setChecked(True) - if how.find('nothing') != -1: - self.drawNothingCheckboxRight.setChecked(True) - - def setStatusBarLabel(self, log=True): - self.statusbar.clearMessage() - posData = self.data[self.pos_i] - segmentedChannelname = posData.filename[len(posData.basename):] - segmFilename = os.path.basename(posData.segm_npz_path) - segmEndName = segmFilename[len(posData.basename):] - txt = ( - f'{posData.pos_foldername} || ' - f'Basename: {posData.basename} || ' - f'Segmented channel: {segmentedChannelname} || ' - f'Segmentation file name: {segmEndName}' - ) - mode = str(self.modeComboBox.currentText()) - if log: - self.logger.info(txt) - self.statusBarLabel.setText(txt) - - def autoRange(self): - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.ax2.autoRange() - self.ax1.autoRange() - - def resetRange(self): - if self.ax1_viewRange is None: - return - xRange, yRange = self.ax1_viewRange - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.ax2.vb.setRange(xRange=xRange, yRange=yRange) - self.ax1.vb.setRange(xRange=xRange, yRange=yRange) - self.ax1_viewRange = None - self.isRangeReset = True - - def setFramesSnapshotMode(self): - self.measurementsMenu.setDisabled(False) - self.setPermanentGreedyCmapPreferences() - if self.isSnapshot: - self.realTimeTrackingToggle.setDisabled(True) - self.realTimeTrackingToggle.label.setDisabled(True) - try: - self.drawIDsContComboBox.currentIndexChanged.disconnect() - except Exception as e: - pass - - self.imgGrad.rescaleAcrossTimeAction.setDisabled(True) - self.repeatTrackingAction.setDisabled(True) - self.manualTrackingAction.setDisabled(True) - self.logger.info('Setting GUI mode to "Snapshots"...') - self.modeComboBox.clear() - self.modeComboBox.addItems(['Snapshot']) - self.modeComboBox.setDisabled(True) - self.modeMenu.menuAction().setVisible(False) - self.drawIDsContComboBox.clear() - self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) - self.drawIDsContComboBox.setCurrentIndex(1) - self.modeToolBar.setVisible(False) - self.skipToNewIdAction.setVisible(False) - self.skipToNewIdAction.setDisabled(True) - self.modeComboBox.setCurrentText('Snapshot') - self.annotateToolbar.setVisible(True) - self.labelsGrad.showNextFrameAction.setDisabled(True) - self.drawIDsContComboBox.currentIndexChanged.connect( - self.drawIDsContComboBox_cb - ) - self.showTreeInfoCheckbox.hide() - self.rightImageFramesScrollbar.setVisible(False) - self.rightImageFramesScrollbar.setDisabled(True) - if not self.isSegm3D: - self.manualBackgroundAction.setVisible(True) - self.manualBackgroundAction.setDisabled(False) - else: - self.manualBackgroundAction.setVisible(False) - self.manualBackgroundAction.setDisabled(True) - self.manualAnnotPastButton.setDisabled(True) - self.manualAnnotPastButton.action.setDisabled(True) - self.manualAnnotPastButton.setVisible(False) - self.manualAnnotPastButton.action.setVisible(False) - self.copyLostObjButton.setDisabled(True) - self.copyLostObjButton.action.setDisabled(True) - self.copyLostObjButton.setVisible(False) - self.copyLostObjButton.action.setVisible(False) - self.segForLostIDsAction.setVisible(False) - self.segForLostIDsAction.setDisabled(True) - self.delNewObjAction.setVisible(False) - self.delNewObjAction.setDisabled(True) - else: - self.imgGrad.rescaleAcrossTimeAction.setDisabled(False) - self.annotateToolbar.setVisible(False) - self.realTimeTrackingToggle.setDisabled(False) - self.repeatTrackingAction.setDisabled(False) - self.manualTrackingAction.setDisabled(False) - self.modeComboBox.setDisabled(False) - self.modeMenu.menuAction().setVisible(True) - self.skipToNewIdAction.setVisible(True) - self.skipToNewIdAction.setDisabled(False) - try: - self.modeComboBox.activated.disconnect() - self.modeComboBox.sigTextChanged.disconnect() - self.drawIDsContComboBox.currentIndexChanged.disconnect() - except Exception as e: - pass - # traceback.print_exc() - self.modeComboBox.clear() - self.modeComboBox.addItems(self.modeItems) - self.drawIDsContComboBox.clear() - self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) - self.modeComboBox.sigTextChanged.connect(self.changeMode) - self.modeComboBox.activated.connect(self.clearComboBoxFocus) - self.drawIDsContComboBox.currentIndexChanged.connect( - self.drawIDsContComboBox_cb) - self.modeComboBox.setCurrentText('Viewer') - self.showTreeInfoCheckbox.show() - self.manualBackgroundAction.setVisible(False) - self.manualBackgroundAction.setDisabled(True) - self.labelsGrad.showNextFrameAction.setDisabled(False) - self.manualAnnotPastButton.setDisabled(False) - self.manualAnnotPastButton.action.setDisabled(False) - self.manualAnnotPastButton.setVisible(True) - self.manualAnnotPastButton.action.setVisible(True) - self.copyLostObjButton.setDisabled(False) - self.copyLostObjButton.action.setDisabled(False) - self.copyLostObjButton.setVisible(True) - self.copyLostObjButton.action.setVisible(True) - self.segForLostIDsAction.setVisible(True) - self.segForLostIDsAction.setDisabled(False) - self.delNewObjAction.setVisible(True) - self.delNewObjAction.setDisabled(False) - - for ch, overlayItems in self.overlayLayersItems.items(): - lutItem = overlayItems[1] - lutItem.rescaleAcrossTimeAction.setDisabled(self.isSnapshot) - - def checkIfAutoSegm(self): - """ - If there are any frame or position with empty segmentation mask - ask whether automatic segmentation should be turned ON - """ - if self.autoSegmAction.isChecked(): - return - if self.autoSegmDoNotAskAgain: - return - - ask = False - for posData in self.data: - if posData.SizeT > 1: - for lab in posData.segm_data: - if not np.any(lab): - ask = True - txt = 'frames' - break - else: - if not np.any(posData.segm_data): - ask = True - txt = 'positions' - break - - if not ask: - return - - questionTxt = html_utils.paragraph( - f'Some or all loaded {txt} contain empty segmentation masks.

' - 'Do you want to activate automatic segmentation* ' - f'when visiting these {txt}?

' - '* Automatic segmentation can always be turned ON/OFF from the menu
' - ' Edit --> Segmentation --> Enable automatic segmentation

' - f'NOTE: you can automatically segment all {txt} using the
' - ' segmentation module.' - ) - msg = widgets.myMessageBox(wrapText=False) - noButton, yesButton = msg.question( - self, 'Automatic segmentation?', questionTxt, - buttonsTexts=('No', 'Yes') - ) - if msg.clickedButton == yesButton: - self.autoSegmAction.setChecked(True) - else: - self.autoSegmDoNotAskAgain = True - self.autoSegmAction.setChecked(False) - - def init_segmInfo_df(self): - for posData in self.data: - if posData is None: - # posData is None when computing measurements with the utility - # and with timelapse data - continue - posData.init_segmInfo_df() - - def connectScrollbars(self): - self.t_label.show() - self.navigateScrollBar.show() - self.navigateScrollBar.setDisabled(False) - - if self.data[0].SizeZ > 1: - self.enableZstackWidgets(True) - self.zSliceScrollBar.setMaximum(self.data[0].SizeZ-1) - self.zSliceSpinbox.setMaximum(self.data[0].SizeZ) - self.SizeZlabel.setText(f'/{self.data[0].SizeZ}') - try: - self.zSliceScrollBar.actionTriggered.disconnect() - self.zSliceScrollBar.sliderReleased.disconnect() - self.zProjComboBox.currentTextChanged.disconnect() - self.zProjComboBox.activated.disconnect() - self.switchPlaneCombobox.sigPlaneChanged.disconnect() - self.zProjLockViewButton.toggled.disconnect() - except Exception as e: - pass - self.zSliceScrollBar.actionTriggered.connect( - self.zSliceScrollBarActionTriggered - ) - self.zSliceScrollBar.sliderReleased.connect( - self.zSliceScrollBarReleased - ) - self.zProjComboBox.currentTextChanged.connect(self.updateZproj) - self.zProjComboBox.activated.connect(self.clearComboBoxFocus) - self.switchPlaneCombobox.sigPlaneChanged.connect( - self.switchViewedPlane - ) - self.zProjLockViewButton.toggled.connect(self.zProjLockViewToggled) - - posData = self.data[self.pos_i] - if posData.SizeT == 1: - self.t_label.setText('Position n.') - self.navigateScrollBar.setMinimum(1) - self.navigateScrollBar.setMaximum(len(self.data)) - self.navigateScrollBar.setAbsoluteMaximum(len(self.data)) - self.navSpinBox.setMaximum(len(self.data)) - self.navigateScrollBar.connectEvents({ - 'sliderMoved': self.PosScrollBarMoved, - 'sliderReleased': self.PosScrollBarReleased, - 'actionTriggered': self.PosScrollBarAction - }) - else: - self.navigateScrollBar.setMinimum(1) - self.navigateScrollBar.setAbsoluteMaximum(posData.SizeT) - self.rightImageFramesScrollbar.setMinimum(1) - self.rightImageFramesScrollbar.setMaximum(posData.SizeT) - if posData.last_tracked_i is not None: - self.navigateScrollBar.setMaximum(posData.last_tracked_i+1) - self.navSpinBox.setMaximum(posData.last_tracked_i+1) - self.t_label.setText('Frame n.') - self.navigateScrollBar.connectEvents({ - 'sliderMoved': self.framesScrollBarMoved, - 'sliderReleased': self.framesScrollBarReleased, - 'actionTriggered': self.framesScrollBarActionTriggered - }) - self.rightImageFramesScrollbar.connectValueChanged( - self.rightImageFramesScrollbarValueChanged - ) - - def zSliceScrollBarActionTriggered(self, action): - singleMove = ( - action == SliderSingleStepAdd - or action == SliderSingleStepSub - or action == SliderPageStepAdd - or action == SliderPageStepSub - ) - if singleMove: - self.update_z_slice(self.zSliceScrollBar.sliderPosition()) - elif action == SliderMove: - if self.zSliceScrollBarStartedMoving and self.isSegm3D: - self.clearAx1Items(onlyHideText=True) - self.clearAx2Items(onlyHideText=True) - posData = self.data[self.pos_i] - idx = (posData.filename, posData.frame_i) - z = self.zSliceScrollBar.sliderPosition() - if self.switchPlaneCombobox.depthAxes() == 'z': - posData.segmInfo_df.at[idx, 'z_slice_used_gui'] = z - self.zSliceSpinbox.setValueNoEmit(z+1) - img = self._getImageupdateAllImages(None) - self.img1.setCurrentZsliceIndex(z) - self.img1.setImage( - img, next_frame_image=self.nextFrameImage(), - scrollbar_value=posData.frame_i+2 - ) - try: - self.setOverlayImages() - except Exception as err: - pass - - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.img2.setImage(posData.lab, z=z, autoLevels=False) - self.updateViewerWindow() - self.setTextAnnotZsliceScrolling() - self.setGraphicalAnnotZsliceScrolling() - self.setOverlayLabelsItems() - self.drawPointsLayers(computePointsLayers=False) - self.zSliceScrollBarStartedMoving = False - self.highlightSearchedID(self.highlightedID, force=True) - - def zSliceScrollBarReleased(self): - self.clearTempBrushImage() - self.zSliceScrollBarStartedMoving = True - self.update_z_slice(self.zSliceScrollBar.sliderPosition()) - - def setSwitchViewedPlaneDisabled(self, disabled): - posData = self.data[self.pos_i] - if posData.SizeZ == 1: - return - - self.switchPlaneCombobox.setDisabled(disabled) - if disabled: - self.switchPlaneCombobox.setCurrentIndex(0) - - def _setViewRangeSwitchPlane(self, previousPlane): - posData = self.data[self.pos_i] - SizeZ = posData.SizeZ - SizeY, SizeX = self.img1.image.shape[:2] - currentPlane = self.switchPlaneCombobox.plane() - if previousPlane == 'xy': - if currentPlane == 'zy': - self.ax1.setRange(xRange=self.yRangePrev) - unusedRange = np.clip(self.xRangePrev, 0, SizeX) - elif currentPlane == 'zx': - self.ax1.setRange(xRange=self.xRangePrev) - unusedRange = np.clip(self.yRangePrev, 0, SizeY) - elif previousPlane == 'zy': - if currentPlane == 'xy': - self.ax1.setRange(yRange=self.xRangePrev) - unusedRange = np.clip(self.yRangePrev, 0, SizeZ) - elif currentPlane == 'zx': - self.ax1.setRange(yRange=self.yRangePrev) - unusedRange = np.clip(self.xRangePrev, 0, SizeY) - elif previousPlane == 'zx': - if currentPlane == 'xy': - self.ax1.setRange(xRange=self.xRangePrev) - unusedRange = np.clip(self.yRangePrev, 0, SizeZ) - elif currentPlane == 'zy': - self.ax1.setRange(yRange=self.yRangePrev) - unusedRange = np.clip(self.xRangePrev, 0, SizeX) - - sliceValue = round((unusedRange[0] + unusedRange[1])/2) - self.zSliceScrollBar.setSliderPosition(sliceValue) - self.update_z_slice(self.zSliceScrollBar.sliderPosition()) - - def setViewRangeSwitchPlane(self, previousPlane): - self.autoRange() - QTimer.singleShot( - 100, partial(self._setViewRangeSwitchPlane, previousPlane) - ) - - def switchViewedPlane(self, previousPlane, currentPlane): - posData = self.data[self.pos_i] - self.xRangePrev, self.yRangePrev = self.ax1.viewRange() - self.zSlicePrev = self.zSliceScrollBar.sliderPosition() - - self.zProjComboBox.setCurrentText('single z-slice') - depthAxes = self.switchPlaneCombobox.depthAxes() - self.onEscape() - self.initDelRoiLab() - if depthAxes != 'z': - # Disable projections on plane that is not xy - self.zProjComboBox.setCurrentText('single z-slice') - self.zProjComboBox.setDisabled(True) - - # Clear annotations - self.clearAllItems() - self.setHighlightID(False) - - # Disable annotations on a plane that is not yz - self.setDrawNothingAnnotations() - self.setDisabledAnnotCheckBoxesLeft(True) - self.setDisabledAnnotCheckBoxesRight(True) - self.setEnabledAnnotCheckBoxesLeftZdepthAxes() - self.overlayButtonPrevState = self.overlayButton.isChecked() - self.overlayButton.setChecked(False) - self.overlayButton.setDisabled(True) - else: - self.zProjComboBox.setDisabled(False) - self.restoreAnnotationsOptions() - self.setDisabledAnnotCheckBoxesLeft(False) - self.setDisabledAnnotCheckBoxesRight(False) - self.overlayButton.setDisabled(False) - if self.overlayButtonPrevState: - self.overlayButton.setChecked(self.overlayButtonPrevState) - self.updateZsliceScrollbar(posData.frame_i) - - SizeY, SizeX = posData.img_data[posData.frame_i].shape[-2:] - - if depthAxes != 'z' and self.isSnapshot: - # Disable editing when the plane is not xy - self.disableEditingViewPlaneNotXY() - elif self.isSnapshot: - # Re-enable editing in snapshot mode when the plane is xy - self.setEnabledSnapshotMode() - - if depthAxes == 'z': - maxSliceNum = posData.SizeZ - elif depthAxes == 'y': - maxSliceNum = SizeY - else: - maxSliceNum = SizeX - - maxSliceText = f'/{maxSliceNum}' - self.SizeZlabel.setText(maxSliceText) - self.zSliceCheckbox.setText(f'{depthAxes}-slice') - self.zSliceScrollBar.setMaximum(maxSliceNum-1) - self.zSliceSpinbox.setMaximum(maxSliceNum) - - self.initContoursImage() - self.updateAllImages() - QTimer.singleShot( - 200, partial(self.setViewRangeSwitchPlane, previousPlane) - ) - - def onZsliceSpinboxValueChange(self, value): - self.zSliceScrollBar.setSliderPosition(value-1) - - def update_z_slice(self, z): - posData = self.data[self.pos_i] - if self.switchPlaneCombobox.depthAxes() == 'z': - if self.zProjLockViewButton.isChecked(): - idx = [ - (posData.filename, frame_i) - for frame_i in range(posData.SizeT) - ] - else: - idx = [ - (posData.filename, frame_i) - for frame_i in range(posData.frame_i, posData.SizeT) - ] - posData.segmInfo_df.loc[idx, 'z_slice_used_gui'] = z - - self.updatePreprocessPreview() - self.updateCombineChannelsPreview() - self.highlightedID = self.getHighlightedID() - self.updateAllImages( - computePointsLayers=False, - computeContours=False, - updateLookuptable=True - ) - self.updateItemsMousePos() - if self.isSegm3D: - self.updateObjectCounts() - - def updateOverlayZslice(self, z): - self.setOverlayImages() - - def updateOverlayZproj(self, how): - if how.find('max') != -1 or how == 'same as above': - self.overlay_z_label.setDisabled(True) - self.zSliceOverlay_SB.setDisabled(True) - else: - self.overlay_z_label.setDisabled(False) - self.zSliceOverlay_SB.setDisabled(False) - self.setOverlayImages() - - def updateZproj(self, how): - for p, posData in enumerate(self.data[self.pos_i:]): - if self.zProjLockViewButton.isChecked(): - idx = [ - (posData.filename, frame_i) - for frame_i in range(posData.SizeT) - ] - else: - idx = [(posData.filename, posData.frame_i)] - posData.segmInfo_df.loc[idx, 'which_z_proj_gui'] = how - posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - - posData = self.data[self.pos_i] - if how == 'single z-slice': - self.zSliceScrollBar.setDisabled(False) - self.zSliceSpinbox.setDisabled(False) - self.zSliceCheckbox.setDisabled(False) - self.setZprojDisabled(False) - self.update_z_slice(self.zSliceScrollBar.sliderPosition()) - else: - self.zSliceScrollBar.setDisabled(True) - self.zSliceSpinbox.setDisabled(True) - self.zSliceCheckbox.setDisabled(True) - self.setZprojDisabled(self.isSegm3D) - self.updateAllImages() - - def setZprojDisabled(self, disabled, storePrevState=False): - self.combineChannelsAction.setDisabled(disabled) - for action in self.editToolBar.actions(): - button = self.editToolBar.widgetForAction(action) - if button == self.eraserButton: - continue - - if button in self.toolsActiveInProj3Dsegm: - continue - - try: - tooltip = button.toolTip() - prefix = 'WARNING: Disabled due to projection mode\n\n' - if disabled: - if not tooltip.startswith(prefix): - button.setToolTip(prefix + tooltip) - else: - if tooltip.startswith(prefix): - button.setToolTip(tooltip[len(prefix):]) - except: - pass - action.setDisabled(disabled) - try: - button.setChecked(False) - except Exception as err: - pass - - def clearAx2Items(self, onlyHideText=False): - self.ax2_binnedIDs_ScatterPlot.clear() - self.ax2_ripIDs_ScatterPlot.clear() - self.ax2_contoursImageItem.clear() - self.ax2_lostObjImageItem.clear() - self.ax2_lostTrackedObjImageItem.clear() - self.textAnnot[1].clear() - self.ax2_newMothBudLinesItem.setData([], []) - self.ax2_oldMothBudLinesItem.setData([], []) - self.ax2_lostObjScatterItem.setData([], []) - - def clearAx1Items(self, onlyHideText=False): - self.ax1_binnedIDs_ScatterPlot.clear() - self.ax1_ripIDs_ScatterPlot.clear() - self.labelsLayerImg1.clear() - self.labelsLayerRightImg.clear() - self.keepIDsTempLayerLeft.clear() - self.keepIDsTempLayerRight.clear() - self.highLightIDLayerImg1.clear() - self.highLightIDLayerRightImage.clear() - self.searchedIDitemLeft.clear() - self.searchedIDitemRight.clear() - self.ax1_contoursImageItem.clear() - self.ax1_lostObjImageItem.clear() - self.ax1_lostTrackedObjImageItem.clear() - self.textAnnot[0].clear() - self.ax1_newMothBudLinesItem.setData([], []) - self.ax1_oldMothBudLinesItem.setData([], []) - self.ax1_lostObjScatterItem.setData([], []) - self.ax1_lostTrackedScatterItem.setData([], []) - self.ccaFailedScatterItem.setData([], []) - self.yellowContourScatterItem.setData([], []) - - self.clearPointsLayers() - - self.clearOverlayLabelsItems() - self.clearManualBackgroundAnnotations() - self.clearCustomAnnot() - - def clearPointsLayers(self): - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - try: - action.scatterItem.clear() - except Exception as e: - continue - - def clearOverlayLabelsItems(self): - for segmEndname, drawMode in self.drawModeOverlayLabelsChannels.items(): - items = self.overlayLabelsItems[segmEndname] - imageItem, contoursItem, gradItem = items - imageItem.clear() - contoursItem.clear() - - def clearAllItems(self): - self.clearAx1Items() - self.clearAx2Items() - - def clearCustomAnnot(self): - for button in self.customAnnotDict.keys(): - scatterPlotItem = self.customAnnotDict[button]['scatterPlotItem'] - scatterPlotItem.setData([], []) - - def clearCurvItems(self, removeItems=True): - try: - posData = self.data[self.pos_i] - curvItems = zip(posData.curvPlotItems, - posData.curvAnchorsItems, - posData.curvHoverItems) - for plotItem, curvAnchors, hoverItem in curvItems: - plotItem.setData([], []) - curvAnchors.setData([], []) - hoverItem.setData([], []) - if removeItems: - self.ax1.removeItem(plotItem) - self.ax1.removeItem(curvAnchors) - self.ax1.removeItem(hoverItem) - - if removeItems: - posData.curvPlotItems = [] - posData.curvAnchorsItems = [] - posData.curvHoverItems = [] - except AttributeError: - # traceback.print_exc() - pass - - # @exec_time - def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): - posData = self.data[self.pos_i] - # Store undo state before modifying stuff - self.storeUndoRedoStates(False, storeOnlyZoom=True) - - if isRightClick: - xxS, yyS = self.curvPlotItem.getData() - if xxS is None: - self.setUncheckedAllButtons() - return - self.smoothAutoContWithSpline() - - xxS, yyS = self.getClosedSplineCoords() - - if self.autoIDcheckbox.isChecked(): - self.setBrushID() - curvToolID = posData.brushID - else: - curvToolID = self.editIDspinbox.value() - posData.brushID = curvToolID - - if curvToolID <= 0: - self.setBrushID() - curvToolID = posData.brushID - - lab2D = self.get_2Dlab(posData.lab).copy() - newIDMask = np.zeros(lab2D.shape, bool) - rr, cc = skimage.draw.polygon(yyS, xxS, shape=lab2D.shape) - newIDMask[rr, cc] = True - newIDMask[lab2D!=0] = False - lab2D[newIDMask] = curvToolID - self.set_2Dlab(lab2D) - self.currentLab2D = lab2D - - def addFluoChNameContextMenuAction(self, ch_name): - posData = self.data[self.pos_i] - allTexts = [ - action.text() for action in self.chNamesQActionGroup.actions() - ] - if ch_name not in allTexts: - action = QAction(self) - action.setText(ch_name) - action.setCheckable(True) - self.chNamesQActionGroup.addAction(action) - action.setChecked(True) - self.fluoDataChNameActions.append(action) - - def computeSegm(self, force=False): - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer' or mode == 'Cell cycle analysis': - return - - if np.any(posData.lab) and not force: - # Do not compute segm if there is already a mask - return - - if not self.autoSegmAction.isChecked(): - return - - self.repeatSegm(model_name=self.segmModelName) - - def initImgCmap(self): - if not 'img_cmap' in self.df_settings.index: - self.df_settings.at['img_cmap', 'value'] = 'grey' - self.imgCmapName = self.df_settings.at['img_cmap', 'value'] - self.imgCmap = self.imgGrad.cmaps[self.imgCmapName] - if self.imgCmapName != 'grey': - # To ensure mapping to colors we need to normalize image - self.normalizeByMaxAction.setChecked(True) - - def initMetricsToSave(self, posData): - self._measurements_kernel._init_metrics_to_save(posData) - - def initMetrics(self): - self.logger.info('Initializing measurements...') - posData = self.data[self.pos_i] - self._measurements_kernel = cli.ComputeMeasurementsKernel( - self.logger, self.log_path, False - ) - self._measurements_kernel.init_args( - posData.chNames, posData.getSegmEndname() - ) - self._measurements_kernel._init_metrics(posData, self.isSegm3D) - - def initPosAttr(self): - exp_path = self.data[self.pos_i].exp_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) - if len(pos_foldernames) == 1: - self.loadPosAction.setDisabled(True) - else: - self.loadPosAction.setDisabled(False) - - for p, posData in enumerate(self.data): - self.pos_i = p - posData.curvPlotItems = [] - posData.curvAnchorsItems = [] - posData.curvHoverItems = [] - posData.trackedLostIDs = set() - - posData.HDDmaxID = np.max(posData.segm_data) - - # Decision on what to do with changes to future frames attr - posData.doNotShowAgain_EditID = False - posData.UndoFutFrames_EditID = False - posData.applyFutFrames_EditID = False - - posData.doNotShowAgain_RipID = False - posData.UndoFutFrames_RipID = False - posData.applyFutFrames_RipID = False - - posData.doNotShowAgain_DelID = False - posData.UndoFutFrames_DelID = False - posData.applyFutFrames_DelID = False - - posData.doNotShowAgain_keepID = False - posData.UndoFutFrames_keepID = False - posData.applyFutFrames_keepID = False - - posData.doNotShowAgainAssignNewID = False - posData.UndoFutFramesAssignNewID = False - posData.applyFutFramesAssignNewID = False - - posData.includeUnvisitedInfo = { - 'Delete ID': False, 'Edit ID': False, 'Keep ID': False - } - - posData.loadTrackedLostCentroids() - posData.acdcTracker2stepsAnnotInfo = {} - - posData.doNotShowAgain_BinID = False - posData.UndoFutFrames_BinID = False - posData.applyFutFrames_BinID = False - - posData.disableAutoActivateViewerWindow = False - posData.new_IDs = [] - posData.lost_IDs = [] - posData.multiBud_mothIDs = [2] - posData.UndoRedoStates = [[] for _ in range(posData.SizeT)] - posData.UndoRedoCcaStates = [[] for _ in range(posData.SizeT)] - - posData.ol_data_dict = {} - posData.ol_data = None - - posData.ol_labels_data = None - - missing_frames = posData.SizeT - len(posData.allData_li) - if missing_frames > 0: - posData.allData_li.extend([None] * missing_frames) - for i in range(posData.SizeT): - if posData.allData_li[i] is None: - posData.allData_li[i] = ( - myutils.get_empty_stored_data_dict() - ) - - posData.lutLevels = {channel: {} for channel in self.ch_names} - - posData.ccaStatus_whenEmerged = {} - - posData.frame_i = 0 - posData.brushID = 0 - posData.binnedIDs = set() - posData.ripIDs = set() - posData.cca_df = None - if posData.last_tracked_i is not None: - last_tracked_num = posData.last_tracked_i+1 - # Load previous session data - # Keep track of which ROIs have already been added - # in previous frame - delROIshapes = [[] for _ in range(posData.SizeT)] - for i in range(last_tracked_num): - posData.frame_i = i - self.get_data(debug=True) - self.store_data( - enforce=True, autosave=False, store_cca_df_copy=True - ) - - # Ask whether to resume from last frame - if last_tracked_num>1: - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Cell-ACDC detected a previous session ended ' - f'at frame {last_tracked_num}.

' - f'Do you want to resume from frame ' - f'{last_tracked_num}?' - ) - noButton, yesButton = msg.question( - self, 'Start from last session?', txt, - buttonsTexts=(' No ', 'Yes') - ) - self.AutoPilotProfile.storeClickMessageBox( - 'Start from last session?', msg.clickedButton.text() - ) - if msg.clickedButton == yesButton: - posData.frame_i = posData.last_tracked_i - self.lastFrameRanOnFirstVisitTools = posData.frame_i - else: - posData.frame_i = 0 - - posData.img_data_min_max = ( - posData.img_data.min(), posData.img_data.max() - ) - - # Back to first position - self.pos_i = 0 - self.get_data(debug=False) - self.store_data(autosave=False) - # self.updateAllImages() - - # Link Y and X axis of both plots to scroll zoom and pan together - self.ax2.vb.setYLink(self.ax1.vb) - self.ax2.vb.setXLink(self.ax1.vb) - - self.setAllIDs() - - def navigateSpinboxValueChanged(self, value): - self.navigateScrollBar.setSliderPosition(value) - if self.isSnapshot: - self.PosScrollBarMoved(value) - else: - self.navigateScrollBarStartedMoving = True - self.framesScrollBarMoved(value) - - def navigateSpinboxEditingFinished(self): - if self.isSnapshot: - self.PosScrollBarReleased() - else: - self.framesScrollBarReleased() - - def PosScrollBarAction(self, action): - if action == SliderSingleStepAdd: - self.next_cb() - elif action == SliderSingleStepSub: - self.prev_cb() - elif action == SliderPageStepAdd: - self.PosScrollBarReleased() - elif action == SliderPageStepSub: - self.PosScrollBarReleased() - - def PosScrollBarMoved(self, pos_n): - if self.navigateScrollBarStartedMoving: - self.store_data() - - self.pos_i = pos_n-1 - self.updateFramePosLabel() - proceed_cca, never_visited = self.get_data() - self.updateAllImages() - self.setStatusBarLabel() - self.navigateScrollBarStartedMoving = False - - def PosScrollBarReleased(self): - self.navigateScrollBarStartedMoving = True - if self.pos_i == self.navigateScrollBar.sliderPosition()-1: - # Slider released without changing value --> do nothing - return - - self.pos_i = self.navigateScrollBar.sliderPosition()-1 - self.updateFramePosLabel() - self.updatePos() - - def resetNavigateFramesScrollbar(self, frame_i=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - self.navigateScrollBar.setValueNoSignal(frame_i+1) - - def framesScrollBarActionTriggered(self, action): - if action == SliderSingleStepAdd: - # Clicking on dialogs triggered by next_cb might trigger - # pressEvent of navigateQScrollBar, avoid that - self.navigateScrollBar.disableCustomPressEvent() - self.next_cb() - QTimer.singleShot(100, self.navigateScrollBar.enableCustomPressEvent) - elif action == SliderSingleStepSub: - self.prev_cb() - elif action == SliderPageStepAdd: - self.framesScrollBarReleased(do_store_data=True) - elif action == SliderPageStepSub: - self.framesScrollBarReleased(do_store_data=True) - - def framesScrollBarMoved(self, frame_n): - if self.navigateScrollBarStartedMoving: - mode = str(self.modeComboBox.currentText()) - if mode != 'Viewer': - self.store_data(debug=False) - - posData = self.data[self.pos_i] - posData.frame_i = frame_n-1 - if posData.allData_li[posData.frame_i]['labels'] is None: - if posData.frame_i < len(posData.segm_data): - posData.lab = posData.segm_data[posData.frame_i] - else: - posData.lab = np.zeros_like(posData.segm_data[0]) - else: - posData.lab = posData.allData_li[posData.frame_i]['labels'] - - self.setImageImg1() - if self.overlayButton.isChecked(): - self.setOverlayImages() - - if self.navigateScrollBarStartedMoving: - self.clearAllItems() - - self.navSpinBox.setValueNoEmit(posData.frame_i+1) - if self.labelsGrad.showLabelsImgAction.isChecked(): - self.img2.setImage(posData.lab, z=self.z_lab(), autoLevels=False) - self.updateLookuptable() - self.updateFramePosLabel() - self.updateViewerWindow() - self.updateTimestampFrame() - self.updateHighlightedAxis() - self.navigateScrollBarStartedMoving = False - - def framesScrollBarReleased(self, do_store_data=False): - posData = self.data[self.pos_i] - if posData.frame_i == self.navigateScrollBar.sliderPosition()-1: - # Slider released without changing value --> do nothing - return - - mode = str(self.modeComboBox.currentText()) - if mode != 'Viewer' and do_store_data: - self.store_data(debug=False) - - self.navigateScrollBarStartedMoving = True - posData.frame_i = self.navigateScrollBar.sliderPosition()-1 - self.updateFramePosLabel() - proceed_cca, never_visited = self.get_data() - self.updateAllImages() - - def unstore_data(self): - posData = self.data[self.pos_i] - posData.allData_li[posData.frame_i] = myutils.get_empty_stored_data_dict() - - def getStoredSegmData(self): - posData = self.data[self.pos_i] - segm_data = [] - for data_frame_i in posData.allData_li: - lab = data_frame_i['labels'] - if lab is None: - break - segm_data.append(lab) - return np.array(segm_data) - - def trackNewIDtoNewIDsFutureFrame(self, newID, newIDmask): - posData = self.data[self.pos_i] - try: - nextLab = posData.allData_li[posData.frame_i+1]['labels'] - except IndexError: - # This is last frame --> there are no future frames - return - - if nextLab is None: - return - - newID_lab = np.zeros_like(posData.lab) - newID_lab[newIDmask] = newID - newLab_rp = [posData.rp[posData.IDs_idxs[newID]]] - newLab_IDs = [newID] - nextRp = posData.allData_li[posData.frame_i+1]['regionprops'] - - tracked_lab = self.trackFrame( - nextLab, nextRp, newID_lab, newLab_rp, newLab_IDs, - assign_unique_new_IDs=False - ) - trackedID = tracked_lab[newID_lab>0][0] - if trackedID == newID: - # Object does not exist in future frame --> do not track - return - - if posData.IDs_idxs.get(trackedID) is not None: - # Tracked ID already exists --> do not track to avoid merging - return - - return trackedID - - def store_manual_annot_data( - self, posData=None, data_frame_i=None - ): - if posData is None: - posData = self.data[self.pos_i] - - if data_frame_i is None: - data_frame_i = posData.allData_li[posData.frame_i] - - if not self.isSegm3D: - lab = [posData.lab] - else: - lab = posData.lab - - for z, lab_2D in enumerate(lab): - data_frame_i['manually_edited_lab']['lab'][z] = lab_2D - - # data_frame_i['manually_edited_lab']['zoom_slice'] = zoom_slice - - @exception_handler - def store_data( - self, pos_i=None, enforce=True, debug=False, mainThread=True, - autosave=True, store_cca_df_copy=False - ): - pos_i = self.pos_i if pos_i is None else pos_i - posData = self.data[pos_i] - if posData.frame_i < 0: - # In some cases we set frame_i = -1 and then call next_frame - # to visualize frame 0. In that case we don't store data - # for frame_i = -1 - return - - mode = str(self.modeComboBox.currentText()) - - if mode == 'Viewer' and not enforce: - return - - # if not mainThread: - # self.lin_tree_ask_changes() - - allData_li = posData.allData_li[posData.frame_i] - allData_li['regionprops'] = posData.rp.copy() - allData_li['labels'] = posData.lab.copy() - allData_li['IDs'] = posData.IDs.copy() - allData_li['manualBackgroundLab'] = ( - posData.manualBackgroundLab - ) - allData_li['IDs_idxs'] = ( - posData.IDs_idxs.copy() - ) - if self.manualAnnotPastButton.isChecked(): - self.store_manual_annot_data( - posData=posData, data_frame_i=allData_li - ) - - self.store_zslices_rp() - - # Store dynamic metadata - is_cell_dead_li = [False]*len(posData.rp) - is_cell_excluded_li = [False]*len(posData.rp) - IDs = [0]*len(posData.rp) - xx_centroid = [0]*len(posData.rp) - yy_centroid = [0]*len(posData.rp) - if self.isSegm3D: - zz_centroid = [0]*len(posData.rp) - areManuallyEdited = [0]*len(posData.rp) - editedNewIDs = [vals[2] for vals in posData.editID_info] - for i, obj in enumerate(posData.rp): - is_cell_dead_li[i] = obj.dead - is_cell_excluded_li[i] = obj.excluded - IDs[i] = obj.label - try: - xx_centroid[i] = int(self.getObjCentroid(obj.centroid)[1]) - yy_centroid[i] = int(self.getObjCentroid(obj.centroid)[0]) - except Exception as err: - printl(obj, obj.centroid, obj.label, posData.frame_i) - if self.isSegm3D: - zz_centroid[i] = int(obj.centroid[0]) - if obj.label in editedNewIDs: - areManuallyEdited[i] = 1 - - posData.STOREDmaxID = max(IDs, default=0) - - acdc_df = allData_li['acdc_df'] - if acdc_df is None: - allData_li['acdc_df'] = pd.DataFrame( - { - 'Cell_ID': IDs, - 'is_cell_dead': is_cell_dead_li, - 'is_cell_excluded': is_cell_excluded_li, - 'x_centroid': xx_centroid, - 'y_centroid': yy_centroid, - 'was_manually_edited': areManuallyEdited - } - ).set_index('Cell_ID') - - if self.isSegm3D: - allData_li['acdc_df']['z_centroid'] = ( - zz_centroid - ) - else: - # Filter or add IDs that were not stored yet - acdc_df = acdc_df.drop(columns=['time_seconds'], errors='ignore') - acdc_df = acdc_df.reindex(IDs, fill_value=0) - acdc_df['is_cell_dead'] = is_cell_dead_li - acdc_df['is_cell_excluded'] = is_cell_excluded_li - acdc_df['x_centroid'] = xx_centroid - acdc_df['y_centroid'] = yy_centroid - if self.isSegm3D: - acdc_df['z_centroid'] = zz_centroid - acdc_df['was_manually_edited'] = areManuallyEdited - allData_li['acdc_df'] = acdc_df - - if mainThread: - self.pointsLayerDataToDf(posData) - - self.store_cca_df( - pos_i=pos_i, mainThread=mainThread, autosave=autosave, - store_cca_df_copy=store_cca_df_copy - ) - - def nearest_point_2Dyx(self, points, all_others): - """ - Given 2D array of [y, x] coordinates points and all_others return the - [y, x] coordinates of the two points (one from points and one from all_others) - that have the absolute minimum distance - """ - # Compute 3D array where each ith row of each kth page is the element-wise - # difference between kth row of points and ith row in all_others array. - # (i.e. diff[k,i] = points[k] - all_others[i]) - diff = points[:, np.newaxis] - all_others - # Compute 2D array of distances where - # dist[i, j] = euclidean dist (points[i],all_others[j]) - dist = np.linalg.norm(diff, axis=2) - # Compute i, j indexes of the absolute minimum distance - i, j = np.unravel_index(dist.argmin(), dist.shape) - nearest_point = all_others[j] - point = points[i] - min_dist = np.min(dist) - return min_dist, nearest_point - - def isCurrentFrameCcaVisited(self): - posData = self.data[self.pos_i] - curr_df = posData.allData_li[posData.frame_i]['acdc_df'] - return curr_df is not None and 'cell_cycle_stage' in curr_df.columns - - def warnScellsGone(self, ScellsIDsGone, frame_i): - msg = widgets.myMessageBox() - text = html_utils.paragraph(f""" - In the next frame the followning cells' IDs in S/G2/M - (highlighted with a yellow contour) will disappear:

- {ScellsIDsGone}

- If the cell does not exist you might have deleted it at some point. - If that's the case, then try to go to some previous frames and reset - the cell cycle annotations there (button on the top toolbar).

- These cells are either buds or mother whose related IDs will not - disappear. This is likely due to cell division happening in - previous frame and the divided bud or mother will be - washed away.

- If you decide to continue these cells will be automatically - annotated as divided at frame number {frame_i}.

- Do you want to continue? - """) - _, yesButton, noButton = msg.warning( - self, 'Cells in "S/G2/M" disappeared!', text, - buttonsTexts=('Cancel', 'Yes', 'No') - ) - return msg.clickedButton == yesButton - - def checkScellsGone(self): - """Check if there are cells in S phase whose relative disappear in - current frame. Allow user to choose between automatically assign - division to these cells or cancel and not visit the frame. - - Returns - ------- - bool - False if there are no cells disappeared or the user decided - to accept automatic division. - """ - automaticallyDividedIDs = [] - - mode = str(self.modeComboBox.currentText()) - if mode.find('Cell cycle') == -1: - # No cell cycle analysis mode --> do nothing - return False, automaticallyDividedIDs - - posData = self.data[self.pos_i] - - if posData.allData_li[posData.frame_i]['labels'] is None: - # Frame never visited/checked in segm mode --> autoCca_df will raise - # a critical message - return False, automaticallyDividedIDs - - # Check if there are S cells that either only mother or only - # bud disappeared and automatically assign division to it - # or abort visiting this frame - prev_acdc_df = posData.allData_li[posData.frame_i-1]['acdc_df'] - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_cca_df = prev_acdc_df[self.cca_df_colnames].copy() - - ScellsIDsGone = [] - for ccSeries in prev_cca_df.itertuples(): - ID = ccSeries.Index - ccs = ccSeries.cell_cycle_stage - if ccs != 'S': - continue - - relID = ccSeries.relative_ID - if relID == -1: - continue - - # Check is relID is gone while ID stays - if relID not in posData.IDs and ID in posData.IDs: - ScellsIDsGone.append(relID) - - if not ScellsIDsGone: - # No cells in S that disappears --> do nothing - return False, automaticallyDividedIDs - - self.highlightNewIDs_ccaFailed(ScellsIDsGone, rp=prev_rp) - proceed = self.warnScellsGone(ScellsIDsGone, posData.frame_i) - self.clearLostObjContoursItems() - - if not proceed: - return True, automaticallyDividedIDs - - for IDgone in ScellsIDsGone: - relID = prev_cca_df.at[IDgone, 'relative_ID'] - self.annotateDisappearedBeforeDivision(relID, IDgone, prev_cca_df) - self.annotateDivision( - prev_cca_df, IDgone, relID, frame_i=posData.frame_i-1 - ) - self.annotateDivisionCurrentFrameRelativeIDgone(relID) - automaticallyDividedIDs.append(relID) - - self.store_cca_df(frame_i=posData.frame_i-1, cca_df=prev_cca_df) - - return False, automaticallyDividedIDs - - def annotateDivisionCurrentFrameRelativeIDgone(self, IDwhoseRelativeIsGone): - posData = self.data[self.pos_i] - if posData.cca_df is None: - return - ID = IDwhoseRelativeIsGone - posData.cca_df.at[ID, 'generation_num'] += 1 - posData.cca_df.at[ID, 'division_frame_i'] = posData.frame_i-1 - posData.cca_df.at[ID, 'relationship'] = 'mother' - - def annotateDisappearedBeforeDivision( - self, relID, IDgone, cca_df, frame_i=None - ): - posData = self.data[self.pos_i] - gen_num = cca_df.at[relID, 'generation_num'] - if frame_i is None: - frame_i = posData.frame_i - - for past_frame_i in range(frame_i-1, -1, -1): - past_cca_df = self.get_cca_df(frame_i=past_frame_i, return_df=True) - if past_cca_df is None: - return - - try: - if past_cca_df.at[relID, 'generation_num'] != gen_num: - # ID is a mother and the cell cycle is finished here - return - except Exception as err: - # Bud stops existing --> stop process - return - - past_cca_df.at[IDgone, 'disappears_before_division'] = 1 - past_cca_df.at[relID, 'daughter_disappears_before_division'] = 1 - - self.store_cca_df( - cca_df=past_cca_df, frame_i=past_frame_i, autosave=False - ) - - @exception_handler - def attempt_auto_cca(self, enforceAll=False): - mode = str(self.modeComboBox.currentText()) - posData = self.data[self.pos_i] - - if mode == 'Cell cycle analysis': - notEnoughG1Cells, proceed = self.autoCca_df( - enforceAll=enforceAll - ) - if not proceed: - return notEnoughG1Cells, proceed - - # mode = str(self.modeComboBox.currentText()) - if posData.cca_df is None: # ??? - notEnoughG1Cells = False - proceed = True - return notEnoughG1Cells, proceed - if posData.cca_df.isna().any(axis=None): - raise ValueError('Cell cycle analysis table contains NaNs') - # self.checkMultiBudMoth() - proceed = self.checkMothersExcludedOrDead() - return notEnoughG1Cells, proceed - - elif mode == 'Normal division: Lineage tree': - self.autoLinTree_df() - notEnoughG1Cells = False - proceed = True - return notEnoughG1Cells, proceed - - else: - notEnoughG1Cells = False - proceed = True - return notEnoughG1Cells, proceed - - - - def highlightIDs(self, IDs, pen): - pass - - def warnFrameNeverVisitedSegmMode(self): - msg = widgets.myMessageBox() - warn_cca = msg.critical( - self, 'Next frame NEVER visited', - 'Next frame was never visited in "Segmentation and Tracking"' - 'mode.\n You cannot perform cell cycle analysis on frames' - 'where segmentation and/or tracking errors were not' - 'checked/corrected.\n\n' - 'Switch to "Segmentation and Tracking" mode ' - 'and check/correct next frame,\n' - 'before attempting cell cycle analysis again', - ) - return False - - def checkCcaPastFramesNewIDs(self): - posData = self.data[self.pos_i] - if not posData.new_IDs: - return - - found_cca_df_IDs = [] - for frame_i in range(posData.frame_i-2, -1, -1): - acdc_df = posData.allData_li[frame_i]['acdc_df'] - cca_df_i = acdc_df[self.cca_df_colnames] - intersect_idx = cca_df_i.index.intersection(posData.new_IDs) - cca_df_i = cca_df_i.loc[intersect_idx] - if cca_df_i.empty: - continue - found_cca_df_IDs.append(cca_df_i) - - # Remove IDs found in past frames from new_IDs list - newIDs = np.array(posData.new_IDs, dtype=np.uint32) - mask_index = np.in1d(newIDs, cca_df_i.index) - posData.new_IDs = list(newIDs[~mask_index]) - if not posData.new_IDs: - return found_cca_df_IDs - return found_cca_df_IDs - - def initMissingFramesCca(self, last_cca_frame_i, current_frame_i): - self.logger.info( - 'Initialising cell cycle annotations of missing past frames...' - ) - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - - annotated_cca_dfs = [] - for frame_i in range(last_cca_frame_i+1): - acdc_df = posData.allData_li[frame_i]['acdc_df'] - if 'cell_cycle_stage' in acdc_df.columns: - continue - - acdc_df[self.cca_df_colnames] = '' - - annotated_cca_dfs = [ - posData.allData_li[i]['acdc_df'][self.cca_df_colnames] - for i in range(last_cca_frame_i+1) - ] - keys = range(last_cca_frame_i+1) - names = ['frame_i', 'Cell_ID'] - annotated_cca_df = ( - pd.concat(annotated_cca_dfs, keys=keys, names=names) - .reset_index() - .set_index(['Cell_ID', 'frame_i']) - .sort_index() - ) - - last_annotated_cca_df = annotated_cca_df.groupby(level=0).last() - cca_df_colnames = self.cca_df_colnames - pbar = tqdm(total=current_frame_i-last_cca_frame_i+1, ncols=100) - for frame_i in range(last_cca_frame_i, current_frame_i+1): - posData.frame_i = frame_i - self.get_data() - cca_df = self.getBaseCca_df() - - idx = last_annotated_cca_df.index.intersection(cca_df.index) - cca_df.loc[idx, cca_df_colnames] = last_annotated_cca_df.loc[idx] - - self.store_cca_df(cca_df=cca_df, frame_i=frame_i, autosave=False) - pbar.update() - pbar.close() - - posData.frame_i = current_frame_i - self.get_data() - - def initMissingFramesLinTree(self, current_frame_i): # done Need to add partially missing previous frames and loading - """ - When not starting from the first frame, automatically creates lineage tree dfs for all "skipped" frames and initializes the tree if not done so before. - - Parameters - ---------- - current_frame_i : int - The index of the current frame. - - Returns - ------- - None - - Notes - ----- - This method initializes the lineage tree annotations of missing past frames. If the lineage tree has not been initialized before, it creates a new lineage tree based on the labels of the first frame. It then iterates over the missing frames and updates the lineage tree with the labels and region properties of each frame. - """ - - self.logger.info( - 'Initialising lineage tree annotations of missing past frames...' - ) - - self.store_data(autosave=False) - self.get_data() - - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - - if not self.lineage_tree: # init lin tree if not done already - self.lineage_tree = normal_division_lineage_tree(gui=self) # here frame_i!=0 - - missing_frames = list(range(current_frame_i+1)) - present_frames = list(self.lineage_tree.frames_for_dfs) if self.lineage_tree else [] - present_frames = [] if not present_frames else present_frames # deal with None - missing_frames = [frame_i for frame_i in missing_frames if frame_i not in present_frames] - missing_frames.sort() - - for frame_i in missing_frames: - lab = posData.allData_li[frame_i]['labels'] - prev_lab = posData.allData_li[frame_i-1]['labels'] - rp = posData.allData_li[frame_i]['regionprops'] - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - # i might need to change this if I need support for only partially missing frames... Although I probably never have to care about that though - self.lineage_tree.real_time(frame_i, lab, prev_lab, rp=rp, prev_rp=prev_rp) - - posData.frame_i = current_frame_i - self.store_data() - - def _getCcaCostMatrix( - self, numCellsG1, numNewCells, IDsCellsG1, newIDs_contours - ): - posData = self.data[self.pos_i] - dataDict = posData.allData_li[posData.frame_i] - dist_matrix_df = dataDict.get('obj_to_obj_dist_cost_matrix_df') - if dist_matrix_df is None: - cost = np.full((numCellsG1, numNewCells), np.inf) - for obj in posData.rp: - ID = obj.label - try: - i = IDsCellsG1.index(ID) - except ValueError: - continue - - cont = self.getObjContours(obj) - i = IDsCellsG1.index(ID) - - # Get distance from cell in G1 and all other new cells - for j, newID_cont in enumerate(newIDs_contours): - min_dist, nearest_xy = self.nearest_point_2Dyx( - cont, newID_cont - ) - cost[i, j] = min_dist - - return cost - - cost = dist_matrix_df.loc[IDsCellsG1, posData.new_IDs].values - - return cost - - def autoCca_df(self, enforceAll=False): - """ - Assign each bud to a mother with scipy linear sum assignment - (Hungarian or Munkres algorithm). First we build a cost matrix where - each (i, j) element is the minimum distance between bud i and mother j. - Then we minimize the cost of assigning each bud to a mother, and finally - we write the assignment info into cca_df - """ - proceed = True - notEnoughG1Cells = False - ScellsGone = False - - posData = self.data[self.pos_i] - - # Skip cca if not the right mode - mode = str(self.modeComboBox.currentText()) - if mode.find('Cell cycle') == -1: - return notEnoughG1Cells, proceed - - - # Make sure that this is a visited frame in segmentation tracking mode - if posData.allData_li[posData.frame_i]['labels'] is None: - proceed = self.warnFrameNeverVisitedSegmMode() - return notEnoughG1Cells, proceed - - # Determine if this is the last visited frame for repeating - # bud assignment on non manually correct (corrected_on_frame_i>0) buds. - # The idea is that the user could have assigned division on a cell - # by going previous and we want to check if this cell could be a - # "better" mother for those non manually corrected buds - curr_df = posData.allData_li[posData.frame_i]['acdc_df'] - isLastVisitedAgain = self.isLastVisitedAgainCca( - curr_df, enforceAll=enforceAll - ) - - frameAlreadyAnnotated = ( - posData.cca_df is not None - and not enforceAll - and not isLastVisitedAgain - ) - # Use stored cca_df and do not modify it with automatic stuff - if frameAlreadyAnnotated: - return notEnoughG1Cells, proceed - - # Keep only correctedAssignIDs if requested - # For the last visited frame we perform assignment again only on - # IDs where we didn't manually correct assignment - correctedAssignIDs = set() - if isLastVisitedAgain and not enforceAll: - try: - correctedAssignIDs = curr_df[ - curr_df['corrected_on_frame_i']>0 - ].index - except Exception as e: - correctedAssignIDs = [] - posData.new_IDs = [ - ID for ID in posData.new_IDs - if ID not in correctedAssignIDs - ] - - # Check if new IDs exist some time in the past - found_cca_df_IDs = self.checkCcaPastFramesNewIDs() - - # Check if there are some S cells that disappeared - abort, automaticallyDividedIDs = self.checkScellsGone() - if abort: - notEnoughG1Cells = False - proceed = False - return notEnoughG1Cells, proceed - - # Get previous dataframe - acdc_df = posData.allData_li[posData.frame_i-1]['acdc_df'] - prev_cca_df = acdc_df[self.cca_df_colnames].copy() - - if posData.cca_df is None: - posData.cca_df = prev_cca_df.copy() - else: - posData.cca_df = curr_df[self.cca_df_colnames].copy() - - # concatenate new IDs found in past frames (before frame_i-1) - if found_cca_df_IDs is not None: - cca_df = pd.concat([posData.cca_df, *found_cca_df_IDs]) - unique_idx = ~cca_df.index.duplicated(keep='first') - posData.cca_df = cca_df[unique_idx] - - # If there are no new IDs we are done - if not posData.new_IDs: - proceed = True - self.store_cca_df() - return notEnoughG1Cells, proceed - - # Get cells in G1 (exclude dead) and check if there are enough cells in G1 - try: - prev_df_G1 = prev_cca_df[prev_cca_df['cell_cycle_stage']=='G1'] - prev_df_G1 = prev_df_G1[~acdc_df.loc[prev_df_G1.index]['is_cell_dead']] - IDsCellsG1 = set(prev_df_G1.index) - except Exception as err: - IDsCellsG1 = set() - - if isLastVisitedAgain or enforceAll: - # If we are repeating auto cca for last visited frame - # then we also add the cells in G1 that appears in current frame - # and we remove the ones that are already in S in current frame - # if they were manually corrected (i.e., they cannot be mother). - # Note that potential mother cells must be either appearing in - # current frame or in G1 also at previous frame. - # If we would consider cells that are in G1 at current frame - # but not in previous frame, assigning a bud to it would - # result in no G1 at all for the mother cell. - df_G1 = posData.cca_df[posData.cca_df['cell_cycle_stage']=='G1'] - current_G1_IDs = df_G1.index - new_cell_G1 = [ - ID for ID in current_G1_IDs if ID not in prev_cca_df.index - ] - IDsCellsG1.update(new_cell_G1) - cells_S_current = posData.cca_df[ - (posData.cca_df['cell_cycle_stage']=='S') - & (posData.cca_df['corrected_on_frame_i']==posData.frame_i) - ].index - IDsCellsG1 = IDsCellsG1 - set(cells_S_current) - - # Remove cells that disappeared - IDsCellsG1 = [ID for ID in IDsCellsG1 if ID in posData.IDs] - - numCellsG1 = len(IDsCellsG1) - numNewCells = len(posData.new_IDs) - if numCellsG1 < numNewCells: - notEnoughG1Cells, proceed = self.handleNoCellsInG1( - numCellsG1, numNewCells - ) - return notEnoughG1Cells, proceed - - # Compute new IDs contours - newIDs_contours = [] - for obj in posData.rp: - ID = obj.label - if ID in posData.new_IDs: - cont = self.getObjContours(obj) - newIDs_contours.append(cont) - - # Compute cost matrix - cost = self._getCcaCostMatrix( - numCellsG1, numNewCells, IDsCellsG1, newIDs_contours - ) - - # Run hungarian (munkres) assignment algorithm - row_idx, col_idx = scipy.optimize.linear_sum_assignment(cost) - - # New mother cells - newMothIDs = {IDsCellsG1[i] for i in row_idx} - - # Assign buds to mothers - for i, j in zip(row_idx, col_idx): - mothID = IDsCellsG1[i] - budID = posData.new_IDs[j] - - relID = None - # If we are repeating assignment for the bud then we also have to - # correct the possibily wrong mother --> it goes back to - # G1 if it's not a mother that we assign now - if budID in posData.cca_df.index: - relID = posData.cca_df.at[budID, 'relative_ID'] - if relID in prev_cca_df.index and relID not in newMothIDs: - posData.cca_df.loc[relID] = prev_cca_df.loc[relID] - - posData.cca_df.at[mothID, 'relative_ID'] = budID - posData.cca_df.at[mothID, 'cell_cycle_stage'] = 'S' - - bud_cca_dict = base_cca_dict.copy() - bud_cca_dict['cell_cycle_stage'] = 'S' - bud_cca_dict['generation_num'] = 0 - bud_cca_dict['relative_ID'] = mothID - bud_cca_dict['relationship'] = 'bud' - bud_cca_dict['emerg_frame_i'] = posData.frame_i - bud_cca_dict['is_history_known'] = True - bud_cca_dict['corrected_on_frame_i'] = -1 - posData.cca_df.loc[budID] = pd.Series(bud_cca_dict) - - # Keep only existing IDs - posData.cca_df = posData.cca_df.loc[posData.IDs] - - self.store_cca_df() - proceed = True - return notEnoughG1Cells, proceed - - def autoLinTree_df(self, enforceAll=False): - """Automatically generates a lineage tree dataframe. - - This method generates a lineage tree dataframe based on the current mode and data. - It checks if the mode is set to 'Normal division: Lineage tree' and if the current frame - is not already processed. If the conditions are met, it retrieves the necessary data - from the current position data and previous position data, and passes it to the - `real_time` method of the `lineage_tree` object. Finally, it converts the lineage tree - to an ACDC dataframe and adds the current frame to the set of frames that have been - processed. - - Parameters - ---------- - enforceAll : bool, optional - If True, enforces processing of all frames, even if they have been processed before. - If False, only processes frames that have not been processed before. Default is False. - - Returns - ------- - bool - True if there are not enough G1 cells for lineage tree generation, False otherwise. - bool - True if the lineage tree generation should proceed, False otherwise. - """ - proceed = True - notEnoughG1Cells = False - mode = str(self.modeComboBox.currentText()) - - # Skip if not the right mode - if mode != 'Normal division: Lineage tree': - return notEnoughG1Cells, proceed - - posData = self.data[self.pos_i] - frame_i = posData.frame_i - - if frame_i in self.lineage_tree.frames_for_dfs: - return notEnoughG1Cells, proceed - - # Make sure that this is a visited frame in segmentation tracking mode - if posData.allData_li[frame_i]['labels'] is None: # may need to change this - proceed = self.warnFrameNeverVisitedSegmMode() - return notEnoughG1Cells, proceed - - self.store_data(autosave=False) - self.get_data() - lab = posData.lab - prev_lab = posData.allData_li[frame_i-1]['labels'] - rp = posData.rp - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - - self.lineage_tree.real_time(frame_i, lab, prev_lab, rp=rp, prev_rp=prev_rp) - self.store_data() - - def getObjBbox(self, obj_bbox): - if self.isSegm3D and len(obj_bbox)==6: - obj_bbox = (obj_bbox[1], obj_bbox[2], obj_bbox[4], obj_bbox[5]) - return obj_bbox - else: - return obj_bbox - - def z_lab(self, checkIfProj=False): - if checkIfProj and self.zProjComboBox.currentText() != 'single z-slice': - return - - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - - idx = self.zSliceScrollBar.sliderPosition() - - # ensure idx doesnt exceed the number of z-slices of the position - idx_z = min(idx, posData.SizeZ-1) - - if not self.switchPlaneCombobox.isEnabled(): - return idx_z - - depthAxes = self.switchPlaneCombobox.depthAxes() - if depthAxes == 'z': - return idx_z - elif depthAxes == 'y': - idx_y = min(idx, posData.SizeY-1) - return (slice(None), idx_y) - else: - idx_x = min(idx, posData.SizeX-1) - return (slice(None), slice(None), idx_x) - - def get_2Dlab(self, lab, force_z=True): - if self.isSegm3D: - if force_z: - return lab[self.z_lab()] - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - return lab[self.z_lab()] - else: - return lab.max(axis=0) - else: - return lab - - # @exec_time - def applyEraserMask(self, mask): - posData = self.data[self.pos_i] - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - posData.lab[self.z_lab(), mask] = 0 - else: - posData.lab[:, mask] = 0 - else: - posData.lab[mask] = 0 - - def changeBrushID(self): - """Function called when pressing or releasing shift - """ - if not self.isSegm3D: - # Changing brush ID with shift is only for 3D segm - return - - if not self.brushButton.isChecked(): - # Brush if not active - return - - if not self.isMouseDragImg2 and not self.isMouseDragImg1: - # Mouse is not brushing at the moment - return - - posData = self.data[self.pos_i] - forceNewObj = not self.isNewID - - if forceNewObj: - # Shift is down --> force new object with brush - # e.g., 24 --> 28: - # 24 is hovering ID that we store as self.prevBrushID - # 24 object becomes 28 that is the new posData.brushID - self.isNewID = True - self.changedID = posData.brushID - self.restoreBrushID = posData.brushID - # Set a new ID - self.setBrushID() - else: - # Shift released or hovering on ID in z+-1 - # --> restore brush ID from before shift was pressed or from - # when we started brushing from outside an object - # but we hovered on ID in z+-1 while dragging. - # We change the entire 28 object to 24 so before changing the - # brush ID back to 24 we builg the mask with 28 to change it to 24 - self.isNewID = False - self.changedID = posData.brushID - # Restore ID - posData.brushID = self.restoreBrushID - - brushID = posData.brushID - brushIDmask = self.get_2Dlab(posData.lab) == self.changedID - self.applyBrushMask(brushIDmask, brushID) - if self.isMouseDragImg1: - self.brushColor = self.lut[posData.brushID]/255 - self.setTempImg1Brush(True, brushIDmask, posData.brushID) - - def applyBrushMask(self, mask, ID, toLocalSlice=None): - posData = self.data[self.pos_i] - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - if toLocalSlice is not None: - toLocalSlice = (self.z_lab(), *toLocalSlice) - posData.lab[toLocalSlice][mask] = ID - else: - posData.lab[self.z_lab()][mask] = ID - else: - if toLocalSlice is not None: - for z in range(len(posData.lab)): - _slice = (z, *toLocalSlice) - posData.lab[_slice][mask] = ID - else: - posData.lab[:, mask] = ID - else: - if toLocalSlice is not None: - posData.lab[toLocalSlice][mask] = ID - else: - posData.lab[mask] = ID - - def assignNewIDfromClickedID( - self, clickedID: int, event: QGraphicsSceneMouseEvent - ): - posData = self.data[self.pos_i] - x, y = event.pos().x(), event.pos().y() - newID = self.setBrushID(return_val=True) - mapper = [(clickedID, newID)] - self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y) - - def get_2Drp(self, lab=None): - if self.isSegm3D: - if lab is None: - # self.currentLab2D is defined at self.setImageImg2() - lab = self.currentLab2D - lab = self.get_2Dlab(lab) - rp = skimage.measure.regionprops(lab) - return rp - else: - return self.data[self.pos_i].rp - - def set_2Dlab(self, lab2D, lab3D=None): - posData = self.data[self.pos_i] - - if lab3D is None: - lab3D = posData.lab - - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - lab3D[self.z_lab()] = lab2D - else: - lab3D[:] = lab2D - else: - if lab3D.shape == lab2D.shape: - lab3D[...] = lab2D - else: - posData.lab = lab2D - - def get_labels( - self, - from_store=False, - frame_i=None, - return_existing=False, - return_copy=True - ): - """Get the labels array. - - Parameters - ---------- - from_store : bool, optional - If True load the labels array from the stored posData.allData_li, - i.e., from RAM. Default is False - frame_i : int, optional - If None, use the current frame index. Default is None - return_existing : bool, optional - If True, the second return element will be a boolean that - is True if the labels array was found stored in `posData.allData_li`. - Default is False - return_copy : bool, optional - If True returns a copy of the labels array - - Returns - ------- - numpy.ndarray or tuple of (numpy.ndarray, bool) - The first element is the labels array requested. If `return_existing` - is True then this method also returns a second boolean element that - is True if the labels array was found in in `posData.allData_li`. - - Note - ---- - - If `from_store` is True then this method will try to get the stored - labels array. If any error occurs then the returned labels are the - saved ones in the segmentation file (i.e., from hard drive). - - """ - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - existing = True - if from_store: - try: - labels = posData.allData_li[frame_i]['labels'] - if labels is None: - from_store = False - except Exception as err: - from_store = False - - if not from_store: - try: - labels = posData.segm_data[frame_i] - except IndexError: - existing = False - # Visting a frame that was not segmented --> empty masks - if self.isSegm3D: - shape = (posData.SizeZ, posData.SizeY, posData.SizeX) - else: - shape = (posData.SizeY, posData.SizeX) - labels = np.zeros(shape, dtype=np.uint32) - return_copy = False - - if return_copy: - labels = labels.copy() - - if return_existing: - return labels, existing - else: - return labels - - def addYXcentroidToDf(self, df): - posData = self.data[self.pos_i] - for obj in posData.rp: - y_centroid = int(self.getObjCentroid(obj.centroid)[0]) - x_centroid = int(self.getObjCentroid(obj.centroid)[1]) - df.at[obj.label, 'y_centroid'] = y_centroid - df.at[obj.label, 'x_centroid'] = x_centroid - return df - - def _get_editID_info(self, df): - if 'was_manually_edited' not in df.columns: - return [] - - if 'y_centroid' not in df.columns or 'x_centroid' not in df.columns: - df = self.addYXcentroidToDf(df) - - manually_edited_df = df[df['was_manually_edited'] > 0] - editID_info = [ - (row.y_centroid, row.x_centroid, row.Index) - for row in manually_edited_df.itertuples() - ] - return editID_info - - def apply_manual_edits_to_lab_if_needed(self, lab): - posData = self.data[self.pos_i] - data_frame_i = posData.allData_li[posData.frame_i] - edited_lab_dict = data_frame_i['manually_edited_lab']['lab'] - if not edited_lab_dict: - return lab - - # zoom_slice = data_frame_i['manually_edited_lab']['zoom_slice'] - for z, lab_edited in edited_lab_dict.items(): - if not self.isSegm3D: - # lab[zoom_slice] = lab_edited - lab = lab_edited - break - - lab[z] = lab_edited - - # lab[z, zoom_slice[0], zoom_slice[1]] = zoom_lab - - return lab - - def _get_data_unvisited(self, posData, debug=False, lin_tree_init=True,): - posData.editID_info = [] - proceed_cca = True - never_visited = True - if str(self.modeComboBox.currentText()) == 'Cell cycle analysis': - # Warn that we are visiting a frame that was never segm-checked - # on cell cycle analysis mode - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Segmentation and Tracking was never checked from ' - f'frame {posData.frame_i+1} onwards.

' - 'To ensure correct cell cell cycle analysis you have to ' - 'first visit the frames after ' - f'{posData.frame_i+1} with "Segmentation and Tracking" mode.' - ) - warn_cca = msg.critical( - self, 'Never checked segmentation on requested frame', txt - ) - proceed_cca = False - return proceed_cca, never_visited - - elif str(self.modeComboBox.currentText()) == 'Normal division: Lineage tree': - # Warn that we are visiting a frame that was never segm-checked - # on cell cycle analysis mode - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Segmentation and Tracking was never checked from ' - f'frame {posData.frame_i+1} onwards.

' - 'To ensure correct lineage tree analysis you have to ' - 'first visit the frames after ' - f'{posData.frame_i+1} with "Segmentation and Tracking" mode.' - ) - warn_cca = msg.critical(#??? - self, 'Never checked segmentation on requested frame', txt - ) - proceed_cca = False - return proceed_cca, never_visited - - # Requested frame was never visited before. Load from HDD - labels = self.get_labels() - posData.lab = self.apply_manual_edits_to_lab_if_needed( - labels - ) - posData.rp = skimage.measure.regionprops(posData.lab) - self.setManualBackgroundLab() - - if posData.acdc_df is not None: - frames = posData.acdc_df.index.get_level_values(0) - if posData.frame_i in frames: - # Since there was already segmentation metadata from - # previous closed session add it to current metadata - df = posData.acdc_df.loc[posData.frame_i].copy() - binnedIDs_df = df[df['is_cell_excluded']>0] - binnedIDs = set(binnedIDs_df.index).union(posData.binnedIDs) - posData.binnedIDs = binnedIDs - ripIDs_df = df[df['is_cell_dead']>0] - ripIDs = set(ripIDs_df.index).union(posData.ripIDs) - posData.ripIDs = ripIDs - posData.editID_info.extend(self._get_editID_info(df)) - # Load cca df into current metadata - if 'cell_cycle_stage' in df.columns: - cca_cols = df.columns.intersection(self.cca_df_colnames) - cca_df = df[cca_cols].dropna() - if cca_df.empty: - df = df.drop( - columns=self.cca_df_colnames, errors='ignore' - ) - else: - df = df.loc[cca_df.index] - cols = self.cca_df_int_cols - df[cols] = df[cols].astype('Int64') - - i = posData.frame_i - posData.allData_li[i]['acdc_df'] = df.copy() - - if self.lineage_tree is None and lin_tree_init: - self.initLinTree() - - self.get_cca_df() - - return proceed_cca, never_visited - - def _get_data_visited(self, posData, debug=False, lin_tree_init=True,): - # Requested frame was already visited. Load from RAM. - never_visited = False - posData.lab = self.get_labels(from_store=True) - posData.rp = skimage.measure.regionprops(posData.lab) - df = posData.allData_li[posData.frame_i]['acdc_df'] - if df is None: - posData.binnedIDs = set() - posData.ripIDs = set() - posData.editID_info = [] - else: - try: - binnedIDs_df = df[df['is_cell_excluded']>0] - except Exception as err: - df = myutils.fix_acdc_df_dtypes(df) - binnedIDs_df = df[df['is_cell_excluded']>0] - posData.binnedIDs = set(binnedIDs_df.index) - ripIDs_df = df[df['is_cell_dead']>0] - posData.ripIDs = set(ripIDs_df.index) - posData.editID_info = self._get_editID_info(df) - self.setManualBackgroundLab(load_from_store=True, debug=debug) - if self.lineage_tree is None and lin_tree_init: - self.initLinTree() - - self.get_cca_df(debug=debug) - - return True, never_visited - - @get_data_exception_handler - def get_data(self, debug=False, lin_tree_init=True): - posData = self.data[self.pos_i] - proceed_cca = True - never_visited = False - if posData.frame_i > 2: - # Remove undo states from 4 frames back to avoid memory issues - posData.UndoRedoStates[posData.frame_i-4] = [] - # Check if current frame contains undo states (not empty list) - if posData.UndoRedoStates[posData.frame_i]: - self.undoAction.setDisabled(False) - elif posData.UndoRedoCcaStates[posData.frame_i]: - self.undoAction.setDisabled(False) - else: - self.undoAction.setDisabled(True) - self.UndoCount = 0 - # If stored labels is None then it is the first time we visit this frame - if posData.allData_li[posData.frame_i]['labels'] is None: - proceed_cca, never_visited = self._get_data_unvisited( - posData, lin_tree_init=lin_tree_init, - ) - if not proceed_cca: - return proceed_cca, never_visited - else: - proceed_cca, never_visited = self._get_data_visited( - posData, lin_tree_init=lin_tree_init, debug=debug - ) - - self.update_rp_metadata(draw=False) - posData.IDs = [obj.label for obj in posData.rp] - posData.IDs_idxs = { - ID:i for ID, i in zip(posData.IDs, range(len(posData.IDs))) - } - self.get_zslices_rp() - self.pointsLayerDfsToData(posData) - return proceed_cca, never_visited - - def addIDBaseCca_df(self, posData, ID): - if ID <= 0: - # When calling update_cca_df_deletedIDs we add relative IDs - # but they could be -1 for cells in G1 - return - - _zip = zip( - self.cca_df_colnames, - self.cca_df_default_values, - ) - if posData.cca_df.empty: - posData.cca_df = pd.DataFrame( - {col: val for col, val in _zip}, - index=[ID] - ) - else: - for col, val in _zip: - posData.cca_df.at[ID, col] = val - self.store_cca_df() - - def getBaseCca_df(self, with_tree_cols=False): - posData = self.data[self.pos_i] - IDs = [obj.label for obj in posData.rp] - cca_df = core.getBaseCca_df(IDs, with_tree_cols=with_tree_cols) - return cca_df - - def get_last_tracked_i(self): - posData = self.data[self.pos_i] - last_tracked_i = 0 - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None and frame_i == 0: - last_tracked_i = 0 - break - elif lab is None: - last_tracked_i = frame_i-1 - break - else: - last_tracked_i = posData.segmSizeT-1 - return last_tracked_i - - def get_last_cca_frame_i(self): - posData = self.data[self.pos_i] - - i = 0 - # Determine last annotated frame index - for i, dict_frame_i in enumerate(posData.allData_li): - df = dict_frame_i['acdc_df'] - if df is None: - break - elif 'cell_cycle_stage' not in df.columns: - break - - last_cca_frame_i = i if i==0 or i+1==len(posData.allData_li) else i-1 - - return last_cca_frame_i - - def initSegmTrackMode(self): - posData = self.data[self.pos_i] - last_tracked_i = self.get_last_tracked_i() - - if posData.frame_i > last_tracked_i: - # Prompt user to go to last tracked frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - f'The last visited frame in "Segmentation and Tracking mode" ' - f'is frame {last_tracked_i+1}.\n\n' - f'We recommend to resume from that frame.

' - 'How do you want to proceed?' - ) - goToButton, stayButton = msg.warning( - self, 'Go to last visited frame?', txt, - buttonsTexts=( - f'Resume from frame {last_tracked_i+1} (RECOMMENDED)', - f'Stay on current frame {posData.frame_i+1}' - ) - ) - if msg.clickedButton == goToButton: - posData.frame_i = last_tracked_i - self.lastFrameRanOnFirstVisitTools = posData.frame_i - self.get_data() - self.updateAllImages() - self.updateScrollbars() - else: - last_tracked_i = posData.frame_i - current_frame_i = posData.frame_i - self.lastFrameRanOnFirstVisitTools = posData.frame_i - self.logger.info( - f'Storing data up until frame n. {current_frame_i+1}...' - ) - pbar = tqdm(total=current_frame_i+1, ncols=100) - for i in range(current_frame_i): - posData.frame_i = i - self.get_data() - self.store_data(autosave=i==current_frame_i-1) - pbar.update() - pbar.close() - - posData.frame_i = current_frame_i - self.get_data() - - self.highlightLostNew() - self.updateLastCheckedFrameWidgets(last_tracked_i) - - self.isFirstTimeOnNextFrame() - self.initRealTimeTracker() - - def updateLastCheckedFrameWidgets(self, last_tracked_i): - self.navigateScrollBar.setMaximum(last_tracked_i+1) - self.navSpinBox.setMaximum(last_tracked_i+1) - self.lastTrackedFrameLabel.setText( - f'Last checked frame n. = {last_tracked_i+1}' - ) - - @exception_handler - def initCca(self): - posData = self.data[self.pos_i] - last_tracked_i = self.get_last_tracked_i() - defaultMode = 'Viewer' - if last_tracked_i == 0: - txt = html_utils.paragraph( - 'On this dataset either you never checked that the segmentation ' - 'and tracking are correct or you did not save yet.

' - 'If you already visited some frames with "Segmentation and Tracking" ' - 'mode save data before switching to "Cell cycle analysis mode".

' - 'Otherwise you first have to check (and eventually correct) some frames ' - 'in "Segmentation and Tracking" mode before proceeding ' - 'with cell cycle analysis.') - msg = widgets.myMessageBox() - msg.critical( - self, 'Tracking was never checked', txt - ) - self.modeComboBox.setCurrentText(defaultMode) - return - - proceed = True - - last_cca_frame_i = self.get_last_cca_frame_i() - if last_cca_frame_i == 0: - # Remove undoable actions from segmentation mode - posData.UndoRedoStates[0] = [] - self.undoAction.setEnabled(False) - self.redoAction.setEnabled(False) - - if posData.frame_i > last_cca_frame_i: - # Prompt user to go to last annotated frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - The last annotated frame is frame {last_cca_frame_i+1}.

- Do you want to restart cell cycle analysis from frame - {last_cca_frame_i+1}?
- """) - _, goToFrameButton, stayButton = msg.warning( - self, 'Go to last annotated frame?', txt, - buttonsTexts=( - 'Cancel', f'Yes, go to frame {last_cca_frame_i+1}', - 'No, stay on current frame') - ) - if goToFrameButton == msg.clickedButton: - self.addMissingIDs_cca_df(posData) - self.store_cca_df() - msg = 'Looking good!' - self.last_cca_frame_i = last_cca_frame_i - posData.frame_i = last_cca_frame_i - self.titleLabel.setText(msg, color=self.titleColor) - self.get_data() - self.addMissingIDs_cca_df(posData) - self.store_cca_df() - self.updateAllImages() - self.updateScrollbars() - elif stayButton == msg.clickedButton: - self.addMissingIDs_cca_df(posData) - self.store_cca_df() - self.initMissingFramesCca(last_cca_frame_i, posData.frame_i) - last_cca_frame_i = posData.frame_i - msg = 'Cell cycle analysis initialised!' - self.titleLabel.setText(msg, color='g') - elif msg.cancel: - msg = 'Cell cycle analysis aborted.' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - self.modeComboBox.setCurrentText(defaultMode) - proceed = False - return - elif posData.frame_i < last_cca_frame_i: - # Prompt user to go to last annotated frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - The last annotated frame is frame {last_cca_frame_i+1}.

- Do you want to restart cell cycle analysis from frame - {last_cca_frame_i+1}?
- """) - yesButton, noButton, _ = msg.question( - self, 'Go to last annotated frame?', txt, - buttonsTexts=('Yes', 'No', 'Cancel') - ) - if msg.cancel: - msg = 'Cell cycle analysis aborted.' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - self.modeComboBox.setCurrentText(defaultMode) - proceed = False - return - - self.addMissingIDs_cca_df(posData) - if msg.clickedButton == yesButton: - self.addMissingIDs_cca_df(posData) - msg = 'Looking good!' - self.titleLabel.setText(msg, color=self.titleColor) - self.last_cca_frame_i = last_cca_frame_i - posData.frame_i = last_cca_frame_i - self.get_data() - self.addMissingIDs_cca_df(posData) - self.store_cca_df() - self.updateAllImages() - self.updateScrollbars() - else: - self.get_data() - self.addMissingIDs_cca_df(posData) - self.store_cca_df() - - self.last_cca_frame_i = last_cca_frame_i - - self.navigateScrollBar.setMaximum(last_cca_frame_i+1) - self.navSpinBox.setMaximum(last_cca_frame_i+1) - self.lastTrackedFrameLabel.setText( - f'Last cc annot. frame n. = {last_cca_frame_i+1}' - ) - - if posData.cca_df is None: - posData.cca_df = self.getBaseCca_df() - self.store_cca_df() - msg = 'Cell cycle analysis initialized!' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - else: - self.get_cca_df() - - self.enqCcaIntegrityChecker() - - return proceed - @exception_handler - def initLinTree(self, force=False): - """ - Initializes the lineage tree analysis. - - This method checks if the tracking has been previously checked and saved. If not, it displays a message to the user. - It also prompts the user to go to the last annotated frame and restart the lineage tree analysis if necessary. - Finally, it initializes the necessary data structures and updates the GUI. - - Returns - ------- - proceed : bool - True if the initialization is successful, nothing otherwise. - """ - - if not force and self.lineage_tree is not None: - return - - mode = str(self.modeComboBox.currentText()) - if mode != 'Normal division: Lineage tree' and not force: - return - - posData = self.data[self.pos_i] - last_tracked_i = self.get_last_tracked_i() - defaultMode = 'Viewer' - if last_tracked_i == 0: - # Display message to the user - txt = html_utils.paragraph( - 'On this dataset either you never checked that the segmentation ' - 'and tracking are correct or you did not save yet.

' - 'If you already visited some frames with "Segmentation and Tracking" ' - 'mode save data before switching to "Normal division: Lineage Tree".

' - 'Otherwise you first have to check (and eventually correct) some frames ' - 'in "Segmentation and Tracking" mode before proceeding ' - 'with lineage tree analysis.') - msg = widgets.myMessageBox() - msg.critical( - self, 'Tracking was never checked', txt - ) - self.modeComboBox.setCurrentText(defaultMode) - return - - proceed = True - last_lin_tree_frame_i = 0 - # Determine last annotated frame index - for i, dict_frame_i in enumerate(posData.allData_li): - df = dict_frame_i['acdc_df'] - if (df is None or - 'generation_num_tree' not in df.columns - or df['generation_num_tree'].isin([np.nan, 0]).all() - ): - break - else: - last_lin_tree_frame_i = i - - if last_lin_tree_frame_i == 0: - # Remove undoable actions from segmentation mode - posData.UndoRedoStates[0] = [] - self.undoAction.setEnabled(False) - self.redoAction.setEnabled(False) - - if posData.frame_i > last_lin_tree_frame_i: - # Prompt user to go to last annotated frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - The last annotated frame is frame {last_lin_tree_frame_i+1}.

- Do you want to restart lineage tree analysis from frame - {last_lin_tree_frame_i+1}?
- """) - _, yesButton, stayButton = msg.warning( - self, 'Go to last annotated frame?', txt, - buttonsTexts=( - 'Cancel', f'Yes, go to frame {last_lin_tree_frame_i+1}', - 'No, stay on current frame') - ) - if yesButton == msg.clickedButton: - msg = 'Looking good!' - self.last_lin_tree_frame_i = last_lin_tree_frame_i - posData.frame_i = last_lin_tree_frame_i - self.titleLabel.setText(msg, color=self.titleColor) - self.get_data(lin_tree_init=False) - self.updateAllImages() # i dont think I need to change this - self.updateScrollbars() # i dont think I need to change this - elif stayButton == msg.clickedButton: - self.initMissingFramesLinTree(posData.frame_i) #!!! - last_lin_tree_frame_i = posData.frame_i - msg = 'Lineage tree analysis initialised!' - self.titleLabel.setText(msg, color='g') - elif msg.cancel: - msg = 'Lineage tree analysis aborted.' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - self.modeComboBox.setCurrentText(defaultMode) - proceed = False - return - - elif posData.frame_i < last_lin_tree_frame_i: - # Prompt user to go to last annotated frame - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - The last annotated frame is frame {last_lin_tree_frame_i+1}.

- Do you want to restart lineage tree analysis from frame - {last_lin_tree_frame_i+1}?
- """) - goTo_last_annotated_frame_i = msg.question( - self, 'Go to last annotated frame?', txt, - buttonsTexts=('Yes', 'No', 'Cancel') - )[0] - if goTo_last_annotated_frame_i == msg.clickedButton: - msg = 'Looking good!' - self.titleLabel.setText(msg, color=self.titleColor) - self.last_lin_tree_frame_i = last_lin_tree_frame_i - posData.frame_i = last_lin_tree_frame_i - self.get_data(lin_tree_init=False) - self.updateAllImages() # i dont think I need to change this - self.updateScrollbars() # i dont think I need to change this - elif msg.cancel: - msg = 'Lineage tree analysis aborted.' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - self.modeComboBox.setCurrentText(defaultMode) - proceed = False - return - else: - self.get_data(lin_tree_init=False) - - self.last_lin_tree_frame_i = last_lin_tree_frame_i - - self.navigateScrollBar.setMaximum(last_lin_tree_frame_i+1) - self.navSpinBox.setMaximum(last_lin_tree_frame_i+1) - - if self.lineage_tree is None or force: - self.store_data(autosave=False) - self.get_data(lin_tree_init=False) - self.lineage_tree = normal_division_lineage_tree(gui=self) - - msg = 'Lineage tree analysis initialized!' - self.logger.info(msg) - self.titleLabel.setText(msg, color=self.titleColor) - - return proceed - - @disableWindow - def propagateLinTreeAction(self, dummy_for_button=None): - """ - Propagates the lineage tree based on the current frame_i. Used in self.propagateLinTreeButton. - """ - posData = self.data[self.pos_i] - self.lineage_tree.propagate(posData.frame_i) - if posData.frame_i == self.original_df_lin_tree_i: - self.original_df_lin_tree = posData.allData_li[posData.frame_i]['acdc_df'].copy() - - self.logger.info('Lineage tree propagated.') - - def isCcaCheckerChecking(self): - if not self.ccaCheckerRunning: - return False - - return self.ccaIntegrityCheckerWorker.isChecking - - def getConcatCcaDf(self): - posData = self.data[self.pos_i] - cca_dfs = [] - keys = [] - for frame_i in range(0, posData.SizeT): - cca_df = self.get_cca_df(frame_i=frame_i, return_df=True) - if cca_df is None: - break - - cca_dfs.append(cca_df) - keys.append(frame_i) - - if not cca_dfs: - return - - global_cca_df = pd.concat(cca_dfs, keys=keys, names=['frame_i']) - return global_cca_df - - def storeFromConcatCcaDf(self, global_cca_df): - posData = self.data[self.pos_i] - for frame_i in range(0, posData.SizeT): - try: - cca_df = global_cca_df.loc[frame_i] - except KeyError as err: - break - - self.store_cca_df(frame_i=frame_i, cca_df=cca_df, autosave=False) - - self.get_cca_df() - - def resetWillDivideInfo(self): - global_cca_df = self.getConcatCcaDf() - if global_cca_df is None: - return - - global_cca_df = load._fix_will_divide(global_cca_df) - self.storeFromConcatCcaDf(global_cca_df) - - def ccaCheckerStopChecking(self): - if not self.ccaCheckerRunning: - return - - self.ccaIntegrityCheckerWorker.clearQueue() - - if self.ccaIntegrityCheckerWorker.isChecking: - self.ccaIntegrityCheckerWorker.abortChecking = True - - def updateLastVisitedFrame(self, last_visited_frame_i=None): - if last_visited_frame_i is None: - posData = self.data[self.pos_i] - last_visited_frame_i = posData.frame_i - - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - return - elif mode == 'Segmentation and Tracking': - posData = self.data[self.pos_i] - if posData.last_tracked_i >= last_visited_frame_i: - return - posData.last_tracked_i = last_visited_frame_i - elif mode == 'Cell cycle analysis': - if self.last_cca_frame_i >= last_visited_frame_i: - return - self.last_cca_frame_i = last_visited_frame_i - - def resetCcaFuture(self, from_frame_i): - posData = self.data[self.pos_i] - self.last_cca_frame_i = from_frame_i-1 - self.ccaCheckerStopChecking() - - self.setNavigateScrollBarMaximum() - for i in range(from_frame_i, posData.SizeT): - posData.allData_li[i].pop('cca_df', None) - posData.allData_li[i].pop('cca_df_checker', None) - - df = posData.allData_li[i]['acdc_df'] - if df is None: - # No more saved info to delete - break - - if 'cell_cycle_stage' not in df.columns: - # No cell cycle info present - continue - - df = df.drop(columns=self.cca_df_colnames) - posData.allData_li[i]['acdc_df'] = df - - if posData.acdc_df is not None: - frames = posData.acdc_df.index.get_level_values(0) - if from_frame_i in frames: - posData.acdc_df = posData.acdc_df.loc[:from_frame_i] - - self.resetWillDivideInfo() - - def removeCcaAnnotationsCurrentFrame(self): - posData = self.data[self.pos_i] - posData.cca_df = None - - posData.allData_li[posData.frame_i].pop('cca_df', None) - posData.allData_li[posData.frame_i].pop('cca_df_checker', None) - - df = posData.allData_li[posData.frame_i]['acdc_df'] - if df is None: - # No more saved info to delete - return False - - if 'cell_cycle_stage' not in df.columns: - # No cell cycle info present - return False - - df = df.drop(columns=self.cca_df_colnames) - posData.allData_li[posData.frame_i]['acdc_df'] = df - - return True - - def resetFutureCcaColCurrentFrame(self): - posData = self.data[self.pos_i] - - cca_df_S_mask = posData.cca_df.cell_cycle_stage == 'S' - posData.cca_df.loc[cca_df_S_mask, 'will_divide'] = 0 - - mothers_mask = ( - (posData.cca_df.relationship == 'mother') - & cca_df_S_mask - ) - bud_mask = posData.cca_df.relationship == 'bud' - - posData.cca_df.loc[mothers_mask, 'daughter_disappears_before_division'] = 0 - posData.cca_df.loc[bud_mask, 'disappears_before_division'] = 0 - - cca_df = self.get_cca_df(frame_i=posData.frame_i, return_df=True) - if cca_df is not None: - cca_df_S_mask = cca_df.cell_cycle_stage == 'S' - cca_df.loc[cca_df_S_mask, 'will_divide'] = 0 - - mothers_mask = ( - (cca_df.relationship == 'mother') - & cca_df_S_mask - ) - bud_mask = cca_df.relationship == 'bud' - - cca_df.loc[mothers_mask, 'daughter_disappears_before_division'] = 0 - cca_df.loc[bud_mask, 'disappears_before_division'] = 0 - - self.store_data() - - def resetLin_tree_future(self): - posData = self.data[self.pos_i] - frame_i = posData.frame_i - - for i in range(frame_i, posData.SizeT): - if self.lineage_tree is not None: - self.lineage_tree.frames_for_dfs.discard(frame_i) - df = posData.allData_li[i]['acdc_df'] - # reste lineage tree columns - if df is None: - continue - df = df.drop(columns=lineage_tree_cols, errors='ignore') - posData.allData_li[i]['acdc_df'] = df - - def get_cca_df(self, frame_i=None, return_df=False, debug=False): - # cca_df is None unless the metadata contains cell cycle annotations - # NOTE: cell cycle annotations are either from the current session - # or loaded from HDD in "initPosAttr" with a .question to the user - posData = self.data[self.pos_i] - cca_df = None - i = posData.frame_i if frame_i is None else frame_i - df = posData.allData_li[i]['acdc_df'] - if df is not None: - if 'cell_cycle_stage' in df.columns: - cca_df = df[self.cca_df_colnames].copy() - - if cca_df is None and self.isSnapshot: - cca_df = self.getBaseCca_df() - posData.cca_df = cca_df - - if cca_df is not None: - cca_df = cca_df.dropna() - - if return_df: - return cca_df - else: - posData.cca_df = cca_df - - def changeIDfutureFrames( - self, endFrame_i, oldIDnewIDMapper, includeUnvisited, - shift=False - ): - posData = self.data[self.pos_i] - self.current_frame_i = posData.frame_i - - # Store data for current frame - self.store_data() - if endFrame_i is None: - self.app.restoreOverrideCursor() - return - - segmSizeT = len(posData.segm_data) - for i in range(posData.frame_i+1, segmSizeT): - lab = posData.allData_li[i]['labels'] - if lab is None and not includeUnvisited: - self.enqAutosave() - break - - if lab is not None: - # Visited frame - posData.frame_i = i - self.get_data(lin_tree_init=False) - if shift and self.isSegm3D: - lab = self.get_2Dlab(posData.lab) - else: - lab = posData.lab - - if self.onlyTracking: - self.tracking(enforce=True) - elif not posData.IDs: - continue - else: - maxID = max(posData.IDs, default=0) + 1 - for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = maxID + 1 # lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - else: - lab[lab == old_ID] = new_ID - - if shift and self.isSegm3D: - self.set_2Dlab(lab) - - self.update_rp(draw=False) - self.store_data(autosave=i==endFrame_i) - elif includeUnvisited: - # Unvisited frame (includeUnvisited = True) - lab = posData.segm_data[i] - if shift and self.isSegm3D: - lab = self.get_2Dlab(lab) - else: - lab = lab - - for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - else: - lab[lab == old_ID] = new_ID - - if shift and self.isSegm3D: - posData.segm_data[i][self.z_lab()] = lab - - # Back to current frame - posData.frame_i = self.current_frame_i - self.get_data() - self.app.restoreOverrideCursor() - - def unstore_cca_df(self): - posData = self.data[self.pos_i] - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - for col in self.cca_df_colnames: - if col not in acdc_df.columns: - continue - acdc_df.drop(col, axis=1, inplace=True) - - def store_cca_df_checker(self, posData, frame_i, cca_df): - if not self.ccaCheckerRunning: - return - - if cca_df is None: - return - - posData.allData_li[frame_i]['cca_df_checker'] = cca_df.copy() - - def store_cca_df( - self, pos_i=None, frame_i=None, cca_df=None, mainThread=True, - autosave=True, store_cca_df_copy=False - ): - pos_i = self.pos_i if pos_i is None else pos_i - posData = self.data[pos_i] - i = posData.frame_i if frame_i is None else frame_i - if cca_df is None: - cca_df = posData.cca_df - if self.ccaTableWin is not None and mainThread: - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - acdc_df = posData.allData_li[i]['acdc_df'] - if acdc_df is None: - current_frame_i = None - if frame_i is not None and frame_i != posData.frame_i: - current_frame_i = posData.frame_i - posData.frame_i = frame_i - self.get_data() - self.store_data() - acdc_df = posData.allData_li[i]['acdc_df'] - if current_frame_i is not None: - # Back to current frame - posData.frame_i = current_frame_i - self.get_data(debug=False) - - if 'cell_cycle_stage' in acdc_df.columns: - # Cell cycle info already present --> overwrite with new - acdc_df[self.cca_df_colnames] = cca_df[self.cca_df_colnames] - posData.allData_li[i]['acdc_df'] = acdc_df - elif cca_df is not None: - df = acdc_df.drop(cca_df.columns, axis=1, errors='ignore') - df = df.join(cca_df, how='left') - posData.allData_li[i]['acdc_df'] = df - - # Store copy for cca integrity worker - self.store_cca_df_checker(posData, i, cca_df) - - if store_cca_df_copy and cca_df is not None: - posData.allData_li[i]['cca_df'] = cca_df.copy() - - if autosave: - self.enqAutosave() - self.enqCcaIntegrityChecker() - - # def lin_tree_to_acdc_df(self, force_all=False, ignore=set(), force=set(), specific=set()): - # """ - # Syncs the lineage tree DataFrame with the acdc_df DataFrame. By default, it will only try to sync frames which have not been synced before. - # This can be changed using the optional arguments. - - # Parameters - # ---------- - # force_all : bool, optional - # If True, forces synchronization for all frames. Defaults to False. - # ignore : set, optional - # Set of frames to ignore during synchronization. Defaults to set(). - # force : set, optional - # Set of frames to force synchronization. Defaults to set(). - # specific : set, optional - # Set of frames to specifically synchronize. In this case it will ignore all other inputs and sync those no matter what. Defaults to set(). - # """ - - # if self.lineage_tree is None: - # return - - # # df_for_sync = [] - # # lineage_copy = self.lineage_tree.lineage_list.copy() - # lin_tree_set = self.lineage_tree.frames_for_dfs.copy() - - # if not force_all and not specific: - # dont_sync = self.already_synced_lin_tree - # dont_sync = {frame for frame in dont_sync if not frame in force} - # dont_sync.update(ignore) - - # lin_tree_set = lin_tree_set.difference(dont_sync) - - # if specific: - # lin_tree_set = lin_tree_set.intersection(specific) - - - # if lin_tree_set == []: - # return - - # posData = self.data[self.pos_i] - - # lin_tree_colnames = None - # self.store_data(autosave=False) - # for frame_i in lin_tree_set: - # acdc_df = posData.allData_li[frame_i]['acdc_df'] - - # lin_tree_df = self.lineage_tree.export_df(frame_i) - # if lin_tree_colnames is None: - # lin_tree_colnames = lin_tree_df.columns - - # acdc_df.loc[lin_tree_df.index, lin_tree_colnames] = lin_tree_df[lin_tree_colnames] - - # try: - # try: - # if (acdc_df['generation_num'] == 2).all() and not (acdc_df['generation_num_tree'].isna().all()): # check if generation_num is all just the default value and if yes, replace it with the tree values - # acdc_df['generation_num'] = acdc_df['generation_num_tree'] - # except KeyError: - # acdc_df['generation_num'] = acdc_df['generation_num_tree'] - # except Exception as e: - # self.logger.error(f'Error while syncing generation_num from lineage tree: {e} \n please save and restart') - - # posData.allData_li[frame_i]['acdc_df'] = acdc_df - # self.already_synced_lin_tree.add(frame_i) - - def turnOffAutoSaveWorker(self): - self.autoSaveToggle.setChecked(False) - - def autoSaveTimerTimedOut(self): - if not hasattr(self, 'data'): - # This happes when the self.autoSaveTimer times out after - # the GUI has been closed --> we simply ignore it - self.autoSaveTimer.stop() - return - - self.autoSaveTimer.stop() - self.flushDirtyPointsLayersAutosave() - self._enqueueAutoSave() - - def autoSaveTimerCountFrames(self): - if not hasattr(self, 'data'): - # This happes when the self.autoSaveTimer times out after - # the GUI has been closed --> we simply ignore it - return - - posData = self.data[self.pos_i] - autoSaveIntevalValue, autoSaveIntervalUnit = ( - self.autoSaveIntevalValueUnit - ) - isTimeToAutoSave = ( - abs(posData.frame_i - self.autoSaveTimeStartFrameIdx) - >= autoSaveIntevalValue - ) - if not isTimeToAutoSave: - return - - self.autoSaveTimeStartFrameIdx = posData.frame_i - self.flushDirtyPointsLayersAutosave() - self._enqueueAutoSave() - - def enqAutosave(self): - mode = str(self.modeComboBox.currentText()) - if mode == 'Viewer': - if self.statusBarLabel.text().endswith('Autosaving...'): - self.statusBarLabel.setText( - self.statusBarLabel.text().replace(' | Autosaving...', '') - ) - return - - if not self.autoSaveActiveWorkers: - self.gui_createAutoSaveWorker() - - if not self.autoSaveActiveWorkers: - return - - if self.autoSaveTimer.isActive(): - return - - self._enqueueAutoSave() - autoSaveIntevalValue, autoSaveIntervalUnit = ( - self.autoSaveIntevalValueUnit - ) - if autoSaveIntevalValue == 0: - return - - try: - self.autoSaveTimer.timeout.disconnect() - except Exception as err: - pass - - - if autoSaveIntervalUnit == 'minutes': - autosave_interval_ms = round(autoSaveIntevalValue*60*1000) - self.autoSaveTimer.timeout.connect(self.autoSaveTimerTimedOut) - self.autoSaveTimer.start(autosave_interval_ms) - else: - self.startAutoSaveEveryNframesTimer() - - def startAutoSaveEveryNframesTimer(self): - posData = self.data[self.pos_i] - self.autoSaveTimeStartFrameIdx = posData.frame_i - self.autoSaveTimer.timeout.connect( - self.autoSaveTimerCountFrames - ) - self.autoSaveTimer.start(500) - - def _enqueueAutoSave(self): - if not self.statusBarLabel.text().endswith('Autosaving...'): - self.statusBarLabel.setText( - f'{self.statusBarLabel.text()} | Autosaving...' - ) - - timestamp = datetime.now().strftime(r'%H:%M:%S.%f')[:-3] - self.logger.info(f'Autosaving... - {timestamp}') - - posData = self.data[self.pos_i] - worker, thread = self.autoSaveActiveWorkers[-1] - worker.enqueue(posData) - - def enqCcaIntegrityChecker(self): - if not self.ccaCheckerRunning: - return - posData = self.data[self.pos_i] - self.ccaIntegrityCheckerWorker.enqueue(posData) - - def drawAllMothBudLines(self): - posData = self.data[self.pos_i] - for obj in posData.rp: - self.drawObjMothBudLines(obj, posData, ax=0) - self.drawObjMothBudLines(obj, posData, ax=1) - - def drawObjMothBudLines(self, obj, posData, ax=0): - areMothBudLinesRequested = self.areMothBudLinesRequested(ax) - if not areMothBudLinesRequested: - return - - if posData.cca_df is None: - return - - mode = str(self.modeComboBox.currentText()) - if mode == 'Normal division: Lineage Tree': - return - - ID = obj.label - try: - cca_df_ID = posData.cca_df.loc[ID] - except KeyError: - return - - isObjVisible = self.isObjVisible(obj.bbox) - if not isObjVisible: - return - - ccs_ID = cca_df_ID['cell_cycle_stage'] - if ccs_ID == 'G1': - return - - relationship = cca_df_ID['relationship'] - if relationship != 'bud': - return - - emerg_frame_i = cca_df_ID['emerg_frame_i'] - isNew = emerg_frame_i == posData.frame_i - scatterItem = self.getMothBudLineScatterItem(ax, isNew) - relative_ID = cca_df_ID['relative_ID'] - - try: - relative_rp_idx = posData.IDs_idxs[relative_ID] - except KeyError: - return - - relative_ID_obj = posData.rp[relative_rp_idx] - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(relative_ID_obj.centroid) - xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) - scatterItem.addPoints(xx, yy) - - def clearAllCellToCellLines(self): - self.ax1_newMothBudLinesItem.setData([], []) - self.ax1_oldMothBudLinesItem.setData([], []) - self.ax2_newMothBudLinesItem.setData([], []) - self.ax2_oldMothBudLinesItem.setData([], []) - - def drawAllLineageTreeLines(self): - """ - Draw all lineage tree lines on the GUI. - - This method retrieves the lineage tree data and draws the lineage tree lines - connecting cells and their respective mothers when the mother has split. - """ - if self.lineage_tree is None: - return - - if len(self.lineage_tree.frames_for_dfs) < 2: - return - - self.clearAllCellToCellLines() - posData = self.data[self.pos_i] - frame_i = posData.frame_i - lin_tree_df = posData.allData_li[frame_i]['acdc_df'] - lin_tree_df_prev = posData.allData_li[frame_i-1]['acdc_df'] - rp = posData.rp - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - - self.setTitleText() - - new_cells = lin_tree_df.index.difference(lin_tree_df_prev.index) # I could use this for the if already but this is probably faster for frames where nothing changes - if new_cells.shape[0] == 0: - return - - for ax in (0, 1): - if not self.areMothBudLinesRequested(ax): - continue - - for ID in new_cells: - curr_obj = myutils.get_obj_by_label(rp, ID) - lin_tree_df_ID = lin_tree_df.loc[ID] - - # lin_tree_df_mother_ID = lin_tree_df_prev.loc[lin_tree_df_ID["parent_ID_tree"]] - if lin_tree_df_ID["parent_ID_tree"] == -1: # make sure that new obj where the parents are not known get skipped - continue - - mother_obj = myutils.get_obj_by_label(prev_rp, lin_tree_df_ID["parent_ID_tree"]) - - emerg_frame_i = lin_tree_df_ID["emerg_frame_i"] - isNew = emerg_frame_i == frame_i - - self.drawObjLin_TreeMothBudLines(ax, curr_obj, mother_obj, isNew, ID=ID) - - def drawObjLin_TreeMothBudLines(self, ax, obj, mother_obj, isNew, ID=None): - """ - Draw moth-bud lines between an object and its mother object. - - Parameters - ---------- - ax : cellacdc.widgets.MainPlotItem - The Cell-ACDC GUI axes object to draw on. - obj : Object - The object for which to draw the moth-bud lines. - mother_obj : Object - The mother object to connect with. - isNew : bool - Indicates whether the object is new or not. - ID : int, optional - The ID of the object, by default None. - """ - if not self.areMothBudLinesRequested(ax): - return - - if not ID: - ID = obj.label - - isObjVisible = self.isObjVisible(obj.bbox) - - if not isObjVisible: - return - - scatterItem = self.getMothBudLineScatterItem(ax, isNew) - - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(mother_obj.centroid) - xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) - scatterItem.addPoints(xx, yy) - - def getObjCentroid(self, obj_centroid): - if self.isSegm3D: - depthAxes = self.switchPlaneCombobox.depthAxes() - zc, yc, xc = obj_centroid - if depthAxes == 'z': - return yc, xc - elif depthAxes == 'y': - return zc, xc - else: - return zc, yc - else: - return obj_centroid - - def getAnnotateHowRightImage(self): - if not self.labelsGrad.showRightImgAction.isChecked(): - return 'nothing' - - if self.rightBottomGroupbox.isChecked(): - how = self.annotateRightHowCombobox.currentText() - else: - how = self.drawIDsContComboBox.currentText() - return how - - def getObjOptsSegmLabels(self, obj): - if not self.labelsGrad.showLabelsImgAction.isChecked(): - return - - objOpts = self.getObjTextAnnotOpts(obj, 'Draw only IDs', ax=1) - return objOpts - - def store_zslices_rp(self, force_update=False): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - are_zslices_rp_stored = ( - posData.allData_li[posData.frame_i].get('z_slices_rp') is not None - ) - if force_update or not are_zslices_rp_stored: - self._update_zslices_rp() - - posData.allData_li[posData.frame_i]['z_slices_rp'] = posData.zSlicesRp - - def removeObjectFromRp(self, delID): - posData = self.data[self.pos_i] - rp = [] - IDs = [] - IDs_idxs = {} - idx = 0 - for obj in posData.rp: - if obj.label == delID: - continue - rp.append(obj) - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - idx += 1 - - posData.rp = rp - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs - - if not self.isSegm3D: - return - - zSlicesRp = {} - for z, zSliceRp in posData.zSlicesRp.items(): - if delID in zSliceRp: - continue - - zSlicesRp[z] = zSlicesRp - - posData.zSlicesRp = zSlicesRp - self.store_zslices_rp(force_update=True) - - def get_zslices_rp(self): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - self.store_zslices_rp() - posData.zSlicesRp = posData.allData_li[posData.frame_i]['z_slices_rp'] - - # @exec_time - def _update_zslices_rp(self): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - posData.zSlicesRp = {} - for z, lab2d in enumerate(posData.lab): - lab2d_rp = skimage.measure.regionprops(lab2d) - posData.zSlicesRp[z] = {obj.label:obj for obj in lab2d_rp} - - def instructHowDeleteID(self): - if 'showInfoDeleteObject' not in self.df_settings.index: - self.df_settings.at['showInfoDeleteObject', 'value'] = 'Yes' - - showInfoDeleteObject = ( - self.df_settings.at['showInfoDeleteObject', 'value'] == 'Yes' - ) - if not showInfoDeleteObject: - return - - actionText = self.middleClickText() - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'You have deleted an object using the eraser tool.

' - 'Did you know that you can use the "Delete object" action
' - 'to delete an object with a single click?

' - f'To do so, use the following action: {actionText}

' - 'Note: You can also set a custom shortcut by going to the menu
' - 'Settings --> Customise keyboard shortcuts....' - ) - doNotShowAgainCheckbox = QCheckBox('Do not show again') - msg.information( - self, 'Delete objects with single click', txt, - widgets=doNotShowAgainCheckbox - ) - - showInfoDeleteObjectValue = ( - 'No' if doNotShowAgainCheckbox.isChecked() else 'Yes' - ) - self.df_settings.at['showInfoDeleteObject', 'value'] = ( - showInfoDeleteObjectValue - ) - self.df_settings.to_csv(settings_csv_path) - - - def checkWarnDeletedIDwithEraser(self): - posData = self.data[self.pos_i] - - for ID in self.erasedIDs: - if ID == 0: - continue - if ID in posData.IDs_idxs: - continue - - self.instructHowDeleteID() - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Delete ID with eraser') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Delete ID with eraser') - - return True - - return False - - @exception_handler - def update_rp( - self, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False,wl_update_lab=False - ): - - posData = self.data[self.pos_i] - # Update rp for current posData.lab (e.g. after any change) - - if wl_update: - if self.whitelistOriginalIDs is None: - old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() # for whitelist stuff - else: - old_IDs = self.whitelistOriginalIDs.copy() - self.whitelistOriginalIDs = None - elif self.whitelistOriginalIDs is None: - self.whitelist_old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() - - posData.rp = skimage.measure.regionprops(posData.lab) - if update_IDs: - IDs = [] - IDs_idxs = {} - for idx, obj in enumerate(posData.rp): - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs - self.update_rp_metadata(draw=draw) - self.store_zslices_rp(force_update=True) - - if not wl_update: - return - - # Update tracking whitelist - accepted_lost_centroids = self.getTrackedLostIDs() - new_IDs = posData.IDs - added_IDs = set(new_IDs) - set(old_IDs) - removed_IDs = ( - set(old_IDs) - - set(new_IDs) - - set(accepted_lost_centroids) - ) - - self.whitelistPropagateIDs( - IDs_to_add=added_IDs, IDs_to_remove=removed_IDs, - curr_frame_only=True, IDs_curr=new_IDs, - track_og_curr=wl_track_og_curr, - curr_lab=posData.lab, curr_rp=posData.rp, - update_lab=wl_update_lab - ) - - def extendLabelsLUT(self, lenNewLut): - posData = self.data[self.pos_i] - # Build a new lut to include IDs > than original len of lut - if lenNewLut > len(self.lut): - numNewColors = lenNewLut-len(self.lut) - # Index original lut - _lut = np.zeros((lenNewLut, 3), np.uint8) - _lut[:len(self.lut)] = self.lut - # Pick random colors and append them at the end to recycle them - randomIdx = np.random.randint(0,len(self.lut),size=numNewColors) - for i, idx in enumerate(randomIdx): - rgb = self.lut[idx] - _lut[len(self.lut)+i] = rgb - self.lut = _lut - self.initLabelsImageItems() - return True - return False - - def initLookupTableLab(self): - self.img2.setLookupTable(self.lut) - self.img2.setLevels([0, len(self.lut)]) - self.initLabelsImageItems() - - def getLabelsImageLut(self): - lut = np.zeros((len(self.lut), 4), dtype=np.uint8) - lut[:,-1] = 255 - lut[:,:-1] = self.lut - lut[0] = [0,0,0,0] - return lut - - def initLabelsImageItems(self): - lut = self.getLabelsImageLut() - self.labelsLayerImg1.setLevels([0, len(lut)]) - self.labelsLayerRightImg.setLevels([0, len(lut)]) - self.labelsLayerImg1.setLookupTable(lut) - self.labelsLayerRightImg.setLookupTable(lut) - alpha = self.imgGrad.labelsAlphaSlider.value() - self.labelsLayerImg1.setOpacity(alpha) - self.labelsLayerRightImg.setOpacity(alpha) - - def initKeepObjLabelsLayers(self): - lut = np.zeros((len(self.lut), 4), dtype=np.uint8) - lut[:,:-1] = self.lut - lut[:,-1:] = 255 - lut[0] = [0,0,0,0] - self.keepIDsTempLayerLeft.setLevels([0, len(lut)]) - self.keepIDsTempLayerLeft.setLookupTable(lut) - - - def updateTempLayerKeepIDs(self): - if not self.keepIDsButton.isChecked(): - return - - keptLab = np.zeros_like(self.currentLab2D) - - posData = self.data[self.pos_i] - for obj in posData.rp: - if obj.label not in self.keptObjectsIDs: - continue - - if not self.isObjVisible(obj.bbox): - continue - - _slice = self.getObjSlice(obj.slice) - _objMask = self.getObjImage(obj.image, obj.bbox) - - keptLab[_slice][_objMask] = obj.label - - self.keepIDsTempLayerLeft.setImage(keptLab, autoLevels=False) - - def highlightLabelID(self, ID, ax=0): - posData = self.data[self.pos_i] - try: - obj = posData.rp[posData.IDs_idxs[ID]] - except KeyError: - return - - self.textAnnot[ax].highlightObject(obj) - - def _keepObjects(self, keepIDs=None, lab=None, rp=None): - posData = self.data[self.pos_i] - if lab is None: - lab = posData.lab - - if rp is None: - rp = posData.rp - - if keepIDs is None: - keepIDs = self.keptObjectsIDs - - for obj in rp: - if obj.label in keepIDs: - continue - - lab[obj.slice][obj.image] = 0 - - return lab - - def clearHighlightedText(self): - pass - - def removeHighlightLabelID(self, IDs=None, ax=0): - posData = self.data[self.pos_i] - if IDs is None: - IDs = posData.IDs - - for ID in IDs: - obj = posData.rp[posData.IDs_idxs[ID]] - self.textAnnot[ax].removeHighlightObject(obj) - - def updateKeepIDs(self, IDs): - posData = self.data[self.pos_i] - - self.clearHighlightedText() - - isAnyIDnotExisting = False - # Check if IDs from line edit are present in current keptObjectIDs list - for ID in IDs: - if ID not in posData.allIDs: - isAnyIDnotExisting = True - continue - if ID not in self.keptObjectsIDs: - self.keptObjectsIDs.append(ID, editText=False) - self.highlightLabelID(ID) - - # Check if IDs in current keptObjectsIDs are present in IDs from line edit - for ID in self.keptObjectsIDs: - if ID not in posData.allIDs: - isAnyIDnotExisting = True - continue - if ID not in IDs: - self.keptObjectsIDs.remove(ID, editText=False) - - self.updateTempLayerKeepIDs() - if isAnyIDnotExisting: - self.keptIDsLineEdit.warnNotExistingID() - else: - self.keptIDsLineEdit.setInstructionsText() - - @exception_handler - def applyKeepObjects(self): - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - self._keepObjects() - self.highlightHoverIDsKeptObj(0, 0, hoverID=0) - - posData = self.data[self.pos_i] - - self.update_rp() - # Repeat tracking - self.tracking(enforce=True, assign_unique_new_IDs=False) - - if self.isSnapshot: - self.fixCcaDfAfterEdit('Deleted non-selected objects') - self.updateAllImages() - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - return - else: - removeAnnot = self.warnEditingWithCca_df( - 'Deleted non-selected objects', get_answer=True - ) - if not removeAnnot: - # We can propagate changes only if the user agrees on - # removing annotations - return - - self.current_frame_i = posData.frame_i - if posData.frame_i > 0: - txt = html_utils.paragraph(""" - Do you want to remove un-kept objects in the past frames too? - """) - msg = widgets.myMessageBox(wrapText=False, showCentered=False) - _, _, applyToPastButton = msg.question( - self, 'Propagate to past frames?', txt, - buttonsTexts=('Cancel', 'No', 'Yes, apply to past frames') - ) - if msg.cancel: - return - if msg.clickedButton == applyToPastButton: - self.store_data() - self.logger.info('Applying keep objects to past frames...') - if not removeAnnot and posData.cca_df is not None: - delIDs = [ - ID for ID in posData.cca_df.index - if ID not in posData.IDs - ] - self.update_cca_df_deletedIDs(posData, delIDs) - - for i in tqdm(range(posData.frame_i), ncols=100): - lab = posData.allData_li[i]['labels'] - rp = posData.allData_li[i]['regionprops'] - keepLab = self._keepObjects(lab=lab, rp=rp) - # Store change - posData.allData_li[i]['labels'] = keepLab.copy() - # Get the rest of the stored metadata based on the new lab - posData.frame_i = i - self.get_data() - self.store_data(autosave=False) - - posData.frame_i = self.current_frame_i - self.get_data() - - # Ask to propagate change to all future visited frames - key = 'Keep ID' - askAction = self.askHowFutureFramesActions[key] - doNotShow = not askAction.isChecked() - (UndoFutFrames, applyFutFrames, endFrame_i, - doNotShowAgain) = self.propagateChange( - self.keptObjectsIDs, key, doNotShow, - posData.UndoFutFrames_keepID, posData.applyFutFrames_keepID, - force=True, applyTrackingB=True - ) - - if UndoFutFrames is None: - # Empty keep object list - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - return - - posData.doNotShowAgain_keepID = doNotShowAgain - posData.UndoFutFrames_keepID = UndoFutFrames - posData.applyFutFrames_keepID = applyFutFrames - includeUnvisited = posData.includeUnvisitedInfo['Keep ID'] - - if applyFutFrames: - self.store_data() - - self.logger.info('Applying to future frames...') - pbar = tqdm(total=posData.SizeT-posData.frame_i-1, ncols=100) - segmSizeT = len(posData.segm_data) - if not removeAnnot and posData.cca_df is not None: - delIDs = [ - ID for ID in posData.cca_df.index - if ID not in posData.IDs - ] - self.update_cca_df_deletedIDs(posData, delIDs) - - for i in range(posData.frame_i+1, segmSizeT): - lab = posData.allData_li[i]['labels'] - if lab is None and not includeUnvisited: - self.enqAutosave() - pbar.update(posData.SizeT-i) - break - - rp = posData.allData_li[i]['regionprops'] - - if lab is not None: - keepLab = self._keepObjects(lab=lab, rp=rp) - # Store change - posData.allData_li[i]['labels'] = keepLab.copy() - # Get the rest of the stored metadata based on the new lab - posData.frame_i = i - self.get_data() - self.store_data(autosave=False) - elif includeUnvisited: - # Unvisited frame (includeUnvisited = True) - lab = posData.segm_data[i] - rp = skimage.measure.regionprops(lab) - keepLab = self._keepObjects(lab=lab, rp=rp) - posData.segm_data[i] = keepLab - - pbar.update() - pbar.close() - - # Back to current frame - if applyFutFrames: - posData.frame_i = self.current_frame_i - self.get_data() - - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - - def updateLookuptable(self, lenNewLut=None, delIDs=None): - posData = self.data[self.pos_i] - if lenNewLut is None: - try: - if delIDs is None: - IDs = posData.IDs - else: - # Remove IDs removed with ROI from LUT - IDs = [ID for ID in posData.IDs if ID not in delIDs] - lenNewLut = max(IDs, default=0) + 1 - except ValueError: - # Empty segmentation mask - lenNewLut = 1 - # Build a new lut to include IDs > than original len of lut - updateLevels = self.extendLabelsLUT(lenNewLut) - lut = self.lut.copy() - - try: - # lut = self.lut[:lenNewLut].copy() - for ID in posData.binnedIDs: - lut[ID] = lut[ID]*0.2 - - for ID in posData.ripIDs: - lut[ID] = lut[ID]*0.2 - except Exception as e: - err_str = traceback.format_exc() - print('='*30) - self.logger.info(err_str) - print('='*30) - - if updateLevels: - self.img2.setLevels([0, len(lut)]) - - if self.keepIDsButton.isChecked(): - lut = np.round(lut*0.3).astype(np.uint8) - keptLut = np.round(lut[self.keptObjectsIDs]/0.3).astype(np.uint8) - lut[self.keptObjectsIDs] = keptLut - - self.img2.setLookupTable(lut) - - # @exec_time - def update_rp_metadata(self, draw=True): - posData = self.data[self.pos_i] - # Add to rp dynamic metadata (e.g. cells annotated as dead) - for i, obj in enumerate(posData.rp): - ID = obj.label - obj.excluded = ID in posData.binnedIDs - obj.dead = ID in posData.ripIDs - - def annotate_rip_and_bin_IDs(self, updateLabel=False): - depthAxes = self.switchPlaneCombobox.depthAxes() - if self.switchPlaneCombobox.isEnabled() and depthAxes != 'z': - return - - posData = self.data[self.pos_i] - binnedIDs_xx = [] - binnedIDs_yy = [] - ripIDs_xx = [] - ripIDs_yy = [] - for obj in posData.rp: - obj.excluded = obj.label in posData.binnedIDs - obj.dead = obj.label in posData.ripIDs - if not self.isObjVisible(obj.bbox): - continue - - if obj.excluded: - y, x = self.getObjCentroid(obj.centroid) - binnedIDs_xx.append(x) - binnedIDs_yy.append(y) - if updateLabel: - self.getObjOptsSegmLabels(obj) - how = self.drawIDsContComboBox.currentText() - - if obj.dead: - y, x = self.getObjCentroid(obj.centroid) - ripIDs_xx.append(x) - ripIDs_yy.append(y) - if updateLabel: - self.getObjOptsSegmLabels(obj) - how = self.drawIDsContComboBox.currentText() - - self.ax2_binnedIDs_ScatterPlot.setData(binnedIDs_xx, binnedIDs_yy) - self.ax2_ripIDs_ScatterPlot.setData(ripIDs_xx, ripIDs_yy) - self.ax1_binnedIDs_ScatterPlot.setData(binnedIDs_xx, binnedIDs_yy) - self.ax1_ripIDs_ScatterPlot.setData(ripIDs_xx, ripIDs_yy) - - def loadNonAlignedFluoChannel(self, fluo_path): - posData = self.data[self.pos_i] - if posData.filename.find('aligned') != -1: - filename, _ = os.path.splitext(os.path.basename(fluo_path)) - path = f'.../{posData.pos_foldername}/Images/{filename}_aligned.npz' - msg = widgets.myMessageBox() - msg.critical( - self, 'Aligned fluo channel not found!', - 'Aligned data for fluorescence channel not found!\n\n' - f'You loaded aligned data for the cells channel, therefore ' - 'loading NON-aligned fluorescence data is not allowed.\n\n' - 'Run the script "dataPrep.py" to create the following file:\n\n' - f'{path}' - ) - return None - fluo_data = np.squeeze(skimage.io.imread(fluo_path)) - return fluo_data - - def load_fluo_data(self, fluo_path, isGuiThread=True): - self.logger.info(f'Loading fluorescence image data from "{fluo_path}"...') - bkgrData = None - posData = self.data[self.pos_i] - # Load overlay frames and align if needed - filename = os.path.basename(fluo_path) - filename_noEXT, ext = os.path.splitext(filename) - if ext == '.npy' or ext == '.npz': - fluo_data = np.load(fluo_path) - try: - fluo_data = np.squeeze(fluo_data['arr_0']) - except Exception as e: - fluo_data = np.squeeze(fluo_data) - - # Load background data - bkgrData_path = os.path.join( - posData.images_path, f'{filename_noEXT}_bkgrRoiData.npz' - ) - if os.path.exists(bkgrData_path): - bkgrData = np.load(bkgrData_path) - elif ext == '.tif' or ext == '.tiff': - aligned_filename = f'{filename_noEXT}_aligned.npz' - aligned_path = os.path.join(posData.images_path, aligned_filename) - if os.path.exists(aligned_path): - fluo_data = np.load(aligned_path)['arr_0'] - - # Load background data - bkgrData_path = os.path.join( - posData.images_path, f'{aligned_filename}_bkgrRoiData.npz' - ) - if os.path.exists(bkgrData_path): - bkgrData = np.load(bkgrData_path) - else: - fluo_data = self.loadNonAlignedFluoChannel(fluo_path) - if fluo_data is None: - return None, None - - # Load background data - bkgrData_path = os.path.join( - posData.images_path, f'{filename_noEXT}_bkgrRoiData.npz' - ) - if os.path.exists(bkgrData_path): - bkgrData = np.load(bkgrData_path) - elif isGuiThread: - txt = html_utils.paragraph( - f'File format {ext} is not supported!\n' - 'Choose either .tif or .npz files.' - ) - msg = widgets.myMessageBox() - msg.critical(self, 'File not supported', txt) - return None, None - - return fluo_data, bkgrData - - def setOverlayColors(self): - self.overlayRGBs = [ - (255, 255, 0), - (252, 72, 254), - (49, 222, 134), - (22, 108, 27) - ] - self.overlayCmap = matplotlib.colormaps['hsv'] - self.overlayRGBs.extend( - [tuple([round(c*255) for c in self.overlayCmap(i)][:3]) - for i in np.linspace(0,1,8)] - ) - - def getFileExtensions(self, images_path): - alignedFound = any([f.find('_aligned.np')!=-1 - for f in myutils.listdir(images_path)]) - if alignedFound: - extensions = ( - 'Aligned channels (*npz *npy);; Tif channels(*tiff *tif)' - ';;All Files (*)' - ) - else: - extensions = ( - 'Tif channels(*tiff *tif);; All Files (*)' - ) - return extensions - - def loadOverlayData(self, ol_channels, addToExisting=False): - posData = self.data[self.pos_i] - for ol_ch in ol_channels: - if ol_ch not in list(posData.loadedFluoChannels): - # Requested channel was never loaded --> load it at first - # iter i == 0 - success = self.loadFluo_cb(fluo_channels=[ol_ch]) - if not success: - return False - - lastChannelName = ol_channels[-1] - for action in self.fluoDataChNameActions: - if action.text() == lastChannelName: - action.setChecked(True) - - for p, posData in enumerate(self.data): - if addToExisting: - ol_data = posData.ol_data - else: - ol_data = {} - for i, ol_ch in enumerate(ol_channels): - _, filename = self.getPathFromChName(ol_ch, posData) - ol_data[filename] = ( - posData.ol_data_dict[filename].copy() - ) - self.addFluoChNameContextMenuAction(ol_ch) - posData.ol_data = ol_data - - return True - - def askSelectOverlayChannel(self): - ch_names = [ch for ch in self.ch_names if ch != self.user_ch_name] - selectFluo = widgets.QDialogListbox( - 'Select channel', - 'Select channel names to overlay:\n', - ch_names, multiSelection=True, parent=self - ) - selectFluo.exec_() - if selectFluo.cancel: - return - - return selectFluo.selectedItemsText - - def overlayLabels_cb(self, checked, selectedLabelsEndnames=None): - if checked: - if not self.drawModeOverlayLabelsChannels: - if selectedLabelsEndnames is None: - selectedLabelsEndnames = self.askLabelsToOverlay() - if selectedLabelsEndnames is None: - self.logger.info('Overlay labels cancelled.') - self.overlayLabelsButton.setChecked(False) - return - for selectedEndname in selectedLabelsEndnames: - self.loadOverlayLabelsData(selectedEndname) - for action in self.overlayLabelsContextMenu.actions(): - if not action.isCheckable(): - continue - if action.text() == selectedEndname: - action.setChecked(True) - lastSelectedName = selectedLabelsEndnames[-1] - for action in self.selectOverlayLabelsActionGroup.actions(): - if action.text() == lastSelectedName: - action.setChecked(True) - self.updateAllImages() - - def askLabelsToOverlay(self): - selectOverlayLabels = widgets.QDialogListbox( - 'Select segmentation to overlay', - 'Select segmentation file to overlay:\n', - natsorted(self.existingSegmEndNames), - multiSelection=True, - parent=self - ) - selectOverlayLabels.exec_() - if selectOverlayLabels.cancel: - return - - return selectOverlayLabels.selectedItemsText - - def closeToolbars(self): - for toolbar in self.sender().toolbars: - toolbar.setVisible(False) - for action in toolbar.actions(): - try: - action.button.setChecked(False) - except Exception as e: - pass - - def askSaveAddedPoints(self): - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'Do you want to save the annotated points?' - ) - _, noButton, yesButton = msg.question( - self, 'Save?', txt, - buttonsTexts=('Cancel', 'No', 'Yes') - ) - if msg.clickedButton != yesButton: - return - - for toolbar in self.pointsLayersToolbars: - for action in self.pointsLayersToolbar.actions(): - try: - if 'Save annotated' in action.text(): - action.trigger() - except Exception as err: - pass - - def pointsLayerToggled(self, checked): - if not checked: - for action in self.pointsLayersToolbar.actions(): - try: - if 'Save annotated' in action.text(): - self.askSaveAddedPoints() - break - except Exception as err: - pass - self.pointsLayersToolbar.setVisible(checked) - self.autoPilotZoomToObjToolbar.setVisible(checked) - if self.pointsLayersNeverToggled: - self.pointsLayersToolbar.sigAddPointsLayer.emit() - self.pointsLayersNeverToggled = False - QTimer.singleShot(200, self.autoRange) - - def addPointsLayerTriggered(self, checked=False, toolbar=None): - if toolbar is None: - toolbar = self.pointsLayersToolbar - - if self.addPointsWin is not None: - self.logger.info( - 'Add points layer window is already open. Cannot add now.' - ) - return - - onlyMouseClicks = toolbar == self.promptSegmentPointsLayerToolbar - posData = self.data[self.pos_i] - self.addPointsWin = apps.AddPointsLayerDialog( - channelNames=posData.chNames, - imagesPath=posData.images_path, - hideCentroidsSection=onlyMouseClicks, - hideWeightedCentroidsSection=onlyMouseClicks, - hideFromTableSection=onlyMouseClicks, - hideManualEntrySection=onlyMouseClicks, - hideWithMouseClicksSection=False, - parent=self, - ) - cmap = matplotlib.colormaps['gist_rainbow'] - i = np.random.default_rng(seed=123).uniform() - for action in toolbar.actions()[1:]: - if not hasattr(action, 'layerTypeIdx'): - continue - rgb = [round(c*255) for c in cmap(i)][:3] - self.addPointsWin.appearanceGroupbox.colorButton.setColor(rgb) - break - - self.addPointsWin.sigCriticalReadTable.connect(self.logger.info) - self.addPointsWin.sigLoadedTable.connect(self.logLoadedTablePointsLayer) - self.addPointsWin.sigClosed.connect( - partial(self.addPointsLayer, toolbar=toolbar) - ) - self.addPointsWin.sigCheckClickEntryTableEndnameExists.connect( - self.checkClickEntryTableEndnameExists - ) - self.addPointsWin.show() - if self.addPointsWin.clickEntryRadiobutton.isChecked(): - QTimer.singleShot( - 200, - partial( - self.addPointsWin.sigCheckClickEntryTableEndnameExists.emit, - self.addPointsWin.clickEntryTableEndname.text(), - False - ) - ) - - def logLoadedTablePointsLayer(self, df, filename: str): - separator = f'-'*100 - header = f'First 10 rows of loaded table - "{filename}":' - footer = f'Number of points: {len(df)}' - text = ( - f'{separator}\n' - f'{header}\n\n' - f'{df.head(10)}\n\n' - f'{footer}\n' - f'{separator}' - ) - if filename: - text = f'{text}\nFilename: {filename}' - self.logger.info(text) - - def buttonAddPointsByClickingActive(self): - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'layerTypeIdx'): - continue - if action.layerTypeIdx == 4 and action.button.isChecked(): - return action.button - - def setupAddPointsByClicking(self, toolButton, isLoadedDf, toolbar): - self.LeftClickButtons.append(toolButton) - posData = self.data[self.pos_i] - tableEndName = self.addPointsWin.clickEntryTableEndnameText - if isLoadedDf is not None: - posData = self.data[self.pos_i] - tableEndName = tableEndName[len(posData.basename):] - self.loadClickEntryDfs(tableEndName) - - toolButton.toolbar = toolbar - toolButton.clickEntryTableEndName = tableEndName - self.checkableQButtonsGroup.addButton(toolButton) - toolButton.toggled.connect(self.addPointsByClickingButtonToggled) - - self.addPointsByClickingButtonToggled(sender=toolButton) - - toolButton.setToolTip(tableEndName) - - pointIdSpinbox = widgets.SpinBox() - pointIdSpinbox.setMinimum(0) - pointIdSpinbox.setValue(1) - pointIdSpinbox.label = QLabel(' Left-click ID: ') - pointIdSpinbox.labelAction = toolbar.addWidget(pointIdSpinbox.label) - if toolbar == self.promptSegmentPointsLayerToolbar: - newID = self.setBrushID(return_val=True) - pointIdSpinbox.setValue(newID) - pointIdSpinbox.setReadOnly(True) - pointIdSpinbox.setToolTip( - 'The ids added with left-click cannot be manually edited. ' - 'They are always a new, non-existing id.' - ) - - toolButton.actions.append(pointIdSpinbox.labelAction) - pointIdSpinbox.action = toolbar.addWidget(pointIdSpinbox) - toolButton.actions.append(pointIdSpinbox.action) - pointIdSpinbox.toolButton = toolButton - toolButton.pointIdSpinbox = pointIdSpinbox - - rightClickIDSpinbox = widgets.SpinBox() - pointIdSpinbox.setLinkedValueWidget(rightClickIDSpinbox) - rightClickIDSpinbox.setMaximumWidth(pointIdSpinbox.sizeHint().width()) - rightClickIDSpinbox.setValue(pointIdSpinbox.value()) - rightClickIDSpinbox.setMinimum(0) - rightClickIDSpinbox.label = QLabel(' | Right-click ID: ') - rightClickIDSpinbox.labelAction = toolbar.addWidget( - rightClickIDSpinbox.label - ) - toolButton.actions.append(rightClickIDSpinbox.labelAction) - rightClickIDSpinbox.action = toolbar.addWidget(rightClickIDSpinbox) - toolButton.actions.append(rightClickIDSpinbox.action) - rightClickIDSpinbox.toolButton = toolButton - toolButton.rightClickIDSpinbox = rightClickIDSpinbox - - saveToolbutton = widgets.SavePointsLayerButton( - tableEndName, parent=self - ) - saveToolbutton.sigRenameTableAction.connect( - self.updatePointsLayerClickEntryTableEndname - ) - saveToolbutton.sigLeftClick.connect(self.savePointsAddedByClicking) - saveAction = toolbar.addWidget(saveToolbutton) - saveToolbutton.action = saveAction - saveAction.saveToolbutton = saveToolbutton - saveAction.toolButton = toolButton - toolButton.saveAction = saveAction - toolButton.saveToolbutton = saveToolbutton - - toolButton.actions.append(saveAction) - - vlineAction = toolbar.addWidget(widgets.QVLine()) - spacerAction = toolbar.addWidget( - widgets.QHWidgetSpacer(width=5) - ) - - toolButton.actions.append(vlineAction) - toolButton.actions.append(spacerAction) - - action = toolButton.action - scatterItem = action.scatterItem - scatterItem.sigHoverEntered.connect( - self.addPointsByClickingScatterItemHoverEntered - ) - - self.pointsLayerClicksDfsToData(posData, toolbar=toolbar) - - def storeUndoAddPoint(self, action): - if not hasattr(self, 'undoAddPointQueueMapper'): - self.undoAddPointQueueMapper = defaultdict(list) - - posData = self.data[self.pos_i] - pointsDataPos = action.pointsData.get(self.pos_i) - if pointsDataPos is None: - return - - state = deepcopy(pointsDataPos) - self.undoAddPointQueueMapper[action].append(state) - self.undoAction.setEnabled(True) - - def undoAddPoint(self, action): - undoAddPointQueue = self.undoAddPointQueueMapper.get(action) - if undoAddPointQueue is None: - return False - - if len(undoAddPointQueue) == 0: - return False - - posData = self.data[self.pos_i] - state = undoAddPointQueue.pop(-1) - action.pointsData[self.pos_i] = state - self.markPointsLayerDirty(action=action) - - self.drawPointsLayers(computePointsLayers=False) - - if len(self.undoAddPointQueueMapper[action]) == 0: - self.undoAction.setEnabled(True) - - return True - - def getAddedPointId( - self, isMagicPrompts, addPointsByClickingButton, - right_click, left_click, middle_click - ): - action = addPointsByClickingButton.action - if right_click: - id = addPointsByClickingButton.rightClickIDSpinbox.value() - elif left_click: - id = addPointsByClickingButton.pointIdSpinbox.value() - id = self.getClickedPointNewId( - action, id, addPointsByClickingButton.pointIdSpinbox, - isMagicPrompts=isMagicPrompts - ) - if isMagicPrompts: - proceed = self.warnAddingPointWithExistingId(id) - if not proceed: - return - - addPointsByClickingButton.pointIdSpinbox.setValue(id) - elif middle_click: - id = 0 - - return id - - def addPointsByClickingScatterItemHoverEntered(self, item, points, event): - point = points[0] - point_id = point.data() - toolButton = item.action.button - toolButton.rightClickIDSpinbox.prevId = ( - toolButton.rightClickIDSpinbox.value() - ) - toolButton.rightClickIDSpinbox.setValue(point_id) - - def autoPilotZoomToObjToggled(self, checked): - if not checked: - self.zoomOut() - return - - posData = self.data[self.pos_i] - if not posData.IDs: - self.logger.info('There are no objects in current segmentation mask') - return - self.autoPilotZoomToObjSpinBox.setValue(posData.IDs[0]) - self.zoomToObj(posData.rp[0]) - - def savePointsAddedByClickingFromEndname(self, tableEndName, recovery=False): - self.pointsLayerDataToDf(self.data[self.pos_i]) - for posData in self.data: - if not posData.basename.endswith('_'): - basename = f'{posData.basename}_' - else: - basename = posData.basename - tableFilename = f'{basename}{tableEndName}.csv' - if recovery: - tableFilepath = os.path.join( - posData.recoveryFolderpath(), tableFilename - ) - else: - tableFilepath = os.path.join(posData.images_path, tableFilename) - df = posData.clickEntryPointsDfs.get(tableEndName) - if df is None: - continue - df = df.sort_values(['frame_i', 'Cell_ID']) - df.to_csv(tableFilepath, index=False) - - def markPointsLayerDirty(self, tableEndName=None, action=None): - if tableEndName is None and action is not None: - tableEndName = getattr(action, 'clickEntryTableEndName', None) - - if tableEndName is None: - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - if addPointsByClickingButton is None: - return - tableEndName = addPointsByClickingButton.clickEntryTableEndName - - self.dirtyPointsLayerTableEndNames.add(tableEndName) - - def flushDirtyPointsLayersAutosave(self): - if not self.dirtyPointsLayerTableEndNames: - return - - for tableEndName in tuple(self.dirtyPointsLayerTableEndNames): # avoid runtime error - self.savePointsAddedByClickingFromEndname( - tableEndName, recovery=True - ) - - self.dirtyPointsLayerTableEndNames.clear() - - @exception_handler - def savePointsAddedByClicking(self, button, event): - sender = button.action - toolButton = sender.toolButton - tableEndName = toolButton.clickEntryTableEndName - - self.logger.info(f'Saving _{tableEndName}.csv table...') - - self.savePointsAddedByClickingFromEndname(tableEndName) - - self.logger.info(f'{tableEndName}.csv saved!') - self.titleLabel.setText(f'{tableEndName}.csv saved!', color='g') - - def updatePointsLayerClickEntryTableEndname( - self, saveToolbutton, table_endname - ): - saveAction = saveToolbutton.action - toolButton = saveAction.toolButton - toolButton.clickEntryTableEndName = table_endname - - self.logger.info( - f'Done. Click entry table endname updated to "{table_endname}"' - ) - - def pointsLayerDfsToData(self, posData): - self.pointsLayerClicksDfsToData(posData) - - def pointsLayerLoadedDfsToData(self): - posData = self.data[self.pos_i] - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'loadedDfInfo'): - continue - - if action.loadedDfInfo is None: - continue - - endname = action.loadedDfInfo.get('endname') - if endname is None: - continue - - filename = f'{posData.basename}{endname}' - filepath = os.path.join(posData.images_path, filename) - if not os.path.exists(filepath): - action.pointsData[self.pos_i] = {} - - df = load.load_df_points_layer(filepath) - action.pointsData[self.pos_i] = ( - load.loaded_df_to_points_data( - df, action.loadedDfInfo['t'], action.loadedDfInfo['z'], - action.loadedDfInfo['y'], action.loadedDfInfo['x'] - ) - ) - self.logLoadedTablePointsLayer(df, filename=filename) - - def setPointsLayerLoadedDfEndanme(self, action): - if action.loadedDfInfo is None: - return - - posData = self.data[self.pos_i] - images_path = posData.images_path.replace('\\', '/') - - df_folderpath = os.path.dirname( - action.loadedDfInfo['filepath'].replace('\\', '/') - ) - - if images_path != df_folderpath: - return - - df_filename = os.path.basename(action.loadedDfInfo['filepath']) - - if not df_filename.startswith(posData.basename): - return - - endname = df_filename[len(posData.basename):] - action.loadedDfInfo['endname'] = endname - - action.button.setToolTip(endname) - - def pointsLayerClicksDfsToData(self, posData, toolbar=None): - if toolbar is None: - toolbar = self.pointsLayersToolbar - - for action in toolbar.actions()[1:]: - if not hasattr(action, 'button'): - continue - - if not hasattr(action.button, 'clickEntryTableEndName'): - continue - tableEndName = action.button.clickEntryTableEndName - action.pointsData[self.pos_i] = {} - if posData.clickEntryPointsDfs.get(tableEndName) is None: - continue - - df = posData.clickEntryPointsDfs[tableEndName] - - if posData.SizeZ > 1 and df['z'].isna().any(): - self.warnLoadedPointsTableIsNot3D(tableEndName) - return - - for frame_i, df_frame in df.groupby('frame_i'): - action.pointsData[self.pos_i][frame_i] = {} - if posData.SizeZ > 1: - for z, df_zlice in df_frame.groupby('z'): - xx = df_zlice['x'].to_list() - yy = df_zlice['y'].to_list() - ids = df_zlice['id'].to_list() - action.pointsData[self.pos_i][frame_i][z] = { - 'x': xx, 'y': yy, 'id': ids - } - else: - xx = df_frame['x'].to_list() - yy = df_frame['y'].to_list() - ids = df_frame['id'].to_list() - action.pointsData[self.pos_i][frame_i] = { - 'x': xx, 'y': yy, 'id': ids - } - - def pointsLayerDataToDf(self, posData, getOnlyActive=False, toolbar=None): - df = None - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'button'): - continue - if not hasattr(action.button, 'clickEntryTableEndName'): - continue - - tableEndName = action.button.clickEntryTableEndName - if getOnlyActive and not action.button.isChecked(): - continue - - df = toolbar.fromActionToDataFrame( - action, posData, isSegm3D=self.isSegm3D - ) - posData.clickEntryPointsDfs[tableEndName] = df - return df - - def restartZoomAutoPilot(self): - if not self.autoPilotZoomToObjToggle.isChecked(): - return - - posData = self.data[self.pos_i] - if not posData.IDs: - return - - self.autoPilotZoomToObjSpinBox.setValue(posData.IDs[0]) - self.zoomToObj(posData.rp[0]) - - def resizeRangeWelcomeText(self): - xRange, yRange = self.ax1.viewRange() - deltaX = xRange[1] - xRange[0] - deltaY = yRange[1] - yRange[0] - self.ax1.setXRange(0, deltaX) - self.ax1.setYRange(0, deltaY) - self.ax1.setLimits( - xMin=0, xMax=deltaX, yMin=0, yMax=deltaY - ) - # self.ax1.setXRange(0, 0) - # self.ax1.setYRange(0, 0) - - def zoomToObj(self, obj=None): - if not hasattr(self, 'data'): - return - posData = self.data[self.pos_i] - if obj is None: - ID = self.sender().value() - try: - ID_idx = posData.IDs_idxs[ID] - obj = obj = posData.rp[ID_idx] - except Exception as e: - self.logger.warning( - f'ID {ID} does not exist (add points by clicking)' - ) - - if obj is None: - return - - self.goToZsliceSearchedID(obj) - min_row, min_col, max_row, max_col = self.getObjBbox(obj.bbox) - xRange = min_col-5, max_col+5 - yRange = max_row+5, min_row-5 - - self.ax1.setRange(xRange=xRange, yRange=yRange) - - def addPointsByClickingButtonToggled(self, checked=True, sender=None): - if sender is None: - sender = self.sender() - if not sender.isChecked(): - action = sender.action - action.scatterItem.setVisible(False) - return - - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(sender) - self.connectLeftClickButtons() - action = sender.action - action.scatterItem.setVisible(True) - self.ax1_BrushCircle.setBrush(action.brushColor) - self.ax1_BrushCircle.setPen(action.penColor) - - def autoZoomNextObj(self): - self.sender().setValue(self.sender().value() - 1) - self.pointsLayerAutoPilot('next') - self.setFocusMain() - self.setFocusGraphics() - - def autoZoomPrevObj(self): - self.sender().setValue(self.sender().value() + 1) - self.pointsLayerAutoPilot('prev') - self.setFocusMain() - self.setFocusGraphics() - - def pointsLayerAutoPilot(self, direction): - if not self.autoPilotZoomToObjToggle.isChecked(): - return - ID = self.autoPilotZoomToObjSpinBox.value() - posData = self.data[self.pos_i] - if not posData.IDs: - return - - try: - ID_idx = posData.IDs_idxs[ID] - if direction == 'next': - nextID_idx = ID_idx + 1 - else: - nextID_idx = ID_idx - 1 - obj = posData.rp[nextID_idx] - except Exception as e: - self.logger.info( - f'Auto-pilot restarted from first ID' - ) - obj = posData.rp[0] - - self.autoPilotZoomToObjSpinBox.setValue(obj.label) - self.zoomToObj(obj) - - def getClickEntryTableFilepaths(self, posData, tableEndName): - if posData.basename.endswith('_'): - basename = posData.basename - else: - basename = f'{posData.basename}_' - - csv_filename = f'{basename}{tableEndName}' - if not csv_filename.endswith('.csv'): - csv_filename = f'{csv_filename}.csv' - - filepath = os.path.join(posData.images_path, csv_filename) - recovery_filepath = os.path.join( - posData.images_path, 'recovery', csv_filename - ) - return filepath, recovery_filepath - - def getClickEntryNewerRecoveryFilepaths(self, tableEndName): - newer_recovery_filepaths = [] - for posData in self.data: - filepath, recovery_filepath = self.getClickEntryTableFilepaths( - posData, tableEndName - ) - if not os.path.exists(filepath) or not os.path.exists(recovery_filepath): - continue - - if os.path.getmtime(recovery_filepath) <= os.path.getmtime(filepath) + 15: # add a 15 second tolerance - continue - - newer_recovery_filepaths.append((filepath, recovery_filepath)) - - return newer_recovery_filepaths - - def askLoadNewerRecoveryClickEntryDfs( - self, tableEndName, newer_recovery_filepaths - ): - if not newer_recovery_filepaths: - return False - - num_tables = len(newer_recovery_filepaths) - filepath, recovery_filepath = newer_recovery_filepaths[0] - main_timestamp = datetime.fromtimestamp( - os.path.getmtime(filepath) - ).strftime('%a %d. %b. %y - %H:%M:%S') - recovery_timestamp = datetime.fromtimestamp( - os.path.getmtime(recovery_filepath) - ).strftime('%a %d. %b. %y - %H:%M:%S') - - if num_tables == 1: - text = html_utils.paragraph( - f'A newer recovery version of {tableEndName}.csv ' - 'was found.

' - f'Main table save date: {main_timestamp}
' - f'Recovery save date: {recovery_timestamp}

' - 'Do you want to load the newer recovery version?' - ) - else: - text = html_utils.paragraph( - f'Newer recovery versions of {tableEndName}.csv ' - f'were found for {num_tables} positions.

' - f'Example main table save date: {main_timestamp}
' - f'Example recovery save date: {recovery_timestamp}

' - 'Do you want to load the newer recovery version where available?' - ) - - msg = widgets.myMessageBox(wrapText=False) - _, yesButton, _ = msg.warning( - self.addPointsWin, 'Newer recovery table found', text, - buttonsTexts=('Cancel', 'Yes, load newer recovery', 'No, load main table') - ) - return msg.clickedButton == yesButton - - def checkClickEntryTableEndnameExists(self, tableEndName, forceLoading): - doesTableExists = False - for posData in self.data: - filepath, _ = self.getClickEntryTableFilepaths(posData, tableEndName) - if os.path.exists(filepath): - doesTableExists = True - break - - if not doesTableExists: - return - - if not forceLoading: - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - f'The table {tableEndName}.csv already exists!

' - 'Do you want to load it?' - ) - _, yesButton, _ = msg.warning( - self.addPointsWin, 'Table exists!', txt, - buttonsTexts=('Cancel', 'Yes, load it', 'No, let me enter a new name') - ) - if msg.clickedButton != yesButton: - return - - newer_recovery_filepaths = self.getClickEntryNewerRecoveryFilepaths( - tableEndName - ) - load_recovery_if_newer = self.askLoadNewerRecoveryClickEntryDfs( - tableEndName, newer_recovery_filepaths - ) - - self.loadClickEntryDfs( - tableEndName, loadRecoveryIfNewer=load_recovery_if_newer - ) - - def checkLoadedTableIds(self, toolbar): - if toolbar != self.promptSegmentPointsLayerToolbar: - return True - - for posData in self.data: - for tableEndName, df in posData.clickEntryPointsDfs.items(): - for point_id in df['id'].values: - if point_id in posData.IDs_idxs: - proceed = self.warnAddingPointWithExistingId( - point_id, table_endname=tableEndName - ) - return proceed - - return True - - @exception_handler - def addPointsLayer(self, toolbar=None): - proceed = self.checkLoadedTableIds(toolbar) - - if self.addPointsWin.cancel or not proceed: - self.addPointsWin = None - self.logger.info('Adding points layer cancelled.') - return - - if toolbar is None: - toolbar = self.pointsLayersToolbar - - symbol = self.addPointsWin.symbol - color = self.addPointsWin.color - pointSize = self.addPointsWin.pointSize - zRadius = int((self.addPointsWin.zHeight-1)/2) - r,g,b,a = color.getRgb() - - scatterItem = widgets.PointsScatterPlotItem( - [], [], ax=self.ax1, symbol=symbol, pxMode=False, size=pointSize, - brush=pg.mkBrush(color=(r,g,b,100)), - pen=pg.mkPen(width=2, color=(r,g,b)), - hoverable=True, hoverBrush=pg.mkBrush((r,g,b,200)), - tip=None, show_data_as_tip=True - ) - self.ax1.addItem(scatterItem) - - toolButton = widgets.PointsLayerToolButton(symbol, color, parent=self) - toolButton.actions = [] - toolButton.setCheckable(True) - toolButton.setChecked(True) - if self.addPointsWin.keySequence is not None: - toolButton.setShortcut(self.addPointsWin.keySequence) - toolButton.toggled.connect(self.pointLayerToolbuttonToggled) - toolButton.sigEditAppearance.connect(self.editPointsLayerAppearance) - toolButton.sigShowIdsToggled.connect(self.showPointsLayerIdsToggled) - toolButton.sigRemove.connect( - partial(self.removePointsLayer, toolbar=toolbar) - ) - - action = toolbar.addWidget(toolButton) - action.state = self.addPointsWin.state() - - toolButton.action = action - action.brushColor = (r,g,b,100) - action.brushColorId0 = ( - *colors.hex_to_rgb( - colors.lighten_color( - np.array(action.brushColor)/255, 0.3 - ) - ), 100 - ) - action.penColor = (r,g,b) - action.penColorId0 = colors.lighten_color( - np.array(action.penColor)/255, 0.3 - ) - action.pointSize = pointSize - action.zRadius = zRadius - action.button = toolButton - action.scatterItem = scatterItem - scatterItem.action = action - action.layerType = self.addPointsWin.layerType - action.layerTypeIdx = self.addPointsWin.layerTypeIdx - action.loadedDf = self.addPointsWin.loadedDf - posData = self.data[self.pos_i] - action.pointsData = {} - action.pointsData[self.pos_i] = self.addPointsWin.pointsData - action.snapToMax = False - action.loadedDfInfo = self.addPointsWin.loadedDfInfo - self.setPointsLayerLoadedDfEndanme(action) - - if self.addPointsWin.layerType.startswith('Click to annotate point'): - action.snapToMax = self.addPointsWin.snapToMaxToggle.isChecked() - isLoadedDf = self.addPointsWin.clickEntryIsLoadedDf - self.setupAddPointsByClicking( - toolButton, isLoadedDf, toolbar=toolbar - ) - if self.addPointsWin.autoPilotToggle.isChecked(): - self.autoPilotZoomToObjToggle.setChecked(True) - - weighingChannel = self.addPointsWin.weighingChannel - self.loadPointsLayerWeighingData(action, weighingChannel) - - self.drawPointsLayers() - - if toolbar == self.promptSegmentPointsLayerToolbar: - self.promptSegmentPointsLayerToolbar.isPointsLayerInit = True - self.magicPromptsToolbar.clearPointsAction.setDisabled(False) - self.magicPromptsToolbar.clearPointsActionOnZoom.setDisabled(False) - QTimer.singleShot( - 200, self.magicPromptsToolbar.selectModelAction.trigger - ) - - self.addPointsWin = None - - def loadClickEntryDfs(self, tableEndName, loadRecoveryIfNewer=False): - for posData in self.data: - filepath, recovery_filepath = self.getClickEntryTableFilepaths( - posData, tableEndName - ) - - if loadRecoveryIfNewer: - recovery_exists = os.path.exists(recovery_filepath) - main_exists = os.path.exists(filepath) - if ( - recovery_exists - and ( - not main_exists - or os.path.getmtime(recovery_filepath) - > os.path.getmtime(filepath) + 15 - ) - ): - filepath = recovery_filepath - elif not main_exists: - continue - - if not os.path.exists(filepath): - continue - - self.logger.info(f'Loading points from "{filepath}"...') - df = pd.read_csv(filepath) - if 'id' not in df.columns: - df['id'] = range(1, len(df)+1) - posData.clickEntryPointsDfs[tableEndName] = df - - try: - self.addPointsWin.loadButton.confirmAction() - except Exception as err: - pass - - def removeClickedPoints(self, action, points): - posData = self.data[self.pos_i] - framePointsData = action.pointsData[self.pos_i][posData.frame_i] - if posData.SizeZ > 1: - zProjHow = self.zProjComboBox.currentText() - if zProjHow != 'single z-slice': - _warnings.warnCannotAddRemovePointsProjection() - return - zSlice = self.zSliceScrollBar.sliderPosition() - else: - zSlice = None - - removed_ids = [] - for point in points: - pos = point.pos() - x, y = pos.x(), pos.y() - if zSlice is not None: - zSliceRad = action.zRadius - sliceFramePointsData = [framePointsData[z] for z in range( - zSlice-zSliceRad, zSlice+zSliceRad+1 - ) if z in framePointsData.keys()] - else: - sliceFramePointsData = [framePointsData] - - - for sliceFramePointsData in sliceFramePointsData: - if point.data() in sliceFramePointsData['id']: - sliceFramePointsData['x'].remove(x) - sliceFramePointsData['y'].remove(y) - sliceFramePointsData['id'].remove(point.data()) - removed_ids.append(point.data()) - - if removed_ids: - self.markPointsLayerDirty(action=action) - - return removed_ids - - def restorePrevPointIdRightClick(self, addPointsByClickingButton): - # Try to restore the id that was there before hovering - # because the hovering was required only to delete the - # point - try: - prevId = addPointsByClickingButton.rightClickIDSpinbox.prevId - addPointsByClickingButton.rightClickIDSpinbox.setValue(prevId) - except Exception as err: - addPointsByClickingButton.rightClickIDSpinbox.prevId = None - - def getClickedPointNewId( - self, action, current_id, pointIdSpinbox, isMagicPrompts=False - ): - removed_id = getattr(pointIdSpinbox, 'removedId', None) - if removed_id is not None: - pointIdSpinbox.removedId = None - return removed_id - - posData = self.data[self.pos_i] - if isMagicPrompts: - is_already_new = self.isPointIdAlreadyNew(current_id, action) - if is_already_new: - return current_id - - new_ID = self.setBrushID(return_val=True) - new_id = max(current_id, new_ID) + 1 - return new_id - else: - pointsDataPos = action.pointsData.get(self.pos_i) - if pointsDataPos is None: - return 1 - - framePointsData = pointsDataPos.get(posData.frame_i) - if framePointsData is None: - return 1 - if posData.SizeZ > 1: - new_id = 1 - for z_data in framePointsData.values(): - max_id = max(z_data.get('id', 0), default=0) + 1 - if max_id > new_id: - new_id = max_id - else: - new_id = max(framePointsData.get('id', 0), default=0) + 1 - if current_id >= new_id: - return current_id - return new_id - - def setHoverCircleAddPoint(self, x, y): - addPointsByClickingButton = self.buttonAddPointsByClickingActive() - if addPointsByClickingButton is None: - return - action = addPointsByClickingButton.action - self.setHoverToolSymbolData( - [x], [y], (self.ax1_BrushCircle,), - size=action.pointSize - ) - - def isPointIdAlreadyNew(self, point_id, action): - posData = self.data[self.pos_i] - if point_id in posData.IDs_idxs: - return False - - is_ID = point_id in posData.IDs_idxs - pointsDataPos = action.pointsData.get(self.pos_i) - if pointsDataPos is None: - return not is_ID - - framePointsData = pointsDataPos.get(posData.frame_i) - if framePointsData is None: - return not is_ID - - if 'x' not in framePointsData: - is_id_already_added = False - for z, z_data in framePointsData.items(): - if point_id in z_data['id']: - is_id_already_added = True - break - else: - is_id_already_added = point_id in framePointsData['id'] - - is_already_new = not is_ID and not is_id_already_added - return is_already_new - - def addClickedPoint(self, action, x, y, id): - x, y = round(x, 2), round(y, 2) - posData = self.data[self.pos_i] - pointsDataPos = action.pointsData.get(self.pos_i) - if pointsDataPos is None: - action.pointsData[self.pos_i] = {} - - framePointsData = action.pointsData[self.pos_i].get(posData.frame_i) - if action.snapToMax: - radius = round(action.pointSize/2) - rr, cc = skimage.draw.disk((round(y), round(x)), radius) - idx_max = (self.img1.image[rr, cc]).argmax() - y, x = rr[idx_max], cc[idx_max] - - if framePointsData is None: - if posData.SizeZ > 1: - zSlice = self.zSliceScrollBar.sliderPosition() - action.pointsData[self.pos_i][posData.frame_i] = { - zSlice: {'x': [x], 'y': [y], 'id': [id]} - } - else: - action.pointsData[self.pos_i][posData.frame_i] = { - 'x': [x], 'y': [y], 'id': [id] - } - else: - if posData.SizeZ > 1: - zSlice = self.zSliceScrollBar.sliderPosition() - z_data = framePointsData.get(zSlice) - if z_data is None: - framePointsData[zSlice] = {'x': [x], 'y': [y], 'id': [id]} - else: - framePointsData[zSlice]['x'].append(x) - framePointsData[zSlice]['y'].append(y) - framePointsData[zSlice]['id'].append(id) - action.pointsData[self.pos_i][posData.frame_i] = ( - framePointsData - ) - else: - pointsDataPos = action.pointsData[self.pos_i] - framePointsData = pointsDataPos[posData.frame_i] - framePointsData['x'].append(x) - framePointsData['y'].append(y) - framePointsData['id'].append(id) - - self.markPointsLayerDirty(action=action) - - def showPointsLayerIdsToggled(self, button, checked): - button.action.scatterItem.drawIds = checked - self.drawPointsLayers() - - def removePointsLayer(self, button, toolbar=None): - button.setChecked(False) - button.action.scatterItem.setData([], []) - button.action.loadedDfInfo = None - self.ax1.removeItem(button.action.scatterItem) - toolbar.removeAction(button.action) - for action in button.actions: - toolbar.removeAction(action) - - if toolbar == self.promptSegmentPointsLayerToolbar: - self.promptSegmentPointsLayerToolbar.isPointsLayerInit = False - - def editPointsLayerAppearance(self, button): - win = apps.EditPointsLayerAppearanceDialog(parent=self) - win.restoreState(button.action.state) - win.exec_() - if win.cancel: - return - - symbol = win.symbol - color = win.color - pointSize = win.pointSize - zRadius = int((win.zHeight-1)/2) - r,g,b,a = color.getRgb() - - scatterItem = button.action.scatterItem - scatterItem.opts['hoverBrush'] = pg.mkBrush((r,g,b,200)) - scatterItem.setSymbol(symbol, update=False) - scatterItem.setBrush(pg.mkBrush(color=(r,g,b,100)), update=False) - scatterItem.setPen(pg.mkPen(width=2, color=(r,g,b)), update=False) - scatterItem.setSize(pointSize, update=True) - - button.action.brushColor = (r,g,b,100) - button.action.penColor = (r,g,b) - button.action.pointSize = pointSize - button.action.zRadius = zRadius - - button.action.state = win.state() - - def loadPointsLayerWeighingData(self, action, weighingChannel): - if not weighingChannel: - return - - self.logger.info(f'Loading "{weighingChannel}" weighing data...') - action.weighingData = [] - for p, posData in enumerate(self.data): - if weighingChannel == posData.user_ch_name: - wData = posData.img_data - action.weighingData.append(wData) - continue - - path, filename = self.getPathFromChName(weighingChannel, posData) - if path is None: - self.criticalFluoChannelNotFound(weighingChannel, posData) - action.weighingData = [] - return - - if filename in posData.fluo_data_dict: - # Weighing data already loaded as additional fluo channel - wData = posData.fluo_data_dict[filename] - else: - # Weighing data never loaded --> load now - wData, _ = self.load_fluo_data(path) - if posData.SizeT == 1: - wData = wData[np.newaxis] - action.weighingData.append(wData) - - def pointLayerToolbuttonToggled(self, checked): - action = self.sender().action - action.scatterItem.setVisible(checked) - - def getCentroidsPointsData(self, action): - # Centroids (either weighted or not) - # NOTE: if user requested to draw from table we load that in - # apps.AddPointsLayerDialog.ok_cb() - posData = self.data[self.pos_i] - action.pointsData[self.pos_i] = {posData.frame_i: {}} - if hasattr(action, 'weighingData'): - lab = posData.lab - img = action.weighingData[self.pos_i][posData.frame_i] - rp = skimage.measure.regionprops(lab, intensity_image=img) - attr = 'weighted_centroid' - else: - rp = posData.rp - attr = 'centroid' - for i, obj in enumerate(rp): - centroid = getattr(obj, attr) - if len(centroid) == 3: - zc, yc, xc = centroid - z_int = round(zc) - if z_int not in action.pointsData[self.pos_i][posData.frame_i]: - action.pointsData[self.pos_i][posData.frame_i][z_int] = { - 'x': [xc], 'y': [yc], 'id': [obj.label] - } - else: - z_data = action.pointsData[self.pos_i][posData.frame_i][z_int] - z_data['x'].append(xc) - z_data['y'].append(yc) - z_data['id'].append(obj.label) - else: - yc, xc = centroid - if 'y' not in action.pointsData[self.pos_i][posData.frame_i]: - action.pointsData[self.pos_i][posData.frame_i]['y'] = [yc] - action.pointsData[self.pos_i][posData.frame_i]['x'] = [xc] - action.pointsData[self.pos_i][posData.frame_i]['id'] = ( - [obj.label] - ) - else: - action.pointsData[self.pos_i][posData.frame_i]['y'].append(yc) - action.pointsData[self.pos_i][posData.frame_i]['x'].append(xc) - action.pointsData[self.pos_i][posData.frame_i]['id'].append( - obj.label - ) - - def drawPointsLayers(self, computePointsLayers=True): - posData = self.data[self.pos_i] - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'layerTypeIdx'): - continue - - if action.layerTypeIdx < 2 and computePointsLayers: - self.getCentroidsPointsData(action) - - if not action.button.isChecked(): - continue - - frames = action.pointsData.get(self.pos_i, set()) - if posData.frame_i not in frames: - if action.layerTypeIdx != 4: - self.logger.info( - f'Frame number {posData.frame_i+1} does not have any ' - f'"{action.layerType}" point to display.' - ) - continue - - framePointsData = action.pointsData[self.pos_i][posData.frame_i] - - if 'x' not in framePointsData: - # 3D points - zProjHow = self.zProjComboBox.currentText() - isZslice = ( - zProjHow == 'single z-slice' and posData.SizeZ > 1 - ) - if isZslice: - xx, yy, ids, data = [], [], [], [] - zSlice = self.zSliceScrollBar.sliderPosition() - zRadius = action.zRadius - zRange = range(zSlice-zRadius, zSlice+zRadius+1) - for z in zRange: - z_data = framePointsData.get(z) - if z_data is None: - continue - xx.extend(z_data['x']) - yy.extend(z_data['y']) - ids.extend(z_data['id']) - try: - data.extend(z_data['data']) - except KeyError as err: - # data is needed only for loaded tables - pass - else: - xx, yy, ids, data = [], [], [], [] - # z-projection --> draw all points - for z, z_data in framePointsData.items(): - xx.extend(z_data['x']) - yy.extend(z_data['y']) - ids.extend(z_data['id']) - try: - data.extend(z_data['data']) - except KeyError as err: - # data is needed only for loaded tables - pass - else: - # 2D segmentation - xx = framePointsData['x'] - yy = framePointsData['y'] - ids = framePointsData['id'] - try: - data = framePointsData['data'] - except KeyError as err: - # data is needed only for loaded tables - pass - - brushColors = [ - action.brushColor if id != 0 else action.brushColorId0 - for id in ids - ] - brushes = [pg.mkBrush(color) for color in brushColors] - - pensColor = [ - action.penColor if id != 0 else action.penColorId0 - for id in ids - ] - pens = [pg.mkPen(color) for color in pensColor] - - if action.layerTypeIdx == 2: - # For loaded table show the rest of the table as a tooltip - data = data - show_data_as_tip = True - else: - data = ids - show_data_as_tip = False - - xx = np.array(xx) # + 0.5 - yy = np.array(yy) # + 0.5 - - action.scatterItem.show_data_as_tip = show_data_as_tip - action.scatterItem.setData( - xx, yy, data=data, brush=brushes, pen=pens - ) - - def setOverlaySingleChannel(self, *args, **kwargs): - if self.overlayToolbar.isSingleChannel(): - self.overlayToolbarAreChannelsChecked = { - channel:toolbutton.isChecked() - for channel, toolbutton in self.allOverlayToolbuttons.items() - } - firstActiveToolbutton = [ - toolbutton for toolbutton in self.allOverlayToolbuttons.values() - if toolbutton.isChecked() - ][0] - firstActiveToolbutton.click() - else: - for ch, checked in self.overlayToolbarAreChannelsChecked.items(): - toolbutton = self.allOverlayToolbuttons[ch] - toolbutton.setChecked(checked) - - self.setOverlayItemsOpacities() - - def updateTransparentOverlayRgba(self, *args, **kwargs): - self.setOverlayImages() - - def setOverlayTransparency(self, transparent: bool): - opacity = float(transparent) - opacity = opacity if opacity < 1.0 else 0.999 - self.rgbaImg1.setOpacity(opacity) - - if transparent: - self.img1.setOpacity(0.001, applyToLinked=False) - self.imgGrad.sigLookupTableChanged.connect( - self.updateTransparentOverlayRgba - ) - self.imgGrad.sigLevelsChanged.connect( - self.updateTransparentOverlayRgba - ) - - for channel, items in self.overlayLayersItems.items(): - imageItem, lutItem, alphaSB = items[:3] - if transparent: - alphaSB.valueChanged.disconnect() - alphaSB.valueChanged.connect( - self.updateTransparentOverlayRgba - ) - lutItem.sigLookupTableChanged.connect( - self.updateTransparentOverlayRgba - ) - lutItem.sigLevelsChanged.connect( - self.updateTransparentOverlayRgba - ) - imageItem.setOpacity(0) - - if not transparent: - self.setOverlayItemsOpacities() - - self.setOverlayImages() - - def overlay_cb(self, checked): - self.overlayToolbar.setVisible(checked) - - self.UserNormAction, _, _ = self.getCheckNormAction() - posData = self.data[self.pos_i] - if checked: - if posData.ol_data is None: - selectedChannels = self.askSelectOverlayChannel() - if selectedChannels is None: - self.overlayButton.toggled.disconnect() - self.overlayButton.setChecked(False) - self.overlayButton.toggled.connect(self.overlay_cb) - return - - success = self.loadOverlayData(selectedChannels) - if not success: - return False - lastChannel = selectedChannels[-1] - self.setCheckedOverlayContextMenusActions(selectedChannels) - imageItem = self.overlayLayersItems[lastChannel][0] - self.setOpacityOverlayLayersItems(None, imageItem=imageItem) - self.setOverlayChannelsToolbuttonsChecked() - - self.setRetainSizePolicyLutItems() - self.normalizeRescale0to1Action.setChecked(True) - - self.updateAllImages() - self.updateImageValueFormatter() - self.enableOverlayWidgets(True) - else: - self.img1.setOpacity(1.0) - self.updateAllImages() - self.updateImageValueFormatter() - self.enableOverlayWidgets(False) - self.clearOverlayImageItems() - - - self.setOverlayItemsVisible() - - def countObjectsCb(self, checked): - if self.countObjsWindow is None: - categoryCountMapper = self.countObjects() - self.countObjsWindow = apps.ObjectCountDialog( - categoryCountMapper=categoryCountMapper, - parent=self, - data=self.data - ) - self.countObjsWindow.sigShowEvent.connect(self.updateObjectCounts) - self.countObjsWindow.sigUpdateCounts.connect(self.updateObjectCounts) - - if checked: - self.countObjsWindow.show() - else: - self.countObjsWindow.hide() - - def showLabelRoiContextMenu(self, event): - menu = QMenu(self.labelRoiButton) - action = QAction('Re-initialize magic labeller model...') - action.triggered.connect(self.initLabelRoiModel) - menu.addAction(action) - menu.exec_(QCursor.pos()) - - def initLabelRoiModel(self): - self.app.restoreOverrideCursor() - # Ask which model - self.initLabelRoiModelDialog = apps.QDialogSelectModel(parent=self) - self.initLabelRoiModelDialog.exec_() - if self.initLabelRoiModelDialog.cancel: - self.logger.info('Magic labeller aborted.') - self.initLabelRoiModelDialog = None - return True - self.app.setOverrideCursor(Qt.WaitCursor) - model_name = self.initLabelRoiModelDialog.selectedModel - self.labelRoiModel = self.repeatSegm( - model_name=model_name, askSegmParams=True, - is_label_roi=True - ) - if self.labelRoiModel is None: - self.initLabelRoiModelDialog = None - return True - self.labelRoiViewCurrentModelAction.setDisabled(False) - self.initLabelRoiModelDialog = None - return False - - def showOverlayContextMenu(self, event): - if not self.overlayButton.isChecked(): - return - - self.overlayContextMenu.exec_(QCursor.pos()) - - def showOverlayLabelsContextMenu(self, event): - if not self.overlayLabelsButton.isChecked(): - return - - self.overlayLabelsContextMenu.exec_(QCursor.pos()) - - def showInstructionsCustomModel(self): - modelFilePath = apps.addCustomModelMessages(self) - if modelFilePath is None: - self.logger.info('Adding custom model process stopped.') - return - - myutils.store_custom_model_path(modelFilePath) - modelName = os.path.basename(os.path.dirname(modelFilePath)) - customModelAction = QAction(modelName) - self.segmSingleFrameMenu.addAction(customModelAction) - self.segmActions.append(customModelAction) - self.segmActionsVideo.append(customModelAction) - self.modelNames.append(modelName) - self.models.append(None) - self.sender().callback(customModelAction) - - def showInstructionsCustomPromptModel(self): - modelFilePath = apps.addCustomPromptModelMessages(QParent=self) - if modelFilePath is None: - self.logger.info('Adding custom promptable model process stopped.') - return - - myutils.store_custom_promptable_model_path(modelFilePath) - - msg = widgets.myMessageBox(wrapText=False) - info_txt = html_utils.paragraph(f""" - Done!

- The custom promptable model has been added to the list of models.

- Use the Magic prompts button (top toolbar) to use it.

- Have fun! - """) - msg.information(self, 'Custom promptable model added', info_txt) - - def segmWithPromptableModelActionTriggered(self): - self.blinker = qutils.QControlBlink( - self.magicPromptsToolButton, qparent=self - ) - self.blinker.start() - - def setCheckedOverlayContextMenusActions(self, channelNames): - for action in self.overlayContextMenu.actions(): - if action.text() in channelNames: - action.setChecked(True) - self.checkedOverlayChannels.add(action.text()) - - def enableOverlayWidgets(self, enabled): - posData = self.data[self.pos_i] - if enabled: - self.overlayColorButton.setDisabled(False) - self.editOverlayColorAction.setDisabled(False) - - if posData.SizeZ == 1: - return - - self.zSliceOverlay_SB.setMaximum(posData.SizeZ-1) - if self.zProjOverlay_CB.currentText().find('max') != -1: - self.overlay_z_label.setDisabled(True) - self.zSliceOverlay_SB.setDisabled(True) - else: - z = self.zSliceOverlay_SB.sliderPosition() - self.overlay_z_label.setText(f'Overlay z-slice {z+1:02}/{posData.SizeZ}') - self.zSliceOverlay_SB.setDisabled(False) - self.overlay_z_label.setDisabled(False) - self.zSliceOverlay_SB.show() - self.overlay_z_label.show() - self.zProjOverlay_CB.show() - self.zSliceOverlay_SB.valueChanged.connect(self.updateOverlayZslice) - self.zProjOverlay_CB.currentTextChanged.connect(self.updateOverlayZproj) - self.zProjOverlay_CB.activated.connect(self.clearComboBoxFocus) - else: - self.zSliceOverlay_SB.setDisabled(True) - self.zSliceOverlay_SB.hide() - self.overlay_z_label.hide() - self.zProjOverlay_CB.hide() - self.overlayColorButton.setDisabled(True) - self.editOverlayColorAction.setDisabled(True) - - if posData.SizeZ == 1: - return - - self.zSliceOverlay_SB.valueChanged.disconnect() - self.zProjOverlay_CB.currentTextChanged.disconnect() - self.zProjOverlay_CB.activated.disconnect() - - - def criticalFluoChannelNotFound(self, fluo_ch, posData): - msg = widgets.myMessageBox(showCentered=False) - ls = "\n".join(myutils.listdir(posData.images_path)) - msg.setDetailedText( - f'Files present in the {posData.relPath} folder:\n' - f'{ls}' - ) - title = 'Requested channel data not found!' - txt = html_utils.paragraph( - f'The folder {posData.pos_path} ' - 'does not contain ' - 'either one of the following files:

' - f'{posData.basename}{fluo_ch}.tif
' - f'{posData.basename}{fluo_ch}_aligned.npz

' - 'Data loading aborted.' - ) - msg.addShowInFileManagerButton(posData.images_path) - okButton = msg.warning( - self, title, txt, buttonsTexts=('Ok') - ) - - def imgGradLUTfinished_cb(self): - posData = self.data[self.pos_i] - ticks = self.imgGrad.gradient.listTicks() - - self.img1ChannelGradients[self.user_ch_name] = { - 'ticks': [(x, t.color.getRgb()) for t,x in ticks], - 'mode': 'rgb' - } - - self.df_settings = self.imgGrad.saveState(self.df_settings) - self.df_settings.to_csv(self.settings_csv_path) - - def updateContColour(self, colorButton): - color = colorButton.color().getRgb() - self.df_settings.at['contLineColor', 'value'] = str(color) - self._updateContColour(color) - self.updateAllImages() - - def _updateContColour(self, color): - self.gui_createContourPens() - for items in self.overlayLayersItems.values(): - lutItem = items[1] - lutItem.contoursColorButton.setColor(color) - - def saveContColour(self, colorButton): - self.df_settings.to_csv(self.settings_csv_path) - - def updateMothBudLineColour(self, colorButton): - color = colorButton.color().getRgb() - self.df_settings.at['mothBudLineColor', 'value'] = str(color) - self._updateMothBudLineColour(color) - self.updateAllImages() - - def _updateMothBudLineColour(self, color): - self.gui_createMothBudLinePens() - self.ax1_newMothBudLinesItem.setBrush(self.newMothBudLineBrush) - self.ax1_oldMothBudLinesItem.setBrush(self.oldMothBudLineBrush) - self.ax2_newMothBudLinesItem.setBrush(self.newMothBudLineBrush) - self.ax2_oldMothBudLinesItem.setBrush(self.oldMothBudLineBrush) - for items in self.overlayLayersItems.values(): - lutItem = items[1] - lutItem.mothBudLineColorButton.setColor(color) - - def saveMothBudLineColour(self, colorButton): - self.df_settings.to_csv(self.settings_csv_path) - - def contLineWeightToggled(self, checked=True): - if not checked: - return - self.imgGrad.uncheckContLineWeightActions() - w = self.sender().lineWeight - self.df_settings.at['contLineWeight', 'value'] = w - self.df_settings.to_csv(self.settings_csv_path) - self._updateContLineThickness() - self.updateAllImages() - - def _updateContLineThickness(self): - self.gui_createContourPens() - for act in self.imgGrad.contLineWightActionGroup.actions(): - if act == self.sender(): - act.setChecked(True) - act.toggled.connect(self.contLineWeightToggled) - - def mothBudLineWeightToggled(self, checked=True): - if not checked: - return - self.imgGrad.uncheckContLineWeightActions() - w = self.sender().lineWeight - self.df_settings.at['mothBudLineSize', 'value'] = w - self.df_settings.to_csv(self.settings_csv_path) - self._updateMothBudLineSize(w) - self.updateAllImages() - - def _updateMothBudLineSize(self, size): - self.gui_createMothBudLinePens() - - for act in self.imgGrad.mothBudLineWightActionGroup.actions(): - if act == self.sender(): - act.setChecked(True) - act.toggled.connect(self.mothBudLineWeightToggled) - - self.ax1_oldMothBudLinesItem.setSize(size) - self.ax1_newMothBudLinesItem.setSize(size) - self.ax2_oldMothBudLinesItem.setSize(size) - self.ax2_newMothBudLinesItem.setSize(size) - - def getOlImg(self, key, frame_i=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - img = posData.ol_data[key][frame_i] - if posData.SizeZ > 1: - zProjHow = self.zProjOverlay_CB.currentText() - z = self.zSliceOverlay_SB.sliderPosition() - if zProjHow == 'same as above': - zProjHow = self.zProjComboBox.currentText() - z = self.zSliceScrollBar.sliderPosition() - reconnect = False - try: - self.zSliceOverlay_SB.valueChanged.disconnect() - reconnect = True - except TypeError: - pass - self.zSliceOverlay_SB.setSliderPosition(z) - if reconnect: - self.zSliceOverlay_SB.valueChanged.connect( - self.updateOverlayZslice - ) - if zProjHow == 'single z-slice': - self.overlay_z_label.setText(f'Overlay z-slice {z+1:02}/{posData.SizeZ}') - ol_img = img[z].copy() - elif zProjHow == 'max z-projection': - ol_img = img.max(axis=0) - elif zProjHow == 'mean z-projection': - ol_img = img.mean(axis=0) - elif zProjHow == 'median z-proj.': - ol_img = np.median(img, axis=0) - else: - ol_img = img.copy() - - return ol_img - - def setTextAnnotZsliceScrolling(self): - pass - - def setGraphicalAnnotZsliceScrolling(self): - posData = self.data[self.pos_i] - if self.isSegm3D: - self.currentLab2D = posData.lab[self.z_lab()] - self.setOverlaySegmMasks() - self.doCustomAnnotation(0) - self.update_rp_metadata() - else: - self.currentLab2D = posData.lab - self.setOverlaySegmMasks() - self.updateContoursImage(0) - self.updateContoursImage(1) - - def initContoursImage(self): - posData = self.data[self.pos_i] - z_slice = self.z_lab() - img = posData.img_data[posData.frame_i] - Y, X = img[z_slice].shape[-2:] - - self.contoursImage = np.zeros((Y, X, 4), dtype=np.uint8) - - def initDelRoiLab(self): - posData = self.data[self.pos_i] - z_slice = self.z_lab() - img = posData.img_data[posData.frame_i] - Y, X = img[z_slice].shape[-2:] - - self.delRoiLab = np.zeros((Y, X), dtype=np.uint32) - - def initLostObjContoursImage(self): - posData = self.data[self.pos_i] - z_slice = self.z_lab() - img = posData.img_data[posData.frame_i] - Y, X = img[z_slice].shape[-2:] - - self.lostObjContoursImage = np.zeros((Y, X, 4), dtype=np.uint8) - - def initExportMaskImage(self): - posData = self.data[self.pos_i] - z_slice = self.z_lab() - img = posData.img_data[posData.frame_i] - Y, X = img[z_slice].shape[-2:] - - self.exportMaskImage = np.zeros((Y, X, 4), dtype=np.uint8) - - def initLostTrackedObjContoursImage(self): - posData = self.data[self.pos_i] - z_slice = self.z_lab() - img = posData.img_data[posData.frame_i] - Y, X = img[z_slice].shape[-2:] - - self.lostTrackedObjContoursImage = np.zeros((Y, X, 4), dtype=np.uint8) - - def initManualBackgroundImage(self): - posData = self.data[self.pos_i] - if hasattr(posData, 'lab'): - Y, X = posData.lab.shape[-2:] - else: - Y, X = posData.img_data.shape[-2:] - if not hasattr(self, 'manualBackgroundTextItems'): - self.manualBackgroundTextItems = {} - posData.manualBackgroundImage = np.zeros((Y, X, 4), dtype=np.uint8) - if posData.manualBackgroundLab is None: - posData.manualBackgroundLab = np.zeros((Y, X), dtype=np.uint32) - - def initTextAnnot(self, force=False): - posData = self.data[self.pos_i] - if hasattr(posData, 'lab'): - Y, X = posData.lab.shape[-2:] - else: - Y, X = posData.img_data.shape[-2:] - self.textAnnot[0].initItem((Y, X)) - self.textAnnot[1].initItem((Y, X)) - - def getObjContours( - self, obj, all_external=False, local=False, force_calc=True, - include_internal=False - ): - posData = self.data[self.pos_i] - dataDict = posData.allData_li[posData.frame_i] - allContours = dataDict.get('contours') - if allContours is not None and not force_calc: - z = self.z_lab() - key = (obj.label, str(z), all_external, local) - contours = allContours.get(key) - if contours is not None: - return contours - - obj_image = self.getObjImage(obj.image, obj.bbox).astype(np.uint8) - obj_bbox = self.getObjBbox(obj.bbox) - try: - contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, - local=local, - all_external=all_external - ) - except Exception as e: - if all_external: - contours = [] - else: - contours = None - self.logger.warning( - f'Object ID {obj.label} contours drawing failed. ' - f'(bounding box = {obj.bbox})' - ) - return contours - - def clearComputedContours(self): - for posData in self.data: - for frame_i, dataDict in enumerate(posData.allData_li): - dataDict['contours'] = {} - - def _computeAllContours2D( - self, dataDict, obj, z, obj_bbox, include_internal=False - ): - obj_image = self.getObjImage(obj.image, obj.bbox, z_slice=z) - if obj_image is None: - return - - all_external = False - local = False - contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, - local=local, - all_external=all_external - ) - key = (obj.label, str(z), all_external, local) - dataDict['contours'][key] = contours - - all_external = True - local = False - contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, - local=local, - all_external=all_external, - all=include_internal - ) - key = (obj.label, str(z), all_external, local) - dataDict['contours'][key] = contours - - return dataDict - - def computeAllContours(self): - self.logger.info('Computing all contours...') - posData = self.data[self.pos_i] - zz = [None] - if self.isSegm3D: - zz.extend(range(posData.SizeZ)) - - include_internal = self.showAllContoursToggle.isChecked() - for frame_i, dataDict in enumerate(posData.allData_li): - lab = dataDict['labels'] - if lab is None: - break - - rp = dataDict['regionprops'] - if rp is None: - rp = skimage.measure.regionprops(lab) - - dataDict['contours'] = {} - for obj in rp: - obj_bbox = self.getObjBbox(obj.bbox) - for z in zz: - if not self.isObjVisible(obj.bbox, z_slice=z): - continue - - try: - self._computeAllContours2D( - dataDict, obj, z, obj_bbox, - include_internal=include_internal - ) - except Exception as err: - # Contours computation fails on weird objects - pass - - def computeAllObjToObjCostPairs(self): - desc = ( - 'Computing all object-to-object cost matrices...' - ) - self.logger.info(desc) - posData = self.data[self.pos_i] - - - self.progressWin = apps.QDialogWorkerProgress( - title=desc, parent=self, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.show(self.app) - - self.computeAllObjCostPairsThread = QThread() - self.computeAllObjCostPairsWorker = workers.SimpleWorker( - posData, self._computeAllObjToObjCostPairs - ) - - self.computeAllObjCostPairsWorker.moveToThread( - self.computeAllObjCostPairsThread - ) - - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsThread.quit - ) - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsWorker.deleteLater - ) - self.computeAllObjCostPairsThread.finished.connect( - self.computeAllObjCostPairsThread.deleteLater - ) - - self.computeAllObjCostPairsWorker.signals.critical.connect( - self.computeAllObjCostPairsWorkerCritical - ) - self.computeAllObjCostPairsWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.computeAllObjCostPairsWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.computeAllObjCostPairsWorker.signals.progress.connect( - self.workerProgress - ) - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsWorkerFinished - ) - - self.computeAllObjCostPairsThread.started.connect( - self.computeAllObjCostPairsWorker.run - ) - self.computeAllObjCostPairsThread.start() - - self.computeAllObjCostPairsWorkerLoop = QEventLoop() - self.computeAllObjCostPairsWorkerLoop.exec_() - - def _computeAllObjToObjCostPairs(self, posData): - self.computeAllObjCostPairsWorker.signals.initProgressBar.emit( - len(posData.allData_li) - ) - for frame_i, dataDict in enumerate(posData.allData_li): - if frame_i == 0: - continue - - rp = dataDict['regionprops'] - if rp is None: - break - - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - dist_matrix = core._compute_all_obj_to_obj_contour_dist_pairs( - dataDict['contours'], rp, - prev_rp=prev_rp, - restrict_search=True - ) - dataDict['obj_to_obj_dist_cost_matrix_df'] = dist_matrix - self.computeAllObjCostPairsWorker.signals.progressBar.emit(1) - self.computeAllObjCostPairsWorker.signals.initProgressBar.emit(0) - - def computeAllObjCostPairsWorkerCritical(self, error): - self.computeAllObjCostPairsWorkerLoop.exit() - self.workerCritical(error) - - def computeAllObjCostPairsWorkerFinished(self, output): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.computeAllObjCostPairsWorkerLoop.exit() - - def setOverlaySegmMasks(self, force=False, forceIfNotActive=False): - if not hasattr(self, 'currentLab2D'): - return - - how = self.drawIDsContComboBox.currentText() - isOverlaySegmLeftActive = how.find('overlay segm. masks') != -1 - - how_ax2 = self.getAnnotateHowRightImage() - isOverlaySegmRightActive = ( - how_ax2.find('overlay segm. masks') != -1 - and self.labelsGrad.showRightImgAction.isChecked() - ) - - isOverlaySegmActive = ( - isOverlaySegmLeftActive or isOverlaySegmRightActive - or force - ) - if not isOverlaySegmActive and not forceIfNotActive: - return - - alpha = self.imgGrad.labelsAlphaSlider.value() - if alpha == 0: - return - - posData = self.data[self.pos_i] - maxID = max(posData.IDs, default=0) - - if maxID >= len(self.lut): - self.extendLabelsLUT(maxID+10) - - currentLab2D = self.currentLab2D - if isOverlaySegmLeftActive: - self.labelsLayerImg1.setImage(currentLab2D, autoLevels=False) - - if isOverlaySegmRightActive: - self.labelsLayerRightImg.setImage(currentLab2D, autoLevels=False) - - def getObject2DimageFromZ(self, z, obj): - posData = self.data[self.pos_i] - z_min = obj.bbox[0] - local_z = z - z_min - if local_z >= posData.SizeZ or local_z < 0: - return - return obj.image[local_z] - - def getObject2DsliceFromZ(self, z, obj): - posData = self.data[self.pos_i] - z_min = obj.bbox[0] - local_z = z - z_min - if local_z >= posData.SizeZ or local_z < 0: - return - return obj.image[local_z] - - def isObjVisible(self, obj_bbox, debug=False, z_slice=None): - if z_slice is None: - z_slice = self.z_lab() - - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if not isZslice: - # required a projection --> all obj are visible - return True - - depthAxes = self.switchPlaneCombobox.depthAxes() - - min_z, min_y, min_x, max_z, max_y, max_x = obj_bbox - if depthAxes == 'z': - min_val, max_val = min_z, max_z - val = z_slice - elif depthAxes == 'y': - min_val, max_val = min_y, max_y - val = z_slice[-1] - else: - min_val, max_val = min_x, max_x - val = z_slice[-1] - - if val >= min_val and val < max_val: - return True - else: - return False - else: - return True - - def getObjImage(self, obj_image, obj_bbox, z_slice=None): - if self.isSegm3D and len(obj_bbox)==6: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if not isZslice: - # required a projection - return obj_image.max(axis=0) - - min_z = obj_bbox[0] - if z_slice is None: - z_slice = self.z_lab() - if isinstance(z_slice, tuple): - z_slice = z_slice[-1] - - local_z = z_slice - min_z - try: - obi_image_2d = obj_image[local_z] - except Exception as err: - obi_image_2d = None - return obi_image_2d - else: - return obj_image - - def getObjSlice(self, obj_slice): - if self.isSegm3D: - return obj_slice[1:3] - else: - return obj_slice - - def setOverlayImages(self, frame_i=None): - if not self.overlayButton.isChecked(): - return - - posData = self.data[self.pos_i] - if posData.ol_data is None: - return - - rgba_imgs_info = {} - for filename in posData.ol_data: - chName = myutils.get_chname_from_basename( - filename, posData.basename, remove_ext=False - ) - if chName not in self.checkedOverlayChannels: - continue - - items = self.overlayLayersItems[chName] - imageItem, lutItem, alphaSB = items[:3] - - ol_img = self.getOlImg(filename, frame_i=frame_i) - - if self.overlayToolbar.isTransparent(): - toolbutton = items[3] - if not toolbutton.isChecked(): - continue - alpha_val = alphaSB.value()/alphaSB.maximum() - ol_img = skimage.exposure.rescale_intensity( - ol_img, out_range=(0.0, 1.0) - ) - out_range_min, out_range_max = lutItem.getLevels() - rgba_imgs_info[chName] = (ol_img, alpha_val, lutItem) - else: - self.rescaleIntensitiesLut(setImage=False, imageItem=imageItem) - imageItem.setImage(ol_img) - - if not self.overlayToolbar.isTransparent(): - return - - alpha_values = [] - images = [] - luts = [] - for channel, info in rgba_imgs_info.items(): - ol_img, alpha_val, lutItem = info - alpha_values.append(alpha_val) - images.append(ol_img) - luts.append(lutItem.gradient.getLookupTable(256, alpha=255)/255) - - weights = colors.hierarchical_weights(alpha_values) - - if self.baseLayerToolbutton.isChecked(): - image1 = self._getImageupdateAllImages() - image1 = skimage.exposure.rescale_intensity( - image1, out_range=(0.0, 1.0) - ) - images.append(image1) - baseLut = ( - self.imgGrad.gradient.getLookupTable(256, alpha=255)/255 - ) - luts.append(baseLut) - - images_rgba = [] - for img, lut in zip(images, luts): - rgba = colors.grayscale_apply_lut(img, lut) - images_rgba.append(rgba) - - rgba_merge = colors.hierarchical_blend(images_rgba, weights) - self.rgbaImg1.setImage(rgba_merge) - - def getOpacitiesFromAlphaScrollbarValues(self): - alpha_values = [] - activeOverlayImageItems = [] - for items in self.overlayLayersItems.values(): - imgItem, lutItem, alphaSB = items[:3] - _toolbutton = alphaSB.toolbutton - if not _toolbutton.isChecked() or not _toolbutton.isVisible(): - continue - - alpha_values.append(alphaSB.value()/alphaSB.maximum()) - activeOverlayImageItems.append(imgItem) - - opacities = colors.hierarchical_weights(alpha_values)[::-1] - channel_opacity_mapper = {} - for i, imgItem in enumerate(activeOverlayImageItems): - channel_opacity_mapper[imgItem.channelName] = opacities[i+1] - - channel_opacity_mapper[self.user_ch_name] = opacities[0] - - return channel_opacity_mapper - - def initShortcuts(self): - from . import config - cp = config.ConfigParser() - if os.path.exists(shortcut_filepath): - cp.read(shortcut_filepath) - - if 'keyboard.shortcuts' not in cp: - cp['keyboard.shortcuts'] = {} - - if cp.has_option('keyboard.shortcuts', 'Zoom out'): - zoomOutKeyValueStr = cp['keyboard.shortcuts']['Zoom out'] - try: - self.zoomOutKeyValue = int(zoomOutKeyValueStr) - except Exception as err: - self.logger.warning( - f'{zoomOutKeyValueStr} is not a valid key ' - 'zooming out action. Restoring default key "H".' - ) - - if 'delete_object.action' not in cp: - self.delObjAction = None - else: - delObjKeySequenceText = cp['delete_object.action']['Key sequence'] - delObjButtonText = cp['delete_object.action']['Mouse button'] - delObjQtButton = ( - Qt.MouseButton.LeftButton if delObjButtonText == 'Left click' - else Qt.MouseButton.MiddleButton - ) - if not delObjKeySequenceText: - delObjKeySequence = None - else: - delObjKeySequence = widgets.KeySequenceFromText( - delObjKeySequenceText - ) - self.delObjToolAction.setChecked(True) - self.delObjAction = delObjKeySequence, delObjQtButton - - shortcuts = {} - for name, widget in self.widgetsWithShortcut.items(): - if name not in cp.options('keyboard.shortcuts'): - if hasattr(widget, 'keyPressShortcut'): - key = widget.keyPressShortcut - shortcut = widgets.KeySequenceFromText(key) - else: - shortcut = widget.shortcut() - shortcut_text = shortcut.toString() - cp['keyboard.shortcuts'][name] = shortcut_text - else: - shortcut_text = cp['keyboard.shortcuts'][name] - shortcut = widgets.KeySequenceFromText(shortcut_text) - - shortcuts[name] = (shortcut_text, shortcut) - self.setShortcuts(shortcuts, save=False) - with open(shortcut_filepath, 'w') as ini: - cp.write(ini) - - def setShortcuts(self, shortcuts: dict, save=True): - for name, (text, shortcut) in shortcuts.items(): - widget = self.widgetsWithShortcut[name] - if shortcut is None: - shortcut = QKeySequence() - if hasattr(widget, 'keyPressShortcut'): - widget.keyPressShortcut = shortcut - else: - widget.setShortcut(shortcut) - s = widget.toolTip() - toolTip = re.sub(r'Shortcut: "(.*)"', f'Shortcut: "{text}"', s) - widget.setToolTip(toolTip) - - if not save: - return - - from . import config - cp = config.ConfigParser() - if os.path.exists(shortcut_filepath): - cp.read(shortcut_filepath) - - if 'keyboard.shortcuts' not in cp: - cp['keyboard.shortcuts'] = {} - - for name, (text, shortcut) in shortcuts.items(): - cp['keyboard.shortcuts'][name] = text - - cp['keyboard.shortcuts']['Zoom out'] = str(self.zoomOutKeyValue) - - if self.delObjAction is None: - with open(shortcut_filepath, 'w') as ini: - cp.write(ini) - return - - delObjKeySequence, delObjQtButton = self.delObjAction - try: - if delObjKeySequence is None: - delObjKeySequenceText = '' - else: - delObjKeySequenceText = delObjKeySequence.toString() - - delObjKeySequenceText = ( - delObjKeySequenceText - .encode('ascii', 'ignore') - .decode('utf-8') - ) - delObjButtonText = ( - 'Left click' if delObjQtButton == Qt.MouseButton.LeftButton - else 'Middle click' - ) - cp['delete_object.action'] = { - 'Key sequence': delObjKeySequenceText, - 'Mouse button': delObjButtonText - } - except Exception as err: - self.logger.warning( - f'{delObjKeySequence} is not a valid keys sequence for ' - 'deleting objects. Setting default action' - ) - self.delObjAction = None - cp.remove_section('delete_object.action') - - with open(shortcut_filepath, 'w') as ini: - cp.write(ini) - - def editShortcuts_cb(self): - if is_mac: - delObjKeySequenceText = 'Ctrl' - delObjButtonText = 'Left click' - else: - delObjKeySequenceText = '' - delObjButtonText = 'Middle click' - - if self.delObjAction is not None: - delObjKeySequence, delObjQtButton = self.delObjAction - if delObjKeySequence is None: - delObjKeySequenceText = '' - else: - delObjKeySequenceText = delObjKeySequence.toString() - delObjKeySequenceText = ( - delObjKeySequenceText.encode('ascii', 'ignore').decode('utf-8') - ) - delObjButtonText = ( - 'Left click' if delObjQtButton == Qt.MouseButton.LeftButton - else 'Middle click' - ) - - win = apps.ShortcutEditorDialog( - self.widgetsWithShortcut, - delObjectKey=delObjKeySequenceText, - delObjectButton=delObjButtonText, - zoomOutKeyValue=self.zoomOutKeyValue, - parent=self - ) - win.exec_() - if win.cancel: - return - - self.delObjAction = win.delObjAction - self.zoomOutKeyValue = win.zoomOutKeyValue - self.setShortcuts(win.customShortcuts) - - def toggleOverlayColorButton(self, checked=True): - self.mousePressColorButton(None) - - def toggleTextIDsColorButton(self, checked=True): - self.textIDsColorButton.selectColor() - - def updateTextAnnotColor(self, button): - r, g, b = np.array(self.textIDsColorButton.color().getRgb()[:3]) - self.imgGrad.textColorButton.setColor((r, g, b)) - for items in self.overlayLayersItems.values(): - lutItem = items[1] - lutItem.textColorButton.setColor((r, g, b)) - self.gui_createTextAnnotColors(r,g,b, custom=True) - self.gui_setTextAnnotColors() - self.updateAllImages() - - def saveTextIDsColors(self, button): - self.df_settings.at['textIDsColor', 'value'] = self.objLabelAnnotRgb - self.df_settings.to_csv(self.settings_csv_path) - - def setLut(self, shuffle=True): - self.lut = self.labelsGrad.item.colorMap().getLookupTable(0,1,255) - if shuffle: - np.random.shuffle(self.lut) - - # Insert background color - if 'labels_bkgrColor' in self.df_settings.index: - rgbString = self.df_settings.at['labels_bkgrColor', 'value'] - try: - r, g, b = rgbString - except Exception as e: - r, g, b = colors.rgb_str_to_values(rgbString) - else: - r, g, b = 25, 25, 25 - self.df_settings.at['labels_bkgrColor', 'value'] = (r, g, b) - - self.lut = np.insert(self.lut, 0, [r, g, b], axis=0) - - def useCenterBrushCursorHoverIDtoggled(self, checked): - if checked: - self.df_settings.at['useCenterBrushCursorHoverID', 'value'] = 'Yes' - else: - self.df_settings.at['useCenterBrushCursorHoverID', 'value'] = 'No' - self.df_settings.to_csv(self.settings_csv_path) - - def shuffle_cmap(self): - np.random.shuffle(self.lut[1:]) - self.initLabelsImageItems() - self.updateAllImages() - - def setPermanentGreedyCmapPreferences(self): - if self.isSnapshot: - option_name = 'permanent_greedy_lut_snapshots' - else: - option_name = 'permanent_greedy_lut_timelapse' - - if option_name not in self.df_settings.index: - return - - checked = self.df_settings.at[option_name, 'value'] == 'yes' - self.labelsGrad.permanentGreedyCmapAction.setChecked(checked) - - def permanentGreedyCmapToggled(self, checked): - if checked: - settings_value = 'yes' - else: - self.setLut() - self.updateLookuptable() - self.initLabelsImageItems() - settings_value = 'no' - - self.updateAllImages() - - if self.isSnapshot: - option_name = 'permanent_greedy_lut_snapshots' - else: - option_name = 'permanent_greedy_lut_timelapse' - - self.df_settings.at[option_name, 'value'] = settings_value - self.df_settings.to_csv(self.settings_csv_path) - - def greedyShuffleCmap(self, updateImages=True): - lut = self.labelsGrad.item.colorMap().getLookupTable(0,1,255) - greedy_lut = colors.get_greedy_lut(self.currentLab2D, lut) - self.lut = greedy_lut - self.initLabelsImageItems() - if updateImages: - self.updateAllImages() - - def highlightZneighLabels_cb(self, checked): - if checked: - pass - else: - pass - - def setTwoImagesLayout(self, isTwoImages): - self.isTwoImageLayout = isTwoImages - if isTwoImages: - self.graphLayout.removeItem(self.titleLabel) - self.graphLayout.addItem(self.titleLabel, row=0, col=1, colspan=2) - # self.mainLayout.setAlignment(self.bottomLayout, Qt.AlignLeft) - self.ax2.show() - self.ax2.vb.setYLink(self.ax1.vb) - self.ax2.vb.setXLink(self.ax1.vb) - else: - self.graphLayout.removeItem(self.titleLabel) - self.graphLayout.addItem(self.titleLabel, row=0, col=1) - # self.mainLayout.setAlignment(self.bottomLayout, Qt.AlignCenter) - self.ax2.hide() - oldLink = self.ax2.vb.linkedView(self.ax1.vb.YAxis) - try: - oldLink.sigYRangeChanged.disconnect() - oldLink.sigXRangeChanged.disconnect() - except TypeError: - pass - - def showNextFrameImageItem(self, checked): - self.rightImageFramesScrollbar.setVisible(checked) - self.rightImageFramesScrollbar.setDisabled(not checked) - self.setTwoImagesLayout(checked) - if checked: - self.df_settings.at['isNextFrameVisible', 'value'] = 'Yes' - self.df_settings.at['isRightImageVisible', 'value'] = 'No' - self.df_settings.at['isLabelsVisible', 'value'] = 'No' - self.graphLayout.addItem( - self.imgGradRight, row=1, col=self.plotsCol+2 - ) - self.rightBottomGroupbox.show() - self.rightBottomGroupbox.setChecked(True) - self.drawNothingCheckboxRight.click() - if not self.isDataLoading: - self.updateAllImages() - else: - self.clearAx2Items() - self.rightBottomGroupbox.hide() - self.df_settings.at['isNextFrameVisible', 'value'] = 'No' - try: - self.graphLayout.removeItem(self.imgGradRight) - except Exception: - return - self.rightImageItem.clear() - - self.df_settings.to_csv(self.settings_csv_path) - - QTimer.singleShot(300, self.resizeGui) - - self.setBottomLayoutStretch() - - - def showRightImageItem(self, checked): - self.rightImageFramesScrollbar.setVisible(not checked) - self.rightImageFramesScrollbar.setDisabled(checked) - self.setTwoImagesLayout(checked) - if checked: - self.df_settings.at['isRightImageVisible', 'value'] = 'Yes' - self.df_settings.at['isNextFrameVisible', 'value'] = 'No' - self.df_settings.at['isLabelsVisible', 'value'] = 'No' - self.graphLayout.addItem( - self.imgGradRight, row=1, col=self.plotsCol+2 - ) - self.rightBottomGroupbox.show() - if not self.isDataLoading: - self.updateAllImages() - else: - self.clearAx2Items() - self.rightBottomGroupbox.hide() - self.df_settings.at['isRightImageVisible', 'value'] = 'No' - try: - self.graphLayout.removeItem(self.imgGradRight) - except Exception: - return - self.rightImageItem.clear() - - self.df_settings.to_csv(self.settings_csv_path) - - QTimer.singleShot(300, self.resizeGui) - - self.setBottomLayoutStretch() - - def showLabelImageItem(self, checked): - self.rightImageFramesScrollbar.setVisible(not checked) - self.rightImageFramesScrollbar.setDisabled(checked) - self.setTwoImagesLayout(checked) - self.setAnnotOptionsRightImageLabelsDisabled(checked) - if checked: - self.df_settings.at['isLabelsVisible', 'value'] = 'Yes' - self.df_settings.at['isNextFrameVisible', 'value'] = 'No' - self.df_settings.at['isRightImageVisible', 'value'] = 'No' - self.rightBottomGroupbox.show() - self.rightBottomGroupbox.setChecked(True) - if not self.isDataLoading: - self.updateAllImages() - else: - self.clearAx2Items() - self.img2.clear() - self.df_settings.at['isLabelsVisible', 'value'] = 'No' - self.rightBottomGroupbox.hide() - self.moveDelRoisToLeft() - - self.df_settings.to_csv(self.settings_csv_path) - QTimer.singleShot(200, self.resizeGui) - - self.setBottomLayoutStretch() - - def setAnnotOptionsRightImageLabelsDisabled(self, disabled): - self.annotContourCheckboxRight.setDisabled(disabled) - self.annotSegmMasksCheckboxRight.setDisabled(disabled) - if disabled: - self.annotSegmMasksCheckboxRight.setChecked(False) - self.annotSegmMasksCheckboxRight.setChecked(False) - self.annotIDsCheckboxRight.setChecked(True) - - def moveDelRoisToLeft(self): - # Move del ROIs to the left image - for posData in self.data: - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - for roi in delROIs_info['rois']: - if not self.ax2.isDelRoiItemPresent(roi): - continue - - self.ax1.addDelRoiItem(roi, roi.key) - self.ax2.removeDelRoiItem(roi) - - def setBottomLayoutStretch(self): - if ( - self.labelsGrad.showRightImgAction.isChecked() - or self.labelsGrad.showNextFrameAction.isChecked() - ): - # Equally share space between the two control groupboxes - self.bottomLayout.setStretch(1, 1) - self.bottomLayout.setStretch(2, 5) - self.bottomLayout.setStretch(3, 1) - self.bottomLayout.setStretch(4, 5) - self.bottomLayout.setStretch(5, 1) - elif self.labelsGrad.showLabelsImgAction.isChecked(): - # Left control takes only left space - self.bottomLayout.setStretch(1, 1) - self.bottomLayout.setStretch(2, 5) - self.bottomLayout.setStretch(3, 5) - self.bottomLayout.setStretch(4, 1) - self.bottomLayout.setStretch(5, 1) - else: - # Left control takes all the space - self.bottomLayout.setStretch(1, 3) - self.bottomLayout.setStretch(2, 10) - self.bottomLayout.setStretch(3, 1) - self.bottomLayout.setStretch(4, 1) - self.bottomLayout.setStretch(5, 1) - - def setCheckedInvertBW(self, checked): - self.invertBwAction.setChecked(checked) - - def ticksCmapMoved(self, gradient): - pass - # posData = self.data[self.pos_i] - # self.setLut(posData, shuffle=False) - # self.updateLookuptable() - - def updateLabelsCmap(self, gradient): - self.setLut() - self.updateLookuptable() - self.initLabelsImageItems() - - self.df_settings = self.labelsGrad.saveState(self.df_settings) - self.df_settings.to_csv(self.settings_csv_path) - - self.updateAllImages() - - def updateBkgrColor(self, button): - color = button.color().getRgb()[:3] - self.lut[0] = color - self.updateLookuptable() - - def updateTextLabelsColor(self, button): - self.ax2_textColor = button.color().getRgb()[:3] - posData = self.data[self.pos_i] - if posData.rp is None: - return - - for obj in posData.rp: - self.getObjOptsSegmLabels(obj) - - def saveTextLabelsColor(self, button): - color = button.color().getRgb()[:3] - self.df_settings.at['labels_text_color', 'value'] = color - self.df_settings.to_csv(self.settings_csv_path) - - def saveBkgrColor(self, button): - color = button.color().getRgb()[:3] - self.df_settings.at['labels_bkgrColor', 'value'] = color - self.df_settings.to_csv(self.settings_csv_path) - self.updateAllImages() - - def changeOverlayColor(self, button): - rgb = button.color().getRgb()[:3] - lutItem = self.overlayLayersItems[button.channel][1] - self.initColormapOverlayLayerItem(rgb, lutItem) - lutItem.overlayColorButton.setColor(rgb) - - def saveOverlayColor(self, button): - rgb = button.color().getRgb()[:3] - rgb_text = '_'.join([str(val) for val in rgb]) - self.df_settings.at[f'{button.channel}_rgb', 'value'] = rgb_text - self.df_settings.to_csv(self.settings_csv_path) - - def getImageDataFromFilename(self, filename): - posData = self.data[self.pos_i] - if filename == posData.filename: - return posData.img_data[posData.frame_i] - else: - return posData.ol_data_dict.get(filename) - - def z_slice_index(self): - posData = self.data[self.pos_i] - if posData.SizeZ == 1: - return None - zProjHow = self.zProjComboBox.currentText() - if zProjHow != 'single z-slice': - return zProjHow - - axis_slice = self.zSliceScrollBar.sliderPosition() - if self.switchPlaneCombobox.depthAxes() == 'x': - z_slice = ( - slice(None, None, None), slice(None, None, None), axis_slice - ) - elif self.switchPlaneCombobox.depthAxes() == 'y': - z_slice = ( - slice(None, None, None), axis_slice - ) - else: - z_slice = axis_slice - - return z_slice - - def get_2Dimg_from_3D(self, imgData, isLayer0=True, frame_i=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - if frame_i < 0: - frame_i = 0 - frame_i = posData.frame_i = 0 - - axis_slice = self.zSliceScrollBar.sliderPosition() - if self.switchPlaneCombobox.depthAxes() == 'x': - return imgData[:, :, axis_slice].copy() - elif self.switchPlaneCombobox.depthAxes() == 'y': - return imgData[:, axis_slice].copy() - - idx = (posData.filename, frame_i) - zProjHow_L0 = self.zProjComboBox.currentText() - if isLayer0: - try: - z = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - except ValueError as e: - z = posData.segmInfo_df.loc[idx, 'z_slice_used_gui'].iloc[0] - zProjHow = zProjHow_L0 - else: - z = self.zSliceOverlay_SB.sliderPosition() - zProjHow_L1 = self.zProjOverlay_CB.currentText() - if zProjHow_L1 == 'same as above': - zProjHow = zProjHow_L0 - else: - zProjHow = zProjHow_L1 - - if zProjHow == 'single z-slice': - img = imgData[z] #.copy() - elif zProjHow == 'max z-projection': - img = imgData.max(axis=0) - elif zProjHow == 'mean z-projection': - img = imgData.mean(axis=0) - elif zProjHow == 'median z-proj.': - img = np.median(imgData, axis=0) - return img - - def updateZsliceScrollbar(self, frame_i): - posData = self.data[self.pos_i] - if self.switchPlaneCombobox.depthAxes() != 'z': - return - - idx = (posData.filename, frame_i) - try: - z = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - except ValueError as e: - z = posData.segmInfo_df.loc[idx, 'z_slice_used_gui'].iloc[0] - try: - zProjHow = posData.segmInfo_df.at[idx, 'which_z_proj_gui'] - except ValueError as e: - zProjHow = posData.segmInfo_df.loc[idx, 'which_z_proj_gui'].iloc[0] - - self.zProjComboBox.setCurrentText(zProjHow) - - reconnect = False - try: - self.zSliceScrollBar.actionTriggered.disconnect() - self.zSliceScrollBar.sliderReleased.disconnect() - reconnect = True - except TypeError: - pass - self.zSliceScrollBar.setSliderPosition(z) - if reconnect: - self.zSliceScrollBar.actionTriggered.connect( - self.zSliceScrollBarActionTriggered - ) - self.zSliceScrollBar.sliderReleased.connect( - self.zSliceScrollBarReleased - ) - self.zSliceSpinbox.setValueNoEmit(z+1) - - def getRawImage(self, frame_i=None, filename=None): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - if filename is None: - rawImgData = posData.img_data[frame_i] - isLayer0 = True - else: - rawImgData = posData.ol_data[filename][frame_i] - isLayer0 = False - if posData.SizeZ > 1: - rawImg = self.get_2Dimg_from_3D(rawImgData, isLayer0=isLayer0) - else: - rawImg = rawImgData - return rawImg - - def getRawImageLayer0(self, frame_i): - posData = self.data[self.pos_i] - - if posData.SizeZ > 1: - img = posData.img_data[frame_i] - self.updateZsliceScrollbar(frame_i) - img = self.get_2Dimg_from_3D(img) - else: - img = posData.img_data[frame_i].copy() - - if img.ndim == 2: - return img - if img.ndim == 3 and img.shape[-1] in (3, 4): - return img - - raise ValueError( - 'Raw image for display must be 2D (Y, X) or RGB/A (Y, X, 3 or 4); ' - f'got shape={getattr(img, "shape", None)}, ndim={getattr(img, "ndim", None)} ' - f'for frame_i={frame_i} (metadata SizeT={posData.SizeT}, SizeZ={posData.SizeZ}). ' - 'Check that metadata SizeT/SizeZ matches the loaded array (e.g. squeezed TIFF vs CSV).' - ) - - def initFloodMaskImage(self): - posData = self.data[self.pos_i] - self.flood_img = posData.img_data[posData.frame_i] - if not self.isSegm3D and posData.SizeZ > 1: - self.flood_img = self.get_2Dimg_from_3D(self.flood_img) - return - - def getMagicWandFloodTolerance(self): - tol_perc = self.wandControlsToolbar.toleranceSpinbox.value() - if tol_perc == 0: - return - - posData = self.data[self.pos_i] - _min, _max = posData.img_data_min_max - tol_fraction = tol_perc/100 - tol = (_max - _min) * tol_fraction - - return tol - - def getImage(self, frame_i=None, raw=False): - posData = self.data[self.pos_i] - if frame_i is None: - frame_i = posData.frame_i - - if raw: - return self.getRawImageLayer0(frame_i) - - if self.viewPreprocDataToggle.isChecked(): - try: - img = posData.preproc_img_data[frame_i] - if posData.SizeZ == 1: - return np.array(img) - - self.updateZsliceScrollbar(frame_i) - z_slice = self.z_slice_index() - img = img[z_slice] - return img - except Exception as err: - # self.logger.warning( - # 'Pre-processed image not existing --> returning raw image' - # ) - return self.getRawImageLayer0(frame_i) - - viewCombinedImageData = ( - self.viewCombineChannelDataToggle.isChecked() - and self.combineDialog is not None - and not self.combineDialog.saveAsSegm() - ) - - if viewCombinedImageData: - try: - img = posData.combine_img_data[frame_i] - if posData.SizeZ == 1: - return np.array(img) - - self.updateZsliceScrollbar(frame_i) - z_slice = self.z_slice_index() - img = img[z_slice] - return img - except Exception as err: - # self.logger.warning( - # 'combined image not existing --> returning raw image' - # ) - return self.getRawImageLayer0(frame_i) - - if self.equalizeHistPushButton.isChecked(): - img = posData.equalized_img_data[frame_i] - if posData.SizeZ == 1: - return np.array(img) - - self.updateZsliceScrollbar(frame_i) - z_slice = self.z_slice_index() - img = img[z_slice] - return img - - return self.getRawImageLayer0(frame_i) - - def setImageImg2(self, updateLookuptable=True, set_image=True): - posData = self.data[self.pos_i] - mode = str(self.modeComboBox.currentText()) - if mode == 'Segmentation and Tracking' or self.isSnapshot: - # self.addExistingDelROIs() - allDelIDs, lab2D = self.getDelROIlab() - else: - lab2D = self.get_2Dlab(posData.lab, force_z=False) - allDelIDs = set() - - self.currentLab2D = lab2D - if self.labelsGrad.permanentGreedyCmapAction.isChecked() and updateLookuptable: - self.greedyShuffleCmap(updateImages=False) - - if self.labelsGrad.showLabelsImgAction.isChecked() and set_image: - self.img2.setImage(lab2D, z=self.z_lab(), autoLevels=False) - - if updateLookuptable: - self.updateLookuptable(delIDs=allDelIDs) - - def applyDelROIimg1(self, roi, init=False, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - if ax == 1 and not self.labelsGrad.showRightImgAction.isChecked(): - return - - if init and how.find('contours') == -1: - self.setOverlaySegmMasks(force=True) - return - - posData = self.data[self.pos_i] - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - try: - idx = delROIs_info['rois'].index(roi) - except Exception as err: - try: - ax.removeDelRoiItem(roi) - except Exception as err: - pass - return - delIDs = delROIs_info['delIDsROI'][idx] - delMask = delROIs_info['delMasks'][idx] - if how.find('nothing') != -1: - return - elif how.find('contours') != -1: - self.updateContoursImage(ax=ax) - - if not delIDs: - return - - if how.find('overlay segm. masks') != -1: - lab = self.currentLab2D.copy() - lab[delMask > 0] = 0 - if ax == 0: - self.labelsLayerImg1.setImage(lab, autoLevels=False) - else: - self.labelsLayerRightImg.setImage(lab, autoLevels=False) - - self.setAllTextAnnotations(labelsToSkip={ID:True for ID in delIDs}) - - def applyDelROIs(self): - self.logger.info('Applying deletion ROIs (if present)...') - - for posData in self.data: - self.current_frame_i = posData.frame_i - for frame_i in range(posData.SizeT): - lab = posData.allData_li[frame_i]['labels'] - if lab is None: - break - delROIs_info = posData.allData_li[frame_i]['delROIs_info'] - delIDs_rois = delROIs_info['delIDsROI'] - if not delIDs_rois: - continue - for delIDs in delIDs_rois: - for delID in delIDs: - lab[lab==delID] = 0 - posData.allData_li[frame_i]['labels'] = lab - # Get the rest of the metadata and store data based on the new lab - posData.frame_i = frame_i - self.get_data() - self.store_data(autosave=False) - - # Back to current frame - posData.frame_i = self.current_frame_i - self.get_data() - - def initTempLayerBrush(self, ID, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - self.hideItemsHoverBrush(ID=ID, force=True) - Y, X = self.img1.image.shape[:2] - tempImage = np.zeros((Y, X), dtype=np.uint32) - if how.find('contours') != -1: - tempImage[self.currentLab2D==ID] = ID - self.brushImage = tempImage.copy() - self.brushContourImage = np.zeros((Y, X, 4), dtype=np.uint8) - color = self.imgGrad.contoursColorButton.color() - self.brushContoursRgba = color.getRgb() - opacity = 1.0 - else: - opacity = self.imgGrad.labelsAlphaSlider.value() - color = self.lut[ID] - lut = np.zeros((2, 4), dtype=np.uint8) - lut[1,-1] = 255 - lut[1,:-1] = color - self.tempLayerImg1.setLookupTable(lut) - self.tempLayerImg1.setOpacity(opacity) - self.tempLayerImg1.setImage(tempImage, force_set_linked=True) - - def _setTempImageBrushContour(self): - pass - - def setTempBrushMaskFromWand(self, flood_mask, init=False): - if not np.any(flood_mask): - return - - posData = self.data[self.pos_i] - mask = np.logical_or( - flood_mask, - posData.lab==posData.brushID - ) - if mask.ndim == 3: - z_slice = self.zSliceScrollBar.sliderPosition() - mask = mask[z_slice] - - self.setTempImg1Brush(init, mask, posData.brushID) - - # @exec_time - def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0): - if init: - self.initTempLayerBrush(ID, ax=ax) - - if self.annotContourCheckbox.isChecked(): - brushImage = self.brushImage - else: - brushImage = self.tempLayerImg1.image - - if toLocalSlice is None: - brushImage[mask] = ID - else: - brushImage[toLocalSlice][mask] = ID - - if self.annotContourCheckbox.isChecked(): - try: - obj = skimage.measure.regionprops(brushImage)[0] - except IndexError: - return - objContour = [self.getObjContours(obj)] - # objContour = core.get_obj_contours( - # obj_image=(brushImage>0).astype(np.uint8), local=True - # ) - self.brushContourImage[:] = 0 - img = self.brushContourImage - color = self.brushContoursRgba - cv2.drawContours(img, objContour, -1, color, 1) - self.tempLayerImg1.setImage(img, force_set_linked=True) - else: - self.tempLayerImg1.setImage(brushImage, force_set_linked=True) - - def getLabelsLayerImage(self, ax=0): - if ax == 0: - return self.labelsLayerImg1.image - else: - return self.labelsLayerRightImg.image - - def clearObjFromMask(self, image, mask, toLocalSlice=None): - if mask is None: - return image - - if toLocalSlice is None: - image[mask] = 0 - else: - image[toLocalSlice][mask] = 0 - - return image - - # @exec_time - def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0): - if init: - self.erasedLab = np.zeros_like(self.currentLab2D) - - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - if ax == 1 and not self.labelsGrad.showRightImgAction.isChecked(): - return - - if how.find('contours') != -1: - self.clearObjFromMask( - self.contoursImage, mask, toLocalSlice=toLocalSlice - ) - erasedRp = skimage.measure.regionprops(self.erasedLab) - for obj in erasedRp: - self.addObjContourToContoursImage(obj=obj, ax=ax) - elif how.find('overlay segm. masks') != -1: - labelsImage = self.getLabelsLayerImage(ax=ax) - self.clearObjFromMask(labelsImage, mask, toLocalSlice=toLocalSlice) - if ax == 0: - self.labelsLayerImg1.setImage( - self.labelsLayerImg1.image, autoLevels=False - ) - else: - self.labelsLayerRightImg.setImage( - self.labelsLayerRightImg.image, autoLevels=False - ) - - def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0): - # Remove previous overlaid mask - labelsImage = self.getLabelsLayerImage(ax=ax) - labelsImage[prevCoords] = 0 - - # Overlay new moved mask - labelsImage[prevCoords] = self.expandingID - - if ax == 0: - self.labelsLayerImg1.setImage( - self.labelsLayerImg1.image, autoLevels=False) - else: - self.labelsLayerRightImg.setImage( - self.labelsLayerRightImg.image, autoLevels=False) - - def _setTempImgExpandLabelContours(self, prevCoords, ax=0): - self.contoursImage[prevCoords] = [0,0,0,0] - currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) - for obj in currentLab2Drp: - if obj.label == self.expandingID: - # self.clearObjContour(obj=obj, ax=ax) - self.addObjContourToContoursImage(obj=obj, ax=ax, force=True) - break - - def setTempImgExpandLabel(self, prevCoords, expandedObjCoords, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - self._setTempImgExpandLabelContours(prevCoords, ax=ax) - - # if how.find('overlay segm. masks') != -1: - # self._setTempImgExpandLabelSegmMasks(ax=ax) - # else: - # self._setTempImgExpandLabelContours(ax=ax) - - def setTempImg1MoveLabel(self, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - if how.find('contours') != -1: - currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) - for obj in currentLab2Drp: - if obj.label == self.movingID: - self.addObjContourToContoursImage(obj=obj, ax=ax) - break - elif how.find('overlay segm. masks') != -1: - if ax == 0: - self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) - self.highLightIDLayerImg1.image[:] = 0 - mask = self.currentLab2D==self.movingID - self.highLightIDLayerImg1.image[mask] = self.movingID - highlightedImage = self.highLightIDLayerImg1.image - self.highLightIDLayerImg1.setImage(highlightedImage) - else: - self.labelsLayerRightImg.setImage( - self.currentLab2D, autoLevels=False - ) - self.highLightIDLayerRightImage.image[:] = 0 - mask = self.currentLab2D==self.movingID - self.highLightIDLayerRightImage.image[mask] = self.movingID - highlightedImage = self.highLightIDLayerRightImage.image - self.highLightIDLayerRightImage.setImage(highlightedImage) - - def addMissingIDs_cca_df(self, posData): - base_cca_df = self.getBaseCca_df() - if posData.cca_df is None: - posData.cca_df = base_cca_df - return - - posData.cca_df = posData.cca_df.combine_first(base_cca_df) - - def update_cca_df_relabelling(self, posData, oldIDs, newIDs): - relIDs = posData.cca_df['relative_ID'] - posData.cca_df['relative_ID'] = relIDs.replace(oldIDs, newIDs) - mapper = dict(zip(oldIDs, newIDs)) - posData.cca_df = posData.cca_df.rename(index=mapper) - - def update_cca_df_deletedIDs( - self, posData, deletedIDs, dropInPast=True, dropInFuture=True - ): - if posData.cca_df is None: - return - - # Store cca_df for undo action - undoId = uuid.uuid4() - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - try: - relIDs = ( - posData.cca_df.reindex(deletedIDs, fill_value=-1) - ['relative_ID'] - ) - except KeyError as err: - return - - posData.cca_df = posData.cca_df.drop(deletedIDs, errors='ignore') - if self.isSnapshot: - self.update_cca_df_newIDs(posData, relIDs) - else: - self.updateCcaDfDeletedIDsTimelapse( - posData, relIDs, deletedIDs, undoId, dropInPast, dropInFuture - ) - - @disableWindow - def updateCcaDfDeletedIDsTimelapse( - self, posData, relIDsOfDelIDs, deletedIDs, undoId, - dropInPast, dropInFuture - ): - # Get status of the relIDs (of deleted IDs) to restore - relIDsCcaStatus = {} - for relID in relIDsOfDelIDs: - try: - ccs = posData.cca_df.at[relID, 'cell_cycle_stage'] - relationship = posData.cca_df.at[relID, 'relationship'] - except Exception as err: - continue - - ccaStatus = core.getBaseCca_df([relID]).loc[relID] - if relationship == 'mother' and ccs == 'S': - for past_frame_i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df( - frame_i=past_frame_i, return_df=True - ) - ccs_past = cca_df_i.at[relID, 'cell_cycle_stage'] - if ccs_past == 'G1': - ccaStatus = cca_df_i.loc[relID] - break - - posData.cca_df.loc[relID] = ccaStatus - self.store_data(autosave=False) - relIDsCcaStatus[relID] = ccaStatus - - for fut_frame_i in range(posData.frame_i+1, posData.SizeT): - cca_df_i = self.get_cca_df(frame_i=fut_frame_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(fut_frame_i, cca_df_i, undoId) - - if dropInFuture: - cca_df_i = cca_df_i.drop(deletedIDs, errors='ignore') - else: - for delID in deletedIDs: - dataDict = posData.allData_li[fut_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) - if not delIDexists: - continue - - cca_df_i.loc[delID] = core.getBaseCca_df([delID]).loc[delID] - - areRelIDsPresent = False - for relID in relIDsOfDelIDs: - try: - ccs = cca_df_i.at[relID, 'cell_cycle_stage'] - relationship = cca_df_i.at[relID, 'relationship'] - ccaStatus = relIDsCcaStatus[relID] - cca_df_i.loc[relID] = ccaStatus - areRelIDsPresent = True - except Exception as err: - continue - - if not areRelIDsPresent: - break - - self.store_cca_df( - frame_i=fut_frame_i, cca_df=cca_df_i, autosave=False - ) - - # Correct past frames - for past_frame_i in range(posData.frame_i-1, -1, -1): - cca_df_i = self.get_cca_df(frame_i=past_frame_i, return_df=True) - if cca_df_i is None: - # ith frame was not visited yet - break - - self.storeUndoRedoCca(past_frame_i, cca_df_i, undoId) - if dropInPast: - cca_df_i = cca_df_i.drop(deletedIDs, errors='ignore') - else: - for delID in deletedIDs: - dataDict = posData.allData_li[past_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) - if not delIDexists: - continue - - cca_df_i.loc[delID] = core.getBaseCca_df([delID]).loc[delID] - - areRelIDsPresent = False - for relID in relIDsOfDelIDs: - try: - ccs = cca_df_i.at[relID, 'cell_cycle_stage'] - relationship = cca_df_i.at[relID, 'relationship'] - ccaStatus = relIDsCcaStatus[relID] - cca_df_i.loc[relID] = ccaStatus - areRelIDsPresent = True - except Exception as err: - continue - - if not areRelIDsPresent: - break - - self.store_cca_df( - frame_i=past_frame_i, cca_df=cca_df_i, autosave=False - ) - - def update_cca_df_newIDs(self, posData, new_IDs): - for newID in new_IDs: - self.addIDBaseCca_df(posData, newID) - - def update_cca_df_snapshots(self, editTxt, posData): - cca_df = posData.cca_df - cca_df_IDs = cca_df.index - if editTxt == 'Delete ID': - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Separate IDs': - new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] - self.update_cca_df_newIDs(posData, new_IDs) - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Edit ID': - new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] - self.update_cca_df_newIDs(posData, new_IDs) - old_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, old_IDs) - - elif editTxt == 'Annotate ID as dead': - return - - elif editTxt == 'Deleted non-selected objects': - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Delete ID with eraser': - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Add new ID with brush tool': - new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] - self.update_cca_df_newIDs(posData, new_IDs) - - elif editTxt == 'Merge IDs': - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Add new ID with curvature tool': - new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] - self.update_cca_df_newIDs(posData, new_IDs) - - elif editTxt == 'Delete IDs using ROI': - deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] - self.update_cca_df_deletedIDs(posData, deleted_IDs) - - elif editTxt == 'Repeat segmentation': - posData.cca_df = self.getBaseCca_df() - - def fixCcaDfAfterEdit(self, editTxt): - posData = self.data[self.pos_i] - if posData.cca_df is not None: - # For snapshot mode we fix or reinit cca_df depending on the edit - self.update_cca_df_snapshots(editTxt, posData) - self.store_data() - - def isFrameCcaAnnotated(self): - posData = self.data[self.pos_i] - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - if acdc_df is None: - return False - - return 'cell_cycle_stage' in acdc_df.columns - - def warnEditingWithCca_df( - self, editTxt, return_answer=False, get_answer=False, - get_cancelled=False, update_images=True - ): - # Function used to warn that the user is editing in "Segmentation and - # Tracking" mode a frame that contains cca annotations. - # Ask whether to remove annotations from all future frames - if self.isSnapshot: - return True - - posData = self.data[self.pos_i] - acdc_df = posData.allData_li[posData.frame_i]['acdc_df'] - - if acdc_df is None and self.lineage_tree is None: - if update_images: - self.updateAllImages() - return True - - cell_cycle_stage_present = ( - acdc_df is not None and 'cell_cycle_stage' in acdc_df.columns - ) - lineage_tree_present = ( - self.lineage_tree is not None or 'parent_ID_tree' in acdc_df.columns - ) - if not cell_cycle_stage_present and not lineage_tree_present: - if update_images: - self.updateAllImages() - return True - - action = self.warnEditingWithAnnotActions.get(editTxt, None) - if action is not None and not action.isChecked(): - # user has checked that he does not want to be asked again AND he doesnt want to delete - if update_images: - self.updateAllImages() - return True - - msg = widgets.myMessageBox() - warn_type = 'cell cycle annotations' if cell_cycle_stage_present else 'lineage tree annotations' - txt = html_utils.paragraph( - f'You modified a frame that has {warn_type}.

' - f'The change "{editTxt}" most likely makes the ' - 'annotations wrong.

' - 'If you really want to apply this change we reccommend to remove' - f'ALL {warn_type}
' - 'from current frame to the end.

' - 'What do you want to do?' - ) - if action is not None: - checkBox = QCheckBox('Remember my choice and do not ask again') - else: - checkBox = None - - dropDelIDsNoteText = ( - '' if editTxt.find('Delete') == -1 else ' (drop removed IDs)' - ) - _, removeAnnotButton, _ = msg.warning( - self, 'Edited segmentation with annotations!', txt, - buttonsTexts=( - 'Cancel', - 'Remove annotations from future frames (RECOMMENDED)', - f'Do not remove annotations{dropDelIDsNoteText}' - ), widgets=checkBox - ) - if msg.cancel: - if get_cancelled: - return 'cancelled' - removeAnnotations = False - return removeAnnotations - - if action is not None: - action.setChecked(not checkBox.isChecked()) - action.removeAnnot = msg.clickedButton == removeAnnotButton - - if return_answer: - return msg.clickedButton == removeAnnotButton - - if (msg.clickedButton == removeAnnotButton) and cell_cycle_stage_present: - self.resetFutureCcaColCurrentFrame() - self.resetCcaFuture(posData.frame_i+1) - self.updateAllImages() - elif (msg.clickedButton == removeAnnotButton) and lineage_tree_present: - self.resetLin_tree_future() - self.updateAllImages() - else: - if dropDelIDsNoteText and posData.cca_df is not None: - delIDs = [ - ID for ID in posData.cca_df.index if ID not in posData.IDs - ] - self.update_cca_df_deletedIDs( - posData, delIDs, dropInPast=False - ) - self.addMissingIDs_cca_df(posData) - self.updateAllImages() - self.store_data() - # if action is not None: - # if action.removeAnnot: - # self.store_data() - # posData.frame_i -= 1 - # self.get_data() - # if lineage_tree_present: - # self.resetLin_tree_future() - # self.resetCcaFuture(posData.frame_i) - # self.next_frame() - - if get_answer: - return msg.clickedButton == removeAnnotButton - else: - return True - - def warnRepeatTrackingVideoOnVisitedFrames(self, last_tracked_i, start_n): - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'You are repeating tracking on frames that have already ' - 'been visited/tracked before.

' - 'This will very likely make the annotations wrong.

' - 'If you really want to repeat tracking on the frames before ' - f'{last_tracked_i+1} the annotations from frame ' - f'{start_n} to frame {last_tracked_i+1} ' - 'will be removed.

' - 'Do you want to continue?' - ) - noButton, yesButton = msg.warning( - self, 'Repating tracking with annotations!', txt, - buttonsTexts=( - ' No, stop tracking and keep annotations.', - ' Yes, repeat tracking and DELETE annotations.' - ) - ) - if msg.cancel: - return False - - if msg.clickedButton == noButton: - return False - else: - return True - - def warnRepeatTrackingVideoWithAnnotations(self, last_tracked_i, start_n): - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'You are repeating tracking on frames that have cell cycle ' - 'annotations.

' - 'This will very likely make the annotations wrong.

' - 'If you really want to repeat tracking on the frames before ' - f'{last_tracked_i+1} the annotations from frame ' - f'{start_n} to frame {last_tracked_i+1} ' - 'will be removed.

' - 'Do you want to continue?' - ) - noButton, yesButton = msg.warning( - self, 'Repating tracking with annotations!', txt, - buttonsTexts=( - ' No, stop tracking and keep annotations.', - ' Yes, repeat tracking and DELETE annotations.' - ) - ) - if msg.cancel: - return False - - if msg.clickedButton == noButton: - return False - else: - return True - - def setDelRoiState(self, roi: pg.ROI, state): - roi.sigRegionChanged.disconnect() - roi.sigRegionChangeFinished.disconnect() - roi.setState(state) - roi.sigRegionChanged.connect(self.delROImoving) - roi.sigRegionChangeFinished.connect(self.delROImovingFinished) - - def addExistingDelROIs(self): - posData = self.data[self.pos_i] - delROIs_info = posData.allData_li[posData.frame_i]['delROIs_info'] - isAx2hidden = not self.labelsGrad.showLabelsImgAction.isChecked() - - for r, roi in enumerate(delROIs_info['rois']): - if isinstance(roi, pg.PolyLineROI) or isAx2hidden: - # PolyLine ROIs are only on ax1 - self.ax1.addDelRoiItem(roi, roi.key) - else: - # Rect ROI is on ax2 because ax2 is visible - self.ax2.addDelRoiItem(roi, roi.key) - - self.setDelRoiState(roi, delROIs_info['state'][r]) - - def updateFramePosLabel(self): - if self.isSnapshot: - posData = self.data[self.pos_i] - self.navSpinBox.setValueNoEmit(self.pos_i+1) - else: - posData = self.data[0] - self.navSpinBox.setValueNoEmit(posData.frame_i+1) - - def highlightHoverID(self, x, y, hoverID=None): - if hoverID is None: - try: - hoverID = self.currentLab2D[int(y), int(x)] - except IndexError: - return - - if hoverID == 0: - return - - posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[hoverID] - obj = posData.rp[objIdx] - self.goToZsliceSearchedID(obj) - self.highlightSearchedID(hoverID) - - def grayOutHighlightedLabels(self, nonGrayedIDs=None, alpha=None): - if nonGrayedIDs is None: - nonGrayedIDs = set() - - posData = self.data[self.pos_i] - if alpha is None: - alpha = self.imgGrad.labelsAlphaSlider.value() - - if not hasattr(self, 'highlightedLab'): - self.highlightedLab = np.zeros_like(self.currentLab2D) - else: - self.highlightedLab[:] = 0 - - lut = np.zeros((2, 4), dtype=np.uint8) - for _obj in posData.rp: - if not self.isObjVisible(_obj.bbox): - continue - if _obj.label not in nonGrayedIDs: - continue - _slice = self.getObjSlice(_obj.slice) - _objMask = self.getObjImage(_obj.image, _obj.bbox) - self.highlightedLab[_slice][_objMask] = _obj.label - rgb = self.lut[_obj.label].copy() - lut[1, :-1] = rgb - # Set alpha to 0.7 - lut[1, -1] = 178 - - return lut - - def grayOutOverlaySegm(self, ax=0): - if ax == 0: - how = self.drawIDsContComboBox.currentText() - else: - how = self.getAnnotateHowRightImage() - - isOverlaySegmActive = how.find('segm. masks') != -1 - if not isOverlaySegmActive: - return - - grayedLut = self.grayOutHighlightedLabels() - - def highlightHoverIDsKeptObj(self, x, y, hoverID=None): - if hoverID is None: - try: - hoverID = self.currentLab2D[int(y), int(x)] - except IndexError: - return - - self.highlightSearchedID(hoverID, greyOthers=False) - - if hoverID == 0 and self.highlightedID == 0: - return - - if hoverID == 0 and self.highlightedID != 0: - self.clearHighlightedKeepIDs() - for ID in self.keptObjectsIDs: - self.highlightLabelID(ID) - return - - posData = self.data[self.pos_i] - try: - objIdx = posData.IDs_idxs[hoverID] - except KeyError as err: - return - - obj = posData.rp[objIdx] - self.goToZsliceSearchedID(obj) - - for ID in self.keptObjectsIDs: - self.highlightLabelID(ID) - - def getHighlightedID(self): - if self.highlightedID > 0: - return self.highlightedID - - doHighlight = ( - self.propsDockWidget.isVisible() - and ( - self.guiTabControl.highlightCheckbox.isChecked() - or self.guiTabControl.highlightSearchedCheckbox.isChecked() - ) - ) - if not doHighlight: - return 0 - - return self.guiTabControl.propsQGBox.idSB.value() - - def clearHighlightedKeepIDs(self): - self.setAllTextAnnotations() - self.highlightedID = 0 - self.searchedIDitemRight.setData([], []) - self.searchedIDitemLeft.setData([], []) - self.highLightIDLayerImg1.clear() - self.highLightIDLayerRightImage.clear() - - def setHighlighedIDfromToolbar(self, ID: int): - self.findID(ID=ID) - - def highlightSearchedID(self, ID, force=False, greyOthers=True): - self.highlightIDToolbar.setIDNoSignals(ID) - - if ID == 0: - self.highlightIDToolbar.setVisible(False) - return - - if ID == self.highlightedID and not force: - return - - doHighlight = ( - self.propsDockWidget.isVisible() - and ( - self.guiTabControl.highlightCheckbox.isChecked() - or self.guiTabControl.highlightSearchedCheckbox.isChecked() - ) - ) - if doHighlight: - self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() - ID = self.highlightedID - - if self.highlightedID > 0: - self.clearHighlightedText() - - self.searchedIDitemRight.setData([], []) - self.searchedIDitemLeft.setData([], []) - - posData = self.data[self.pos_i] - - self.highlightedID = ID - self.highlightIDToolbar.setVisible(True) - - objIdx = posData.IDs_idxs.get(ID) - if objIdx is None: - return - - obj = posData.rp[objIdx] - isObjVisible = self.isObjVisible(obj.bbox) - if not isObjVisible: - return - - if greyOthers: - self.textAnnot[0].grayOutAnnotations() - self.textAnnot[1].grayOutAnnotations() - - how_ax1 = self.drawIDsContComboBox.currentText() - how_ax2 = self.getAnnotateHowRightImage() - isOverlaySegm_ax1 = how_ax1.find('segm. masks') != -1 - isOverlaySegm_ax2 = how_ax2.find('segm. masks') != -1 - alpha = self.imgGrad.labelsAlphaSlider.value() - - if isOverlaySegm_ax1 or isOverlaySegm_ax2: - grayedLut = self.grayOutHighlightedLabels( - nonGrayedIDs={obj.label}, - alpha=alpha - ) - - cont = None - contours = None - if isOverlaySegm_ax1: - self.highLightIDLayerImg1.setLookupTable(grayedLut) - self.highLightIDLayerImg1.setImage(self.highlightedLab) - self.labelsLayerImg1.setOpacity(alpha/3) - else: - contours = self.getObjContours(obj, all_external=True) - for cont in contours: - self.searchedIDitemLeft.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) - - if isOverlaySegm_ax2: - self.highLightIDLayerRightImage.setLookupTable(grayedLut) - self.highLightIDLayerRightImage.setImage(self.highlightedLab) - self.labelsLayerRightImg.setOpacity(alpha/3) - else: - if contours is None: - contours = self.getObjContours(obj, all_external=True) - for cont in contours: - self.searchedIDitemRight.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) - - # Gray out all IDs excpet searched one - lut = self.lut.copy() # [:max(posData.IDs)+1] - lut[:ID] = lut[:ID]*0.2 - lut[ID+1:] = lut[ID+1:]*0.2 - self.img2.setLookupTable(lut) - - # Highlight text - self.highlightLabelID(ID, ax=0) - self.highlightLabelID(ID, ax=1) - - def _drawGhostContour(self, x, y): - if self.ghostObject is None: - return - - ID = self.ghostObject.label - yc, xc = self.ghostObject.local_centroid - Dx = x-xc - Dy = y-yc - xx = self.ghostObject.xx_contour + Dx - yy = self.ghostObject.yy_contour + Dy - self.ghostContourItemLeft.setData( - xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x - ) - self.ghostContourItemRight.setData( - xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x - ) - - def _drawManualBackgroundObjContour(self, x, y): - if self.manualBackgroundObj is None: - return - - ID = self.manualBackgroundObj.label - yc, xc = self.manualBackgroundObj.local_centroid - Dx = x-xc - Dy = y-yc - xx = self.manualBackgroundObj.xx_contour + Dx - yy = self.manualBackgroundObj.yy_contour + Dy - self.manualBackgroundObjItem.setData( - xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x - ) - - def _drawGhostMask(self, x, y): - if self.ghostObject is None: - return - - self.clearGhostMask() - ID = self.ghostObject.label - h, w = self.ghostObject.image.shape[-2:] - yc, xc = self.ghostObject.local_centroid - Dx = int(x-xc) - Dy = int(y-yc) - bbox = ((Dy, Dy+h), (Dx, Dx+w)) - - Y, X = self.currentLab2D.shape - slices = myutils.get_slices_local_into_global_arr(bbox, (Y, X)) - slice_global_to_local, slice_crop_local = slices - - obj_image = self.ghostObject.image[slice_crop_local] - - self.ghostMaskItemLeft.image[slice_global_to_local][obj_image] = ID - self.ghostMaskItemLeft.updateGhostImage( - fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x - ) - - self.ghostMaskItemRight.image[slice_global_to_local][obj_image] = ID - self.ghostMaskItemRight.updateGhostImage( - fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x - ) - - def drawManualBackgroundObj(self, x, y): - if x is None or y is None: - self.clearGhost() - return - - self._drawManualBackgroundObjContour(x, y) - - def drawManualTrackingGhost(self, x, y): - if not self.manualTrackingToolbar.showGhostCheckbox.isChecked(): - return - - if x is None or y is None: - self.clearGhost() - return - - if self.manualTrackingToolbar.ghostContourRadiobutton.isChecked(): - self._drawGhostContour(x, y) - else: - self._drawGhostMask(x, y) - - def restoreDefaultSettings(self): - df = self.df_settings - df.at['contLineWeight', 'value'] = 1 - df.at['mothBudLineSize', 'value'] = 1 - df.at['mothBudLineColor', 'value'] = (255, 165, 0, 255) - df.at['contLineColor', 'value'] = (205, 0, 0, 220) - - self._updateContColour((205, 0, 0, 220)) - self._updateMothBudLineColour((255, 165, 0, 255)) - self._updateMothBudLineSize(1) - self._updateContLineThickness() - - df.at['overlaySegmMasksAlpha', 'value'] = 0.3 - df.at['img_cmap', 'value'] = 'grey' - self.imgCmap = self.imgGrad.cmaps['grey'] - self.imgCmapName = 'grey' - self.labelsGrad.item.loadPreset('viridis') - df.at['labels_bkgrColor', 'value'] = (25, 25, 25) - - if df.at['is_bw_inverted', 'value'] == 'Yes': - self.invertBw(update=False) - - df = df[~df.index.str.contains('lab_cmap')] - df.to_csv(self.settings_csv_path) - self.imgGrad.restoreState(df) - for items in self.overlayLayersItems.values(): - lutItem = items[1] - lutItem.restoreState(df) - - self.labelsGrad.saveState(df) - self.labelsGrad.restoreState(df, loadCmap=False) - - self.df_settings.to_csv(self.settings_csv_path) - self.updateAllImages() - - def updateLabelsAlpha(self, value): - self.df_settings.at['overlaySegmMasksAlpha', 'value'] = value - self.df_settings.to_csv(self.settings_csv_path) - if self.keepIDsButton.isChecked(): - value = value/3 - self.labelsLayerImg1.setOpacity(value) - self.labelsLayerRightImg.setOpacity(value) - - - def _getImageupdateAllImages(self, image=None): - if image is not None: - return image - - img = self.getImage() - return img - - def setImageImg1(self, image=None): - img = self._getImageupdateAllImages(image=image) - posData = self.data[self.pos_i] - self.img1.setCurrentPosIndex(self.pos_i) - self.img1.setCurrentFrameIndex(posData.frame_i) - if posData.SizeZ > 1: - zProjHow = self.zProjComboBox.currentText() - if zProjHow == 'single z-slice': - z = self.zSliceScrollBar.sliderPosition() - else: - z = zProjHow - - self.img1.setCurrentZsliceIndex(z) - - self.img1.setImage( - img, next_frame_image=self.nextFrameImage(), - scrollbar_value=posData.frame_i+2 - ) - - def getContoursImageItem(self, ax, force=False): - if not self.areContoursRequested(ax) and not force: - return - - if ax == 0: - return self.ax1_contoursImageItem - else: - return self.ax2_contoursImageItem - - def getLostObjImageItem(self, ax): - if ax == 0: - return self.ax1_lostObjImageItem - else: - return self.ax1_lostTrackedObjImageItem - - def getLostTrackedObjImageItem(self, ax): - if ax == 0: - return self.ax1_lostTrackedObjImageItem - else: - return self.ax2_lostTrackedObjImageItem - - def setManualBackgroundImage(self): - if not self.manualBackgroundButton.isChecked(): - return - - posData = self.data[self.pos_i] - if not hasattr(posData, 'manualBackgroundImage'): - self.initManualBackgroundImage() - - contours = [] - for obj in skimage.measure.regionprops(posData.manualBackgroundLab): - obj_contours = self.getObjContours(obj, all_external=True) - contours.extend(obj_contours) - textItem = self.manualBackgroundTextItems[obj.label] - textItem.setText(f'{obj.label}') - self.ax1.addItem(textItem) - yc, xc = obj.centroid - textItem.setPos(xc, yc) - - cv2.drawContours( - posData.manualBackgroundImage, contours, -1, (255, 0, 0, 200), 1 - ) - self.manualBackgroundImageItem.setImage(posData.manualBackgroundImage) - - def setManualBackgrounNextID(self): - posData = self.data[self.pos_i] - currentID = self.manualBackgroundObj.label - idx = posData.IDs_idxs[currentID] - next_idx = idx + 1 - if next_idx >= len(posData.IDs): - return - next_ID = posData.IDs[next_idx] - self.manualBackgroundToolbar.spinboxID.setValue(next_ID) - - def clearManualBackgroundObject(self, ID): - posData = self.data[self.pos_i] - mask = posData.manualBackgroundLab==ID - posData.manualBackgroundImage[mask, :] = 0 - posData.manualBackgroundLab[mask] = 0 - - def addManualBackgroundObject(self, x, y): - posData = self.data[self.pos_i] - - if not hasattr(self, 'manualBackgroundObj'): - self.initManualBackgroundObject() - - Y, X = self.currentLab2D.shape - ymin, xmin, ymax, xmax = self.manualBackgroundObj.bbox - width, height = xmax-xmin, ymax-ymin - yc, xc = self.manualBackgroundObj.local_centroid - xstart, ystart = round(x-xc), round(y-yc) - xstart = xstart if xstart >= 0 else 0 - ystart = ystart if ystart >= 0 else 0 - - xend = xstart+width - yend = ystart+height - xend = xend if xend <= X else X - yend = yend if yend <= Y else Y - - width = xend-xstart - height = yend-ystart - - obj_image = self.manualBackgroundObj.image[:height, :width] - obj_slice = (slice(ystart, yend), slice(xstart, xend)) - ID = self.manualBackgroundObj.label - self.clearManualBackgroundObject(ID) - posData.manualBackgroundLab[obj_slice][obj_image] = ID - - if ID in self.manualBackgroundTextItems: - self.manualBackgroundTextItems[ID].setPos(x, y) - return - - textItem = pg.TextItem( - text=str(ID), color='r', anchor=(0.5, 0.5) - ) - textItem.setFont(font_13px) - textItem.setPos(x, y) - self.manualBackgroundTextItems[ID] = textItem - - self.ax1.addItem(textItem) - - def setManualBackgroundLab(self, load_from_store=False, debug=True): - posData = self.data[self.pos_i] - if posData.manualBackgroundLab is None: - self.initManualBackgroundImage() - - for obj in skimage.measure.regionprops(posData.manualBackgroundLab): - textItem = pg.TextItem(text='', color='r', anchor=(0.5, 0.5)) - if obj.label in self.manualBackgroundTextItems: - continue - self.manualBackgroundTextItems[obj.label] = textItem - - def updateContoursImage(self, ax, delROIsIDs=None, compute=True): - imageItem = self.getContoursImageItem(ax) - if imageItem is None: - return - - if not hasattr(self, 'contoursImage'): - self.initContoursImage() - else: - self.contoursImage[:] = 0 - - contours = [] - for obj in skimage.measure.regionprops(self.currentLab2D): - obj_contours = self.getObjContours( - obj, - all_external=True, - force_calc=compute, - include_internal=self.showAllContoursToggle.isChecked() - ) - contours.extend(obj_contours) - - thickness = self.contLineWeight - color = self.contLineColor - self.setContoursImage(imageItem, contours, thickness, color) - - def setContoursImage(self, imageItem, contours, thickness, color): - cv2.drawContours(self.contoursImage, contours, -1, color, thickness) - imageItem.setImage(self.contoursImage) - - def getObjFromID(self, ID): - posData = self.data[self.pos_i] - try: - idx = posData.IDs_idxs[ID] - except KeyError as e: - # Object already cleared - return - - obj = posData.rp[idx] - return obj - - def setLostObjectContour(self, obj): - allContours = self.getObjContours(obj, all_external=True) - for objContours in allContours: - xx = objContours[:,0] + 0.5 - yy = objContours[:,1] + 0.5 - data = [obj.label]*len(xx) - self.ax1_lostObjScatterItem.addPoints(xx, yy, data=data) - self.ax2_lostObjScatterItem.addPoints(xx, yy) - - def setTrackedLostObjectContour(self, obj): - if self.isExportingVideo: - return - - allContours = self.getObjContours(obj, all_external=True) - for objContours in allContours: - xx = objContours[:,0] + 0.5 - yy = objContours[:,1] + 0.5 - data = [obj.label]*len(xx) - self.ax1_lostTrackedScatterItem.addPoints(xx, yy, data=data) - self.ax2_lostTrackedScatterItem.addPoints(xx, yy) - - def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): - if draw: - imageItem = self.getLostObjImageItem(ax) - if imageItem is None: - return - - if not hasattr(self, 'lostObjContoursImage'): - self.initLostObjContoursImage() - else: - self.lostObjContoursImage[:] = 0 - - if delROIsIDs is None: - delROIsIDs = set() - - posData = self.data[self.pos_i] - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - if posData.whitelist is not None and posData.whitelist.whitelistIDs is not None: - whitelist = posData.whitelist.whitelistIDs[posData.frame_i-1] - else: - whitelist = None - - contours = [] - for lostID in posData.lost_IDs: - if lostID in delROIsIDs or (whitelist is not None and lostID not in whitelist): - continue - - obj = prev_rp[prev_IDs_idxs[lostID]] - if not self.isObjVisible(obj.bbox): - continue - - obj_contours = self.getObjContours(obj, all_external=True) - - if ax == 0: - self.addLostObjsToLostObjImage(obj, lostID) - - contours.extend(obj_contours) - - if not draw: - return - - self.drawLostObjContoursImage(imageItem, contours) - - def drawLostObjContoursImage( - self, imageItem, contours, - thickness=1, - color=(255, 165, 0, 255) # orange - ): - img = self.lostObjContoursImage - cv2.drawContours(img, contours, -1, color, thickness) - imageItem.setImage(img) - - def updateLostTrackedContoursImage( - self, ax, delROIsIDs=None, tracked_lost_IDs=None - ): - imageItem = self.getLostTrackedObjImageItem(ax) - if imageItem is None: - return - - if not hasattr(self, 'lostTrackedObjContoursImage'): - self.initLostTrackedObjContoursImage() - else: - self.lostTrackedObjContoursImage[:] = 0 - - if delROIsIDs is None: - delROIsIDs = set() - - posData = self.data[self.pos_i] - if tracked_lost_IDs is None: - tracked_lost_IDs = self.getTrackedLostIDs() - - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - contours = [] - for tracked_lost_ID in tracked_lost_IDs: - if tracked_lost_ID in delROIsIDs: - continue - - obj = prev_rp[prev_IDs_idxs[tracked_lost_ID]] - if not self.isObjVisible(obj.bbox): - continue - - obj_contours = self.getObjContours(obj, all_external=True) - contours.extend(obj_contours) - - self.drawLostTrackedObjContoursImage(imageItem, contours) - - def drawLostTrackedObjContoursImage(self, imageItem, contours): - thickness = 1 - color = (0, 255, 0, 255) # green - img = self.lostTrackedObjContoursImage - cv2.drawContours(img, contours, -1, color, thickness) - imageItem.setImage(img) - - def getNearestLostObjID(self, y, x): - if not self.annotLostObjsToggle.isChecked(): - return - - posData = self.data[self.pos_i] - if not posData.lost_IDs: - return - - prev_lab = posData.allData_li[posData.frame_i-1]['labels'] - if prev_lab is None: - return - - # if not hasattr(self, 'lostObjContoursImage'): - # self.store_data() - # posData.frame_i -= 1 - # self.get_data() - # self.store_data() - # posData.frame_i += 1 - # self.get_data() - # self.updateLostNewCurrentIDs() - # self.updateLostContoursImage(ax=0) - # self.updateLostContoursImage(ax=1) - # self.updateLostNewCurrentIDs() - - yy, xx, _ = np.nonzero(self.lostObjContoursImage) - lostObjsContourMask = np.zeros(self.currentLab2D.shape, dtype=bool) - lostObjsContourMask[yy.astype(int), xx.astype(int)] = True - - # Add accepted lost IDs - try: - yy, xx, _ = np.nonzero(self.lostTrackedObjContoursImage) - lostObjsContourMask[yy.astype(int), xx.astype(int)] = True - except Exception as err: - pass - - _, y_nearest, x_nearest = core.nearest_nonzero_2D( - lostObjsContourMask, y, x, return_coords=True - ) - nearest_ID = self.get_2Dlab(prev_lab)[y_nearest, x_nearest] - - if nearest_ID == 0: - return - - return nearest_ID - - def setCcaIssueContour(self, obj): - objContours = self.getObjContours(obj, all_external=True) - for cont in objContours: - xx = cont[:,0] + 0.5 - yy = cont[:,1] + 0.5 - self.ax1_lostObjScatterItem.addPoints(xx, yy) - self.textAnnot[0].addObjAnnotation( - obj, 'lost_object', f'{obj.label}?', False - ) - - def isLastVisitedAgainCca(self, curr_df, enforceAll=False): - # Determine if this is the last visited frame for repeating - # bud assignment on non manually corrected_on_frame_i buds. - # The idea is that the user could have assigned division on a cell - # by going previous and we want to check if this cell could be a - # "better" mother for those non manually corrected buds - posData = self.data[self.pos_i] - if curr_df is None: - return False - - if 'cell_cycle_stage' not in curr_df.columns: - return False - - if enforceAll: - return False - - lastVisited = False - posData.new_IDs = [ - ID for ID in posData.new_IDs - if curr_df.at[ID, 'is_history_known'] - and curr_df.at[ID, 'cell_cycle_stage'] == 'S' - ] - if posData.frame_i+1 < posData.SizeT: - next_df = posData.allData_li[posData.frame_i+1]['acdc_df'] - if next_df is None: - lastVisited = True - else: - if 'cell_cycle_stage' not in next_df.columns: - lastVisited = True - else: - lastVisited = True - - return lastVisited - - def highlightNewCellNotEnoughG1cells(self, IDsCellsG1): - posData = self.data[self.pos_i] - for obj in posData.rp: - if obj.label not in IDsCellsG1: - continue - objContours = self.getObjContours(obj) - if objContours is not None: - xx = objContours[:,0] + 0.5 - yy = objContours[:,1] + 0.5 - self.ccaFailedScatterItem.addPoints(xx, yy) - self.textAnnot[0].addObjAnnotation( - obj, 'green', f'{obj.label}?', False - ) - - def handleNoCellsInG1(self, numCellsG1, numNewCells): - posData = self.data[self.pos_i] - self.highlightNewCellNotEnoughG1cells(posData.new_IDs) - continueAnyway = _warnings.warnNotEnoughG1Cells( - numCellsG1, posData.frame_i, numNewCells, qparent=self - ) - if continueAnyway: - notEnoughG1Cells = False - proceed = True - # Annotate the new IDs with unknown history - for ID in posData.new_IDs: - posData.cca_df.loc[ID] = pd.Series(base_cca_dict) - cca_df_ID = self.getStatusKnownHistoryBud(ID) - posData.ccaStatus_whenEmerged[ID] = cca_df_ID - else: - notEnoughG1Cells = True - proceed = False - - # Clear new cells annotations - self.ccaFailedScatterItem.setData([], []) - return notEnoughG1Cells, proceed - - def addObjContourToContoursImage( - self, ID=0, obj=None, ax=0, thickness=None, color=None, - force=False - ): - imageItem = self.getContoursImageItem(ax, force=force) - if imageItem is None: - return - - if obj is None: - obj = self.getObjFromID(ID) - if obj is None: - return - - contours = self.getObjContours(obj, all_external=True) - if thickness is None: - thickness = self.contLineWeight - if color is None: - color = self.contLineColor - - self.setContoursImage(imageItem, contours, thickness, color) - - def clearObjContour( - self, ID=0, obj=None, ax=0, debug=False, updateImage=True - ): - imageItem = self.getContoursImageItem(ax) - if imageItem is None: - return - - if ID > 0: - self.contoursImage[self.currentLab2D==ID] = [0,0,0,0] - else: - obj_slice = self.getObjSlice(obj.slice) - obj_image = self.getObjImage(obj.image, obj.bbox) - self.contoursImage[obj_slice][obj_image] = [0,0,0,0] - - if not updateImage: - return - - imageItem.setImage(self.contoursImage) - - def clearAnnotItems(self): - self.textAnnot[0].clear() - self.textAnnot[1].clear() - - # @exec_time - def setAllTextAnnotations(self, labelsToSkip=None): - delROIsIDs = self.setLostNewOldPrevIDs() - posData = self.data[self.pos_i] - self.textAnnot[0].setAnnotations( - posData=posData, - labelsToSkip=labelsToSkip, - isVisibleCheckFunc=self.isObjVisible, - highlightedID=self.highlightedID, - delROIsIDs=delROIsIDs, - annotateLost=self.annotLostObjsToggle.isChecked(), - getCurrentZfunc=self.z_lab, - getObjCentroidFunc=self.getObjCentroid - ) - self.textAnnot[1].setAnnotations( - posData=posData, labelsToSkip=labelsToSkip, - isVisibleCheckFunc=self.isObjVisible, - highlightedID=self.highlightedID, - delROIsIDs=delROIsIDs, - annotateLost=self.annotLostObjsToggle.isChecked(), - getCurrentZfunc=self.z_lab, - getObjCentroidFunc=self.getObjCentroid - ) - self.textAnnot[0].update() - self.textAnnot[1].update() - return delROIsIDs - - def setAllContoursImages(self, delROIsIDs=None, compute=True): - if compute: - self.computeAllContours() - self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) - self.updateContoursImage(ax=1, delROIsIDs=delROIsIDs, compute=compute) - - def setAllLostObjContoursImage(self, delROIsIDs=None): - self.updateLostContoursImage(ax=0, delROIsIDs=None) - self.updateLostContoursImage(ax=1, delROIsIDs=None) - - def setAllLostTrackedObjContoursImage(self, delROIsIDs=None): - self.updateLostTrackedContoursImage(ax=0, delROIsIDs=None) - self.updateLostTrackedContoursImage(ax=1, delROIsIDs=None) - - def nextFrameImage(self, current_frame_i=None): - if not self.labelsGrad.showNextFrameAction.isEnabled(): - return - - if not self.labelsGrad.showNextFrameAction.isChecked(): - return - - posData = self.data[self.pos_i] - if current_frame_i is None: - current_frame_i = posData.frame_i - - next_frame_i = current_frame_i + 1 - if next_frame_i >= len(posData.img_data): - img = posData.img_data[-1] - else: - img = posData.img_data[next_frame_i] - - if posData.SizeZ > 1: - img = self.get_2Dimg_from_3D(img, isLayer0=True) - - # img = self.normalizeIntensities(img) - - return img - - def onKeyHome(self): - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepAdd - ) - - def onKeyEnd(self): - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepSub - ) - - def onKeyPageUp(self): - isAutoPilotActive = ( - self.autoPilotZoomToObjToggle.isChecked() - and self.autoPilotZoomToObjToolbar.isVisible() - ) - if isAutoPilotActive: - self.pointsLayerAutoPilot('next') - elif self.zSliceScrollBar.isVisible(): - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepAdd - ) - - def onKeyPageDown(self): - isAutoPilotActive = ( - self.autoPilotZoomToObjToggle.isChecked() - and self.autoPilotZoomToObjToolbar.isVisible() - ) - if isAutoPilotActive: - self.pointsLayerAutoPilot('prev') - elif self.zSliceScrollBar.isVisible(): - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepAdd - ) - - def keyUpCallback( - self, isBrushActive, isWandActive, isExpandLabelActive, - isLabelRoiCircActive - ): - isAutoPilotActive = ( - self.autoPilotZoomToObjToggle.isChecked() - and self.autoPilotZoomToObjToolbar.isVisible() - ) - if isBrushActive: - brushSize = self.brushSizeSpinbox.value() - self.brushSizeSpinbox.setValue(brushSize+1) - elif isWandActive: - wandTolerance = self.wandControlsToolbar.toleranceSpinbox.value() - self.wandControlsToolbar.toleranceSpinbox.setValue(wandTolerance+1) - elif isExpandLabelActive: - self.expandLabel(dilation=True) - self.expandFootprintSize += 1 - elif isLabelRoiCircActive: - val = self.labelRoiCircularRadiusSpinbox.value() - self.labelRoiCircularRadiusSpinbox.setValue(val+1) - elif isAutoPilotActive: - self.pointsLayerAutoPilot('next') - else: - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepAdd - ) - - def keyDownCallback( - self, isBrushActive, isWandActive, isExpandLabelActive, - isLabelRoiCircActive - ): - isAutoPilotActive = ( - self.autoPilotZoomToObjToggle.isChecked() - and self.autoPilotZoomToObjToolbar.isVisible() - ) - if isBrushActive: - brushSize = self.brushSizeSpinbox.value() - self.brushSizeSpinbox.setValue(brushSize-1) - elif isWandActive: - wandTolerance = self.wandControlsToolbar.toleranceSpinbox.value() - self.wandControlsToolbar.toleranceSpinbox.setValue(wandTolerance-1) - elif isExpandLabelActive: - self.expandLabel(dilation=False) - self.expandFootprintSize += 1 - elif isLabelRoiCircActive: - val = self.labelRoiCircularRadiusSpinbox.value() - self.labelRoiCircularRadiusSpinbox.setValue(val-1) - elif isAutoPilotActive: - self.pointsLayerAutoPilot('prev') - elif self.isNavigateActionOnNextFrame(): - posData = self.data[self.pos_i] - self.rightImageFramesScrollbar.setValue(posData.frame_i+2) - else: - self.zSliceScrollBar.triggerAction( - QAbstractSlider.SliderAction.SliderSingleStepSub - ) - - # @exec_time - @exception_handler - def updateAllImages( - self, image=None, computePointsLayers=True, computeContours=True, - updateLookuptable=True - ): - self.clearAllItems() - - posData = self.data[self.pos_i] - - self.last_pos_i = self.pos_i - self.last_frame_i = posData.frame_i - - self.rescaleIntensitiesLut(setImage=False) - - self.setImageImg1(image=image) - self.setImageImg2(updateLookuptable=updateLookuptable) - - self.setOverlayImages() - - self.setOverlayLabelsItems() - self.setOverlaySegmMasks() - - if self.slideshowWin is not None: - self.slideshowWin.frame_i = posData.frame_i - self.slideshowWin.update_img() - - # self.update_rp() - - # Annotate ID and draw contours - delROIsIDs = self.setAllTextAnnotations() - self.setAllContoursImages( - delROIsIDs=delROIsIDs, compute=False - ) - - mode = self.modeComboBox.currentText() - self.drawAllMothBudLines() - if mode == 'Normal division: Lineage tree': - self.drawAllLineageTreeLines() - - self.highlightLostNew() - - if self.ccaTableWin is not None: # need to add for lin tree, later - zoomIDs = self.getZoomIDs() - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - self.doCustomAnnotation(0) - - self.annotate_rip_and_bin_IDs() - self.updateTempLayerKeepIDs() - self.whitelistUpdateTempLayer() - self.drawPointsLayers(computePointsLayers=computePointsLayers) - self.setManualBackgroundImage() - self.annotateAssignedObjsAcdcTrackerSecondStep() - - self.highlightSearchedID(self.highlightedID, force=True) - self.updateTimestampFrame() - - posData.visited = True - - def updateTimestampFrame(self): - if not hasattr(self, 'timestamp'): - return - - if not self.addTimestampAction.isChecked(): - return - - posData = self.data[self.pos_i] - self.timestamp.setText(posData.frame_i) - - def deleteIDFromLab( - self, lab, delID, frame_i=None, delMask=None, shift=False - ): - posData = self.data[self.pos_i] - frame_i = posData.frame_i if frame_i is None else frame_i - - if shift and self.isSegm3D: - lab3D = lab - delMask3D = delMask - lab = self.get_2Dlab(lab) - if delMask is not None: - delMask = self.get_2Dlab(delMask) - rp = skimage.measure.regionprops(lab) - IDs_idxs = {obj.label: idx for idx, obj in enumerate(rp)} - else: - if frame_i==posData.frame_i: - rp = posData.rp - IDs_idxs = posData.IDs_idxs - else: - rp = posData.allData_li[frame_i]['regionprops'] - IDs_idxs = posData.allData_li[frame_i]['IDs_idxs'] - - if isinstance(delID, int): - delID = [delID] - - is_any_id_present = False - for _delID in delID: - if _delID in IDs_idxs: - is_any_id_present = True - break - - if not is_any_id_present: - return lab, delMask - - if delMask is None: - delMask = np.zeros(lab.shape, dtype=bool) - else: - delMask[:] = False - - for _delID in delID: - idx = IDs_idxs.get(_delID, None) - if idx is None: - continue - obj = rp[idx] - delMask[obj.slice][obj.image] = True - lab[delMask] = 0 - - if shift and self.isSegm3D: - self.set_2Dlab(lab, lab3D=lab3D) - lab = lab3D - if delMask3D is not None: - self.set_2Dlab(delMask, lab3D=delMask3D) - delMask = delMask3D - - return lab, delMask - - def removeStoredContours(self, delID, frame_i=None, z_slice=None): - posData = self.data[self.pos_i] - - if frame_i is None: - frame_i = posData.frame_i - - dataDict = posData.allData_li[posData.frame_i] - try: - newContours = {} - for key, contours in dataDict['contours'].items(): - ID = key[0] - if ID == delID: - continue - - if z_slice is not None: - z_slice_i = key[1] - if z_slice_i != z_slice: - continue - - newContours[key] = contours - - dataDict['contours'] = newContours - except KeyError as err: - pass - - @disableWindow - def deleteIDmiddleClick( - self, delIDs: Iterable, applyFutFrames, includeUnvisited, - shift=False - ): - self.clearHighlightedID() - - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - - # Apply Delete ID to future frames if requested - if applyFutFrames: - delMask = np.zeros(posData.lab.shape, dtype=bool) - # Store current data before going to future frames - self.store_data() - segmSizeT = len(posData.segm_data) - for i in range(posData.frame_i+1, segmSizeT): - lab = posData.allData_li[i]['labels'] - if lab is None and not includeUnvisited: - self.enqAutosave() - break - - if lab is not None: - # Visited frame - lab, _ = self.deleteIDFromLab( - lab, delIDs, frame_i=i, delMask=delMask, shift=shift - ) - - # Store change - posData.allData_li[i]['labels'] = lab - # Get the rest of the stored metadata based on the new lab - posData.frame_i = i - self.get_data() - self.store_data(autosave=False) - elif includeUnvisited: - # Unvisited frame (includeUnvisited = True) - lab = posData.segm_data[i] - lab, _ = self.deleteIDFromLab( - lab, delIDs, frame_i=i, delMask=delMask, shift=shift - ) - - # Back to current frame - if applyFutFrames: - posData.frame_i = current_frame_i - self.get_data() - - z_slice = None - if shift and self.isSegm3D: - z_slice = self.z_lab() - - posData.lab, delID_mask = self.deleteIDFromLab( - posData.lab, delIDs, shift=shift - ) - for _delID in delIDs: - self.clearObjContour(ID=_delID, ax=0) - self.clearObjContour(ID=_delID, ax=1) - if z_slice is None: - self.removeObjectFromRp(_delID) - self.removeStoredContours(_delID, z_slice=z_slice) - - if shift and self.isSegm3D: - self.update_rp() - - self.store_data(autosave=False) - self.whitelistPropagateIDs(IDs_to_remove=delIDs, curr_frame_only=(not applyFutFrames)) - return delID_mask - - def hideOverlayLabelsItems(self, specific=None): - if specific is None: - specific = self.overlayLabelsItems.keys() - for segmEndname in specific: - imageItem, contoursItem, gradItem = self.overlayLabelsItems[segmEndname] - imageItem.setVisible(False) - contoursItem.setVisible(False) - gradItem.setVisible(False) - - def showOverlayLabelsItems(self, specific=None): - if specific is None: - specific = self.overlayLabelsItems.keys() - for segmEndname in specific: - imageItem, contoursItem, gradItem = self.overlayLabelsItems[segmEndname] - drawMode = self.drawModeOverlayLabelsChannels[segmEndname] - if drawMode == 'Draw contours': - contoursItem.setVisible(True) - elif drawMode == 'Overlay labels': - imageItem.setVisible(True) - gradItem.setVisible(True) - - def setOverlayLabelsItems(self, specific=None): - if not self.overlayLabelsButton.isChecked(): - self.hideOverlayLabelsItems(specific=specific) - return - - if specific is None: - specific = self.drawModeOverlayLabelsChannels.keys() - - for segmEndname in specific: - drawMode = self.drawModeOverlayLabelsChannels[segmEndname] - ol_lab = self.getOverlayLabelsData(segmEndname) - items = self.overlayLabelsItems[segmEndname] - imageItem, contoursItem, gradItem = items - contoursItem.clear() - if drawMode == 'Draw contours': - for obj in skimage.measure.regionprops(ol_lab): - contours = self.getObjContours( - obj, all_external=True - ) - for cont in contours: - contoursItem.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) - elif drawMode == 'Overlay labels': - imageItem.setImage(ol_lab, autoLevels=False) - self.showOverlayLabelsItems(specific=specific) - - def getOverlayLabelsData(self, segmEndname): - posData = self.data[self.pos_i] - - if posData.ol_labels_data is None: - self.loadOverlayLabelsData(segmEndname) - elif segmEndname not in posData.ol_labels_data: - self.loadOverlayLabelsData(segmEndname) - - comb_seg = False - if 'combined segm.' == segmEndname: - comb_seg = True - if not self.isSegm3D: - zStackImg = self.data[0].SizeZ > 1 - if zStackImg: - selected_z_stack = self.zSliceScrollBar.sliderPosition() - else: - selected_z_stack = 0 - out = posData.ol_labels_data['combined segm.'][posData.frame_i][selected_z_stack] - return out.astype(np.uint32) - - if self.isSegm3D: - zProjHow = self.zProjComboBox.currentText() - isZslice = zProjHow == 'single z-slice' - if isZslice: - z = self.zSliceScrollBar.sliderPosition() - ol_lab = posData.ol_labels_data[segmEndname][posData.frame_i][z] - if comb_seg: - ol_lab = ol_lab.astype(np.uint32) - return ol_lab - else: - ol_lab = posData.ol_labels_data[segmEndname][posData.frame_i].max(axis=0) - if comb_seg: - ol_lab = ol_lab.astype(np.uint32) - return ol_lab - else: - return posData.ol_labels_data[segmEndname][posData.frame_i] - - def loadOverlayLabelsData(self, segmEndname, pos_i=None): - if pos_i is None: - pos_i = self.pos_i - posData = self.data[pos_i] - - if posData.ol_labels_data is None: - posData.ol_labels_data = {} - if segmEndname == 'combined segm.': - posData.ol_labels_data['combined segm.'] = posData.combine_img_data - return - filePath, filename = load.get_path_from_endname( - segmEndname, posData.images_path - ) - self.logger.info(f'Loading "{segmEndname}.npz"...') - labelsData = np.load(filePath)['arr_0'] - if posData.SizeT == 1: - labelsData = labelsData[np.newaxis] - if self.isSegm3D and labelsData.ndim == 3: - # 2D segm --> stack to 3D - T, Y, X = labelsData.shape - repeat = [labelsData]*posData.SizeZ - labelsData = np.stack(repeat, axis=1) - - - posData.ol_labels_data[segmEndname] = labelsData - - def startBlinkingModeCB(self): - try: - self.timer.stop() - self.stopBlinkTimer.stop() - except Exception as e: - pass - if self.rulerButton.isChecked(): - return - self.timer = QTimer(self) - self.timer.timeout.connect(self.blinkModeComboBox) - self.timer.start(200) - self.stopBlinkTimer = QTimer(self) - self.stopBlinkTimer.timeout.connect(self.stopBlinkingCB) - self.stopBlinkTimer.start(2000) - - def blinkModeComboBox(self): - if self.flag: - self.modeComboBox.setStyleSheet('background-color: orange') - else: - self.modeComboBox.setStyleSheet('background-color: none') - self.flag = not self.flag - - def stopBlinkingCB(self): - self.timer.stop() - self.modeComboBox.setStyleSheet('background-color: none') - - def highlightNewIDs_ccaFailed(self, IDsWithIssue, rp=None): - if rp is None: - posData = self.data[self.pos_i] - rp = posData.rp - for obj in rp: - if obj.label not in IDsWithIssue: - continue - self.setCcaIssueContour(obj) - - # @exec_time - def highlightLostNew(self): - if self.modeComboBox.currentText() == 'Viewer': - return - - posData = self.data[self.pos_i] - delROIsIDs = self.getDelRoisIDs() - - # self.setAllContoursImages(delROIsIDs=delROIsIDs) - if posData.frame_i == 0: - return - - if not self.annotLostObjsToggle.isChecked(): - return - - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - - if prev_rp is None: - return - - self.setAllLostObjContoursImage(delROIsIDs=delROIsIDs) - self.setAllLostTrackedObjContoursImage(delROIsIDs=delROIsIDs) - - def addLostObjsToLostObjImage(self, lostObj, lostID, force=False): - if not force: - if not self.copyLostObjButton.isChecked(): - return - - obj_slice = self.getObjSlice(lostObj.slice) - obj_image = self.getObjImage(lostObj.image, lostObj.bbox) - self.lostObjImage[obj_slice][obj_image] = lostID - - def highlightHoverLostObj(self, modifiers, event): - noModifier = modifiers == Qt.NoModifier - if not noModifier: - return - - if not self.copyLostObjButton.isChecked(): - return - - if event.isExit(): - return - - posData = self.data[self.pos_i] - x, y = event.pos() - xdata, ydata = int(x), int(y) - try: - hoverLostID = self.lostObjImage[ydata, xdata] - except IndexError: - return - - self.ax1_lostObjScatterItem.hoverLostID = hoverLostID - if hoverLostID == 0: - self.ax1_lostObjScatterItem.setSize(self.contLineWeight+1) - self.ax1_lostObjScatterItem.setData([], []) - else: - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - lostObj = prev_rp[prev_IDs_idxs[hoverLostID]] - obj_contours = self.getObjContours(lostObj, all_external=True) - for cont in obj_contours: - xx = cont[:,0] - yy = cont[:,1] - self.ax1_lostObjScatterItem.addPoints(xx, yy) - self.ax1_lostObjScatterItem.setSize(self.contLineWeight+2) - - def annotLostObjsToggled(self, checked): - if not self.isDataLoaded: - return - self.updateAllImages() - - def getPrevFrameIDs(self, current_frame_i=None): - posData = self.data[self.pos_i] - if current_frame_i is None: - current_frame_i = posData.frame_i - - if current_frame_i is None: - return [] - - prev_frame_i = current_frame_i - 1 - prevIDs = posData.allData_li[prev_frame_i]['IDs'] - - if prevIDs: - return prevIDs - - # IDs in previous frame were not stored --> load prev lab from HDD - prev_lab = self.get_labels( - from_store=False, - frame_i=prev_frame_i, - return_copy=False - ) - rp = skimage.measure.regionprops(prev_lab) - prevIDs = [obj.label for obj in rp] - return prevIDs - - # @exec_time - def setLostNewOldPrevIDs(self): - posData = self.data[self.pos_i] - if posData.frame_i == 0: - posData.lost_IDs = [] - posData.new_IDs = [] - posData.old_IDs = [] - # posData.multiContIDs = set() - self.titleLabel.setText('Looking good!', color=self.titleColor) - return [] - - # elif self.modeComboBox.currentText() == 'Viewer': - # pass - - out = self.updateLostNewCurrentIDs() - lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs, curr_delRoiIDs = ( - out - ) - self.setTitleText( - lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs - ) - return curr_delRoiIDs - - - def setTitleFormatter(self, htmlTxt_li, htmlTxtFull_li, pretxt, color, IDs): - if not IDs: - return htmlTxt_li, htmlTxtFull_li - - if isinstance(IDs, set): - IDs = list(IDs) - - trim_IDs = myutils.get_trimmed_list(IDs) - txt = f'{pretxt}: {trim_IDs}' - txt_full = f'{pretxt}:
{IDs}' - - txt = f'{txt}' - txt_full = f'{txt_full}' - - htmlTxt_li.append(txt) - htmlTxtFull_li.append(txt_full) - - return htmlTxt_li, htmlTxtFull_li - - def setTitleText( - self, lost_IDs=None, new_IDs=None, IDs_with_holes=None, - tracked_lost_IDs=None - ): - if self.manualAnnotPastButton.isChecked(): - lockedID = self.editIDspinbox.value() - frame_to_restore = self.manualAnnotState.get('frame_i_to_restore') - txt = ( - f'Locked ID {lockedID} ' - f'since frame n. {frame_to_restore+1}' - ) - htmlTxt = f'{txt}' - self.titleLabel.setText(htmlTxt) - return - - mode = self.modeComboBox.currentText() - try: - posData = self.data[self.pos_i] - posData.segm_data[posData.frame_i] - prev_segmented = True - except IndexError: - prev_segmented = False - - if prev_segmented: - htmlTxt_li = [] - htmlTxtFull_li = [] - else: - htmlTxt = f'Never segmented frame. ' - self.titleLabel.setText(htmlTxt) - self.titleLabel.setToolTip(htmlTxt) - return - - if mode != 'Normal division: Lineage tree': - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'IDs lost', 'orange', lost_IDs - ) - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'New IDs', 'red', new_IDs - ) - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'Acc. IDs lost', 'green', - tracked_lost_IDs - ) - - for i, htmlTxtFull in enumerate(htmlTxtFull_li): - htmlTxtFull_li[i] = htmlTxtFull.replace('Acc.', 'Accepted') - - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'IDs with holes', 'red', - IDs_with_holes - ) - else: - try: - cells_with_parent, orphan_cells, lost_cells = self.lineage_tree.export_lin_tree_info(posData.frame_i) - except IndexError or KeyError: - title = 'Processing lineage tree...' - htmlTxt = f'{title}' - self.titleLabel.setText(htmlTxt) - self.titleLabel.setToolTip(htmlTxt) - return - except AttributeError: - title = 'Lineage tree still initializing...' - htmlTxt = f'{title}' - self.titleLabel.setText(htmlTxt) - self.titleLabel.setToolTip(htmlTxt) - return - - parent_cell_txt_raw = [] - if cells_with_parent: - # aggregate same parents - parent_cell_groups = dict() - for cell, parent in cells_with_parent: - if parent not in parent_cell_groups: - parent_cell_groups[parent] = [] - parent_cell_groups[parent].append(cell) - for parent, daughters in parent_cell_groups.items(): - cells_str = ','.join([str(daughter) for daughter in daughters]) - parent_cell_txt_raw.append(f'({parent}>{cells_str})') - - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'New w/out mother', 'red', - orphan_cells - ) - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'Lost', 'yellow', lost_cells - ) - htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( - htmlTxt_li, htmlTxtFull_li, 'Parent > Cell', 'green', - parent_cell_txt_raw - ) - - if not htmlTxt_li: - title = 'Looking good' - htmlTxt = f'{title}' - self.titleLabel.setText(htmlTxt) - self.titleLabel.setToolTip(htmlTxt) - return - - htmlTxt = ', '.join(htmlTxt_li) - htmlTxtFull = '
'.join(htmlTxtFull_li) - - self.titleLabel.setText(htmlTxt) - self.titleLabel.setToolTip(htmlTxtFull) - - def separateByLabelling(self, lab, rp, maxID=None): - """ - Label each single object in posData.lab and if the result is more than - one object then we insert the separated object into posData.lab - """ - setRp = False - posData = self.data[self.pos_i] - if maxID is None: - maxID = max(posData.IDs, default=1) - for obj in rp: - lab_obj = skimage.measure.label(obj.image) - rp_lab_obj = skimage.measure.regionprops(lab_obj) - if len(rp_lab_obj)<=1: - continue - lab_obj += maxID - _slice = obj.slice # self.getObjSlice(obj.slice) - _objMask = obj.image # self.getObjImage(obj.image) - lab[_slice][_objMask] = lab_obj[_objMask] - setRp = True - maxID += 1 - return setRp - - def isFirstTimeOnNextFrame(self): - posData = self.data[self.pos_i] - posData.last_tracked_i = self.navigateScrollBar.maximum()-1 - return posData.frame_i > posData.last_tracked_i - - def trackManuallyAddedObject( - self, added_IDs: List[int] | int | Set[int], isNewID: bool, - wl_update:bool=True, wl_track_og_curr:bool=False - ): - """Track object added manually on frame that was already visited. - - Parameters - ---------- - added_IDs : int | list of int | set - ID or IDs of the object added manually - isNewID : bool - If True, the added object is new - - Notes - ----- - This method tracks the new added object against the previous frame - labels. If the ID determined by tracking is different from `added_ID` - (meaning that tracking thinks the new ID should be changed to the - tracked ID) and the tracked ID is not already existing (which would - otherwise causing merging) we assign the tracked ID to the object with - `added_ID`. - - If instead the tracked ID is the same as `added_ID` we are dealing - with a truly new object. In this case we want to try tracking it against - the next frame (since the next frame was already validated). - As before, we assign the tracked ID (against the next frame) only if - not already existing in current frame (to avoid merging). - """ - if self.isSnapshot: - return - - if not isNewID: - return - - if isinstance(added_IDs, int): - added_IDs = [added_IDs] - - posData = self.data[self.pos_i] - tracked_lab = self.tracking( - enforce=True, assign_unique_new_IDs=False, return_lab=True, - IDs=added_IDs - ) - self.clearAssignedObjsSecondStep() - if tracked_lab is None: - return - - # Track only new object - prevIDs = posData.allData_li[posData.frame_i-1]['IDs'] - - # mask = np.zeros(posData.lab.shape, dtype=bool) - update_rp = False - - for added_ID in added_IDs: - # try: - # obj = posData.rp[added_ID] # ID not present - # mask[obj.slice][obj.image] = True - - # except IndexError as err: - mask = posData.lab == added_ID - try: - trackedID = tracked_lab[mask][0] - except IndexError as err: - # added_ID is not present - continue - - isTrackedIDalreadyPresentAndNotNew = ( - posData.IDs_idxs.get(trackedID) is not None - and added_ID != trackedID - ) - if isTrackedIDalreadyPresentAndNotNew: - continue - - isTrackedIDinPrevIDs = trackedID in prevIDs - if isTrackedIDinPrevIDs: - posData.lab[mask] = trackedID - else: - # New object where we can try to track against next frame - trackedID = self.trackNewIDtoNewIDsFutureFrame(added_ID, mask) - if trackedID is None: - self.clearAssignedObjsSecondStep() - continue - posData.lab[mask] = trackedID - - self.keepOnlyNewIDAssignedObjsSecondStep(trackedID) - update_rp = True - - if update_rp: - self.update_rp(wl_update=wl_update) - - def trackFrameCustomTracker( - self, prev_lab, currentLab, IDs=None, unique_ID=None - ): - if unique_ID is None: - unique_ID = self.setBrushID() - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - IDs=IDs, - **self.track_frame_params, - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, IDs=IDs, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'IDs\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params) - else: - raise err - elif str(err).find('an unexpected keyword argument \'IDs\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params - ) - else: - raise err - else: - raise err - return tracked_result - - def trackFrame( - self, prev_lab, prev_rp, curr_lab, curr_rp, curr_IDs, - assign_unique_new_IDs=True, IDs=None, unique_ID=None - ): - if self.trackWithAcdcAction.isChecked(): - tracked_result = CellACDC_tracker.track_frame( - prev_lab, prev_rp, curr_lab, curr_rp, - IDs_curr_untracked=curr_IDs, - setBrushID_func=self.setBrushID, - posData=self.data[self.pos_i], - assign_unique_new_IDs=assign_unique_new_IDs, - IDs=IDs, - unique_ID=unique_ID - ) - elif self.trackWithYeazAction.isChecked(): - tracked_result = self.tracking_yeaz.correspondence( - prev_lab, curr_lab, use_modified_yeaz=True, - use_scipy=True - ) - else: - tracked_result = self.trackFrameCustomTracker( - prev_lab, curr_lab, IDs=IDs, unique_ID=unique_ID - ) - - # Check if tracker also returns additional info - if isinstance(tracked_result, tuple): - tracked_lab, tracked_lost_IDs = tracked_result - self.handleAdditionalInfoRealTimeTracker(prev_rp, tracked_lost_IDs) - else: - tracked_lab = tracked_result - - return tracked_lab - - def clearAssignedObjsSecondStep(self): - posData = self.data[self.pos_i] - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None - - def trackSubsetIDs(self, subsetIDs: Iterable[int]): - posData = self.data[self.pos_i] - if posData.frame_i == 0: - return - - subsetLab = np.zeros_like(posData.lab) - for subsetID in subsetIDs: - subsetLab[posData.lab == subsetID] = subsetID - - prev_lab = posData.allData_li[posData.frame_i-1]['labels'] - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - tracked_lab = self.trackFrame( - prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=True - ) - doUpdateRp = False - for subsetID in subsetIDs: - subsetIDmask = posData.lab == subsetID - trackedID = tracked_lab[subsetIDmask][0] - if trackedID == subsetID: - continue - - is_manually_edited = False - for y, x, new_ID in posData.editID_info: - if new_ID == subsetID: - # Do not track because it was manually edited - break - - posData.lab[subsetIDmask] = tracked_lab[subsetIDmask] - doUpdateRp = True - - if not doUpdateRp: - return - - self.update_rp() - - def doSkipTracking(self, against_next: bool, enforce: bool): - if self.isSnapshot: - return True - - mode = str(self.modeComboBox.currentText()) - if mode != 'Segmentation and Tracking': - return True - - if self.UserEnforced_DisabledTracking: - return True - - if not self.realTimeTrackingToggle.isChecked(): - return True - - posData = self.data[self.pos_i] - if against_next: - reference_lab = posData.allData_li[posData.frame_i+1]['labels'] - if reference_lab is None: - # Next frame never visited --> cannot track against next - return True - - if posData.frame_i == posData.SizeT - 1: - # Last frame --> cannot track against next - return True - - else: - # check that we are not on the last frame - if posData.frame_i == 0: - return True - - if enforce or self.UserEnforced_Tracking: - # Enforce even if not last visited frame - return False - - is_first_time_on_next_frame = self.isFirstTimeOnNextFrame() - skip_tracking = not is_first_time_on_next_frame - - return skip_tracking - - - # @exec_time - @exception_handler - def tracking( - self, enforce=False, DoManualEdit=True, - storeUndo=False, prev_lab=None, prev_rp=None, - return_lab=False, assign_unique_new_IDs=True, - separateByLabel=True, wl_update=True, - IDs=None, against_next=False, - ): - posData = self.data[self.pos_i] - - if self.doSkipTracking(against_next, enforce): - self.setLostNewOldPrevIDs() - return - - """Tracking starts here""" - staturBarLabelText = self.statusBarLabel.text() - self.statusBarLabel.setText('Tracking...') - - if storeUndo: - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - # First separate by labelling - if separateByLabel: - maxID = max(posData.IDs, default=1) - setRp = core.split_connected_components( - posData.lab, rp=posData.rp, max_ID=maxID - ) - if setRp: - self.update_rp(wl_update=wl_update, ) - - if prev_lab is None: - if not against_next: - prev_lab = posData.allData_li[posData.frame_i-1]['labels'] - else: - prev_lab = posData.allData_li[posData.frame_i+1]['labels'] - if prev_rp is None: - if not against_next: - prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - else: - prev_rp = posData.allData_li[posData.frame_i+1]['regionprops'] - - unique_ID = None - if posData.frame_i < self.get_last_tracked_i(): - unique_ID = self.setBrushID(return_val=True) - - tracked_lab = self.trackFrame( - prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=assign_unique_new_IDs, IDs=IDs, - unique_ID=unique_ID - ) - - if DoManualEdit: - # Correct tracking with manually changed IDs - rp = skimage.measure.regionprops(tracked_lab) - IDs = [obj.label for obj in rp] - self.manuallyEditTracking(tracked_lab, IDs) - - if return_lab: - QTimer.singleShot(50, partial( - self.statusBarLabel.setText, staturBarLabelText - )) - return tracked_lab - - # Update labels, regionprops and determine new and lost IDs - posData.lab = tracked_lab - self.update_rp(wl_update=wl_update, ) - self.setAllTextAnnotations() - QTimer.singleShot(50, partial( - self.statusBarLabel.setText, staturBarLabelText - )) - - def handleAdditionalInfoRealTimeTracker(self, prev_rp, *args): - if self._rtTrackerName == 'CellACDC_normal_division': - tracked_lost_IDs = args[0] - self.setTrackedLostCentroids(prev_rp, tracked_lost_IDs) - elif self._rtTrackerName == 'CellACDC_2steps': - if args[0] is None: - return - posData = self.data[self.pos_i] - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = args[0] - - def keepOnlyNewIDAssignedObjsSecondStep(self, trackedID): - posData = self.data[self.pos_i] - annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) - - if annotInfo is None: - return - - new_objs_1st_step, lost_objs_1st_step = annotInfo - correct_new_objs, correct_lost_objs = [], [] - for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): - newObj_ID = posData.lab[newObj.slice][newObj.image][0] - if newObj_ID != trackedID: - continue - - correct_new_objs.append(newObj) - correct_lost_objs.append(lostObj) - - if not correct_new_objs: - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None - else: - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = ( - correct_new_objs, correct_lost_objs - ) - # self.annotateAssignedObjsAcdcTrackerSecondStep() - - def updateAssignedObjsAcdcTrackerSecondStep(self, newID): - posData = self.data[self.pos_i] - annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) - if annotInfo is None: - return - - new_objs_1st_step, lost_objs_1st_step = annotInfo - correct_new_objs, correct_lost_objs = [], [] - for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): - newObj_ID = posData.lab[newObj.slice][newObj.image][0] - if newObj_ID == newID: - # The ID of the new object tracked with 2nd step was - # manually edit --> do not annotate its linking to lost obj anymore - continue - correct_new_objs.append(newObj) - correct_lost_objs.append(lostObj) - - if not correct_new_objs: - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None - else: - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = ( - correct_new_objs, correct_lost_objs - ) - self.annotateAssignedObjsAcdcTrackerSecondStep() - - - def annotateAssignedObjsAcdcTrackerSecondStep(self): - posData = self.data[self.pos_i] - annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) - if annotInfo is None: - return - - new_objs_1st_step, lost_objs_1st_step = annotInfo - for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): - allContours = self.getObjContours(lostObj, all_external=True) - for objContours in allContours: - isObjVisible = self.isObjVisible(newObj.bbox) - if not isObjVisible: - continue - xx = objContours[:,0] + 0.5 - yy = objContours[:,1] + 0.5 - self.yellowContourScatterItem.addPoints(xx, yy) - - y1, x1 = self.getObjCentroid(lostObj.centroid) - y2, x2 = self.getObjCentroid(newObj.centroid) - xx, yy = core.get_line(y1, x1, y2, x2, dashed=False) - self.ax1_oldMothBudLinesItem.addPoints(xx, yy) - - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None - - def setTrackedLostCentroids(self, prev_rp, tracked_lost_IDs): - """Store centroids of those IDs the tracker decided is fine to lose - (e.g., upon standard cell division the ID of the mother is fine) - - Parameters - ---------- - prev_rp : skimage.measure.RegionProperties - List of region properties of the object in previous frame - tracked_lost_IDs : iterable - List-like container of the IDs that is fine to lose from previous - frame to current frame - - Note - ---- - This function stores the centroids because the user could change IDs - in multiple ways. Storing centroids is more robust. - """ - posData = self.data[self.pos_i] - frame_i = posData.frame_i - - for obj in prev_rp: - if obj.label not in tracked_lost_IDs: - continue - - int_centroid = tuple([int(val) for val in obj.centroid]) - try: - posData.tracked_lost_centroids[frame_i].add(int_centroid) - except KeyError: - posData.tracked_lost_centroids[frame_i] = {int_centroid} - - def getTrackedLostIDs(self, prev_lab=None, IDs_in_frames=None, frame_i=None): - trackedLostIDs = set() - posData = self.data[self.pos_i] - if self.isExportingVideo: - posData.trackedLostIDs = trackedLostIDs - return trackedLostIDs - - retrackedLostcent = set() - if frame_i is None: - frame_i = posData.frame_i - - if prev_lab is None: - prev_lab = self.get_labels( - from_store=True, - frame_i=posData.frame_i-1, - return_existing=False, - return_copy=False - ) - - if IDs_in_frames is None: - IDs_in_frames = posData.IDs - - try: - tracked_lost_centroids = posData.tracked_lost_centroids[frame_i] - except KeyError: - tracked_lost_centroids = set() - - for centroid in tracked_lost_centroids: - if len(centroid) < 3 and prev_lab.ndim == 3: - # Ignore wrongly stored centroids - continue - - ID = prev_lab[centroid] - if ID == 0: - continue - - if ID in IDs_in_frames: - retrackedLostcent.add(centroid) - continue - - trackedLostIDs.add(ID) - - posData.tracked_lost_centroids[frame_i] = ( - tracked_lost_centroids - retrackedLostcent - ) - posData.trackedLostIDs = trackedLostIDs - - return trackedLostIDs - - def manuallyEditTracking(self, tracked_lab, allIDs): - posData = self.data[self.pos_i] - infoToRemove = [] - # Correct tracking with manually changed IDs - maxID = max(allIDs, default=1) - for y, x, new_ID in posData.editID_info: - old_ID = tracked_lab[y, x] - if old_ID == 0 or old_ID == new_ID: - infoToRemove.append((y, x, new_ID)) - continue - if new_ID in allIDs: - tempID = maxID+1 - tracked_lab[tracked_lab == old_ID] = tempID - tracked_lab[tracked_lab == new_ID] = old_ID - tracked_lab[tracked_lab == tempID] = new_ID - else: - tracked_lab[tracked_lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID - for info in infoToRemove: - posData.editID_info.remove(info) - - def warnReinitLastSegmFrame(self): - current_frame_n = self.navigateScrollBar.value() - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - Are you sure you want to re-initialize the last visited and - validated frame to number {current_frame_n}?

- WARNING: If you save, all annotations after frame number - {current_frame_n} will be lost! - """) - msg.warning( - self, 'WARNING: Potential loss of data', txt, - buttonsTexts=('Cancel', 'Yes, I am sure') - ) - return msg.cancel - - def extendSegmDataIfNeeded(self, stopFrameNum): - posData = self.data[self.pos_i] - segmSizeT = len(posData.segm_data) - if stopFrameNum <= segmSizeT: - return - numFramesToAdd = stopFrameNum - segmSizeT - posData.allData_li.extend( - [myutils.get_empty_stored_data_dict() for i in range(numFramesToAdd)] - ) - lab_shape = posData.segm_data[0].shape - shapeToAdd = (numFramesToAdd, *lab_shape) - additionalSegmData = np.zeros(shapeToAdd, dtype=posData.segm_data.dtype) - extendedSegmData = np.concatenate((posData.segm_data, additionalSegmData)) - posData.segm_data = extendedSegmData - - def reInitLastSegmFrame( - self, checked=True, from_frame_i=None, updateImages=True, - force=False - ): - if not force: - cancel = self.warnReinitLastSegmFrame() - if cancel: - self.logger.info( - 'Re-initialization of last validated frame cancelled.' - ) - return - - posData = self.data[self.pos_i] - if from_frame_i is None: - from_frame_i = posData.frame_i - - self.lastFrameRanOnFirstVisitTools = posData.frame_i - - self.updateLastCheckedFrameWidgets(from_frame_i) - posData.last_tracked_i = from_frame_i - self.navigateScrollBar.setMaximum(from_frame_i+1) - self.navSpinBox.setMaximum(from_frame_i+1) - # self.navigateScrollBar.setMinimum(1) - - # posData.tracked_lost_centroids[from_frame_i-1] = set() - for i in range(from_frame_i, posData.SizeT): - if posData.allData_li[i]['labels'] is None: - break - - posData.segm_data[i] = posData.allData_li[i]['labels'] - posData.allData_li[i] = myutils.get_empty_stored_data_dict() - - posData.tracked_lost_centroids[i] = set() - posData.acdcTracker2stepsAnnotInfo.pop(i, None) - - if posData.acdc_df is not None: - frames = posData.acdc_df.index.get_level_values(0) - if from_frame_i in frames: - posData.acdc_df = posData.acdc_df.loc[:from_frame_i] - - self.removeAlldelROIsCurrentFrame() - - if not updateImages: - return - - self.updateAllImages() - - def resetAcceptedLostIDs(self, from_frame_i=None): - posData = self.data[self.pos_i] - if from_frame_i is None: - from_frame_i = posData.frame_i - - posData.tracked_lost_centroids[from_frame_i-1] = set() - for i in range(from_frame_i, posData.SizeT): - posData.tracked_lost_centroids[i] = set() - - def removeAllItems(self): - self.ax1.clear() - self.ax2.clear() - try: - self.chNamesQActionGroup.removeAction(self.userChNameAction) - except Exception as e: - pass - try: - posData = self.data[self.pos_i] - for action in self.fluoDataChNameActions: - self.chNamesQActionGroup.removeAction(action) - except Exception as e: - pass - try: - self.overlayButton.setChecked(False) - except Exception as e: - pass - - if hasattr(self, 'contoursImage'): - self.initContoursImage() - - def createUserChannelNameAction(self): - self.userChNameAction = QAction(self) - self.userChNameAction.setCheckable(True) - self.userChNameAction.setText(self.user_ch_name) - - def createChannelNamesActions(self): - # LUT histogram channel name context menu actions - self.chNamesQActionGroup = QActionGroup(self) - self.chNamesQActionGroup.addAction(self.userChNameAction) - posData = self.data[self.pos_i] - for action in self.fluoDataChNameActions: - self.chNamesQActionGroup.addAction(action) - action.setChecked(False) - - self.userChNameAction.setChecked(True) - - for action in self.overlayContextMenu.actions(): - action.setChecked(False) - - def restoreDefaultColors(self): - try: - color = self.defaultToolBarButtonColor - self.overlayButton.setStyleSheet(f'background-color: {color}') - except AttributeError: - # traceback.print_exc() - pass - - @exception_handler - def _createEmptyData(self): - self.MostRecentPath = self.getMostRecentPath() - exp_path = QFileDialog.getExistingDirectory( - self, - 'Select experiment folder where to create empty data', - self.MostRecentPath - ) - if not exp_path: - return - - pos_path = os.path.join(exp_path, 'Position_1') - images_path = os.path.join(pos_path, 'Images') - if os.path.exists(images_path): - raise FileExistsError(f'The following path already exists "{images_path}"') - - os.makedirs(images_path, exist_ok=True) - - basename = 'test_empty_' - tif_filename = f'{basename}channel_1.tif' - tif_filepath = os.path.join(images_path, tif_filename) - empty_img = np.zeros((256,256), dtype=np.uint8) - empty_img[0,0] = 255 - skimage.io.imsave(tif_filepath, empty_img) - - metadata_filename = f'{basename}metadata.csv' - metadata_filepath = os.path.join(images_path, metadata_filename) - df_metadata = pd.DataFrame({ - 'Description': ['basename'], - 'values': [basename] - }) - df_metadata.to_csv(metadata_filepath, index=False) - - self.isNewFile = True - self._openFolder(exp_path=images_path) - - - def segmNdimIndicatorClicked(self): - ndimText = self.segmNdimIndicator.text() - if ndimText == '2D': - alternativeNdimText = '3D' - toggleText = 'activate' - else: - alternativeNdimText = '2D' - toggleText = 'de-activate' - msg = widgets.myMessageBox(wrapText=False) - important_txt = (""" - The toggle to activate 3D segmentation is visible only when - the Number of z-slices is greater than 1. - """) - txt = html_utils.paragraph(f""" - This indicator shows that you are working with {ndimText} - segmentation masks.

- - If instead, you want to work with {alternativeNdimText} segmentation, - you need to initialize a new segmentation file.

- - To do so, go the menu on the top menubar File --> - New Segmentation File... and,
- at the dialog where you insert the metadata (Number of z-slices, - pixel size, etc.),
- {toggleText} the parameter called Work with 3D - segmentation masks (z-stack)
- as indicated in the screenshot below
. - {html_utils.to_admonition(important_txt, admonition_type='note')} -
- """) - msg.information( - self, 'Segmentation nmber of dimensions info', txt, - image_paths=':toggle_3D_screenshot.png' - ) - self.segmNdimIndicator.setChecked(True) - - def newFile(self): - self.newSegmEndName = '' - self.isNewFile = True - msg = widgets.myMessageBox(parent=self, showCentered=False) - msg.setWindowTitle('File or folder?') - msg.addText(html_utils.paragraph(f""" - Do you want to load an image file or Position - folder(s)? - """)) - loadPosButton = QPushButton('Load Position folder', msg) - loadPosButton.setIcon(QIcon(":folder-open.svg")) - loadFileButton = QPushButton('Load image file', msg) - loadFileButton.setIcon(QIcon(":image.svg")) - helpButton = widgets.helpPushButton('Help...') - msg.addButton(helpButton) - helpButton.disconnect() - helpButton.clicked.connect(self.helpNewFile) - msg.addCancelButton(connect=True) - msg.addButton(loadFileButton) - msg.addButton(loadPosButton) - loadPosButton.setDefault(True) - msg.exec_() - if msg.cancel: - return - - if msg.clickedButton == loadPosButton: - self._openFolder() - else: - self._openFile() - - def openNewWindow(self): - self.logger.info('Opening a new window...') - if self.launcherSlot is not None: - self.launcherSlot() - return - - winClass = self.__class__ - win = winClass( - self.app, parent=self, mainWin=self.mainWin, version=self._version - ) - win.run() - self.newWindows.append(win) - - def helpNewFile(self): - msg = widgets.myMessageBox(showCentered=False) - href = f'user manual' - txt = html_utils.paragraph(f""" - Cell-ACDC can open both a single image file or files structured - into Position folders.

- If you are just testing out you can load a single image file, but - in general we reccommend structuring your data into Position - folders.

- More info about Position folders in the {href} at the section - called "Create required data structure from microscopy file(s)". - """) - msg.information( - self, 'Help on Position folders', txt - ) - - def openFile(self, checked=False, file_path=None): - self.logger.info(f'Opening FILE "{file_path}"') - - self.isNewFile = False - self._openFile(file_path=file_path) - - def manageVersions(self): - posData = self.data[self.pos_i] - selectVersion = apps.SelectAcdcDfVersionToRestore(posData, parent=self) - selectVersion.exec_() - - if selectVersion.cancel: - return - - undoId = uuid.uuid4() - if posData.cca_df is not None: - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - - selectedTime = selectVersion.selectedTimestamp - - self.modeComboBox.setCurrentText('Viewer') - self.logger.info(f'Loading file from {selectedTime}...') - - acdc_df = load.read_acdc_df_from_archive( - selectVersion.archiveFilePath, selectVersion.selectedKey - ) - posData.acdc_df = acdc_df - frames = acdc_df.index.get_level_values(0) - last_visited_frame_i = frames.max() - current_frame_i = posData.frame_i - pbar = tqdm(total=last_visited_frame_i+1, ncols=100) - for frame_i in range(last_visited_frame_i+1): - posData.frame_i = frame_i - self.get_data() - if posData.cca_df is not None: - self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) - if posData.allData_li[frame_i]['labels'] is None: - pbar.update() - continue - - if frame_i not in frames: - acdc_df_i = pd.DataFrame(columns=acdc_df.columns) - acdc_df_i.drop(self.cca_df_colnames, axis=1, errors='ignore') - acdc_df_i.index.name = 'Cell_ID' - else: - acdc_df_i = acdc_df.loc[frame_i].dropna(axis=1, how='all') - - posData.allData_li[frame_i]['acdc_df'] = acdc_df_i - pbar.update() - pbar.close() - - # Back to current frame - posData.frame_i = current_frame_i - self.get_data(debug=False) - self.updateAllImages() - self.logger.info('Annotations correctly recovered.') - - def askUserChannelName(self, filename_no_ext, ext): - help_txt = html_utils.paragraph(f""" - Cell-ACDC requires that every image file has a basename and some - additional text, typically the channel name.

- The basename will be common to all created files, while the additional text is used to identify the image files. - """) - - basename = filename_no_ext - underscore_splits = filename_no_ext.split('_') - if len(underscore_splits) > 1: - channel_name = underscore_splits[-1] - basename = '_'.join(underscore_splits[:-1]) - else: - channel_name = 'channel_1' - - txt = html_utils.paragraph(f""" - Provide some text (e.g., the channel name) to append at the end of the image file. - """) - win = apps.filenameDialog( - basename=basename, - ext=ext, - hintText=txt, - defaultEntry=channel_name, - helpText=help_txt, - allowEmpty=False, - parent=self, - title='Provide channel name for image file', - ) - win.exec_() - if win.cancel: - return False, '' - - return True, win.entryText - - def warnUserCreationImagesFolder(self, images_path, ext): - msg = widgets.myMessageBox(wrapText=False) - txt = (f""" - Cell-ACDC requires a specific folder structure to load the data.

- Specifically, it requires the image(s) to be located in a - folder called Images.

- The file format of the images must be TIFF or NPZ - (.tif or .npz extension).

- You can choose to let Cell-ACDC create the required data structure - from your file,
- or you can stop the - process and manually place the image(s) into a folder called - Images.

- If you choose to proceed, Cell-ACDC will create the following - folder: - {images_path} -
- """) - - if ext == '.tif' or ext == '.npz': - txt = f'{txt}How do you want to proceed?' - else: - txt = f'{txt}Do you want to proceed?' - txt = html_utils.paragraph(txt) - - if ext == '.tif' or ext == '.npz': - copyButton = widgets.copyPushButton( - 'Copy the image into the new folder' - ) - moveButton = widgets.movePushButton( - 'Move the image into the new folder' - ) - _, copyButton, moveButton = msg.information( - self, 'Creating Images folder', txt, - buttonsTexts=('Cancel', copyButton, moveButton) - ) - if msg.cancel: - return False, None - - if msg.clickedButton == copyButton: - return True, True - elif msg.clickedButton == moveButton: - return True, False - - else: - msg.information( - self, 'Creating Images folder', txt, - buttonsTexts=('Cancel', 'Yes, proceed') - ) - if msg.cancel: - return False, None - - return True, True - - @exception_handler - def _openFile(self, file_path=None): - """ - Function used for loading an image file directly. - """ - if file_path is None: - self.MostRecentPath = self.getMostRecentPath() - file_path = QFileDialog.getOpenFileName( - self, 'Select image file', self.MostRecentPath, - "Image/Video Files (*.png *.tif *.tiff *.jpg *.jpeg *.mov *.avi *.mp4)" - ";;All Files (*)")[0] - if not file_path: - return - - filename, ext = os.path.splitext(os.path.basename(file_path)) - ext = ext.lower() - dirpath = os.path.dirname(file_path) - dirname = os.path.basename(dirpath) - filename = filename.rstrip('_') - channel_name = None - do_copy = True - if dirname != 'Images': - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - acdc_folder = f'{timestamp}_acdc' - exp_path = os.path.join(dirpath, acdc_folder, 'Images') - proceed, do_copy = self.warnUserCreationImagesFolder(exp_path, ext) - if not proceed: - self.logger.info('Loading image file cancelled.') - return - - proceed, channel_name = self.askUserChannelName( - filename, '.tif' - ) - if not proceed: - self.logger.info('Loading image file cancelled.') - return - - os.makedirs(exp_path, exist_ok=True) - else: - exp_path = dirpath - - if channel_name is not None: - # Check if user wants to use the existing channel name - underscore_splits = filename.split('_') - if len(underscore_splits) > 1: - default_ch_name = underscore_splits[-1] - if channel_name == default_ch_name: - filename = '_'.join(underscore_splits[:-1]) - - basename = f'{filename}_' - new_filename = f'{filename}_{channel_name}{ext}' - df_metadata = pd.DataFrame({ - 'Description': ['basename'], - 'values': [basename] - }) - metadata_csv_filename = f'{basename}metadata.csv' - metadata_csv_filepath = os.path.join( - exp_path, metadata_csv_filename - ) - df_metadata.to_csv(metadata_csv_filepath, index=False) - else: - new_filename = f'{filename}{ext}' - - if do_copy: - action_text = 'Copying' - else: - action_text = 'Moving' - - if ext == '.tif' or ext == '.npz': - new_filepath = os.path.join(exp_path, new_filename) - if not os.path.exists(new_filepath): - self.logger.info(f'{action_text} file to Images folder...') - if do_copy: - shutil.copy2(file_path, new_filepath) - else: - shutil.move(file_path, new_filepath) - self._openFolder(exp_path=exp_path, imageFilePath=new_filepath) - else: - self.logger.info(f'{action_text} file to .tif format...') - data = load.loadData(file_path, '', log_func=self.logger.info) - data.loadImgData() - img = data.img_data - if img.ndim == 3 and (img.shape[-1] == 3 or img.shape[-1] == 4): - self.logger.info('Converting RGB image to grayscale...') - if img.shape[-1] == 3: - data.img_data = skimage.color.rgb2gray(data.img_data) - else: - data.img_data = cv2.cvtColor( - data.img_data, cv2.COLOR_RGBA2GRAY - ) - data.img_data = skimage.img_as_ubyte(data.img_data) - new_filename_no_ext, ext = os.path.splitext(new_filename) - tif_filename = f'{new_filename_no_ext}.tif' - tif_path = os.path.join(exp_path, tif_filename) - if data.img_data.ndim == 3: - SizeT = data.img_data.shape[0] - SizeZ = 1 - elif data.img_data.ndim == 4: - SizeT = data.img_data.shape[0] - SizeZ = data.img_data.shape[1] - else: - SizeT = 1 - SizeZ = 1 - is_imageJ_dtype = ( - data.img_data.dtype == np.uint8 - or data.img_data.dtype == np.uint32 - or data.img_data.dtype == np.uint32 - or data.img_data.dtype == np.float32 - ) - if not is_imageJ_dtype: - data.img_data = skimage.img_as_ubyte(data.img_data) - - myutils.to_tiff(tif_path, data.img_data) - self._openFolder(exp_path=exp_path, imageFilePath=tif_path) - - def criticalNoTifFound(self, images_path): - err_title = 'No .tif files found in folder.' - err_msg = html_utils.paragraph( - 'The following folder

' - f'{images_path}

' - 'does not contain .tif or .h5 files.

' - 'Only .tif or .h5 files can be loaded with "Open Folder" button.

' - 'Try with File --> Open image/video file... ' - 'and directly select the file you want to load.' - ) - msg = widgets.myMessageBox() - msg.addShowInFileManagerButton(images_path) - msg.critical(self, err_title, err_msg) - - def reinitStoredSegmModels(self): - self.models = [None]*len(self.models) - - def checkAskSavePointsLayers(self): - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - if not hasattr(action, 'layerTypeIdx'): - continue - if action.layerTypeIdx != 4: - continue - - scatterItem = action.scatterItem - xx, yy = scatterItem.getData() - - if xx is None or len(xx) == 0: - toolButton = action.button - tableEndName = toolButton.clickEntryTableEndName - # Check in other loaded pos - are_there_points_to_save = False - for pos_i, _posData in enumerate(self.data): - if pos_i == self.pos_i: - continue - - df = _posData.clickEntryPointsDfs.get(tableEndName) - if df is None: - continue - - are_there_points_to_save = True - break - - if not are_there_points_to_save: - continue - - cancel = self.askSavePointsLayer(action) - if cancel: - return cancel - - return False - - def askSavePointsLayer(self, action): - toolButton = action.button - tableEndName = toolButton.clickEntryTableEndName - saveAction = toolButton.saveAction - - txt = html_utils.paragraph(f""" - Do you want to save the points you added - (table called {tableEndName}.csv)? - """ - ) - msg = widgets.myMessageBox(wrapText=False) - _, _, saveButton = msg.question( - self, 'Save points layer?', txt, - buttonsTexts=('Cancel', 'No, do not save', 'Yes, save points') - ) - if msg.clickedButton == saveButton: - self.savePointsAddedByClicking(saveAction.saveToolbutton, None) - - return msg.cancel - - def removeOverlayItems(self): - self.lutItemsLayout.clear() - - try: - for toolbutton in self.allOverlayToolbuttonsByIdx.values(): - self.overlayToolbar.removeAction(toolbutton.action) - - self.overlayToolbuttonsSep.removeFromToolbar() - except Exception as err: - pass - - def clearOverlayImageItems(self): - for items in self.overlayLayersItems.values(): - imageItem = items[0] - imageItem.clear() - - self.rgbaImg1.clear() - - def reInitGui(self): - cancel = self.checkAskSavePointsLayers() - if cancel: - return False - - if self.overlayToolbar.isTransparent(): - self.overlayToolbar.setTransparent(False) - - self.secondLevelToolbar.setVisible(False) - - self.gui_createLazyLoader() - - try: - self.navSpinBox.valueChanged.disconnect() - except Exception as e: - pass - - try: - self.scaleBar.removeFromAxis(self.ax1) - except Exception as e: - pass - - self.lineage_tree = None - self.getDistanceListMissingIDsCachedFrame = None - self.isZmodifier = False - self.zKeptDown = False - self.askRepeatSegment3D = True - self.askZrangeSegm3D = True - self.isDataLoaded = False - self.retainSizeLutItems = False - self.setMeasWinState = None - self.addPointsWin = None - self.delRoiLab = None - self.showPropsDockButton.setDisabled(True) - self.removeOverlayItems() - self.lutItemsLayout.addItem(self.imgGrad, row=0, col=0) - - self.reinitWidgetsPos() - self.removeAllItems() - self.reinitCustomAnnot() - self.reinitPointsLayers() - self.gui_createPlotItems() - self.setUncheckedAllButtons() - self.setUncheckedPointsLayers() - self.restoreDefaultColors() - self.reinitStoredSegmModels() - self.removeAxLimits() - self.curvToolButton.setChecked(False) - - self.wandControlsToolbar.setVisible(False) - self.wandToolButton.setChecked(False) - self.segmNdimIndicatorAction.setVisible(False) - - self.navigateToolBar.hide() - self.ccaToolBar.hide() - self.editToolBar.hide() - self.brushEraserToolBar.hide() - self.modeToolBar.hide() - - self.modeComboBox.setCurrentText('Viewer') - - alpha = self.imgGrad.labelsAlphaSlider.value() - self.labelsLayerImg1.setOpacity(alpha) - self.labelsLayerRightImg.setOpacity(alpha) - self.lastTrackedFrameLabel.setText('') - - self.promptSegmentPointsLayerToolbar.isPointsLayerInit = False - - for action in self.askHowFutureFramesActions.values(): - action.setChecked(True) - action.setDisabled(True) - - return True - - def reinitPointsLayers(self): - for toolbar in self.pointsLayersToolbars: - for action in toolbar.actions()[1:]: - toolbar.removeAction(action) - toolbar.setVisible(False) - self.autoPilotZoomToObjToolbar.setVisible(False) - - def reinitWidgetsPos(self): - pass - # try: - # # self.highlightZneighObjCheckbox will be connected in - # # self.showHighlightZneighCheckbox() - # self.highlightZneighObjCheckbox.toggled.disconnect() - # except Exception as e: - # pass - # layout = self.bottomLeftLayout - # self.highlightZneighObjCheckbox.hide() - # try: - # layout.removeWidget(self.highlightZneighObjCheckbox) - # except Exception as e: - # pass - # self.highlightZneighObjCheckbox.hide() - # # layout.addWidget( - # # self.drawIDsContComboBox, 0, 1, 1, 2, - # # alignment=Qt.AlignCenter - # # ) - - def reinitCustomAnnot(self): - buttons = list(self.customAnnotDict.keys()) - for button in buttons: - self.clearScatterPlotCustomAnnotButton(button) - action = self.customAnnotDict[button]['action'] - self.annotateToolbar.removeAction(action) - self.checkableQButtonsGroup.removeButton(button) - self.customAnnotDict.pop(button) - # self.savedCustomAnnot.pop(name) - - self.saveCustomAnnot(only_temp=True) - - def loadingDataAborted(self): - self.openFolderAction.setEnabled(True) - self.titleLabel.setText('Loading data aborted.') - - def cleanUpOnError(self): - self.onEscape() - caller = 'Cell-ACDC' - if self.module.startswith('spotmax'): - caller = 'spotMAX' - txt = f'WARNING: {caller} is in error state. Please, restart.' - _hl = '*'*100 - self.titleLabel.setText(txt, color='r') - self.logger.info(f'{_hl}\n{txt}\n{_hl}') - - def openFolder( - self, checked=False, exp_path=None, imageFilePath='' - ): - if exp_path is None: - self.logger.info('Asking to select a folder path...') - else: - self.logger.info(f'Opening FOLDER "{exp_path}"...') - - self.isNewFile = False - if hasattr(self, 'data') and self.titleLabel.text != 'Saved!': - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Do you want to save before loading another dataset?' - ) - _, no, yes = msg.question( - self, 'Save?', txt, - buttonsTexts=('Cancel', 'No', 'Yes') - ) - if msg.clickedButton == yes: - func = partial(self._openFolder, exp_path, imageFilePath) - cancel = self.saveData(finishedCallback=func) - return - elif msg.cancel: - self.store_data() - return - else: - self.store_data(autosave=False) - - self._openFolder( - exp_path=exp_path, imageFilePath=imageFilePath - ) - - def addToRecentPaths(self, path, logger=None): - myutils.addToRecentPaths(path, logger=self.logger) - - def getMostRecentPath(self): - return myutils.getMostRecentPath() - - @exception_handler - def _openFolder( - self, checked=False, exp_path=None, imageFilePath='' - ): - """Main function to load data. - - Parameters - ---------- - checked : bool - kwarg needed because openFolder can be called by openFolderAction. - exp_path : string or None - Path selected by the user either directly, through openFile, - or drag and drop image file. - imageFilePath : string - Path of the image file that was either drag and dropped or opened - from File --> Open image/video file (openFileAction). - - Returns - ------- - None - """ - - if exp_path is None: - self.MostRecentPath = self.getMostRecentPath() - exp_path = QFileDialog.getExistingDirectory( - self, - 'Select experiment folder containing Position_n folders ' - 'or specific Position_n folder', - self.MostRecentPath - ) - - if not exp_path: - self.openFolderAction.setEnabled(True) - return - - proceed = self.reInitGui() - if not proceed: - self.openFolderAction.setEnabled(True) - return - - self.openFolderAction.setEnabled(False) - - if self.slideshowWin is not None: - self.slideshowWin.close() - - if self.ccaTableWin is not None: - self.ccaTableWin.close() - - self.exp_path = exp_path - self.logger.info(f'Loading from {self.exp_path}') - self.addToRecentPaths(exp_path, logger=self.logger) - self.addPathToOpenRecentMenu(exp_path) - - folder_type = myutils.determine_folder_type(exp_path) - is_pos_folder, is_images_folder, exp_path = folder_type - - self.titleLabel.setText('Loading data...', color=self.titleColor) - - skip_channels = [] - ch_name_selector = prompts.select_channel_name( - which_channel='segm', allow_abort=False - ) - user_ch_name = None - if not is_pos_folder and not is_images_folder and not imageFilePath: - images_paths = self._loadFromExperimentFolder(exp_path) - if not images_paths: - self.loadingDataAborted() - return - - elif is_pos_folder and not imageFilePath: - pos_foldername = os.path.basename(exp_path) - exp_path = os.path.dirname(exp_path) - images_paths = [os.path.join(exp_path, pos_foldername, 'Images')] - - elif is_images_folder and not imageFilePath: - images_paths = [exp_path] - pos_path = os.path.dirname(exp_path) - exp_path = os.path.dirname(pos_path) - - elif imageFilePath: - # images_path = exp_path because called by openFile func - filenames = myutils.listdir(exp_path) - ch_names, basenameNotFound = ( - ch_name_selector.get_available_channels(filenames, exp_path) - ) - filename = os.path.basename(imageFilePath) - self.ch_names = ch_names - user_ch_name = [ - chName for chName in ch_names if filename.find(chName)!=-1 - ][0] - images_paths = [exp_path] - pos_path = os.path.dirname(exp_path) - exp_path = os.path.dirname(pos_path) - - self.images_paths = images_paths - - # Get info from first position selected - images_path = self.images_paths[0] - filenames = myutils.listdir(images_path) - if ch_name_selector.is_first_call and user_ch_name is None: - ch_names, _ = ch_name_selector.get_available_channels( - filenames, images_path - ) - self.ch_names = ch_names - if not ch_names: - self.openFolderAction.setEnabled(True) - self.criticalNoTifFound(images_path) - return - if len(ch_names) > 1: - CbLabel='Select channel name to load: ' - ch_name_selector.QtPrompt( - self, ch_names, CbLabel=CbLabel - ) - if ch_name_selector.was_aborted: - self.openFolderAction.setEnabled(True) - return - skip_channels.extend([ - ch for ch in ch_names if ch!=ch_name_selector.channel_name - ]) - else: - ch_name_selector.channel_name = ch_names[0] - ch_name_selector.setUserChannelName() - user_ch_name = ch_name_selector.user_ch_name - else: - # File opened directly with self.openFile - ch_name_selector.channel_name = user_ch_name - - user_ch_file_paths = [] - not_allowed_ends = ['btrack_tracks.h5'] - for images_path in self.images_paths: - channel_file_path = load.get_filename_from_channel( - images_path, user_ch_name, skip_channels=skip_channels, - not_allowed_ends=not_allowed_ends, logger=self.logger.info - ) - if not channel_file_path: - self.criticalImgPathNotFound(images_path) - return - user_ch_file_paths.append(channel_file_path) - - ch_name_selector.setUserChannelName() - self.user_ch_name = user_ch_name - self.img1.channelName = user_ch_name - - self.AutoPilotProfile.storeSelectedChannel(self.user_ch_name) - - self.initGlobalAttr() - self.createOverlayContextMenu() - self.createUserChannelNameAction() - self.gui_createOverlayColors() - self.gui_createOverlayItems() - lastRow = self.bottomLeftLayout.rowCount() - self.bottomLeftLayout.setRowStretch(lastRow+1, 1) - - self.num_pos = len(user_ch_file_paths) - proceed = self.loadSelectedData(user_ch_file_paths, user_ch_name) - if not proceed: - self.openFolderAction.setEnabled(True) - return - - def _loadFromExperimentFolder(self, exp_path): - select_folder = load.select_exp_folder() - values = select_folder.get_values_segmGUI(exp_path) - if not values: - self.criticalInvalidPosFolder(exp_path) - self.openFolderAction.setEnabled(True) - return [] - - if len(values) > 1: - select_folder.QtPrompt(self, values, allow_cancel=False) - if select_folder.cancel: - return [] - else: - select_folder.cancel = False - select_folder.selected_pos = select_folder.pos_foldernames - - images_paths = [] - for pos in select_folder.selected_pos: - images_paths.append(os.path.join(exp_path, pos, 'Images')) - return images_paths - - def criticalInvalidPosFolder(self, exp_path): - href = html_utils.href_tag('here', data_structure_docs_url) - txt = html_utils.paragraph(f""" - The selected folder:

- - {exp_path}

- - is not a valid folder.

- - Select a folder that contains the Position_n folders, - or a specific Position.

- - If you are trying to load a single image file go to - File --> Open image/video file....

- - To load a folder containing multiple .tif files the folder must - be called either Position_n
- (with n being an integer) or Images.

- - For more information about the correct folder structure see {href}. - """) - msg = widgets.myMessageBox(wrapText=False) - helpButton = widgets.helpPushButton('Help...') - msg.addButton(helpButton) - helpButton.clicked.disconnect() - helpButton.clicked.connect( - partial(myutils.browse_url, data_structure_docs_url) - ) - msg.addShowInFileManagerButton(exp_path) - msg.critical( - self, 'Incompatible folder', txt - ) - - def createOverlayContextMenu(self): - ch_names = [ch for ch in self.ch_names if ch != self.user_ch_name] - self.overlayContextMenu = QMenu() - self.overlayContextMenu.addSeparator() - self.checkedOverlayChannels = set() - for chName in ch_names: - action = QAction(chName, self.overlayContextMenu) - action.setCheckable(True) - action.toggled.connect(self.overlayChannelToggled) - self.overlayContextMenu.addAction(action) - - def createOverlayLabelsContextMenu(self, segmEndnames): - self.overlayLabelsContextMenu = QMenu() - self.overlayLabelsContextMenu.addSeparator() - self.drawModeOverlayLabelsChannels = {} - segmEndnames_extended = list(segmEndnames.copy()) - segmEndnames_extended = ['combined segm.'] + segmEndnames_extended - for segmEndname in segmEndnames_extended: - action = QAction(segmEndname, self.overlayLabelsContextMenu) - if segmEndname == 'combined segm.': - action.setCheckable(False) - self.combineSegmViewToggle = action - else: - action.setCheckable(True) - action.toggled.connect(self.addOverlayLabelsToggled) - self.overlayLabelsContextMenu.addAction(action) - - self.overlayLabelsContextMenu.addSeparator() - action = QAction('Edit appearance...', self.overlayLabelsContextMenu) - action.triggered.connect(self.editOverlayLabelsAppearance) - self.overlayLabelsContextMenu.addAction(action) - - def editOverlayLabelsAppearance(self, *args): - segmEndname = list(self.overlayLabelsItems.keys())[0] - contoursItem = self.overlayLabelsItems[segmEndname][1] - win = apps.OverlayLabelsAppearanceDialog( - scatterPlotItem=contoursItem, parent=self - ) - win.exec_() - if win.cancel: - return - - brush = win.properties['brush'] - pen = win.properties['pen'] - for items in self.overlayLabelsItems.values(): - imageItem, contoursItem, gradItem = items - contoursItem.setBrush(brush, update=False) - contoursItem.setPen(pen) - - def createOverlayLabelsItems(self, segmEndnames): - selectActionGroup = QActionGroup(self) - segmEndnames_extended = list(segmEndnames.copy()) - segmEndnames_extended = ['combined segm.'] + segmEndnames_extended - for segmEndname in segmEndnames_extended: - action = QAction(segmEndname) - if segmEndname == 'combined segm.': - action.setCheckable(False) - else: - action.setCheckable(True) - action.toggled.connect(self.setOverlayLabelsItemsVisible) - selectActionGroup.addAction(action) - self.selectOverlayLabelsActionGroup = selectActionGroup - - self.overlayLabelsItems = {} - for segmEndname in segmEndnames_extended: - imageItem = pg.ImageItem() - - gradItem = widgets.overlayLabelsGradientWidget( - imageItem, selectActionGroup, segmEndname - ) - gradItem.hide() - gradItem.drawModeActionGroup.triggered.connect( - self.overlayLabelsDrawModeToggled - ) - self.mainLayout.addWidget(gradItem, 0, 0) - - contoursItem = pg.ScatterPlotItem() - color = colors.get_complementary_color(self.contLineColor) - r, g, b, a = colors.rgba_str_to_values(color) - qcolor = QColor(r, g, b, a) - contoursItem.setData( - [], [], symbol='s', pxMode=False, size=self.contLineWeight*2, - brush=pg.mkBrush(color=qcolor), - pen=pg.mkPen(width=3, color=qcolor), tip=None - ) - - items = (imageItem, contoursItem, gradItem) - self.overlayLabelsItems[segmEndname] = items - - def addOverlayLabelsToggled(self, checked, name=None): - if name is None: - name = self.sender().text() - if checked: - gradItem = self.overlayLabelsItems[name][-1] - drawMode = gradItem.drawModeActionGroup.checkedAction().text() - self.drawModeOverlayLabelsChannels[name] = drawMode - else: - self.drawModeOverlayLabelsChannels.pop(name) - self.hideOverlayLabelsItems(specific=[name]) - self.setOverlayLabelsItems() - - def overlayLabelsDrawModeToggled(self, action): - segmEndname = action.segmEndname - drawMode = action.text() - if segmEndname in self.drawModeOverlayLabelsChannels: - self.drawModeOverlayLabelsChannels[segmEndname] = drawMode - self.setOverlayLabelsItems() - - def overlayChannelToggled(self, checked): - # Action toggled from overlayButton context menu - channelName = self.sender().text() - posData = self.data[self.pos_i] - if checked: - if channelName not in posData.loadedFluoChannels: - self.loadOverlayData([channelName], addToExisting=True) - else: - _, filename = self.getPathFromChName(channelName, posData) - posData.ol_data[filename] = ( - posData.ol_data_dict[filename].copy() - ) - - self.checkedOverlayChannels.add(channelName) - else: - self.checkedOverlayChannels.remove(channelName) - imageItem = self.overlayLayersItems[channelName][0] - imageItem.clear() - - self.setOverlayChannelsToolbuttonsChecked() - self.setOverlayItemsVisible() - self.setRetainSizePolicyLutItems() - self.updateAllImages() - - @exception_handler - def loadDataWorkerDataIntegrityWarning(self, pos_foldername): - err_msg = ( - 'WARNING: Segmentation mask file ("..._segm.npz") not found. ' - 'You could run segmentation module first.' - ) - self.workerProgress(err_msg, 'INFO') - self.titleLabel.setText(err_msg, color='r') - abort = False - msg = widgets.myMessageBox(parent=self) - warn_msg = html_utils.paragraph(f""" - The folder {pos_foldername} does not contain a - pre-computed segmentation mask.

- You can continue with a blank mask or cancel and - pre-compute the mask with the segmentation module.

- Do you want to continue? - """) - msg.setIcon(iconName='SP_MessageBoxWarning') - msg.setWindowTitle('Segmentation file not found') - msg.addText(warn_msg) - msg.addButton('Ok') - continueWithBlankSegm = msg.addButton(' Cancel ') - msg.show(block=True) - if continueWithBlankSegm == msg.clickedButton: - abort = True - self.loadDataWorker.abort = abort - self.loadDataWaitCond.wakeAll() - - def warnMemoryNotSufficient(self, total_ram, available_ram, required_ram): - total_ram = myutils._bytes_to_GB(total_ram) - available_ram = myutils._bytes_to_GB(available_ram) - required_ram = myutils._bytes_to_GB(required_ram) - required_perc = round(100*required_ram/available_ram) - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - The total amount of data that you requested to load is about - {required_ram:.2f} GB ({required_perc}% of the available memory) - but there are only {available_ram:.2f} GB available.

- For optimal operation, we recommend loading maximum 30% - of the available memory. To do so, try to close open apps to - free up some memory. Another option is to crop the images - using the data prep module.

- If you choose to continue, the system might freeze - or your OS could simply kill the process.

- What do you want to do? - """) - cancelButton, continueButton = msg.warning( - self, 'Memory not sufficient', txt, - buttonsTexts=('Cancel', 'Continue anyway') - ) - if msg.clickedButton == continueButton: - # Disable autosaving since it would keep a copy of the data and - # we cannot afford it with low memory - self.autoSaveToggle.setChecked(False) - return True - else: - return False - - def checkMemoryRequirements(self, required_ram): - memory = psutil.virtual_memory() - total_ram = memory.total - available_ram = memory.available - if required_ram/available_ram > 0.3: - proceed = self.warnMemoryNotSufficient( - total_ram, available_ram, required_ram - ) - return proceed - else: - return True - - def criticalImgPathNotFound(self, images_path): - self.logger.info( - 'The following folder does not contain valid image files: ' - f'"{images_path}"\n\n' - 'Check that all the positions loaded contain the same channel name. ' - 'Make sure to double check for spelling mistakes or types in the ' - 'channel names.' - ) - msg = widgets.myMessageBox() - msg.addShowInFileManagerButton(images_path) - err_msg = html_utils.paragraph(f""" - The folder

- {images_path}

- does not contain any valid image file!

- Valid file formats are .h5, .tif, _aligned.h5, _aligned.npz. - """) - okButton = msg.critical( - self, 'No valid files found!', err_msg, buttonsTexts=('Ok',) - ) - - def initRealTimeTracker(self, force=False): - for rtTrackerAction in self.trackingAlgosGroup.actions(): - if rtTrackerAction.isChecked(): - break - - aliases = myutils.aliases_real_time_trackers(reverse=True) - - rtTracker = rtTrackerAction.text() - rtTracker_txt = rtTracker - - if rtTracker in aliases: - rtTracker = aliases[rtTracker] - - if rtTracker == 'Cell-ACDC': - return - if rtTracker == 'YeaZ': - return - - if self.isRealTimeTrackerInitialized and not force: - return - - self.logger.info(f'Initializing {rtTracker_txt} tracker...') - self._rtTrackerName = rtTracker - posData = self.data[self.pos_i] - realTimeTracker, track_frame_params = myutils.init_tracker( - posData, rtTracker, qparent=self, realTime=True - ) - if realTimeTracker is None: - self.logger.info(f'{rtTracker} tracker initialization cancelled.') - return - - self.realTimeTracker = realTimeTracker - self.track_frame_params = track_frame_params - self.logger.info(f'{rtTracker} tracker successfully initialized.') - if 'image_channel_name' in self.track_frame_params: - # Remove the channel name since it was already loaded in init_tracker - del self.track_frame_params['image_channel_name'] - - def initFluoData(self): - if len(self.ch_names) <= 1: - return - - if 'ask_load_fluo_at_init' in self.df_settings.index: - if self.df_settings.at['ask_load_fluo_at_init', 'value'] == 'No': - return - msg = widgets.myMessageBox(allowClose=False) - txt = ( - 'Do you also want to load fluorescence images?
' - 'You can load as many channels as you want.

' - 'If you load fluorescence images then the software will ' - 'calculate metrics for each loaded fluorescence channel ' - 'such as min, max, mean, quantiles, etc. ' - 'of each segmented object.

' - 'NOTE: You can always load them later from the menu ' - 'File --> Load fluorescence images... or when you set ' - 'measurements from the menu ' - 'Measurements --> Set measurements...' - ) - msg.addDoNotShowAgainCheckbox(text="Don't ask again") - no, yes = msg.question( - self, 'Load fluorescence images?', html_utils.paragraph(txt), - buttonsTexts=('No', 'Yes') - ) - if msg.doNotShowAgainCheckbox.isChecked(): - self.df_settings.at['ask_load_fluo_at_init', 'value'] = 'No' - self.df_settings.to_csv(self.settings_csv_path) - if msg.clickedButton == yes: - self.loadFluo_cb(None) - self.AutoPilotProfile.storeClickMessageBox( - 'Load fluorescence images?', msg.clickedButton.text() - ) - - def getPathFromChName(self, chName, posData): - ls = myutils.listdir(posData.images_path) - endnames = {f[len(posData.basename):]:f for f in ls} - validEnds = ['_aligned.npz', '_aligned.h5', '.h5', '.tif', '.npz'] - for end in validEnds: - files = [ - filename for endname, filename in endnames.items() - if endname == f'{chName}{end}' - ] - if files: - filename = files[0] - break - else: - self.criticalFluoChannelNotFound(chName, posData) - self.app.restoreOverrideCursor() - return None, None - - fluo_path = os.path.join(posData.images_path, filename) - filename, _ = os.path.splitext(filename) - return fluo_path, filename - - def loadPosTriggered(self): - if not self.isDataLoaded: - return - - self.startAutomaticLoadingPos() - - def startAutomaticLoadingPos(self): - self.AutoPilot = autopilot.AutoPilot(self) - self.AutoPilot.execLoadPos() - - def stopAutomaticLoadingPos(self): - if self.AutoPilot is None: - return - - if self.AutoPilot.timer.isActive(): - self.AutoPilot.timer.stop() - self.AutoPilot = None - - def startCcaIntegrityCheckerWorker(self): - if not hasattr(self, 'data'): - return - - if not self.isDataLoaded: - return - - if not self.ccaIntegrCheckerToggle.isChecked(): - return - - ccaCheckerThread = QThread() - self.ccaCheckerMutex = QMutex() - self.ccaCheckerWaitCond = QWaitCondition() - - worker = workers.CcaIntegrityCheckerWorker( - self.ccaCheckerMutex, self.ccaCheckerWaitCond - ) - self.ccaIntegrityCheckerWorker = worker - self.ccaCheckerThread = ccaCheckerThread - - worker.moveToThread(ccaCheckerThread) - worker.finished.connect(ccaCheckerThread.quit) - worker.finished.connect(worker.deleteLater) - ccaCheckerThread.finished.connect(ccaCheckerThread.deleteLater) - - worker.sigDone.connect(self.ccaCheckerWorkerDone) - worker.progress.connect(self.workerProgress) - worker.critical.connect(self.ccaIntegrityWorkerCritical) - worker.finished.connect(self.ccaCheckerWorkerClosed) - worker.sigWarning.connect(self.warnCcaIntegrity) - worker.sigFixWillDivide.connect(self.fixWillDivide) - - ccaCheckerThread.started.connect(worker.run) - ccaCheckerThread.start() - - self.ccaCheckerRunning = True - - self.initCcaIntegrityChecker() - - self.logger.info('Cell cycle annotations integrity checker started.') - - def initCcaIntegrityChecker(self): - posData = self.data[self.pos_i] - for frame_i, data_frame_i in enumerate(posData.allData_li): - lab = data_frame_i['labels'] - if lab is None: - break - - cca_df = self.get_cca_df(frame_i, return_df=True) - self.store_cca_df_checker(posData, frame_i, cca_df) - - self.enqCcaIntegrityChecker() - - def initCcaIntegrityChecker(self): - posData = self.data[self.pos_i] - for frame_i, data_frame_i in enumerate(posData.allData_li): - lab = data_frame_i['labels'] - if lab is None: - break - - cca_df = self.get_cca_df(frame_i, return_df=True) - self.store_cca_df_checker(posData, frame_i, cca_df) - - self.enqCcaIntegrityChecker() - - def disableCcaIntegrityChecker(self): - self.stopCcaIntegrityCheckerWorker() - - def stopCcaIntegrityCheckerWorker(self): - try: - self.ccaIntegrityCheckerWorker._stop() - except Exception as err: - pass - - def loadFluo_cb(self, checked=True, fluo_channels=None): - if fluo_channels is None: - posData = self.data[self.pos_i] - ch_names = [ - ch for ch in self.ch_names if ch != self.user_ch_name - and ch not in posData.loadedFluoChannels - ] - if not ch_names: - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'You already loaded ALL channels.

' - 'To change the overlaid channel ' - 'right-click on the overlay button.' - ) - msg.information(self, 'All channels are loaded', txt) - return False - selectFluo = widgets.QDialogListbox( - 'Select channel to load', - 'Select channel names to load:\n', - ch_names, multiSelection=True, parent=self - ) - selectFluo.exec_() - - if selectFluo.cancel: - return False - - fluo_channels = selectFluo.selectedItemsText - self.AutoPilotProfile.storeLoadedFluoChannels(fluo_channels) - - for p, posData in enumerate(self.data): - # posData.ol_data = None - for fluo_ch in fluo_channels: - fluo_path, filename = self.getPathFromChName(fluo_ch, posData) - if fluo_path is None: - self.criticalFluoChannelNotFound(fluo_ch, posData) - return False - fluo_data, bkgrData = self.load_fluo_data(fluo_path) - if fluo_data is None: - return False - posData.loadedFluoChannels.add(fluo_ch) - - if posData.SizeT == 1: - fluo_data = fluo_data[np.newaxis] - - posData.fluo_data_dict[filename] = fluo_data - posData.fluo_bkgrData_dict[filename] = bkgrData - posData.ol_data_dict[filename] = fluo_data.copy() - - self.overlayButton.setStyleSheet(f'background-color: {GREEN_HEX}') - self.guiTabControl.addChannels([ - posData.user_ch_name, *posData.loadedFluoChannels - ]) - return True - - def labelRoiCancelled(self): - self.labelRoiRunning = False - self.app.restoreOverrideCursor() - self.labelRoiItem.setPos((0,0)) - self.labelRoiItem.setSize((0,0)) - self.freeRoiItem.clear() - self.logger.info('Magic labeller process cancelled.') - - def labelRoiCheckStartStopFrame(self): - if not self.labelRoiTrangeCheckbox.isChecked(): - return True - - start_n = self.labelRoiStartFrameNoSpinbox.value() - stop_n = self.labelRoiStopFrameNoSpinbox.value() - if start_n <= stop_n: - return True - - self.blinker = qutils.QControlBlink( - self.labelRoiStopFrameNoSpinbox, - qparent=self - ) - self.blinker.start() - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - Stop frame number is less than start frame number!

- What do you want to do? - """) - msg.warning( - self, 'Stop frame number lower than start', txt, - buttonsTexts=('Cancel', 'Segment only current frame') - ) - if msg.cancel: - return False - - posData = self.data[self.pos_i] - self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i+1) - self.labelRoiStopFrameNoSpinbox.setValue(posData.frame_i+1) - - def getSecondChannelData(self): - if self.secondChannelName is None: - return - - posData = self.data[self.pos_i] - - fluo_ch = self.secondChannelName - fluo_path, filename = self.getPathFromChName(fluo_ch, posData) - if filename in posData.fluo_data_dict: - fluo_data = posData.fluo_data_dict[filename] - else: - fluo_data, bkgrData = self.load_fluo_data(fluo_path) - posData.fluo_data_dict[filename] = fluo_data - posData.fluo_bkgrData_dict[filename] = bkgrData - - if self.labelRoiTrangeCheckbox.isChecked(): - start_frame_i = self.labelRoiStartFrameNoSpinbox.value()-1 - stop_frame_n = self.labelRoiStopFrameNoSpinbox.value() - tRangeLen = stop_frame_n-start_frame_i - else: - tRangeLen = 1 - - if tRangeLen > 1: - # fluo_img_data = fluo_data[start_frame_i:stop_frame_n] - if self.isSegm3D or posData.SizeZ == 1: - return fluo_data - else: - T, Z, Y, X = fluo_data.shape - secondChannelData = np.zeros((T, Y, X), dtype=fluo_data.dtype) - for frame_i, fluo_img in enumerate(fluo_data): - secondChannelData[frame_i] = self.get_2Dimg_from_3D( - fluo_data, frame_i=frame_i - ) - return secondChannelData - else: - if posData.SizeT > 1: - fluo_img_data = fluo_data[posData.frame_i] - else: - fluo_img_data = fluo_data - - if self.isSegm3D or posData.SizeZ == 1: - return fluo_img_data - else: - return self.get_2Dimg_from_3D(fluo_img_data) - - def addActionsLutItemContextMenu(self, lutItem): - lutItem.gradient.menu.addSection('Visible channels: ') - for action in self.overlayContextMenu.actions(): - if action.isSeparator(): - continue - lutItem.gradient.menu.addAction(action) - lutItem.gradient.menu.addSeparator() - - annotationMenu = lutItem.gradient.menu.addMenu('Annotations settings') - ID_menu = annotationMenu.addMenu('IDs') - self.annotSettingsIDmenu = QActionGroup(annotationMenu) - labID_action = QAction("Show label's ID") - labID_action.setCheckable(True) - labID_action.setChecked(True) - labID_action.toggled.connect(self.annotLabelIDtreeToggled) - treeID_action = QAction("Show tree's ID") - treeID_action.setCheckable(True) - treeID_action.toggled.connect(self.annotLabelIDtreeToggled) - self.annotSettingsIDmenu.addAction(labID_action) - self.annotSettingsIDmenu.addAction(treeID_action) - ID_menu.addAction(labID_action) - ID_menu.addAction(treeID_action) - - ID_menu = annotationMenu.addMenu('Generation number') - self.annotSettingsGenNumMenu = QActionGroup(annotationMenu) - gen_num_action = QAction("Show default generation number") - gen_num_action.setCheckable(True) - gen_num_action.setChecked(True) - gen_num_action.toggled.connect(self.annotGenNumTreeToggled) - tree_gen_num_action = QAction("Show tree generation number") - tree_gen_num_action.setCheckable(True) - tree_gen_num_action.toggled.connect(self.annotGenNumTreeToggled) - self.annotSettingsGenNumMenu.addAction(gen_num_action) - self.annotSettingsGenNumMenu.addAction(tree_gen_num_action) - ID_menu.addAction(gen_num_action) - ID_menu.addAction(tree_gen_num_action) - - def annotGenNumTreeToggled(self, checked): - self.textAnnot[0].setGenNumTreeAnnotationsEnabled(checked) - - def annotLabelIDtreeToggled(self, checked): - self.textAnnot[0].setLabelTreeAnnotationsEnabled(checked) - - def setAnnotInfoMode(self, checked): - if checked: - for action in self.annotSettingsIDmenu.actions(): - if action.text().find('tree') != -1: - self.textAnnot[0].setLabelTreeAnnotationsEnabled(True) - action.setChecked(True) - break - for action in self.annotSettingsGenNumMenu.actions(): - if action.text().find('tree') != -1: - self.textAnnot[0].setGenNumTreeAnnotationsEnabled(True) - action.setChecked(True) - break - else: - for action in self.annotSettingsIDmenu.actions(): - if action.text().find('tree') == -1: - action.setChecked(False) - self.textAnnot[0].setLabelTreeAnnotationsEnabled(False) - break - for action in self.annotSettingsGenNumMenu.actions(): - if action.text().find('tree') == -1: - action.setChecked(False) - self.textAnnot[0].setGenNumTreeAnnotationsEnabled(False) - break - self.setAllTextAnnotations() - - def annotOptionClicked(self, clicked=True, sender=None, saveSettings=True): - if sender is None: - sender = self.sender() - # First manually set exclusive with uncheckable - clickedIDs = sender == self.annotIDsCheckbox - clickedCca = sender == self.annotCcaInfoCheckbox - clickedMBline = sender == self.drawMothBudLinesCheckbox - if self.annotIDsCheckbox.isChecked() and clickedIDs: - if self.annotCcaInfoCheckbox.isChecked(): - self.annotCcaInfoCheckbox.setChecked(False) - if self.drawMothBudLinesCheckbox.isChecked(): - self.drawMothBudLinesCheckbox.setChecked(False) - - if self.annotCcaInfoCheckbox.isChecked() and clickedCca: - if self.annotIDsCheckbox.isChecked(): - self.annotIDsCheckbox.setChecked(False) - if self.drawMothBudLinesCheckbox.isChecked(): - self.drawMothBudLinesCheckbox.setChecked(False) - - if self.drawMothBudLinesCheckbox.isChecked() and clickedMBline: - if self.annotIDsCheckbox.isChecked(): - self.annotIDsCheckbox.setChecked(False) - if self.annotCcaInfoCheckbox.isChecked(): - self.annotCcaInfoCheckbox.setChecked(False) - - clickedCont = sender == self.annotContourCheckbox - clickedSegm = sender == self.annotSegmMasksCheckbox - if self.annotContourCheckbox.isChecked() and clickedCont: - if self.annotSegmMasksCheckbox.isChecked(): - self.annotSegmMasksCheckbox.setChecked(False) - - if self.annotSegmMasksCheckbox.isChecked() and clickedSegm: - if self.annotContourCheckbox.isChecked(): - self.annotContourCheckbox.setChecked(False) - - clickedDoNot = sender == self.drawNothingCheckbox - if clickedDoNot: - self.annotIDsCheckbox.setChecked(False) - self.annotCcaInfoCheckbox.setChecked(False) - self.annotContourCheckbox.setChecked(False) - self.annotSegmMasksCheckbox.setChecked(False) - self.drawMothBudLinesCheckbox.setChecked(False) - self.annotNumZslicesCheckbox.setChecked(False) - else: - self.drawNothingCheckbox.setChecked(False) - - if sender == self.annotNumZslicesCheckbox: - self.annotIDsCheckbox.setChecked(True) - self.drawNothingCheckbox.setChecked(False) - - self.setDrawAnnotComboboxText(saveSettings=saveSettings) - - def setDisabledAnnotCheckBoxesLeft(self, disabled): - self.annotIDsCheckbox.setDisabled(disabled) - self.annotCcaInfoCheckbox.setDisabled(disabled) - self.annotContourCheckbox.setDisabled(disabled) - self.annotSegmMasksCheckbox.setDisabled(disabled) - self.drawMothBudLinesCheckbox.setDisabled(disabled) - self.annotNumZslicesCheckbox.setDisabled(disabled) - self.drawNothingCheckbox.setDisabled(disabled) - - def setEnabledAnnotCheckBoxesLeftZdepthAxes(self): - if not self.isSegm3D: - return - - self.annotIDsCheckbox.setDisabled(False) - self.annotContourCheckbox.setDisabled(False) - self.annotIDsCheckbox.setChecked(True) - self.annotContourCheckbox.setChecked(True) - - self.annotOptionClicked( - sender=self.annotIDsCheckbox, saveSettings=False) - - def setDisabledAnnotCheckBoxesRight(self, disabled): - self.annotIDsCheckboxRight.setDisabled(disabled) - self.annotCcaInfoCheckboxRight.setDisabled(disabled) - self.annotContourCheckboxRight.setDisabled(disabled) - self.annotSegmMasksCheckboxRight.setDisabled(disabled) - self.drawMothBudLinesCheckboxRight.setDisabled(disabled) - self.annotNumZslicesCheckboxRight.setDisabled(disabled) - self.drawNothingCheckboxRight.setDisabled(disabled) - - def annotOptionClickedRight( - self, clicked=True, sender=None, saveSettings=True - ): - if sender is None: - sender = self.sender() - # First manually set exclusive with uncheckable - clickedIDs = sender == self.annotIDsCheckboxRight - clickedCca = sender == self.annotCcaInfoCheckboxRight - clickedMBline = sender == self.drawMothBudLinesCheckboxRight - if self.annotIDsCheckboxRight.isChecked() and clickedIDs: - if self.annotCcaInfoCheckboxRight.isChecked(): - self.annotCcaInfoCheckboxRight.setChecked(False) - if self.drawMothBudLinesCheckboxRight.isChecked(): - self.drawMothBudLinesCheckboxRight.setChecked(False) - - if self.annotCcaInfoCheckboxRight.isChecked() and clickedCca: - if self.annotIDsCheckboxRight.isChecked(): - self.annotIDsCheckboxRight.setChecked(False) - if self.drawMothBudLinesCheckboxRight.isChecked(): - self.drawMothBudLinesCheckboxRight.setChecked(False) - - if self.drawMothBudLinesCheckboxRight.isChecked() and clickedMBline: - if self.annotIDsCheckboxRight.isChecked(): - self.annotIDsCheckboxRight.setChecked(False) - if self.annotCcaInfoCheckboxRight.isChecked(): - self.annotCcaInfoCheckboxRight.setChecked(False) - - clickedCont = sender == self.annotContourCheckboxRight - clickedSegm = sender == self.annotSegmMasksCheckboxRight - if self.annotContourCheckboxRight.isChecked() and clickedCont: - if self.annotSegmMasksCheckboxRight.isChecked(): - self.annotSegmMasksCheckboxRight.setChecked(False) - - if self.annotSegmMasksCheckboxRight.isChecked() and clickedSegm: - if self.annotContourCheckboxRight.isChecked(): - self.annotContourCheckboxRight.setChecked(False) - - clickedDoNot = sender == self.drawNothingCheckboxRight - if clickedDoNot: - self.annotIDsCheckboxRight.setChecked(False) - self.annotCcaInfoCheckboxRight.setChecked(False) - self.annotContourCheckboxRight.setChecked(False) - self.annotSegmMasksCheckboxRight.setChecked(False) - self.drawMothBudLinesCheckboxRight.setChecked(False) - self.annotNumZslicesCheckboxRight.setChecked(False) - else: - self.drawNothingCheckboxRight.setChecked(False) - - if sender == self.annotNumZslicesCheckboxRight: - self.annotIDsCheckboxRight.setChecked(True) - self.drawNothingCheckboxRight.setChecked(False) - - self.setDrawAnnotComboboxTextRight(saveSettings=saveSettings) - - def setAnnotOptionsCcaMode(self): - self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1( - return_value=True - ) - self.annotCcaInfoCheckbox.setChecked(True) - self.annotIDsCheckbox.setChecked(False) - self.drawMothBudLinesCheckbox.setChecked(False) - self.setDrawAnnotComboboxText() - - def setAnnotOptionsLin_treeMode(self): - # self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1( - # return_value=True - # ) - self.annotCcaInfoCheckbox.setChecked(True) - self.annotIDsCheckbox.setChecked(False) - self.drawMothBudLinesCheckbox.setChecked(False) - self.setDrawAnnotComboboxText() - self.showTreeInfoCheckbox.setChecked(True) - - def setDrawAnnotComboboxText(self, saveSettings=True): - if self.annotIDsCheckbox.isChecked(): - if self.annotContourCheckbox.isChecked(): - t = 'Draw IDs and contours' - elif self.annotSegmMasksCheckbox.isChecked(): - t = 'Draw IDs and overlay segm. masks' - else: - t = 'Draw only IDs' - - elif self.annotCcaInfoCheckbox.isChecked(): - if self.annotContourCheckbox.isChecked(): - t = 'Draw cell cycle info and contours' - elif self.annotSegmMasksCheckbox.isChecked(): - t = 'Draw cell cycle info and overlay segm. masks' - else: - t = 'Draw only cell cycle info' - - elif self.annotSegmMasksCheckbox.isChecked(): - t = 'Draw only overlay segm. masks' - - elif self.annotContourCheckbox.isChecked(): - t = 'Draw only contours' - - elif self.drawMothBudLinesCheckbox.isChecked(): - t = 'Draw only mother-bud lines' - - elif self.drawNothingCheckbox.isChecked(): - t = 'Draw nothing' - else: - t = 'Draw nothing' - - if t == self.drawIDsContComboBox.currentText(): - self.drawIDsContComboBox_cb(0) - - self.drawIDsContComboBox.saveSettings = saveSettings - self.drawIDsContComboBox.setCurrentText(t) - - def setDrawAnnotComboboxTextRight(self, saveSettings=True): - if self.annotIDsCheckboxRight.isChecked(): - if self.annotContourCheckboxRight.isChecked(): - t = 'Draw IDs and contours' - elif self.annotSegmMasksCheckboxRight.isChecked(): - t = 'Draw IDs and overlay segm. masks' - else: - t = 'Draw only IDs' - - elif self.annotCcaInfoCheckboxRight.isChecked(): - if self.annotContourCheckboxRight.isChecked(): - t = 'Draw cell cycle info and contours' - elif self.annotSegmMasksCheckboxRight.isChecked(): - t = 'Draw cell cycle info and overlay segm. masks' - else: - t = 'Draw only cell cycle info' - - elif self.annotSegmMasksCheckboxRight.isChecked(): - t = 'Draw only overlay segm. masks' - - elif self.annotContourCheckboxRight.isChecked(): - t = 'Draw only contours' - - elif self.drawMothBudLinesCheckboxRight.isChecked(): - t = 'Draw only mother-bud lines' - - elif self.drawNothingCheckboxRight.isChecked(): - t = 'Draw nothing' - else: - t = 'Draw nothing' - - if t == self.annotateRightHowCombobox.currentText(): - self.annotateRightHowCombobox_cb(0) - - self.annotateRightHowCombobox.saveSettings = saveSettings - self.annotateRightHowCombobox.setCurrentText(t) - - def getOverlayItems(self, channelName, index): - imageItem = widgets.OverlayImageItem() - imageItem.setOpacity(0.5) - imageItem.channelName = channelName - - lutItem = widgets.myHistogramLUTitem( - parent=self, name='image', axisLabel=channelName - ) - imageItem.lutItem = lutItem - for action in lutItem.rescaleActionGroup.actions(): - if action.text() == self.defaultRescaleIntensHow: - action.setChecked(True) - break - - lutItem.removeAddScaleBarAction() - lutItem.removeAddTimestampAction() - lutItem.restoreState(self.df_settings) - lutItem.setImageItem(imageItem) - lutItem.vb.raiseContextMenu = lambda x: None - initColor = self.overlayColors[channelName] - self.initColormapOverlayLayerItem(initColor, lutItem) - lutItem.addOverlayColorButton(initColor, channelName) - lutItem.initColor = initColor - lutItem.hide() - - lutItem.overlayColorButton.sigColorChanging.connect( - self.changeOverlayColor - ) - lutItem.overlayColorButton.sigColorChanged.connect( - self.saveOverlayColor - ) - - lutItem.invertBwAction.toggled.connect(self.setCheckedInvertBW) - - lutItem.contoursColorButton.disconnect() - lutItem.contoursColorButton.clicked.connect( - self.imgGrad.contoursColorButton.click - ) - for act in lutItem.contLineWightActionGroup.actions(): - act.toggled.connect(self.contLineWeightToggled) - - lutItem.mothBudLineColorButton.disconnect() - lutItem.mothBudLineColorButton.clicked.connect( - self.imgGrad.mothBudLineColorButton.click - ) - for act in lutItem.mothBudLineWightActionGroup.actions(): - act.toggled.connect(self.mothBudLineWeightToggled) - - lutItem.textColorButton.disconnect() - lutItem.textColorButton.clicked.connect( - self.editTextIDsColorAction.trigger - ) - - lutItem.defaultSettingsAction.triggered.connect( - self.restoreDefaultSettings - ) - lutItem.labelsAlphaSlider.valueChanged.connect( - self.setValueLabelsAlphaSlider - ) - lutItem.sigRescaleIntes.connect( - partial(self.rescaleIntensitiesLut, imageItem=imageItem) - ) - if f'how_rescale_intensities_{channelName}' in self.df_settings.index: - how = self.df_settings.at[ - f'how_rescale_intensities_{channelName}', 'value' - ] - lutItem.setRescaleIntensitiesHow(how) - - self.rescaleIntensChannelHowMapper[channelName] = ( - 'Rescale each 2D image' - ) - - self.addActionsLutItemContextMenu(lutItem) - - alphaScrollBar = self.addAlphaScrollbar(channelName, imageItem) - - toolbutton = widgets.OverlayChannelToolButton( - channelName, lutItem, shortcut=str(index) - ) - toolbutton.action = self.overlayToolbar.addWidget(toolbutton) - toolbutton.setVisible(False) - - toolbutton.clicked.connect(self.overlayChannelToolbuttonClicked) - - alphaScrollBar.toolbutton = toolbutton - - return imageItem, lutItem, alphaScrollBar, toolbutton - - def addAlphaScrollbar(self, channelName, imageItem): - alphaScrollBar = widgets.ScrollBar(Qt.Horizontal) - imageItem.alphaScrollBar = alphaScrollBar - alphaScrollBar.channelName = channelName - - label = QLabel(f'Alpha {channelName}') - label.setFont(_font) - label.hide() - alphaScrollBar.imageItem = imageItem - alphaScrollBar.label = label - alphaScrollBar.setFixedHeight(self.h) - alphaScrollBar.hide() - alphaScrollBar.setMinimum(0) - alphaScrollBar.setMaximum(40) - alphaScrollBar.setValue(20) - alphaScrollBar.setToolTip( - f'Control the alpha value of the overlaid channel {channelName}.\n' - 'alpha=0 results in NO overlay,\n' - 'alpha=1 results in only fluorescence data visible' - ) - self.bottomLeftLayout.addWidget( - alphaScrollBar.label, self.alphaScrollbarRow, 0, - alignment=Qt.AlignRight - ) - self.bottomLeftLayout.addWidget( - alphaScrollBar, self.alphaScrollbarRow, 1, 1, 2 - ) - - alphaScrollBar.valueChanged.connect( - partial(self.setOpacityOverlayLayersItems, scrollbar=alphaScrollBar) - ) - - self.alphaScrollbarRow += 1 - return alphaScrollBar - - def setValueLabelsAlphaSlider(self, value): - self.imgGrad.labelsAlphaSlider.setValue(value) - self.updateLabelsAlpha(value) - - def setOverlayLabelsItemsVisible(self, checked): - for _segmEndname, drawMode in self.drawModeOverlayLabelsChannels.items(): - items = self.overlayLabelsItems[_segmEndname] - gradItem = items[-1] - gradItem.hide() - - if checked: - segmEndname = self.sender().text() - gradItem = self.overlayLabelsItems[segmEndname][-1] - gradItem.show() - - def setRetainSizePolicyLutItems(self): - if not self.retainSizeLutItems: - return - for channel, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB = items[:3] - myutils.setRetainSizePolicy(lutItem, retain=True) - QTimer.singleShot(300, self.autoRange) - - def setOverlayChannelsToolbuttonsChecked(self): - for channel, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB, toolbutton = items[:4] - toolbutton.setChecked( - not self.overlayToolbar.isSingleChannel() - and channel in self.checkedOverlayChannels - ) - - def setOverlayItemsVisible(self): - for channel, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB, toolbutton = items[:4] - lutItem.hide() - alphaSB.hide() - alphaSB.label.hide() - toolbutton.setVisible(False) - - if not self.overlayButton.isChecked(): - return - - for channel, items in self.overlayLayersItems.items(): - _, lutItem, alphaSB, toolbutton = items[:4] - if channel in self.checkedOverlayChannels: - lutItem.show() - alphaSB.show() - alphaSB.label.show() - toolbutton.setVisible(True) - - def overlayChannelToolbuttonClicked(self, checked=False, toolbutton=None): - if toolbutton is None: - toolbutton = self.sender() - - n_checked_buttons = ( - sum([b.isChecked() for b in self.allOverlayToolbuttons.values()]) - ) - - channelName = toolbutton.channelName() - - if n_checked_buttons == 0 or self.overlayToolbar.isSingleChannel(): - # At least one button must be checked - toolbutton.setChecked(True) - - if self.overlayToolbar.isSingleChannel(): - # Exclusive buttons - for channel, otherToolbutton in self.allOverlayToolbuttons.items(): - if channel == channelName: - continue - - otherToolbutton.setChecked(False) - - if self.overlayToolbar.isTransparent(): - self.setOverlayImages() - return - - self.setOverlayItemsOpacities() - - def setOverlayItemsOpacities(self): - n_checked_buttons = ( - sum([b.isChecked() for b in self.allOverlayToolbuttons.values()]) - ) - - isSingleChannel = ( - self.overlayToolbar.isSingleChannel() - or n_checked_buttons == 1 - ) - - channel_opacity_mapper = self.getOpacitiesFromAlphaScrollbarValues() - - # Set opacity of every layer accordingly - for channel, otherToolbutton in self.allOverlayToolbuttons.items(): - if channel == self.user_ch_name: - otherImageItem = self.img1 - alphaScrollbar = None - # alpha_value = channel_opacity_mapper[channel] - else: - otherItems = self.overlayLayersItems[channel] - otherImageItem = otherItems[0] - alphaScrollbar = otherItems[2] - # alpha_value = alphaScrollbar.value()/alphaScrollbar.maximum() - - if otherToolbutton.isChecked() and isSingleChannel: - op_val = 1.0 - elif otherToolbutton.isChecked(): - op_val = channel_opacity_mapper[channel] - else: - op_val = 0.0 - - if op_val == 0: - op_val = 0.01 - - op_val = op_val if op_val < 1.0 else 0.999 - - otherImageItem.setOpacity(op_val, applyToLinked=False) - - if alphaScrollbar is None: - continue - - alphaScrollbar.setDisabled(bool(op_val == 0)) - - def initColormapOverlayLayerItem(self, foregrColor, lutItem): - if self.invertBwAction.isChecked(): - bkgrColor = (255,255,255,255) - else: - bkgrColor = (0,0,0,255) - gradient = colors.get_pg_gradient((bkgrColor, foregrColor)) - lutItem.setGradient(gradient) - - def setOpacityOverlayLayersItems(self, value, imageItem=None, scrollbar=None): - if scrollbar is None: - scrollbar = imageItem.alphaScrollBar - - channel = scrollbar.channelName - toolbutton = self.allOverlayToolbuttons[channel] - if not toolbutton.isChecked() or not toolbutton.isVisible(): - return - - if value is None: - value = scrollbar.value() - - if imageItem is None: - imageItem = scrollbar.imageItem - alpha = value/scrollbar.maximum() - elif value > 1: - alpha = value/scrollbar.maximum() - else: - alpha = value - - alpha_values = [] - activeOverlayImageItems = [] - for items in self.overlayLayersItems.values(): - imgItem, lutItem, alphaSB = items[:3] - _toolbutton = alphaSB.toolbutton - if alphaSB.channelName == channel: - alpha_values.append(alpha) - elif not _toolbutton.isChecked() or not _toolbutton.isVisible(): - continue - else: - alpha_values.append(alphaSB.value()/alphaSB.maximum()) - - activeOverlayImageItems.append(imgItem) - - opacities = colors.hierarchical_weights(alpha_values)[::-1] - - for i, imgItem in enumerate(activeOverlayImageItems): - imgItem.setOpacity(opacities[i+1]) - - self.img1.setOpacity(opacities[0], applyToLinked=False) - - def showInExplorer_cb(self): - posData = self.data[self.pos_i] - path = posData.images_path - myutils.showInExplorer(path) - - def zSliceAbsent(self, filename, posData): - self.app.restoreOverrideCursor() - SizeZ = posData.SizeZ - chNames = posData.chNames - filenamesPresent = posData.segmInfo_df.index.get_level_values(0).unique() - chNamesPresent = [ - ch for ch in chNames - for file in filenamesPresent - if file.endswith(ch) or file.endswith(f'{ch}_aligned') - ] - win = apps.QDialogZsliceAbsent(filename, SizeZ, chNamesPresent) - win.exec_() - if win.cancel: - self.worker.abort = True - self.waitCond.wakeAll() - return - if win.useMiddleSlice: - user_ch_name = filename[len(posData.basename):] - for _posData in self.data: - if _posData is None: - continue - _, filename = self.getPathFromChName(user_ch_name, _posData) - df = myutils.getDefault_SegmInfo_df(_posData, filename) - _posData.segmInfo_df = pd.concat([df, _posData.segmInfo_df]) - unique_idx = ~_posData.segmInfo_df.index.duplicated() - _posData.segmInfo_df = _posData.segmInfo_df[unique_idx] - _posData.segmInfo_df.to_csv(_posData.segmInfo_df_csv_path) - elif win.useSameAsCh: - user_ch_name = filename[len(posData.basename):] - for _posData in self.data: - if _posData is None: - continue - _, srcFilename = self.getPathFromChName( - win.selectedChannel, _posData - ) - cellacdc_df = _posData.segmInfo_df.loc[srcFilename].copy() - _, dstFilename = self.getPathFromChName(user_ch_name, _posData) - if dstFilename is None: - self.worker.abort = True - self.waitCond.wakeAll() - return - dst_df = myutils.getDefault_SegmInfo_df(_posData, dstFilename) - for z_info in cellacdc_df.itertuples(): - frame_i = z_info.Index - zProjHow = z_info.which_z_proj - if zProjHow == 'single z-slice': - src_idx = (srcFilename, frame_i) - if _posData.segmInfo_df.at[src_idx, 'resegmented_in_gui']: - col = 'z_slice_used_gui' - else: - col = 'z_slice_used_dataPrep' - z_slice = _posData.segmInfo_df.at[src_idx, col] - dst_idx = (dstFilename, frame_i) - dst_df.at[dst_idx, 'z_slice_used_dataPrep'] = z_slice - dst_df.at[dst_idx, 'z_slice_used_gui'] = z_slice - _posData.segmInfo_df = pd.concat([dst_df, _posData.segmInfo_df]) - unique_idx = ~_posData.segmInfo_df.index.duplicated() - _posData.segmInfo_df = _posData.segmInfo_df[unique_idx] - _posData.segmInfo_df.to_csv(_posData.segmInfo_df_csv_path) - elif win.runDataPrep: - user_ch_file_paths = [] - user_ch_name = filename[len(self.data[self.pos_i].basename):] - for _posData in self.data: - if _posData is None: - continue - user_ch_path = load.get_filename_from_channel( - _posData.images_path, user_ch_name - ) - if user_ch_path is None: - self.worker.abort = True - self.waitCond.wakeAll() - return - user_ch_file_paths.append(user_ch_path) - exp_path = os.path.dirname(_posData.pos_path) - - dataPrepWin = dataPrep.dataPrepWin() - dataPrepWin.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - dataPrepWin.titleText = ( - """ - Select z-slice (or projection) for each frame/position.
- Once happy, close the window. - """) - dataPrepWin.show() - dataPrepWin.initLoading() - dataPrepWin.SizeT = self.data[0].SizeT - dataPrepWin.SizeZ = self.data[0].SizeZ - dataPrepWin.metadataAlreadyAsked = True - self.logger.info(f'Loading channel {user_ch_name} data...') - dataPrepWin.loadFiles( - exp_path, user_ch_file_paths, user_ch_name - ) - dataPrepWin.startAction.setDisabled(True) - dataPrepWin.onlySelectingZslice = True - - loop = QEventLoop(self) - dataPrepWin.loop = loop - loop.exec_() - - self.waitCond.wakeAll() - - def showSetMeasurements(self, checked=False, qparent=None): - qparent = qparent if qparent is not None else self - if self.measurementsWin is not None: - self.measurementsWin.show() - self.measurementsWin.raise_() - self.measurementsWin.activateWindow() - return - - try: - df_favourite_funcs = pd.read_csv(favourite_func_metrics_csv_path) - favourite_funcs = df_favourite_funcs['favourite_func_name'].to_list() - except Exception as e: - favourite_funcs = None - - posData = self.data[self.pos_i] - allPos_acdc_df_cols = set() - for _posData in self.data: - for frame_i, data_dict in enumerate(_posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - - allPos_acdc_df_cols.update(acdc_df.columns) - loadedChNames = posData.setLoadedChannelNames(returnList=True) - posData.fluo_data_dict.pop(self.user_ch_name, None) - if self.user_ch_name not in loadedChNames: - loadedChNames.insert(0, self.user_ch_name) - notLoadedChNames = [c for c in self.ch_names if c not in loadedChNames] - self.notLoadedChNames = notLoadedChNames - self.measurementsWin = apps.SetMeasurementsDialog( - loadedChNames, notLoadedChNames, posData.SizeZ > 1, self.isSegm3D, - favourite_funcs=favourite_funcs, - allPos_acdc_df_cols=list(allPos_acdc_df_cols), - acdc_df_path=posData.images_path, posData=posData, - addCombineMetricCallback=self.addCombineMetric, - allPosData=self.data, - parent=qparent, - state=self.setMeasWinState - ) - self.measurementsWin.sigCancel.connect(self.setMeasurementsCancelled) - self.measurementsWin.sigClosed.connect(self.setMeasurements) - self.measurementsWin.show() - - def setMeasurementsCancelled(self): - self.measurementsWin = None - - def setMeasurements(self): - posData = self.data[self.pos_i] - if self.measurementsWin.delExistingCols: - self.logger.info('Removing existing unchecked measurements...') - delCols = self.measurementsWin.existingUncheckedColnames - delRps = self.measurementsWin.existingUncheckedRps - delCols_format = [f' * {colname}' for colname in delCols] - delRps_format = [f' * {colname}' for colname in delRps] - delCols_format.extend(delRps_format) - delCols_format = '\n'.join(delCols_format) - self.logger.info(delCols_format) - for _posData in self.data: - for frame_i, data_dict in enumerate(_posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - - acdc_df = acdc_df.drop(columns=delCols, errors='ignore') - for col_rp in delRps: - drop_df_rp = acdc_df.filter(regex=fr'{col_rp}.*', axis=1) - drop_cols_rp = drop_df_rp.columns - acdc_df = acdc_df.drop(columns=drop_cols_rp, errors='ignore') - _posData.allData_li[frame_i]['acdc_df'] = acdc_df - self.setMeasWinState = self.measurementsWin.state() - self.logger.info('Setting measurements...') - self._setMetrics(self.measurementsWin) - self.logger.info('Metrics successfully set.') - self.measurementsWin = None - - def _setMetrics(self, measurementsWin): - self._measurements_kernel.set_metrics_from_set_measurements_dialog( - measurementsWin - ) - for ch in self._measurements_kernel.chNamesToProcess: - if ch not in self.notLoadedChNames: - continue - - success = self.loadFluo_cb(fluo_channels=[ch]) - if not success: - continue - - def addCustomMetric(self, checked=False): - txt = measurements.add_metrics_instructions() - metrics_path = measurements.metrics_path - msg = widgets.myMessageBox() - msg.addShowInFileManagerButton(metrics_path, 'Show example...') - title = 'Add custom metrics instructions' - msg.information(self, title, txt, buttonsTexts=('Ok',)) - - def addCombineMetric(self): - posData = self.data[self.pos_i] - isZstack = posData.SizeZ > 1 - win = apps.combineMetricsEquationDialog( - self.ch_names, isZstack, self.isSegm3D, parent=self - ) - win.sigOk.connect(self.saveCombineMetricsToPosData) - win.exec_() - win.sigOk.disconnect() - - def saveCombineMetricsToPosData(self, window): - for posData in self.data: - equationsDict, isMixedChannels = window.getEquationsDict() - for newColName, equation in equationsDict.items(): - posData.addEquationCombineMetrics( - equation, newColName, isMixedChannels - ) - posData.saveCombineMetrics() - - if self.measurementsWin is None: - return - - self.measurementsWinState = self.measurementsWin.state() - self.measurementsWin.close() - self.showSetMeasurements() - self.measurementsWin.restoreState(self.measurementsWinState) - - def labelRoiToEndFramesTriggered(self): - posData = self.data[self.pos_i] - self.labelRoiStopFrameNoSpinbox.setValue(posData.SizeT) - - def labelRoiFromCurrentFrameTriggered(self): - posData = self.data[self.pos_i] - self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i+1) - - def labelRoiViewCurrentModel(self): - from . import config - ini_path = os.path.join( - settings_folderpath, 'last_params_segm_models.ini' - ) - configPars = config.ConfigParser() - configPars.read(ini_path) - model_name = self.labelRoiModel.model_name - txt = f'Model: {model_name}' - SECTION = f'{model_name}.init' - txt = f'{txt}

[Initialization parameters]
' - for option in configPars.options(SECTION): - value = configPars[SECTION][option] - param_txt = f'{option} = {value}
' - txt = f'{txt}{param_txt}' - - SECTION = f'{model_name}.segment' - txt = f'{txt}
[Segmentation parameters]
' - for option in configPars.options(SECTION): - value = configPars[SECTION][option] - param_txt = f'{option} = {value}
' - txt = f'{txt}{param_txt}' - - win = apps.ViewTextDialog(txt, parent=self) - win.exec_() - - def setMetricsFunc(self): - posData = self.data[self.pos_i] - self._measurements_kernel._set_metrics_func_from_posData(posData) - - def getLastTrackedFrame(self, posData): - last_tracked_i = 0 - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - frame_i -= 1 - break - if frame_i > 0: - return frame_i - else: - return last_tracked_i - - def computeVolumeRegionprop(self): - if 'cell_vol_vox' not in self._measurements_kernel.sizeMetricsToSave: - return - - # We compute the cell volume in the main thread because calling - # skimage.transform.rotate in a separate thread causes crashes - # with segmentation fault on macOS. I don't know why yet. - self.logger.info('Computing cell volume...') - end_i = self.save_until_frame_i - pos_iter = tqdm(self.data, ncols=100) - for p, posData in enumerate(pos_iter): - if self.posToSave is not None: - if posData.pos_foldername not in self.posToSave: - continue - - PhysicalSizeY = posData.PhysicalSizeY - PhysicalSizeX = posData.PhysicalSizeX - frame_iter = tqdm( - posData.allData_li[:end_i+1], ncols=100, position=1, leave=False - ) - for frame_i, data_dict in enumerate(frame_iter): - lab = data_dict['labels'] - if lab is None: - break - rp = data_dict['regionprops'] - obj_iter = tqdm(rp, ncols=100, position=2, leave=False) - for i, obj in enumerate(obj_iter): - vol_vox, vol_fl = _calc_rot_vol( - obj, PhysicalSizeY, PhysicalSizeX - ) - obj.vol_vox = vol_vox - obj.vol_fl = vol_fl - posData.allData_li[frame_i]['regionprops'] = rp - - def askSaveOriginalSegm(self, isQuickSave=False): - if isQuickSave: - return "", True, True - - posData = self.data[self.pos_i] - if not posData.whitelist: - return "", True, True - - help_txt = html_utils.paragraph(f""" - You have whitelisted IDs in the current position.
- Do you want to save the not whitelisted segmentation data
- This will allow you to revisit the original segmentation.
- """) - - txt = html_utils.paragraph(f""" - You have whitelisted IDs in the current position.
- Do you want to save the not whitelisted segmentation data?
- """) - - found_files = load.get_segm_files(posData.images_path) - existingEndnames = load.get_endnames( - posData.basename, found_files - ) - - segmFilename = os.path.basename(posData.segm_npz_path) - segmFilename = f"{segmFilename[:-4]}_not_whitelisted" - win = apps.filenameDialog( - basename=posData.basename, - hintText=txt, - defaultEntry=segmFilename, - existingNames=existingEndnames, - helpText=help_txt, - allowEmpty=False, - parent=self, - title='Save not whitelisted segmentation data', - addDoNotSaveButton=True - ) - win.exec_() - if win.cancel: - return "", False, True - if win.doNotSave: - return "", True, True - return win.entryText, True, False - - def askSaveLastVisitedCcaMode(self, isQuickSave=False): - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - frame_i = 0 - last_tracked_i = 0 - self.save_until_frame_i = 0 - if self.isSnapshot: - return True - - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - frame_i -= 1 - break - - self.save_until_frame_i = frame_i - self.save_cca_until_frame_i = frame_i - self.last_tracked_i = frame_i - - if isQuickSave: - return True - - last_cca_frame_i = self.navigateScrollBar.maximum()-1 - # Ask to save last visited frame or not - txt = html_utils.paragraph(f""" - You annotated the cell cycle stages up - until frame number {last_cca_frame_i+1}.

- Enter up to which frame number you want to save the - cell cycle annotations: - """) - lastFrameDialog = apps.QLineEditDialog( - title='Last annoated frame number to save', - defaultTxt=str(last_cca_frame_i+1), - msg=txt, parent=self, allowedValues=(1, last_cca_frame_i+1), - warnLastFrame=True, isInteger=True, stretchEntry=False, - lastVisitedFrame=last_cca_frame_i+1, - ) - lastFrameDialog.exec_() - if lastFrameDialog.cancel: - return False - - last_save_cca_frame_i = lastFrameDialog.enteredValue - 1 - - if last_save_cca_frame_i < last_cca_frame_i: - self.resetCcaFuture(last_cca_frame_i) - - self.save_cca_until_frame_i = last_save_cca_frame_i - - return True - - def askSaveLastVisitedSegmMode(self, isQuickSave=False): - posData = self.data[self.pos_i] - current_frame_i = posData.frame_i - frame_i = 0 - last_tracked_i = 0 - self.save_until_frame_i = 0 - self.save_cca_until_frame_i = 0 - if self.isSnapshot: - return True - - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - frame_i -= 1 - break - - if isQuickSave: - self.save_until_frame_i = frame_i - self.save_cca_until_frame_i = frame_i - self.last_tracked_i = frame_i - return True - - # Ask to save last visited frame or not - txt = html_utils.paragraph(f""" - You visualised and corrected segmentation and tracking data up - until frame number {frame_i+1}.

- Enter up to which frame number you want to save data: - """) - lastFrameDialog = apps.QLineEditDialog( - title='Last frame number to save', defaultTxt=str(frame_i+1), - msg=txt, parent=self, allowedValues=(1, posData.SizeT), - warnLastFrame=True, isInteger=True, stretchEntry=False, - lastVisitedFrame=frame_i+1, - ) - lastFrameDialog.exec_() - if lastFrameDialog.cancel: - return False - - self.save_until_frame_i = lastFrameDialog.enteredValue - 1 - self.save_cca_until_frame_i = self.save_until_frame_i - if self.save_until_frame_i > frame_i: - self.logger.info( - f'Storing frames {frame_i+1}-{self.save_until_frame_i+1}...' - ) - current_frame_i = posData.frame_i - # User is requesting to save past the last visited frame --> - # store data as if they were visited - for i in range(frame_i+1, self.save_until_frame_i+1): - posData.frame_i = i - self.get_data() - self.store_data(autosave=False) - - # Go back to current frame - posData.frame_i = current_frame_i - self.get_data() - last_tracked_i = self.save_until_frame_i - - self.last_tracked_i = last_tracked_i - return True - - def askSaveMetrics(self): - txt = html_utils.paragraph( - """ - Do you also want to save the measurements - (e.g., cell volume, mean, amount etc.)?

- - You can find more information by clicking on the - "Set measurements" button below
- where you will be able to select which measurements - you want to save.

- If you already set the measurements and you want to save them click "Yes".

- - NOTE: Saving metrics might be slow, - we recommend doing it only when you need it.
- """) - msg = widgets.myMessageBox( - parent=self, resizeButtons=False, wrapText=False - ) - setMeasurementsButton = widgets.setPushButton('Set measurements...') - _, yesButton, noButton, _ = msg.question( - self, 'Save measurements?', txt, - buttonsTexts=('Cancel', 'Yes', 'No', setMeasurementsButton), - showDialog=False - ) - setMeasurementsButton.disconnect() - setMeasurementsButton.clicked.connect( - partial( - self.showSetMeasurements, - qparent=msg, - ) - ) - msg.exec_() - save_metrics = msg.clickedButton == yesButton - return save_metrics, msg.cancel - - def askSelectPos(self, action='to save'): - last_pos = 1 - for p, posData in enumerate(self.data): - acdc_df = posData.allData_li[0]['acdc_df'] - if acdc_df is None: - last_pos = p - break - else: - last_pos = len(self.data) - - items = [posData.pos_foldername for posData in self.data] - selectPosWin = widgets.QDialogListbox( - f'Select Positions {action}', f'Select Positions {action}:\n', - items, multiSelection=True, parent=self, - preSelectedItems=items[:last_pos] - ) - selectPosWin.exec_() - if selectPosWin.cancel: - return - - return selectPosWin.selectedItemsText - - def askPosToSave(self): - return self.askSelectPos() - - def saveMetricsCritical(self, traceback_format): - print('\n====================================') - self.logger.exception(traceback_format) - print('====================================\n') - self.logger.info('Warning: calculating metrics failed see above...') - print('------------------------------') - - msg = widgets.myMessageBox(wrapText=False) - err_msg = html_utils.paragraph(f""" - Error while saving metrics.

- More details below or in the terminal/console.

- Note that the error details from this session are also saved - in the file
- {self.log_path}

- Please send the log file when reporting a bug, thanks! - Please restart Cell-ACDC, we apologise for any inconvenience.

- - """) - msg.addShowInFileManagerButton(self.logs_path, txt='Show log file...') - msg.setDetailedText(traceback_format, visible=True) - msg.critical(self, 'Critical error while saving metrics', err_msg) - - self.is_error_state = True - self.waitCond.wakeAll() - - def saveAsData(self, checked=True): - try: - posData = self.data[self.pos_i] - except AttributeError: - return - - existingFilenames = set() - for _posData in self.data: - segm_files = load.get_segm_files(_posData.images_path) - _existingEndnames = load.get_endnames( - _posData.basename, segm_files - ) - existingFilenames.update([ - f'{_posData.basename}{endname}.npz' - for endname in _existingEndnames - ]) - posData = self.data[self.pos_i] - if posData.basename.endswith('_'): - basename = f'{posData.basename}segm' - else: - basename = f'{posData.basename}_segm' - win = apps.filenameDialog( - basename=basename, - hintText='Insert a filename for the segmentation file:
', - existingNames=existingFilenames - ) - win.exec_() - if win.cancel: - return - - for posData in self.data: - posData.setFilePaths(new_endname=win.entryText) - - self.setStatusBarLabel() - self.saveData() - - def startExportToVideoWorker(self, preferences): - self.isExportingVideo = True - self.isTransparent = self.overlayToolbar.isTransparent() - if not self.isTransparent: - # SVG export works only with RGBA not with setOpacity - # --> only true transparency mode can be used - self.overlayToolbar.setTransparent(True) - - self.setDisabled(True) - - self.progressWin = apps.QDialogWorkerProgress( - title='Exporting to video', parent=self.mainWin, - pbarDesc='Exporting to video...' - ) - self.progressWin.show(self.app) - self.exportToVideoStopNavVarNum = preferences['stop_nav_var_num'] - self.numFramesExported = 0 - self.progressWin.mainPbar.setMaximum( - preferences['stop_nav_var_num'] - - preferences['start_nav_var_num'] + 1 - ) - self.exportToVideoPreferences = preferences - - self.store_data() - posData = self.data[self.pos_i] - if self.exportToVideoPreferences['is_timelapse']: - # Go to requested start frame - posData.frame_i = preferences['start_nav_var_num'] - 1 - self.get_data() - self.updateAllImages() - self.exportToVideoNavVarIdxToRestore = posData.frame_i - else: - self.update_z_slice(preferences['start_nav_var_num'] - 1) - self.exportToVideoNavVarIdxToRestore = ( - self.zSliceScrollBar.sliderPosition() - ) - self.exportToVideoCurrentNavVarIdx = ( - preferences['start_nav_var_num'] - 1 - ) - - self.exportToVideoImageExporter = exporters.ImageExporter( - self.ax1, - save_pngs=preferences['save_pngs'], - dpi=preferences['dpi'] - ) - self.exportToVideoExporter = exporters.VideoExporter( - preferences['avi_filepath'], preferences['fps'] - ) - - QTimer.singleShot(200, self.updateAndExportFrame) - - def updateAndExportFrame(self): - didVideoExporterFinish = ( - self.exportToVideoCurrentNavVarIdx - == self.exportToVideoStopNavVarNum - ) - if didVideoExporterFinish: - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.mainPbar.setValue(0) - QTimer.singleShot(50, self.exportingFramesFinished) - return - - posData = self.data[self.pos_i] - if self.exportToVideoPreferences['is_timelapse']: - self.goToFrameNumber(self.exportToVideoCurrentNavVarIdx+1) - else: - self.update_z_slice(self.exportToVideoCurrentNavVarIdx) - - success = self.exportFrame() - if success is None: - self.exportingVideoCritical() - return - - self.exportToVideoCurrentNavVarIdx += 1 - self.progressWin.mainPbar.update(1) - - QTimer.singleShot(50, self.updateAndExportFrame) - - @exception_handler - def exportFrame(self): - nd = self.exportToVideoPreferences['num_digits'] - idx = str(self.exportToVideoCurrentNavVarIdx).zfill(nd) - filename = self.exportToVideoPreferences['filename'] - png_filename = f'{idx}_{filename}.png' - pngs_folderpath = self.exportToVideoPreferences['pngs_folderpath'] - - png_filepath = os.path.join(pngs_folderpath, png_filename) - img_bgr = self.exportToVideoImageExporter.export(png_filepath) - self.exportToVideoExporter.add_frame(img_bgr) - return True - - def exportingVideoCritical(self): - self.setDisabled(False) - - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - self.logger.info('Exporting video process failed.') - - def exportingFramesFinished(self): - if not self.exportToVideoPreferences['save_pngs']: - self.logger.info('Removing PNGs...') - try: - shutil.rmtree(self.exportToVideoPreferences['pngs_folderpath']) - except Exception as err: - pass - - self.logger.info('Saving video...') - - self.exportToVideoExporter.release() - - # Run ffmpeg new process - conversion_to_mp4_successful = True - if self.exportToVideoPreferences['filepath'].endswith('.mp4'): - try: - self.exportToVideoExporter.avi_to_mp4() - try: - os.remove(self.exportToVideoPreferences['avi_filepath']) - except Exception as err: - pass - except Exception as err: - self.logger.exception(traceback.format_exc()) - self.logger.info( - 'Conversion to MP4 failed. See traceback above.' - ) - conversion_to_mp4_successful = False - self.exportToVideoPreferences['filepath'] = ( - self.exportToVideoExporter._avi_filepath - ) - - self.exportToVideoFinished(conversion_to_mp4_successful) - - def exportToVideoFinished(self, conversion_to_mp4_successful): - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - - # Back to current frame - if self.exportToVideoPreferences['is_timelapse']: - posData = self.data[self.pos_i] - posData.frame_i = self.exportToVideoNavVarIdxToRestore - self.get_data() - self.store_data() - self.updateAllImages() - self.navigateScrollBar.setSliderPosition(posData.frame_i+1) - self.navSpinBox.setValue(posData.frame_i+1) - else: - self.update_z_slice(self.exportToVideoNavVarIdxToRestore) - - self.setDisabled(False) - self.isExportingVideo = False - - if not self.isTransparent: - # True transparency mode was activated programmatically - # --> restore what the user had before starting to export - self.overlayToolbar.setTransparent(False) - - prompts.exportToVideoFinished( - self.exportToVideoPreferences, conversion_to_mp4_successful, - qparent=self - ) - - def exportAddScaleBar(self, checked): - self.addScaleBarAction.setChecked(checked) - - def exportToVideoAddTimestamp(self, checked): - self.addTimestampAction.setChecked(checked) - - def askTimelapseOrZslicesVideo(self): - txt = html_utils.paragraph(""" - Do you want to record a video of scrolling through the z-slices or - a Timelapse video? - """) - msg = widgets.myMessageBox(wrapText=False) - _, timelapseButton = msg.question( - self, 'Z-slices or Timelapse video?', txt, - buttonsTexts=('Z-slices', 'Timelapse') - ) - if msg.cancel: - return - - return msg.clickedButton == timelapseButton - - def exportToVideoTriggered(self): - posData = self.data[self.pos_i] - - doTimelapseVideo = posData.SizeT > 1 - if posData.SizeT > 1 and posData.SizeZ > 1: - doTimelapseVideo = self.askTimelapseOrZslicesVideo() - - if doTimelapseVideo is None: - self.logger.info('Export to video process cancelled') - return - - channels = [self.user_ch_name, *self.checkedOverlayChannels] - mode = 'timelapse' if doTimelapseVideo else 'z_slices' - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'{timestamp}_acdc_exported_{mode}_video' - win = apps.ExportToVideoParametersDialog( - channels, - parent=self, - startFolderpath=posData.pos_path, - startFilename=filename, - startFrameNum=posData.frame_i+1, - SizeT=posData.SizeT, - SizeZ=posData.SizeZ, - isTimelapseVideo=doTimelapseVideo, - isScaleBarPresent=self.addScaleBarAction.isChecked(), - isTimestampPresent=self.addTimestampAction.isChecked(), - rescaleIntensChannelHowMapper=self.rescaleIntensChannelHowMapper - ) - win.sigAddScaleBar.connect(self.exportAddScaleBar) - win.sigAddTimestamp.connect(self.exportToVideoAddTimestamp) - win.sigRescaleIntensLut.connect(self.rescaleIntensExportToVideoDialog) - win.exec_() - if win.cancel: - self.logger.info('Export to video process cancelled') - return - - cancel = _warnings.warnExportToVideo(qparent=self) - if cancel: - self.logger.info('Export to video process cancelled') - return - - self.startExportToVideoWorker(win.selected_preferences) - - def setExportMaskImage(self, viewRange): - if not hasattr(self, 'exportMaskImage'): - self.initExportMaskImage() - else: - self.exportMaskImage[:] = 0 - - xRange, yRange = viewRange - x0, x1 = map(round, xRange) - y0, y1 = map(round, yRange) - - if self.invertBwAction.isChecked(): - self.exportMaskImage[:, :, :3] = 255 - - if x0 > 0: - self.exportMaskImage[:, :x0, 3] = 255 - if x1 < self.exportMaskImage.shape[1]: - self.exportMaskImage[:, x1:, 3] = 255 - if y0 > 0: - self.exportMaskImage[:y0, :, 3] = 255 - if y1 < self.exportMaskImage.shape[0]: - self.exportMaskImage[y1:, :, 3] = 255 - - self.exportMaskImageItem.setImage(self.exportMaskImage) - - def setViewRangeFromExportToImageDialog(self, viewRange, win=None): - xRange, yRange = viewRange - # self.ax1.sigRangeChanged.disconnect(self.viewRangeChanged) - self.ax1.setRange(xRange=xRange, yRange=yRange) - # self.ax1.sigRangeChanged.connect(self.viewRangeChanged) - # self.viewRangeChanged( - # self.ax1.vb, viewRange, updateExportMaskImage=False - # ) - self.setExportMaskImage(viewRange) - - def getZoomIDs(self, viewRange=None): - if viewRange is None: - viewRange = self.ax1.viewRange() - - lab = self.currentLab2D - Y, X = lab.shape - ((xmin, xmax), (ymin, ymax)) = viewRange - if xmin <= 0 and ymin <= 0 and xmax >= X and ymax >= Y: - posData = self.data[self.pos_i] - return None - - xmin = xmin if xmin >= 0 else 0 - ymin = ymin if ymin >= 0 else 0 - xmax = xmax if xmax < X else X - ymax = ymax if ymax < Y else Y - - zoomSlice = ( - slice(round(ymin), round(ymax)), - slice(round(xmin), round(xmax)), - ) - - zoomLab = skimage.segmentation.clear_border(lab[zoomSlice]) - zoomRp = skimage.measure.regionprops(zoomLab) - zoomIDs = [obj.label for obj in zoomRp] - return zoomIDs - - def onSigUpdateCcaTableWindow(self, *args): - if not self.isDataLoaded: - return - - if self.ccaTableWin is None: - return - - viewRange = self.ax1.viewRange() - posData = self.data[self.pos_i] - zoomIDs = self.getZoomIDs(viewRange=viewRange) - - self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) - - @disableWindow - def exportToImage(self, preferences): - filepath = preferences['filepath'] - self.logger.info(f'Saving image to "{filepath}"...') - - if filepath.endswith('.svg'): - exporter = exporters.SVGExporter(self.ax1) - else: - exporter = exporters.ImageExporter(self.ax1, dpi=preferences['dpi']) - exporter.export(filepath) - self.logger.info(f'Image saved.') - - self.setDisabled(False) - self.exportMaskImage[:] = 0 - self.exportMaskImageItem.setImage(self.exportMaskImage) - prompts.exportToImageFinished(filepath, qparent=self) - - def exportToImageTriggered(self): - posData = self.data[self.pos_i] - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'{timestamp}_acdc_exported_image' - win = apps.ExportToImageParametersDialog( - parent=self, - startFolderpath=posData.pos_path, - startFilename=filename, - startViewRange=self.ax1.viewRange(), - isScaleBarPresent=self.addScaleBarAction.isChecked(), - ) - win.sigAddScaleBar.connect(self.exportAddScaleBar) - win.sigRangeChanged.connect( - partial(self.setViewRangeFromExportToImageDialog, win=win) - ) - # self.ax1.vb.sigRangeChanged.connect( - # win.updateViewRangeExportToImageDialog - # ) - self.setExportMaskImage(self.ax1.viewRange()) - self.exportToImageWindow = win - win.exec_() - # self.ax1.vb.sigRangeChanged.disconnect() - if win.cancel: - self.exportMaskImage[:] = 0 - self.exportMaskImageItem.setImage(self.exportMaskImage) - self.exportToImageWindow = None - self.logger.info('Export to image process cancelled') - return - - isTransparent = self.overlayToolbar.isTransparent() - if not isTransparent: - # SVG export works only with RGBA not with setOpacity - # --> only true transparency mode can be used - self.overlayToolbar.setTransparent(True) - - self.exportToImage(win.selected_preferences) - self.exportToImageWindow = None - - if not isTransparent: - self.overlayToolbar.setTransparent(False) - - def saveDataPermissionError(self, err_msg): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - msg = QMessageBox() - msg.critical(self, 'Permission denied', err_msg, msg.Ok) - self.waitCond.wakeAll() - - def saveDataProgress(self, text): - self.logger.info(text) - self.saveWin.progressLabel.setText(text) - - def saveDataCustomMetricsCritical(self, traceback_format, func_name): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - self.logger.info('') - _hl = '====================================' - self.logger.info(f'{_hl}\n{traceback_format}\n{_hl}') - self.worker.customMetricsErrors[func_name] = traceback_format - - def saveDataCombinedMetricsMissingColumn(self, error_msg, func_name): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - self.logger.info('') - warning = f'[WARNING]: {error_msg}. Metric {func_name} was skipped.' - _hl = '====================================' - self.logger.info(f'{_hl}\n{warning}\n{_hl}') - self.worker.customMetricsErrors[func_name] = warning - - def saveDataAddMetricsCritical(self, traceback_format, error_message): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - self.logger.info('') - _hl = '====================================' - self.logger.info(f'{_hl}\n{traceback_format}\n{_hl}') - self.worker.addMetricsErrors[error_message] = traceback_format - - def saveDataRegionPropsCritical(self, traceback_format, error_message): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - self.logger.info('') - _hl = '====================================' - self.logger.info(f'{_hl}\n{traceback_format}\n{_hl}') - self.worker.regionPropsErrors[error_message] = traceback_format - - def saveDataUpdateMetricsPbar(self, max, step): - if max > 0: - self.saveWin.metricsQPbar.setMaximum(max) - self.saveWin.metricsQPbar.setValue(0) - self.saveWin.metricsQPbar.setValue( - self.saveWin.metricsQPbar.value()+step - ) - - def saveDataUpdatePbar(self, step, max=-1, exec_time=0.0): - if max >= 0: - self.saveWin.QPbar.setMaximum(max) - else: - self.saveWin.QPbar.setValue(self.saveWin.QPbar.value()+step) - steps_left = self.saveWin.QPbar.maximum()-self.saveWin.QPbar.value() - seconds = round(exec_time*steps_left) - ETA = myutils.seconds_to_ETA(seconds) - self.saveWin.ETA_label.setText(f'ETA: {ETA}') - - def quickSave(self): - self.saveData(isQuickSave=True) - - def checkMissingCca(self): - proceed = True - ignore = False - doNotShowAgain = False - if not self.doNotShowAgainMissingCca: - return proceed, ignore, doNotShowAgain - - missing_cca_items = [] - for posData in self.data: - for frame_i, data_dict in enumerate(posData.allData_li): - acdc_df = data_dict['acdc_df'] - if acdc_df is None: - continue - - if 'cell_cycle_stage' not in acdc_df.columns: - continue - - cca_df = acdc_df[cca_df_colnames] - if cca_df.isnull().values.any(): - i = frame_i if not self.isSnapshot else None - missing_cca_items.append((cca_df, posData, i)) - - if not missing_cca_items: - return proceed, ignore, doNotShowAgain - - proceed = False - ignore, doNotShowAgain =_warnings.warnMissingCca( - missing_cca_items, qparent=self - ) - - if doNotShowAgain: - self.df_settings.at['doNotShowAgainMissingCca', 'value'] = 'Yes' - self.df_settings.to_csv(self.settings_csv_path) - - return proceed, ignore, doNotShowAgain - - def warnDifferentSegmChannel( - self, loaded_channel, segm_channel_hyperparams, segmEndName - ): - txt = html_utils.paragraph(f""" - You loaded the segmentation file ending with _{segmEndName}.npz - which corresponds to the channel - {segm_channel_hyperparams}.

- However, in this session you loaded the channel - {loaded_channel}.

- If you proceed with saving, the segmentation file ending with - _{segmEndName}.npz will be OVERWRITTEN.

- Are you sure you want to proceed? - """) - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - msg.warning( - self, 'WARNING: Potential for data loss', txt, - buttonsTexts=('Cancel', 'Yes') - ) - return msg.cancel - - def waitAutoSaveWorker(self, worker): - if worker.isFinished or worker.isPaused or len(worker.dataQ) == 0: - self.waitAutoSaveWorkerLoop.exit() - self.waitAutoSaveWorkerTimer.stop() - self.setStatusBarLabel(log=False) - - @exception_handler - def saveData(self, checked=False, finishedCallback=None, isQuickSave=False): - self.setDisabled(True, keepDisabled=True) - - self.askLineageTreeChanges() - - self.store_data(autosave=False) - self.applyDelROIs() - self.store_data() - self._isQuickSave = isQuickSave - - # Wait autosave worker to finish - for worker, thread in self.autoSaveActiveWorkers: - self.logger.info('Stopping autosaving process...') - self.statusBarLabel.setText('Stopping autosaving process...') - worker.stop() - self.waitAutoSaveWorkerTimer = QTimer() - self.waitAutoSaveWorkerTimer.timeout.connect( - partial(self.waitAutoSaveWorker, worker) - ) - self.waitAutoSaveWorkerTimer.start(100) - self.waitAutoSaveWorkerLoop = QEventLoop() - self.waitAutoSaveWorkerLoop.exec_() - - self.titleLabel.setText( - 'Saving data... (check progress in the terminal)', - color=self.titleColor - ) - - # Check channel name correspondence to warn - posData = self.data[self.pos_i] - lastSegmChannel, segmEndName = posData.getSegmentedChannelHyperparams() - if lastSegmChannel != self.user_ch_name and lastSegmChannel: - cancel = self.warnDifferentSegmChannel( - self.user_ch_name, lastSegmChannel, segmEndName - ) - if cancel: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - posData.updateSegmentedChannelHyperparams(self.user_ch_name) - - # Check missing cca annotations in snaphots - proceed, ignore, self.doNotShowAgainMissingCca = self.checkMissingCca() - if not proceed and not ignore: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return - - self.save_metrics = False - if not isQuickSave: - self.save_metrics, cancel = self.askSaveMetrics() - if cancel: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - - self.posToSave = None - if self.isSnapshot and not isQuickSave and len(self.data) > 1: - self.posToSave = self.askPosToSave() - if self.posToSave is None: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - - if isQuickSave: - # Quick save only current pos - self.posToSave = {self.data[self.pos_i].pos_foldername} - - if self.isSnapshot: - self.store_data(mainThread=False) - - mode = self.modeComboBox.currentText() - if mode == 'Cell cycle analysis': - proceed = self.askSaveLastVisitedCcaMode(isQuickSave=isQuickSave) - if not proceed: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - else: - proceed = self.askSaveLastVisitedSegmMode(isQuickSave=isQuickSave) - if not proceed: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - - append_name_og_whitelist, proceed, do_not_save_og_whitelist = self.askSaveOriginalSegm(isQuickSave=isQuickSave) - if not proceed: - self.cancelSavingInitialisation() - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - return True - - if self.save_metrics or mode == 'Cell cycle analysis': - self.computeVolumeRegionprop() - - infoTxt = html_utils.paragraph( - f'Saving {self.exp_path}...
', font_size='14px' - ) - - self.saveWin = apps.QDialogPbar( - parent=self, title='Saving data', infoTxt=infoTxt - ) - self.saveWin.setFont(_font) - # if not self.save_metrics: - self.saveWin.metricsQPbar.hide() - self.saveWin.progressLabel.setText('Preparing data...') - self.saveWin.show() - - # Set up separate thread for saving and show progress bar widget - self.mutex = QMutex() - self.waitCond = QWaitCondition() - self.thread = QThread() - self.worker = workers.saveDataWorker(self) - self.worker.mode = mode - self.worker.isQuickSave = isQuickSave - self.worker.append_name_og_whitelist = append_name_og_whitelist - self.worker.do_not_save_og_whitelist = do_not_save_og_whitelist - - self.worker.moveToThread(self.thread) - - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - - # Custom signals - self.worker.finished.connect(self.saveDataFinished) - if finishedCallback is not None: - self.worker.finished.connect(finishedCallback) - self.worker.progress.connect(self.saveDataProgress) - self.worker.sigLog.connect(self.workerLog) - self.worker.progressBar.connect(self.saveDataUpdatePbar) - # self.worker.metricsPbarProgress.connect(self.saveDataUpdateMetricsPbar) - self.worker.critical.connect(self.saveDataWorkerCritical) - self.worker.customMetricsCritical.connect( - self.saveDataCustomMetricsCritical - ) - self.worker.sigCombinedMetricsMissingColumn.connect( - self.saveDataCombinedMetricsMissingColumn - ) - self.worker.addMetricsCritical.connect(self.saveDataAddMetricsCritical) - self.worker.regionPropsCritical.connect( - self.saveDataRegionPropsCritical - ) - self.worker.criticalPermissionError.connect(self.saveDataPermissionError) - self.worker.askZsliceAbsent.connect(self.zSliceAbsent) - self.worker.sigDebug.connect(self._workerDebug) - - self.thread.started.connect(self.worker.run) - - self.thread.start() - - return False - - def _workerDebug(self, stuff_to_debug): - pass - # from acdctools.plot import imshow - # lab, frame_i, autoBkgr_masks = stuff_to_debug - # autoBkgr_mask, autoBkgr_mask_proj = autoBkgr_masks - # imshow(lab, autoBkgr_mask) - # self.worker.waitCond.wakeAll() - - def changeTextResolution(self): - mode = 'high' if self.highLowResAction.isChecked() else 'low' - self.logger.info( - f'Switching to {mode} for the text annnotations...' - ) - self.pxModeAction.setDisabled(not self.highLowResAction.isChecked()) - if not self.isDataLoaded: - return - - self.setAllIDs() - posData = self.data[self.pos_i] - allIDs = posData.allIDs - img_shape = self.img1.image.shape[:2] - self.textAnnot[0].changeResolution(mode, allIDs, self.ax1, img_shape) - self.textAnnot[1].changeResolution(mode, allIDs, self.ax2, img_shape) - self.updateAllImages() - - def highLowResToggled(self, clicked=True): - self.changeTextResolution() - - def autoSaveClose(self): - for worker, thread in self.autoSaveActiveWorkers: - worker._stop() - - def viewPreprocDataToggled(self, checked): - self.img1.setUsePreprocessed(checked) - self.setImageImg1() - - if self.viewCombineChannelDataToggle.isChecked(): - self.viewCombineChannelDataToggle.toggled.disconnect() - self.viewCombineChannelDataToggle.setChecked(False) - self.viewCombineChannelDataToggle.toggled.connect( - self.viewCombineChannelDataToggled - ) - - def setAutoSaveSegmentationEnabled(self, enabled): - if not self.autoSaveActiveWorkers: - return - - worker, thread = self.autoSaveActiveWorkers[-1] - - if enabled: - worker.isAutoSaveON = self.autoSaveToggle.isChecked() - else: - worker.isAutoSaveON = False - - def setAutoSaveAnnotationsEnabled(self, enabled): - if not self.autoSaveActiveWorkers: - return - - worker, thread = self.autoSaveActiveWorkers[-1] - - if enabled: - worker.isAutoSaveAnnotON = self.autoSaveToggle.isChecked() - else: - worker.isAutoSaveAnnotON = False - - def autoSaveToggled(self, checked): - if not self.autoSaveActiveWorkers: - self.gui_createAutoSaveWorker() - - if not self.autoSaveActiveWorkers: - return - - worker, thread = self.autoSaveActiveWorkers[-1] - - mode = self.modeComboBox.currentText() - if mode != 'Segmentation and Tracking': - # Autosaving segmentation makes sense only in - # "Segmentation and Tracking" mode - checked = False - - worker.isAutoSaveON = checked - - def autoSaveAnnotToggled(self, checked): - if not self.autoSaveActiveWorkers: - self.gui_createAutoSaveWorker() - - if not self.autoSaveActiveWorkers: - return - - worker, thread = self.autoSaveActiveWorkers[-1] - - mode = self.modeComboBox.currentText() - if mode != 'Viewer': - # No reason to save in viewer mode - checked = False - - worker.isAutoSaveAnnotON = checked - - def autoSaveIntervalEdit(self): - self.autoSaveIntervalDialog.show() - self.autoSaveIntervalDialog.raise_() - self.autoSaveIntervalDialog.activateWindow() - - def autoSaveIntervalValueChanged( - self, value: float, unit: Literal['minutes', 'frames'] - ): - self.autoSaveIntevalValueUnit = (value, unit) - self.autoSaveTimer.stop() - - self.df_settings.at['autoSaveIntevalValue', 'value'] = str(value) - self.df_settings.at['autoSaveIntervalUnit', 'value'] = unit - self.df_settings.to_csv(settings_csv_path) - - self.logger.info( - f'Autosave interval changed to: {value} {unit}' - ) - self.autoSaveIntervalSetTooltip() - - if unit == 'frames': - self.startAutoSaveEveryNframesTimer() - - def autoSaveIntervalSetTooltip(self): - value, unit = self.autoSaveIntevalValueUnit - autoSaveIntervalEditTooltip = ( - 'Change autosave interval to every N frames or minutes\n\n' - f'Current autosave interval: {value} {unit}' - ) - self.autoSaveIntervalLabel.setToolTip(autoSaveIntervalEditTooltip) - self.autoSaveIntervalEditButton.setToolTip(autoSaveIntervalEditTooltip) - - def ccaIntegrCheckerToggled(self, checked): - self.df_settings.at['is_cca_integrity_checker_activated', 'value'] = ( - int(checked) - ) - self.df_settings.to_csv(self.settings_csv_path) - mode = self.modeComboBox.currentText() - if mode != 'Cell cycle analysis': - return - - if checked: - self.startCcaIntegrityCheckerWorker() - else: - self.disableCcaIntegrityChecker() - - def warnErrorsCustomMetrics(self): - win = apps.ComputeMetricsErrorsDialog( - self.worker.customMetricsErrors, self.logs_path, - log_type='custom_metrics', parent=self - ) - win.exec_() - - def warnErrorsAddMetrics(self): - win = apps.ComputeMetricsErrorsDialog( - self.worker.addMetricsErrors, self.logs_path, - log_type='standard_metrics', parent=self - ) - win.exec_() - - def warnErrorsRegionProps(self): - win = apps.ComputeMetricsErrorsDialog( - self.worker.regionPropsErrors, self.logs_path, - log_type='region_props', parent=self - ) - win.exec_() - - def askConcatenate(self): - if self.mainWin is None: - return - - if self._isQuickSave: - return - - if 'showAskConcatenate' not in self.df_settings.index: - self.df_settings.at['showAskConcatenate', 'value'] = 'Yes' - - showAskConcatenate = ( - self.df_settings.at['showAskConcatenate', 'value'] == 'Yes' - ) - if not showAskConcatenate: - return - - txt = html_utils.paragraph(f""" - Do you want to concatenate the `acdc_output.csv` tables from - multiple Positions into one single CSV file?
- """) - doNotShowAgainCheckbox = QCheckBox('Do not show again') - msg = widgets.myMessageBox(wrapText=False) - noButton, yesButton = msg.question( - self, 'Concatenate tables?', txt, - buttonsTexts=('No', 'Yes'), - widgets=doNotShowAgainCheckbox - ) - showAskConcatenate = ( - 'No' if doNotShowAgainCheckbox.isChecked() else 'Yes' - ) - self.df_settings.at['showAskConcatenate', 'value'] = ( - showAskConcatenate - ) - self.df_settings.to_csv(settings_csv_path) - - if not msg.clickedButton == yesButton: - return - - txt = html_utils.paragraph(f""" - To concatenate the `acdc_output.csv` tables from - multiple Positions and multiple experiments
- launch the concatenation utility from the top menubar of the Cell-ACDC main launcher:

- Utilities --> Concatenate --> Concatenate acdc output tables from multiple Positions and experiments.... - """) - msg = widgets.myMessageBox(wrapText=False) - msg.information(self, 'How to concatenate tables', txt) - - def updateSegmDataAutoSaveWorker(self): - # Update savedSegmData in autosave worker - posData = self.data[self.pos_i] - for worker, thread in self.autoSaveActiveWorkers: - worker.savedSegmData = posData.segm_data.copy() - - def saveDataFinished(self): - self.setDisabled(False, keepDisabled=False) - self.activateWindow() - if self.saveWin.aborted or self.worker.abort: - self.titleLabel.setText('Saving process cancelled.', color='r') - elif self._isQuickSave: - self.titleLabel.setText('Saved segmentation file and annotations') - else: - self.titleLabel.setText('Saved!') - self.saveWin.workerFinished = True - self.saveWin.close() - - if not self.closeGUI: - # Update savedSegmData in autosave worker - self.updateSegmDataAutoSaveWorker() - - if self.worker.addMetricsErrors: - self.warnErrorsAddMetrics() - if self.worker.regionPropsErrors: - self.warnErrorsRegionProps() - if self.worker.customMetricsErrors: - self.warnErrorsCustomMetrics() - - self.checkManageVersions() - - self.askConcatenate() - - if self.closeGUI: - salute_string = myutils.get_salute_string() - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Data saved!. The GUI will now close.

' - f'{salute_string}' - ) - msg.information(self, 'Data saved', txt) - self.close() - - def copyContent(self): - pass - - def pasteContent(self): - pass - - def cutContent(self): - pass - - def showAbout(self): - self.aboutWin = about.QDialogAbout(parent=self) - self.aboutWin.show() - - def openLogFile(self): - self.logger.info(f'Opening log file "{self.log_path}"...') - myutils.showInExplorer(self.log_path) - - def showLogFiles(self): - log_files_path = os.path.dirname(self.log_path) - self.logger.info(f'Opening log files folder "{log_files_path}"...') - myutils.showInExplorer(log_files_path) - - def showTipsAndTricks(self): - self.welcomeWin = welcome.welcomeWin() - self.welcomeWin.showAndSetSize() - self.welcomeWin.showPage(self.welcomeWin.quickStartItem) - - def about(self): - pass - - def openRecentFile(self, path): - self.logger.info(f'Opening recent folder: {path}') - self.addToRecentPaths(path, logger=self.logger) - self.openFolder(exp_path=path) - - def _waitCloseAutoSaveWorker(self): - didWorkersFinished = [True] - for worker, thread in self.autoSaveActiveWorkers: - if worker.isFinished: - didWorkersFinished.append(True) - else: - didWorkersFinished.append(False) - if all(didWorkersFinished): - self.waitCloseAutoSaveWorkerLoop.stop() - - def cancelSavingInitialisation(self): - self.titleLabel.setText( - 'Saving data process cancelled.', color=self.titleColor - ) - self.closeGUI = False - - @disableWindow - def askSaveOnClosing(self, event): - if not self.saveAction.isEnabled(): - return True - if self.titleLabel.text == 'Saved!': - return True - if not self.isDataLoaded: - return True - - msg = widgets.myMessageBox() - txt = html_utils.paragraph('Do you want to save before closing?') - _, noButton, yesButton = msg.question( - self, 'Save?', txt, - buttonsTexts=('Cancel', 'No', 'Yes') - ) - if msg.cancel: - event.ignore() - return False - - if msg.clickedButton == yesButton: - self.closeGUI = True - QTimer.singleShot(100, self.saveAction.trigger) - event.ignore() - return False - return True - - def clearMemory(self): - if not hasattr(self, 'data'): - return - self.logger.info('Clearing memory...') - for posData in self.data: - try: - del posData.img_data - except Exception as e: - pass - try: - del posData.segm_data - except Exception as e: - pass - try: - del posData.ol_data_dict - except Exception as e: - pass - try: - del posData.fluo_data_dict - except Exception as e: - pass - try: - del posData.ol_data - except Exception as e: - pass - del self.data - - def setUncheckedPointsLayers(self): - self.togglePointsLayerAction.setChecked(False) - self.magicPromptsToolButton.setChecked(False) - - def clearHighlightedID(self): - self.highlightIDToolbar.setVisible(False) - - try: - self.updateLostContoursImage(ax=0, delROIsIDs=None) - except Exception as err: - pass - - if self.highlightedID == 0: - return - - self.highlightedID = 0 - self.guiTabControl.highlightCheckbox.setChecked(False) - self.guiTabControl.highlightSearchedCheckbox.setChecked(False) - self.setHighlightID(False) - - def onEscape( - self, - isTypingIDFunctionChecked=False, - buttonsToNotUncheck=None, - doAutoRange=True - ): - if buttonsToNotUncheck is None: - buttonsToNotUncheck = set() - - if self.keepIDsButton.isChecked() and self.keptObjectsIDs: - self.keptObjectsIDs = widgets.KeptObjectIDsList( - self.keptIDsLineEdit, self.keepIDsConfirmAction - ) - self.highlightHoverIDsKeptObj(0, 0, hoverID=0) - QTimer.singleShot(300, self.autoRange) - return - - if self.brushButton.isChecked() and self.typingEditID: - self.autoIDcheckbox.setChecked(True) - self.typingEditID = False - QTimer.singleShot(300, self.autoRange) - return - - if isTypingIDFunctionChecked and self.typingEditID: - self.typingEditID = False - QTimer.singleShot(300, self.autoRange) - return - - if self.labelRoiButton.isChecked() and self.isMouseDragImg1: - self.isMouseDragImg1 = False - self.labelRoiItem.setPos((0,0)) - self.labelRoiItem.setSize((0,0)) - self.freeRoiItem.clear() - QTimer.singleShot(300, self.autoRange) - return - - if self.zoomRectButton.isChecked(): - self.zoomRectCancelled() - QTimer.singleShot(300, self.autoRange) - return - - self.setUncheckedAllButtons(buttonsToNotUncheck=buttonsToNotUncheck) - self.setUncheckedAllCustomAnnotButtons() - self.setUncheckedPointsLayers() - self.clearTempBrushImage() - self.isMouseDragImg1 = False - self.typingEditID = False - self.clearHighlightedID() - try: - self.polyLineRoi.clearPoints() - except Exception as e: - pass - - if doAutoRange: - QTimer.singleShot(11, self.autoRange) - - def clearTempBrushImage(self, forceClearLinked=True): - if not hasattr(self, 'tempLayerImg1'): - return - - self.tempLayerImg1.setImage( - self.emptyLab, force_set_linked=forceClearLinked - ) - - try: - self.brushContourImage[:] = 0 - except Exception as err: - pass - - try: - self.brushImage[:] = 0 - except Exception as err: - pass - - def askCloseAllWindows(self): - txt = html_utils.paragraph(""" - There are other open windows that were created from this window. -

- If you proceed, the other windows will be closed too.
- """) - msg = widgets.myMessageBox(wrapText=False) - msg.warning( - self, 'Open windows', txt, - buttonsTexts=('Cancel', 'Ok, close now') - ) - return msg.cancel - - def stopPreprocWorker(self): - self.logger.info('Closing pre-processing worker...') - try: - self.preprocWorker.stop() - except Exception as err: - pass - - def closeEvent(self, event): - self.setDisabled(False) - cancel = self.checkAskSavePointsLayers() - if cancel: - event.ignore() - return - - self.onEscape() - self.saveWindowGeometry() - - if self.newWindows: - cancel = self.askCloseAllWindows() - if cancel: - event.ignore() - return - - for window in self.newWindows: - window.close() - - if self.slideshowWin is not None: - self.slideshowWin.close() - if self.ccaTableWin is not None: - self.ccaTableWin.close() - - proceed = self.askSaveOnClosing(event) - if not proceed: - event.ignore() - return - - self.autoSaveClose() - - if self.autoSaveActiveWorkers: - progressWin = apps.QDialogWorkerProgress( - title='Closing autosaving worker', parent=self, - pbarDesc='Closing autosaving worker...' - ) - progressWin.show(self.app) - progressWin.mainPbar.setMaximum(0) - self.waitCloseAutoSaveWorkerLoop = qutils.QWhileLoop( - self._waitCloseAutoSaveWorker, period=250 - ) - self.waitCloseAutoSaveWorkerLoop.exec_() - progressWin.workerFinished = True - progressWin.close() - - self.stopPreprocWorker() - self.stopCombineWorker() - self.stopCcaIntegrityCheckerWorker() - - # Close the inifinte loop of the thread - if self.lazyLoader is not None: - self.lazyLoader.exit = True - self.lazyLoaderWaitCond.wakeAll() - self.waitReadH5cond.wakeAll() - - if self.storeStateWorker is not None: - # Close storeStateWorker - self.storeStateWorker._stop() - while self.storeStateWorker.isFinished: - time.sleep(0.05) - - # Block main thread while separate threads closes - time.sleep(0.1) - - self.clearMemory() - - self.logger.info('Closing GUI logger...') - self.logger.close() - - if self.lazyLoader is None: - self.sigClosed.emit(self) - - gc.collect() - - def storeManualSeparateDrawMode(self, mode): - self.df_settings.at['manual_separate_draw_mode', 'value'] = mode - self.df_settings.to_csv(self.settings_csv_path) - - def readSettings(self): - settings = QSettings('schmollerlab', 'acdc_gui') - if settings.value('geometry') is not None: - self.restoreGeometry(settings.value("geometry")) - # self.restoreState(settings.value("windowState")) - - def saveWindowGeometry(self): - settings = QSettings('schmollerlab', 'acdc_gui') - settings.setValue("geometry", self.saveGeometry()) - # settings.setValue("windowState", self.saveState()) - - def storeDefaultAndCustomColors(self): - c = self.overlayButton.palette().button().color().name() - self.defaultToolBarButtonColor = c - self.doublePressKeyButtonColor = '#fa693b' - - def initPixelSizePropsDockWidget(self): - posData = self.data[self.pos_i] - PhysicalSizeX = posData.PhysicalSizeX - PhysicalSizeY = posData.PhysicalSizeY - PhysicalSizeZ = posData.PhysicalSizeZ - self.guiTabControl.initPixelSize( - PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ - ) - - def showPropsDockWidget(self, checked=False): - if self.showPropsDockButton.isExpand: - self.propsDockWidget.setVisible(False) - self.setHighlightID(False) - else: - self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() - if self.isSegm3D: - self.guiTabControl.propsQGBox.cellVolVox3D_SB.show() - self.guiTabControl.propsQGBox.cellVolVox3D_SB.label.show() - self.guiTabControl.propsQGBox.cellVolFl3D_DSB.show() - self.guiTabControl.propsQGBox.cellVolFl3D_DSB.label.show() - else: - self.guiTabControl.propsQGBox.cellVolVox3D_SB.hide() - self.guiTabControl.propsQGBox.cellVolVox3D_SB.label.hide() - self.guiTabControl.propsQGBox.cellVolFl3D_DSB.hide() - self.guiTabControl.propsQGBox.cellVolFl3D_DSB.label.hide() - - self.propsDockWidget.setVisible(True) - self.propsDockWidget.setEnabled(True) - self.updateAllImages() - - def showEvent(self, event): - if self.mainWin is not None: - if not self.mainWin.isMinimized(): - return - self.mainWin.showAllWindows() - # self.setFocus() - self.activateWindow() - - def super_show(self): - super().show() - - def show(self): - self.setFont(_font) - QMainWindow.show(self) - - self.setWindowState(Qt.WindowNoState) - self.setWindowState(Qt.WindowActive) - self.raise_() - - self.readSettings() - self.storeDefaultAndCustomColors() - - self.h = self.navSpinBox.size().height() - fontSizeFactor = None - heightFactor = None - if 'bottom_sliders_zoom_perc' in self.df_settings.index: - val = int(self.df_settings.at['bottom_sliders_zoom_perc', 'value']) - if val != 100: - fontSizeFactor = val/100 - heightFactor = val/100 - - self.defaultWidgetHeightBottomLayout = self.h - self.checkBoxesHeight = 14 - self.fontPixelSize = 11 - self.defaultBottomLayoutHeight = self.img1BottomGroupbox.height() - - self.bottomLayout.setStretch(0, 0) - self.bottomLayout.addSpacing(self.quickSettingsGroupbox.width()) - self.resizeSlidersArea( - fontSizeFactor=fontSizeFactor, heightFactor=heightFactor - ) - self.bottomScrollArea.hide() - - self.gui_initImg1BottomWidgets() - self.img1BottomGroupbox.hide() - - w = self.showPropsDockButton.width() - h = self.showPropsDockButton.height() - - self.showPropsDockButton.setMaximumWidth(15) - self.showPropsDockButton.setMaximumHeight(120) - - for toolbar in self.controlToolBars: - toolbar.setMinimumHeight( - self.secondLevelToolbar.sizeHint().height() - ) - - self.graphLayout.setFocus() - - def resizeSlidersArea(self, fontSizeFactor=None, heightFactor=None): - global _font - if heightFactor is None: - self.newCheckBoxesHeight = self.checkBoxesHeight - self.newHeight = self.h - else: - self.newHeight = round(self.h*heightFactor) - self.newCheckBoxesHeight = round(self.checkBoxesHeight*heightFactor) - - if fontSizeFactor is None: - newFontSize = self.fontPixelSize - else: - newFontSize = round(self.fontPixelSize*fontSizeFactor) - newFont = QFont() - newFont.setPixelSize(newFontSize) - _font = newFont - self.zProjComboBox.setFont(newFont) - self.t_label.setFont(newFont) - self.zProjOverlay_CB.setFont(newFont) - self.annotateRightHowCombobox.setFont(newFont) - self.drawIDsContComboBox.setFont(newFont) - self.showTreeInfoCheckbox.setFont(newFont) - self.highlightZneighObjCheckbox.setFont(newFont) - self.navSpinBox.setFont(newFont) - self.zSliceSpinbox.setFont(newFont) - self.SizeZlabel.setFont(newFont) - self.navSizeLabel.setFont(newFont) - self.overlay_z_label.setFont(newFont) - self.img1BottomGroupbox.setFont(newFont) - self.rightBottomGroupbox.setFont(newFont) - try: - self.img1.alphaScrollbar.label.setFont(newFont) - except Exception as e: - pass - for i in range(self.annotOptionsLayout.count()): - widget = self.annotOptionsLayout.itemAt(i).widget() - widget.setFont(newFont) - for i in range(self.annotOptionsLayoutRight.count()): - widget = self.annotOptionsLayoutRight.itemAt(i).widget() - widget.setFont(newFont) - try: - for channel, items in self.overlayLayersItems.items(): - alphaScrollbar = items[2] - alphaScrollbar.label.setFont(newFont) - except: - pass - QTimer.singleShot(100, self._resizeSlidersArea) - - def _resizeSlidersArea(self): - self.navigateScrollBar.setFixedHeight(self.newHeight) - self.zSliceScrollBar.setFixedHeight(self.newHeight) - self.zSliceOverlay_SB.setFixedHeight(self.newHeight) - self.zProjComboBox.setFixedHeight(self.newHeight) - self.zProjOverlay_CB.setFixedHeight(self.newHeight) - self.navSpinBox.setFixedHeight(self.newHeight) - self.zSliceSpinbox.setFixedHeight(self.newHeight) - try: - self.img1.alphaScrollbar.setFixedHeight(self.newHeight) - except Exception as e: - pass - try: - for channel, items in self.overlayLayersItems.items(): - alphaScrollbar = items[2] - alphaScrollbar.setFixedHeight(self.newHeight) - except: - pass - checkBoxStyleSheet = ( - 'QCheckBox::indicator {' - f'width: {self.newCheckBoxesHeight}px;' - f'height: {self.newCheckBoxesHeight}px' - '}' - ) - for i in range(self.annotOptionsLayout.count()): - widget = self.annotOptionsLayout.itemAt(i).widget() - if isinstance(widget, QCheckBox): - widget.setStyleSheet(checkBoxStyleSheet) - for i in range(self.annotOptionsLayoutRight.count()): - widget = self.annotOptionsLayoutRight.itemAt(i).widget() - if isinstance(widget, QCheckBox): - widget.setStyleSheet(checkBoxStyleSheet) - self.zSliceCheckbox.setStyleSheet(checkBoxStyleSheet) - - def resizeEvent(self, event): - if hasattr(self, 'ax1'): - self.ax1.autoRange() - - def hoverEventDrawSpline(self, event): - x, y = event.pos() - xx, yy = self.curvAnchors.getData() - hoverAnchors = self.curvAnchors.pointsAt(event.pos()) - per = False - # If we are hovering the starting point we generate - # a closed spline - if len(xx) < 2: - return - - if len(hoverAnchors)>0: - xA_hover, yA_hover = hoverAnchors[0].pos() - if xx[0]==xA_hover and yy[0]==yA_hover: - per=True - if per: - # Append start coords and close spline - xx = np.r_[xx, xx[0]] - yy = np.r_[yy, yy[0]] - xi, yi = self.getSpline(xx, yy, per=per) - # self.curvPlotItem.setData([], []) - else: - # Append mouse coords - xx = np.r_[xx, x] - yy = np.r_[yy, y] - xi, yi = self.getSpline(xx, yy, per=per) - self.curvHoverPlotItem.setData(xi, yi) - - def updateViewRangeExportToImage(self, viewRange): - if self.exportToImageWindow is None: - return - - # prevViewRange = self.exportToImageWindow.viewRange() - prevViewRange = self._viewRange - prevXRange = prevViewRange[0] - prevYRange = prevViewRange[1] - currXRange = viewRange[0] - currYRange = viewRange[1] - - prevX0, prevX1 = prevXRange - currX0, currX1 = currXRange - prevY0, prevY1 = prevYRange - currY0, currY1 = currYRange - - deltaX = currX0 - prevX0 - deltaY = currY0 - prevY0 - - winViewRange = self.exportToImageWindow.viewRange() - winXRange = winViewRange[0] - winYRange = winViewRange[1] - winX0, winX1 = winXRange - winY0, winY1 = winYRange - - newX0 = winX0 + deltaX - newX1 = winX1 + deltaX - newY0 = winY0 + deltaY - newY1 = winY1 + deltaY - - self.exportToImageWindow.setViewRange( - (newX0, newX1), (newY0, newY1), emitSignal=False - ) - - def viewRangeChanged(self, viewBox, viewRange, updateExportImageMask=True): - # self.updateViewRangeExportToImage(viewRange) - self.updateValuesStatusBar() - - if hasattr(self, 'scaleBar'): - isScaleBarMoveWithZoom = ( - self.scaleBar.properties()['move_with_zoom'] - ) - else: - isScaleBarMoveWithZoom = False - doMoveScaleBar = ( - self.scaleBarDialog is not None or isScaleBarMoveWithZoom - ) - if doMoveScaleBar: - self.scaleBar.updatePosViewRangeChanged(viewRange) - - if hasattr(self, 'timestamp'): - isTimestampMoveWithZoom = ( - self.timestamp.properties()['move_with_zoom'] - ) - else: - isTimestampMoveWithZoom = False - - doMoveTimestamp = ( - self.timestampDialog is not None or isTimestampMoveWithZoom - ) - if doMoveTimestamp: - self.timestamp.updatePosViewRangeChanged(viewRange) - - self._viewRange = viewRange + self.logger.info("GUI ready.") diff --git a/cellacdc/gui_decorators.py b/cellacdc/gui_decorators.py new file mode 100644 index 000000000..72dbb09d7 --- /dev/null +++ b/cellacdc/gui_decorators.py @@ -0,0 +1,81 @@ +"""Decorators shared by guiWin and mixins.""" + +from __future__ import annotations + +import os +import traceback +from functools import wraps + +from qtpy.QtCore import QTimer + +from . import html_utils, widgets + + +def get_data_exception_handler(func): + @wraps(func) + def inner_function(self, *args, **kwargs): + try: + if func.__code__.co_argcount == 1 and func.__defaults__ is None: + result = func(self) + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: + result = func(self, *args) + else: + result = func(self, *args, **kwargs) + except Exception as e: + try: + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + except AttributeError: + pass + result = None + posData = self.data[self.pos_i] + acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) + segm_filename = os.path.basename(posData.segm_npz_path) + traceback_str = traceback.format_exc() + self.logger.exception(traceback_str) + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + msg.addShowInFileManagerButton(self.logs_path, txt="Show log file...") + msg.setDetailedText(traceback_str) + err_msg = html_utils.paragraph(f""" + Error in function {func.__name__}.

+ One possbile explanation is that either the + {acdc_df_filename} file
+ or the segmentation file {segm_filename}
+ are being synchronized by a cloud service (e.g., Google Drive + or OneDrive) or they are corrupted/damaged.

+ Try moving these files (one by one) outside of the + {os.path.dirname(posData.relPath)} folder +
and reloading the data.

+ More details below or in the terminal/console.

+ Note that the error details from this session are + also saved in the following file:

+ {self.log_path}

+ Please send the log file when reporting a bug, thanks! + Please restart Cell-ACDC, we apologise for any inconvenience.

+ + """) + + msg.critical(self, "Critical error", err_msg) + self.is_error_state = True + raise e + return result + + return inner_function + + +def resetViewRange(func): + @wraps(func) + def inner_function(self, *args, **kwargs): + self.storeViewRange() + if func.__code__.co_argcount == 1 and func.__defaults__ is None: + result = func(self) + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: + result = func(self, *args) + else: + result = func(self, *args, **kwargs) + QTimer.singleShot(200, self.resetRange) + return result + + return inner_function diff --git a/cellacdc/gui_utils.py b/cellacdc/gui_utils.py index 9fa55d5e8..5f703b3a3 100644 --- a/cellacdc/gui_utils.py +++ b/cellacdc/gui_utils.py @@ -1,5 +1,6 @@ import numpy as np + def nearest_ID_to_centroid(a, y, x, max_iterations=10, distance_threshold=5): """ Return cell ID by checking `max_iterations` nearest non-zero pixels @@ -8,11 +9,11 @@ def nearest_ID_to_centroid(a, y, x, max_iterations=10, distance_threshold=5): """ if not isinstance(a, np.ndarray): a = np.array(a) # Ensure a is a numpy array - + r, c = np.nonzero(a) if r.size == 0: return None - + distances = np.linalg.norm(np.array([r, c]).T - np.array([y, x]), axis=1) sorted_indices = np.argsort(distances) sorted_IDs = a[r, c][sorted_indices] diff --git a/cellacdc/help/about.py b/cellacdc/help/about.py index 0b0239241..f9c7e3722 100755 --- a/cellacdc/help/about.py +++ b/cellacdc/help/about.py @@ -5,102 +5,112 @@ from functools import partial from qtpy.QtWidgets import ( - QDialog, QLabel, QGridLayout, QHBoxLayout, QSpacerItem, QApplication, - QVBoxLayout + QDialog, + QLabel, + QGridLayout, + QHBoxLayout, + QSpacerItem, + QApplication, + QVBoxLayout, ) from qtpy.QtGui import QPixmap from qtpy.QtCore import Qt from qtpy import QtCore -from ..myutils import read_version, get_date_from_version -from ..myutils import get_pip_install_cellacdc_version_command -from ..myutils import get_git_pull_checkout_cellacdc_version_commands -from ..myutils import get_info_version_text -from .. import widgets, myutils +from ..utils import read_version, get_date_from_version +from ..utils import get_pip_install_cellacdc_version_command +from ..utils import get_git_pull_checkout_cellacdc_version_commands +from ..utils import get_info_version_text +from .. import widgets, utils from .. import html_utils, printl from .. import cellacdc_path + class QDialogAbout(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowFlags(Qt.Dialog) - self.setWindowTitle('About Cell-ACDC') + self.setWindowTitle("About Cell-ACDC") layout = QGridLayout() - + version = read_version() release_date = get_date_from_version(version) - + py_ver = sys.version_info - python_version = f'{py_ver.major}.{py_ver.minor}.{py_ver.micro}' + python_version = f"{py_ver.major}.{py_ver.minor}.{py_ver.micro}" titleLabel = QLabel() - txt = (f""" + txt = f"""

Cell-ACDC (Analysis of the Cell Division Cycle)

- """) - + """ + info_txts = get_info_version_text(cli_formatted_text=False) for info_txt in info_txts: paragraph = html_utils.paragraph(info_txt) - txt = f'{txt}{paragraph}' + txt = f"{txt}{paragraph}" titleLabel.setText(txt) titleLabel.setText(txt) - + # '{next_version}.dev{distance}+{scm letter}{revision hash}' command, command_github = get_pip_install_cellacdc_version_command( version=version ) - commandLabel = QLabel(html_utils.paragraph( - f'To install this specific version ' - f'on a new environment or to upgrade/downgrade in an ' - 'environment where you already have Cell-ACDC
' - 'installed with pip run the following command:' - )) - commandWidget = widgets.CopiableCommandWidget( - command=command, font_size='11px' + commandLabel = QLabel( + html_utils.paragraph( + f"To install this specific version " + f"on a new environment or to upgrade/downgrade in an " + "environment where you already have Cell-ACDC
" + "installed with pip run the following command:" + ) ) - + commandWidget = widgets.CopiableCommandWidget(command=command, font_size="11px") + if command_github is not None: - commandLabelGh = QLabel(html_utils.paragraph( - f'If the command above fails, it means that this ' - f'specific version was not released on PyPi yet.

' - 'In that case, you need to run the following command instead:' - )) + commandLabelGh = QLabel( + html_utils.paragraph( + f"If the command above fails, it means that this " + f"specific version was not released on PyPi yet.

" + "In that case, you need to run the following command instead:" + ) + ) commandGhWidget = widgets.CopiableCommandWidget( - command=command_github, font_size='11px' + command=command_github, font_size="11px" ) - + commandWidgetsGit = [] git_commands = get_git_pull_checkout_cellacdc_version_commands(version) if git_commands: - commandLabelGit = QLabel(html_utils.paragraph( - f'

To upgrade/downgrade the Cell-ACDC version in an ' - 'environment where you installed it by first cloning with ' - 'git
' - 'run the following commands one by one:' - )) + commandLabelGit = QLabel( + html_utils.paragraph( + f"

To upgrade/downgrade the Cell-ACDC version in an " + "environment where you installed it by first cloning with " + "git
" + "run the following commands one by one:" + ) + ) for command in git_commands: commandWidgetsGit.append( - widgets.CopiableCommandWidget(command=command, font_size='11px') + widgets.CopiableCommandWidget(command=command, font_size="11px") ) - + iconPixmap = QPixmap(":icon.ico") h = 128 - iconPixmap = iconPixmap.scaled(h,h, aspectRatioMode=Qt.KeepAspectRatio) + iconPixmap = iconPixmap.scaled(h, h, aspectRatioMode=Qt.KeepAspectRatio) iconLabel = QLabel() iconLabel.setPixmap(iconPixmap) - github_url = r'https://github.com/SchmollerLab/Cell_ACDC' + github_url = r"https://github.com/SchmollerLab/Cell_ACDC" infoLabel = QLabel() - infoLabel.setTextInteractionFlags(Qt.TextBrowserInteraction); - infoLabel.setOpenExternalLinks(True); + infoLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) + infoLabel.setOpenExternalLinks(True) txt = html_utils.paragraph(f"""
More info on our home page.
""") @@ -108,28 +118,27 @@ def __init__(self, parent=None): installedLayout = QHBoxLayout() installedLabel = QLabel() - txt = html_utils.paragraph(f""" + txt = html_utils.paragraph( + f""" Installed in: {cellacdc_path} - """, font_size='12px') + """, + font_size="12px", + ) installedLabel.setText(txt) installedLabel.setTextInteractionFlags(Qt.TextSelectableByMouse) - - self.copyCellACDCpathButton = widgets.copyPushButton('Copy path') - self.copyCellACDCpathButton.clicked.connect( - self.copyCellACDCpath - ) - + + self.copyCellACDCpathButton = widgets.copyPushButton("Copy path") + self.copyCellACDCpathButton.clicked.connect(self.copyCellACDCpath) + self.showHowToInstallButton = widgets.helpPushButton( - 'How to install this version' - ) - self.showHowToInstallButton.clicked.connect( - self.showHotToInstallInstructions + "How to install this version" ) + self.showHowToInstallButton.clicked.connect(self.showHotToInstallInstructions) button = widgets.showInFileManagerButton( - myutils.get_open_filemaneger_os_string() + utils.get_open_filemaneger_os_string() ) - func = partial(myutils.showInExplorer, cellacdc_path) + func = partial(utils.showInExplorer, cellacdc_path) button.clicked.connect(func) installedLayout.addWidget(installedLabel) installedLayout.addWidget(self.copyCellACDCpathButton) @@ -140,74 +149,77 @@ def __init__(self, parent=None): row = 0 layout.addWidget(iconLabel, row, 0) layout.addWidget(titleLabel, row, 1, alignment=Qt.AlignLeft) - + row += 1 layout.addWidget(infoLabel, row, 1, alignment=Qt.AlignLeft) - + row += 1 - layout.setColumnStretch(2,1) - layout.addItem(QSpacerItem(10,20), row, 1) - + layout.setColumnStretch(2, 1) + layout.addItem(QSpacerItem(10, 20), row, 1) + row += 1 - layout.setRowStretch(row,1) - + layout.setRowStretch(row, 1) + row += 1 layout.addLayout(installedLayout, row, 0, 1, 3) - + row += 1 self.howToInstallDialog = QDialog(self) - self.howToInstallDialog.setWindowTitle( - f'How to install Cell-ACDC v{version}' - ) + self.howToInstallDialog.setWindowTitle(f"How to install Cell-ACDC v{version}") howToInstallLayout = QVBoxLayout() self.howToInstallDialog.setLayout(howToInstallLayout) - - howToInstallOkButton = widgets.okPushButton(' Ok ') + + howToInstallOkButton = widgets.okPushButton(" Ok ") buttonsLayout = QHBoxLayout() buttonsLayout.addStretch(1) buttonsLayout.addWidget(howToInstallOkButton) howToInstallOkButton.clicked.connect(self.howToInstallDialog.close) - + howToInstallLayout.addWidget(commandLabel, alignment=Qt.AlignLeft) howToInstallLayout.addWidget(commandWidget, alignment=Qt.AlignLeft) - + if command_github is not None: howToInstallLayout.addWidget(commandLabelGh, alignment=Qt.AlignLeft) howToInstallLayout.addWidget(commandGhWidget, alignment=Qt.AlignLeft) - + if git_commands: howToInstallLayout.addWidget(commandLabelGit, alignment=Qt.AlignLeft) for widget in commandWidgetsGit: howToInstallLayout.addWidget(widget, alignment=Qt.AlignLeft) - + howToInstallLayout.addSpacing(20) - importantText = html_utils.to_admonition(""" + importantText = html_utils.to_admonition( + """ Whenever you run commands with pip make sure to FIRST activate the correct environment (e.g. with conda activate acdc
) - """, admonition_type='important') - + """, + admonition_type="important", + ) + howToInstallLayout.addWidget(QLabel(importantText)) - + # layout.addWidget(self.howToInstallWidget, row, 0, 1, 3) howToInstallLayout.addLayout(buttonsLayout) self.howToInstallDialog.hide() - + self.setLayout(layout) - + def copyCellACDCpath(self): cb = QApplication.clipboard() cb.clear(mode=cb.Clipboard) cb.setText(cellacdc_path, mode=cb.Clipboard) - + def showHotToInstallInstructions(self): self.howToInstallDialog.show() + def _test(): import sys from qtpy.QtWidgets import QStyleFactory, QApplication + app = QApplication(sys.argv) - app.setStyle(QStyleFactory.create('Fusion')) + app.setStyle(QStyleFactory.create("Fusion")) win = QDialogAbout() win.show() app.exec_() diff --git a/cellacdc/help/welcome.py b/cellacdc/help/welcome.py index 6cafa9cf2..8b5deef68 100755 --- a/cellacdc/help/welcome.py +++ b/cellacdc/help/welcome.py @@ -5,37 +5,51 @@ import pandas as pd import numpy as np -from qtpy.QtGui import ( - QIcon, QFont, QFontMetrics, QPixmap, QPalette, QColor -) -from qtpy.QtCore import ( - Qt, QSize, QEvent, Signal, QObject, QThread, QTimer -) +from qtpy.QtGui import QIcon, QFont, QFontMetrics, QPixmap, QPalette, QColor +from qtpy.QtCore import Qt, QSize, QEvent, Signal, QObject, QThread, QTimer from qtpy.QtWidgets import ( - QApplication, QWidget, QGridLayout, QTextEdit, QPushButton, - QListWidget, QListWidgetItem, QCheckBox, QFrame, QStyleFactory, - QLabel, QTreeWidget, QTreeWidgetItem, QTreeWidgetItemIterator, - QScrollArea, QComboBox, QHBoxLayout, QToolButton, QMainWindow, - QProgressBar, QAction + QApplication, + QWidget, + QGridLayout, + QTextEdit, + QPushButton, + QListWidget, + QListWidgetItem, + QCheckBox, + QFrame, + QStyleFactory, + QLabel, + QTreeWidget, + QTreeWidgetItem, + QTreeWidgetItemIterator, + QScrollArea, + QComboBox, + QHBoxLayout, + QToolButton, + QMainWindow, + QProgressBar, + QAction, ) script_path = os.path.dirname(os.path.realpath(__file__)) -from .. import gui, dataStruct, myutils, cite_url, html_utils, urls, widgets +from .. import gui, dataStruct, utils, cite_url, html_utils, urls, widgets from .. import _palettes # NOTE: Enable icons from .. import cellacdc_path, settings_folderpath -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except Exception as e: pass + class downloadWorker(QObject): finished = Signal() progress = Signal(int, int) @@ -45,11 +59,10 @@ def __init__(self, which): self.which = which def run(self): - self.exp_path = myutils.download_examples( - self.which, progress=self.progress - ) + self.exp_path = utils.download_examples(self.which, progress=self.progress) self.finished.emit() + class QHLine(QFrame): def __init__(self): super(QHLine, self).__init__() @@ -63,7 +76,7 @@ def __init__(self, parent=None, mainWin=None, app=None): self.mainWin = mainWin self.app = app super().__init__(parent) - self.setWindowTitle('Welcome') + self.setWindowTitle("Welcome") self.setWindowIcon(QIcon(":icon.ico")) self.loadSettings() @@ -103,23 +116,23 @@ def __init__(self, parent=None, mainWin=None, app=None): # self.setDebuggingTools() def setDebuggingTools(self): - self.debugButton = QPushButton('debug') + self.debugButton = QPushButton("debug") self.debugButton.clicked.connect(self.debug) self.mainLayout.addWidget(self.debugButton, 2, 0) # self.debugAction.hide() def loadSettings(self): - csv_path = os.path.join(settings_folderpath, 'settings.csv') + csv_path = os.path.join(settings_folderpath, "settings.csv") if os.path.exists(csv_path): - self.df_settings = pd.read_csv(csv_path, index_col='setting') - if 'showWelcomeGuide' not in self.df_settings.index: - self.df_settings.at['showWelcomeGuide', 'value'] = 'Yes' + self.df_settings = pd.read_csv(csv_path, index_col="setting") + if "showWelcomeGuide" not in self.df_settings.index: + self.df_settings.at["showWelcomeGuide", "value"] = "Yes" else: - idx = ['showWelcomeGuide'] - values = ['Yes'] - self.df_settings = pd.DataFrame({'setting': idx, - 'value': values} - ).set_index('setting') + idx = ["showWelcomeGuide"] + values = ["Yes"] + self.df_settings = pd.DataFrame( + {"setting": idx, "value": values} + ).set_index("setting") self.df_settings.to_csv(csv_path) self.df_settings_path = csv_path @@ -129,13 +142,13 @@ def addtreeSelector(self): treeSelector.setFrameStyle(QFrame.Shape.NoFrame) self.welcomeItem = QTreeWidgetItem(treeSelector) - self.welcomeItem.setIcon(0, QIcon(':home.svg')) - self.welcomeItem.setText(0, 'Welcome') + self.welcomeItem.setIcon(0, QIcon(":home.svg")) + self.welcomeItem.setText(0, "Welcome") treeSelector.addTopLevelItem(self.welcomeItem) self.quickStartItem = QTreeWidgetItem(treeSelector) - self.quickStartItem.setIcon(0, QIcon(':quickStart.svg')) - self.quickStartItem.setText(0, 'Quick Start') + self.quickStartItem.setIcon(0, QIcon(":quickStart.svg")) + self.quickStartItem.setText(0, "Quick Start") treeSelector.addTopLevelItem(self.quickStartItem) # self.settingsItem = QTreeWidgetItem(treeSelector) @@ -144,18 +157,17 @@ def addtreeSelector(self): # treeSelector.addTopLevelItem(self.settingsItem) self.manualItem = QTreeWidgetItem(treeSelector) - self.manualItem.setIcon(0, QIcon(':book.svg')) + self.manualItem.setIcon(0, QIcon(":book.svg")) # textLabel = QLabel() # textLabel.setText(""" #

# User Manual #

# """) - self.manualItem.setText(0, 'User Manual') + self.manualItem.setText(0, "User Manual") treeSelector.addTopLevelItem(self.manualItem) # treeSelector.setItemWidget(self.manualItem, 0, textLabel) - # self.manualDataPrepItem = QTreeWidgetItem(self.manualItem) # self.manualDataPrepItem.setText(0, ' Data Prep module') # self.manualItem.addChild(self.manualDataPrepItem) @@ -167,8 +179,8 @@ def addtreeSelector(self): # self.manualItem.addChild(self.manualGUIItem) self.contributeItem = QTreeWidgetItem(treeSelector) - self.contributeItem.setIcon(0, QIcon(':contribute.svg')) - self.contributeItem.setText(0, 'Contribute') + self.contributeItem.setIcon(0, QIcon(":contribute.svg")) + self.contributeItem.setText(0, "Contribute") treeSelector.addTopLevelItem(self.contributeItem) # treeSelector.setSpacing(3) @@ -189,7 +201,6 @@ def treeItemChanged(self, currentItem, prevItem=None): else: frame.hide() - def addWelcomePage(self): self.welcomeFrame = QFrame(self) welcomeLayout = QGridLayout() @@ -202,8 +213,7 @@ def addWelcomePage(self): # welcomeTextWidget.setFrameStyle(QFrame.Shape.NoFrame) # welcomeTextWidget.viewport().setAutoFillBackground(False) - htmlTxt = ( - """ + htmlTxt = """ @@ -252,37 +262,37 @@ def addWelcomePage(self): """ - ) # welcomeTextWidget.setHtml(htmlTxt) welcomeTextWidget.setText(htmlTxt) welcomeTextWidget.linkActivated.connect(self.linkActivated_cb) - welcomeLayout.addWidget(welcomeTextWidget, 0, 0, 1, 5, - alignment=Qt.AlignTop) + welcomeLayout.addWidget(welcomeTextWidget, 0, 0, 1, 5, alignment=Qt.AlignTop) - startWizardButton = QPushButton(' Launch Wizard') - startWizardButton.setIcon(QIcon(':wizard.svg')) + startWizardButton = QPushButton(" Launch Wizard") + startWizardButton.setIcon(QIcon(":wizard.svg")) startWizardButton.clicked.connect(self.launchDataStruct) welcomeLayout.addWidget(startWizardButton, 1, 0) - testMyImageButton = QPushButton(' Test segmentation with my image/video') - testMyImageButton.setIcon(QIcon(':image.svg')) + testMyImageButton = QPushButton(" Test segmentation with my image/video") + testMyImageButton.setIcon(QIcon(":image.svg")) testMyImageButton.clicked.connect(self.openGUIsingleImage) welcomeLayout.addWidget(testMyImageButton, 1, 1) testTimeLapseButton = QPushButton( - text='Download and test with a time-lapse example') - testTimeLapseButton.setIcon(QIcon(':download.svg')) + text="Download and test with a time-lapse example" + ) + testTimeLapseButton.setIcon(QIcon(":download.svg")) testTimeLapseButton.clicked.connect(self.testTimeLapseExample) welcomeLayout.addWidget(testTimeLapseButton, 1, 2) test3DzStackButton = QPushButton( - text='Download and test with a 3D z-stack example') - test3DzStackButton.setIcon(QIcon(':download.svg')) + text="Download and test with a 3D z-stack example" + ) + test3DzStackButton.setIcon(QIcon(":download.svg")) test3DzStackButton.clicked.connect(self.test3DzStacksExample) welcomeLayout.addWidget(test3DzStackButton, 1, 3) @@ -308,13 +318,12 @@ def addQuickStartPage(self): QuickStartLayout = QGridLayout() - fs = 13 # font size + fs = 13 # font size row = 0 QuickStartTextWidget = QLabel() - htmlHead = ( - """ + htmlHead = """ @@ -330,10 +339,8 @@ def addQuickStartPage(self): """ - ) - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -349,18 +356,15 @@ def addQuickStartPage(self): """ - ) QuickStartTextWidget.setText(htmlTxt) QuickStartTextWidget.linkActivated.connect(self.linkActivated_cb) - QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, alignment=Qt.AlignTop) row += 1 QuickStartTextWidget = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -390,19 +394,16 @@ def addQuickStartPage(self): """ - ) QuickStartTextWidget.setText(htmlTxt) QuickStartTextWidget.linkActivated.connect(self.linkActivated_cb) - QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -416,14 +417,12 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 - pixmap = QPixmap(':toolbar.png') + pixmap = QPixmap(":toolbar.png") label = QLabel() # padding: top, left, bottom, right label.setStyleSheet("padding:5px 0px 10px 40px;") @@ -433,8 +432,7 @@ def addQuickStartPage(self): row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -447,18 +445,15 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 10px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QS_tipTxtLabel.setStyleSheet("padding-bottom: 10px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -472,14 +467,12 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 - pixmap = QPixmap(':toolTip.png') + pixmap = QPixmap(":toolTip.png") label = QLabel() label.setStyleSheet("padding:5px 0px 10px 40px;") label.setPixmap(pixmap) @@ -488,8 +481,7 @@ def addQuickStartPage(self): row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -503,18 +495,15 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 10px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QS_tipTxtLabel.setStyleSheet("padding-bottom: 10px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -556,18 +545,15 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 10px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QS_tipTxtLabel.setStyleSheet("padding-bottom: 10px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -580,19 +566,15 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 10px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QS_tipTxtLabel.setStyleSheet("padding-bottom: 10px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -606,21 +588,19 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 modeComboBox = QComboBox() - modeComboBox.addItems(['Segmentation and Tracking', - 'Cell cycle analysis', - 'Viewer']) - modeComboBox.setCurrentText('Viewer') + modeComboBox.addItems( + ["Segmentation and Tracking", "Cell cycle analysis", "Viewer"] + ) + modeComboBox.setCurrentText("Viewer") modeComboBox.setFocusPolicy(Qt.StrongFocus) modeComboBox.installEventFilter(self) - modeComboBoxLabel = QLabel(' Mode: ') + modeComboBoxLabel = QLabel(" Mode: ") layout = QHBoxLayout() layout.addWidget(modeComboBoxLabel) layout.addWidget(modeComboBox) @@ -631,8 +611,7 @@ def addQuickStartPage(self): row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -646,18 +625,14 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 10px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) - + QS_tipTxtLabel.setStyleSheet("padding-bottom: 10px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -671,27 +646,22 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-bottom: 8px') + QS_tipTxtLabel.setStyleSheet("padding-bottom: 8px") viewerButton = QToolButton() - viewerButton.setIcon(QIcon(':eye-plus.svg')) - viewerButton.setIconSize(QSize(24, 24)); - + viewerButton.setIcon(QIcon(":eye-plus.svg")) + viewerButton.setIconSize(QSize(24, 24)) layout = QHBoxLayout() layout.addWidget(QS_tipTxtLabel, alignment=Qt.AlignBottom) layout.addWidget(viewerButton) layout.addStretch(1) QuickStartLayout.addLayout(layout, row, 0) - - row += 1 QS_tipTxtLabel = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -706,17 +676,14 @@ def addQuickStartPage(self):

""" - ) QS_tipTxtLabel.setText(htmlTxt) - QS_tipTxtLabel.setStyleSheet('padding-top: 2px') - QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, - alignment=Qt.AlignTop) + QS_tipTxtLabel.setStyleSheet("padding-top: 2px") + QuickStartLayout.addWidget(QS_tipTxtLabel, row, 0, alignment=Qt.AlignTop) - row +=1 + row += 1 QuickStartTextWidget = QLabel() - htmlTxt = ( - f""" + htmlTxt = f""" {htmlHead}
@@ -729,29 +696,28 @@ def addQuickStartPage(self): """ - ) QuickStartTextWidget.setText(htmlTxt) - QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, - alignment=Qt.AlignTop) + QuickStartLayout.addWidget(QuickStartTextWidget, row, 0, alignment=Qt.AlignTop) row += 1 layout = QHBoxLayout() - testMyImage = QPushButton( - text='Test segmentation with my image/video') - testMyImage.setIcon(QIcon(':image.svg')) + testMyImage = QPushButton(text="Test segmentation with my image/video") + testMyImage.setIcon(QIcon(":image.svg")) layout.addWidget(testMyImage) testMyImage.clicked.connect(self.openGUIsingleImage) testTimeLapseButton = QPushButton( - text='Download and test with a time-lapse example') - testTimeLapseButton.setIcon(QIcon(':download.svg')) + text="Download and test with a time-lapse example" + ) + testTimeLapseButton.setIcon(QIcon(":download.svg")) layout.addWidget(testTimeLapseButton) testTimeLapseButton.clicked.connect(self.testTimeLapseExample) test3DzStackButton = QPushButton( - text='Download and test with a 3D z-stack example') - test3DzStackButton.setIcon(QIcon(':download.svg')) + text="Download and test with a 3D z-stack example" + ) + test3DzStackButton.setIcon(QIcon(":download.svg")) layout.addWidget(test3DzStackButton) test3DzStackButton.clicked.connect(self.test3DzStacksExample) @@ -775,9 +741,9 @@ def addManualPage(self): manualLayout = QGridLayout() openManualButton = widgets.showInFileManagerButton( - ' Download and open user manual... ' + " Download and open user manual... " ) - openManualButton.clicked.connect(myutils.browse_docs) + openManualButton.clicked.connect(utils.browse_docs) buttonLayout = QHBoxLayout() buttonLayout.addWidget(openManualButton) @@ -794,15 +760,15 @@ def addContributePage(self): layout = QGridLayout() - contribute_href = html_utils.href_tag('here', urls.contribute_url) - github_href = html_utils.href_tag('GitHub page', urls.github_url) - issues_href = html_utils.href_tag('Issues', urls.issues_url) - forum_href = html_utils.href_tag('Discussions', urls.forum_url) - resources_href = html_utils.href_tag('here', urls.resources_url) - my_contact_href = html_utils.href_tag('my email', urls.my_contact_url) - user_manual_href = html_utils.href_tag('User Manual', urls.user_manual_url) + contribute_href = html_utils.href_tag("here", urls.contribute_url) + github_href = html_utils.href_tag("GitHub page", urls.github_url) + issues_href = html_utils.href_tag("Issues", urls.issues_url) + forum_href = html_utils.href_tag("Discussions", urls.forum_url) + resources_href = html_utils.href_tag("here", urls.resources_url) + my_contact_href = html_utils.href_tag("my email", urls.my_contact_url) + user_manual_href = html_utils.href_tag("User Manual", urls.user_manual_url) - text = (f""" + text = f"""

Here at Cell-ACDC we want to keep a community-centred approach.

@@ -833,7 +799,7 @@ def addContributePage(self): Additional resources {resources_href}.

- """) + """ label = QLabel() label.setText(text) @@ -844,28 +810,27 @@ def addContributePage(self): self.mainLayout.addWidget(self.contributeFrame, 0, 1) self.itemsDict[self.contributeItem.text(0)] = self.contributeFrame - def linkActivated_cb(self, link): - if link == 'DataPrepMore': + if link == "DataPrepMore": pass - elif link == 'paper': + elif link == "paper": url = cite_url webbrowser.open(url) - elif link == 'tweet': - url = 'https://twitter.com/frank_pado/status/1443957038841794561?s=20' + elif link == "tweet": + url = "https://twitter.com/frank_pado/status/1443957038841794561?s=20" webbrowser.open(url) - elif link == 'segmMore': + elif link == "segmMore": pass - elif link == 'guiMore': + elif link == "guiMore": pass - elif link == 'quickStart': + elif link == "quickStart": self.showPage(self.quickStartItem) - elif link == 'userManual': + elif link == "userManual": self.showPage(self.manualItem) def addShowGuideCheckbox(self): - checkBox = QCheckBox('Show Welcome Guide when opening Cell-ACDC') - checked = self.df_settings.at['showWelcomeGuide', 'value'] == 'Yes' + checkBox = QCheckBox("Show Welcome Guide when opening Cell-ACDC") + checked = self.df_settings.at["showWelcomeGuide", "value"] == "Yes" checkBox.setChecked(checked) self.mainLayout.addWidget(checkBox, 1, 1, alignment=Qt.AlignRight) @@ -873,13 +838,13 @@ def addShowGuideCheckbox(self): def showWelcomeGuideCheckBox_cb(self, state): if state == 0: - show = 'No' + show = "No" else: - show = 'Yes' - self.df_settings.loc['showWelcomeGuide'] = ( - self.df_settings.loc['showWelcomeGuide'].astype(str) - ) - self.df_settings.at['showWelcomeGuide', 'value'] = show + show = "Yes" + self.df_settings.loc["showWelcomeGuide"] = self.df_settings.loc[ + "showWelcomeGuide" + ].astype(str) + self.df_settings.at["showWelcomeGuide", "value"] = show self.saveSettings() def saveSettings(self): @@ -893,10 +858,8 @@ def openGUIsingleImage(self): You will then be asked to select an image file (e.g., .tif or .png), or a video file (e.g., .avi). """) - msg.information( - self, 'Test with my image', txt - ) - + msg.information(self, "Test with my image", txt) + if self.mainWin is not None: self.mainWin.launchGui() guiWin = self.mainWin.guiWins[-1] @@ -910,9 +873,7 @@ def openGUIfolder(self, exp_path): if self.mainWin is not None: self.mainWin.launchGui() guiWin = self.mainWin.guiWins[-1] - QTimer.singleShot( - 200, partial(guiWin.openFolder, exp_path=exp_path) - ) + QTimer.singleShot(200, partial(guiWin.openFolder, exp_path=exp_path)) else: self.guiWin = gui.guiWin(self.app) self.guiWin.showAndSetSize() @@ -927,14 +888,12 @@ def addPbar(self): self.welcomeLayout.addWidget(self.QPbar, 3, 0, 1, 3) def testTimeLapseExample(self, checked=True): - _, example_path, _, _ = myutils.get_examples_path('time_lapse_2D') - txt = ( - f""" + _, example_path, _, _ = utils.get_examples_path("time_lapse_2D") + txt = f"""


Downloading example to {example_path}...

""" - ) self.infoTextWidget.setText(txt) if self.QPbar is None: @@ -943,7 +902,7 @@ def testTimeLapseExample(self, checked=True): self.QPbar.setVisible(True) self.thread = QThread() - self.worker = downloadWorker('time_lapse_2D') + self.worker = downloadWorker("time_lapse_2D") self.worker.moveToThread(self.thread) self.worker.progress.connect(self.downloadProgress) self.worker.finished.connect(self.thread.quit) @@ -958,7 +917,7 @@ def downloadProgress(self, file_size, len_chunk): if file_size != -1: self.QPbar.setMaximum(file_size) elif len_chunk != -1: - self.QPbar.setValue(self.QPbar.value()+len_chunk) + self.QPbar.setValue(self.QPbar.value() + len_chunk) elif len_chunk == 0: self.QPbar.setValue(self.QPbar.maximum()) @@ -971,40 +930,37 @@ def downloadExampleWorkerFinished(self): Do you want to open it in the GUI? """) _, yesButton = msg.question( - self, 'Open downloaded dataset?', txt, - buttonsTexts=('No, thanks', 'Yes, please, open the GUI'), + self, + "Open downloaded dataset?", + txt, + buttonsTexts=("No, thanks", "Yes, please, open the GUI"), commands=(self.worker.exp_path,), - path_to_browse=self.worker.exp_path + path_to_browse=self.worker.exp_path, ) self.infoTextWidget.setText( - '
Example downloaded to ' - f'{self.worker.exp_path}.
' + f"
Example downloaded to {self.worker.exp_path}.
" ) if msg.clickedButton == yesButton: self.openGUIexample() - + def openGUIexample(self): - txt = ( - f""" + txt = f"""


Example downloaded to {self.worker.exp_path}.
Opening GUI...

""" - ) self.infoTextWidget.setText(txt) self.openGUIfolder(self.worker.exp_path) def test3DzStacksExample(self, checked=True): - _, example_path, _, _ = myutils.get_examples_path('snapshots_3D') - txt = ( - f""" + _, example_path, _, _ = utils.get_examples_path("snapshots_3D") + txt = f"""


Downloading example to {example_path}...

""" - ) self.infoTextWidget.setText(txt) if self.QPbar is None: @@ -1013,7 +969,7 @@ def test3DzStacksExample(self, checked=True): self.QPbar.setVisible(True) self.thread = QThread() - self.worker = downloadWorker('snapshots_3D') + self.worker = downloadWorker("snapshots_3D") self.worker.moveToThread(self.thread) self.worker.progress.connect(self.downloadProgress) self.worker.finished.connect(self.thread.quit) @@ -1030,7 +986,7 @@ def debug(self): def showAndSetSize(self): font = QFont() font.setPixelSize(13) - font.setFamily('Ubuntu') + font.setFamily("Ubuntu") self.treeSelector.setFont(font) self.showPage(self.quickStartItem) @@ -1044,7 +1000,7 @@ def showAndSetSize(self): def resizeScrollbar(self): if self.quickStartScrollArea.horizontalScrollBar().isVisible(): - self.resize(self.width()+5, self.height()) + self.resize(self.width() + 5, self.height()) else: self.timer.stop() self.moveWindow() @@ -1065,21 +1021,20 @@ def moveWindow(self): left = screenLeft + 10 top = screenTop + 70 width = w - height = int(h*Dh) - if height > 0.9*screenHeight: - height = int(0.9*screenHeight) + height = int(h * Dh) + if height > 0.9 * screenHeight: + height = int(0.9 * screenHeight) self.setGeometry(left, top, width, height) if self.mainWin is not None: mainWinWidth = self.mainWin.width() - welcomeWinRight = left+width - if welcomeWinRight+mainWinWidth > screenRight: + welcomeWinRight = left + width + if welcomeWinRight + mainWinWidth > screenRight: # The right edge of the welcome window is out of screen bounds # Keep it in the screen - welcomeWinRight = screenRight-mainWinWidth + welcomeWinRight = screenRight - mainWinWidth self.mainWin.move(welcomeWinRight, top) - def showPage(self, currentItem): self.treeSelector.setCurrentItem(currentItem, 0) @@ -1090,11 +1045,12 @@ def eventFilter(self, object, event): return True return False -if __name__ == '__main__': + +if __name__ == "__main__": app = QApplication([]) win = welcomeWin(app=app) win.showAndSetSize() win.showPage(win.welcomeItem) # win.showPage(win.quickStartItem) - app.setStyle(QStyleFactory.create('Fusion')) + app.setStyle(QStyleFactory.create("Fusion")) sys.exit(app.exec_()) diff --git a/cellacdc/html_utils.py b/cellacdc/html_utils.py index 8bc832553..84003a729 100755 --- a/cellacdc/html_utils.py +++ b/cellacdc/html_utils.py @@ -4,17 +4,18 @@ import sys import textwrap -from . import GUI_INSTALLED, myutils +from . import GUI_INSTALLED, utils from ._palettes import ( - _get_highligth_header_background_rgba, _get_highligth_text_background_rgba + _get_highligth_header_background_rgba, + _get_highligth_text_background_rgba, ) from .colors import rgb_uint_to_html_hex if GUI_INSTALLED: from matplotlib.colors import to_hex -is_mac = sys.platform == 'darwin' +is_mac = sys.platform == "darwin" RST_NOTE_DIR_RGBA = _get_highligth_header_background_rgba() RST_NOTE_DIR_HEX_COLOR = rgb_uint_to_html_hex(RST_NOTE_DIR_RGBA[:3]) @@ -23,78 +24,84 @@ RST_NOTE_TXT_HEX_COLOR = rgb_uint_to_html_hex(RST_NOTE_TXT_RGBA[:3]) ADMONITION_TYPES = ( - 'topic', - 'admonition', - 'attention', - 'caution', - 'danger', - 'error', - 'hint', - 'important', - 'note', - 'seealso', - 'tip', - 'todo', - 'warning', - 'versionadded', - 'versionchanged', - 'deprecated' + "topic", + "admonition", + "attention", + "caution", + "danger", + "error", + "hint", + "important", + "note", + "seealso", + "tip", + "todo", + "warning", + "versionadded", + "versionchanged", + "deprecated", ) -HTML_TAGS = ( - 'code', 'i', 'b', 'br' -) +HTML_TAGS = ("code", "i", "b", "br") + def _tag(tag_info='p style="font-size:10px"'): def wrapper(func): @wraps(func) def inner(text): - tag = tag_info.split(' ')[0] - text = f'<{tag_info}>{text}' + tag = tag_info.split(" ")[0] + text = f"<{tag_info}>{text}" return text + return inner + return wrapper + def tag(text, tag_info='p style="font-size:10pt"'): - tag = tag_info.split(' ')[0] - text = f'<{tag_info}>{text}' + tag = tag_info.split(" ")[0] + text = f"<{tag_info}>{text}" return text + def to_plain_text(html_text): - html_text = re.sub(r' +', ' ', html_text) - html_text = html_text.replace('\n ', '\n') - html_text = html_text.strip('\n') - html_text = html_text.replace('', '`') - html_text = html_text.replace('', '`') - html_text = html_text.replace('
', '\n') - html_text = html_text.replace('
  • ', '\n * ') - html_text = re.sub(r'', '', html_text) - html_text = re.sub(r'<.+>', '', html_text) - html_text = html_text.strip('\n') + html_text = re.sub(r" +", " ", html_text) + html_text = html_text.replace("\n ", "\n") + html_text = html_text.strip("\n") + html_text = html_text.replace("", "`") + html_text = html_text.replace("", "`") + html_text = html_text.replace("
    ", "\n") + html_text = html_text.replace("
  • ", "\n * ") + html_text = re.sub(r"", "", html_text) + html_text = re.sub(r"<.+>", "", html_text) + html_text = html_text.strip("\n") return html_text + def href_tag(text, url): txt = tag(text, tag_info=f'a href="{url}"') return txt + def to_list(items, ordered=False): - list_tag = 'ol' if ordered else 'ul' - items_txt = ''.join([f'
  • {item}
  • ' for item in items]) + list_tag = "ol" if ordered else "ul" + items_txt = "".join([f"
  • {item}
  • " for item in items]) txt = tag(items_txt, tag_info=list_tag) return txt -def span(text, color='r', font_size=None, bold=False): + +def span(text, color="r", font_size=None, bold=False): span_text = f'{text}' if color is not None: try: c = to_hex(color) except Exception as e: - if color == 'r': - c = 'red' - elif color == 'g': - c = 'green' - elif color == 'k': - c = 'black' + if color == "r": + c = "red" + elif color == "g": + c = "green" + elif color == "k": + c = "black" else: c = color span_text = f'{text}' @@ -104,10 +111,11 @@ def span(text, color='r', font_size=None, bold=False): span_text = span_text.replace('">', f'; font-weight:bold;">') return span_text + def css_head(txt): # if is_mac: # txt = txt.replace(',', ', ') - s = (f""" + s = f""" @@ -115,21 +123,23 @@ def css_head(txt): {txt} - """) + """ return s + def html_body(txt): if is_mac: - txt = txt.replace(',', ', ') - s = (f""" + txt = txt.replace(",", ", ") + s = f""" {txt} - """) + """ return s -def paragraph(txt, font_size='13px', font_color=None, wrap=True, center=False): + +def paragraph(txt, font_size="13px", font_color=None, wrap=True, center=False): # if is_mac: # # Qt < 5.15.3 has a bug on macOS and the space after comma and perdiod # # are super small. Force a non-breaking space (except for 'e.g.,'). @@ -140,248 +150,251 @@ def paragraph(txt, font_size='13px', font_color=None, wrap=True, center=False): # txt = txt.replace('i. e. ', 'i.e.') # txt = txt.replace('etc. )', 'etc.)') if not wrap: - txt = txt.replace(' ', ' ') + txt = txt.replace(" ", " ") if font_color is None: - s = (f""" + s = f"""

    {txt}

    - """) + """ else: - s = (f""" + s = f"""

    {txt}

    - """) + """ if center: s = re.sub(r'

    ', r'

    ', s) return s + def rst_urls_to_html(rst_text): - links = re.findall(r'`(.*) ?<(.*)>`_', rst_text) + links = re.findall(r"`(.*) ?<(.*)>`_", rst_text) html_text = rst_text for text, link in links: if not text: text = link repl = href_tag(text.rstrip(), link) - pattern = fr'`{text} ?<{link}>`_' + pattern = rf"`{text} ?<{link}>`_" html_text = re.sub(pattern, repl, html_text) return html_text + def rst_to_html(rst_text, parse_urls=False, keep_spacing=False): if parse_urls: rst_text = rst_urls_to_html(rst_text) - valid_chars = r'[,A-Za-z0-9μ\-\.=_ \<\>\(\)\\\&;]' - html_text = re.sub(rf'\`\`([^\`]*)\`\`', r'\1', rst_text) - html_text = re.sub(rf'\`([^\`]*)\`', r'\1', html_text) - html_text = html_text.replace('<', '<').replace('>', '>') - + valid_chars = r"[,A-Za-z0-9μ\-\.=_ \<\>\(\)\\\&;]" + html_text = re.sub(rf"\`\`([^\`]*)\`\`", r"\1", rst_text) + html_text = re.sub(rf"\`([^\`]*)\`", r"\1", html_text) + html_text = html_text.replace("<", "<").replace(">", ">") + # Insert back the allowed html tags as actual tags for html_tag in HTML_TAGS: - html_text = html_text.replace(f'<{html_tag}>', f'<{html_tag}>') - html_text = html_text.replace(f'</{html_tag}>', f'') - - html_text = html_text.replace('\n', '
    ') + html_text = html_text.replace(f"<{html_tag}>", f"<{html_tag}>") + html_text = html_text.replace(f"</{html_tag}>", f"") + + html_text = html_text.replace("\n", "
    ") if keep_spacing: - html_text = re.sub( - r'(\s\s+)', lambda m: ' '*len(m.group(0)), html_text - ) + html_text = re.sub(r"(\s\s+)", lambda m: " " * len(m.group(0)), html_text) return html_text + def rst_docstring_filter_args(rst_doc, args_to_keep): - start_idx = rst_doc.find('Parameters') + start_idx = rst_doc.find("Parameters") before_params_text = rst_doc[:start_idx] - start_params_idx = before_params_text.rfind('\n') + 1 - - params_text = rst_doc[start_params_idx:] + start_params_idx = before_params_text.rfind("\n") + 1 + + params_text = rst_doc[start_params_idx:] numls = len(params_text) - len(params_text.lstrip()) - ul = ' '*numls + '-'*len('Parameters') - section = ' '*numls + 'Parameters' - section_header = f'{section}\n{ul}\n' - - found_end = re.findall(r'\n *\n', params_text) + ul = " " * numls + "-" * len("Parameters") + section = " " * numls + "Parameters" + section_header = f"{section}\n{ul}\n" + + found_end = re.findall(r"\n *\n", params_text) if not found_end: stop_idx = None else: stop_idx = params_text.find(found_end[0]) - + params_text = params_text[:stop_idx] filtered_params_text = params_text - found_args = re.findall(r'([A-Za-z0-9_]+) \: (.*)', params_text) + found_args = re.findall(r"([A-Za-z0-9_]+) \: (.*)", params_text) for a, (arg_name, arg_dtype) in enumerate(found_args): if arg_name in args_to_keep: continue - - arg_doc = f' {arg_name} : {arg_dtype}' + + arg_doc = f" {arg_name} : {arg_dtype}" start_idx = filtered_params_text.find(arg_doc) + 1 - - if a+1 == len(found_args): + + if a + 1 == len(found_args): stop_idx = None else: - next_arg, next_arg_type = found_args[a+1] - next_arg_doc = f' {next_arg} : {next_arg_type}' + next_arg, next_arg_type = found_args[a + 1] + next_arg_doc = f" {next_arg} : {next_arg_type}" stop_idx = filtered_params_text.find(next_arg_doc) - + text_to_remove = filtered_params_text[start_idx:stop_idx] - filtered_params_text = filtered_params_text.replace(text_to_remove, '') - - filtered_params_text = filtered_params_text.rstrip().rstrip('\n') + filtered_params_text = filtered_params_text.replace(text_to_remove, "") + + filtered_params_text = filtered_params_text.rstrip().rstrip("\n") filtered_doc = rst_doc.replace(params_text, filtered_params_text) return filtered_doc + def rst_docstring_to_html(rst_doc: str, args_subset=None): html_text = rst_doc # ignore lines which start with a # - html_new = '' - for line in html_text.split('\n'): + html_new = "" + for line in html_text.split("\n"): try: first_char = line.lstrip()[0] except IndexError: - first_char = '' - - if first_char == '#': + first_char = "" + + if first_char == "#": continue - html_new += line + '\n' + html_new += line + "\n" html_text = html_new - + if args_subset is not None: html_text = rst_docstring_filter_args(html_text, args_subset) - + # Replace args with indented `bold : italic` - found_args = re.findall(r'([A-Za-z0-9_]+) \: (.*)', html_text) + found_args = re.findall(r"([A-Za-z0-9_]+) \: (.*)", html_text) for a, (arg_name, arg_dtype) in enumerate(found_args): - arg_doc = f' {arg_name} : {arg_dtype}' + arg_doc = f" {arg_name} : {arg_dtype}" html_text = html_text.replace( - arg_doc, - f'
      {arg_name} : {arg_dtype}', + arg_doc, + f"
      {arg_name} : {arg_dtype}", ) - + # Indent description of arg more admon_sections = [] - found_sections = re.findall(r'([A-Za-z ]+)\n *[\-]+\n', rst_doc) + found_sections = re.findall(r"([A-Za-z ]+)\n *[\-]+\n", rst_doc) for s, section in enumerate(found_sections): section_lstrip = section.lstrip() - section_admon = section_lstrip.replace(' ', '').lower() + section_admon = section_lstrip.replace(" ", "").lower() if section_admon in ADMONITION_TYPES: admon_sections.append(section) continue - + numls = len(section) - len(section_lstrip) - ul = ' '*numls + '-'*len(section_lstrip) - section_header = f'{section}\n{ul}\n' + ul = " " * numls + "-" * len(section_lstrip) + section_header = f"{section}\n{ul}\n" start_idx = html_text.find(section_header) + len(section_header) - if s+1 == len(found_sections): + if s + 1 == len(found_sections): stop_idx = None else: - next_section = found_sections[s+1] + next_section = found_sections[s + 1] stop_idx = html_text.find(next_section) - + section_text = html_text[start_idx:stop_idx] section_indented = re.sub( - r'(\n\s\s+)', '
        ', section_text + r"(\n\s\s+)", "
        ", section_text ) - + html_text = list(html_text) html_text[start_idx:stop_idx] = section_indented - html_text = ''.join(html_text) - + html_text = "".join(html_text) + # Replace section header with 16px bold html for section in found_sections: if section in admon_sections: continue - - section_lstrip = section.lstrip() + + section_lstrip = section.lstrip() numls = len(section) - len(section_lstrip) - ul = ' '*numls + '-'*len(section_lstrip) + ul = " " * numls + "-" * len(section_lstrip) html_text = html_text.replace( - f'{section}\n{ul}', - span(section.strip(), font_size='16px', color=None, bold=True) + f"{section}\n{ul}", + span(section.strip(), font_size="16px", color=None, bold=True), ) - + # Replace admonition sections with html table for admon_section in admon_sections: - section_lstrip = admon_section.lstrip() + section_lstrip = admon_section.lstrip() numls = len(admon_section) - len(section_lstrip) - ul = ' '*numls + '-'*len(section_lstrip) - section_header = f'{admon_section}\n{ul}\n' - + ul = " " * numls + "-" * len(section_lstrip) + section_header = f"{admon_section}\n{ul}\n" + start_idx = html_text.find(section_header) + len(section_header) section_text = html_text[start_idx:] - found_end = re.findall(r'\n *\n', section_text) + found_end = re.findall(r"\n *\n", section_text) if not found_end: stop_idx = None else: stop_idx = section_text.find(found_end[0]) - + section_text = section_text[:stop_idx] html_admon = to_admonition(section_text, admonition_type=section_lstrip) html_text = html_text.replace(section_text, html_admon) - html_text = html_text.replace(section_header, '') - + html_text = html_text.replace(section_header, "") + # Replace last charachaters to html html_text = rst_urls_to_html(html_text) - html_text = html_text.replace('\n', '
    ') - html_text = re.sub(rf'\`\`([^\`]*)\`\`', r'\1', html_text) - html_text = re.sub(rf'\`([^\`]*)\`', r'\1', html_text) - + html_text = html_text.replace("\n", "
    ") + html_text = re.sub(rf"\`\`([^\`]*)\`\`", r"\1", html_text) + html_text = re.sub(rf"\`([^\`]*)\`", r"\1", html_text) + return html_text -def to_admonition(text, admonition_type='note'): - if text.find('
    ') == -1: + +def to_admonition(text, admonition_type="note"): + if text.find("
    ") == -1: wrapped_list = textwrap.wrap(text, width=130) - text = '
    '.join(wrapped_list) + text = "
    ".join(wrapped_list) title = admonition_type.capitalize() title_row = tag( - f'! {title}', - tag_info=f'tr bgcolor="{RST_NOTE_DIR_HEX_COLOR}"' + f"! {title}", tag_info=f'tr bgcolor="{RST_NOTE_DIR_HEX_COLOR}"' ) text_row = tag( - f'{text}', - tag_info=f'tr bgcolor="{RST_NOTE_TXT_HEX_COLOR}"' + f"{text}", tag_info=f'tr bgcolor="{RST_NOTE_TXT_HEX_COLOR}"' ) admonition_html = ( - '' - f'{title_row}{text_row}' - '

    ' + "" + f"{title_row}{text_row}" + "

    " ) return admonition_html + def to_note(note_text): - note_html = to_admonition(note_text, admonition_type='note') + note_html = to_admonition(note_text, admonition_type="note") return note_html + # Syntax highlighting html -func_color = (111/255,66/255,205/255) # purplish -kwargs_color = (208/255,88/255,9/255) # reddish/orange -class_color = (215/255,58/255,73/255) # reddish -blue_color = (0/255,92/255,197/255) # blueish -class_sh = span('class', color=class_color) -def_sh = span('def', color=class_color) -if_sh = span('if', color=class_color) -elif_sh = span('elif', color=class_color) -kwargs_sh = span('**kwargs', color=kwargs_color) -Model_sh = span('Model', color=func_color) -segment_sh = span('segment', color=func_color) -add_prompt_sh = span('add_prompt', color=func_color) -predict_sh = span('predict', color=func_color) -CV_sh = span('CV', color=func_color) -init_sh = span('__init__', color=blue_color) -myModel_sh = span('MyModel', color=func_color) -return_sh = span('return', color=class_color) -equal_sh = span('=', color=class_color) -open_par_sh = span('(', color=blue_color) -close_par_sh = span(')', color=blue_color) -image_sh = span('image', color=kwargs_color) -from_sh = span('from', color=class_color) -import_sh = span('import', color=class_color) -is_not_sh = span('is not', color=class_color) -np_mean_sh = span('np.mean', color=class_color) -np_std_sh = span('np.std', color=class_color) +func_color = (111 / 255, 66 / 255, 205 / 255) # purplish +kwargs_color = (208 / 255, 88 / 255, 9 / 255) # reddish/orange +class_color = (215 / 255, 58 / 255, 73 / 255) # reddish +blue_color = (0 / 255, 92 / 255, 197 / 255) # blueish +class_sh = span("class", color=class_color) +def_sh = span("def", color=class_color) +if_sh = span("if", color=class_color) +elif_sh = span("elif", color=class_color) +kwargs_sh = span("**kwargs", color=kwargs_color) +Model_sh = span("Model", color=func_color) +segment_sh = span("segment", color=func_color) +add_prompt_sh = span("add_prompt", color=func_color) +predict_sh = span("predict", color=func_color) +CV_sh = span("CV", color=func_color) +init_sh = span("__init__", color=blue_color) +myModel_sh = span("MyModel", color=func_color) +return_sh = span("return", color=class_color) +equal_sh = span("=", color=class_color) +open_par_sh = span("(", color=blue_color) +close_par_sh = span(")", color=blue_color) +image_sh = span("image", color=kwargs_color) +from_sh = span("from", color=class_color) +import_sh = span("import", color=class_color) +is_not_sh = span("is not", color=class_color) +np_mean_sh = span("np.mean", color=class_color) +np_std_sh = span("np.std", color=class_color) import textwrap -table_style_header = textwrap.dedent('''\ +table_style_header = textwrap.dedent("""\ -''') \ No newline at end of file +""") diff --git a/cellacdc/info.py b/cellacdc/info.py index dc177a4f6..c44afa508 100644 --- a/cellacdc/info.py +++ b/cellacdc/info.py @@ -1,14 +1,16 @@ from . import urls, html_utils -forum_href = html_utils.href_tag('forum page', urls.forum_url) +forum_href = html_utils.href_tag("forum page", urls.forum_url) utilsInfo = { - 'Convert _segm.npz file(s) to ImageJ ROIs...': (f""" + "Convert _segm.npz file(s) to ImageJ ROIs...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Create connected 3D segmentation mask from z-slices segmentation...': (f""" + """ + ), + "Create connected 3D segmentation mask from z-slices segmentation...": ( + f""" This utility is used to create a 3D segmentation mask by projecting the center z-slice of the 3D objects to their own z-boundaries.

    @@ -16,15 +18,17 @@ a "cylindrical" object,
    where the largest z-slice is projected up and down to the max and min z-slice. - """), - - 'Track sub-cellular objects (assign same ID as the cell they belong to)...': (f""" + """ + ), + "Track sub-cellular objects (assign same ID as the cell they belong to)...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Apply tracking info from tabular data...': (f""" + """ + ), + "Apply tracking info from tabular data...": ( + f""" This utility is used to load the information of an external tracker into Cell-ACDC.

    @@ -44,41 +48,48 @@ Note that to use this utility you need to have a Cell-ACDC compatible segmentation file. - """), - - 'Create required data structure from image files...': (f""" + """ + ), + "Create required data structure from image files...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Re-apply data prep steps to selected channels...': (f""" + """ + ), + "Re-apply data prep steps to selected channels...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Concatenate acdc output tables from multiple Positions...': (f""" + """ + ), + "Concatenate acdc output tables from multiple Positions...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Compute measurements for one or more experiments...': (f""" + """ + ), + "Compute measurements for one or more experiments...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Combine measurements from multiple segmentation files...': (f""" + """ + ), + "Combine measurements from multiple segmentation files...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """), - - 'Add lineage tree table to one or more experiments...': (f""" + """ + ), + "Add lineage tree table to one or more experiments...": ( + f""" Not documented yet. You can ask help about utilities on our {forum_href}.

    Thank you for your patience! - """) -} \ No newline at end of file + """ + ), +} diff --git a/cellacdc/io.py b/cellacdc/io.py index 96edb34db..ad7d1b33d 100644 --- a/cellacdc/io.py +++ b/cellacdc/io.py @@ -14,91 +14,94 @@ import numpy as np import skimage.io -from . import path, load, myutils, printl +from . import path, load, utils, printl from . import moth_bud_tot_selected_columns_filepath from . import saved_measurements_selections_folderpath from . import config + def get_saved_measurements_selections(): if not os.path.exists(saved_measurements_selections_folderpath): return [] - + return list(os.listdir(saved_measurements_selections_folderpath)) + def save_measurements_selections( - selected_measurements_filename, selected_measurements_dict - ): - os.makedirs( - saved_measurements_selections_folderpath, exist_ok=True - ) - + selected_measurements_filename, selected_measurements_dict +): + os.makedirs(saved_measurements_selections_folderpath, exist_ok=True) + configPars = config.ConfigParser() for section, values in selected_measurements_dict.items(): configPars[section] = {} for option, value in values.items(): configPars[section][option] = str(value) - + ini_filepath = os.path.join( saved_measurements_selections_folderpath, selected_measurements_filename ) - with open(ini_filepath, 'w') as configfile: + with open(ini_filepath, "w") as configfile: configPars.write(configfile) - + return ini_filepath + def read_measurements_selections(selected_measurements_filename): ini_filepath = os.path.join( saved_measurements_selections_folderpath, selected_measurements_filename ) - + cp = config.ConfigParser() cp.read(ini_filepath) - + return dict(cp) + def get_saved_moth_bud_tot_selections(): if not os.path.exists(moth_bud_tot_selected_columns_filepath): return {} - + with open(moth_bud_tot_selected_columns_filepath) as file: json_data = json.load(file) - + return json_data + def save_moth_bud_tot_selected_options(selected_options): - with open(moth_bud_tot_selected_columns_filepath, mode='w') as file: + with open(moth_bud_tot_selected_columns_filepath, mode="w") as file: json.dump(selected_options, file, indent=2) + def get_filepath_from_channel_name(images_path, channel_name): - h5_aligned_path = '' - h5_path = '' - npz_aligned_path = '' - img_path = '' - is_segm_ch = channel_name.find('segm') != -1 - segm_npy_path = '' - segm_npz_path = '' + h5_aligned_path = "" + h5_path = "" + npz_aligned_path = "" + img_path = "" + is_segm_ch = channel_name.find("segm") != -1 + segm_npy_path = "" + segm_npz_path = "" for file in path.listdir(images_path): filepath = os.path.join(images_path, file) if file.endswith(channel_name): return filepath - is_segm_npz_file = is_segm_ch and file.endswith(f'{channel_name}.npz') - is_segm_npy_file = is_segm_ch and file.endswith(f'{channel_name}.npy') + is_segm_npz_file = is_segm_ch and file.endswith(f"{channel_name}.npz") + is_segm_npy_file = is_segm_ch and file.endswith(f"{channel_name}.npy") if is_segm_npz_file: segm_npz_path = filepath if is_segm_npy_file: segm_npy_path = filepath - if file.endswith(f'{channel_name}_aligned.h5'): + if file.endswith(f"{channel_name}_aligned.h5"): h5_aligned_path = filepath - elif file.endswith(f'{channel_name}.h5'): + elif file.endswith(f"{channel_name}.h5"): h5_path = filepath - elif file.endswith(f'{channel_name}_aligned.npz'): + elif file.endswith(f"{channel_name}_aligned.npz"): npz_aligned_path = filepath - elif ( - file.endswith(f'{channel_name}.tif') - or file.endswith(f'{channel_name}.npz') - ): + elif file.endswith(f"{channel_name}.tif") or file.endswith( + f"{channel_name}.npz" + ): img_path = filepath - + if segm_npz_path: return segm_npz_path elif segm_npy_path: @@ -112,83 +115,85 @@ def get_filepath_from_channel_name(images_path, channel_name): elif img_path: return img_path else: - return '' + return "" + def _validate_filename(filename: str, is_path=False): if is_path: - pattern = r'[A-Za-z0-9_\\\/\:\.\-]+' + pattern = r"[A-Za-z0-9_\\\/\:\.\-]+" else: - pattern = r'[A-Za-z0-9_\.\-]+' + pattern = r"[A-Za-z0-9_\.\-]+" m = list(re.finditer(pattern, filename)) invalid_matches = [] for i, valid_chars in enumerate(m): start_idx, stop_idx = valid_chars.span() - if i == len(m)-1: + if i == len(m) - 1: invalid_chars = filename[stop_idx:] else: - next_valid_chars = m[i+1] + next_valid_chars = m[i + 1] start_next_idx = next_valid_chars.span()[0] invalid_chars = filename[stop_idx:start_next_idx] if invalid_chars: invalid_matches.append(invalid_chars) return set(invalid_matches) + def get_filename_cli( - question='Insert a filename', logger_func=print, check_exists=False, - is_path=False - ): + question="Insert a filename", logger_func=print, check_exists=False, is_path=False +): while True: filename = input(f'{question} (type "q" to cancel): ') - if filename.lower() == 'q': + if filename.lower() == "q": return - + if not is_path: invalid = _validate_filename(filename, is_path=is_path) if invalid: logger_func( - f'[ERROR]: The filename contains invalid charachters: {invalid}' - 'Valid charachters are letters, numbers, underscore, full stop, and hyphen.\n' + f"[ERROR]: The filename contains invalid charachters: {invalid}" + "Valid charachters are letters, numbers, underscore, full stop, and hyphen.\n" ) continue if check_exists and not os.path.exists(filename): - logger_func( - f'[ERROR] The provided path "{filename}" does not exist.' - ) + logger_func(f'[ERROR] The provided path "{filename}" does not exist.') continue return filename + def save_image_data(filepath, img_data): - if filepath.endswith('.h5'): + if filepath.endswith(".h5"): load.save_to_h5(filepath, img_data) - elif filepath.endswith('.npz'): + elif filepath.endswith(".npz"): savez_compressed(filepath, img_data) - elif filepath.endswith('.npy'): + elif filepath.endswith(".npy"): np.save(filepath, img_data) else: - myutils.to_tiff(filepath, img_data) + utils.to_tiff(filepath, img_data) return np.squeeze(img_data) + def savez_compressed(filepath, *args, safe=True, **kwargs): if not safe: np.savez_compressed(filepath, *args, **kwargs) - return - + return + if not os.path.exists(filepath): np.savez_compressed(filepath, *args, **kwargs) return - + try: pathlib.Path(filepath).unlink() - temp_filepath = filepath.replace('.npz', '.new.npz') + temp_filepath = filepath.replace(".npz", ".new.npz") np.savez_compressed(temp_filepath, *args, **kwargs) os.replace(temp_filepath, filepath) except PermissionError as err: np.savez_compressed(filepath, *args, **kwargs) -def rename_files_replace_invalid_chars(files, src_path, replacement_char='_'): + +def rename_files_replace_invalid_chars(files, src_path, replacement_char="_"): renamed_files = [] for file in files: invalid_chars = _validate_filename(file, is_path=False) @@ -202,17 +207,18 @@ def rename_files_replace_invalid_chars(files, src_path, replacement_char='_'): renamed_files.append(new_file) return renamed_files + def move_separate_channels_tiffs_to_pos_folders( - tiffs_folderpath: os.PathLike, - channel_names: Sequence[str], - get_only_basenames=False, - extension='.tif' - ): + tiffs_folderpath: os.PathLike, + channel_names: Sequence[str], + get_only_basenames=False, + extension=".tif", +): basenames = set() - for file in myutils.listdir(tiffs_folderpath): + for file in utils.listdir(tiffs_folderpath): if not file.endswith(extension): continue - + filename_no_ext = os.path.splitext(file)[0] for channel in channel_names: splits = filename_no_ext.split(channel) @@ -220,33 +226,33 @@ def move_separate_channels_tiffs_to_pos_folders( basename = splits[0] basenames.add(basename) break - - basenames = natsorted(basenames) - + + basenames = natsorted(basenames) + if get_only_basenames: return basenames - + for p, basename in enumerate(basenames): - pos_folderpath = os.path.join(tiffs_folderpath, f'Position_{p+1}') - images_path = os.path.join(pos_folderpath, 'Images') - + pos_folderpath = os.path.join(tiffs_folderpath, f"Position_{p + 1}") + images_path = os.path.join(pos_folderpath, "Images") + os.makedirs(images_path, exist_ok=True) - for file in myutils.listdir(tiffs_folderpath): + for file in utils.listdir(tiffs_folderpath): if not file.startswith(basename): continue - + src_filepath = os.path.join(tiffs_folderpath, file) - if file.endswith('.tif'): + if file.endswith(".tif"): dst_filepath = os.path.join(images_path, file) shutil.move(src_filepath, dst_filepath) - elif file.endswith('_metadata.csv'): - dst_filename = f'{basename}metadata.csv' + elif file.endswith("_metadata.csv"): + dst_filename = f"{basename}metadata.csv" dst_filepath = os.path.join(images_path, dst_filename) - df_metadata = pd.read_csv(src_filepath, index_col='Description') - df_metadata.at['basename', 'values'] = basename + df_metadata = pd.read_csv(src_filepath, index_col="Description") + df_metadata.at["basename", "values"] = basename df_metadata.to_csv(dst_filepath) try: os.remove(src_filepath) except Exception as err: pass - return True \ No newline at end of file + return True diff --git a/cellacdc/load.py b/cellacdc/load.py index 62a336464..392de1b03 100755 --- a/cellacdc/load.py +++ b/cellacdc/load.py @@ -23,13 +23,14 @@ import skimage import skimage.io -import skimage.measure - +import skimage.measure + import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) + +warnings.simplefilter(action="ignore", category=FutureWarning) from . import prompts -from . import myutils, measurements, config +from . import utils, measurements, config from . import base_cca_dict, base_acdc_df, html_utils, settings_folderpath from . import cca_df_colnames, printl from . import ignore_exception, cellacdc_path @@ -46,11 +47,10 @@ if GUI_INSTALLED: from qtpy import QtGui from qtpy.QtCore import Qt, QRect, QRectF - from qtpy.QtWidgets import ( - QApplication, QMessageBox, QFileDialog - ) + from qtpy.QtWidgets import QApplication, QMessageBox, QFileDialog import pyqtgraph as pg - pg.setConfigOption('imageAxisOrder', 'row-major') + + pg.setConfigOption("imageAxisOrder", "row-major") from . import apps from . import widgets from . import qrc_resources_path, qrc_resources_light_path @@ -58,90 +58,97 @@ from . import whitelist acdc_df_bool_cols = [ - 'is_cell_dead', - 'is_cell_excluded', - 'is_history_known', + "is_cell_dead", + "is_cell_excluded", + "is_history_known", ] -acdc_df_str_cols = {'cell_cycle_stage': str, 'relationship': str} +acdc_df_str_cols = {"cell_cycle_stage": str, "relationship": str} acdc_df_int_cols = { - 'frame_i': int, - 'Cell_ID': int, - 'generation_num': int, - 'emerg_frame_i': int, - 'division_frame_i': int, - 'generation_num_tree': int, - 'parent_ID_tree': int, - 'root_ID_tree': int, - 'sister_ID_tree': int, - 'num_objects': int, + "frame_i": int, + "Cell_ID": int, + "generation_num": int, + "emerg_frame_i": int, + "division_frame_i": int, + "generation_num_tree": int, + "parent_ID_tree": int, + "root_ID_tree": int, + "sister_ID_tree": int, + "num_objects": int, } acdc_df_dtype_id_checker_mapper = { - 'float': pd.api.types.is_float_dtype, - 'string': pd.api.types.is_string_dtype, - 'object': pd.api.types.is_object_dtype, - 'bool': pd.api.types.is_bool_dtype, + "float": pd.api.types.is_float_dtype, + "string": pd.api.types.is_string_dtype, + "object": pd.api.types.is_object_dtype, + "bool": pd.api.types.is_bool_dtype, } -additional_metadata_path = os.path.join(settings_folderpath, 'additional_metadata.json') -last_entries_metadata_path = os.path.join(settings_folderpath, 'last_entries_metadata.csv') -last_selected_measurements_ini_path = os.path.join( - settings_folderpath, 'last_selected_measurements.ini' +additional_metadata_path = os.path.join(settings_folderpath, "additional_metadata.json") +last_entries_metadata_path = os.path.join( + settings_folderpath, "last_entries_metadata.csv" ) -channel_file_formats = ( - '_aligned.h5', '.h5', '_aligned.npz', '.tif' +last_selected_measurements_ini_path = os.path.join( + settings_folderpath, "last_selected_measurements.ini" ) -ISO_TIMESTAMP_FORMAT = r'iso%Y%m%d%H%M%S' +channel_file_formats = ("_aligned.h5", ".h5", "_aligned.npz", ".tif") +ISO_TIMESTAMP_FORMAT = r"iso%Y%m%d%H%M%S" + class FileNameError(Exception): pass + def _pd_cast_float_and_bool_to_int(df, col, _): - df[col] = df[col].astype("Int64") # preserves NA values + df[col] = df[col].astype("Int64") # preserves NA values return df + def _pd_cast_string_to_int(df, col, not_nan_mask): - df[col] = (df[col].astype(str).str.lower() == 'true').astype("Int64") + df[col] = (df[col].astype(str).str.lower() == "true").astype("Int64") df.loc[~not_nan_mask, col] = pd.NA return df + acdc_df_dtype_id_func_mapper = { - 'float': _pd_cast_float_and_bool_to_int, - 'string': _pd_cast_string_to_int, - 'object': _pd_cast_string_to_int, - 'bool': _pd_cast_float_and_bool_to_int, + "float": _pd_cast_float_and_bool_to_int, + "string": _pd_cast_string_to_int, + "object": _pd_cast_string_to_int, + "bool": _pd_cast_float_and_bool_to_int, } -def read_json(json_path, logger_func=print, desc='custom annotations'): + +def read_json(json_path, logger_func=print, desc="custom annotations"): json_data = {} try: with open(json_path) as file: json_data = json.load(file) except Exception as e: - print('****************************') + print("****************************") logger_func(traceback.format_exc()) - print('****************************') - logger_func(f'json path: {json_path}') - print('----------------------------') + print("****************************") + logger_func(f"json path: {json_path}") + print("----------------------------") logger_func(f'Error while reading saved "{desc}". See above') - print('============================') + print("============================") return json_data + def remove_duplicates_file(filepath): if not os.path.exists(filepath): return - with open(filepath, 'r') as file: + with open(filepath, "r") as file: first_line = file.readline() rest_of_text = file.read() duplicate_first_line_idx = rest_of_text.find(first_line) if duplicate_first_line_idx == -1: return - unique_text = f'{first_line}{rest_of_text[:duplicate_first_line_idx]}' - with open(filepath, 'w') as file: + unique_text = f"{first_line}{rest_of_text[:duplicate_first_line_idx]}" + with open(filepath, "w") as file: file.write(unique_text) + def to_csv_through_temp(df, csv_path): filename = os.path.basename(csv_path) with tempfile.TemporaryDirectory() as temp_dir: @@ -149,15 +156,16 @@ def to_csv_through_temp(df, csv_path): df.to_csv(tmp_filepath) shutil.copy2(tmp_filepath, csv_path) + def get_all_acdc_folders(user_profile_path): - models = myutils.get_list_of_models() - acdc_folders = [f'acdc-{model}' for model in models] - acdc_folders.append('acdc-java') - acdc_folders.append('.acdc-logs') - acdc_folders.append('.acdc-settings') - acdc_folders.append('acdc-manual') - acdc_folders.append('acdc-metrics') - acdc_folders.append('acdc-examples') + models = utils.get_list_of_models() + acdc_folders = [f"acdc-{model}" for model in models] + acdc_folders.append("acdc-java") + acdc_folders.append(".acdc-logs") + acdc_folders.append(".acdc-settings") + acdc_folders.append("acdc-manual") + acdc_folders.append("acdc-metrics") + acdc_folders.append("acdc-examples") existing_acdc_folders = [] for file in os.listdir(user_profile_path): filepath = os.path.join(user_profile_path, file) @@ -168,72 +176,76 @@ def get_all_acdc_folders(user_profile_path): existing_acdc_folders.append(file) return existing_acdc_folders + def write_json(json_data, json_path, indent=2): - with open(json_path, mode='w') as file: + with open(json_path, mode="w") as file: json.dump(json_data, file, indent=indent) + def read_last_selected_set_measurements(logger_func=print): if not os.path.exists(last_selected_measurements_ini_path): return {} - + cp = config.ConfigParser() cp.read(last_selected_measurements_ini_path) - + return cp + def write_last_selected_set_measurements(last_selected_meas: dict[str, dict]): configPars = config.ConfigParser() for section, values in last_selected_meas.items(): configPars[section] = {} for option, value in values.items(): configPars[section][option] = str(value) - - with open(last_selected_measurements_ini_path, 'w') as configfile: + + with open(last_selected_measurements_ini_path, "w") as configfile: configPars.write(configfile) + def migrate_models_paths(dst_path): - models = myutils.get_list_of_models() - user_profile_path = dst_path.replace('\\', '/') + models = utils.get_list_of_models() + user_profile_path = dst_path.replace("\\", "/") for model in models: - model_path = os.path.join(models_path, model, 'model') - weight_location_txt_path = os.path.join( - model_path, 'weights_location_path.txt' - ) + model_path = os.path.join(models_path, model, "model") + weight_location_txt_path = os.path.join(model_path, "weights_location_path.txt") if not os.path.exists(weight_location_txt_path): continue - with open(weight_location_txt_path, 'r') as txt: + with open(weight_location_txt_path, "r") as txt: model_location = os.path.expanduser(txt.read()) - model_location = model_location.replace('\\', '/') + model_location = model_location.replace("\\", "/") model_folder = os.path.basename(model_location) model_location = os.path.join(user_profile_path, model_folder) - model_location = model_location.replace('\\', '/') - with open(weight_location_txt_path, 'w') as txt: + model_location = model_location.replace("\\", "/") + with open(weight_location_txt_path, "w") as txt: txt.write(model_location) + def save_workflow_to_config( - filepath, - ini_items: dict, - paths: list[str], - stop_frame_nums: list[int], - type='segment' - ): - paths = [path.replace('\\', '/') for path in paths] - paths_param = '\n'.join(paths) - paths_param = f'\n{paths_param}' + filepath, + ini_items: dict, + paths: list[str], + stop_frame_nums: list[int], + type="segment", +): + paths = [path.replace("\\", "/") for path in paths] + paths_param = "\n".join(paths) + paths_param = f"\n{paths_param}" configPars = config.ConfigParser() - configPars['paths_info'] = {'paths': paths_param} - - stop_frames_param = '\n'.join([str(n) for n in stop_frame_nums]) - stop_frames_param = f'\n{stop_frames_param}' - configPars['paths_info']['stop_frame_numbers'] = stop_frames_param - + configPars["paths_info"] = {"paths": paths_param} + + stop_frames_param = "\n".join([str(n) for n in stop_frame_nums]) + stop_frames_param = f"\n{stop_frames_param}" + configPars["paths_info"]["stop_frame_numbers"] = stop_frames_param + for section, options in ini_items.items(): configPars[section] = {} for option, value in options.items(): configPars[section][option] = str(value) - with open(filepath, 'w') as configfile: + with open(filepath, "w") as configfile: configPars.write(configfile) + def read_segm_workflow_from_config(filepath) -> dict: configPars = config.ConfigParser() configPars.read(filepath) @@ -242,20 +254,20 @@ def read_segm_workflow_from_config(filepath) -> dict: options = dict(configPars[section]) ini_items[section] = {} for option, value in options.items(): - if section == 'paths_info' or section == 'paths_to_segment': - value_list = value.strip('\n').strip().split('\n') - if option == 'paths': + if section == "paths_info" or section == "paths_to_segment": + value_list = value.strip("\n").strip().split("\n") + if option == "paths": abs_paths = [] folderpath = os.path.dirname(filepath) for path in value_list: if os.path.exists(path): abs_paths.append(path) continue - - abs_path = f'{folderpath}{os.sep}{path}' + + abs_path = f"{folderpath}{os.sep}{path}" if not os.path.exists(abs_path): raise FileNotFoundError( - 'The following path to analyse does not exist:' + "The following path to analyse does not exist:" f'\n\n"{path}"\n' ) @@ -265,322 +277,340 @@ def read_segm_workflow_from_config(filepath) -> dict: else: ini_items[section][option] = value_list continue - if value == 'False': + if value == "False": value = False - elif value == 'True': + elif value == "True": value = True - elif value == 'None': + elif value == "None": value = None - elif option == 'SizeT' or option == 'SizeZ': + elif option == "SizeT" or option == "SizeZ": value = int(value) - - if section == 'standard_postprocess_features' and value is not None: + + if section == "standard_postprocess_features" and value is not None: for _type in (int, float, str): try: value = _type(value) break except Exception as e: continue - - elif section == 'custom_postprocess_features': - low, high = value.strip().strip('(').strip(')').split(',') - if low.strip().lower() == 'none': + + elif section == "custom_postprocess_features": + low, high = value.strip().strip("(").strip(")").split(",") + if low.strip().lower() == "none": low = None else: low = float(low) - if high.strip().lower() == 'none': + if high.strip().lower() == "none": high = None else: high = float(high) value = (low, high) - + ini_items[section][option] = value return ini_items + def get_images_paths(folder_path): - folder_type = myutils.determine_folder_type(folder_path) - is_pos_folder, is_images_folder, folder_path = folder_type + folder_type = utils.determine_folder_type(folder_path) + is_pos_folder, is_images_folder, folder_path = folder_type if not is_pos_folder and not is_images_folder: - pos_foldernames = myutils.get_pos_foldernames(folder_path) + pos_foldernames = utils.get_pos_foldernames(folder_path) images_paths = [ - os.path.join(folder_path, pos, 'Images') for pos in pos_foldernames + os.path.join(folder_path, pos, "Images") for pos in pos_foldernames ] elif is_pos_folder: - images_paths = [os.path.join(folder_path, 'Images')] + images_paths = [os.path.join(folder_path, "Images")] elif is_images_folder: images_paths = [folder_path] return images_paths + def read_config_metrics(ini_path): configPars = config.ConfigParser() configPars.read(ini_path) - if 'equations' not in configPars: - configPars['equations'] = {} + if "equations" not in configPars: + configPars["equations"] = {} + + if "mixed_channels_equations" not in configPars: + configPars["mixed_channels_equations"] = {} - if 'mixed_channels_equations' not in configPars: - configPars['mixed_channels_equations'] = {} + if "user_path_equations" not in configPars: + configPars["user_path_equations"] = {} - if 'user_path_equations' not in configPars: - configPars['user_path_equations'] = {} - return configPars + def add_configPars_metrics(configPars_ref, configPars2_to_add): - configPars_ref['equations'] = { - **configPars2_to_add['equations'], **configPars_ref['equations'] + configPars_ref["equations"] = { + **configPars2_to_add["equations"], + **configPars_ref["equations"], } - configPars_ref['mixed_channels_equations'] = { - **configPars2_to_add['mixed_channels_equations'], - **configPars_ref['mixed_channels_equations'] + configPars_ref["mixed_channels_equations"] = { + **configPars2_to_add["mixed_channels_equations"], + **configPars_ref["mixed_channels_equations"], } - configPars_ref['user_path_equations'] = { - **configPars2_to_add['user_path_equations'], - **configPars_ref['user_path_equations'] + configPars_ref["user_path_equations"] = { + **configPars2_to_add["user_path_equations"], + **configPars_ref["user_path_equations"], } keep_user_path_equations = { - key:val for key, val in configPars_ref['user_path_equations'].items() - if key not in configPars_ref['equations'] - } - configPars_ref['user_path_equations'] = keep_user_path_equations + key: val + for key, val in configPars_ref["user_path_equations"].items() + if key not in configPars_ref["equations"] + } + configPars_ref["user_path_equations"] = keep_user_path_equations return configPars_ref -def h5py_iter(g, prefix=''): + +def h5py_iter(g, prefix=""): for key, item in g.items(): path = os.path.join(prefix, key) - if isinstance(item, h5py.Dataset): # test for dataset + if isinstance(item, h5py.Dataset): # test for dataset yield (path, item) - elif isinstance(item, h5py.Group): # test for group (go down) + elif isinstance(item, h5py.Group): # test for group (go down) yield from h5py_iter(item, path) + def h5dump_to_arr(h5path): data_dict = {} - with h5py.File(h5path, 'r') as f: - for (path, dset) in h5py_iter(f): + with h5py.File(h5path, "r") as f: + for path, dset in h5py_iter(f): data_dict[dset.name] = dset[()] sorted_keys = natsorted(data_dict.keys()) arr = np.array([data_dict[key] for key in sorted_keys]) return arr + def save_to_h5(dst_filepath, data): filename = os.path.basename(dst_filepath) tempDir = tempfile.mkdtemp() tempFilepath = os.path.join(tempDir, filename) - chunks = [1]*data.ndim + chunks = [1] * data.ndim chunks[-2:] = data.shape[-2:] - with h5py.File(tempFilepath, 'w') as h5f: + with h5py.File(tempFilepath, "w") as h5f: dataset = h5f.create_dataset( - 'data', data.shape, dtype=data.dtype, - chunks=chunks, shuffle=False + "data", data.shape, dtype=data.dtype, chunks=chunks, shuffle=False ) dataset[:] = data shutil.move(tempFilepath, dst_filepath) shutil.rmtree(tempDir) -def load_segm_file(images_path, end_name_segm_file='segm', return_path=False): - if not end_name_segm_file.endswith('.npz'): - end_name_segm_file = f'{end_name_segm_file}.npz' - + +def load_segm_file(images_path, end_name_segm_file="segm", return_path=False): + if not end_name_segm_file.endswith(".npz"): + end_name_segm_file = f"{end_name_segm_file}.npz" + found_files = [ - file for file in myutils.listdir(images_path) + file + for file in utils.listdir(images_path) if file.endswith(end_name_segm_file) ] try: if len(found_files) == 0: segm_data = None - segm_filepath = '' + segm_filepath = "" elif len(found_files) == 1: segm_filepath = os.path.join(images_path, found_files[0]) - segm_data = np.load(segm_filepath)['arr_0'].astype(np.uint32) + segm_data = np.load(segm_filepath)["arr_0"].astype(np.uint32) else: found_files.sort(key=len) segm_filepath = os.path.join(images_path, found_files[0]) - segm_data = np.load(segm_filepath)['arr_0'].astype(np.uint32) + segm_data = np.load(segm_filepath)["arr_0"].astype(np.uint32) except OSError as e: - if str(e).find("[Errno 22] Invalid argument") != -1 and segm_filepath.find("OneDrive") != -1: + if ( + str(e).find("[Errno 22] Invalid argument") != -1 + and segm_filepath.find("OneDrive") != -1 + ): print(traceback.print_exc()) - raise OSError("If the file is online only, and syncing is disabled, this file cannot be accessed.") + raise OSError( + "If the file is online only, and syncing is disabled, this file cannot be accessed." + ) else: raise e - + if return_path: return segm_data, segm_filepath else: return segm_data + def get_tzyx_shape(images_path): df_metadata = load_metadata_df(images_path) - channel = df_metadata.at['channel_0_name', 'values'] + channel = df_metadata.at["channel_0_name", "values"] img_filepath = get_filename_from_channel(images_path, channel) img_data = load_image_file(img_filepath) if img_data.ndim == 4: return img_data.shape - - SizeZ = int(df_metadata.at['SizeZ', 'values']) - SizeT = int(df_metadata.at['SizeT', 'values']) + + SizeZ = int(df_metadata.at["SizeZ", "values"]) + SizeT = int(df_metadata.at["SizeT", "values"]) YX = img_data.shape[-2:] return (SizeT, SizeZ, *YX) - + + def load_metadata_df(images_path): - for file in myutils.listdir(images_path): - if not file.endswith('metadata.csv'): + for file in utils.listdir(images_path): + if not file.endswith("metadata.csv"): continue filepath = os.path.join(images_path, file) parse_metadata_csv_file(filepath) - return pd.read_csv(filepath).set_index('Description') + return pd.read_csv(filepath).set_index("Description") + def _add_will_divide_column(acdc_df): - if 'cell_cycle_stage' not in acdc_df.columns: + if "cell_cycle_stage" not in acdc_df.columns: return acdc_df - if 'will_divide' in acdc_df.columns: + if "will_divide" in acdc_df.columns: return acdc_df - acdc_df['will_divide'] = np.nan - last_index_cca_df = acdc_df[['cell_cycle_stage']].last_valid_index() + acdc_df["will_divide"] = np.nan + last_index_cca_df = acdc_df[["cell_cycle_stage"]].last_valid_index() cca_df = acdc_df.loc[:last_index_cca_df, cca_df_colnames].reset_index() - cca_df['will_divide'] = 0.0 + cca_df["will_divide"] = 0.0 cca_df_buds = cca_df.query('relationship == "bud"') - for budID, bud_cca_df in cca_df_buds.groupby('Cell_ID'): - all_gen_nums = cca_df.query(f'Cell_ID == {budID}')['generation_num'] + for budID, bud_cca_df in cca_df_buds.groupby("Cell_ID"): + all_gen_nums = cca_df.query(f"Cell_ID == {budID}")["generation_num"] if not (all_gen_nums > 0).any(): # bud division is annotated in the future - continue + continue - cca_df.loc[bud_cca_df.index, 'will_divide'] = 1 - - mothID = int(bud_cca_df['relative_ID'].iloc[0]) - first_frame_bud = bud_cca_df['frame_i'].iloc[0] + cca_df.loc[bud_cca_df.index, "will_divide"] = 1 + + mothID = int(bud_cca_df["relative_ID"].iloc[0]) + first_frame_bud = bud_cca_df["frame_i"].iloc[0] gen_num_moth = cca_df.query( - f'(frame_i == {first_frame_bud}) & (Cell_ID == {mothID})' - )['generation_num'].iloc[0] - - mothMask = ( - (cca_df['Cell_ID'] == mothID) - & (cca_df['generation_num'] == gen_num_moth) + f"(frame_i == {first_frame_bud}) & (Cell_ID == {mothID})" + )["generation_num"].iloc[0] + + mothMask = (cca_df["Cell_ID"] == mothID) & ( + cca_df["generation_num"] == gen_num_moth ) - cca_df.loc[mothMask, 'will_divide'] = 1 - - cca_df = cca_df.set_index(['frame_i', 'Cell_ID']) + cca_df.loc[mothMask, "will_divide"] = 1 + + cca_df = cca_df.set_index(["frame_i", "Cell_ID"]) acdc_df.loc[cca_df.index, cca_df.columns] = cca_df return acdc_df + def _fix_corrected_assignment_i(acdc_df: pd.DataFrame): - """Replaces the column 'corrected_assignment' with the newer + """Replaces the column 'corrected_assignment' with the newer 'corrected_on_frame_i' Parameters ---------- acdc_df : pd.DataFrame - Annotations and metrics dataframe (from the `acdc_output` CSV file) + Annotations and metrics dataframe (from the `acdc_output` CSV file) with ['frame_i', 'Cell_ID'] as index Returns ------- pd.DataFrame - acdc_df with correct `corrected_on_frame_i` and `corrected_assignment` + acdc_df with correct `corrected_on_frame_i` and `corrected_assignment` removed. - """ - - if 'corrected_assignment' not in acdc_df.columns: + """ + + if "corrected_assignment" not in acdc_df.columns: return acdc_df - - if 'corrected_on_frame_i' in acdc_df.columns: - if (acdc_df['corrected_on_frame_i'] > -1).any(): - acdc_df = acdc_df.drop( - columns='corrected_assignment', errors='ignore' - ) + + if "corrected_on_frame_i" in acdc_df.columns: + if (acdc_df["corrected_on_frame_i"] > -1).any(): + acdc_df = acdc_df.drop(columns="corrected_assignment", errors="ignore") return acdc_df - + for ID, df in acdc_df.groupby(level=1): # df = df[['corrected_assignment']].sort_index() - df['block'] = ( - df['corrected_assignment'].shift(1) != df['corrected_assignment'] - ).astype(int).cumsum() - df = df[df['corrected_assignment']>0] - for block, df_block in df.reset_index().groupby('block'): - corr_on_frame_i = df_block['frame_i'].min() - df_block = df_block.set_index(['frame_i', 'Cell_ID']) + df["block"] = ( + (df["corrected_assignment"].shift(1) != df["corrected_assignment"]) + .astype(int) + .cumsum() + ) + df = df[df["corrected_assignment"] > 0] + for block, df_block in df.reset_index().groupby("block"): + corr_on_frame_i = df_block["frame_i"].min() + df_block = df_block.set_index(["frame_i", "Cell_ID"]) corr_on_index = df_block.index - acdc_df.loc[corr_on_index, 'corrected_on_frame_i'] = corr_on_frame_i - + acdc_df.loc[corr_on_index, "corrected_on_frame_i"] = corr_on_frame_i + # acdc_df['corrected_on_frame_i'] = acdc_df['corrected_on_frame_i'].astype(int) - acdc_df = acdc_df.drop(columns='corrected_assignment') - + acdc_df = acdc_df.drop(columns="corrected_assignment") + return acdc_df + def _fix_will_divide(acdc_df): - """Resetting annotaions in GUI sometimes does not fully reset `will_divide` - column. Here we set `will_divide` back to 0 for those cells whose + """Resetting annotaions in GUI sometimes does not fully reset `will_divide` + column. Here we set `will_divide` back to 0 for those cells whose next generation does not exist (division was not annotated) Parameters ---------- acdc_df : pd.DataFrame - Annotations and metrics dataframe (from the `acdc_output` CSV file) + Annotations and metrics dataframe (from the `acdc_output` CSV file) with ['frame_i', 'Cell_ID'] as index Returns ------- pd.DataFrame acdc_df with `will_divide` corrected. - """ - if 'cell_cycle_stage' not in acdc_df.columns: + """ + if "cell_cycle_stage" not in acdc_df.columns: return acdc_df - - required_cols = ['frame_i', 'Cell_ID', 'generation_num', 'will_divide'] - - cca_df_mask = ~acdc_df['cell_cycle_stage'].isna() + + required_cols = ["frame_i", "Cell_ID", "generation_num", "will_divide"] + + cca_df_mask = ~acdc_df["cell_cycle_stage"].isna() cca_df = acdc_df[cca_df_mask].reset_index()[required_cols] - - IDs_will_divide_wrong = ( - cca_functions.get_IDs_gen_num_will_divide_wrong(cca_df) - ) + + IDs_will_divide_wrong = cca_functions.get_IDs_gen_num_will_divide_wrong(cca_df) if not IDs_will_divide_wrong: return acdc_df - - cca_df = cca_df.reset_index().set_index(['Cell_ID', 'generation_num']) - cca_df.loc[IDs_will_divide_wrong, 'will_divide'] = 0 + + cca_df = cca_df.reset_index().set_index(["Cell_ID", "generation_num"]) + cca_df.loc[IDs_will_divide_wrong, "will_divide"] = 0 cca_df = cca_df.reset_index() acdc_df = acdc_df.reset_index() - cca_df = cca_df.set_index(['frame_i', 'Cell_ID']) - acdc_df = acdc_df.set_index(['frame_i', 'Cell_ID']) - + cca_df = cca_df.set_index(["frame_i", "Cell_ID"]) + acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]) + cca_df_index = cca_df_mask[cca_df_mask].index - acdc_df.loc[cca_df_index, 'will_divide'] = cca_df['will_divide'] - + acdc_df.loc[cca_df_index, "will_divide"] = cca_df["will_divide"] + return acdc_df + def _add_missing_columns(acdc_df): - if 'is_cell_excluded' not in acdc_df.columns: - acdc_df['is_cell_excluded'] = 0 - - if 'is_cell_dead' not in acdc_df.columns: - acdc_df['is_cell_dead'] = 0 - - if 'cell_cycle_stage' not in acdc_df.columns: + if "is_cell_excluded" not in acdc_df.columns: + acdc_df["is_cell_excluded"] = 0 + + if "is_cell_dead" not in acdc_df.columns: + acdc_df["is_cell_dead"] = 0 + + if "cell_cycle_stage" not in acdc_df.columns: return acdc_df - - last_index_cca_df = acdc_df[['cell_cycle_stage']].last_valid_index() - + + last_index_cca_df = acdc_df[["cell_cycle_stage"]].last_valid_index() + for col, default in base_cca_dict.items(): - if col == 'will_divide': + if col == "will_divide": # Already taken care by _add_will_divide_column continue - + if col in acdc_df.columns: continue - + acdc_df[col] = np.nan acdc_df.loc[:last_index_cca_df, col] = default - + return acdc_df + def _ensure_acdc_df_latest_compatibility(acdc_df): acdc_df = _parse_loaded_acdc_df(acdc_df) acdc_df = _add_missing_columns(acdc_df) @@ -589,10 +619,11 @@ def _ensure_acdc_df_latest_compatibility(acdc_df): acdc_df = _fix_corrected_assignment_i(acdc_df) return acdc_df + def _parse_loaded_acdc_df(acdc_df): - acdc_df = acdc_df.set_index(['frame_i', 'Cell_ID']).sort_index() + acdc_df = acdc_df.set_index(["frame_i", "Cell_ID"]).sort_index() # remove duplicates saved by mistake or bugs - duplicated = acdc_df.index.duplicated(keep='first') + duplicated = acdc_df.index.duplicated(keep="first") acdc_df = acdc_df[~duplicated] acdc_df = pd_bool_and_float_to_int_to_str( acdc_df, acdc_df_bool_cols, colsToCastInt=[], inplace=True @@ -600,16 +631,17 @@ def _parse_loaded_acdc_df(acdc_df): acdc_df = pd_int_to_bool(acdc_df, acdc_df_bool_cols) return acdc_df + def _remove_redundant_columns(acdc_df): - acdc_df = acdc_df.drop(columns=['index', 'level_0'], errors='ignore') + acdc_df = acdc_df.drop(columns=["index", "level_0"], errors="ignore") return acdc_df + def read_acdc_df_csv(acdc_df_filepath, index_col=None): - acdc_df = pd.read_csv( - acdc_df_filepath, dtype=acdc_df_str_cols, index_col=index_col - ) + acdc_df = pd.read_csv(acdc_df_filepath, dtype=acdc_df_str_cols, index_col=index_col) return acdc_df + def _load_acdc_df_file(acdc_df_file_path): acdc_df = read_acdc_df_csv(acdc_df_file_path) acdc_df = _remove_redundant_columns(acdc_df) @@ -618,25 +650,25 @@ def _load_acdc_df_file(acdc_df_file_path): acdc_df[acdc_df_drop_cca.columns] = acdc_df_drop_cca except KeyError: pass - + acdc_df = _ensure_acdc_df_latest_compatibility(acdc_df) return acdc_df + def load_acdc_df_file( - images_path, - end_name_acdc_df_file='acdc_output', - return_path=False - ): - if not end_name_acdc_df_file.endswith('.csv'): - end_name_acdc_df_file = f'{end_name_acdc_df_file}.csv' - + images_path, end_name_acdc_df_file="acdc_output", return_path=False +): + if not end_name_acdc_df_file.endswith(".csv"): + end_name_acdc_df_file = f"{end_name_acdc_df_file}.csv" + found_files = [ - file for file in myutils.listdir(images_path) + file + for file in utils.listdir(images_path) if file.endswith(end_name_acdc_df_file) ] if len(found_files) == 0: acdc_df = None - acdc_df_file_path = '' + acdc_df_file_path = "" elif len(found_files) == 1: acdc_df_file_path = os.path.join(images_path, found_files[0]) acdc_df = _load_acdc_df_file(acdc_df_file_path).reset_index() @@ -644,52 +676,52 @@ def load_acdc_df_file( found_files.sort(key=len) acdc_df_file_path = os.path.join(images_path, found_files[0]) acdc_df = _load_acdc_df_file(acdc_df_file_path).reset_index() - + if return_path: return acdc_df, acdc_df_file_path else: return acdc_df + def save_acdc_df_file( - acdc_df, csv_path, custom_annot_columns=None, - last_cca_frame_i=None - ): + acdc_df, csv_path, custom_annot_columns=None, last_cca_frame_i=None +): if custom_annot_columns is not None: new_order_cols = [*sorted_cols, *custom_annot_columns] else: new_order_cols = sorted_cols - + for col in new_order_cols.copy(): if col in acdc_df.columns: continue new_order_cols.remove(col) - + for col in acdc_df.columns: if col in new_order_cols: continue new_order_cols.append(col) - + acdc_df = acdc_df[new_order_cols] - + if last_cca_frame_i is not None: - max_frame_i = acdc_df.index.get_level_values('frame_i').max() + max_frame_i = acdc_df.index.get_level_values("frame_i").max() if last_cca_frame_i < max_frame_i: - acdc_df.loc[last_cca_frame_i+1:, cca_df_colnames] = pd.NA - + acdc_df.loc[last_cca_frame_i + 1 :, cca_df_colnames] = pd.NA + try: acdc_df.to_csv(csv_path) except Exception as err: - printl(f'[WARNING]: {err}') + printl(f"[WARNING]: {err}") return + def store_copy_acdc_df(posData, acdc_output_csv_path, log_func=printl): try: if not os.path.exists(acdc_output_csv_path): return - - df = ( - pd.read_csv(acdc_output_csv_path, dtype=acdc_df_str_cols) - .set_index(['frame_i', 'Cell_ID']) + + df = pd.read_csv(acdc_output_csv_path, dtype=acdc_df_str_cols).set_index( + ["frame_i", "Cell_ID"] ) posData.setTempPaths() zip_path = posData.acdc_output_backup_zip_path @@ -697,188 +729,186 @@ def store_copy_acdc_df(posData, acdc_output_csv_path, log_func=printl): except Exception as e: log_func(traceback.format_exc()) + def _copy_acdc_dfs_to_temp_archive( - zip_path, temp_zip_path, csv_names, compression_opts - ): - if not os.path.exists(zip_path): + zip_path, temp_zip_path, csv_names, compression_opts +): + if not os.path.exists(zip_path): return - - with zipfile.ZipFile(zip_path, mode='r') as zip: + + with zipfile.ZipFile(zip_path, mode="r") as zip: for csv_name in csv_names: with warnings.catch_warnings(): - warnings.simplefilter("ignore") - acdc_df = pd.read_csv( - zip.open(csv_name), dtype=acdc_df_str_cols - ) + warnings.simplefilter("ignore") + acdc_df = pd.read_csv(zip.open(csv_name), dtype=acdc_df_str_cols) acdc_df = _ensure_acdc_df_latest_compatibility(acdc_df) acdc_df = pd_bool_and_float_to_int_to_str(acdc_df, inplace=False) - compression_opts['archive_name'] = csv_name - acdc_df.to_csv( - temp_zip_path, compression=compression_opts - ) + compression_opts["archive_name"] = csv_name + acdc_df.to_csv(temp_zip_path, compression=compression_opts) + def _store_acdc_df_archive(zip_path, acdc_df_to_store): csv_names = [] if os.path.exists(zip_path): - with zipfile.ZipFile(zip_path, mode='r') as zip: + with zipfile.ZipFile(zip_path, mode="r") as zip: csv_names = natsorted(set(zip.namelist())) - + new_key = datetime.now().strftime(ISO_TIMESTAMP_FORMAT) - csv_name = f'{new_key}.csv' + csv_name = f"{new_key}.csv" if csv_name in csv_names: # Do not save duplicates within the same second return - + if len(csv_names) > 20: # Delete oldest df and resave remaining 19 csv_names.pop(0) - + zip_filename = os.path.basename(zip_path) - temp_zip_filename = zip_filename.replace('.csv', '_temp.csv') + temp_zip_filename = zip_filename.replace(".csv", "_temp.csv") temp_dirpath = tempfile.mkdtemp() temp_zip_path = os.path.join(temp_dirpath, temp_zip_filename) - compression_opts = {'method': 'zip', 'compresslevel': zipfile.ZIP_STORED} - _copy_acdc_dfs_to_temp_archive( - zip_path, temp_zip_path, csv_names, compression_opts - ) - - - compression_opts['archive_name'] = csv_name + compression_opts = {"method": "zip", "compresslevel": zipfile.ZIP_STORED} + _copy_acdc_dfs_to_temp_archive(zip_path, temp_zip_path, csv_names, compression_opts) + + compression_opts["archive_name"] = csv_name acdc_df = pd_bool_and_float_to_int_to_str(acdc_df_to_store, inplace=False) acdc_df.to_csv(temp_zip_path, compression=compression_opts) shutil.move(temp_zip_path, zip_path) shutil.rmtree(temp_dirpath) + def store_unsaved_acdc_df(recovery_folderpath, df, log_func=printl): new_key = datetime.now().strftime(ISO_TIMESTAMP_FORMAT) - csv_name = f'{new_key}.csv' - unsaved_recovery_folderpath = os.path.join( - recovery_folderpath, 'never_saved' - ) + csv_name = f"{new_key}.csv" + unsaved_recovery_folderpath = os.path.join(recovery_folderpath, "never_saved") if not os.path.exists(unsaved_recovery_folderpath): os.mkdir(unsaved_recovery_folderpath) - - files = myutils.listdir(unsaved_recovery_folderpath) - csv_files = [file for file in files if file.endswith('.csv')] + + files = utils.listdir(unsaved_recovery_folderpath) + csv_files = [file for file in files if file.endswith(".csv")] if len(files) > 20: csv_files = natsorted(csv_files) files_to_remove = csv_files[:-20] for file_to_remove in files_to_remove: os.remove(os.path.join(unsaved_recovery_folderpath, file_to_remove)) - + csv_path = os.path.join(unsaved_recovery_folderpath, csv_name) df.to_csv(csv_path) + def get_last_stored_unsaved_acdc_df_filepath(recovery_folderpath): if not os.path.exists(recovery_folderpath): return - - unsaved_recovery_folderpath = os.path.join( - recovery_folderpath, 'never_saved' - ) + + unsaved_recovery_folderpath = os.path.join(recovery_folderpath, "never_saved") if not os.path.exists(unsaved_recovery_folderpath): return - - files = myutils.listdir(unsaved_recovery_folderpath) - csv_files = [file for file in files if file.endswith('.csv')] + + files = utils.listdir(unsaved_recovery_folderpath) + csv_files = [file for file in files if file.endswith(".csv")] if not csv_files: return - + csv_files = natsorted(csv_files) csv_name = csv_files[-1] - + return os.path.join(unsaved_recovery_folderpath, csv_name) + def get_last_stored_unsaved_acdc_df(recovery_folderpath): if not os.path.exists(recovery_folderpath): return - - unsaved_recovery_folderpath = os.path.join( - recovery_folderpath, 'never_saved' - ) + + unsaved_recovery_folderpath = os.path.join(recovery_folderpath, "never_saved") if not os.path.exists(unsaved_recovery_folderpath): return - - files = myutils.listdir(unsaved_recovery_folderpath) - csv_files = [file for file in files if file.endswith('.csv')] + + files = utils.listdir(unsaved_recovery_folderpath) + csv_files = [file for file in files if file.endswith(".csv")] if not csv_files: return - + csv_files = natsorted(csv_files) csv_name = csv_files[-1] acdc_df = pd.read_csv(os.path.join(unsaved_recovery_folderpath, csv_name)) acdc_df = _ensure_acdc_df_latest_compatibility(acdc_df) - + return acdc_df + def read_acdc_df_from_archive(archive_path, key): - if not key.endswith('.csv'): - csv_name = f'{key}.csv' + if not key.endswith(".csv"): + csv_name = f"{key}.csv" else: csv_name = key - - if archive_path.endswith('.zip'): - with zipfile.ZipFile(archive_path, 'r') as zip: + + if archive_path.endswith(".zip"): + with zipfile.ZipFile(archive_path, "r") as zip: acdc_df = pd.read_csv(zip.open(csv_name)) else: - csv_path = os.path.join(archive_path, f'{key}.csv') + csv_path = os.path.join(archive_path, f"{key}.csv") acdc_df = pd.read_csv(csv_path) - + acdc_df = _ensure_acdc_df_latest_compatibility(acdc_df) return acdc_df + def get_user_ch_paths(images_paths, user_ch_name): user_ch_file_paths = [] for images_path in images_paths: img_aligned_found = False - for filename in myutils.listdir(images_path): - if filename.find(f'{user_ch_name}_aligned.np') != -1: - img_path_aligned = f'{images_path}/{filename}' + for filename in utils.listdir(images_path): + if filename.find(f"{user_ch_name}_aligned.np") != -1: + img_path_aligned = f"{images_path}/{filename}" img_aligned_found = True - elif filename.find(f'{user_ch_name}.tif') != -1: - img_path_tif = f'{images_path}/{filename}' + elif filename.find(f"{user_ch_name}.tif") != -1: + img_path_tif = f"{images_path}/{filename}" if img_aligned_found: img_path = img_path_aligned else: img_path = img_path_tif user_ch_file_paths.append(img_path) - print(f'Loading {img_path}...') + print(f"Loading {img_path}...") return user_ch_file_paths + def get_acdc_output_files(images_path): - ls = myutils.listdir(images_path) + ls = utils.listdir(images_path) acdc_output_files = [ - file for file in ls - if file.find('acdc_output') != -1 and file.endswith('.csv') + file for file in ls if file.find("acdc_output") != -1 and file.endswith(".csv") ] return acdc_output_files + def get_segm_files(images_path): - ls = myutils.listdir(images_path) + ls = utils.listdir(images_path) segm_files = [ - file for file in ls if file.endswith('segm.npz') - or file.find('segm_raw_postproc') != -1 - or file.endswith('segm_raw.npz') - or (file.endswith('.npz') and file.find('segm') != -1) - or file.endswith('_segm.npy') + file + for file in ls + if file.endswith("segm.npz") + or file.find("segm_raw_postproc") != -1 + or file.endswith("segm_raw.npz") + or (file.endswith(".npz") and file.find("segm") != -1) + or file.endswith("_segm.npy") ] - return segm_files + return segm_files + def get_segm_endnames_from_exp_path(exp_path, pos_foldernames=None): if pos_foldernames is None: - pos_foldernames = myutils.get_pos_foldernames(exp_path) - + pos_foldernames = utils.get_pos_foldernames(exp_path) + existingEndNames = set() for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) # Use first found channel, it doesn't matter for metrics for chName in chNames: - filePath = myutils.getChannelFilePath(images_path, chName) + filePath = utils.getChannelFilePath(images_path, chName) if filePath: break else: @@ -889,54 +919,59 @@ def get_segm_endnames_from_exp_path(exp_path, pos_foldernames=None): _posData = loadData(filePath, chName) _posData.getBasenameAndChNames() found_files = get_segm_files(_posData.images_path) - _existingEndnames = get_endnames( - _posData.basename, found_files - ) + _existingEndnames = get_endnames(_posData.basename, found_files) existingEndNames.update(_existingEndnames) - + return existingEndNames -def get_files_with(images_path: os.PathLike, with_text: str, ext: str=None): - ls = myutils.listdir(images_path) + +def get_files_with(images_path: os.PathLike, with_text: str, ext: str = None): + ls = utils.listdir(images_path) found_files = [] for file in ls: if file.find(with_text) == -1: continue - + if ext is not None and not file.endswith(ext): continue - + found_files.append(file) - + return found_files + def load_segmInfo_df(pos_path): - images_path = os.path.join(pos_path, 'Images') - for file in myutils.listdir(images_path): - if file.endswith('segmInfo.csv'): + images_path = os.path.join(pos_path, "Images") + for file in utils.listdir(images_path): + if file.endswith("segmInfo.csv"): csv_path = os.path.join(images_path, file) df = pd.read_csv(csv_path) - df = df.set_index(['filename', 'frame_i']).sort_index() + df = df.set_index(["filename", "frame_i"]).sort_index() df = df[~df.index.duplicated()] return df + def get_filename_from_channel( - images_path, channel_name, not_allowed_ends=None, logger=None, - basename=None, skip_channels=None - ): + images_path, + channel_name, + not_allowed_ends=None, + logger=None, + basename=None, + skip_channels=None, +): if not_allowed_ends is None: not_allowed_ends = tuple() if skip_channels is None: skip_channels = tuple() if basename is None: - basename = '' - - channel_filepath = '' - h5_aligned_path = '' - h5_path = '' - npz_aligned_path = '' - tif_path = '' - for file in myutils.listdir(images_path): + basename = "" + + channel_filepath = "" + h5_aligned_path = "" + h5_path = "" + npz_aligned_path = "" + tif_path = "" + for file in utils.listdir(images_path): isValidEnd = True for not_allowed_end in not_allowed_ends: if file.endswith(not_allowed_end): @@ -944,104 +979,110 @@ def get_filename_from_channel( break if not isValidEnd: continue - + is_channel_to_skip = False for channel_to_skip in skip_channels: for ff in channel_file_formats: - if file.endswith(f'{basename}{channel_to_skip}{ff}'): + if file.endswith(f"{basename}{channel_to_skip}{ff}"): is_channel_to_skip = channel_name not in file break if is_channel_to_skip: break - + if is_channel_to_skip: continue channelDataPath = os.path.join(images_path, file) - if file == f'{basename}{channel_name}': + if file == f"{basename}{channel_name}": channel_filepath = channelDataPath - elif file.endswith(f'{basename}{channel_name}_aligned.h5'): + elif file.endswith(f"{basename}{channel_name}_aligned.h5"): h5_aligned_path = channelDataPath - elif file.endswith(f'{basename}{channel_name}.h5'): + elif file.endswith(f"{basename}{channel_name}.h5"): h5_path = channelDataPath - elif file.endswith(f'{basename}{channel_name}_aligned.npz'): + elif file.endswith(f"{basename}{channel_name}_aligned.npz"): npz_aligned_path = channelDataPath - elif file.endswith(f'{basename}{channel_name}.tif'): + elif file.endswith(f"{basename}{channel_name}.tif"): tif_path = channelDataPath - + if channel_filepath: if logger is not None: - logger(f'Using channel file ({channel_filepath})...') + logger(f"Using channel file ({channel_filepath})...") return channel_filepath elif h5_aligned_path: if logger is not None: - logger(f'Using .h5 aligned file ({h5_aligned_path})...') + logger(f"Using .h5 aligned file ({h5_aligned_path})...") return h5_aligned_path elif h5_path: if logger is not None: - logger(f'Using .h5 file ({h5_path})...') + logger(f"Using .h5 file ({h5_path})...") return h5_path elif npz_aligned_path: if logger is not None: - logger(f'Using .npz aligned file ({npz_aligned_path})...') + logger(f"Using .npz aligned file ({npz_aligned_path})...") return npz_aligned_path elif tif_path: if logger is not None: - logger(f'Using .tif file ({tif_path})...') + logger(f"Using .tif file ({tif_path})...") return tif_path else: - return '' + return "" + def imread(path): - if path.endswith('.tif') or path.endswith('.tiff'): + if path.endswith(".tif") or path.endswith(".tiff"): return tifffile.imread(path) else: return skimage.io.imread(path) + def load_image_file(filepath): - if filepath.endswith('.h5'): - with h5py.File(filepath, 'r') as h5f: - img_data = h5f['data'][()] - elif filepath.endswith('.npz'): + if filepath.endswith(".h5"): + with h5py.File(filepath, "r") as h5f: + img_data = h5f["data"][()] + elif filepath.endswith(".npz"): with np.load(filepath) as archive: files = archive.files img_data = archive[files[0]] - elif filepath.endswith('.npy'): + elif filepath.endswith(".npy"): img_data = np.load(filepath) else: img_data = imread(filepath) return np.squeeze(img_data) + def load_image_data_from_channel(images_path: os.PathLike, channel_name: str): filepath = get_filename_from_channel(images_path, channel_name) return load_image_file(filepath) + def get_endnames(basename, files): endnames = [] for f in files: filename, _ = os.path.splitext(f) - endname = filename[len(basename):] + endname = filename[len(basename) :] endnames.append(endname) return endnames + def get_filepath_from_endname(images_path, endname): channel_filepath = get_filename_from_channel(images_path, endname) if channel_filepath: return channel_filepath - - for file in myutils.listdir(images_path): + + for file in utils.listdir(images_path): if file.endswith(endname): - return os.path.join(images_path, file) + return os.path.join(images_path, file) - for file in myutils.listdir(images_path): + for file in utils.listdir(images_path): file_noext, ext = os.path.splitext(file) if file_noext.endswith(endname): - return os.path.join(images_path, file) - - return '' + return os.path.join(images_path, file) + + return "" + def get_exp_path(path): - folder_type = myutils.determine_folder_type(path) + folder_type = utils.determine_folder_type(path) is_pos_folder, is_images_folder, _ = folder_type if is_pos_folder: exp_path = os.path.dirname(path) @@ -1051,60 +1092,64 @@ def get_exp_path(path): exp_path = path return exp_path + def get_endname_from_channels(filename, channels): endname = None for ch in channels: - ch_aligned = f'{ch}_aligned' - m = re.search(fr'{ch}(.\w+)*$', filename) - m_aligned = re.search(fr'{ch_aligned}(.\w+)*$', filename) + ch_aligned = f"{ch}_aligned" + m = re.search(rf"{ch}(.\w+)*$", filename) + m_aligned = re.search(rf"{ch_aligned}(.\w+)*$", filename) if m_aligned is not None: return endname elif m is not None: return endname + def get_endname_from_filepath(filepath, allow_empty=False): parent_folderpath = os.path.dirname(filepath) - if not parent_folderpath.endswith('Images'): - return - + if not parent_folderpath.endswith("Images"): + return + filename = os.path.basename(filepath) filename_noext, ext = os.path.splitext(filename) - images_files = myutils.listdir(parent_folderpath) + images_files = utils.listdir(parent_folderpath) basename = os.path.commonprefix(images_files) - endname = filename_noext[len(basename):] + endname = filename_noext[len(basename) :] if not endname: - endname = basename.split('_')[-1] - + endname = basename.split("_")[-1] + return endname - + def get_endnames_from_basename(basename, filenames): - return [os.path.splitext(f)[0][len(basename):] for f in filenames] + return [os.path.splitext(f)[0][len(basename) :] for f in filenames] + def get_path_from_endname(end_name, images_path, ext=None): if ext is None: - end_name, ext = myutils.remove_known_extension(end_name) - - if os.path.exists(os.path.join(images_path, f'{end_name}{ext}')): - return os.path.join(images_path, f'{end_name}{ext}') - - basename = os.path.commonprefix(myutils.listdir(images_path)) - searched_file = f'{basename}{end_name}{ext}' - for file in myutils.listdir(images_path): + end_name, ext = utils.remove_known_extension(end_name) + + if os.path.exists(os.path.join(images_path, f"{end_name}{ext}")): + return os.path.join(images_path, f"{end_name}{ext}") + + basename = os.path.commonprefix(utils.listdir(images_path)) + searched_file = f"{basename}{end_name}{ext}" + for file in utils.listdir(images_path): filename, ext = os.path.splitext(file) if file == searched_file: return os.path.join(images_path, file), file elif filename == searched_file: return os.path.join(images_path, file), file - - for file in myutils.listdir(images_path): + + for file in utils.listdir(images_path): filename, ext = os.path.splitext(file) if file.endswith(end_name): return os.path.join(images_path, file), file elif filename.endswith(end_name): return os.path.join(images_path, file), file - - return '', '' + + return "", "" + def pd_int_to_bool(acdc_df, colsToCast=None): if colsToCast is None: @@ -1116,15 +1161,12 @@ def pd_int_to_bool(acdc_df, colsToCast=None): continue return acdc_df + def pd_bool_and_float_to_int_to_str( - acdc_df, - colsToCastBool=None, - colsToCastInt=None, - csv_path=None, - inplace=True - ): - """Converts boolean columns to 0s and 1s, float columns to integers, - and then to "string" to ensure smooth saving to CSV. Save to CSV if + acdc_df, colsToCastBool=None, colsToCastInt=None, csv_path=None, inplace=True +): + """Converts boolean columns to 0s and 1s, float columns to integers, + and then to "string" to ensure smooth saving to CSV. Save to CSV if `csv_path` is not None. Parameters @@ -1147,30 +1189,24 @@ def pd_bool_and_float_to_int_to_str( """ if not inplace: acdc_df = acdc_df.copy() - + if colsToCastBool is None: colsToCastBool = acdc_df_bool_cols - + if colsToCastInt is None: colsToCastInt = acdc_df_int_cols additional_sister_cols = [ - col for col in acdc_df.columns if col.startswith('sister_ID_tree') + col for col in acdc_df.columns if col.startswith("sister_ID_tree") ] - additional_sister_cols = { - col: int for col in additional_sister_cols - } + additional_sister_cols = {col: int for col in additional_sister_cols} colsToCastInt = ({**colsToCastInt, **additional_sister_cols}).keys() - + for col in colsToCastInt: try: series = acdc_df[col] notna_idx = series.notna() acdc_df[col] = ( - acdc_df[col] - .astype(float) - .fillna(0) - .astype(int) - .astype("string") + acdc_df[col].astype(float).fillna(0).astype(int).astype("string") ) acdc_df.loc[~notna_idx, col] = "" except KeyError: @@ -1178,20 +1214,20 @@ def pd_bool_and_float_to_int_to_str( except Exception as e: printl(col) traceback.print_exc() - + for col in colsToCastBool: try: series = acdc_df[col] - notna_idx = (series.notna()) & (series != '') + notna_idx = (series.notna()) & (series != "") notna_series = series.loc[notna_idx] dtype_id = None for dtype_id, dtype_checker in acdc_df_dtype_id_checker_mapper.items(): if dtype_checker(notna_series): break - + if dtype_id is None: break - + casting_func = acdc_df_dtype_id_func_mapper[dtype_id] acdc_df = casting_func(acdc_df, col, notna_idx) except KeyError: @@ -1199,82 +1235,85 @@ def pd_bool_and_float_to_int_to_str( except Exception as e: printl(col) traceback.print_exc() - + if csv_path is not None: acdc_df.to_csv(csv_path) - + return acdc_df + def parse_metadata_csv_file(csv_filepath): - with open(csv_filepath, 'r') as file: + with open(csv_filepath, "r") as file: txt = file.read() - - lines = txt.split('\n') + + lines = txt.split("\n") for l, line in enumerate(lines.copy()): - is_channel_name_line = re.search(r'channel_\d+_name', line) - if line.startswith('basename') or is_channel_name_line: - parts = line.split(',') + is_channel_name_line = re.search(r"channel_\d+_name", line) + if line.startswith("basename") or is_channel_name_line: + parts = line.split(",") if len(parts) == 2: continue - + if parts[1].startswith('"') and parts[-1].endswith('"'): continue - + quoted_value = f'"{"".join(parts[1:])}"' - parsed_line = f'{parts[0]},{quoted_value}' + parsed_line = f"{parts[0]},{quoted_value}" lines[l] = parsed_line - - with open(csv_filepath, 'w') as file: - file.write('\n'.join(lines)) + + with open(csv_filepath, "w") as file: + file.write("\n".join(lines)) + def get_posData_metadata(images_path, basename): # First check if metadata.csv already has the channel names - for file in myutils.listdir(images_path): - if file.endswith('metadata.csv'): + for file in utils.listdir(images_path): + if file.endswith("metadata.csv"): metadata_csv_path = os.path.join(images_path, file) parse_metadata_csv_file(metadata_csv_path) - df_metadata = pd.read_csv(metadata_csv_path).set_index('Description') + df_metadata = pd.read_csv(metadata_csv_path).set_index("Description") break else: - df_metadata = ( - pd.DataFrame( - columns=['Description', 'values']).set_index('Description') - ) - if basename.endswith('_'): + df_metadata = pd.DataFrame(columns=["Description", "values"]).set_index( + "Description" + ) + if basename.endswith("_"): basename = basename[:-1] - metadata_csv_path = os.path.join(images_path, f'{basename}_metadata.csv') + metadata_csv_path = os.path.join(images_path, f"{basename}_metadata.csv") return df_metadata, metadata_csv_path + def is_pos_prepped(images_path): - filenames = myutils.listdir(images_path) + filenames = utils.listdir(images_path) for filename in filenames: - if filename.endswith('dataPrepROIs_coords.csv'): + if filename.endswith("dataPrepROIs_coords.csv"): return True - elif filename.endswith('dataPrep_bkgrROIs.json'): + elif filename.endswith("dataPrep_bkgrROIs.json"): return True - elif filename.endswith('aligned.npz'): + elif filename.endswith("aligned.npz"): return True - elif filename.endswith('align_shift.npy'): + elif filename.endswith("align_shift.npy"): return True - elif filename.endswith('bkgrRoiData.npz'): + elif filename.endswith("bkgrRoiData.npz"): return True return False + def is_bkgrROIs_present(images_path): - filenames = myutils.listdir(images_path) + filenames = utils.listdir(images_path) for filename in filenames: - if filename.endswith('dataPrep_bkgrROIs.json'): + if filename.endswith("dataPrep_bkgrROIs.json"): return True - elif filename.endswith('bkgrRoiData.npz'): + elif filename.endswith("bkgrRoiData.npz"): return True return False + class loadData: def __init__( - self, imgPath, user_ch_name, - relPathDepth=3, QParent=None, log_func=None - ): + self, imgPath, user_ch_name, relPathDepth=3, QParent=None, log_func=None + ): self.fluo_data_dict = {} self.fluo_bkgrData_dict = {} self.bkgrROIs = [] @@ -1285,7 +1324,7 @@ def __init__( self.images_path = os.path.dirname(imgPath) self.images_folder_files = os.listdir(self.images_path) self.pos_path = os.path.dirname(self.images_path) - self.spotmax_out_path = os.path.join(self.pos_path, 'spotMAX_output') + self.spotmax_out_path = os.path.join(self.pos_path, "spotMAX_output") self.exp_path = os.path.dirname(self.pos_path) self.pos_foldername = os.path.basename(self.pos_path) self.pos_num = self.getPosNum() @@ -1298,75 +1337,75 @@ def __init__( self.frame_i = 0 self.clickEntryPointsDfs = {} path_li = os.path.normpath(imgPath).split(os.sep) - self.relPath = f'{f"{os.sep}".join(path_li[-relPathDepth:])}' + self.relPath = f"{f'{os.sep}'.join(path_li[-relPathDepth:])}" filename_ext = os.path.basename(imgPath) self.filename_ext = filename_ext self.filename, self.ext = os.path.splitext(filename_ext) self._additionalMetadataValues = None self.loadLastEntriesMetadata() self.attempFixBasenameBug() - self.non_aligned_ext = '.tif' - if filename_ext.endswith('aligned.npz'): - for file in myutils.listdir(self.images_path): - if file.endswith(f'{user_ch_name}.h5'): - self.non_aligned_ext = '.h5' + self.non_aligned_ext = ".tif" + if filename_ext.endswith("aligned.npz"): + for file in utils.listdir(self.images_path): + if file.endswith(f"{user_ch_name}.h5"): + self.non_aligned_ext = ".h5" break self.tracked_lost_centroids = None - if not hasattr(self, 'whitelist'): + if not hasattr(self, "whitelist"): self.whitelist = None self.log_func = log_func def attempFixBasenameBug(self): - r'''Attempt removing _s(\d+)_ from filenames if not present in basename - + r"""Attempt removing _s(\d+)_ from filenames if not present in basename + This was a bug introduced when saving the basename with data structure, it was not saving the _s(\d+)_ part. - ''' + """ try: - ls = myutils.listdir(self.images_path) + ls = utils.listdir(self.images_path) for file in ls: - if file.endswith('metadata.csv'): + if file.endswith("metadata.csv"): metadata_csv_path = os.path.join(self.images_path, file) break else: return - + parse_metadata_csv_file(metadata_csv_path) - df_metadata = pd.read_csv(metadata_csv_path).set_index('Description') + df_metadata = pd.read_csv(metadata_csv_path).set_index("Description") try: - basename = df_metadata.at['basename', 'values'] + basename = df_metadata.at["basename", "values"] except Exception as e: return - - numPos = len(myutils.get_pos_foldernames(self.exp_path)) + + numPos = len(utils.get_pos_foldernames(self.exp_path)) numPosDigits = len(str(numPos)) - s0p = str(self.pos_num+1).zfill(numPosDigits) + s0p = str(self.pos_num + 1).zfill(numPosDigits) - if basename.endswith(f'_s{s0p}_'): + if basename.endswith(f"_s{s0p}_"): return - + for file in ls: - endname = file[len(basename):] - if not endname.startswith(f's{s0p}_'): + endname = file[len(basename) :] + if not endname.startswith(f"s{s0p}_"): continue - fixed_endname = endname[len(f's{s0p}_'):] - fixed_filename = f'{basename}{fixed_endname}' + fixed_endname = endname[len(f"s{s0p}_") :] + fixed_filename = f"{basename}{fixed_endname}" fixed_filepath = os.path.join(self.images_path, fixed_filename) filepath = os.path.join(self.images_path, file) - hidden_filepath = os.path.join(self.images_path, f'.{file}') + hidden_filepath = os.path.join(self.images_path, f".{file}") shutil.copy2(filepath, fixed_filepath) try: os.rename(filepath, hidden_filepath) except Exception as e: pass - + except Exception as e: traceback.print_exc() - + def isPrepped(self): return is_pos_prepped(self.images_path) - + def isBkgrROIpresent(self): return is_bkgrROIs_present(self.images_path) @@ -1375,9 +1414,9 @@ def setLoadedChannelNames(self, returnList=False): loadedChNames = [] for key in fluo_keys: - chName = key[len(self.basename):] - if chName.endswith('_aligned'): - aligned_idx = chName.find('_aligned') + chName = key[len(self.basename) :] + if chName.endswith("_aligned"): + aligned_idx = chName.find("_aligned") chName = chName[:aligned_idx] loadedChNames.append(chName) @@ -1388,7 +1427,7 @@ def setLoadedChannelNames(self, returnList=False): def getPosNum(self): try: - pos_num = int(re.findall(r'Position_(\d+)', self.pos_foldername))[0] + pos_num = int(re.findall(r"Position_(\d+)", self.pos_foldername))[0] except Exception: pos_num = 0 return pos_num @@ -1397,78 +1436,77 @@ def loadLastEntriesMetadata(self): if not os.path.exists(settings_folderpath): self.last_md_df = None return - csv_path = os.path.join(settings_folderpath, 'last_entries_metadata.csv') + csv_path = os.path.join(settings_folderpath, "last_entries_metadata.csv") if not os.path.exists(csv_path): self.last_md_df = None else: parse_metadata_csv_file(csv_path) - self.last_md_df = pd.read_csv(csv_path).set_index('Description') + self.last_md_df = pd.read_csv(csv_path).set_index("Description") def saveLastEntriesMetadata(self): if not os.path.exists(settings_folderpath): return self.metadata_df.to_csv(last_entries_metadata_path) - + def getCustomAnnotColumnNames(self): - if not hasattr(self, 'customAnnot'): - return - + if not hasattr(self, "customAnnot"): + return + return natsorted(self.customAnnot.keys()) - + def saveCustomAnnotationParams(self): - if not hasattr(self, 'customAnnot'): - return - + if not hasattr(self, "customAnnot"): + return + if not self.customAnnot: return - - with open(self.custom_annot_json_path, mode='w') as file: + + with open(self.custom_annot_json_path, mode="w") as file: json.dump(self.customAnnot, file, indent=2) def addYXcentroidColsIfMissing(self, show_progress=False): if not self.segmFound: return - + if not self.acdc_df_found: return - + is_centroid_present = ( - 'y_centroid' in self.acdc_df.columns - and 'x_centroid' in self.acdc_df.columns + "y_centroid" in self.acdc_df.columns + and "x_centroid" in self.acdc_df.columns ) if is_centroid_present: return - + segm_data = self.segm_data if self.SizeT == 1: segm_data = (segm_data,) - - last_frame_i = self.acdc_df.reset_index()['frame_i'].max() + + last_frame_i = self.acdc_df.reset_index()["frame_i"].max() if show_progress: pbar = tqdm( - total=last_frame_i+1, - desc='Adding centroid columns to acdc_df' + total=last_frame_i + 1, desc="Adding centroid columns to acdc_df" ) - - for frame_i in range(last_frame_i+1): + + for frame_i in range(last_frame_i + 1): lab = segm_data[frame_i] rp = skimage.measure.regionprops(lab) for obj in rp: ID = obj.label y_centroid, x_centroid = obj.centroid[-2:] - self.acdc_df.loc[(frame_i, ID), 'y_centroid'] = y_centroid - self.acdc_df.loc[(frame_i, ID), 'x_centroid'] = x_centroid - + self.acdc_df.loc[(frame_i, ID), "y_centroid"] = y_centroid + self.acdc_df.loc[(frame_i, ID), "x_centroid"] = x_centroid + if show_progress: pbar.update(1) - + if show_progress: pbar.close() - + self.acdc_df.to_csv(self.acdc_output_csv_path) - + def getBasenameAndChNames(self, useExt=None, qparent=None): - ls = myutils.listdir(self.images_path) + ls = utils.listdir(self.images_path) selector = prompts.select_channel_name() self.chNames, _ = selector.get_available_channels( ls, self.images_path, useExt=useExt @@ -1479,41 +1517,39 @@ def getBasenameAndChNames(self, useExt=None, qparent=None): filename, _ = os.path.splitext(file) if filename != self.basename: continue - - sep = '*'*100 + + sep = "*" * 100 error_text = ( f'The file "{file}" has the same name as ' - f'the basename of all other files.\n\n' - f'Please, rename the file to include something ' + f"the basename of all other files.\n\n" + f"Please, rename the file to include something " f'after "{self.basename}", e.g., "{self.basename}_channel_name".' ) if qparent is not None: - html_error_text = f'[WARNING]: {error_text}' - html_error_text = html_error_text.replace('\n', '
    ') + html_error_text = f"[WARNING]: {error_text}" + html_error_text = html_error_text.replace("\n", "
    ") html_error_text = ( - html_error_text.replace( - f'"{file}"', f'{file}' - ).replace( - f'"{self.basename}"', f'{self.basename}' - ).replace( - f'"{self.basename}_channel_name"', - f'{self.basename}_channel_name' + html_error_text.replace(f'"{file}"', f"{file}") + .replace(f'"{self.basename}"', f"{self.basename}") + .replace( + f'"{self.basename}_channel_name"', + f"{self.basename}_channel_name", ) ) html_error_text = html_utils.paragraph(html_error_text) msg = widgets.myMessageBox(wrapText=False) - msg.warning(qparent, 'Rename files', html_error_text) - - raise FileNameError(f'\n\n{sep}\n[ERROR]: {error_text}') + msg.warning(qparent, "Rename files", html_error_text) + + raise FileNameError(f"\n\n{sep}\n[ERROR]: {error_text}") def loadImgData(self, imgPath=None, signals=None): if imgPath is None: imgPath = self.imgPath self.z0_window = 0 self.t0_window = 0 - if self.ext == '.h5': - self.h5f = h5py.File(imgPath, 'r') - self.dset = self.h5f['data'] + if self.ext == ".h5": + self.h5f = h5py.File(imgPath, "r") + self.dset = self.h5f["data"] self.img_data_shape = self.dset.shape readH5 = self.loadSizeT is not None and self.loadSizeZ is not None if not readH5: @@ -1524,33 +1560,33 @@ def loadImgData(self, imgPath=None, signals=None): is3Dt = self.SizeZ == 1 and self.SizeT > 1 is2D = self.SizeZ == 1 and self.SizeT == 1 if is4D: - midZ = int(self.SizeZ/2) - halfZLeft = int(self.loadSizeZ/2) - halfZRight = self.loadSizeZ-halfZLeft - z0 = midZ-halfZLeft - z1 = midZ+halfZRight + midZ = int(self.SizeZ / 2) + halfZLeft = int(self.loadSizeZ / 2) + halfZRight = self.loadSizeZ - halfZLeft + z0 = midZ - halfZLeft + z1 = midZ + halfZRight self.z0_window = z0 self.t0_window = 0 - self.img_data = self.dset[:self.loadSizeT, z0:z1] + self.img_data = self.dset[: self.loadSizeT, z0:z1] elif is3Dz: - midZ = int(self.SizeZ/2) - halfZLeft = int(self.loadSizeZ/2) - halfZRight = self.loadSizeZ-halfZLeft - z0 = midZ-halfZLeft - z1 = midZ+halfZRight + midZ = int(self.SizeZ / 2) + halfZLeft = int(self.loadSizeZ / 2) + halfZRight = self.loadSizeZ - halfZLeft + z0 = midZ - halfZLeft + z1 = midZ + halfZRight self.z0_window = z0 self.img_data = np.squeeze(self.dset[z0:z1]) elif is3Dt: self.t0_window = 0 - self.img_data = np.squeeze(self.dset[:self.loadSizeT]) + self.img_data = np.squeeze(self.dset[: self.loadSizeT]) elif is2D: self.img_data = np.squeeze(self.dset[:]) - elif self.ext == '.npz': - self.img_data = np.squeeze(np.load(imgPath)['arr_0']) + elif self.ext == ".npz": + self.img_data = np.squeeze(np.load(imgPath)["arr_0"]) self.dset = self.img_data self.img_data_shape = self.img_data.shape - elif self.ext == '.npy': + elif self.ext == ".npy": self.img_data = np.squeeze(np.load(imgPath)) self.dset = self.img_data self.img_data_shape = self.img_data.shape @@ -1562,7 +1598,7 @@ def loadImgData(self, imgPath=None, signals=None): except Exception as err: traceback.print_exc() self.criticalExtNotValid(signals=signals) - elif self.ext in VIDEO_EXTENSIONS: + elif self.ext in VIDEO_EXTENSIONS: try: self.img_data = self._loadVideo(imgPath) self.dset = self.img_data @@ -1572,11 +1608,11 @@ def loadImgData(self, imgPath=None, signals=None): self.criticalExtNotValid(signals=signals) else: self.criticalExtNotValid(signals=signals) - + def loadChannelData(self, channelName): if channelName == self.user_ch_name: return self.img_data - + dataPath = get_filename_from_channel(self.images_path, channelName) if dataPath: data = load_image_file(dataPath) @@ -1586,24 +1622,21 @@ def loadChannelData(self, channelName): def init_segmInfo_df(self): if self.SizeZ > 1 and self.segmInfo_df is not None: - if 'z_slice_used_gui' not in self.segmInfo_df.columns: - self.segmInfo_df['z_slice_used_gui'] = ( - self.segmInfo_df['z_slice_used_dataPrep'] - ) - if 'which_z_proj_gui' not in self.segmInfo_df.columns: - self.segmInfo_df['which_z_proj_gui'] = ( - self.segmInfo_df['which_z_proj'] - ) - self.segmInfo_df['resegmented_in_gui'] = False + if "z_slice_used_gui" not in self.segmInfo_df.columns: + self.segmInfo_df["z_slice_used_gui"] = self.segmInfo_df[ + "z_slice_used_dataPrep" + ] + if "which_z_proj_gui" not in self.segmInfo_df.columns: + self.segmInfo_df["which_z_proj_gui"] = self.segmInfo_df["which_z_proj"] + self.segmInfo_df["resegmented_in_gui"] = False self.segmInfo_df.to_csv(self.segmInfo_df_csv_path) NO_segmInfo = ( - self.segmInfo_df is None - or self.filename not in self.segmInfo_df.index + self.segmInfo_df is None or self.filename not in self.segmInfo_df.index ) if NO_segmInfo and self.SizeZ > 1: filename = self.filename - df = myutils.getDefault_SegmInfo_df(self, filename) + df = utils.getDefault_SegmInfo_df(self, filename) if self.segmInfo_df is None: self.segmInfo_df = df else: @@ -1611,7 +1644,7 @@ def init_segmInfo_df(self): unique_idx = ~self.segmInfo_df.index.duplicated() self.segmInfo_df = self.segmInfo_df[unique_idx] self.segmInfo_df.to_csv(self.segmInfo_df_csv_path) - + def _loadVideo(self, path): video = cv2.VideoCapture(path) num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) @@ -1626,247 +1659,243 @@ def _loadVideo(self, path): def countObjectsInSegmTimelapse(self, categories: set[str] | list[str]): numObjsCurrentFrame = len(self.IDs) - + uniqueIDsVisited = None uniqueIDsAll = None numObjsVisitedFrames = None numObjsTotal = None - if 'Unique objects in all visited frames' in categories: + if "Unique objects in all visited frames" in categories: uniqueIDsVisited = set() - - if 'Unique objects in entire video' in categories: + + if "Unique objects in entire video" in categories: uniqueIDsAll = set() - - if 'In all visited frames' in categories: + + if "In all visited frames" in categories: numObjsVisitedFrames = 0 - - if 'In entire video' in categories: + + if "In entire video" in categories: numObjsTotal = 0 - + for frame_i in range(len(self.segm_data)): - lab = self.allData_li[frame_i]['labels'] + lab = self.allData_li[frame_i]["labels"] if lab is not None: - IDsFrame = self.allData_li[frame_i]['IDs'] - + IDsFrame = self.allData_li[frame_i]["IDs"] + if uniqueIDsVisited is not None: uniqueIDsVisited.update(IDsFrame) - + if uniqueIDsAll is not None: uniqueIDsAll.update(IDsFrame) - + numObjsFrame = len(IDsFrame) - + if numObjsVisitedFrames is not None: numObjsVisitedFrames += numObjsFrame - + if numObjsTotal is not None: numObjsTotal += numObjsFrame else: lab = self.segm_data[frame_i] - + if numObjsTotal is not None or numObjsTotal is not None: rp = skimage.measure.regionprops(self.segm_data[frame_i]) - + if numObjsTotal is not None: numObjsTotal += len(rp) - + if uniqueIDsAll is not None: uniqueIDsAll.update([obj.label for obj in rp]) - + numUniqueObjsVisitedFrames = None if uniqueIDsVisited is not None: numUniqueObjsVisitedFrames = len(uniqueIDsVisited) - + numUniqueObjsTotal = None if uniqueIDsAll is not None: numUniqueObjsTotal = len(uniqueIDsAll) - + allCategoryCountMapper = { - 'In current frame': numObjsCurrentFrame, - 'In all visited frames': numObjsVisitedFrames, - 'In entire video': numObjsTotal, - 'Unique objects in all visited frames': numUniqueObjsVisitedFrames, - 'Unique objects in entire video': numUniqueObjsTotal + "In current frame": numObjsCurrentFrame, + "In all visited frames": numObjsVisitedFrames, + "In entire video": numObjsTotal, + "Unique objects in all visited frames": numUniqueObjsVisitedFrames, + "Unique objects in entire video": numUniqueObjsTotal, } - + return allCategoryCountMapper - + def countObjectsInSegmSnapshots(self, categories: set[str] | list[str]): - if hasattr(self, 'IDs'): + if hasattr(self, "IDs"): numObjs = len(self.IDs) else: lab = np.squeeze(self.segm_data) rp = skimage.measure.regionprops(lab) numObjs = len(rp) - - mapper = { - 'In current position': numObjs - } - + + mapper = {"In current position": numObjs} + return mapper - - def countObjectsInSegm(self, categories: set[str] | list[str] | None=None): + + def countObjectsInSegm(self, categories: set[str] | list[str] | None = None): if self.SizeT > 1: if categories is None: - categories = ['In entire video'] - + categories = ["In entire video"] + return self.countObjectsInSegmTimelapse(categories) else: if categories is None: - categories = ['In current position'] - + categories = ["In current position"] + return self.countObjectsInSegmSnapshots(categories) - + def saveObjCounts(self, countMapper: dict[str, int]): df = pd.DataFrame(countMapper, index=[0]) segmFilename = os.path.basename(self.segm_npz_path) - segmEndname = segmFilename[len(self.basename):] - dfCountEndname = ( - segmEndname - .replace('segm', 'acdc_objects_count') - .replace('.npz', '.csv') + segmEndname = segmFilename[len(self.basename) :] + dfCountEndname = segmEndname.replace("segm", "acdc_objects_count").replace( + ".npz", ".csv" ) - - dfCountFilename = f'{self.basename}{dfCountEndname}' + + dfCountFilename = f"{self.basename}{dfCountEndname}" dfCountFilepath = os.path.join(self.images_path, dfCountFilename) - + df.to_csv(dfCountFilepath, index=False) - + return dfCountEndname - - + def detectMultiSegmNpz( - self, multiPos=False, signals=None, - mutex=None, waitCond=None, askMultiSegmFunc=None, - newEndFilenameSegm='' - ): + self, + multiPos=False, + signals=None, + mutex=None, + waitCond=None, + askMultiSegmFunc=None, + newEndFilenameSegm="", + ): if newEndFilenameSegm: - return '', newEndFilenameSegm, False + return "", newEndFilenameSegm, False segm_files = get_segm_files(self.images_path) if askMultiSegmFunc is None: return segm_files - is_multi_npz = len(segm_files)>0 + is_multi_npz = len(segm_files) > 0 if is_multi_npz and askMultiSegmFunc is not None: askMultiSegmFunc(segm_files, self, waitCond) - endFilename = self.selectedItemText[len(self.basename):] + endFilename = self.selectedItemText[len(self.basename) :] return self.selectedItemText, endFilename, self.cancel - elif len(segm_files)==1: + elif len(segm_files) == 1: segmFilename = segm_files[0] - endFilename = segmFilename[len(self.basename):] + endFilename = segmFilename[len(self.basename) :] return segm_files[0], endFilename, False else: - return '', '', False + return "", "", False def readLastUsedStopFrameNumber(self): - if not hasattr(self, 'metadata_df'): + if not hasattr(self, "metadata_df"): return - + if self.metadata_df is None: return - + try: - stop_frame_num = int(self.metadata_df.at['stop_frame_num', 'values']) + stop_frame_num = int(self.metadata_df.at["stop_frame_num", "values"]) except Exception as err: stop_frame_num = None - + return stop_frame_num - + def getSamEmbeddingsPath(self): - sam_embed_filename = ( - f'{self.basename}_{self.user_ch_name}_sam_embeddings.pt' - ) + sam_embed_filename = f"{self.basename}_{self.user_ch_name}_sam_embeddings.pt" sam_embeddings_path = os.path.join(self.images_path, sam_embed_filename) return sam_embeddings_path - + def storeSamEmbeddings(self, samAcdcSegment, frame_i=0, z=0): # See here how to save embeddings # https://github.com/facebookresearch/segment-anything/issues/217 - - if not hasattr(self, 'sam_embeddings'): + + if not hasattr(self, "sam_embeddings"): self.sam_embeddings = {} - + if frame_i not in self.sam_embeddings: self.sam_embeddings[frame_i] = {} - - if hasattr(samAcdcSegment.model, 'predictor'): + + if hasattr(samAcdcSegment.model, "predictor"): predictor = samAcdcSegment.model.predictor else: predictor = samAcdcSegment.model - + embedding = { - 'original_size': predictor.original_size, - 'input_size': predictor.input_size, - 'features': predictor.features, - 'is_image_set': True, + "original_size": predictor.original_size, + "input_size": predictor.input_size, + "features": predictor.features, + "is_image_set": True, } self.sam_embeddings[frame_i][z] = embedding - - def saveSamEmbeddings(self, logger_func=print): - if not hasattr(self, 'sam_embeddings'): - return - - logger_func( - f'\nSaving SAM image embeddings to "{self.sam_embeddings_path}"...' - ) + + def saveSamEmbeddings(self, logger_func=print): + if not hasattr(self, "sam_embeddings"): + return + + logger_func(f'\nSaving SAM image embeddings to "{self.sam_embeddings_path}"...') import torch + torch.save(self.sam_embeddings, self.sam_embeddings_path) - + def loadSamEmbeddings(self, force_reload=False, logger_func=None): - if hasattr(self, 'sam_embeddings') and not force_reload: - return - + if hasattr(self, "sam_embeddings") and not force_reload: + return + if not os.path.exists(self.sam_embeddings_path): return - + if logger_func is not None: logger_func( f'\nLoading SAM image embeddings from "{self.sam_embeddings_path}"...' ) - + import torch + self.sam_embeddings = torch.load(self.sam_embeddings_path) - + def getSamEmbeddings(self, frame_i=0, z=0): - if not hasattr(self, 'sam_embeddings'): - return - + if not hasattr(self, "sam_embeddings"): + return + frame_embeddings = self.sam_embeddings.get(frame_i) if frame_embeddings is None: return - + img_embeddings = frame_embeddings.get(z) if img_embeddings is None: return - + return img_embeddings - - + def loadOtherFiles( - self, - load_segm_data=True, - create_new_segm=False, - load_acdc_df=False, - load_shifts=False, - loadSegmInfo=False, - load_delROIsInfo=False, - load_bkgr_data=False, - loadBkgrROIs=False, - load_last_tracked_i=False, - load_metadata=False, - load_dataPrep_ROIcoords=False, - load_customAnnot=False, - load_customCombineMetrics=False, - load_manual_bkgr_lab=False, - load_dataprep_free_roi=False, - getTifPath=False, - end_filename_segm='', - new_endname='', - labelBoolSegm=None, - load_whitelistIDs=False, - ): + self, + load_segm_data=True, + create_new_segm=False, + load_acdc_df=False, + load_shifts=False, + loadSegmInfo=False, + load_delROIsInfo=False, + load_bkgr_data=False, + loadBkgrROIs=False, + load_last_tracked_i=False, + load_metadata=False, + load_dataPrep_ROIcoords=False, + load_customAnnot=False, + load_customCombineMetrics=False, + load_manual_bkgr_lab=False, + load_dataprep_free_roi=False, + getTifPath=False, + end_filename_segm="", + new_endname="", + labelBoolSegm=None, + load_whitelistIDs=False, + ): self.segmFound = False if load_segm_data else None self.acdc_df_found = False if load_acdc_df else None self.shiftsFound = False if load_shifts else None @@ -1883,36 +1912,37 @@ def loadOtherFiles( self.dataPrepFreeRoiPoints = [] self.labelBoolSegm = labelBoolSegm self.bkgrDataExists = False - ls = myutils.listdir(self.images_path) - + ls = utils.listdir(self.images_path) + if end_filename_segm: - end_filename_segm = end_filename_segm.replace('.npz', '') + end_filename_segm = end_filename_segm.replace(".npz", "") linked_acdc_filename = None if end_filename_segm and load_acdc_df: # Check if there is an acdc_output file linked to selected .npz - _acdc_df_end_fn = end_filename_segm.replace('segm', 'acdc_output') - _acdc_df_end_fn = f'{_acdc_df_end_fn}.csv' + _acdc_df_end_fn = end_filename_segm.replace("segm", "acdc_output") + _acdc_df_end_fn = f"{_acdc_df_end_fn}.csv" self._acdc_df_end_fn = _acdc_df_end_fn - _linked_acdc_fn = f'{self.basename}{_acdc_df_end_fn}' + _linked_acdc_fn = f"{self.basename}{_acdc_df_end_fn}" acdc_df_path = os.path.join(self.images_path, _linked_acdc_fn) self.acdc_output_csv_path = acdc_df_path linked_acdc_filename = _linked_acdc_fn - - if not hasattr(self, 'basename'): + + if not hasattr(self, "basename"): self.getBasenameAndChNames() dataPrepFreeRoiPath = self.dataPrepFreeRoiPath() dataPrepFreeRoiFilename = os.path.basename(dataPrepFreeRoiPath) - + for file in ls: filePath = os.path.join(self.images_path, file) filename, segmExt = os.path.splitext(file) - endName = filename[len(self.basename):] + endName = filename[len(self.basename) :] loadMetadata = ( - load_metadata and file.endswith('metadata.csv') - and not file.endswith('segm_metadata.csv') + load_metadata + and file.endswith("metadata.csv") + and not file.endswith("segm_metadata.csv") ) if new_endname: @@ -1923,10 +1953,10 @@ def loadOtherFiles( elif end_filename_segm: # Load the segmentation file selected by the user self._segm_end_fn = end_filename_segm - is_segm_file = endName == end_filename_segm and segmExt == '.npz' + is_segm_file = endName == end_filename_segm and segmExt == ".npz" else: # Load default segmentation file - is_segm_file = file.endswith('segm.npz') + is_segm_file = file.endswith("segm.npz") if linked_acdc_filename is not None: is_acdc_df_file = file == linked_acdc_filename @@ -1935,13 +1965,13 @@ def loadOtherFiles( # do not load acdc_df file is_acdc_df_file = False else: - is_acdc_df_file = file.endswith('acdc_output.csv') + is_acdc_df_file = file.endswith("acdc_output.csv") is_acdc_df_file = file == linked_acdc_filename - + if load_dataprep_free_roi and file == dataPrepFreeRoiFilename: self.loadDataPrepFreeRoi() - + if load_segm_data and is_segm_file and not create_new_segm: self.segmFound = True self.segm_npz_path = filePath @@ -1956,16 +1986,16 @@ def loadOtherFiles( if squeezed_arr.shape != self.segm_data.shape: self.segm_data = squeezed_arr io.savez_compressed(filePath, squeezed_arr) - elif getTifPath and file.find(f'{self.user_ch_name}.tif')!=-1: + elif getTifPath and file.find(f"{self.user_ch_name}.tif") != -1: self.tif_path = filePath self.TifPathFound = True elif load_acdc_df and is_acdc_df_file and not create_new_segm: self.acdc_df_found = True self.loadAcdcDf(filePath) - elif load_shifts and file.endswith('align_shift.npy'): + elif load_shifts and file.endswith("align_shift.npy"): self.shiftsFound = True self.loaded_shifts = np.load(filePath) - elif loadSegmInfo and file.endswith('segmInfo.csv'): + elif loadSegmInfo and file.endswith("segmInfo.csv"): self.segmInfoFound = True try: remove_duplicates_file(filePath) @@ -1973,77 +2003,77 @@ def loadOtherFiles( printl(filePath) printl(traceback.format_exc()) df = pd.read_csv(filePath).dropna() - # In some old versions, there was a bug that removed the - # 'filename', and the 'frame_i' column names, so - # we check if they are not present and rename the + # In some old versions, there was a bug that removed the + # 'filename', and the 'frame_i' column names, so + # we check if they are not present and rename the # 'Unnamed: 0' and 'Unnamed: 1' to filename and frame_i - if 'Unnamed: 0' in df.columns and 'Unnamed: 1' in df.columns: + if "Unnamed: 0" in df.columns and "Unnamed: 1" in df.columns: df = df.rename( - columns={ - 'Unnamed: 0': 'filename', - 'Unnamed: 1': 'frame_i' - } + columns={"Unnamed: 0": "filename", "Unnamed: 1": "frame_i"} ) - if 'filename' not in df.columns: - df['filename'] = self.filename - df = df.set_index(['filename', 'frame_i']).sort_index() + if "filename" not in df.columns: + df["filename"] = self.filename + df = df.set_index(["filename", "frame_i"]).sort_index() df = df[~df.index.duplicated()] self.segmInfo_df = df.sort_index() self.segmInfo_df.to_csv(filePath) - elif load_delROIsInfo and file.endswith('delROIsInfo.npz'): + elif load_delROIsInfo and file.endswith("delROIsInfo.npz"): self.delROIsInfoFound = True self.delROIsInfo_npz = np.load(filePath) - elif file.endswith(f'{self.filename}_bkgrRoiData.npz'): + elif file.endswith(f"{self.filename}_bkgrRoiData.npz"): self.bkgrDataExists = True if load_bkgr_data: self.bkgrDataFound = True self.bkgrData = np.load(filePath) - elif loadBkgrROIs and file.endswith('dataPrep_bkgrROIs.json'): + elif loadBkgrROIs and file.endswith("dataPrep_bkgrROIs.json"): self.bkgrROisFound = True with open(filePath) as json_fp: bkgROIs_states = json.load(json_fp) - if hasattr(self, 'img_data'): + if hasattr(self, "img_data"): for roi_state in bkgROIs_states: Y, X = self.img_data.shape[-2:] roi = pg.ROI( - [0, 0], [1, 1], + [0, 0], + [1, 1], rotatable=False, removable=False, - pen=pg.mkPen(color=(150,150,150)), - maxBounds=QRectF(QRect(0,0,X,Y)), + pen=pg.mkPen(color=(150, 150, 150)), + maxBounds=QRectF(QRect(0, 0, X, Y)), scaleSnap=True, - translateSnap=True + translateSnap=True, ) roi.setState(roi_state) self.bkgrROIs.append(roi) - elif load_dataPrep_ROIcoords and file.endswith('dataPrepROIs_coords.csv'): + elif load_dataPrep_ROIcoords and file.endswith("dataPrepROIs_coords.csv"): df = pd.read_csv(filePath) - if 'roi_id' not in df.columns: - df['roi_id'] = 0 - if 'description' in df.columns and 'value' in df.columns: - df = df.set_index(['roi_id', 'description']) + if "roi_id" not in df.columns: + df["roi_id"] = 0 + if "description" in df.columns and "value" in df.columns: + df = df.set_index(["roi_id", "description"]) self.dataPrep_ROIcoordsFound = True self.dataPrep_ROIcoords = df elif loadMetadata: self.metadataFound = True remove_duplicates_file(filePath) parse_metadata_csv_file(filePath) - self.metadata_df = pd.read_csv(filePath).set_index('Description') - elif load_customAnnot and file.endswith('custom_annot_params.json'): + self.metadata_df = pd.read_csv(filePath).set_index("Description") + elif load_customAnnot and file.endswith("custom_annot_params.json"): self.customAnnotFound = True self.customAnnot = read_json(filePath) - elif load_customCombineMetrics and file.endswith('custom_combine_metrics.ini'): + elif load_customCombineMetrics and file.endswith( + "custom_combine_metrics.ini" + ): self.combineMetricsFound = True self.setCombineMetricsConfig(ini_path=filePath) if self.metadataFound is not None and self.metadataFound: self.extractMetadata() - + # Check if there is the old segm.npy if not self.segmFound and not create_new_segm: for file in ls: - is_segm_npy = file.endswith('segm.npy') + is_segm_npy = file.endswith("segm.npy") filePath = os.path.join(self.images_path, file) if load_segm_data and is_segm_npy and not self.segmFound: self.segmFound = True @@ -2053,8 +2083,7 @@ def loadOtherFiles( self.last_tracked_i_found = True try: self.last_tracked_i = max( - self.acdc_df.index.get_level_values(0), - default=None + self.acdc_df.index.get_level_values(0), default=None ) except AttributeError as e: # traceback.print_exc() @@ -2069,117 +2098,114 @@ def loadOtherFiles( if load_whitelistIDs: self.loadWhitelist() - + def checkAndFixZsliceSegmInfo(self): - if not hasattr(self, 'segmInfo_df'): + if not hasattr(self, "segmInfo_df"): return - + if self.segmInfo_df is None: return - - if not hasattr(self, 'SizeZ'): + + if not hasattr(self, "SizeZ"): return - + if self.SizeZ == 1: return - - middleZslice = int(self.SizeZ/2) - + + middleZslice = int(self.SizeZ / 2) + try: - mask = self.segmInfo_df['z_slice_used_dataPrep'] >= self.SizeZ + mask = self.segmInfo_df["z_slice_used_dataPrep"] >= self.SizeZ valid_idx = mask[mask].index - self.segmInfo_df.loc[valid_idx, 'z_slice_used_dataPrep'] = middleZslice + self.segmInfo_df.loc[valid_idx, "z_slice_used_dataPrep"] = middleZslice except Exception as err: pass - + try: - mask = self.segmInfo_df['z_slice_used_gui'] >= self.SizeZ + mask = self.segmInfo_df["z_slice_used_gui"] >= self.SizeZ valid_idx = mask[mask].index - self.segmInfo_df.loc[valid_idx, 'z_slice_used_gui'] = middleZslice + self.segmInfo_df.loc[valid_idx, "z_slice_used_gui"] = middleZslice except Exception as err: pass - + def loadMostRecentUnsavedAcdcDf(self): acdc_df = get_last_stored_unsaved_acdc_df(self.recoveryFolderpath()) if acdc_df is None: return self.acdc_df = acdc_df self.acdc_df_found = True - self.last_tracked_i = max( - self.acdc_df.index.get_level_values(0), - default=None - ) - + self.last_tracked_i = max(self.acdc_df.index.get_level_values(0), default=None) + def loadAcdcDf(self, filePath, updatePaths=True, return_df=False): acdc_df = _load_acdc_df_file(filePath) if acdc_df.empty: self.acdc_df_found = False return - + if updatePaths: self.acdc_df = acdc_df self.acdc_df_found = True self.last_tracked_i = max( - self.acdc_df.index.get_level_values(0), - default=None + self.acdc_df.index.get_level_values(0), default=None ) if return_df: return acdc_df - + def dataPrepFreeRoiPath(self): dataPrepFreeRoiPath = os.path.join( - self.images_path, f'{self.basename}dataPrepFreeRoi.npz' + self.images_path, f"{self.basename}dataPrepFreeRoi.npz" ) return dataPrepFreeRoiPath - + def saveDataPrepFreeRoi( - self, - roiItem: 'widgets.PlotCurveItem', - logger_func=print, - local_mask=None, bbox=None - ): + self, + roiItem: "widgets.PlotCurveItem", + logger_func=print, + local_mask=None, + bbox=None, + ): dataPrepFreeRoiPath = self.dataPrepFreeRoiPath() - + logger_func(f'\nSaving free ROI to file "{dataPrepFreeRoiPath}"...') - + if local_mask is None: local_mask = roiItem.mask() - + if bbox is None: bbox = roiItem.bbox() - + y0, x0, y1, x1 = bbox - key = f'{x0}_{y0}_{x1}_{y1}' + key = f"{x0}_{y0}_{x1}_{y1}" data = {key: local_mask} np.savez_compressed(dataPrepFreeRoiPath, **data) - + def removeDataPrepFreeRoi(self, logger_func=print): self.dataPrepFreeRoiPoints = [] dataPrepFreeRoiPath = self.dataPrepFreeRoiPath() if not os.path.exists(dataPrepFreeRoiPath): return - + logger_func(f'\nRemoving free ROI file "{dataPrepFreeRoiPath}"...') os.remove(dataPrepFreeRoiPath) - + def loadDataPrepFreeRoi(self, logger_func=print): self.dataPrepFreeRoiPoints = [] dataPrepFreeRoiPath = self.dataPrepFreeRoiPath() if not os.path.exists(dataPrepFreeRoiPath): return - + logger_func(f'\nLoading free ROI from file "{dataPrepFreeRoiPath}"...') archive = np.load(dataPrepFreeRoiPath) key = archive.files[0] - x0, y0, x1, y1 = [int(coord) for coord in key.split('_')] + x0, y0, x1, y1 = [int(coord) for coord in key.split("_")] mask = archive[key] obj = skimage.measure.regionprops(mask.astype(np.uint8))[0] contours = core.get_obj_contours(obj=obj, only_longest_contour=False) self.dataPrepFreeRoiPoints = contours + (int(x0), int(y0)) self.dataPrepFreeRoiLocalMask = mask - self.dataPrepFreeRoiSlice = (slice(y0, y1+1), slice(x0, x1+1)) + self.dataPrepFreeRoiSlice = (slice(y0, y1 + 1), slice(x0, x1 + 1)) self.dataPrepFreeRoiBbox = (y0, x0, y1, x1) - + def clearSegmObjsDataPrepFreeRoi(self, segm_data, is_timelapse=True): local_mask = self.dataPrepFreeRoiLocalMask local_slice = self.dataPrepFreeRoiSlice @@ -2193,7 +2219,7 @@ def clearSegmObjsDataPrepFreeRoi(self, segm_data, is_timelapse=True): for obj in rp: if not np.any(delMask[obj.slice][obj.image]): continue - + lab[obj.slice][obj.image] = 0 segm_data[i] = lab else: @@ -2204,17 +2230,18 @@ def clearSegmObjsDataPrepFreeRoi(self, segm_data, is_timelapse=True): for obj in rp: if not np.any(delMask[obj.slice][obj.image]): continue - + lab[obj.slice][obj.image] = 0 segm_data = lab - + return segm_data - + def getSpotmaxSingleSpotsfiles(self): from spotmax import DFs_FILENAMES - spotmax_files = myutils.listdir(self.spotmax_out_path) + + spotmax_files = utils.listdir(self.spotmax_out_path) patterns = [ - filename.replace('*rn*', '').replace('*desc*', '') + filename.replace("*rn*", "").replace("*desc*", "") for filename in DFs_FILENAMES.values() ] valid_files = [] @@ -2222,7 +2249,7 @@ def getSpotmaxSingleSpotsfiles(self): filepath = os.path.join(self.spotmax_out_path, file) if not os.path.isfile(filepath): continue - if file.endswith('aggregated.csv'): + if file.endswith("aggregated.csv"): continue for pattern in patterns: if file.find(pattern) != -1: @@ -2230,24 +2257,26 @@ def getSpotmaxSingleSpotsfiles(self): else: continue valid_files.append(file) - + return reversed(valid_files) def askBooleanSegm(self): segmFilename = os.path.basename(self.segm_npz_path) msg = widgets.myMessageBox() txt = html_utils.paragraph( - f'The loaded segmentation file

    ' + f"The loaded segmentation file

    " f'"{segmFilename}"

    ' - 'has boolean data type.

    ' - 'To correctly load it, Cell-ACDC needs to convert it ' - 'to integer data type.

    ' - 'Do you want to label the mask to separate the objects ' - '(recommended) or do you want to keep one single object?
    ' + "has boolean data type.

    " + "To correctly load it, Cell-ACDC needs to convert it " + "to integer data type.

    " + "Do you want to label the mask to separate the objects " + "(recommended) or do you want to keep one single object?
    " ) - LabelButton, _ = msg.question( - self.parent, 'Boolean segmentation mask?', txt, - buttonsTexts=('Label (recommended)', 'Keep single object') + LabelButton, _ = msg.question( + self.parent, + "Boolean segmentation mask?", + txt, + buttonsTexts=("Label (recommended)", "Keep single object"), ) if msg.clickedButton == LabelButton: self.labelBoolSegm = True @@ -2273,64 +2302,62 @@ def labelSegmData(self): self.segm_data = self.segm_data.astype(np.uint32) def setFilePaths(self, new_endname): - if self.basename.endswith('_'): + if self.basename.endswith("_"): basename = self.basename else: - basename = f'{self.basename}_' + basename = f"{self.basename}_" if new_endname: - segm_new_filename = f'{basename}segm_{new_endname}.npz' - acdc_output_filename = f'{basename}acdc_output_{new_endname}.csv' + segm_new_filename = f"{basename}segm_{new_endname}.npz" + acdc_output_filename = f"{basename}acdc_output_{new_endname}.csv" else: - segm_new_filename = f'{basename}segm.npz' - acdc_output_filename = f'{basename}acdc_output.csv' - + segm_new_filename = f"{basename}segm.npz" + acdc_output_filename = f"{basename}acdc_output.csv" + filePath = os.path.join(self.images_path, segm_new_filename) self.segm_npz_path = filePath filePath = os.path.join(self.images_path, acdc_output_filename) self.acdc_output_csv_path = filePath - - def fromTrackerToAcdcDf( - self, tracker, tracked_video, save=False, start_frame_i=0 - ): - cca_dfs_attr = hasattr(tracker, 'cca_dfs') - cca_dfs_auto_attr = hasattr(tracker, 'cca_dfs_auto') - if hasattr(tracker, 'tracked_lost_centroids'): + def fromTrackerToAcdcDf(self, tracker, tracked_video, save=False, start_frame_i=0): + cca_dfs_attr = hasattr(tracker, "cca_dfs") + cca_dfs_auto_attr = hasattr(tracker, "cca_dfs_auto") + + if hasattr(tracker, "tracked_lost_centroids"): self.saveTrackedLostCentroids(tracker.tracked_lost_centroids) if not cca_dfs_attr and not cca_dfs_auto_attr: return - + if cca_dfs_attr: - end_frame_i = start_frame_i+len(tracker.cca_dfs) + end_frame_i = start_frame_i + len(tracker.cca_dfs) keys = list(range(start_frame_i, end_frame_i)) - acdc_df = pd.concat(tracker.cca_dfs, keys=keys, names=['frame_i']) + acdc_df = pd.concat(tracker.cca_dfs, keys=keys, names=["frame_i"]) else: - end_frame_i = start_frame_i+len(tracker.cca_dfs_auto) + end_frame_i = start_frame_i + len(tracker.cca_dfs_auto) keys = list(range(start_frame_i, end_frame_i)) - acdc_df = pd.concat(tracker.cca_dfs_auto, keys=keys, names=['frame_i']) + acdc_df = pd.concat(tracker.cca_dfs_auto, keys=keys, names=["frame_i"]) - acdc_df['is_cell_dead'] = 0 - acdc_df['is_cell_excluded'] = 0 - acdc_df['was_manually_edited'] = 0 - acdc_df['x_centroid'] = 0 - acdc_df['y_centroid'] = 0 + acdc_df["is_cell_dead"] = 0 + acdc_df["is_cell_excluded"] = 0 + acdc_df["was_manually_edited"] = 0 + acdc_df["x_centroid"] = 0 + acdc_df["y_centroid"] = 0 for i, lab in enumerate(tracked_video): frame_i = start_frame_i + i rp = skimage.measure.regionprops(lab) for obj in rp: centroid = obj.centroid yc, xc = obj.centroid[-2:] - acdc_df.at[(frame_i, obj.label), 'x_centroid'] = int(xc) - acdc_df.at[(frame_i, obj.label), 'y_centroid'] = int(yc) + acdc_df.at[(frame_i, obj.label), "x_centroid"] = int(xc) + acdc_df.at[(frame_i, obj.label), "y_centroid"] = int(yc) if len(centroid) == 3: - if 'z_centroid' not in acdc_df.columns: - acdc_df['z_centroid'] = 0 + if "z_centroid" not in acdc_df.columns: + acdc_df["z_centroid"] = 0 zc = obj.centroid[0] - acdc_df.at[(frame_i, obj.label), 'z_centroid'] = int(zc) + acdc_df.at[(frame_i, obj.label), "z_centroid"] = int(zc) if not save: return acdc_df @@ -2343,29 +2370,29 @@ def fromTrackerToAcdcDf( acdc_df.to_csv(self.acdc_output_auto_csv_path) def getAcdcDfEndname(self): - if not hasattr(self, 'acdc_output_csv_path'): + if not hasattr(self, "acdc_output_csv_path"): return - - if not hasattr(self, 'basename'): + + if not hasattr(self, "basename"): return - + filename = os.path.basename(self.acdc_output_csv_path) filename, _ = os.path.splitext(filename) - endname = filename[len(self.basename):].lstrip('_') + endname = filename[len(self.basename) :].lstrip("_") return endname - + def getSegmEndname(self): - if not hasattr(self, 'segm_npz_path'): + if not hasattr(self, "segm_npz_path"): return - - if not hasattr(self, 'basename'): + + if not hasattr(self, "basename"): return - + filename = os.path.basename(self.segm_npz_path) filename, _ = os.path.splitext(filename) - endname = filename[len(self.basename):].lstrip('_') + endname = filename[len(self.basename) :].lstrip("_") return endname - + def getCustomAnnotatedIDs(self): self.customAnnotIDs = {} @@ -2387,7 +2414,7 @@ def getCustomAnnotatedIDs(self): self.acdc_df[name] = 0 for frame_i, df in self.acdc_df.groupby(level=0): series = df[name] - series = series[series>0] + series = series[series > 0] annotatedIDs = list(series.index.get_level_values(1).unique()) self.customAnnotIDs[name][frame_i] = annotatedIDs @@ -2395,12 +2422,14 @@ def isCropped(self): if self.dataPrep_ROIcoords is None: return False df = self.dataPrep_ROIcoords - _isCropped = any([ - df_roi.at[(roi_id, 'cropped'), 'value'] > 0 - for roi_id, df_roi in df.groupby(level=0) - ]) + _isCropped = any( + [ + df_roi.at[(roi_id, "cropped"), "value"] > 0 + for roi_id, df_roi in df.groupby(level=0) + ] + ) return _isCropped - + def getIsSegm3D(self): if self.SizeZ == 1: return False @@ -2411,7 +2440,7 @@ def getIsSegm3D(self): if not self.segmFound: return - if hasattr(self, 'img_data'): + if hasattr(self, "img_data"): return self.segm_data.ndim == self.img_data.ndim else: if self.SizeT > 1: @@ -2420,142 +2449,132 @@ def getIsSegm3D(self): return self.segm_data.ndim == 3 def getBytesImageData(self): - if not hasattr(self, 'img_data'): + if not hasattr(self, "img_data"): return 0 - + return sys.getsizeof(self.img_data) - + def extractMetadata(self): - self.metadata_df['values'] = self.metadata_df['values'].astype(str) - if 'SizeT' in self.metadata_df.index: - self.SizeT = float(self.metadata_df.at['SizeT', 'values']) + self.metadata_df["values"] = self.metadata_df["values"].astype(str) + if "SizeT" in self.metadata_df.index: + self.SizeT = float(self.metadata_df.at["SizeT", "values"]) self.SizeT = int(self.SizeT) - elif self.last_md_df is not None and 'SizeT' in self.last_md_df.index: - self.SizeT = float(self.last_md_df.at['SizeT', 'values']) + elif self.last_md_df is not None and "SizeT" in self.last_md_df.index: + self.SizeT = float(self.last_md_df.at["SizeT", "values"]) self.SizeT = int(self.SizeT) else: self.SizeT = 1 self.SizeZ_found = False - if 'SizeZ' in self.metadata_df.index: - self.SizeZ = float(self.metadata_df.at['SizeZ', 'values']) + if "SizeZ" in self.metadata_df.index: + self.SizeZ = float(self.metadata_df.at["SizeZ", "values"]) self.SizeZ = int(self.SizeZ) self.SizeZ_found = True - elif self.last_md_df is not None and 'SizeZ' in self.last_md_df.index: - self.SizeZ = float(self.last_md_df.at['SizeZ', 'values']) + elif self.last_md_df is not None and "SizeZ" in self.last_md_df.index: + self.SizeZ = float(self.last_md_df.at["SizeZ", "values"]) self.SizeZ = int(self.SizeZ) else: self.SizeZ = 1 - if 'SizeY' in self.metadata_df.index: - self.SizeY = float(self.metadata_df.at['SizeY', 'values']) + if "SizeY" in self.metadata_df.index: + self.SizeY = float(self.metadata_df.at["SizeY", "values"]) self.SizeY = int(self.SizeY) - self.SizeX = float(self.metadata_df.at['SizeX', 'values']) + self.SizeX = float(self.metadata_df.at["SizeX", "values"]) self.SizeX = int(self.SizeX) else: - if hasattr(self, 'img_data_shape'): + if hasattr(self, "img_data_shape"): self.SizeY, self.SizeX = self.img_data_shape[-2:] else: self.SizeY, self.SizeX = 1, 1 self.isSegm3D = False - if hasattr(self, 'segm_npz_path'): + if hasattr(self, "segm_npz_path"): segmEndName = self.getSegmEndname() - isSegm3Dkey = f'{segmEndName}_isSegm3D' + isSegm3Dkey = f"{segmEndName}_isSegm3D" if isSegm3Dkey in self.metadata_df.index: - isSegm3D = str(self.metadata_df.at[isSegm3Dkey, 'values']) - self.isSegm3D = isSegm3D.lower() == 'true' + isSegm3D = str(self.metadata_df.at[isSegm3Dkey, "values"]) + self.isSegm3D = isSegm3D.lower() == "true" - if 'TimeIncrement' in self.metadata_df.index: - self.TimeIncrement = float( - self.metadata_df.at['TimeIncrement', 'values'] - ) - elif self.last_md_df is not None and 'TimeIncrement' in self.last_md_df.index: - self.TimeIncrement = float(self.last_md_df.at['TimeIncrement', 'values']) + if "TimeIncrement" in self.metadata_df.index: + self.TimeIncrement = float(self.metadata_df.at["TimeIncrement", "values"]) + elif self.last_md_df is not None and "TimeIncrement" in self.last_md_df.index: + self.TimeIncrement = float(self.last_md_df.at["TimeIncrement", "values"]) else: self.TimeIncrement = 1 - if 'PhysicalSizeX' in self.metadata_df.index: - self.PhysicalSizeX = float( - self.metadata_df.at['PhysicalSizeX', 'values'] - ) - elif self.last_md_df is not None and 'PhysicalSizeX' in self.last_md_df.index: - self.PhysicalSizeX = float(self.last_md_df.at['PhysicalSizeX', 'values']) + if "PhysicalSizeX" in self.metadata_df.index: + self.PhysicalSizeX = float(self.metadata_df.at["PhysicalSizeX", "values"]) + elif self.last_md_df is not None and "PhysicalSizeX" in self.last_md_df.index: + self.PhysicalSizeX = float(self.last_md_df.at["PhysicalSizeX", "values"]) else: self.PhysicalSizeX = 1 - if 'PhysicalSizeY' in self.metadata_df.index: - self.PhysicalSizeY = float( - self.metadata_df.at['PhysicalSizeY', 'values'] - ) - elif self.last_md_df is not None and 'PhysicalSizeY' in self.last_md_df.index: - self.PhysicalSizeY = float(self.last_md_df.at['PhysicalSizeY', 'values']) + if "PhysicalSizeY" in self.metadata_df.index: + self.PhysicalSizeY = float(self.metadata_df.at["PhysicalSizeY", "values"]) + elif self.last_md_df is not None and "PhysicalSizeY" in self.last_md_df.index: + self.PhysicalSizeY = float(self.last_md_df.at["PhysicalSizeY", "values"]) else: self.PhysicalSizeY = 1 - if 'PhysicalSizeZ' in self.metadata_df.index: - self.PhysicalSizeZ = float( - self.metadata_df.at['PhysicalSizeZ', 'values'] - ) - elif self.last_md_df is not None and 'PhysicalSizeZ' in self.last_md_df.index: - self.PhysicalSizeZ = float(self.last_md_df.at['PhysicalSizeZ', 'values']) + if "PhysicalSizeZ" in self.metadata_df.index: + self.PhysicalSizeZ = float(self.metadata_df.at["PhysicalSizeZ", "values"]) + elif self.last_md_df is not None and "PhysicalSizeZ" in self.last_md_df.index: + self.PhysicalSizeZ = float(self.last_md_df.at["PhysicalSizeZ", "values"]) else: self.PhysicalSizeZ = 1 - if 'LensNA' in self.metadata_df.index: - self.numAperture = float( - self.metadata_df.at['LensNA', 'values'] - ) + if "LensNA" in self.metadata_df.index: + self.numAperture = float(self.metadata_df.at["LensNA", "values"]) else: self.numAperture = 1.4 - - emWavelenMask = self.metadata_df.index.str.contains(r'_emWavelen') + + emWavelenMask = self.metadata_df.index.str.contains(r"_emWavelen") df_emWavelens = self.metadata_df[emWavelenMask] self.emWavelens = {} try: for channel_i_emWavelen, emWavelen in df_emWavelens.itertuples(): - channel_i_name = channel_i_emWavelen.replace('_emWavelen', '_name') - chName = self.metadata_df.at[channel_i_name, 'values'] + channel_i_name = channel_i_emWavelen.replace("_emWavelen", "_name") + chName = self.metadata_df.at[channel_i_name, "values"] self.emWavelens[chName] = float(emWavelen) except Exception as e: pass - + self._additionalMetadataValues = {} for name in self.metadata_df.index: - if name.startswith('__') and len(name) > 2: - value = self.metadata_df.at[name, 'values'] + if name.startswith("__") and len(name) > 2: + value = self.metadata_df.at[name, "values"] self._additionalMetadataValues[name] = value - + if not self._additionalMetadataValues: # Load metadata values saved in temp folder if os.path.exists(additional_metadata_path): self._additionalMetadataValues = read_json( - additional_metadata_path, desc='additional metadata' + additional_metadata_path, desc="additional metadata" ) def saveIsSegm3Dmetadata(self, segm_npz_path): segmFilename = os.path.basename(segm_npz_path) segmFilename = os.path.splitext(segmFilename)[0] - segmEndName = segmFilename[len(self.basename):] - isSegm3Dkey = f'{segmEndName}_isSegm3D' - self.metadata_df.at[isSegm3Dkey, 'values'] = self.isSegm3D + segmEndName = segmFilename[len(self.basename) :] + isSegm3Dkey = f"{segmEndName}_isSegm3D" + self.metadata_df.at[isSegm3Dkey, "values"] = self.isSegm3D self.metadata_df.to_csv(self.metadata_csv_path) - + def additionalMetadataValues(self): additionalMetadataValues = {} for name in self.metadata_df.index: - if name.startswith('__'): - value = self.metadata_df.at[name, 'values'] - key = name.replace('__', '', 1) + if name.startswith("__"): + value = self.metadata_df.at[name, "values"] + key = name.replace("__", "", 1) additionalMetadataValues[key] = value return additionalMetadataValues - + def add_tree_cols_to_cca_df(self, cca_df, frame_i=None): cca_df = cca_df.sort_index().reset_index() if self.acdc_df is None: return cca_df - + if frame_i is not None: df = self.acdc_df.loc[frame_i].sort_index().reset_index() else: @@ -2563,10 +2582,10 @@ def add_tree_cols_to_cca_df(self, cca_df, frame_i=None): cols = cca_df.columns.to_list() for col in df.columns: - if not col.endswith('tree'): + if not col.endswith("tree"): continue - ref_col = col[:col.find('_tree')] + ref_col = col[: col.find("_tree")] if ref_col in cols: ref_col_idx = cols.index(ref_col) + 1 else: @@ -2576,20 +2595,20 @@ def add_tree_cols_to_cca_df(self, cca_df, frame_i=None): cca_df[col] = df[col] else: cca_df.insert(ref_col_idx, col, df[col]) - + return cca_df - + def getManualBackgroudDataFilepath(self): segmFilename = os.path.basename(self.segm_npz_path) - segmEndname = segmFilename[len(self.basename):] - manualBackgrEndname = segmEndname.replace('segm', 'manualBackground') - manualBackgrFilename = f'{self.basename}{manualBackgrEndname}' + segmEndname = segmFilename[len(self.basename) :] + manualBackgrEndname = segmEndname.replace("segm", "manualBackground") + manualBackgrFilename = f"{self.basename}{manualBackgrEndname}" filepath = os.path.join(self.images_path, manualBackgrFilename) return filepath def saveManualBackgroundData(self, data: np.ndarray): if data is None: - return + return filepath = self.getManualBackgroudDataFilepath() io.savez_compressed(filepath, data) @@ -2600,30 +2619,30 @@ def loadManualBackgroundData(self): return archive = np.load(filepath) self.manualBackgroundLab = archive[archive.files[0]] - + def setNotFoundData(self): if self.segmFound is not None and not self.segmFound: self.segm_data = None # Segmentation file not found and a specifc one was requested # --> set the path - if hasattr(self, '_segm_end_fn'): - if self.basename.endswith('_'): + if hasattr(self, "_segm_end_fn"): + if self.basename.endswith("_"): basename = self.basename else: - basename = f'{self.basename}_' + basename = f"{self.basename}_" base_path = os.path.join(self.images_path, basename) - self.segm_npz_path = f'{base_path}{self._segm_end_fn}.npz' + self.segm_npz_path = f"{base_path}{self._segm_end_fn}.npz" if self.acdc_df_found is not None and not self.acdc_df_found: self.acdc_df = None # Set the file path for selected acdc_output.csv file # since it was not found - if hasattr(self, '_acdc_df_end_fn'): - if self.basename.endswith('_'): + if hasattr(self, "_acdc_df_end_fn"): + if self.basename.endswith("_"): basename = self.basename else: - basename = f'{self.basename}_' + basename = f"{self.basename}_" base_path = os.path.join(self.images_path, basename) - self.acdc_output_csv_path = f'{base_path}{self._acdc_df_end_fn}' + self.acdc_output_csv_path = f"{base_path}{self._acdc_df_end_fn}" if self.shiftsFound is not None and not self.shiftsFound: self.loaded_shifts = None if self.segmInfoFound is not None and not self.segmInfoFound: @@ -2638,7 +2657,10 @@ def setNotFoundData(self): if self.bkgrDataExists: # Do not load bkgrROIs if bkgrDataFound to avoid addMetrics to use it self.bkgrROIs = [] - if self.dataPrep_ROIcoordsFound is not None and not self.dataPrep_ROIcoordsFound: + if ( + self.dataPrep_ROIcoordsFound is not None + and not self.dataPrep_ROIcoordsFound + ): self.dataPrep_ROIcoords = None if self.last_tracked_i_found is not None and not self.last_tracked_i_found: self.last_tracked_i = None @@ -2656,7 +2678,7 @@ def setNotFoundData(self): if self.metadataFound: return - if hasattr(self, 'img_data'): + if hasattr(self, "img_data"): if self.img_data.ndim == 3: if len(self.img_data) > 49: self.SizeT, self.SizeZ = len(self.img_data), 1 @@ -2668,7 +2690,7 @@ def setNotFoundData(self): self.SizeT, self.SizeZ = 1, 1 else: self.SizeT, self.SizeZ = 1, 1 - + try: self.SizeY, self.SizeX = self.img_data_shape[-2:] except Exception as e: @@ -2692,127 +2714,114 @@ def setNotFoundData(self): # self.SizeT = int(self.last_md_df.at['SizeT', 'values']) # if 'SizeZ' in self.last_md_df.index and self.SizeZ == 1: # self.SizeZ = int(self.last_md_df.at['SizeZ', 'values']) - if 'TimeIncrement' in self.last_md_df.index: - self.TimeIncrement = float( - self.last_md_df.at['TimeIncrement', 'values'] - ) - if 'PhysicalSizeX' in self.last_md_df.index: - self.PhysicalSizeX = float( - self.last_md_df.at['PhysicalSizeX', 'values'] - ) - if 'PhysicalSizeY' in self.last_md_df.index: - self.PhysicalSizeY = float( - self.last_md_df.at['PhysicalSizeY', 'values'] - ) - if 'PhysicalSizeZ' in self.last_md_df.index: - self.PhysicalSizeZ = float( - self.last_md_df.at['PhysicalSizeZ', 'values'] - ) + if "TimeIncrement" in self.last_md_df.index: + self.TimeIncrement = float(self.last_md_df.at["TimeIncrement", "values"]) + if "PhysicalSizeX" in self.last_md_df.index: + self.PhysicalSizeX = float(self.last_md_df.at["PhysicalSizeX", "values"]) + if "PhysicalSizeY" in self.last_md_df.index: + self.PhysicalSizeY = float(self.last_md_df.at["PhysicalSizeY", "values"]) + if "PhysicalSizeZ" in self.last_md_df.index: + self.PhysicalSizeZ = float(self.last_md_df.at["PhysicalSizeZ", "values"]) def preprocessedDataArray(self, check_integrity=True): - if not hasattr(self, 'preproc_img_data'): + if not hasattr(self, "preproc_img_data"): return - + preprocess_data = [] for frame_i, raw_img in enumerate(self.img_data): preprocess_img = self.preproc_img_data.get(frame_i) if preprocess_img is None: if check_integrity: - raise TypeError( - 'Not all frames have been processed.' - ) + raise TypeError("Not all frames have been processed.") else: continue - + preprocess_img = np.squeeze(preprocess_img) - preprocess_data.append(preprocess_img) - + preprocess_data.append(preprocess_img) + preprocess_data_arr = np.array(preprocess_data) return preprocess_data_arr - + def combinedChannelsDataArray(self, check_integrity=True): - if not hasattr(self, 'combine_img_data'): + if not hasattr(self, "combine_img_data"): return - + combined_channels_data = [] for frame_i, raw_img in enumerate(self.img_data): combined_channels_img = self.combine_img_data.get(frame_i) if combined_channels_img is None: if check_integrity: - raise TypeError( - 'Not all frames have been processed.' - ) + raise TypeError("Not all frames have been processed.") else: continue - + combined_channels_img = np.squeeze(combined_channels_img) - combined_channels_data.append(combined_channels_img) - + combined_channels_data.append(combined_channels_img) + combined_channels_data_arr = np.array(combined_channels_data) return combined_channels_data_arr - + def addEquationCombineMetrics(self, equation, colName, isMixedChannels): - section = 'mixed_channels_equations' if isMixedChannels else 'equations' + section = "mixed_channels_equations" if isMixedChannels else "equations" self.combineMetricsConfig[section][colName] = equation - def setCombineMetricsConfig(self, ini_path=''): + def setCombineMetricsConfig(self, ini_path=""): if ini_path: configPars = config.ConfigParser() configPars.read(ini_path) else: configPars = config.ConfigParser() - if 'equations' not in configPars: - configPars['equations'] = {} + if "equations" not in configPars: + configPars["equations"] = {} - if 'mixed_channels_equations' not in configPars: - configPars['mixed_channels_equations'] = {} + if "mixed_channels_equations" not in configPars: + configPars["mixed_channels_equations"] = {} - if 'user_path_equations' not in configPars: - configPars['user_path_equations'] = {} + if "user_path_equations" not in configPars: + configPars["user_path_equations"] = {} # Append channel specific equations from the user_profile_path ini file - userPathChEquations = configPars['user_path_equations'] + userPathChEquations = configPars["user_path_equations"] for chName in self.chNames: - chName_equations = measurements.get_user_combine_metrics_equations( - chName - ) + chName_equations = measurements.get_user_combine_metrics_equations(chName) chName_equations = { - key:val for key, val in chName_equations.items() - if key not in configPars['equations'] + key: val + for key, val in chName_equations.items() + if key not in configPars["equations"] } userPathChEquations = {**userPathChEquations, **chName_equations} - configPars['user_path_equations'] = userPathChEquations + configPars["user_path_equations"] = userPathChEquations # Append mixed channels equations from the user_profile_path ini file - configPars['mixed_channels_equations'] = { - **configPars['mixed_channels_equations'], - **measurements.get_user_combine_mixed_channels_equations() + configPars["mixed_channels_equations"] = { + **configPars["mixed_channels_equations"], + **measurements.get_user_combine_mixed_channels_equations(), } self.combineMetricsConfig = configPars def saveCombineMetrics(self): - with open(self.custom_combine_metrics_path, 'w') as configfile: + with open(self.custom_combine_metrics_path, "w") as configfile: self.combineMetricsConfig.write(configfile) - + def saveClickEntryPointsDfs(self): for tableEndName, df in self.clickEntryPointsDfs.items(): - if not self.basename.endswith('_'): - basename = f'{self.basename}_' + if not self.basename.endswith("_"): + basename = f"{self.basename}_" else: basename = self.basename - tableFilename = f'{basename}{tableEndName}.csv' + tableFilename = f"{basename}{tableEndName}.csv" tableFilepath = os.path.join(self.images_path, tableFilename) - df = df.sort_values(['frame_i', 'Cell_ID']) + df = df.sort_values(["frame_i", "Cell_ID"]) df.to_csv(tableFilepath, index=False) def check_acdc_df_integrity(self): check = ( - self.acdc_df_found is not None # acdc_df was laoded if present - and self.acdc_df is not None # acdc_df was present - and self.segmFound is not None # segm data was loaded if present - and self.segm_data is not None # segm data was present + self.acdc_df_found is not None # acdc_df was laoded if present + and self.acdc_df is not None # acdc_df was present + and self.segmFound is not None # segm data was loaded if present + and self.segm_data is not None # segm data was present ) if check: if self.SizeT > 1: @@ -2846,177 +2855,172 @@ def _fix_acdc_df(self, lab, frame_i=0): continue self.acdc_df.at[idx, col] = val y, x = obj.centroid - self.acdc_df.at[idx, 'x_centroid'] = x - self.acdc_df.at[idx, 'y_centroid'] = y + self.acdc_df.at[idx, "x_centroid"] = x + self.acdc_df.at[idx, "y_centroid"] = y def getSegmEndname(self): segmFilename = os.path.basename(self.segm_npz_path) segmFilename = os.path.splitext(segmFilename)[0] - segmEndName = segmFilename[len(self.basename):] + segmEndName = segmFilename[len(self.basename) :] return segmEndName - + def getSegmentedChannelHyperparams(self): run_num = self.getSegmHyperparamsNewRunNumber() cp = config.ConfigParser() if os.path.exists(self.segm_hyperparams_ini_path): cp.read(self.segm_hyperparams_ini_path) segmEndName = self.getSegmEndname() - metadata_section = f'{segmEndName}.metadata.run_number_{run_num}' + metadata_section = f"{segmEndName}.metadata.run_number_{run_num}" section = segmEndName - option = 'segmented_channel' - channel_name = cp.get( - metadata_section, option, fallback=self.user_ch_name - ) + option = "segmented_channel" + channel_name = cp.get(metadata_section, option, fallback=self.user_ch_name) return channel_name, segmEndName else: - return self.user_ch_name, '' - + return self.user_ch_name, "" + def getSegmHyperparamsNewRunNumber(self): run_num = 1 if not os.path.exists(self.segm_hyperparams_ini_path): return run_num - + cp = config.ConfigParser() cp.read(self.segm_hyperparams_ini_path) segmEndName = self.getSegmEndname() - metadata_section = f'{segmEndName}.metadata' + metadata_section = f"{segmEndName}.metadata" for section in cp.sections(): if section.startswith(metadata_section): run_num += 1 - + return run_num - + def updateSegmentedChannelHyperparams(self, channelName): if not os.path.exists(self.segm_hyperparams_ini_path): return - + cp = config.ConfigParser() cp.read(self.segm_hyperparams_ini_path) segmEndName = self.getSegmEndname() run_num = self.getSegmHyperparamsNewRunNumber() - metadata_section = f'{segmEndName}.metadata.run_number_{run_num}' + metadata_section = f"{segmEndName}.metadata.run_number_{run_num}" if metadata_section not in cp.sections(): return - - option = 'segmented_channel' + + option = "segmented_channel" cp[metadata_section][option] = channelName - with open(self.segm_hyperparams_ini_path, 'w') as configfile: + with open(self.segm_hyperparams_ini_path, "w") as configfile: cp.write(configfile) def saveSegmHyperparams( - self, model_name, init_kwargs, segment_kwargs, - post_process_params=None, - preproc_recipe=None - ): + self, + model_name, + init_kwargs, + segment_kwargs, + post_process_params=None, + preproc_recipe=None, + ): cp = config.ConfigParser() if os.path.exists(self.segm_hyperparams_ini_path): cp.read(self.segm_hyperparams_ini_path) - + segmEndName = self.getSegmEndname() - + # Remove old sections if present cp.remove_section(segmEndName) - + segm_filename = os.path.basename(self.segm_npz_path) - + run_num = self.getSegmHyperparamsNewRunNumber() - metadata_section = f'{segmEndName}.metadata' - metadata_section = f'{metadata_section}.run_number_{run_num}' + metadata_section = f"{segmEndName}.metadata" + metadata_section = f"{metadata_section}.run_number_{run_num}" cp[metadata_section] = {} - - cp[metadata_section]['segmentation_filename'] = segm_filename - cp[metadata_section]['segmented_channel'] = self.user_ch_name - now = datetime.now().strftime(r'%Y-%m-%d %H:%M:%S.%u') - cp[metadata_section]['segmented_on'] = now - cp[metadata_section]['model_name'] = model_name - - init_section = f'{segmEndName}.init' - init_section = f'{init_section}.run_number_{run_num}' - + + cp[metadata_section]["segmentation_filename"] = segm_filename + cp[metadata_section]["segmented_channel"] = self.user_ch_name + now = datetime.now().strftime(r"%Y-%m-%d %H:%M:%S.%u") + cp[metadata_section]["segmented_on"] = now + cp[metadata_section]["model_name"] = model_name + + init_section = f"{segmEndName}.init" + init_section = f"{init_section}.run_number_{run_num}" + cp[init_section] = {} for key, value in init_kwargs.items(): cp[init_section][key] = str(value) - - segment_section = f'{segmEndName}.segment' - segment_section = f'{segment_section}.run_number_{run_num}' + + segment_section = f"{segmEndName}.segment" + segment_section = f"{segment_section}.run_number_{run_num}" cp[segment_section] = {} for key, value in segment_kwargs.items(): cp[segment_section][key] = str(value) if post_process_params is not None: - post_process_section = f'{segmEndName}.postprocess' - post_process_section = f'{post_process_section}.run_number_{run_num}' + post_process_section = f"{segmEndName}.postprocess" + post_process_section = f"{post_process_section}.run_number_{run_num}" cp[post_process_section] = {} for key, value in post_process_params.items(): cp[post_process_section][key] = str(value) if preproc_recipe is not None: - preproc_ini_items = config.preprocess_recipe_to_ini_items( - preproc_recipe - ) + preproc_ini_items = config.preprocess_recipe_to_ini_items(preproc_recipe) for preproc_section, section_items in preproc_ini_items.items(): - segm_preproc_section = f'{segmEndName}.{preproc_section}' - segm_preproc_section = ( - f'{segm_preproc_section}.run_number_{run_num}' - ) + segm_preproc_section = f"{segmEndName}.{preproc_section}" + segm_preproc_section = f"{segm_preproc_section}.run_number_{run_num}" cp[segm_preproc_section] = {} for key, value in section_items.items(): cp[segm_preproc_section][key] = str(value) - - with open(self.segm_hyperparams_ini_path, 'w') as configfile: + + with open(self.segm_hyperparams_ini_path, "w") as configfile: cp.write(configfile) - + def isRecoveredAcdcDfPresent(self): recovery_folderpath = self.recoveryFolderpath() - unsaved_recovery_folderpath = os.path.join( - recovery_folderpath, 'never_saved' - ) + unsaved_recovery_folderpath = os.path.join(recovery_folderpath, "never_saved") if not os.path.exists(unsaved_recovery_folderpath): return - - files = myutils.listdir(unsaved_recovery_folderpath) - csv_files = [file for file in files if file.endswith('.csv')] + + files = utils.listdir(unsaved_recovery_folderpath) + csv_files = [file for file in files if file.endswith(".csv")] if not csv_files: return - + if not os.path.exists(self.acdc_output_csv_path): acdc_df_mtime = 0 else: acdc_df_mtime = os.path.getmtime(self.acdc_output_csv_path) - + acdc_df_mdatetime = datetime.fromtimestamp(acdc_df_mtime) - + csv_files = natsorted(csv_files) iso_key = csv_files[-1][:-4] most_recent_unsaved_acdc_df_datetime = datetime.strptime( iso_key, ISO_TIMESTAMP_FORMAT ) return most_recent_unsaved_acdc_df_datetime > acdc_df_mdatetime - + def isSafeNpzOverwritePresent(self): - if not hasattr(self, 'segm_npz_path'): + if not hasattr(self, "segm_npz_path"): return False - - safe_npz_path = self.segm_npz_path.replace('.npz', '.new.npz') + + safe_npz_path = self.segm_npz_path.replace(".npz", ".new.npz") return os.path.exists(safe_npz_path) - + def getSafeNpzOverwritePath(self): - if not hasattr(self, 'segm_npz_path'): + if not hasattr(self, "segm_npz_path"): return - - safe_npz_path = self.segm_npz_path.replace('.npz', '.new.npz') + + safe_npz_path = self.segm_npz_path.replace(".npz", ".new.npz") return safe_npz_path - + def recoveryFolderpath(self, create_if_missing=True): - recovery_folder = os.path.join(self.images_path, 'recovery') + recovery_folder = os.path.join(self.images_path, "recovery") if not os.path.exists(recovery_folder) and create_if_missing: os.mkdir(recovery_folder) return recovery_folder - + def setTempPaths(self, createFolder=True): - temp_folder = os.path.join(self.images_path, 'recovery') + temp_folder = os.path.join(self.images_path, "recovery") self.recoveryFolderPath = temp_folder if not os.path.exists(temp_folder) and createFolder: os.mkdir(temp_folder) @@ -3024,63 +3028,59 @@ def setTempPaths(self, createFolder=True): acdc_df_filename = os.path.basename(self.acdc_output_csv_path) self.segm_npz_temp_path = os.path.join(temp_folder, segm_filename) self.acdc_output_backup_zip_path = os.path.join( - temp_folder, acdc_df_filename.replace('.csv', '.zip') - ) - unsaved_acdc_df_filename = acdc_df_filename.replace( - '.csv', '_autosave.zip' + temp_folder, acdc_df_filename.replace(".csv", ".zip") ) + unsaved_acdc_df_filename = acdc_df_filename.replace(".csv", "_autosave.zip") self.unsaved_acdc_df_autosave_path = os.path.join( temp_folder, unsaved_acdc_df_filename ) - + def buildPaths(self): - if self.basename.endswith('_'): + if self.basename.endswith("_"): basename = self.basename else: - basename = f'{self.basename}_' + basename = f"{self.basename}_" base_path = os.path.join(self.images_path, basename) - self.slice_used_align_path = f'{base_path}slice_used_alignment.csv' - self.slice_used_segm_path = f'{base_path}slice_segm.csv' - self.align_npz_path = f'{base_path}{self.user_ch_name}_aligned.npz' - self.align_old_path = f'{base_path}phc_aligned.npy' - self.align_shifts_path = f'{base_path}align_shift.npy' - self.segm_npz_path = f'{base_path}segm.npz' - self.last_tracked_i_path = f'{base_path}last_tracked_i.txt' - self.acdc_output_csv_path = f'{base_path}acdc_output.csv' - self.segmInfo_df_csv_path = f'{base_path}segmInfo.csv' - self.delROIs_info_path = f'{base_path}delROIsInfo.npz' - self.dataPrepROI_coords_path = f'{base_path}dataPrepROIs_coords.csv' + self.slice_used_align_path = f"{base_path}slice_used_alignment.csv" + self.slice_used_segm_path = f"{base_path}slice_segm.csv" + self.align_npz_path = f"{base_path}{self.user_ch_name}_aligned.npz" + self.align_old_path = f"{base_path}phc_aligned.npy" + self.align_shifts_path = f"{base_path}align_shift.npy" + self.segm_npz_path = f"{base_path}segm.npz" + self.last_tracked_i_path = f"{base_path}last_tracked_i.txt" + self.acdc_output_csv_path = f"{base_path}acdc_output.csv" + self.segmInfo_df_csv_path = f"{base_path}segmInfo.csv" + self.delROIs_info_path = f"{base_path}delROIsInfo.npz" + self.dataPrepROI_coords_path = f"{base_path}dataPrepROIs_coords.csv" # self.dataPrepBkgrValues_path = f'{base_path}dataPrep_bkgrValues.csv' - self.dataPrepBkgrROis_path = f'{base_path}dataPrep_bkgrROIs.json' - self.metadata_csv_path = f'{base_path}metadata.csv' - self.mot_events_path = f'{base_path}mot_events' - self.mot_metrics_csv_path = f'{base_path}mot_metrics' - self.raw_segm_npz_path = f'{base_path}segm_raw.npz' - self.raw_postproc_segm_path = f'{base_path}segm_raw_postproc' - self.post_proc_mot_metrics = f'{base_path}post_proc_mot_metrics' - self.segm_hyperparams_ini_path = f'{base_path}segm_hyperparams.ini' - self.custom_annot_json_path = f'{base_path}custom_annot_params.json' - self.custom_combine_metrics_path = ( - f'{base_path}custom_combine_metrics.ini' + self.dataPrepBkgrROis_path = f"{base_path}dataPrep_bkgrROIs.json" + self.metadata_csv_path = f"{base_path}metadata.csv" + self.mot_events_path = f"{base_path}mot_events" + self.mot_metrics_csv_path = f"{base_path}mot_metrics" + self.raw_segm_npz_path = f"{base_path}segm_raw.npz" + self.raw_postproc_segm_path = f"{base_path}segm_raw_postproc" + self.post_proc_mot_metrics = f"{base_path}post_proc_mot_metrics" + self.segm_hyperparams_ini_path = f"{base_path}segm_hyperparams.ini" + self.custom_annot_json_path = f"{base_path}custom_annot_params.json" + self.custom_combine_metrics_path = f"{base_path}custom_combine_metrics.ini" + self.sam_embeddings_path = f"{base_path}{self.user_ch_name}_sam_embeddings.pt" + self.tracked_lost_centroids_json_path = ( + f"{base_path}tracked_lost_centroids.json" ) - self.sam_embeddings_path =( - f'{base_path}{self.user_ch_name}_sam_embeddings.pt' - ) - self.tracked_lost_centroids_json_path = f'{base_path}tracked_lost_centroids.json' - self.acdc_output_auto_csv_path = f'{base_path}acdc_output_auto.csv' - + self.acdc_output_auto_csv_path = f"{base_path}acdc_output_auto.csv" + def get_btrack_export_path(self): - btrack_path = self.segm_npz_path.replace('.npz', '.h5') - btrack_path = btrack_path.replace('_segm', '_btrack_tracks') + btrack_path = self.segm_npz_path.replace(".npz", ".h5") + btrack_path = btrack_path.replace("_segm", "_btrack_tracks") return btrack_path - + def get_tracker_export_path(self, trackerName, ext): - tracker_path = self.segm_npz_path.replace('_segm', f'_{trackerName}_tracks') - tracker_path = tracker_path.replace('.npz', ext) + tracker_path = self.segm_npz_path.replace("_segm", f"_{trackerName}_tracks") + tracker_path = tracker_path.replace(".npz", ext) return tracker_path def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX): - if not hasattr(self, 'img_data'): + if not hasattr(self, "img_data"): self.segm_data = None return @@ -3100,18 +3100,18 @@ def loadAllImgPaths(self): npy_paths = [] npz_paths = [] basename = self.basename[0:-1] - for filename in myutils.listdir(self.images_path): + for filename in utils.listdir(self.images_path): file_path = os.path.join(self.images_path, filename) f, ext = os.path.splitext(filename) - m = re.match(fr'{basename}.*\.tif', filename) + m = re.match(rf"{basename}.*\.tif", filename) if m is not None: tif_paths.append(file_path) # Search for npy fluo data - npy = f'{f}_aligned.npy' - npz = f'{f}_aligned.npz' + npy = f"{f}_aligned.npy" + npz = f"{f}_aligned.npz" npy_found = False npz_found = False - for name in myutils.listdir(self.images_path): + for name in utils.listdir(self.images_path): _path = os.path.join(self.images_path, name) if name == npy: npy_paths.append(_path) @@ -3128,15 +3128,15 @@ def loadAllImgPaths(self): self.npz_paths = npz_paths def checkH5memoryFootprint(self): - if self.ext != '.h5': + if self.ext != ".h5": return 0 else: Y, X = self.dset.shape[-2:] - size = self.loadSizeT*self.loadSizeZ*Y*X + size = self.loadSizeT * self.loadSizeZ * Y * X itemsize = self.dset.dtype.itemsize - required_memory = size*itemsize + required_memory = size * itemsize return required_memory - + def _warnMultiPosTimeLapse(self, SizeT_metadata): txt = html_utils.paragraph(f""" You are trying to load multiple Positions of what it seems to be @@ -3150,47 +3150,62 @@ def _warnMultiPosTimeLapse(self, SizeT_metadata): """) msg = widgets.myMessageBox(wrapText=False, showCentered=False) _, noButton, yesButton = msg.warning( - self.parent, 'WARNING: Edinting saved metadata', txt, - buttonsTexts=('Cancel', 'No, stop the process', 'Yes, proceed anyway') + self.parent, + "WARNING: Edinting saved metadata", + txt, + buttonsTexts=("Cancel", "No, stop the process", "Yes, proceed anyway"), ) return msg.clickedButton == yesButton def askInputMetadata( - self, numPos, - ask_SizeT=False, - ask_TimeIncrement=False, - ask_PhysicalSizes=False, - singlePos=False, - save=False, - askSegm3D=True, - forceEnableAskSegm3D=False, - warnMultiPos=False - ): + self, + numPos, + ask_SizeT=False, + ask_TimeIncrement=False, + ask_PhysicalSizes=False, + singlePos=False, + save=False, + askSegm3D=True, + forceEnableAskSegm3D=False, + warnMultiPos=False, + ): from . import apps + SizeZ_metadata = None SizeT_metadata = None - if hasattr(self, 'metadataFound'): + if hasattr(self, "metadataFound"): if self.metadataFound: SizeT_metadata = self.SizeT SizeZ_metadata = self.SizeZ - if SizeT_metadata>1 and numPos>1 and warnMultiPos: + if SizeT_metadata > 1 and numPos > 1 and warnMultiPos: proceed_anyway = self._warnMultiPosTimeLapse(SizeT_metadata) if not proceed_anyway: return False - - basename = '' - if hasattr(self, 'basename'): + + basename = "" + if hasattr(self, "basename"): basename = self.basename metadataWin = apps.QDialogMetadata( - self.SizeT, self.SizeZ, self.TimeIncrement, - self.PhysicalSizeZ, self.PhysicalSizeY, self.PhysicalSizeX, - ask_SizeT, ask_TimeIncrement, ask_PhysicalSizes, - parent=self.parent, font=apps.font, imgDataShape=self.img_data_shape, - posData=self, singlePos=singlePos, askSegm3D=askSegm3D, + self.SizeT, + self.SizeZ, + self.TimeIncrement, + self.PhysicalSizeZ, + self.PhysicalSizeY, + self.PhysicalSizeX, + ask_SizeT, + ask_TimeIncrement, + ask_PhysicalSizes, + parent=self.parent, + font=apps.font, + imgDataShape=self.img_data_shape, + posData=self, + singlePos=singlePos, + askSegm3D=askSegm3D, additionalValues=self._additionalMetadataValues, - forceEnableAskSegm3D=forceEnableAskSegm3D, - SizeT_metadata=SizeT_metadata, SizeZ_metadata=SizeZ_metadata, - basename=basename + forceEnableAskSegm3D=forceEnableAskSegm3D, + SizeT_metadata=SizeT_metadata, + SizeZ_metadata=SizeZ_metadata, + basename=basename, ) metadataWin.exec_() if metadataWin.cancel: @@ -3218,21 +3233,21 @@ def askInputMetadata( self._additionalMetadataValues = metadataWin._additionalValues if save: self.saveMetadata(additionalMetadata=metadataWin._additionalValues) - + metadataWin.deleteLater() return True - - def zSliceSegmentation(self, filename, frame_i, errors='raise'): + + def zSliceSegmentation(self, filename, frame_i, errors="raise"): if self.SizeZ > 1: idx = (filename, frame_i) try: - if self.segmInfo_df.at[idx, 'resegmented_in_gui']: - col = 'z_slice_used_gui' + if self.segmInfo_df.at[idx, "resegmented_in_gui"]: + col = "z_slice_used_gui" else: - col = 'z_slice_used_dataPrep' + col = "z_slice_used_dataPrep" z = self.segmInfo_df.at[idx, col] except Exception as err: - if errors == 'raise': + if errors == "raise": raise err else: return round(self.SizeZ / 2) @@ -3251,19 +3266,17 @@ def metadataToCsv(self, signals=None, mutex=None, waitCond=None): try: self.metadata_df.to_csv(self.metadata_csv_path) except PermissionError: - print('='*20) + print("=" * 20) traceback.print_exc() - print('='*20) + print("=" * 20) permissionErrorTxt = html_utils.paragraph( - f'The below file is open in another app (Excel maybe?).

    ' - f'{self.metadata_csv_path}

    ' + f"The below file is open in another app (Excel maybe?).

    " + f"{self.metadata_csv_path}

    " 'Close file and then press "Ok".' ) if signals is None: msg = widgets.myMessageBox(self.parent) - msg.warning( - self, 'Permission denied', permissionErrorTxt - ) + msg.warning(self, "Permission denied", permissionErrorTxt) self.metadata_df.to_csv(self.metadata_csv_path) else: mutex.lock() @@ -3273,46 +3286,45 @@ def metadataToCsv(self, signals=None, mutex=None, waitCond=None): self.metadata_df.to_csv(self.metadata_csv_path) def saveMetadata( - self, signals=None, mutex=None, waitCond=None, - additionalMetadata=None - ): + self, signals=None, mutex=None, waitCond=None, additionalMetadata=None + ): segmEndName = self.getSegmEndname() - isSegm3Dkey = f'{segmEndName}_isSegm3D' + isSegm3Dkey = f"{segmEndName}_isSegm3D" if self.metadata_df is None: metadata_dict = { - 'SizeT': self.SizeT, - 'SizeZ': self.SizeZ, - 'SizeY': self.SizeY, - 'SizeX': self.SizeX, - 'TimeIncrement': self.TimeIncrement, - 'PhysicalSizeZ': self.PhysicalSizeZ, - 'PhysicalSizeY': self.PhysicalSizeY, - 'PhysicalSizeX': self.PhysicalSizeX, - isSegm3Dkey: self.isSegm3D + "SizeT": self.SizeT, + "SizeZ": self.SizeZ, + "SizeY": self.SizeY, + "SizeX": self.SizeX, + "TimeIncrement": self.TimeIncrement, + "PhysicalSizeZ": self.PhysicalSizeZ, + "PhysicalSizeY": self.PhysicalSizeY, + "PhysicalSizeX": self.PhysicalSizeX, + isSegm3Dkey: self.isSegm3D, } if additionalMetadata is not None: metadata_dict = {**metadata_dict, **additionalMetadata} for key in list(metadata_dict.keys()): - if key.startswith('__') and key not in additionalMetadata: + if key.startswith("__") and key not in additionalMetadata: metadata_dict.pop(key) - self.metadata_df = pd.DataFrame(metadata_dict, index=['values']).T - self.metadata_df.index.name = 'Description' + self.metadata_df = pd.DataFrame(metadata_dict, index=["values"]).T + self.metadata_df.index.name = "Description" else: - self.metadata_df.at['SizeT', 'values'] = self.SizeT - self.metadata_df.at['SizeZ', 'values'] = self.SizeZ - self.metadata_df.at['TimeIncrement', 'values'] = self.TimeIncrement - self.metadata_df.at['PhysicalSizeZ', 'values'] = self.PhysicalSizeZ - self.metadata_df.at['PhysicalSizeY', 'values'] = self.PhysicalSizeY - self.metadata_df.at['PhysicalSizeX', 'values'] = self.PhysicalSizeX - self.metadata_df.at[isSegm3Dkey, 'values'] = self.isSegm3D + self.metadata_df.at["SizeT", "values"] = self.SizeT + self.metadata_df.at["SizeZ", "values"] = self.SizeZ + self.metadata_df.at["TimeIncrement", "values"] = self.TimeIncrement + self.metadata_df.at["PhysicalSizeZ", "values"] = self.PhysicalSizeZ + self.metadata_df.at["PhysicalSizeY", "values"] = self.PhysicalSizeY + self.metadata_df.at["PhysicalSizeX", "values"] = self.PhysicalSizeX + self.metadata_df.at[isSegm3Dkey, "values"] = self.isSegm3D if additionalMetadata is not None: for name, value in additionalMetadata.items(): - self.metadata_df.at[name, 'values'] = value + self.metadata_df.at[name, "values"] = value idx_to_drop = [] for name in self.metadata_df.index: - if name.startswith('__') and name not in additionalMetadata: + if name.startswith("__") and name not in additionalMetadata: idx_to_drop.append(name) self.metadata_df = self.metadata_df.drop(idx_to_drop) @@ -3323,137 +3335,171 @@ def saveMetadata( pass if additionalMetadata is not None: try: - with open(additional_metadata_path, mode='w') as file: + with open(additional_metadata_path, mode="w") as file: json.dump(additionalMetadata, file, indent=2) except PermissionError: pass def criticalExtNotValid(self, signals=None): - err_title = f'File extension {self.ext} not valid.' + err_title = f"File extension {self.ext} not valid." err_msg = ( - f'The requested file {self.relPath}\n' - 'has an invalid extension.\n\n' - 'Valid extensions are .tif, .tiff, .npy or .npz' + f"The requested file {self.relPath}\n" + "has an invalid extension.\n\n" + "Valid extensions are .tif, .tiff, .npy or .npz" ) if self.parent is None: - print('-------------------------') + print("-------------------------") print(err_msg) - print('-------------------------') + print("-------------------------") raise FileNotFoundError(err_title) elif signals is None: - print('-------------------------') + print("-------------------------") print(err_msg) - print('-------------------------') + print("-------------------------") msg = QMessageBox() msg.critical(self.parent, err_title, err_msg, msg.Ok) return None elif signals is not None: raise FileNotFoundError(err_title) - - def saveTrackedLostCentroids(self, tracked_lost_centroids_list=None, _tracked_lost_centroids_list=None): - if not (self.tracked_lost_centroids or tracked_lost_centroids_list or _tracked_lost_centroids_list): + def saveTrackedLostCentroids( + self, tracked_lost_centroids_list=None, _tracked_lost_centroids_list=None + ): + + if not ( + self.tracked_lost_centroids + or tracked_lost_centroids_list + or _tracked_lost_centroids_list + ): return if _tracked_lost_centroids_list is not None: tracked_lost_centroids_list = _tracked_lost_centroids_list elif tracked_lost_centroids_list is not None: - tracked_lost_centroids_list = {k: v for k, v in tracked_lost_centroids_list.items()} + tracked_lost_centroids_list = { + k: v for k, v in tracked_lost_centroids_list.items() + } else: - tracked_lost_centroids_list = {k: list(v) for k, v in self.tracked_lost_centroids.items()} + tracked_lost_centroids_list = { + k: list(v) for k, v in self.tracked_lost_centroids.items() + } # printl(tracked_lost_centroids_list) try: - with open(self.tracked_lost_centroids_json_path, 'w') as json_file: + with open(self.tracked_lost_centroids_json_path, "w") as json_file: json.dump(tracked_lost_centroids_list, json_file, indent=4) except PermissionError: - print('='*20) + print("=" * 20) traceback.print_exc() - print('='*20) + print("=" * 20) permissionErrorTxt = html_utils.paragraph( - f'The below file is open in another app (Excel maybe?).

    ' - f'{self.tracked_lost_centroids_json_path}

    ' + f"The below file is open in another app (Excel maybe?).

    " + f"{self.tracked_lost_centroids_json_path}

    " 'Close file and then press "Ok", or press "Cancel" to abort.' ) msg = widgets.myMessageBox(self.parent) msg.warning( - self, 'Permission denied', permissionErrorTxt, buttonsTexts=('Cancel', 'Ok') + self, + "Permission denied", + permissionErrorTxt, + buttonsTexts=("Cancel", "Ok"), ) if msg.cancel: return - - self.saveTrackedLostCentroids(_tracked_lost_centroids_list=tracked_lost_centroids_list) + + self.saveTrackedLostCentroids( + _tracked_lost_centroids_list=tracked_lost_centroids_list + ) def loadTrackedLostCentroids(self): try: - with open(self.tracked_lost_centroids_json_path, 'r') as json_file: + with open(self.tracked_lost_centroids_json_path, "r") as json_file: tracked_lost_centroids_list = json.load(json_file) - self.tracked_lost_centroids = {int(k): {tuple(int(val) for val in centroid) for centroid in v} for k, v in tracked_lost_centroids_list.items()} + self.tracked_lost_centroids = { + int(k): {tuple(int(val) for val in centroid) for centroid in v} + for k, v in tracked_lost_centroids_list.items() + } except FileNotFoundError: # print(f"No file found at {self.tracked_lost_centroids_json_path}") self.tracked_lost_centroids = { - frame_i:set() for frame_i in range(self.SizeT) - } + frame_i: set() for frame_i in range(self.SizeT) + } except PermissionError: - print('='*20) + print("=" * 20) traceback.print_exc() - print('='*20) + print("=" * 20) permissionErrorTxt = html_utils.paragraph( - f'The below file is open in another app (Excel maybe?).

    ' - f'{self.tracked_lost_centroids_json_path}

    ' + f"The below file is open in another app (Excel maybe?).

    " + f"{self.tracked_lost_centroids_json_path}

    " 'Close file and then press "Ok", or press "Cancel" to abort.' ) msg = widgets.myMessageBox(self.parent) msg.warning( - self, 'Permission denied', permissionErrorTxt, buttonsTexts=('Cancel', 'Ok') + self, + "Permission denied", + permissionErrorTxt, + buttonsTexts=("Cancel", "Ok"), ) if msg.cancel: self.tracked_lost_centroids = { - frame_i:set() for frame_i in range(self.SizeT) - } + frame_i: set() for frame_i in range(self.SizeT) + } return - + self.loadTrackedLostCentroids() - + def loadWhitelist(self): self.whitelist = whitelist.Whitelist( total_frames=self.SizeT, ) - whitelist_path = self.segm_npz_path.replace('.npz', '_whitelistIDs.json') - new_centroids_path = self.segm_npz_path.replace('.npz', '_new_centroids.json') + whitelist_path = self.segm_npz_path.replace(".npz", "_whitelistIDs.json") + new_centroids_path = self.segm_npz_path.replace(".npz", "_new_centroids.json") success = self.whitelist.load( - whitelist_path, new_centroids_path, self.segm_data, self.allData_li, + whitelist_path, + new_centroids_path, + self.segm_data, + self.allData_li, ) if self.log_func and success: filename = os.path.basename(whitelist_path) - self.log_func(f'Loaded whitelist from file: {filename}') + self.log_func(f"Loaded whitelist from file: {filename}") if not success: self.whitelist = None - + class select_exp_folder: def __init__(self): self.exp_path = None def QtPrompt( - self, parentQWidget, values, - current=0, title='Select Position folder', - CbLabel="Select folder to load:", - showinexplorer_button=False, full_paths=None, - allow_cancel=True, show=False, toggleMulti=False, - allowMultiSelection=True, - informativeText='', - selectedValues=None - ): + self, + parentQWidget, + values, + current=0, + title="Select Position folder", + CbLabel="Select folder to load:", + showinexplorer_button=False, + full_paths=None, + allow_cancel=True, + show=False, + toggleMulti=False, + allowMultiSelection=True, + informativeText="", + selectedValues=None, + ): from . import apps + font = QtGui.QFont() font.setPixelSize(13) win = apps.QtSelectItems( - title, values, informativeText, CbLabel=CbLabel, + title, + values, + informativeText, + CbLabel=CbLabel, parent=parentQWidget, - showInFileManagerPath=self.exp_path + showInFileManagerPath=self.exp_path, ) win.setFont(font) toFront = win.windowState() & ~Qt.WindowMinimized | Qt.WindowActive @@ -3474,42 +3520,42 @@ def QtPrompt( ] def append_last_cca_frame(self, acdc_df, text): - if 'cell_cycle_stage' not in acdc_df.columns: + if "cell_cycle_stage" not in acdc_df.columns: return text - + try: - colnames = ['frame_i', *cca_df_colnames] + colnames = ["frame_i", *cca_df_colnames] cca_df = acdc_df[colnames].dropna() except Exception as e: return text - last_cca_frame_i = max(cca_df['frame_i'], default=None) + last_cca_frame_i = max(cca_df["frame_i"], default=None) if last_cca_frame_i is None: return text - to_append = f', last cc annotated frame: {last_cca_frame_i+1})' - text = text.replace(')', to_append) + to_append = f", last cc annotated frame: {last_cca_frame_i + 1})" + text = text.replace(")", to_append) return text - + def get_values_segmGUI(self, exp_path): self.exp_path = exp_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) + pos_foldernames = utils.get_pos_foldernames(exp_path) self.pos_foldernames = pos_foldernames values = [] for pos in pos_foldernames: last_tracked_i_found = False pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - filenames = myutils.listdir(images_path) + images_path = os.path.join(pos_path, "Images") + filenames = utils.listdir(images_path) for filename in filenames: - if filename.find('acdc_output.csv') != -1: + if filename.find("acdc_output.csv") != -1: last_tracked_i_found = True acdc_df_path = os.path.join(images_path, filename) acdc_df = _load_acdc_df_file(acdc_df_path).reset_index() - last_tracked_i = acdc_df['frame_i'].max() + last_tracked_i = acdc_df["frame_i"].max() break - + if last_tracked_i_found: - text = f'{pos} (Last tracked frame: {last_tracked_i+1})' + text = f"{pos} (Last tracked frame: {last_tracked_i + 1})" text = self.append_last_cca_frame(acdc_df, text) values.append(text) else: @@ -3519,7 +3565,7 @@ def get_values_segmGUI(self, exp_path): def get_values_dataprep(self, exp_path): self.exp_path = exp_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) + pos_foldernames = utils.get_pos_foldernames(exp_path) self.pos_foldernames = pos_foldernames values = [] for pos in pos_foldernames: @@ -3529,53 +3575,53 @@ def get_values_dataprep(self, exp_path): is_roi_info_present = False are_zslices_selected = False pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - filenames = myutils.listdir(images_path) + images_path = os.path.join(pos_path, "Images") + filenames = utils.listdir(images_path) for filename in filenames: - if filename.endswith('dataPrepROIs_coords.csv'): + if filename.endswith("dataPrepROIs_coords.csv"): is_roi_info_present = True filepath = os.path.join(images_path, filename) - df = pd.read_csv(filepath, index_col='description') - is_cropped = (df.loc[['cropped'], 'value'] > 0).any() - elif filename.endswith('dataPrep_bkgrROIs.json'): + df = pd.read_csv(filepath, index_col="description") + is_cropped = (df.loc[["cropped"], "value"] > 0).any() + elif filename.endswith("dataPrep_bkgrROIs.json"): is_bkgr_roi_info_present = True - elif filename.endswith('aligned.npz'): + elif filename.endswith("aligned.npz"): is_aligned = True - elif filename.endswith('align_shift.npy'): + elif filename.endswith("align_shift.npy"): is_aligned = True - elif filename.endswith('bkgrRoiData.npz'): + elif filename.endswith("bkgrRoiData.npz"): is_cropped = True - elif filename.endswith('segmInfo.csv'): + elif filename.endswith("segmInfo.csv"): are_zslices_selected = True - + is_bkgr_roi_info_present is_cropped is_roi_info_present - - info_txt = f'{pos} (' + + info_txt = f"{pos} (" if are_zslices_selected: - info_txt = f'{info_txt} z-slices selected,' + info_txt = f"{info_txt} z-slices selected," if is_aligned: - info_txt = f'{info_txt} aligned,' + info_txt = f"{info_txt} aligned," if is_roi_info_present: - info_txt = f'{info_txt} ROI info present,' + info_txt = f"{info_txt} ROI info present," if is_bkgr_roi_info_present: - info_txt = f'{info_txt} bkgr ROI info present,' + info_txt = f"{info_txt} bkgr ROI info present," if is_cropped: - info_txt = f'{info_txt} cropped' - - if info_txt.endswith('('): + info_txt = f"{info_txt} cropped" + + if info_txt.endswith("("): values.append(pos) else: - values.append(f'{info_txt})') + values.append(f"{info_txt})") self.values = values return values def get_values_cca(self, exp_path): self.exp_path = exp_path - pos_foldernames = natsorted(myutils.listdir(exp_path)) + pos_foldernames = natsorted(utils.listdir(exp_path)) pos_foldernames = [ - pos for pos in pos_foldernames if re.match(r'^Position_(\d+)', pos) + pos for pos in pos_foldernames if re.match(r"^Position_(\d+)", pos) ] self.pos_foldernames = pos_foldernames values = [] @@ -3583,21 +3629,20 @@ def get_values_cca(self, exp_path): cc_stage_found = False pos_path = os.path.join(exp_path, pos) if os.path.isdir(pos_path): - images_path = f'{exp_path}/{pos}/Images' - filenames = myutils.listdir(images_path) + images_path = f"{exp_path}/{pos}/Images" + filenames = utils.listdir(images_path) for filename in filenames: - if filename.find('cc_stage.csv') != -1: + if filename.find("cc_stage.csv") != -1: cc_stage_found = True - cc_stage_path = f'{images_path}/{filename}' + cc_stage_path = f"{images_path}/{filename}" cca_df = pd.read_csv( - cc_stage_path, index_col=['frame_i', 'Cell_ID'] - ) - last_analyzed_frame_i = ( - cca_df.index.get_level_values(0).max() + cc_stage_path, index_col=["frame_i", "Cell_ID"] ) + last_analyzed_frame_i = cca_df.index.get_level_values(0).max() if cc_stage_found: - values.append(f'{pos} (Last analyzed frame: ' - f'{last_analyzed_frame_i})') + values.append( + f"{pos} (Last analyzed frame: {last_analyzed_frame_i})" + ) else: values.append(pos) self.values = values @@ -3619,21 +3664,21 @@ def on_closing(self): self.root.quit() self.root.destroy() if self.allow_abort: - exit('Execution aborted by the user') + exit("Execution aborted by the user") def load_shifts(parent_path, basename=None): shifts_found = False shifts = None if basename is None: - for filename in myutils.listdir(parent_path): - if filename.find('align_shift.npy')>0: + for filename in utils.listdir(parent_path): + if filename.find("align_shift.npy") > 0: shifts_found = True shifts_path = os.path.join(parent_path, filename) shifts = np.load(shifts_path) else: - align_shift_fn = f'{basename}_align_shift.npy' - if align_shift_fn in myutils.listdir(parent_path): + align_shift_fn = f"{basename}_align_shift.npy" + if align_shift_fn in utils.listdir(parent_path): shifts_found = True shifts_path = os.path.join(parent_path, align_shift_fn) shifts = np.load(shifts_path) @@ -3641,6 +3686,7 @@ def load_shifts(parent_path, basename=None): shifts = None return shifts, shifts_found + class OMEXML_image: def __init__(self, Pixels, ome_schema): if Pixels is None: @@ -3649,19 +3695,23 @@ def __init__(self, Pixels, ome_schema): node = Pixels.attrib self.Pixels = OMEXML_Pixels(Pixels, node, ome_schema) + class OMEXML_objective: def __init__(self) -> None: self.LensNA = 1.4 + class OMEXML_intrument: def __init__(self): self.Objective = OMEXML_objective() + class OMEXML_Channel: def __init__(self, Channel) -> None: - self.Name = Channel.attrib.get('Name', '') + self.Name = Channel.attrib.get("Name", "") self.node = Channel.attrib + class OMEXML_Pixels: def __init__(self, Pixels, node, ome_schema) -> None: self.node = node @@ -3675,70 +3725,72 @@ def __init__(self, Pixels, node, ome_schema) -> None: self.PhysicalSizeY = 1.0 self.PhysicalSizeZ = 1.0 else: - self.SizeZ = node.get('SizeZ', 1) - self.SizeT = node.get('SizeT', 1) - self.SizeC = node.get('SizeC', 1) - self.PhysicalSizeX = node.get('PhysicalSizeX', 1.0) - self.PhysicalSizeY = node.get('PhysicalSizeY', 1.0) - self.PhysicalSizeZ = node.get('PhysicalSizeZ', 1.0) - + self.SizeZ = node.get("SizeZ", 1) + self.SizeT = node.get("SizeT", 1) + self.SizeC = node.get("SizeC", 1) + self.PhysicalSizeX = node.get("PhysicalSizeX", 1.0) + self.PhysicalSizeY = node.get("PhysicalSizeY", 1.0) + self.PhysicalSizeZ = node.get("PhysicalSizeZ", 1.0) + def Channel(self, channel_index=0): - Channel = self.Pixels.findall(f'{self.ome_schema}Channel')[channel_index] + Channel = self.Pixels.findall(f"{self.ome_schema}Channel")[channel_index] return OMEXML_Channel(Channel) + class OMEXML: def __init__(self, ometiff_filepath): self.filepath = ometiff_filepath self.read_omexml_string() self.parse_metadata() - + def read_omexml_string(self): with TiffFile(self.filepath) as tif: return tif.ome_metadata - + def parse_metadata(self): self.omexml_string = self.read_omexml_string() self.root = ET.fromstring(self.omexml_string) - self.ome_schema = re.findall(r'({.+})OME', self.root.tag)[0] - + self.ome_schema = re.findall(r"({.+})OME", self.root.tag)[0] + def instrument(self): instrument = OMEXML_intrument() - instrument_xml = self.root.find(f'{self.ome_schema}Instrument') + instrument_xml = self.root.find(f"{self.ome_schema}Instrument") if instrument_xml is None: return instrument - objective_xml = instrument_xml.find(f'{self.ome_schema}Objective') + objective_xml = instrument_xml.find(f"{self.ome_schema}Objective") if objective_xml is None: return instrument - LensNA = objective_xml.attrib.get('LensNA') + LensNA = objective_xml.attrib.get("LensNA") if LensNA is None: return instrument instrument.Objective.LensNA = LensNA return instrument def get_image_count(self): - return len(self.root.findall(f'{self.ome_schema}Image')) + return len(self.root.findall(f"{self.ome_schema}Image")) def image(self): - Image = self.root.find(f'{self.ome_schema}Image') - Pixels = Image.find(f'{self.ome_schema}Pixels') + Image = self.root.find(f"{self.ome_schema}Image") + Pixels = Image.find(f"{self.ome_schema}Pixels") image = OMEXML_image(Pixels, self.ome_schema) - image.Name = Image.attrib.get('Name', '') + image.Name = Image.attrib.get("Name", "") return image + def _restructure_multi_files_multi_pos( - src_path, dst_path, action='copy', signals=None, logger=print - ): + src_path, dst_path, action="copy", signals=None, logger=print +): if signals is not None: signals.initProgressBar.emit(0) - logger('Scanning files...') + logger("Scanning files...") files = list(os.listdir(src_path)) files = [f for f in files if os.path.isfile(os.path.join(src_path, f))] - + # Group files with same starting string with all possible splits files_scanned = list(files) groups = {} for f, file in enumerate(files): - splits = file.split('_') + splits = file.split("_") current_split = splits[0] for split in splits[1:]: for other_file in files_scanned: @@ -3747,9 +3799,9 @@ def _restructure_multi_files_multi_pos( groups[current_split] = {other_file} else: groups[current_split].add(other_file) - current_split = f'{current_split}_{split}' + current_split = f"{current_split}_{split}" files_scanned.pop(0) - + # Determine the keys of duplicated groups keys_duplicates = {} keys_scanned = list(groups.keys()) @@ -3763,11 +3815,11 @@ def _restructure_multi_files_multi_pos( else: keys_duplicates[key].add(other_key) keys_scanned.pop(0) - + # Get unique splits and sort them by length unique_splits = {max(splits, key=len) for splits in keys_duplicates.values()} unique_splits = sorted(list(unique_splits), key=len) - + # Get groups of files sharing the same starting groups_files = {} for split in unique_splits: @@ -3777,54 +3829,54 @@ def _restructure_multi_files_multi_pos( groups_files[split] = {file} else: groups_files[split].add(file) - + # Sort the files according to exp and pos splits - groups_n_splits = {len(split.split('_')):set() for split in groups_files} + groups_n_splits = {len(split.split("_")): set() for split in groups_files} for split in groups_files: - n_splits = len(split.split('_')) + n_splits = len(split.split("_")) groups_n_splits[n_splits].add(split) - + sorted_n_splits = sorted(groups_n_splits.keys()) n_splits_exp, n_splits_pos = sorted_n_splits[-2:] - final_structure = {} + final_structure = {} for split_exp in groups_n_splits[n_splits_exp]: exp_folder_path = os.path.join(dst_path, split_exp) exp_files = groups_files[split_exp] - pos_splits = groups_n_splits[n_splits_pos] + pos_splits = groups_n_splits[n_splits_pos] for exp_file in exp_files: p = 1 for pos_split in pos_splits: if not pos_split.startswith(split_exp): continue try: - pos_n = pos_split.split('_')[-1] + pos_n = pos_split.split("_")[-1] pos_n = int(pos_n) except Exception as e: pos_n = p - pos_path = os.path.join(exp_folder_path, f'Position_{pos_n}') - images_path = os.path.join(pos_path, 'Images') + pos_path = os.path.join(exp_folder_path, f"Position_{pos_n}") + images_path = os.path.join(pos_path, "Images") final_structure[images_path] = [] if not os.path.exists(images_path): os.makedirs(images_path, exist_ok=True) for file in files: if not file.startswith(pos_split): - continue + continue final_structure[images_path].append(file) - + p += 1 - + # Move or copy the files if signals is not None: signals.initProgressBar.emit(len(files)) - action_str = 'Copying' if action=='copy' else 'Moving' - logger(f'{action_str} files...') - pbar = tqdm(total=len(files), ncols=100, unit='file') + action_str = "Copying" if action == "copy" else "Moving" + logger(f"{action_str} files...") + pbar = tqdm(total=len(files), ncols=100, unit="file") for images_path, files in final_structure.items(): for file in files: dst_file = os.path.join(images_path, file) src_file = os.path.join(src_path, file) try: - if action == 'copy': + if action == "copy": shutil.copy2(src_file, dst_file) else: shutil.move(src_file, dst_file) @@ -3834,44 +3886,51 @@ def _restructure_multi_files_multi_pos( if signals is not None: signals.progressBar.emit(1) pbar.close() - - action_str = 'copied' if action=='copy' else 'moved' + + action_str = "copied" if action == "copy" else "moved" logger(f'Done! Files {action_str} and restructured into "{src_path}"') + def get_all_svg_icons_aliases(sort=True): from . import resources_filepath - with open(resources_filepath, 'r') as resources_file: + + with open(resources_filepath, "r") as resources_file: resources_txt = resources_file.read() - + aliases = re.findall(r'', resources_txt) if sort: aliases = natsorted(aliases) return aliases + def get_all_buttons_names(sort=True): - widgets_filepath = os.path.join(cellacdc_path, 'widgets.py') - with open(widgets_filepath, 'r') as py_file: + widgets_filepath = os.path.join(cellacdc_path, "widgets.py") + with open(widgets_filepath, "r") as py_file: txt = py_file.read() - - all_buttons_names = re.findall(r'class (\w+)\(Q?PushButton\):', txt) + + all_buttons_names = re.findall(r"class (\w+)\(Q?PushButton\):", txt) if sort: all_buttons_names = natsorted(all_buttons_names) return all_buttons_names -def rename_qrc_resources_file(scheme='light'): + +def rename_qrc_resources_file(scheme="light"): os.remove(qrc_resources_path) - - if scheme == 'dark' and os.path.exists(qrc_resources_dark_path): + + if scheme == "dark" and os.path.exists(qrc_resources_dark_path): shutil.copyfile(qrc_resources_dark_path, qrc_resources_path) - elif scheme == 'light' and os.path.exists(qrc_resources_light_path): + elif scheme == "light" and os.path.exists(qrc_resources_light_path): shutil.copyfile(qrc_resources_light_path, qrc_resources_path) -def autoLineBreak(text, length): #automatic line breaking for tooltips. Keeps indentation with spaces and preexisting line breaks + +def autoLineBreak( + text, length +): # automatic line breaking for tooltips. Keeps indentation with spaces and preexisting line breaks lines = [] current_line = [] # Split the text into lines while preserving existing newline characters - existing_lines = text.split('\n') + existing_lines = text.split("\n") for existing_line in existing_lines: # Calculate the indentation for the current line @@ -3879,22 +3938,25 @@ def autoLineBreak(text, length): #automatic line breaking for tooltips. Keeps in words = existing_line.lstrip().split() # Split each line into words for word in words: - if len(' '.join(current_line + [word])) + indent <= length: + if len(" ".join(current_line + [word])) + indent <= length: current_line.append(word) else: - lines.append(' ' * indent + ' '.join(current_line)) + lines.append(" " * indent + " ".join(current_line)) current_line = [word] if current_line: # Add any remaining words as the last line - lines.append(' ' * indent + ' '.join(current_line)) + lines.append(" " * indent + " ".join(current_line)) # Reset the current line for the next existing line current_line = [] - return '\n'.join(lines) + return "\n".join(lines) -def format_bullet_points(text): #indentation for bullet points in tooltips. Implementation not robust - lines = text.split('\n') + +def format_bullet_points( + text, +): # indentation for bullet points in tooltips. Implementation not robust + lines = text.split("\n") formatted_lines = [] indent = False indentNo = 0 @@ -3914,19 +3976,21 @@ def format_bullet_points(text): #indentation for bullet points in tooltips. Impl formatted_lines.append(formatted_line) - return '\n'.join(formatted_lines) + return "\n".join(formatted_lines) + -def format_number_list(text): #indentation for number points in tooltips. Implementation not robust - lines = text.split('\n') +def format_number_list( + text, +): # indentation for number points in tooltips. Implementation not robust + lines = text.split("\n") formatted_lines = [] indent = False indentNo = 0 for line in lines: - if line.strip().startswith(( - "0. ", "1. ", "2. ", "3. ", "4. ", - "5. ", "6. ", "7. ", "8. ", "9. " - )): + if line.strip().startswith( + ("0. ", "1. ", "2. ", "3. ", "4. ", "5. ", "6. ", "7. ", "8. ", "9. ") + ): indent = True formatted_line = line indentNo = len(line) - len(line.lstrip()) @@ -3940,10 +4004,10 @@ def format_number_list(text): #indentation for number points in tooltips. Implem formatted_lines.append(formatted_line) - return '\n'.join(formatted_lines) + return "\n".join(formatted_lines) -def get_tooltips_from_docs(): +def get_tooltips_from_docs(): # gets tooltips for GUI from .\Cell_ACDC\docs\source\tooltips.rst var_pattern = r"\|(\S*)\|" shortcut_pattern = r"\*\*(\".*\")\):\*\*" @@ -3951,17 +4015,26 @@ def get_tooltips_from_docs(): if not os.path.exists(tooltips_rst_filepath): return {} - + with open(tooltips_rst_filepath, "r") as file: lines = file.readlines() new_lines = [] for line in lines: - if not (line.startswith("..") or line.startswith(" :target:") or line.startswith(" :alt:") or line.startswith(" :width:") or line.startswith(" :height:") or line==""): + if not ( + line.startswith("..") + or line.startswith(" :target:") + or line.startswith(" :alt:") + or line.startswith(" :width:") + or line.startswith(" :height:") + or line == "" + ): new_lines.append(line) lines = new_lines - non_empty_lines = [line.replace("\n", "") for line in lines if line.strip()] #also removes \n from lines + non_empty_lines = [ + line.replace("\n", "") for line in lines if line.strip() + ] # also removes \n from lines lines = non_empty_lines tipdict = {} @@ -3977,7 +4050,7 @@ def get_tooltips_from_docs(): if shortcut: shortcut = shortcut.group(1) else: - shortcut = "\"No shortcut\"" + shortcut = '"No shortcut"' desc = line.split("):**")[1].lstrip(" ") @@ -3992,7 +4065,7 @@ def get_tooltips_from_docs(): followMatch = re.search(var_pattern, followLine) if followMatch or followLine.startswith("* **"): break - else: + else: descList.append(followLine) if descList != []: @@ -4000,9 +4073,7 @@ def get_tooltips_from_docs(): descList.pop(-1) descList.pop(-1) - for entry in descList: - entry = entry.replace("| ", "") if entry.startswith(" " * 4): @@ -4024,87 +4095,93 @@ def get_tooltips_from_docs(): tipdict[name] = f"Name: {title}\nShortcut: {shortcut}\n\n{desc}" return tipdict + def save_df_to_csv_temp_path(df, csv_filename, **to_csv_kwargs): tempDir = tempfile.mkdtemp() tempFilepath = os.path.join(tempDir, csv_filename) df.to_csv(tempFilepath, **to_csv_kwargs) return tempFilepath + def loaded_df_to_points_data(df, t_col, z_col, y_col, x_col): points_data = {} - if 'id' not in df.columns: - df['id'] = '' - - if t_col != 'None': + if "id" not in df.columns: + df["id"] = "" + + if t_col != "None": grouped = df.groupby(t_col) else: grouped = [(0, df)] - + for frame_i, df_frame in grouped: - if z_col != 'None': + if z_col != "None": df_frame[z_col] = df_frame[z_col].round().astype(int) # Use integer z zz = df_frame[z_col] - points_data[frame_i] = {} + points_data[frame_i] = {} for z in zz.values: df_z = df_frame[df_frame[z_col] == z] z_int = round(z) if z_int in points_data[frame_i]: continue points_data[frame_i][z_int] = { - 'x': df_z[x_col].to_list(), - 'y': df_z[y_col].to_list(), - 'id': df_z['id'].to_list(), - 'data': [row.to_string() for _, row in df_z.iterrows()] + "x": df_z[x_col].to_list(), + "y": df_z[y_col].to_list(), + "id": df_z["id"].to_list(), + "data": [row.to_string() for _, row in df_z.iterrows()], } else: points_data[frame_i] = { - 'x': df[x_col].to_list(), - 'y': df[y_col].to_list(), - 'id': df['id'].to_list(), - 'data': [row.to_string() for _, row in df.iterrows()] + "x": df[x_col].to_list(), + "y": df[y_col].to_list(), + "id": df["id"].to_list(), + "data": [row.to_string() for _, row in df.iterrows()], } return points_data + def load_df_points_layer(filepath): df = None - if filepath.endswith('.csv'): + if filepath.endswith(".csv"): df = pd.read_csv(filepath) - elif filepath.endswith('.h5'): + elif filepath.endswith(".h5"): with pd.HDFStore(filepath) as h5: keys = h5.keys() dfs = [h5.get(key) for key in keys] - df = pd.concat(dfs, keys=keys, names=['h5_key']) + df = pd.concat(dfs, keys=keys, names=["h5_key"]) return df + def get_unique_exp_paths(paths: List): unique_exp_paths = set() for path in paths: exp_path = get_exp_path(path) - unique_exp_paths.add(exp_path.replace('\\', '/')) + unique_exp_paths.add(exp_path.replace("\\", "/")) return unique_exp_paths + def search_filepath_in_pos_path_from_endname( - pos_path, endname, include_spotmax_out=False - ): - images_path = os.path.join(pos_path, 'Images') - spotmax_out_path = os.path.join(pos_path, 'spotMAX_output') + pos_path, endname, include_spotmax_out=False +): + images_path = os.path.join(pos_path, "Images") + spotmax_out_path = os.path.join(pos_path, "spotMAX_output") if include_spotmax_out and os.path.exists(spotmax_out_path): for sm_file in os.listdir(spotmax_out_path): if endname == sm_file: return os.path.join(spotmax_out_path, sm_file) - - images_files = myutils.listdir(images_path) + + images_files = utils.listdir(images_path) sample_filepath = os.path.join(images_path, images_files[0]) - posData = loadData(sample_filepath, '') + posData = loadData(sample_filepath, "") posData.getBasenameAndChNames() - to_match = f'{posData.basename}{endname}' + to_match = f"{posData.basename}{endname}" for file in images_files: if file == to_match: return os.path.join(images_path, file) + def search_filepath_from_endname(exp_path, endname, include_spotmax_out=False): - pos_foldernames = myutils.get_pos_foldernames(exp_path) + pos_foldernames = utils.get_pos_foldernames(exp_path) for pos in pos_foldernames: pos_path = os.path.join(exp_path, pos) filepath = search_filepath_in_pos_path_from_endname( @@ -4112,40 +4189,33 @@ def search_filepath_from_endname(exp_path, endname, include_spotmax_out=False): ) return filepath -def askOpenCsvFile( - title='Open CSV file', - start_dir=None, - qparent=None - ): + +def askOpenCsvFile(title="Open CSV file", start_dir=None, qparent=None): if start_dir is None: - start_dir = myutils.getMostRecentPath() - - file_types = f'CSV files (*.csv);;All Files (*)' - + start_dir = utils.getMostRecentPath() + + file_types = f"CSV files (*.csv);;All Files (*)" + fileDialog = QFileDialog.getOpenFileName - args = ( - qparent, - title, - start_dir, - file_types - ) + args = (qparent, title, start_dir, file_types) file_path = fileDialog(*args) if not isinstance(file_path, str): file_path = file_path[0] return file_path + def read_measurements_workflow_from_config(filepath): configPars = config.ConfigParser() configPars.read(filepath) options_that_are_lists = { - 'channels', - 'calc_for_each_zslice_channels', - 'size_metrics_to_save', - 'regionprops_to_save', - 'channel_indipendent_custom_metrics_to_save', - 'mixed_combine_metrics_to_skip', - 'channel_names_to_skip', - 'channel_names_to_process' + "channels", + "calc_for_each_zslice_channels", + "size_metrics_to_save", + "regionprops_to_save", + "channel_indipendent_custom_metrics_to_save", + "mixed_combine_metrics_to_skip", + "channel_names_to_skip", + "channel_names_to_process", } ini_items = {} for section in configPars.sections(): @@ -4153,23 +4223,23 @@ def read_measurements_workflow_from_config(filepath): ini_items[section] = {} for option, value in options.items(): is_list = ( - section == 'paths_info' + section == "paths_info" or option in options_that_are_lists - or option.startswith('metrics_to_skip_') - or option.startswith('metrics_to_save_') + or option.startswith("metrics_to_skip_") + or option.startswith("metrics_to_save_") ) if is_list: if value: - value = value.strip('\n').strip().split('\n') + value = value.strip("\n").strip().split("\n") else: value = [] ini_items[section][option] = value continue - - if value.lower() == 'false': + + if value.lower() == "false": value = False - elif value.lower() == 'true': + elif value.lower() == "true": value = True - + ini_items[section][option] = value - return ini_items \ No newline at end of file + return ini_items diff --git a/cellacdc/measure.py b/cellacdc/measure.py index b18b7c646..eecc87f12 100644 --- a/cellacdc/measure.py +++ b/cellacdc/measure.py @@ -3,10 +3,13 @@ import skimage.transform import skimage.measure + def rotational_volume( - obj: skimage.measure._regionprops.RegionProperties, - PhysicalSizeY=1.0, PhysicalSizeX=1.0, vox_to_fl=None - ): + obj: skimage.measure._regionprops.RegionProperties, + PhysicalSizeY=1.0, + PhysicalSizeX=1.0, + vox_to_fl=None, +): """Given the region properties of a 2D or 3D object (from skimage.measure.regionprops). calculate the rotation volume as described in the Supplementary information of https://www.nature.com/articles/s41467-020-16764-x @@ -21,7 +24,7 @@ def rotational_volume( PhysicalSizeX : float, optional Physical size of the pixel in the X-diretion in micrometer/pixel. By default 1.0 - + Returns ------- tuple @@ -29,11 +32,11 @@ def rotational_volume( Notes ------- - For 3D objects we take the max projection. + For 3D objects we take the max projection. We convert PhysicalSizeY and PhysicalSizeX to float because when they are read from csv they might be a string value. - """ + """ if obj.image.ndim == 3: obj_image = obj.image.max(axis=0) obj_rp = skimage.measure.regionprops(obj_image.astype(np.uint8))[0] @@ -41,17 +44,21 @@ def rotational_volume( else: obj_image = obj.image obj_orientation = obj.orientation - + if vox_to_fl is None: - vox_to_fl = float(PhysicalSizeY)*(float(PhysicalSizeX)**2) - + vox_to_fl = float(PhysicalSizeY) * (float(PhysicalSizeX) ** 2) + rotate_ID_img = skimage.transform.rotate( - obj_image.astype(np.uint8), -(obj_orientation*180/np.pi), - resize=True, order=3, preserve_range=True + obj_image.astype(np.uint8), + -(obj_orientation * 180 / np.pi), + resize=True, + order=3, + preserve_range=True, ) - radii = np.sum(rotate_ID_img, axis=1)/2 - vol_vox = np.sum(np.pi*(radii**2)) - return vol_vox, float(vol_vox*vox_to_fl) + radii = np.sum(rotate_ID_img, axis=1) / 2 + vol_vox = np.sum(np.pi * (radii**2)) + return vol_vox, float(vol_vox * vox_to_fl) + def separate_with_label(lab, rp, IDs_to_separate, maxID, click_coords_list=None): separate_lab = lab.copy() @@ -80,7 +87,7 @@ def separate_with_label(lab, rp, IDs_to_separate, maxID, click_coords_list=None) click_y_local = yclick - ymin click_x_local = xclick - xmin id_to_keep = label_obj[click_y_local, click_x_local] - + separate_lab[obj.slice][obj.image] = 0 separateIDs = [] for sub_obj_idx, sub_obj in enumerate(label_obj_rp): @@ -91,4 +98,3 @@ def separate_with_label(lab, rp, IDs_to_separate, maxID, click_coords_list=None) separate_lab[obj.slice][sub_obj.slice][sub_obj.image] = new_ID separateIDs.append(new_ID) return separate_lab, separateIDs - \ No newline at end of file diff --git a/cellacdc/measurements.py b/cellacdc/measurements.py index 51404e67c..f2ff15930 100755 --- a/cellacdc/measurements.py +++ b/cellacdc/measurements.py @@ -15,60 +15,62 @@ from . import core, base_cca_dict, cca_df_colnames, html_utils, config, printl from . import user_profile_path, cca_functions -skimage_rp_url = 'https://scikit-image.org/docs/0.18.x/api/skimage.measure.html#skimage.measure.regionprops' +skimage_rp_url = "https://scikit-image.org/docs/0.18.x/api/skimage.measure.html#skimage.measure.regionprops" import warnings + warnings.filterwarnings("ignore", message="Failed to get convex hull image.") warnings.filterwarnings("ignore", message="divide by zero encountered in long_scalars") warnings.filterwarnings("ignore", message="Mean of empty slice.") warnings.filterwarnings("ignore", message="invalid value encountered in double_scalars") -acdc_metrics_path = os.path.join(user_profile_path, 'acdc-metrics') +acdc_metrics_path = os.path.join(user_profile_path, "acdc-metrics") if not os.path.exists(acdc_metrics_path): os.makedirs(acdc_metrics_path, exist_ok=True) sys.path.append(acdc_metrics_path) -combine_metrics_ini_path = os.path.join(acdc_metrics_path, 'combine_metrics.ini') +combine_metrics_ini_path = os.path.join(acdc_metrics_path, "combine_metrics.ini") cellacdc_path = os.path.dirname(os.path.abspath(__file__)) -metrics_path = os.path.join(cellacdc_path, 'metrics') +metrics_path = os.path.join(cellacdc_path, "metrics") -how_3D_to_2D_pattern = r'zSlice|3D|maxProj|meanProj|(?=\s*$)' +how_3D_to_2D_pattern = r"zSlice|3D|maxProj|meanProj|(?=\s*$)" # Copy metrics to acdc-metrics user path for file in os.listdir(metrics_path): - if not file.endswith('.py'): + if not file.endswith(".py"): continue src = os.path.join(metrics_path, file) dst = os.path.join(acdc_metrics_path, file) shutil.copy(src, dst) PROPS_DTYPES = { - 'label': int, - 'major_axis_length': float, - 'minor_axis_length': float, - 'eccentricity': float, - 'circularity': float, - 'roundness': float, - 'aspect_ratio': float, - 'inertia_tensor_eigvals': tuple, - 'equivalent_diameter': float, - 'moments': np.ndarray, - 'area': int, - 'solidity': float, - 'extent': float, - 'inertia_tensor': np.ndarray, - 'filled_area': int, - 'centroid': tuple, - 'bbox_area': int, - 'local_centroid': tuple, - 'convex_area': int, - 'euler_number': int, - 'moments_normalized': np.ndarray, - 'moments_central': np.ndarray, - 'bbox': tuple + "label": int, + "major_axis_length": float, + "minor_axis_length": float, + "eccentricity": float, + "circularity": float, + "roundness": float, + "aspect_ratio": float, + "inertia_tensor_eigvals": tuple, + "equivalent_diameter": float, + "moments": np.ndarray, + "area": int, + "solidity": float, + "extent": float, + "inertia_tensor": np.ndarray, + "filled_area": int, + "centroid": tuple, + "bbox_area": int, + "local_centroid": tuple, + "convex_area": int, + "euler_number": int, + "moments_normalized": np.ndarray, + "moments_central": np.ndarray, + "bbox": tuple, } + def getMetricsFunc(posData): metrics_func, all_metrics_names = standard_metrics_func() total_metrics = len(metrics_func) @@ -81,35 +83,40 @@ def getMetricsFunc(posData): # defined in loadData.setCombineMetricsConfig method for key, section in posData.combineMetricsConfig.items(): total_metrics += len(section) - + out = ( - metrics_func, all_metrics_names, custom_func_dict, total_metrics, - ch_indipend_custom_func_dict + metrics_func, + all_metrics_names, + custom_func_dict, + total_metrics, + ch_indipend_custom_func_dict, ) return out + def get_metric_group_name(col_name: str): size_metrics_names = set(get_size_metrics_desc(True, True).keys()) if col_name in size_metrics_names: - return 'size' - + return "size" + props_names = get_props_names() if col_name in props_names: - return 'regionprop' + return "regionprop" ch_indip_custom_metrics_names = _get_ch_indipendent_custom_metrics_names() if col_name in ch_indip_custom_metrics_names: - return 'ch_indipend_custom_metric' - + return "ch_indipend_custom_metric" + ch_indip_custom_metrics_names = _get_ch_indipendent_custom_metrics_names() if col_name in ch_indip_custom_metrics_names: - return 'mixed_channels' + return "mixed_channels" standard_metrics_names = set(_get_metrics_names().keys()) for col in standard_metrics_names: - if f'_{col_name}' in col: - channel_name = col_name.split(f'_{col_name}')[0] - return {'standard': channel_name} + if f"_{col_name}" in col: + channel_name = col_name.split(f"_{col_name}")[0] + return {"standard": channel_name} + def get_all_metrics_names(include_custom=True): all_metrics_names = [] @@ -126,93 +133,98 @@ def get_all_metrics_names(include_custom=True): all_metrics_names.extend(props_names) return all_metrics_names + def get_all_acdc_df_colnames(include_custom=True): all_acdc_df_colnames = get_all_metrics_names(include_custom=include_custom) - all_acdc_df_colnames.append('frame_i') - all_acdc_df_colnames.append('time_seconds') - all_acdc_df_colnames.append('Cell_ID') + all_acdc_df_colnames.append("frame_i") + all_acdc_df_colnames.append("time_seconds") + all_acdc_df_colnames.append("Cell_ID") all_acdc_df_colnames.extend(cca_df_colnames) additional_colnames = [ - 'is_cell_dead', - 'is_cell_excluded', - 'x_centroid', - 'y_centroid', - 'was_manually_edited' + "is_cell_dead", + "is_cell_excluded", + "x_centroid", + "y_centroid", + "was_manually_edited", ] all_acdc_df_colnames.extend(additional_colnames) return all_acdc_df_colnames + def get_user_combine_metrics_equations(chName, isSegm3D=False): _, equations = channel_combine_metrics_desc(chName, isSegm3D=isSegm3D) return equations + def get_custom_metrics_func(): scripts = os.listdir(acdc_metrics_path) custom_func_dict = {} for file in scripts: - if file == '__init__.py': + if file == "__init__.py": continue module_name, ext = os.path.splitext(file) - if ext != '.py': + if ext != ".py": # print(f'The file {file} is not a python file. Ignoring it.') continue - if module_name == 'combine_metrics_example': + if module_name == "combine_metrics_example": # Ignore the example continue - if module_name == 'channel_indipendent_metric_example': + if module_name == "channel_indipendent_metric_example": # Ignore the example continue try: module = import_module(module_name) - if not getattr(module, 'CALCULATE_FOR_EACH_CHANNEL', True): + if not getattr(module, "CALCULATE_FOR_EACH_CHANNEL", True): continue - + func = getattr(module, module_name) custom_func_dict[module_name] = func except Exception: traceback.print_exc() return custom_func_dict + def get_channel_indipendent_custom_metrics_func(): scripts = os.listdir(acdc_metrics_path) custom_func_dict = {} for file in scripts: - if file == '__init__.py': + if file == "__init__.py": continue module_name, ext = os.path.splitext(file) - if ext != '.py': + if ext != ".py": # print(f'The file {file} is not a python file. Ignoring it.') continue - if module_name == 'combine_metrics_example': + if module_name == "combine_metrics_example": # Ignore the example continue - if module_name == 'channel_indipendent_metric_example': + if module_name == "channel_indipendent_metric_example": # Ignore the example continue try: module = import_module(module_name) - if getattr(module, 'CALCULATE_FOR_EACH_CHANNEL', True): + if getattr(module, "CALCULATE_FOR_EACH_CHANNEL", True): continue - + func = getattr(module, module_name) custom_func_dict[module_name] = func except Exception: traceback.print_exc() return custom_func_dict + def read_saved_user_combine_config(): configPars = _get_saved_user_combine_config() if configPars is None: configPars = config.ConfigParser() - if 'equations' not in configPars: - configPars['equations'] = {} + if "equations" not in configPars: + configPars["equations"] = {} - if 'mixed_channels_equations' not in configPars: - configPars['mixed_channels_equations'] = {} + if "mixed_channels_equations" not in configPars: + configPars["mixed_channels_equations"] = {} - if 'channelLess_equations' not in configPars: - configPars['channelLess_equations'] = {} + if "channelLess_equations" not in configPars: + configPars["channelLess_equations"] = {} return configPars @@ -222,7 +234,7 @@ def _get_saved_user_combine_config(): configPars = None for file in files: module_name, ext = os.path.splitext(file) - if ext != '.ini': + if ext != ".ini": continue filePath = os.path.join(acdc_metrics_path, file) @@ -230,44 +242,50 @@ def _get_saved_user_combine_config(): configPars.read(filePath) return configPars + def add_user_combine_metrics(configPars, equation, colName, isMixedChannels): - section = 'mixed_channels_equations' if isMixedChannels else 'equations' + section = "mixed_channels_equations" if isMixedChannels else "equations" if section not in configPars: configPars[section] = {} configPars[section][colName] = equation return configPars + def add_channelLess_combine_metrics(configPars, equation, equation_name, terms): - if 'channelLess_equations' not in configPars: - configPars['channelLess_equations'] = {} - terms = ','.join(terms) - equation_terms = f'{equation};{terms}' - configPars['channelLess_equations'][equation_name] = equation_terms + if "channelLess_equations" not in configPars: + configPars["channelLess_equations"] = {} + terms = ",".join(terms) + equation_terms = f"{equation};{terms}" + configPars["channelLess_equations"][equation_name] = equation_terms return configPars + def save_common_combine_metrics(configPars): - with open(combine_metrics_ini_path, 'w') as configfile: + with open(combine_metrics_ini_path, "w") as configfile: configPars.write(configfile) + def _get_custom_metrics_names(): custom_func_dict = get_custom_metrics_func() keys = custom_func_dict.keys() - custom_metrics_names = {func_name:func_name for func_name in keys} + custom_metrics_names = {func_name: func_name for func_name in keys} return custom_metrics_names + def _get_ch_indipendent_custom_metrics_names(): custom_func_dict = get_channel_indipendent_custom_metrics_func() keys = custom_func_dict.keys() - custom_metrics_names = {func_name:func_name for func_name in keys} + custom_metrics_names = {func_name: func_name for func_name in keys} return custom_metrics_names + def ch_indipend_custom_metrics_desc(isZstack, isSegm3D=False): how_3Dto2D, how_3Dto2D_desc = get_how_3Dto2D(isZstack, isSegm3D) custom_metrics_names = _get_ch_indipendent_custom_metrics_names() custom_metrics_desc = {} for how, how_desc in zip(how_3Dto2D, how_3Dto2D_desc): for func_name, func_desc in custom_metrics_names.items(): - metric_name = f'{func_name}{how}' + metric_name = f"{func_name}{how}" if isZstack: note_txt = html_utils.paragraph(f""" {_get_zStack_note(how_desc)} @@ -276,7 +294,7 @@ def ch_indipend_custom_metrics_desc(isZstack, isSegm3D=False): converting 3D to 2D {how_desc} """) else: - note_txt = '' + note_txt = "" desc = html_utils.paragraph(f""" {func_desc} is a custom defined measurement.

    @@ -285,19 +303,19 @@ def ch_indipend_custom_metrics_desc(isZstack, isSegm3D=False): {note_txt} """) custom_metrics_desc[metric_name] = desc - + return custom_metrics_desc + def custom_metrics_desc( - isZstack, chName, posData=None, isSegm3D=False, - return_combine=False - ): + isZstack, chName, posData=None, isSegm3D=False, return_combine=False +): how_3Dto2D, how_3Dto2D_desc = get_how_3Dto2D(isZstack, isSegm3D) custom_metrics_names = _get_custom_metrics_names() custom_metrics_desc = {} for how, how_desc in zip(how_3Dto2D, how_3Dto2D_desc): for func_name, func_desc in custom_metrics_names.items(): - metric_name = f'{chName}_{func_name}{how}' + metric_name = f"{chName}_{func_name}{how}" if isZstack: note_txt = html_utils.paragraph(f""" {_get_zStack_note(how_desc)} @@ -306,7 +324,7 @@ def custom_metrics_desc( converting 3D to 2D {how_desc} """) else: - note_txt = '' + note_txt = "" desc = html_utils.paragraph(f""" {func_desc} is a custom defined measurement.

    @@ -326,13 +344,14 @@ def custom_metrics_desc( else: return custom_metrics_desc + def channel_combine_metrics_desc(chName, posData=None, isSegm3D=False): combine_metrics_configPars = read_saved_user_combine_config() how_3Dto2D, how_3Dto2D_desc = get_how_3Dto2D(True, isSegm3D) - combine_metrics = combine_metrics_configPars['equations'] + combine_metrics = combine_metrics_configPars["equations"] if posData is not None: - posDataEquations = posData.combineMetricsConfig['equations'] + posDataEquations = posData.combineMetricsConfig["equations"] combine_metrics = {**combine_metrics, **posDataEquations} combine_metrics_desc = {} all_metrics_names = get_all_metrics_names() @@ -354,7 +373,7 @@ def channel_combine_metrics_desc(chName, posData=None, isSegm3D=False): how_desc = how_3Dto2D_present[0] note_txt = html_utils.paragraph(f"""{_get_zStack_note(how_desc)}""") else: - note_txt = '' + note_txt = "" desc = html_utils.paragraph(f""" {metric_name} is a custom combined measurement that is the @@ -365,14 +384,14 @@ def channel_combine_metrics_desc(chName, posData=None, isSegm3D=False): combine_metrics_desc[metric_name] = desc equations[metric_name] = equation - channelLess_combine_metrics = combine_metrics_configPars['channelLess_equations'] + channelLess_combine_metrics = combine_metrics_configPars["channelLess_equations"] for name, equation_terms in channelLess_combine_metrics.items(): - channelLess_equation, terms = equation_terms.split(';') - _colNames = terms.split(',') - metric_name = f'{chName}_{name}' + channelLess_equation, terms = equation_terms.split(";") + _colNames = terms.split(",") + metric_name = f"{chName}_{name}" equation = channelLess_equation for _col in _colNames: - equation = equation.replace(_col, f'{chName}{_col}') + equation = equation.replace(_col, f"{chName}{_col}") if not any([metric in equation for metric in all_metrics_names]): # Equation does not contain any of the available metrics --> Skip it @@ -385,7 +404,7 @@ def channel_combine_metrics_desc(chName, posData=None, isSegm3D=False): how_desc = how_3Dto2D_present[0] note_txt = html_utils.paragraph(f"""{_get_zStack_note(how_desc)}""") else: - note_txt = '' + note_txt = "" desc = html_utils.paragraph(f""" {metric_name} is a custom combined measurement that is the @@ -398,14 +417,17 @@ def channel_combine_metrics_desc(chName, posData=None, isSegm3D=False): return combine_metrics_desc, equations + def get_user_combine_mixed_channels_equations(isSegm3D=False): _, equations = _combine_mixed_channels_desc(isSegm3D=isSegm3D) return equations + def get_combine_mixed_channels_desc(isSegm3D=False): desc, _ = _combine_mixed_channels_desc(isSegm3D=isSegm3D) return desc + def _combine_mixed_channels_desc(isSegm3D=False, configPars=None): if configPars is None: configPars = _get_saved_user_combine_config() @@ -415,7 +437,7 @@ def _combine_mixed_channels_desc(isSegm3D=False, configPars=None): equations = {} mixed_channels_desc = {} how_3Dto2D, how_3Dto2D_desc = get_how_3Dto2D(True, isSegm3D) - mixed_channels_combine_metrics = configPars['mixed_channels_equations'] + mixed_channels_combine_metrics = configPars["mixed_channels_equations"] all_metrics_names = get_all_metrics_names() equations = {} for name, equation in mixed_channels_combine_metrics.items(): @@ -431,7 +453,7 @@ def _combine_mixed_channels_desc(isSegm3D=False, configPars=None): how_desc = how_3Dto2D_present[0] note_txt = html_utils.paragraph(f"""{_get_zStack_note(how_desc)}""") else: - note_txt = '' + note_txt = "" desc = html_utils.paragraph(f""" {metric_name} is a custom combined measurement that is the @@ -443,6 +465,7 @@ def _combine_mixed_channels_desc(isSegm3D=False, configPars=None): equations[metric_name] = equation return mixed_channels_desc, equations + def combine_mixed_channels_desc(posData=None, isSegm3D=False, available_cols=None): desc, equations = _combine_mixed_channels_desc(isSegm3D=isSegm3D) if posData is None: @@ -454,13 +477,13 @@ def combine_mixed_channels_desc(posData=None, isSegm3D=False, available_cols=Non ) all_desc = {**desc, **pos_desc} all_equations = {**equations, **pos_equations} - + if available_cols is not None: # Check that user folder combine metrics have the right columns available_desc = {} available_equations = {} for name, equation in all_equations.items(): - cols = re.findall(r'[A-Za-z0-9]+_[A-Za-z0-9_]+', equation) + cols = re.findall(r"[A-Za-z0-9]+_[A-Za-z0-9_]+", equation) if all([col in available_cols for col in cols]): available_desc[name] = all_desc[name] available_equations[name] = equation @@ -468,34 +491,40 @@ def combine_mixed_channels_desc(posData=None, isSegm3D=False, available_cols=Non else: return all_desc, all_equations + def _um3(): - return 'µm3' + return "µm3" + def _um2(): - return 'µm2' + return "µm2" + def _um(): - return 'µ' + return "µ" + def _fl(): - return 'fl' + return "fl" + def _get_zStack_note(how_desc): - s = (f""" + s = f""" NOTE: since you loaded 3D z-stacks, Cell-ACDC needs to convert the z-stacks to 2D images {how_desc} for this metric.
    This is specified in the name of the column.

    - """) + """ return s + def get_size_metrics_desc(isSegm3D, is_timelapse): - url = 'https://www.nature.com/articles/s41467-020-16764-x#Sec16' + url = "https://www.nature.com/articles/s41467-020-16764-x#Sec16" size_metrics = { - 'cell_area_pxl': html_utils.paragraph(""" + "cell_area_pxl": html_utils.paragraph(""" Area of the segmented object in pixels, i.e., total number of pixels in the object. """), - 'cell_vol_vox': html_utils.paragraph(f""" + "cell_vol_vox": html_utils.paragraph(f""" Estimated volume of the segmented object in voxels.


    To calculate object volume based on 2D masks, the object is first aligned along its major axis.

    @@ -521,13 +550,13 @@ def get_size_metrics_desc(isSegm3D, is_timelapse): (see in the Standard measurements group) and it cannot be unchecked.

    """), - 'cell_area_um2': html_utils.paragraph(f""" + "cell_area_um2": html_utils.paragraph(f""" Area of the segmented object in {_um2()}, i.e., total number of pixels in the object.

    Conversion from pixels to {_um2()} is perfomed using the provided pixel size. """), - 'cell_vol_fl': html_utils.paragraph(f""" + "cell_vol_fl": html_utils.paragraph(f""" Estimated volume of the segmented object in {_um3()}.


    To calculate object volume based on 2D masks, the object is first @@ -556,11 +585,11 @@ def get_size_metrics_desc(isSegm3D, is_timelapse): by the concentration metric that you requested to save (see in the Standard measurements group) and it cannot be unchecked.

    - """) + """), } if isSegm3D: size_metrics_3D = { - 'cell_vol_vox_3D': html_utils.paragraph(f""" + "cell_vol_vox_3D": html_utils.paragraph(f""" Volume of the segmented object in voxels.

    This is given by the total number of voxels inside the object.

    @@ -569,7 +598,7 @@ def get_size_metrics_desc(isSegm3D, is_timelapse): (see in the Standard measurements group) and it cannot be unchecked.

    """), - 'cell_vol_fl_3D': html_utils.paragraph(f""" + "cell_vol_fl_3D": html_utils.paragraph(f""" Volume of the segmented object in {_fl()}.

    This is given by the total number of voxels inside the object multiplied by the voxel volume.

    @@ -589,61 +618,59 @@ def get_size_metrics_desc(isSegm3D, is_timelapse): size_metrics = {**size_metrics, **size_metrics_3D} if is_timelapse: velocity_metrics = { - 'velocity_pixel': html_utils.paragraph(f""" + "velocity_pixel": html_utils.paragraph(f""" Velocity in [pixel/frame] of the segmented object between previous and current frame. """), - 'velocity_um': html_utils.paragraph(f""" + "velocity_um": html_utils.paragraph(f""" Velocity in [{_um()}/frame] of the segmented object between previous and current frame. - """) + """), } size_metrics = {**size_metrics, **velocity_metrics} return size_metrics + def get_how_3Dto2D(isZstack, isSegm3D): - how_3Dto2D = ['_maxProj', '_meanProj', '_zSlice'] if isZstack else [''] + how_3Dto2D = ["_maxProj", "_meanProj", "_zSlice"] if isZstack else [""] if isSegm3D: - how_3Dto2D.append('_3D') + how_3Dto2D.append("_3D") how_3Dto2D_desc = [ - 'using a max projection', - 'using a mean projection (recommended for confocal imaging)', - 'using the z-slice you used for segmentation ' - '(recommended for epifluorescence imaging)' - 'NOTE: if segmentation mask is 3D, Cell-ACDC will use the ' - 'center z-slice of each object.', - 'using 3D data' + "using a max projection", + "using a mean projection (recommended for confocal imaging)", + "using the z-slice you used for segmentation " + "(recommended for epifluorescence imaging)" + "NOTE: if segmentation mask is 3D, Cell-ACDC will use the " + "center z-slice of each object.", + "using 3D data", ] return how_3Dto2D, how_3Dto2D_desc + def standard_metrics_desc( - isZstack, chName, isManualBackgrPresent=False, isSegm3D=False - ): + isZstack, chName, isManualBackgrPresent=False, isSegm3D=False +): how_3Dto2D, how_3Dto2D_desc = get_how_3Dto2D(isZstack, isSegm3D) - metrics_names = _get_metrics_names( - is_manual_bkgr_present=isManualBackgrPresent - ) - bkgr_val_names = _get_bkgr_val_names( - is_manual_bkgr_present=isManualBackgrPresent - ) + metrics_names = _get_metrics_names(is_manual_bkgr_present=isManualBackgrPresent) + bkgr_val_names = _get_bkgr_val_names(is_manual_bkgr_present=isManualBackgrPresent) metrics_desc = {} bkgr_val_desc = {} for how, how_desc in zip(how_3Dto2D, how_3Dto2D_desc): for func_name, func_desc in metrics_names.items(): - metric_name = f'{chName}_{func_name}{how}' + metric_name = f"{chName}_{func_name}{how}" if isZstack: - note_txt = (f""" + note_txt = f""" {_get_zStack_note(how_desc)} Example: {metric_name} is the {func_desc.lower()} of the {chName} signal after converting 3D to 2D {how_desc} - """) + """ else: - note_txt = '' + note_txt = "" - if func_desc == 'Amount': + if func_desc == "Amount": amount_formula = _get_amount_formula_str(func_name) - amount_desc = (f""" + amount_desc = f""" Amount is the background corrected (subtracted) total fluorescence intensity, which is usually the best proxy for the amount of the tagged molecule, e.g., @@ -655,32 +682,32 @@ def standard_metrics_desc( where _obj refers to the pixels inside the segmented object.

    - """) - main_desc = f'{func_desc} computed from' - elif func_desc == 'Concentration': - amount_desc = (""" + """ + main_desc = f"{func_desc} computed from" + elif func_desc == "Concentration": + amount_desc = """ Concentration is given by Amount/cell_volume, where amount is the background corrected (subtracted) total fluorescence intensity. Amount is usually the best proxy for the amount of the tagged molecule, e.g., protein amount.

    - """) - main_desc = f'{func_desc} computed from' + """ + main_desc = f"{func_desc} computed from" else: - amount_desc = '' - main_desc = f'{func_desc} computed from' + amount_desc = "" + main_desc = f"{func_desc} computed from" - if func_name == 'amount_autoBkgr': - bkgr_desc = (""" + if func_name == "amount_autoBkgr": + bkgr_desc = """ autoBkgr means that the background value used to correct the intensities is computed as the median of ALL the pixels outside of the segmented objects (i.e., pixels with ID 0 in the segmentation mask)

    - """) - elif func_name == 'amount_dataPrepBkgr': - bkgr_desc = (""" + """ + elif func_name == "amount_dataPrepBkgr": + bkgr_desc = """ dataPrepBkgr means that the background value used to correct the intensities is computed as the median of the pixels from the pixels inside the rectangular @@ -688,17 +715,17 @@ def standard_metrics_desc( data prep module (module 1.).

    Note taht this metric is grayed out and it cannot be selected if the selection of the background ROIs was not performed. - """) - elif func_name.find('_manualBkgr') != -1: - bkgr_desc = (""" + """ + elif func_name.find("_manualBkgr") != -1: + bkgr_desc = """ manualBkgr means that the background value used to correct the intensities is computed as the mean of the pixels from the pixels inside each background objects that you selected in the GUI module (module 3).

    - """) + """ else: - bkgr_desc = '' + bkgr_desc = "" desc = html_utils.paragraph(f""" {main_desc} the pixels inside @@ -707,34 +734,34 @@ def standard_metrics_desc( """) metrics_desc[metric_name] = desc - median_note = (""" + median_note = """ Note that this value might be grayed out because it is required by the corresponding amount metric that you requested to save (see above in the Standard measurements group) and it cannot be unchecked.

    - """) + """ for bkgr_name, bkgr_desc in bkgr_val_names.items(): - bkgr_colname = f'{chName}_{bkgr_name}{how}' + bkgr_colname = f"{chName}_{bkgr_name}{how}" if isZstack: - note_txt = (f""" + note_txt = f""" {_get_zStack_note(how_desc)} Example: {bkgr_colname} is the {bkgr_desc.lower()} of the {chName} background after converting 3D to 2D {how_desc} - """) + """ else: - note_txt = '' + note_txt = "" - if bkgr_name.find('autoBkgr') != -1: - bkgr_type_desc = (""" + if bkgr_name.find("autoBkgr") != -1: + bkgr_type_desc = """ autoBkgr means that the background value is computed from ALL the pixels outside of the segmented objects (i.e., pixels with ID 0 in the segmentation mask)

    - """) + """ else: - bkgr_type_desc = (""" + bkgr_type_desc = """ dataPrepBkgr means that the background value is computed from the pixels inside the rectangular background ROIs that you selected in the @@ -742,9 +769,9 @@ def standard_metrics_desc( Note taht this metric is grayed out and it cannot be selected if the selection of the background ROIs was not performed.

    - """) - if bkgr_name.find('bkgrVal_median') != -1: - bkgr_type_desc = f'{bkgr_type_desc}{median_note}' + """ + if bkgr_name.find("bkgrVal_median") != -1: + bkgr_type_desc = f"{bkgr_type_desc}{median_note}" bkgr_final_desc = html_utils.paragraph(f""" {bkgr_desc} of the background intensities.

    @@ -754,38 +781,36 @@ def standard_metrics_desc( return metrics_desc, bkgr_val_desc + def get_conc_keys(amount_colname): conc_key_vox = re.sub( - r'amount_([A-Za-z]+)', - r'concentration_\1_from_vol_vox', - amount_colname + r"amount_([A-Za-z]+)", r"concentration_\1_from_vol_vox", amount_colname ) - conc_key_fl = conc_key_vox.replace('from_vol_vox', 'from_vol_fl') + conc_key_fl = conc_key_vox.replace("from_vol_vox", "from_vol_fl") return conc_key_vox, conc_key_fl + def classify_acdc_df_colnames(acdc_df, channels): standard_funcs = _get_metrics_names() size_metrics_desc = get_size_metrics_desc(True, True) props_names = get_props_names() - foregr_metrics = {ch:[] for ch in channels} - bkgr_metrics = {ch:[] for ch in channels} - custom_metrics = {ch:[] for ch in channels} + foregr_metrics = {ch: [] for ch in channels} + bkgr_metrics = {ch: [] for ch in channels} + custom_metrics = {ch: [] for ch in channels} size_metrics = [] props_metrics = [] for col in acdc_df.columns: for ch in channels: - if col.startswith(f'{ch}_'): + if col.startswith(f"{ch}_"): # Channel specific metric - if col.find('_bkgrVal_') != -1: + if col.find("_bkgrVal_") != -1: # Bkgr metric bkgr_metrics[ch].append(col) else: # Foregr metric - is_standard = any( - [col.find(f'_{f}') != -1 for f in standard_funcs] - ) + is_standard = any([col.find(f"_{f}") != -1 for f in standard_funcs]) if is_standard: # Standard metric foregr_metrics[ch].append(col) @@ -801,70 +826,74 @@ def classify_acdc_df_colnames(acdc_df, channels): elif col in props_names: # Regionprop metric props_metrics.append(col) - + metrics = { - 'foregr': foregr_metrics, - 'bkgr': bkgr_metrics, - 'custom': custom_metrics, - 'size': size_metrics, - 'props': props_metrics + "foregr": foregr_metrics, + "bkgr": bkgr_metrics, + "custom": custom_metrics, + "size": size_metrics, + "props": props_metrics, } return metrics + def _get_metrics_names(is_manual_bkgr_present=False): metrics_names = { - 'mean': 'Mean', - 'sum': 'Sum', - 'amount_autoBkgr': 'Amount', - 'amount_dataPrepBkgr': 'Amount', - 'amount_manualBkgr': 'Amount', - 'mean_manualBkgr': 'Mean', - 'concentration_autoBkgr_from_vol_vox': 'Concentration', - 'concentration_dataPrepBkgr_from_vol_vox': 'Concentration', - 'concentration_autoBkgr_from_vol_fl': 'Concentration', - 'concentration_dataPrepBkgr_from_vol_fl': 'Concentration', - 'median': 'Median', - 'min': 'Minimum', - 'max': 'Maximum', - 'q25': '25 percentile', - 'q75': '75 percentile', - 'q05': '5 percentile', - 'q95': '95 percentile', + "mean": "Mean", + "sum": "Sum", + "amount_autoBkgr": "Amount", + "amount_dataPrepBkgr": "Amount", + "amount_manualBkgr": "Amount", + "mean_manualBkgr": "Mean", + "concentration_autoBkgr_from_vol_vox": "Concentration", + "concentration_dataPrepBkgr_from_vol_vox": "Concentration", + "concentration_autoBkgr_from_vol_fl": "Concentration", + "concentration_dataPrepBkgr_from_vol_fl": "Concentration", + "median": "Median", + "min": "Minimum", + "max": "Maximum", + "q25": "25 percentile", + "q75": "75 percentile", + "q05": "5 percentile", + "q95": "95 percentile", } return metrics_names + def _get_amount_formula_str(func_name): - if func_name.find('manualBkgr') != -1: - formula = 'amount = (mean_obj - mean_background)*area_obj' + if func_name.find("manualBkgr") != -1: + formula = "amount = (mean_obj - mean_background)*area_obj" else: - formula = 'amount = (mean_obj - median_background)*area_obj' + formula = "amount = (mean_obj - median_background)*area_obj" return formula + def _get_bkgr_val_names(is_manual_bkgr_present=False): bkgr_val_names = { - 'autoBkgr_bkgrVal_median': 'Median', - 'autoBkgr_bkgrVal_mean': 'Mean', - 'autoBkgr_bkgrVal_q75': '75 percentile', - 'autoBkgr_bkgrVal_q25': '25 percentile', - 'autoBkgr_bkgrVal_q95': '95 percentile', - 'autoBkgr_bkgrVal_q05': '5 percentile', - 'dataPrepBkgr_bkgrVal_median': 'Median', - 'dataPrepBkgr_bkgrVal_mean': 'Mean', - 'dataPrepBkgr_bkgrVal_q75': '75 percentile', - 'dataPrepBkgr_bkgrVal_q25': '25 percentile', - 'dataPrepBkgr_bkgrVal_q95': '95 percentile', - 'dataPrepBkgr_bkgrVal_q05': '5 percentile', + "autoBkgr_bkgrVal_median": "Median", + "autoBkgr_bkgrVal_mean": "Mean", + "autoBkgr_bkgrVal_q75": "75 percentile", + "autoBkgr_bkgrVal_q25": "25 percentile", + "autoBkgr_bkgrVal_q95": "95 percentile", + "autoBkgr_bkgrVal_q05": "5 percentile", + "dataPrepBkgr_bkgrVal_median": "Median", + "dataPrepBkgr_bkgrVal_mean": "Mean", + "dataPrepBkgr_bkgrVal_q75": "75 percentile", + "dataPrepBkgr_bkgrVal_q25": "25 percentile", + "dataPrepBkgr_bkgrVal_q95": "95 percentile", + "dataPrepBkgr_bkgrVal_q05": "5 percentile", } if is_manual_bkgr_present: - bkgr_val_names['manualBkgr_bkgrVal_median'] = 'Median' - bkgr_val_names['manualBkgr_bkgrVal_mean'] = 'Mean' - bkgr_val_names['manualBkgr_bkgrVal_q75'] = '75 percentile' - bkgr_val_names['manualBkgr_bkgrVal_q25'] = '25 percentile' - bkgr_val_names['manualBkgr_bkgrVal_q95'] = '95 percentile' - bkgr_val_names['manualBkgr_bkgrVal_q05'] = '5 percentile' + bkgr_val_names["manualBkgr_bkgrVal_median"] = "Median" + bkgr_val_names["manualBkgr_bkgrVal_mean"] = "Mean" + bkgr_val_names["manualBkgr_bkgrVal_q75"] = "75 percentile" + bkgr_val_names["manualBkgr_bkgrVal_q25"] = "25 percentile" + bkgr_val_names["manualBkgr_bkgrVal_q95"] = "95 percentile" + bkgr_val_names["manualBkgr_bkgrVal_q05"] = "5 percentile" return bkgr_val_names + def _get_props_info_txt(): txt = html_utils.paragraph(f""" Morphological properties are calculated using the function @@ -875,6 +904,7 @@ def _get_props_info_txt(): """) return txt + def _get_info_circularity(): info_txt = html_utils.paragraph(f""" Circularity is defined as the ratio between @@ -888,6 +918,7 @@ def _get_info_circularity(): """) return info_txt + def _get_info_roundness(): info_txt = html_utils.paragraph(f""" Roundness is defined as the ratio between @@ -898,11 +929,12 @@ def _get_info_roundness(): You can find more details about the major axis and the area here scikit-image regionprops. """) - + return info_txt + def _get_info_aspect_ratio(): - + info_txt = html_utils.paragraph(f""" Aspect ratio is defined as the ratio between the major and minor axis of the object.

    @@ -912,9 +944,10 @@ def _get_info_aspect_ratio(): You can find more details about major and minor axis here scikit-image regionprops. """) - + return info_txt + def get_props_info_txt_mapper(isSegm3D=False): skimage_desc = _get_props_info_txt() if isSegm3D: @@ -922,19 +955,19 @@ def get_props_info_txt_mapper(isSegm3D=False): else: props_names = get_props_names() mapper = {prop: skimage_desc for prop in props_names} - - mapper['circularity'] = _get_info_circularity() - mapper['roundness'] = _get_info_roundness() - mapper['aspect_ratio'] = _get_info_aspect_ratio() - + + mapper["circularity"] = _get_info_circularity() + mapper["roundness"] = _get_info_roundness() + mapper["aspect_ratio"] = _get_info_aspect_ratio() + return mapper + def _is_numeric_dtype(dtype): - is_numeric = ( - dtype is float or dtype is int - ) + is_numeric = dtype is float or dtype is int return is_numeric + def get_bkgrROI_mask(posData, isSegm3D): if posData.bkgrROIs: ROI_bkgrMask = np.zeros(posData.lab.shape, bool) @@ -943,18 +976,17 @@ def get_bkgrROI_mask(posData, isSegm3D): xl, yl = [int(round(c)) for c in roi.pos()] w, h = [int(round(c)) for c in roi.size()] if isSegm3D: - ROI_bkgrMask[:, yl:yl+h, xl:xl+w] = True + ROI_bkgrMask[:, yl : yl + h, xl : xl + w] = True else: - ROI_bkgrMask[yl:yl+h, xl:xl+w] = True + ROI_bkgrMask[yl : yl + h, xl : xl + w] = True else: ROI_bkgrMask = None return ROI_bkgrMask + def get_autoBkgr_mask(lab, isSegm3D, posData, frame_i): autoBkgr_mask = lab == 0 - autoBkgr_mask = _mask_0valued_pixels_from_alignment( - autoBkgr_mask, frame_i, posData - ) + autoBkgr_mask = _mask_0valued_pixels_from_alignment(autoBkgr_mask, frame_i, posData) if isSegm3D: autoBkgr_mask_proj = lab.max(axis=0) == 0 autoBkgr_mask_proj = _mask_0valued_pixels_from_alignment( @@ -962,15 +994,16 @@ def get_autoBkgr_mask(lab, isSegm3D, posData, frame_i): ) else: autoBkgr_mask_proj = autoBkgr_mask - + return autoBkgr_mask, autoBkgr_mask_proj + def regionprops_table(labels, props, logger_func=None): rp = skimage.measure.regionprops(labels) - if 'label' not in props: - props = ('label', *props) - - empty_metric = [None]*len(rp) + if "label" not in props: + props = ("label", *props) + + empty_metric = [None] * len(rp) rp_table = {} error_ids = {} pbar = tqdm(total=len(props), ncols=100, leave=False) @@ -987,15 +1020,15 @@ def regionprops_table(labels, props, logger_func=None): rp_table[prop][o] = metric elif _type == tuple: for m, val in enumerate(metric): - prop_1d = f'{prop}-{m}' + prop_1d = f"{prop}-{m}" if prop_1d not in rp_table: rp_table[prop_1d] = empty_metric.copy() rp_table[prop_1d][o] = val elif _type == np.ndarray: for i, val in enumerate(metric.flatten()): indices = np.unravel_index(i, metric.shape) - s = '-'.join([str(idx) for idx in indices]) - prop_1d = f'{prop}-{s}' + s = "-".join([str(idx) for idx in indices]) + prop_1d = f"{prop}-{s}" if prop_1d not in rp_table: rp_table[prop_1d] = empty_metric.copy() rp_table[prop_1d][o] = val @@ -1005,61 +1038,66 @@ def regionprops_table(labels, props, logger_func=None): printl(format_exception) else: logger_func(format_exception) - + if prop not in error_ids: - error_ids[prop] = {'ids': [obj.label], 'error': e} + error_ids[prop] = {"ids": [obj.label], "error": e} else: - error_ids[prop]['ids'].append(obj.label) + error_ids[prop]["ids"].append(obj.label) pbar.update(1) return rp_table, error_ids + def get_btrack_features(): features = ( - 'area', - 'major_axis_length', - 'minor_axis_length', - 'equivalent_diameter', - 'solidity', - 'extent', - 'filled_area', - 'bbox_area', - 'convex_area', - 'euler_number', - 'orientation' + "area", + "major_axis_length", + "minor_axis_length", + "equivalent_diameter", + "solidity", + "extent", + "filled_area", + "bbox_area", + "convex_area", + "euler_number", + "orientation", ) return features + def get_non_measurements_cols(colnames, metrics_colnames): non_metrics_colnames = [] for col in colnames: if col in metrics_colnames: continue non_metrics_colnames.append(col) - + non_metrics_non_rp_colnames = [] props = get_props_names() # Remove composite regionprops for col in non_metrics_colnames: for prop in props: - match = re.match(rf'{prop}-\d', col) + match = re.match(rf"{prop}-\d", col) if match is not None: break - match = re.match(rf'{col}-\d-\d', col) + match = re.match(rf"{col}-\d-\d", col) if match is not None: break else: non_metrics_non_rp_colnames.append(col) - return non_metrics_non_rp_colnames - + return non_metrics_non_rp_colnames + + def get_props_names_3D(): props_3D = list(PROPS_DTYPES.keys()) - props_3D.remove('solidity') - props_3D.remove('eccentricity') + props_3D.remove("solidity") + props_3D.remove("eccentricity") return props_3D + def get_props_names(): return list(PROPS_DTYPES.keys()) + def _try_metric_func(func, *args): try: val = func(*args) @@ -1067,6 +1105,7 @@ def _try_metric_func(func, *args): val = np.nan return val + def _quantile(arr, q): try: val = np.quantile(arr, q=q) @@ -1074,151 +1113,159 @@ def _quantile(arr, q): val = np.nan return val + def _amount(arr, bkgr, area): try: - val = (np.mean(arr)-bkgr)*area + val = (np.mean(arr) - bkgr) * area except Exception as e: val = np.nan return val + def _mean_corrected(arr, bkgr): try: - val = np.mean(arr)-bkgr + val = np.mean(arr) - bkgr except Exception as e: val = np.nan return val -def get_obj_size_metric( - col_name, obj, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D - ): - if col_name == 'cell_area_pxl': + +def get_obj_size_metric(col_name, obj, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D): + if col_name == "cell_area_pxl": if isSegm3D: return np.count_nonzero(obj.image.max(axis=0)) else: return obj.area - elif col_name == 'cell_area_um2': + elif col_name == "cell_area_um2": if isSegm3D: - return np.count_nonzero(obj.image.max(axis=0))*yx_pxl_to_um2 + return np.count_nonzero(obj.image.max(axis=0)) * yx_pxl_to_um2 else: - return obj.area*yx_pxl_to_um2 - elif col_name == 'cell_vol_vox': - if not hasattr(obj, 'vol_vox'): + return obj.area * yx_pxl_to_um2 + elif col_name == "cell_vol_vox": + if not hasattr(obj, "vol_vox"): PhysicalSizeY = PhysicalSizeX = np.sqrt(yx_pxl_to_um2) vol_vox, vol_fl = cca_functions._calc_rot_vol( obj, PhysicalSizeY, PhysicalSizeX ) obj.vol_vox, obj.vol_fl = vol_vox, vol_fl return obj.vol_vox - elif col_name == 'cell_vol_fl': - if not hasattr(obj, 'vol_fl'): + elif col_name == "cell_vol_fl": + if not hasattr(obj, "vol_fl"): PhysicalSizeY = PhysicalSizeX = np.sqrt(yx_pxl_to_um2) vol_vox, vol_fl = cca_functions._calc_rot_vol( obj, PhysicalSizeY, PhysicalSizeX ) obj.vol_vox, obj.vol_fl = vol_vox, vol_fl return obj.vol_fl - elif col_name == 'cell_vol_vox_3D': + elif col_name == "cell_vol_vox_3D": return obj.area - elif col_name == 'cell_vol_fl_3D': - return obj.area*vox_to_fl_3D + elif col_name == "cell_vol_fl_3D": + return obj.area * vox_to_fl_3D + def get_foregr_data(foregr_img, isSegm3D, z): isZstack = foregr_img.ndim == 3 foregr_data = {} if isSegm3D: - foregr_data['3D'] = foregr_img - + foregr_data["3D"] = foregr_img + if isZstack: - foregr_data['maxProj'] = foregr_img.max(axis=0) - foregr_data['meanProj'] = foregr_img.mean(axis=0) - foregr_data['zSlice'] = foregr_img[z] - foregr_data[''] = foregr_img + foregr_data["maxProj"] = foregr_img.max(axis=0) + foregr_data["meanProj"] = foregr_img.mean(axis=0) + foregr_data["zSlice"] = foregr_img[z] + foregr_data[""] = foregr_img return foregr_data + def get_cell_volumes_areas(df): try: - cell_vol_vox = df['cell_vol_vox'].to_list() + cell_vol_vox = df["cell_vol_vox"].to_list() except Exception as e: - cell_vol_vox = [np.nan]*len(df) - + cell_vol_vox = [np.nan] * len(df) + try: - cell_vol_fl = df['cell_vol_fl'].to_list() + cell_vol_fl = df["cell_vol_fl"].to_list() except Exception as e: - cell_vol_fl = [np.nan]*len(df) - + cell_vol_fl = [np.nan] * len(df) + try: - cell_vol_vox_3D = df['cell_vol_vox_3D'].to_list() + cell_vol_vox_3D = df["cell_vol_vox_3D"].to_list() except Exception as e: - cell_vol_vox_3D = [np.nan]*len(df) - + cell_vol_vox_3D = [np.nan] * len(df) + try: - cell_vol_fl_3D = df['cell_vol_fl_3D'].to_list() + cell_vol_fl_3D = df["cell_vol_fl_3D"].to_list() except Exception as e: - cell_vol_fl_3D = [np.nan]*len(df) - + cell_vol_fl_3D = [np.nan] * len(df) + try: - cell_area_pxl = df['cell_area_pxl'].to_list() + cell_area_pxl = df["cell_area_pxl"].to_list() except Exception as e: - cell_area_pxl = [np.nan]*len(df) - + cell_area_pxl = [np.nan] * len(df) + try: - cell_area_um2 = df['cell_vol_fl_3D'].to_list() + cell_area_um2 = df["cell_vol_fl_3D"].to_list() except Exception as e: - cell_area_um2 = [np.nan]*len(df) - + cell_area_um2 = [np.nan] * len(df) + items = ( - cell_vol_vox, cell_vol_fl, cell_vol_vox_3D, cell_vol_fl_3D, - cell_area_pxl, cell_area_um2 + cell_vol_vox, + cell_vol_fl, + cell_vol_vox_3D, + cell_vol_fl_3D, + cell_area_pxl, + cell_area_um2, ) return items + def get_bkgrVals(df, channel, how, ID, bkgr_type=None): try: if how: - autoBkgr_col = f'{channel}_autoBkgr_bkgrVal_median_{how}' + autoBkgr_col = f"{channel}_autoBkgr_bkgrVal_median_{how}" else: - autoBkgr_col = f'{channel}_autoBkgr_bkgrVal_median' + autoBkgr_col = f"{channel}_autoBkgr_bkgrVal_median" autoBkgrVal = df.at[ID, autoBkgr_col] except Exception as e: autoBkgrVal = np.nan - + try: if how: - dataPrepBkgr_col = f'{channel}_dataPrepBkgr_bkgrVal_median_{how}' + dataPrepBkgr_col = f"{channel}_dataPrepBkgr_bkgrVal_median_{how}" else: - dataPrepBkgr_col = f'{channel}_dataPrepBkgr_bkgrVal_median' + dataPrepBkgr_col = f"{channel}_dataPrepBkgr_bkgrVal_median" dataPrepBkgrVal = df.at[ID, dataPrepBkgr_col] except Exception as e: dataPrepBkgrVal = np.nan if bkgr_type is None: return autoBkgrVal, dataPrepBkgrVal - - if bkgr_type.find('dataPrep') != -1: + + if bkgr_type.find("dataPrep") != -1: return dataPrepBkgrVal else: return autoBkgrVal + def get_manualBkgr_bkgrVal(df, channel, how, ID): try: if how: - bkgr_col = f'{channel}_manualBkgr_bkgrVal_mean_{how}' + bkgr_col = f"{channel}_manualBkgr_bkgrVal_mean_{how}" else: - bkgr_col = f'{channel}_dataPrepBkgr_bkgrVal_mean' + bkgr_col = f"{channel}_dataPrepBkgr_bkgrVal_mean" bkgrVal = df.at[ID, bkgr_col] except Exception as e: bkgrVal = np.nan return bkgrVal + def get_foregr_obj_array(foregr_arr, obj, isSegm3D, z_slice=None, how=None): if foregr_arr.ndim == 3 and isSegm3D: # 3D mask on 3D data return foregr_arr[obj.slice][obj.image], obj.area elif foregr_arr.ndim == 2 and isSegm3D: # 3D mask on 2D data - use_proj = ( - z_slice is None or how is None or how != 'zSlice' - ) + use_proj = z_slice is None or how is None or how != "zSlice" obj_slice = obj.slice[1:3] if use_proj: obj_image = obj.image.max(axis=0) @@ -1237,60 +1284,72 @@ def get_foregr_obj_array(foregr_arr, obj, isSegm3D, z_slice=None, how=None): # 2D mask on 2D data return foregr_arr[obj.slice][obj.image], obj.area + def _mask_0valued_pixels_from_alignment(bkgr_mask, frame_i, posData): if posData.loaded_shifts is None: # Not aligned --> there are no 0-valued pixels return bkgr_mask - + if posData.dataPrep_ROIcoords is not None: df_roi = posData.dataPrep_ROIcoords.loc[0] - is_cropped = int(df_roi.at['cropped', 'value']) + is_cropped = int(df_roi.at["cropped", "value"]) if is_cropped: # Do not mask 0valued pixels if image was cropped return bkgr_mask - + shifts = posData.loaded_shifts[frame_i] dy, dx = shifts - if dy>0: + if dy > 0: bkgr_mask[..., :dy, :] = False - elif dy<0: + elif dy < 0: bkgr_mask[..., dy:, :] = False - if dx>0: + if dx > 0: bkgr_mask[..., :dx] = False - elif dx<0: + elif dx < 0: bkgr_mask[..., dx:] = False - + return bkgr_mask + def get_bkgr_data( - foregr_img, posData, filename, frame_i, autoBkgr_mask, z, - autoBkgr_mask_proj, dataPrepBkgrROI_mask, isSegm3D, lab - ): + foregr_img, + posData, + filename, + frame_i, + autoBkgr_mask, + z, + autoBkgr_mask_proj, + dataPrepBkgrROI_mask, + isSegm3D, + lab, +): isZstack = foregr_img.ndim == 3 bkgr_data = {} """Auto Background""" - bkgr_data['autoBkgr'] = { - '': 0, 'maxProj': 0, 'meanProj': 0, 'zSlice': 0, '3D': 0 - } + bkgr_data["autoBkgr"] = {"": 0, "maxProj": 0, "meanProj": 0, "zSlice": 0, "3D": 0} if isZstack: if isSegm3D: autoBkr_3D = foregr_img[autoBkgr_mask] - bkgr_data['autoBkgr']['3D'] = autoBkr_3D + bkgr_data["autoBkgr"]["3D"] = autoBkr_3D autoBkgr_maxP = foregr_img.max(axis=0)[autoBkgr_mask_proj] autoBkgr_meanP = foregr_img.mean(axis=0)[autoBkgr_mask_proj] autoBkgr_zSlice = foregr_img[int(z)][autoBkgr_mask_proj] - bkgr_data['autoBkgr']['maxProj'] = autoBkgr_maxP - bkgr_data['autoBkgr']['meanProj'] = autoBkgr_meanP - bkgr_data['autoBkgr']['zSlice'] = autoBkgr_zSlice + bkgr_data["autoBkgr"]["maxProj"] = autoBkgr_maxP + bkgr_data["autoBkgr"]["meanProj"] = autoBkgr_meanP + bkgr_data["autoBkgr"]["zSlice"] = autoBkgr_zSlice else: autoBkgr_data = foregr_img[autoBkgr_mask] - bkgr_data['autoBkgr'][''] = autoBkgr_data + bkgr_data["autoBkgr"][""] = autoBkgr_data """DataPrep Background""" bkgr_archive = posData.fluo_bkgrData_dict[filename] - bkgr_data['dataPrepBkgr'] = { - '': [], 'maxProj': [], 'meanProj': [], 'zSlice': [], '3D': [] + bkgr_data["dataPrepBkgr"] = { + "": [], + "maxProj": [], + "meanProj": [], + "zSlice": [], + "3D": [], } dataPrepBkgr_present = False if bkgr_archive is not None: @@ -1306,11 +1365,11 @@ def get_bkgr_data( if isSegm3D: bkgrRoi_3D = bkgrRoi_data else: - bkgrRoi = bkgrRoi_data - dataPrepBkgr_present = True + bkgrRoi = bkgrRoi_data + dataPrepBkgr_present = True elif dataPrepBkgrROI_mask is not None: # Get background data from the bkgr ROI mask - dataPrepBkgrROI_mask = np.logical_and(dataPrepBkgrROI_mask, lab==0) + dataPrepBkgrROI_mask = np.logical_and(dataPrepBkgrROI_mask, lab == 0) if isZstack: if isSegm3D: bkgrRoi_3D = foregr_img[dataPrepBkgrROI_mask] @@ -1319,39 +1378,39 @@ def get_bkgr_data( dataPrepBkgrROI_mask_2D = dataPrepBkgrROI_mask bkgrRoi_maxP = foregr_img.max(axis=0)[dataPrepBkgrROI_mask_2D] bkgrRoi_meanP = foregr_img.mean(axis=0)[dataPrepBkgrROI_mask_2D] - bkgrRoi_zSlice = foregr_img[z][dataPrepBkgrROI_mask_2D] + bkgrRoi_zSlice = foregr_img[z][dataPrepBkgrROI_mask_2D] else: bkgrRoi = foregr_img[dataPrepBkgrROI_mask] - dataPrepBkgr_present = True - + dataPrepBkgr_present = True + if isZstack and dataPrepBkgr_present: # Note: we do not try to exclude 0-valued pixels, see issue #285 - bkgr_data['dataPrepBkgr']['maxProj'].extend(bkgrRoi_maxP) - bkgr_data['dataPrepBkgr']['meanProj'].extend(bkgrRoi_meanP) - bkgr_data['dataPrepBkgr']['zSlice'].extend(bkgrRoi_zSlice) + bkgr_data["dataPrepBkgr"]["maxProj"].extend(bkgrRoi_maxP) + bkgr_data["dataPrepBkgr"]["meanProj"].extend(bkgrRoi_meanP) + bkgr_data["dataPrepBkgr"]["zSlice"].extend(bkgrRoi_zSlice) if isSegm3D: - bkgr_data['dataPrepBkgr']['3D'].extend(bkgrRoi_3D) + bkgr_data["dataPrepBkgr"]["3D"].extend(bkgrRoi_3D) elif dataPrepBkgr_present: - bkgr_data['dataPrepBkgr'][''].extend(bkgrRoi) - + bkgr_data["dataPrepBkgr"][""].extend(bkgrRoi) + return bkgr_data - + def standard_metrics_func(): metrics_func = { - 'sum': lambda arr: _try_metric_func(np.sum, arr), - 'amount_autoBkgr': lambda arr, bkgr, area: _amount(arr, bkgr, area), - 'amount_dataPrepBkgr': lambda arr, bkgr, area: _amount(arr, bkgr, area), - 'amount_manualBkgr': lambda arr, bkgr, area: _amount(arr, bkgr, area), - 'mean_manualBkgr': lambda arr, bkgr, area: _mean_corrected(arr, bkgr), - 'mean': lambda arr: _try_metric_func(np.mean, arr), - 'median': lambda arr: _try_metric_func(np.median, arr), - 'min': lambda arr: _try_metric_func(np.min, arr), - 'max': lambda arr: _try_metric_func(np.max, arr), - 'q25': lambda arr: _quantile(arr, 0.25), - 'q75': lambda arr: _quantile(arr, 0.75), - 'q05': lambda arr: _quantile(arr, 0.05), - 'q95': lambda arr: _quantile(arr, 0.95) + "sum": lambda arr: _try_metric_func(np.sum, arr), + "amount_autoBkgr": lambda arr, bkgr, area: _amount(arr, bkgr, area), + "amount_dataPrepBkgr": lambda arr, bkgr, area: _amount(arr, bkgr, area), + "amount_manualBkgr": lambda arr, bkgr, area: _amount(arr, bkgr, area), + "mean_manualBkgr": lambda arr, bkgr, area: _mean_corrected(arr, bkgr), + "mean": lambda arr: _try_metric_func(np.mean, arr), + "median": lambda arr: _try_metric_func(np.median, arr), + "min": lambda arr: _try_metric_func(np.min, arr), + "max": lambda arr: _try_metric_func(np.max, arr), + "q25": lambda arr: _quantile(arr, 0.25), + "q75": lambda arr: _quantile(arr, 0.75), + "q05": lambda arr: _quantile(arr, 0.05), + "q95": lambda arr: _quantile(arr, 0.95), } all_metrics_names = list(_get_metrics_names().keys()) @@ -1360,10 +1419,11 @@ def standard_metrics_func(): return metrics_func, all_metrics_names + def add_metrics_instructions(): - url = 'https://github.com/SchmollerLab/Cell_ACDC/issues' + url = "https://github.com/SchmollerLab/Cell_ACDC/issues" href = f'here' - rp_url = f'https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops' + rp_url = f"https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops" rp_href = f'skimage.measure.regionproperties' def_sh = html_utils.def_sh CV_sh = html_utils.CV_sh @@ -1377,8 +1437,9 @@ def add_metrics_instructions(): return_sh = html_utils.return_sh is_not_sh = html_utils.is_not_sh args_sh = html_utils.span( - 'signal, autoBkgr, dataPrepBkgr, objectRp, correct_with_bkgr=False, ' - 'which_bkgr="auto"', color=html_utils.kwargs_color + "signal, autoBkgr, dataPrepBkgr, objectRp, correct_with_bkgr=False, " + 'which_bkgr="auto"', + color=html_utils.kwargs_color, ) s = html_utils.paragraph(f""" To add custom metrics to the acdc_output.csv @@ -1416,6 +1477,7 @@ def add_metrics_instructions(): """) return s + def _get_combine_metrics_examples_list(): examples = [ """ @@ -1445,12 +1507,13 @@ def _get_combine_metrics_examples_list(): ch1_minus_ch2_mean
    with the result of the subtraction between the channel_1 signal's mean and the channel_2 signal's mean. - """ + """, ] return examples + def get_combine_metrics_help_txt(): - pandas_eval_url = 'https://pandas.pydata.org/docs/reference/api/pandas.eval.html' + pandas_eval_url = "https://pandas.pydata.org/docs/reference/api/pandas.eval.html" examples = _get_combine_metrics_examples_list() txt = html_utils.paragraph(f""" This dialog allows you to write an equation that will be used to @@ -1501,7 +1564,7 @@ def get_combine_metrics_help_txt(): Cell-ACDC uses the Python package pandas to evaluate the expression.
    You can read more about it - {html_utils.href_tag('here', pandas_eval_url)}

    + {html_utils.href_tag("here", pandas_eval_url)}

    The equations will be saved to both the loaded Position folder
    @@ -1514,71 +1577,80 @@ def get_combine_metrics_help_txt(): """) return txt + def add_concentration_metrics(df, concentration_metrics_params): for col, (func_name, how) in concentration_metrics_params.items(): - idx = col.find('_from_vol_') + idx = col.find("_from_vol_") amount_col = col[:idx] - amount_col = amount_col.replace('concentration_', 'amount_') + amount_col = amount_col.replace("concentration_", "amount_") if how: - amount_col = f'{amount_col}_{how}' + amount_col = f"{amount_col}_{how}" - if col.find('from_vol_vox') != -1: + if col.find("from_vol_vox") != -1: try: - if how == '3D': - cell_vol_values = df['cell_vol_vox_3D'] + if how == "3D": + cell_vol_values = df["cell_vol_vox_3D"] else: - cell_vol_values = df['cell_vol_vox'] - concentration_values = df[amount_col]/cell_vol_values + cell_vol_values = df["cell_vol_vox"] + concentration_values = df[amount_col] / cell_vol_values except Exception as e: concentration_values = np.nan df[col] = concentration_values - elif col.find('from_vol_fl') != -1: + elif col.find("from_vol_fl") != -1: try: - if how == '3D': - cell_vol_values = df['cell_vol_fl_3D'] + if how == "3D": + cell_vol_values = df["cell_vol_fl_3D"] else: - cell_vol_values = df['cell_vol_fl'] - concentration_values = df[amount_col]/cell_vol_values + cell_vol_values = df["cell_vol_fl"] + concentration_values = df[amount_col] / cell_vol_values except Exception as e: concentration_values = np.nan df[col] = concentration_values return df + def add_size_metrics( - df, rp, size_metrics_to_save, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D, - calc_size_for_each_zslice=False - ): + df, + rp, + size_metrics_to_save, + isSegm3D, + yx_pxl_to_um2, + vox_to_fl_3D, + calc_size_for_each_zslice=False, +): for o, obj in enumerate(tqdm(rp, ncols=100, leave=False)): for col in size_metrics_to_save: - val = get_obj_size_metric( - col, obj, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D - ) + val = get_obj_size_metric(col, obj, isSegm3D, yx_pxl_to_um2, vox_to_fl_3D) df.at[obj.label, col] = val - + if not calc_size_for_each_zslice: continue - + z0 = obj.bbox[0] for local_z, obj_img_z in enumerate(obj.image): z_slice = z0 + local_z area_pxl_z = np.count_nonzero(obj_img_z) - area_pxl_zslice_col = f'cell_area_pxl_zslice{z_slice}' + area_pxl_zslice_col = f"cell_area_pxl_zslice{z_slice}" df.at[obj.label, area_pxl_zslice_col] = area_pxl_z - - area_um2_z = area_pxl_z*yx_pxl_to_um2 - area_um2_zslice_col = f'cell_area_um2_zslice{z_slice}' + + area_um2_z = area_pxl_z * yx_pxl_to_um2 + area_um2_zslice_col = f"cell_area_um2_zslice{z_slice}" df.at[obj.label, area_um2_zslice_col] = area_um2_z return df + def add_ch_indipend_custom_metrics( - df: pd.DataFrame, - rp, all_channels_foregr_data, - ch_indipend_custom_func_params, - isSegm3D, lab, all_channels_foregr_imgs, - all_channels_z_slices=None, - text_to_append_to_col='', - customMetricsCritical=None - ): + df: pd.DataFrame, + rp, + all_channels_foregr_data, + ch_indipend_custom_func_params, + isSegm3D, + lab, + all_channels_foregr_imgs, + all_channels_z_slices=None, + text_to_append_to_col="", + customMetricsCritical=None, +): for o, obj in enumerate(tqdm(rp, ncols=100, leave=False)): ID = obj.label for col, (custom_func, how) in ch_indipend_custom_func_params.items(): @@ -1589,111 +1661,149 @@ def add_ch_indipend_custom_metrics( if all_channels_z_slices is not None: z_slice = all_channels_z_slices[channel] else: - z_slice = None + z_slice = None foregr_arr = foregr_data.get(how) if foregr_arr is None: continue - + foregr_obj_arr, obj_area = get_foregr_obj_array( foregr_arr, obj, isSegm3D, z_slice=z_slice, how=how ) - - autoBkgrVal, dataPrepBkgrVal = get_bkgrVals( - df, channel, how, ID - ) - + + autoBkgrVal, dataPrepBkgrVal = get_bkgrVals(df, channel, how, ID) + all_channels_obj_intens[channel] = foregr_obj_arr all_channels_autoBkgr[channel] = autoBkgrVal all_channels_dataPrepBkgr[channel] = dataPrepBkgrVal - - metrics_values = df.to_dict('list') + + metrics_values = df.to_dict("list") items = get_cell_volumes_areas(df) - (cell_vols_vox, cell_vols_fl, cell_vols_vox_3D, cell_vols_fl_3D, - cell_areas_pxl, cell_areas_um2) = items + ( + cell_vols_vox, + cell_vols_fl, + cell_vols_vox_3D, + cell_vols_fl_3D, + cell_areas_pxl, + cell_areas_um2, + ) = items custom_error, custom_val, custom_col_name = ( get_ch_indipend_custom_metric_value( - custom_func, - all_channels_obj_intens, - all_channels_autoBkgr, - all_channels_dataPrepBkgr, - obj, o, - metrics_values, - cell_vols_vox, cell_vols_fl, cell_areas_pxl, - cell_areas_um2, - all_channels_foregr_imgs, - lab, - isSegm3D, + custom_func, + all_channels_obj_intens, + all_channels_autoBkgr, + all_channels_dataPrepBkgr, + obj, + o, + metrics_values, + cell_vols_vox, + cell_vols_fl, + cell_areas_pxl, + cell_areas_um2, + all_channels_foregr_imgs, + lab, + isSegm3D, col, - cell_vols_vox_3D=cell_vols_vox_3D, - cell_vols_fl_3D=cell_vols_fl_3D + cell_vols_vox_3D=cell_vols_vox_3D, + cell_vols_fl_3D=cell_vols_fl_3D, ) ) if custom_col_name is None: - df.at[ID, f'{col}{text_to_append_to_col}'] = custom_val + df.at[ID, f"{col}{text_to_append_to_col}"] = custom_val else: for custom_col, value in zip(custom_col_name, custom_val): - df.at[ID, f'{custom_col}{text_to_append_to_col}'] = value - + df.at[ID, f"{custom_col}{text_to_append_to_col}"] = value + if customMetricsCritical is not None and custom_error: customMetricsCritical.emit(custom_error, col) - + return df def add_custom_metrics( - df: pd.DataFrame, - rp, channel, foregr_data, custom_metrics_params, - isSegm3D, lab, foregr_img, other_channels_foregr_imgs, - z_slice=None, text_to_append_to_col='', - customMetricsCritical=None - ): + df: pd.DataFrame, + rp, + channel, + foregr_data, + custom_metrics_params, + isSegm3D, + lab, + foregr_img, + other_channels_foregr_imgs, + z_slice=None, + text_to_append_to_col="", + customMetricsCritical=None, +): for o, obj in enumerate(tqdm(rp, ncols=100, leave=False)): - for col, (custom_func, how) in custom_metrics_params.items(): + for col, (custom_func, how) in custom_metrics_params.items(): foregr_arr = foregr_data.get(how) if foregr_arr is None: continue - + foregr_obj_arr, obj_area = get_foregr_obj_array( foregr_arr, obj, isSegm3D, z_slice=z_slice, how=how ) ID = obj.label autoBkgrVal, dataPrepBkgrVal = get_bkgrVals(df, channel, how, ID) - metrics_values = df.to_dict('list') + metrics_values = df.to_dict("list") items = get_cell_volumes_areas(df) - (cell_vols_vox, cell_vols_fl, cell_vols_vox_3D, cell_vols_fl_3D, - cell_areas_pxl, cell_areas_um2) = items + ( + cell_vols_vox, + cell_vols_fl, + cell_vols_vox_3D, + cell_vols_fl_3D, + cell_areas_pxl, + cell_areas_um2, + ) = items custom_error, custom_val, custom_col_name = get_custom_metric_value( - custom_func, foregr_obj_arr, autoBkgrVal, dataPrepBkgrVal, obj, - o, metrics_values, cell_vols_vox, cell_vols_fl, cell_areas_pxl, - cell_areas_um2, foregr_img, lab, isSegm3D, - other_channels_foregr_imgs, col, - cell_vols_vox_3D=cell_vols_vox_3D, - cell_vols_fl_3D=cell_vols_fl_3D + custom_func, + foregr_obj_arr, + autoBkgrVal, + dataPrepBkgrVal, + obj, + o, + metrics_values, + cell_vols_vox, + cell_vols_fl, + cell_areas_pxl, + cell_areas_um2, + foregr_img, + lab, + isSegm3D, + other_channels_foregr_imgs, + col, + cell_vols_vox_3D=cell_vols_vox_3D, + cell_vols_fl_3D=cell_vols_fl_3D, ) if custom_col_name is None: - df.at[ID, f'{col}{text_to_append_to_col}'] = custom_val + df.at[ID, f"{col}{text_to_append_to_col}"] = custom_val else: for custom_col, value in zip(custom_col_name, custom_val): - df.at[ID, f'{custom_col}{text_to_append_to_col}'] = value - + df.at[ID, f"{custom_col}{text_to_append_to_col}"] = value + if customMetricsCritical is not None and custom_error: customMetricsCritical.emit(custom_error, col) - + return df + def add_foregr_standard_metrics( - df, rp, channel, foregr_data, - foregr_metrics_params, - metrics_func, isSegm3D, - lab, foregr_img, - z_slice=None, - manualBackgrRp=None, - customMetricsCritical=None, - text_to_append_to_col='' - ): + df, + rp, + channel, + foregr_data, + foregr_metrics_params, + metrics_func, + isSegm3D, + lab, + foregr_img, + z_slice=None, + manualBackgrRp=None, + customMetricsCritical=None, + text_to_append_to_col="", +): if manualBackgrRp is not None: manualBackgrRp = {obj.label for obj in manualBackgrRp} - custom_errors = '' + custom_errors = "" # Iterate objects and compute foreground metrics for o, obj in enumerate(tqdm(rp, ncols=100, leave=False)): for col, (func_name, how) in foregr_metrics_params.items(): @@ -1701,14 +1811,14 @@ def add_foregr_standard_metrics( foregr_arr = foregr_data.get(how) if foregr_arr is None: continue - + foregr_obj_arr, obj_area = get_foregr_obj_array( foregr_arr, obj, isSegm3D, z_slice=z_slice, how=how ) - is_manual_bkgr_metric = func_name.find('manualBkgr') != -1 - is_amount_metric = func_name.find('amount_') != -1 + is_manual_bkgr_metric = func_name.find("manualBkgr") != -1 + is_amount_metric = func_name.find("amount_") != -1 if is_amount_metric and not is_manual_bkgr_metric: - bkgr_type = func_name[len('amount_'):] + bkgr_type = func_name[len("amount_") :] try: bkgr_val = get_bkgrVals( df, channel, how, obj.label, bkgr_type=bkgr_type @@ -1724,100 +1834,118 @@ def add_foregr_standard_metrics( else: func = metrics_func[func_name] val = func(foregr_obj_arr) - df.at[obj.label, f'{col}{text_to_append_to_col}'] = val + df.at[obj.label, f"{col}{text_to_append_to_col}"] = val return df + def add_bkgr_values( - df, bkgr_data, bkgr_metrics_params, metrics_func, - manualBackgrRp=None, foregr_data=None, - text_to_append_to_col='' - ): + df, + bkgr_data, + bkgr_metrics_params, + metrics_func, + manualBackgrRp=None, + foregr_data=None, + text_to_append_to_col="", +): # Compute background values for col, (bkgr_type, func_name, how) in bkgr_metrics_params.items(): bkgr_func = metrics_func[func_name] - if bkgr_type == 'manualBkgr': - add_manual_bkgr_values( - manualBackgrRp, foregr_data, df, col, how, bkgr_func - ) + if bkgr_type == "manualBkgr": + add_manual_bkgr_values(manualBackgrRp, foregr_data, df, col, how, bkgr_func) continue bkgr_arr = bkgr_data[bkgr_type].get(how) if bkgr_arr is None: continue - + bkgr_val = bkgr_func(bkgr_arr) - df[f'{col}{text_to_append_to_col}'] = bkgr_val + df[f"{col}{text_to_append_to_col}"] = bkgr_val return df + def add_manual_bkgr_values(manualBackgrRp, foregr_data, df, col, how, bkgr_func): if manualBackgrRp is None: return if foregr_data is None: return - + foregr_img = foregr_data.get(how) if foregr_img is None: return - + for obj in manualBackgrRp: bkgr_obj_arr = foregr_img[obj.slice][obj.image] bkgr_val = bkgr_func(bkgr_obj_arr) df.at[obj.label, col] = bkgr_val + def add_regionprops_metrics(df, lab, regionprops_to_save, logger_func=None): if not regionprops_to_save: return df, [] - if 'label' not in regionprops_to_save: - regionprops_to_save = ('label', *regionprops_to_save) + if "label" not in regionprops_to_save: + regionprops_to_save = ("label", *regionprops_to_save) rp_table, rp_errors = regionprops_table( lab, regionprops_to_save, logger_func=logger_func ) - df_rp = pd.DataFrame(rp_table).set_index('label') - df_rp.index.name = 'Cell_ID' + df_rp = pd.DataFrame(rp_table).set_index("label") + df_rp.index.name = "Cell_ID" # Drop regionprops that were already calculated in a prev session - df = df.drop(columns=df_rp.columns, errors='ignore') + df = df.drop(columns=df_rp.columns, errors="ignore") df = df.join(df_rp) return df, rp_errors + def get_custom_metric_value( - custom_func, foregr_obj_arr, autoBkgrVal, dataPrepBkgrVal, obj, - i, metrics_values, cell_vols_vox, cell_vols_fl, cell_areas_pxl, - cell_areas_um2, foregr_img, lab, isSegm3D, - other_channels_foregr_imgs: Dict[str, np.ndarray], col_name, - cell_vols_vox_3D=None, - cell_vols_fl_3D=None - ): + custom_func, + foregr_obj_arr, + autoBkgrVal, + dataPrepBkgrVal, + obj, + i, + metrics_values, + cell_vols_vox, + cell_vols_fl, + cell_areas_pxl, + cell_areas_um2, + foregr_img, + lab, + isSegm3D, + other_channels_foregr_imgs: Dict[str, np.ndarray], + col_name, + cell_vols_vox_3D=None, + cell_vols_fl_3D=None, +): base_args = (foregr_obj_arr, autoBkgrVal, dataPrepBkgrVal) - - metrics_obj = {key:mm[i] for key, mm in metrics_values.items()} - metrics_obj['cell_vol_vox'] = cell_vols_vox[i] - metrics_obj['cell_vol_fl'] = cell_vols_fl[i] - metrics_obj['cell_area_pxl'] = cell_areas_pxl[i] - metrics_obj['cell_area_um2'] = cell_areas_um2[i] + + metrics_obj = {key: mm[i] for key, mm in metrics_values.items()} + metrics_obj["cell_vol_vox"] = cell_vols_vox[i] + metrics_obj["cell_vol_fl"] = cell_vols_fl[i] + metrics_obj["cell_area_pxl"] = cell_areas_pxl[i] + metrics_obj["cell_area_um2"] = cell_areas_um2[i] if isSegm3D and cell_vols_vox_3D is not None and cell_vols_fl_3D is not None: - metrics_obj['cell_vol_vox_3D'] = cell_vols_vox_3D[i] - metrics_obj['cell_vol_fl_3D'] = cell_vols_fl_3D[i] - + metrics_obj["cell_vol_vox_3D"] = cell_vols_vox_3D[i] + metrics_obj["cell_vol_fl_3D"] = cell_vols_fl_3D[i] + additional_args_kwargs = ( ((), {}), - ((obj,), {}), + ((obj,), {}), ((obj, metrics_obj), {}), - ((obj, metrics_obj, foregr_img, lab), {'isSegm3D': isSegm3D}), - ) + ((obj, metrics_obj, foregr_img, lab), {"isSegm3D": isSegm3D}), + ) error = None for args, kwargs in additional_args_kwargs: try: custom_val = custom_func(*base_args, *args, **kwargs) - return '', custom_val, None + return "", custom_val, None except TypeError as err: - if 'required positional arguments' in str(err): + if "required positional arguments" in str(err): continue except Exception as error: return traceback.format_exc(), np.nan, None - + # Test if custom metric function requires the other channels images custom_vals_vs_other_ch = [] col_names = [] @@ -1825,92 +1953,102 @@ def get_custom_metric_value( other_channel_foregr_img = {other_channel: other_ch_img} try: custom_val = custom_func( - *base_args, obj, metrics_obj, foregr_img, lab, - other_channel_foregr_img, isSegm3D=isSegm3D + *base_args, + obj, + metrics_obj, + foregr_img, + lab, + other_channel_foregr_img, + isSegm3D=isSegm3D, ) custom_vals_vs_other_ch.append(custom_val) - col_names.append(f'{col_name}_vs_{other_channel}') + col_names.append(f"{col_name}_vs_{other_channel}") except Exception as error: return traceback.format_exc(), np.nan, None - - return '', custom_vals_vs_other_ch, col_names + + return "", custom_vals_vs_other_ch, col_names + def get_ch_indipend_custom_metric_value( - custom_func, - all_channels_obj_intens, - all_channels_autoBkgr, - all_channels_dataPrepBkgr, - obj, i, - metrics_values, cell_vols_vox, cell_vols_fl, cell_areas_pxl, - cell_areas_um2, - all_channels_foregr_imgs, - lab, - isSegm3D, - col_name, - cell_vols_vox_3D=None, - cell_vols_fl_3D=None - ): + custom_func, + all_channels_obj_intens, + all_channels_autoBkgr, + all_channels_dataPrepBkgr, + obj, + i, + metrics_values, + cell_vols_vox, + cell_vols_fl, + cell_areas_pxl, + cell_areas_um2, + all_channels_foregr_imgs, + lab, + isSegm3D, + col_name, + cell_vols_vox_3D=None, + cell_vols_fl_3D=None, +): base_args = ( - all_channels_obj_intens, - all_channels_autoBkgr, - all_channels_dataPrepBkgr + all_channels_obj_intens, + all_channels_autoBkgr, + all_channels_dataPrepBkgr, ) - - metrics_obj = {key:mm[i] for key, mm in metrics_values.items()} - metrics_obj['cell_vol_vox'] = cell_vols_vox[i] - metrics_obj['cell_vol_fl'] = cell_vols_fl[i] - metrics_obj['cell_area_pxl'] = cell_areas_pxl[i] - metrics_obj['cell_area_um2'] = cell_areas_um2[i] + + metrics_obj = {key: mm[i] for key, mm in metrics_values.items()} + metrics_obj["cell_vol_vox"] = cell_vols_vox[i] + metrics_obj["cell_vol_fl"] = cell_vols_fl[i] + metrics_obj["cell_area_pxl"] = cell_areas_pxl[i] + metrics_obj["cell_area_um2"] = cell_areas_um2[i] if isSegm3D and cell_vols_vox_3D is not None and cell_vols_fl_3D is not None: - metrics_obj['cell_vol_vox_3D'] = cell_vols_vox_3D[i] - metrics_obj['cell_vol_fl_3D'] = cell_vols_fl_3D[i] + metrics_obj["cell_vol_vox_3D"] = cell_vols_vox_3D[i] + metrics_obj["cell_vol_fl_3D"] = cell_vols_fl_3D[i] additional_args_kwargs = ( ((), {}), - ((obj,), {}), + ((obj,), {}), ((obj, metrics_obj), {}), - ((obj, metrics_obj, all_channels_foregr_imgs, lab), {'isSegm3D': isSegm3D}), - ) + ((obj, metrics_obj, all_channels_foregr_imgs, lab), {"isSegm3D": isSegm3D}), + ) traceback_text = None for args, kwargs in additional_args_kwargs: try: custom_val = custom_func(*base_args, *args, **kwargs) - return '', custom_val, None + return "", custom_val, None except TypeError as err: - if 'required positional arguments' in str(err): + if "required positional arguments" in str(err): continue except Exception as error: traceback_text = traceback.format_exc() - + return traceback_text, np.nan, None + def get_channel_indipend_custom_metrics_params( - ch_indipend_custom_func_dict, ch_indipend_custom_metric_cols - ): + ch_indipend_custom_func_dict, ch_indipend_custom_metric_cols +): custom_metrics_params = {} - + for col in ch_indipend_custom_metric_cols: for metric, custom_func in ch_indipend_custom_func_dict.items(): - custom_pattern = ( - rf'({metric})_?({how_3D_to_2D_pattern}*)' - ) + custom_pattern = rf"({metric})_?({how_3D_to_2D_pattern}*)" m = re.findall(custom_pattern, col) if m: - # Metric is a standard metric + # Metric is a standard metric func_name, how = m[0] custom_metrics_params[col] = (custom_func, how) break - + return custom_metrics_params + def get_metrics_params(all_channels_metrics, metrics_func, custom_func_dict): channel_names = list(all_channels_metrics.keys()) - bkgr_metrics_params = {ch:{} for ch in channel_names} - foregr_metrics_params = {ch:{} for ch in channel_names} + bkgr_metrics_params = {ch: {} for ch in channel_names} + foregr_metrics_params = {ch: {} for ch in channel_names} concentration_metrics_params = {} - custom_metrics_params = {ch:{} for ch in channel_names} - az = r'[A-Za-z0-9]' - bkgrVal_pattern = fr'_({az}+)_bkgrVal_({az}+)_?({az}*)$' + custom_metrics_params = {ch: {} for ch in channel_names} + az = r"[A-Za-z0-9]" + bkgrVal_pattern = rf"_({az}+)_bkgrVal_({az}+)_?({az}*)$" for channel_name, columns in all_channels_metrics.items(): for col in columns: @@ -1918,31 +2056,29 @@ def get_metrics_params(all_channels_metrics, metrics_func, custom_func_dict): if m: # The metric is a bkgrVal metric bkgr_type, func_name, how = m[0] - bkgr_metrics_params[channel_name][col] = ( - bkgr_type, func_name, how - ) + bkgr_metrics_params[channel_name][col] = (bkgr_type, func_name, how) continue - + is_standard_foregr = False for metric in metrics_func: foregr_pattern = ( - rf'{channel_name}_({metric})_?({how_3D_to_2D_pattern}*)$' + rf"{channel_name}_({metric})_?({how_3D_to_2D_pattern}*)$" ) m = re.findall(foregr_pattern, col) if m: - # Metric is a standard metric + # Metric is a standard metric func_name, how = m[0] foregr_metrics_params[channel_name][col] = (func_name, how) is_standard_foregr = True break - + if is_standard_foregr: continue # Metric is concentration - conc_pattern = rf'concentration_{az}+_from_vol_[a-z]+' + conc_pattern = rf"concentration_{az}+_from_vol_[a-z]+" conc_metric_pattern = ( - rf'{channel_name}_({conc_pattern})_?({how_3D_to_2D_pattern}*)' + rf"{channel_name}_({conc_pattern})_?({how_3D_to_2D_pattern}*)" ) m = re.findall(conc_metric_pattern, col) if m: @@ -1952,21 +2088,24 @@ def get_metrics_params(all_channels_metrics, metrics_func, custom_func_dict): for metric, custom_func in custom_func_dict.items(): custom_pattern = ( - rf'{channel_name}_({metric})_?({how_3D_to_2D_pattern}*)' + rf"{channel_name}_({metric})_?({how_3D_to_2D_pattern}*)" ) m = re.findall(custom_pattern, col) if m: - # Metric is a standard metric + # Metric is a standard metric func_name, how = m[0] custom_metrics_params[channel_name][col] = (custom_func, how) break - + params = ( - bkgr_metrics_params, foregr_metrics_params, - concentration_metrics_params, custom_metrics_params + bkgr_metrics_params, + foregr_metrics_params, + concentration_metrics_params, + custom_metrics_params, ) return params + def get_regionprops_columns(existing_colnames, selected_props_names): selected_rp_cols = [] for col in existing_colnames: @@ -1974,38 +2113,36 @@ def get_regionprops_columns(existing_colnames, selected_props_names): if selected_prop == col: selected_rp_cols.append(col) continue - m = re.match(fr'{selected_prop}-\d', col) + m = re.match(rf"{selected_prop}-\d", col) if m is not None: selected_rp_cols.append(col) return selected_rp_cols + def calc_circularity(obj): if obj.image.ndim == 3: - raise TypeError( - 'Circularity can only be calculated for 2D objects.' - ) - + raise TypeError("Circularity can only be calculated for 2D objects.") + circularity = 4 * np.pi * obj.area / pow(obj.perimeter, 2) return circularity + def calc_roundness(obj): if obj.image.ndim == 3: - raise TypeError( - 'Roundness can only be calculated for 2D objects.' - ) - + raise TypeError("Roundness can only be calculated for 2D objects.") + roundness = 4 * obj.area / np.pi / pow(obj.major_axis_length, 2) return roundness + def calc_aspect_ratio(obj): if obj.image.ndim == 3: - raise TypeError( - 'Roundness can only be calculated for 2D objects.' - ) - + raise TypeError("Roundness can only be calculated for 2D objects.") + roundness = obj.major_axis_length / obj.minor_axis_length return roundness + def calc_additional_regionprops(obj): if obj.image.ndim == 3: circularity_sum = 0 @@ -2028,11 +2165,11 @@ def calc_additional_regionprops(obj): aspect_ratio = calc_aspect_ratio(obj) else: raise TypeError( - 'Additional regionprops can be calculated only for 2D or 3D objects.' + "Additional regionprops can be calculated only for 2D or 3D objects." ) - + obj.circularity = circularity obj.roundness = roundness obj.aspect_ratio = aspect_ratio - - return obj \ No newline at end of file + + return obj diff --git a/cellacdc/metrics/CV.py b/cellacdc/metrics/CV.py index 35fb9d004..5c08025e5 100755 --- a/cellacdc/metrics/CV.py +++ b/cellacdc/metrics/CV.py @@ -1,13 +1,14 @@ import numpy as np + def CV( - signal: np.ndarray, - autoBkgr: float, - dataPrepBkgr: float, - objectRp, - correct_with_bkgr=False, - which_bkgr='auto' - ): + signal: np.ndarray, + autoBkgr: float, + dataPrepBkgr: float, + objectRp, + correct_with_bkgr=False, + which_bkgr="auto", +): """Function used to calculate coefficient of variation. NOTE: Make sure to name the function with the same name as the Python file @@ -28,7 +29,7 @@ def CV( data prep step (Cell-ACDC module 1). Pass None if background correction with this vaue is not needed. objectRp: skimage.measure.RegionProperties class - Region properties for the single object. + Region properties for the single object. Refer to `skimage.measure.regionprops` for more information on the available region properties. correct_with_bkgr : boolean @@ -43,14 +44,14 @@ def CV( Coefficient of Variation """ - + if correct_with_bkgr: - if which_bkgr=='auto': + if which_bkgr == "auto": signal = signal - autoBkgr elif dataPrepBkgr is not None: signal = signal - dataPrepBkgr # Here goes your custom metric computation - CV = np.std(signal)/np.mean(signal) + CV = np.std(signal) / np.mean(signal) return CV diff --git a/cellacdc/metrics/channel_indipendent_metric_example.py b/cellacdc/metrics/channel_indipendent_metric_example.py index 2920243bf..a4d4f304a 100644 --- a/cellacdc/metrics/channel_indipendent_metric_example.py +++ b/cellacdc/metrics/channel_indipendent_metric_example.py @@ -4,36 +4,37 @@ from cellacdc import printl -# If you want to calculate the metric for each channel, set this to True. -# If you want to calculate the metric only once after metrics for all channels +# If you want to calculate the metric for each channel, set this to True. +# If you want to calculate the metric only once after metrics for all channels # have been computed, set this to False. CALCULATE_FOR_EACH_CHANNEL = False + def channel_indipendent_metric( - all_channels_signals, - all_channels_autoBkgr, - all_channels_dataPrepBkgr, - objectRp, - metrics_values, - images, - lab, - isSegm3D=False - ): - """Shows how to combine multiple metrics in a channel-indipendent manner + all_channels_signals, + all_channels_autoBkgr, + all_channels_dataPrepBkgr, + objectRp, + metrics_values, + images, + lab, + isSegm3D=False, +): + """Shows how to combine multiple metrics in a channel-indipendent manner using a custom function. Parameters ---------- all_channels_signals : dictionary of numpy 1D arrays - Dictionary with channel names as keys and the numpy array as value + Dictionary with channel names as keys and the numpy array as value with all the intensities of the signal from each single segmented object. all_channels_autoBkgr : dictionary of single numeric value - Dictionary with channel names as keys and as value the median of all + Dictionary with channel names as keys and as value the median of all the background pixels (i.e. pixels with value 0 in the segmentation mask). Pass None if background correction with this value is not needed. all_channels_dataPrepBkgr : dictionary of single numeric value - Dictionary with channel names as keys and as value the median of all + Dictionary with channel names as keys and as value the median of all the pixels inside the background ROIs added during the data prep step (Cell-ACDC module 1). Pass None if background correction with this vaue is not needed. @@ -41,15 +42,15 @@ def channel_indipendent_metric( Refer to `skimage.measure.regionprops` for more information on the available region properties. metrics_values : dict - Dictionary of metrics values of the specific segmented object - (i.e., cell). You can access these values with the name of the + Dictionary of metrics values of the specific segmented object + (i.e., cell). You can access these values with the name of the specific metric (i.e., column name in the acdc_output.csv file) - Examples: + Examples: - mCitrine_mean = metrics_values['mCitrine_mean'] - - _mean_key = [key for key in metrics_values if key.endswith('_mean')][0] + - _mean_key = [key for key in metrics_values if key.endswith('_mean')][0] _mean = metrics_values[_mean_key] images : dictionary of numpy array - Dictionary with channel names as keys and the corresponding image + Dictionary with channel names as keys and the corresponding image signal as value lab : numpy array Segmentation mask of `image` @@ -58,45 +59,45 @@ def channel_indipendent_metric( ------- float Numerical value of the computed metric - + Notes ----- - - 1. The function must have the same name as the Python file containing it + + 1. The function must have the same name as the Python file containing it (e.g., if this file is called CV.py the function must be called CV) 2. The function must return a single number. You will need one .py for each additional custom metric. - - This implementation shows how to compute the ratio of the amount between - the first two channels (alphabetically) divided by the cell_vol_fl. + + This implementation shows how to compute the ratio of the amount between + the first two channels (alphabetically) divided by the cell_vol_fl. """ - + channels = list(all_channels_signals.keys()) channels = natsorted(channels) channel_1 = channels[0] - + try: channel_2 = channels[1] except IndexError: # Only one channel loaded. Returning 0. return 0.0 - + ch1_amount_key = [ - key for key in metrics_values - if key.startswith(f'{channel_1}_amount_autoBkgr')][0] - + key for key in metrics_values if key.startswith(f"{channel_1}_amount_autoBkgr") + ][0] + ch1_amount = metrics_values[ch1_amount_key] - + ch2_amount_key = [ - key for key in metrics_values - if key.startswith(f'{channel_2}_amount_autoBkgr')][0] + key for key in metrics_values if key.startswith(f"{channel_2}_amount_autoBkgr") + ][0] ch2_amount = metrics_values[ch2_amount_key] - - cell_vol_fl = metrics_values['cell_vol_fl'] - + + cell_vol_fl = metrics_values["cell_vol_fl"] + amount_ratio = ch1_amount / ch2_amount - - totally_useless_metric = amount_ratio/cell_vol_fl + + totally_useless_metric = amount_ratio / cell_vol_fl return totally_useless_metric diff --git a/cellacdc/metrics/combine_metrics_example.py b/cellacdc/metrics/combine_metrics_example.py index 6698cee51..d53e1d356 100644 --- a/cellacdc/metrics/combine_metrics_example.py +++ b/cellacdc/metrics/combine_metrics_example.py @@ -1,10 +1,19 @@ import numpy as np + def combine_metrics_example( - signal, autoBkgr, dataPrepBkgr, objectRp, metrics_values, image, lab, - other_channel_foregr_img, correct_with_bkgr=False, which_bkgr='auto', - isSegm3D=False - ): + signal, + autoBkgr, + dataPrepBkgr, + objectRp, + metrics_values, + image, + lab, + other_channel_foregr_img, + correct_with_bkgr=False, + which_bkgr="auto", + isSegm3D=False, +): """Shows how to combine multiple metrics in a custom function. Parameters @@ -24,21 +33,21 @@ def combine_metrics_example( Refer to `skimage.measure.regionprops` for more information on the available region properties. metrics_values : dict - Dictionary of metrics values of the specific segmented object - (i.e., cell). You can access these values with the name of the + Dictionary of metrics values of the specific segmented object + (i.e., cell). You can access these values with the name of the specific metric (i.e., column name in the acdc_output.csv file) - Examples: + Examples: - mCitrine_mean = metrics_values['mCitrine_mean'] - - _mean_key = [key for key in metrics_values if key.endswith('_mean')][0] + - _mean_key = [key for key in metrics_values if key.endswith('_mean')][0] _mean = metrics_values[_mean_key] image : numpy array Image signal being analysed (if time-lapse this is the current frame) lab : numpy array Segmentation mask of `image` other_channel_foregr_img : dict - Dictionary with a single key another loaded channel name and values - the corresponding channel signal. Cell-ACDC will run this function - for as many other channles were loaded. Do not include in your custom + Dictionary with a single key another loaded channel name and values + the corresponding channel signal. Cell-ACDC will run this function + for as many other channles were loaded. Do not include in your custom function if you don't need it. correct_with_bkgr : boolean Pass True if you need background correction. @@ -50,23 +59,23 @@ def combine_metrics_example( ------- float Numerical value of the computed metric - + Notes ----- - - 1. The function must have the same name as the Python file containing it + + 1. The function must have the same name as the Python file containing it (e.g., this file is called CV.py and the function is called CV) 2. The function must return a single number. You will need one .py for each additional custom metric. - - This implementation shows how to compute the concentration for all the - available channels. Concentration is calculated as the ratio between - columns ending with `_amount_autoBkgr` and `cell_vol_fl`. + + This implementation shows how to compute the concentration for all the + available channels. Concentration is calculated as the ratio between + columns ending with `_amount_autoBkgr` and `cell_vol_fl`. """ - _amount_key = [key for key in metrics_values if key.endswith('_amount_autoBkgr')][0] + _amount_key = [key for key in metrics_values if key.endswith("_amount_autoBkgr")][0] _amount = metrics_values[_amount_key] - cell_vol_fl = metrics_values['cell_vol_fl'] - _concentration = _amount/cell_vol_fl + cell_vol_fl = metrics_values["cell_vol_fl"] + _concentration = _amount / cell_vol_fl return _concentration diff --git a/cellacdc/mixins/__init__.py b/cellacdc/mixins/__init__.py new file mode 100644 index 000000000..a247ec3e1 --- /dev/null +++ b/cellacdc/mixins/__init__.py @@ -0,0 +1,23 @@ +"""Mixins for gui.py.""" + +from __future__ import annotations + +import importlib + +_GRAPH = None + + +def _load_graph(): + global _GRAPH + if _GRAPH is None: + _GRAPH = importlib.import_module("cellacdc.mixins._graph") + return _GRAPH + + +def __getattr__(name: str): + graph = _load_graph() + if name not in graph.MODULE_TO_CLASS.values(): + raise AttributeError(name) + module = next(k for k, v in graph.MODULE_TO_CLASS.items() if v == name) + mod = importlib.import_module(f"cellacdc.mixins.{graph.file_module(module)}") + return getattr(mod, name) diff --git a/cellacdc/mixins/_apply_parents.py b/cellacdc/mixins/_apply_parents.py new file mode 100644 index 000000000..80dc76c77 --- /dev/null +++ b/cellacdc/mixins/_apply_parents.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +"""Apply upstream mixin parents to mixin class definitions.""" + +from __future__ import annotations + +import importlib.util +import re +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +MIXINS = ROOT / "mixins" +GUI = ROOT / "gui.py" + + +def load_graph(): + spec = importlib.util.spec_from_file_location( + "cellacdc_mixins_graph", MIXINS / "_graph.py" + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +GRAPH = load_graph() +MIXIN_PARENTS = GRAPH.MIXIN_PARENTS +class_name = GRAPH.class_name +file_module = GRAPH.file_module +guiwin_classes = GRAPH.guiwin_classes +guiwin_roots = GRAPH.guiwin_roots + +FILE_CLASSES: dict[str, list[tuple[str, str]]] = { + "combine": [("combine", "CombineGui"), ("combine_worker", "CombineWorker")], +} + + +def parent_imports(module: str, parents: tuple[str, ...]) -> list[str]: + lines = [] + seen: set[str] = set() + child_file = file_module(module) + for p in parents: + fm = file_module(p) + if fm == child_file: + continue + cn = class_name(p) + key = (fm, cn) + if key in seen: + continue + seen.add(key) + lines.append(f"from .{fm} import {cn}") + return lines + + +def rewrite_class_bases(content: str, cls: str, parents: tuple[str, ...]) -> str: + parent_names = [class_name(p) for p in parents] + if parent_names: + bases = ", ".join(parent_names) + content = re.sub( + rf"^class {cls}\([^)]*\)\:", + f"class {cls}({bases}):", + content, + count=1, + flags=re.MULTILINE, + ) + content = re.sub( + rf"^class {cls}\:", + f"class {cls}({bases}):", + content, + count=1, + flags=re.MULTILINE, + ) + else: + content = re.sub( + rf"^class {cls}\([^)]*\)\:", + f"class {cls}:", + content, + count=1, + flags=re.MULTILINE, + ) + return content + + +def inject_imports(content: str, import_lines: list[str]) -> str: + if not import_lines: + # strip stale mixin imports + return strip_old_mixin_imports(content) + block = "\n".join(import_lines) + content = strip_old_mixin_imports(content) + idx = content.find("\n\nclass ") + if idx == -1: + raise ValueError("Could not find class definition anchor") + return content[:idx] + "\n" + block + content[idx:] + + +def strip_old_mixin_imports(content: str) -> str: + lines = [] + for line in content.splitlines(keepends=True): + if re.match(r"from \.[a-z_]+ import [A-Z]", line): + continue + lines.append(line) + return "".join(lines) + + +def apply_file(module: str, cls: str) -> None: + parents = MIXIN_PARENTS.get(module, ()) + fp = MIXINS / f"{file_module(module)}.py" + content = fp.read_text() + content = rewrite_class_bases(content, cls, parents) + content = inject_imports(content, parent_imports(module, parents)) + fp.write_text(content) + print(f" {module}({cls}) <- {[class_name(p) for p in parents]}") + + +def update_gui() -> None: + classes = guiwin_classes() + src = GUI.read_text() + import_block = "from .mixins import (\n" + import_block += "".join(f" {c},\n" for c in classes) + import_block += ")\n" + src = re.sub( + r"from \.mixins import \(\n.*?\n\)\n", + import_block, + src, + count=1, + flags=re.DOTALL, + ) + bases = ",\n ".join(classes) + src = re.sub( + r"class guiWin\(QMainWindow,\n(?: .+\n)+?\):", + f"class guiWin(QMainWindow,\n {bases}):", + src, + count=1, + ) + GUI.write_text(src) + + +def main() -> None: + cycles = GRAPH.import_cycles() + if cycles: + raise SystemExit(f"Import cycles in mixin graph: {cycles}") + for mod in sorted(MIXIN_PARENTS): + if mod == "combine_worker": + continue + if mod in FILE_CLASSES: + continue + apply_file(mod, class_name(mod)) + for mod, cls in FILE_CLASSES["combine"]: + apply_file(mod, cls) + update_gui() + print("\nguiWin roots:", guiwin_classes()) + + +if __name__ == "__main__": + main() diff --git a/cellacdc/mixins/_graph.py b/cellacdc/mixins/_graph.py new file mode 100644 index 000000000..3ed139242 --- /dev/null +++ b/cellacdc/mixins/_graph.py @@ -0,0 +1,302 @@ +"""Mixin dependency graph and parent assignments for guiWin MRO.""" + +from __future__ import annotations + +# Layered parent map (downstream -> upstream). Edges only go from higher layers +# to lower layers so importing a mixin loads its parents first without cycles. +_RAW_MIXIN_PARENTS: dict[str, tuple[str, ...]] = { + # Layer 0 — foundation + "display_decorations": (), + "geometry": (), + "main_menu": (), + "measurements": (), + "canvas_tool": (), + "whitelist": (), + "combine": (), + # Layer 1 — display / chrome helpers + "image_display": ("display_decorations",), + "actions": ("image_display",), + "status_hover": ("image_display",), + "main_toolbar": ("actions",), + "quick_settings": ("actions",), + # Layer 2 — workers / session + "worker": ("image_display", "status_hover"), + "session": ("image_display", "worker"), + "app_shell": ("actions", "session"), + "tool_activation": ("image_display", "session", "worker"), + # Layer 3 — tools & canvas primitives + "brush_tools": ("geometry", "image_display", "tool_activation"), + "canvas_context_menu": ("image_display",), + "canvas_selection": ("canvas_tool", "geometry", "brush_tools"), + "label_editing": ("image_display", "session", "tool_activation"), + "undo_redo": ("session", "label_editing"), + "points_layers": ("image_display", "brush_tools"), + "mode_controls": ("session", "tool_activation"), + "annotation_display": ("image_display", "tool_activation", "mode_controls"), + # Layer 4 — canvas interaction stack + "canvas_drawing": ( + "canvas_selection", + "brush_tools", + "label_editing", + "image_display", + ), + "canvas_events": ( + "geometry", + "canvas_context_menu", + "canvas_selection", + "brush_tools", + "label_editing", + "image_display", + ), + "canvas_hover": ("canvas_events", "brush_tools", "tool_activation"), + "curvature_tools": ("brush_tools", "tool_activation", "undo_redo"), + "draw_clear_region": ("label_editing", "undo_redo", "image_display"), + "label_transform_tools": ("brush_tools", "label_editing", "image_display"), + "label_roi": ("session", "image_display", "brush_tools"), + # Layer 5 — domain features + "cell_cycle": ("session", "label_editing", "undo_redo", "image_display"), + "tracking": ("session", "label_editing", "tool_activation", "undo_redo"), + "deleted_rois": ("session", "cell_cycle", "tool_activation"), + "object_properties": ("cell_cycle", "image_display", "tracking"), + "segmentation": ("session", "image_display", "tool_activation"), + "preprocessing": ("image_display", "worker", "session"), + "saving": ("session", "worker", "app_shell"), + "graphics": ("image_display", "points_layers", "worker"), + "lineage_interactions": ("annotation_display", "tracking", "image_display"), + "custom_annotations": ("annotation_display", "object_properties"), + "magic_prompts": ("graphics", "session", "worker"), + # Layer 6 — high-level orchestrators + "frame_navigation": ( + "session", + "graphics", + "label_editing", + "display_decorations", + ), + "data_loading": ( + "app_shell", + "session", + "tool_activation", + "layout_controls", + ), + "image_controls": ("image_display", "frame_navigation"), + "window_events": ( + "app_shell", + "frame_navigation", + "label_editing", + "tool_activation", + ), + "layout_controls": ("image_controls", "window_events", "label_roi"), + "canvas_right_image": ("canvas_drawing", "canvas_events", "canvas_context_menu"), + "object_search": ("frame_navigation", "graphics", "session"), + "object_cleanup": ("cell_cycle", "session", "image_display"), + "seg_for_lost_ids": ( + "segmentation", + "frame_navigation", + "label_editing", + "session", + ), + "exporting": ("app_shell", "frame_navigation", "session"), + "combine_worker": ("combine", "graphics", "preprocessing", "worker"), +} + + +def _ancestors( + module: str, + graph: dict[str, tuple[str, ...]], + cache: dict[str, frozenset[str]], +) -> frozenset[str]: + if module not in cache: + seen: set[str] = set() + for parent in graph.get(module, ()): + seen.add(parent) + seen |= _ancestors(parent, graph, cache) + cache[module] = frozenset(seen) + return cache[module] + + +def _reduce_mixin_parents( + raw: dict[str, tuple[str, ...]], +) -> dict[str, tuple[str, ...]]: + """Drop direct parents already inherited through another direct parent.""" + cache: dict[str, frozenset[str]] = {} + reduced: dict[str, tuple[str, ...]] = {} + for module, parents in raw.items(): + kept = tuple( + parent + for parent in parents + if not any( + parent != other and parent in _ancestors(other, raw, cache) + for other in parents + ) + ) + reduced[module] = kept + return reduced + + +MIXIN_PARENTS = _reduce_mixin_parents(_RAW_MIXIN_PARENTS) + +MODULE_TO_CLASS: dict[str, str] = { + "actions": "Actions", + "annotation_display": "AnnotationDisplay", + "app_shell": "AppShell", + "brush_tools": "BrushTools", + "canvas_context_menu": "CanvasContextMenu", + "canvas_drawing": "CanvasDrawing", + "canvas_events": "CanvasEvents", + "canvas_hover": "CanvasHover", + "canvas_right_image": "CanvasRightImage", + "canvas_selection": "CanvasSelection", + "canvas_tool": "CanvasTool", + "cell_cycle": "CellCycle", + "combine": "CombineGui", + "combine_worker": "CombineWorker", + "curvature_tools": "CurvatureTools", + "custom_annotations": "CustomAnnotations", + "data_loading": "DataLoading", + "deleted_rois": "DeletedRois", + "display_decorations": "DisplayDecorations", + "draw_clear_region": "DrawClearRegion", + "exporting": "Exporting", + "frame_navigation": "FrameNavigation", + "geometry": "Geometry", + "graphics": "Graphics", + "image_controls": "ImageControls", + "image_display": "ImageDisplay", + "label_editing": "LabelEditing", + "label_roi": "LabelRoi", + "label_transform_tools": "LabelTransformTools", + "layout_controls": "LayoutControls", + "lineage_interactions": "LineageInteractions", + "magic_prompts": "MagicPrompts", + "main_menu": "MainMenu", + "main_toolbar": "MainToolbar", + "measurements": "Measurements", + "mode_controls": "ModeControls", + "object_cleanup": "ObjectCleanup", + "object_properties": "ObjectProperties", + "object_search": "ObjectSearch", + "points_layers": "PointsLayers", + "preprocessing": "Preprocessing", + "quick_settings": "QuickSettings", + "saving": "Saving", + "seg_for_lost_ids": "SegForLostIds", + "segmentation": "Segmentation", + "session": "Session", + "status_hover": "StatusHover", + "tool_activation": "ToolActivation", + "tracking": "Tracking", + "undo_redo": "UndoRedo", + "whitelist": "WhitelistGui", + "window_events": "WindowEvents", + "worker": "Worker", +} + +MODULE_FILE: dict[str, str] = { + "combine": "combine", + "combine_worker": "combine", +} + + +def class_name(module: str) -> str: + return MODULE_TO_CLASS[module] + + +def file_module(module: str) -> str: + return MODULE_FILE.get(module, module) + + +def guiwin_roots() -> list[str]: + """Modules listed directly on guiWin (not inherited via another root).""" + all_parents = {p for ps in MIXIN_PARENTS.values() for p in ps} + roots = [m for m in MODULE_TO_CLASS if m not in all_parents] + # combine is parent of combine_worker + roots = [m for m in roots if m != "combine"] + + order = [ + "whitelist", + "layout_controls", + "data_loading", + "canvas_right_image", + "canvas_hover", + "window_events", + "frame_navigation", + "graphics", + "lineage_interactions", + "custom_annotations", + "magic_prompts", + "object_search", + "object_cleanup", + "seg_for_lost_ids", + "exporting", + "combine_worker", + "curvature_tools", + "draw_clear_region", + "label_transform_tools", + "deleted_rois", + "cell_cycle", + "tracking", + "segmentation", + "preprocessing", + "saving", + "object_properties", + "annotation_display", + "mode_controls", + "main_toolbar", + "quick_settings", + "main_menu", + "measurements", + "canvas_events", + "canvas_drawing", + "canvas_selection", + "canvas_context_menu", + "brush_tools", + "canvas_tool", + "label_editing", + "label_roi", + "tool_activation", + "session", + "worker", + "app_shell", + "points_layers", + "image_controls", + "image_display", + "status_hover", + "actions", + "undo_redo", + "geometry", + "display_decorations", + ] + rank = {m: i for i, m in enumerate(order)} + return sorted(roots, key=lambda m: rank.get(m, 999)) + + +def guiwin_classes() -> list[str]: + return [class_name(m) for m in guiwin_roots()] + + +def import_cycles() -> list[list[str]]: + """Detect import cycles in the parent graph (child imports parent modules).""" + graph = MIXIN_PARENTS + mods = set(MODULE_TO_CLASS) + cycles = [] + path: list[str] = [] + visited: set[str] = set() + stack: set[str] = set() + + def dfs(node: str) -> None: + if node in stack: + cycles.append(path[path.index(node) :] + [node]) + return + if node in visited: + return + visited.add(node) + stack.add(node) + path.append(node) + for parent in graph.get(node, ()): + dfs(parent) + path.pop() + stack.remove(node) + + for mod in mods: + dfs(mod) + return cycles diff --git a/cellacdc/mixins/actions.py b/cellacdc/mixins/actions.py new file mode 100644 index 000000000..1937f5e97 --- /dev/null +++ b/cellacdc/mixins/actions.py @@ -0,0 +1,885 @@ +"""Qt view adapter for action and shortcut workflows.""" + +from __future__ import annotations + +import os +import re + +from qtpy.QtCore import Qt +from qtpy.QtGui import QIcon, QKeySequence +from qtpy.QtWidgets import QAction, QActionGroup, QToolButton + +from cellacdc import apps, is_mac, settings_folderpath, widgets + +shortcut_filepath = os.path.join(settings_folderpath, "shortcuts.ini") + +from .image_display import ImageDisplay + + +class Actions(ImageDisplay): + """Extracted from guiWin.""" + + def editShortcuts_cb(self): + if is_mac: + delObjKeySequenceText = "Ctrl" + delObjButtonText = "Left click" + else: + delObjKeySequenceText = "" + delObjButtonText = "Middle click" + + if self.delObjAction is not None: + delObjKeySequence, delObjQtButton = self.delObjAction + if delObjKeySequence is None: + delObjKeySequenceText = "" + else: + delObjKeySequenceText = delObjKeySequence.toString() + delObjKeySequenceText = delObjKeySequenceText.encode( + "ascii", "ignore" + ).decode("utf-8") + delObjButtonText = ( + "Left click" + if delObjQtButton == Qt.MouseButton.LeftButton + else "Middle click" + ) + + win = apps.ShortcutEditorDialog( + self.widgetsWithShortcut, + delObjectKey=delObjKeySequenceText, + delObjectButton=delObjButtonText, + zoomOutKeyValue=self.zoomOutKeyValue, + parent=self, + ) + win.exec_() + if win.cancel: + return + + self.delObjAction = win.delObjAction + self.zoomOutKeyValue = win.zoomOutKeyValue + self.setShortcuts(win.customShortcuts) + + def gui_connectActions(self): + # Connect File actions + if self.debug: + self.createEmptyDataAction.triggered.connect(self._createEmptyData) + self.segmNdimIndicator.clicked.connect(self.segmNdimIndicatorClicked) + self.newWindowAction.triggered.connect(self.openNewWindow) + self.newAction.triggered.connect(self.newFile) + self.openFolderAction.triggered.connect(self.openFolder) + self.openFileAction.triggered.connect(self.openFile) + self.manageVersionsAction.triggered.connect(self.manageVersions) + self.saveAction.triggered.connect(self.saveData) + self.saveAsAction.triggered.connect(self.saveAsData) + self.exportToVideoAction.triggered.connect(self.exportToVideoTriggered) + self.exportToImageAction.triggered.connect(self.exportToImageTriggered) + self.quickSaveAction.triggered.connect(self.quickSave) + self.viewPreprocDataToggle.toggled.connect(self.viewPreprocDataToggled) + self.viewCombineChannelDataToggle.toggled.connect( + self.viewCombineChannelDataToggled + ) + self.autoSaveToggle.toggled.connect(self.autoSaveToggled) + self.autoSaveAnnotToggle.toggled.connect(self.autoSaveAnnotToggled) + self.autoSaveIntervalDialog.sigValueChanged.connect( + self.autoSaveIntervalValueChanged + ) + self.autoSaveIntervalEditButton.clicked.connect(self.autoSaveIntervalEdit) + self.ccaIntegrCheckerToggle.toggled.connect(self.ccaIntegrCheckerToggled) + self.annotLostObjsToggle.toggled.connect(self.annotLostObjsToggled) + self.highLowResAction.clicked.connect(self.highLowResToggled) + self.showInExplorerAction.triggered.connect(self.showInExplorer_cb) + self.exitAction.triggered.connect(self.close) + self.undoAction.triggered.connect(self.undo) + self.redoAction.triggered.connect(self.redo) + self.nextAction.triggered.connect(self.nextActionTriggered) + self.prevAction.triggered.connect(self.prevActionTriggered) + + self.invertBwAction.toggled.connect(self.invertBw) + self.toggleColorSchemeAction.triggered.connect(self.onToggleColorScheme) + self.pxModeAction.clicked.connect(self.pxModeActionToggled) + self.editShortcutsAction.triggered.connect(self.editShortcuts_cb) + self.editAutoSaveIntervalAction.triggered.connect( + self.autoSaveIntervalEditButton.click + ) + self.showMirroredCursorAction.toggled.connect(self.showMirroredCursorToggled) + + # Connect Help actions + self.tipsAction.triggered.connect(self.showTipsAndTricks) + self.UserManualAction.triggered.connect(utils.browse_docs) + self.openLogFileAction.triggered.connect(self.openLogFile) + self.showLogFilesAction.triggered.connect(self.showLogFiles) + self.aboutAction.triggered.connect(self.showAbout) + # Connect Open Recent to dynamically populate it + # self.openRecentMenu.aboutToShow.connect(self.populateOpenRecent) + self.checkableQButtonsGroup.buttonClicked.connect(self.uncheckQButton) + + self.showPropsDockButton.sigClicked.connect(self.showPropsDockWidget) + + self.loadCustomAnnotationsAction.triggered.connect(self.loadCustomAnnotations) + self.addCustomAnnotationAction.triggered.connect(self.addCustomAnnotation) + self.viewAllCustomAnnotAction.toggled.connect(self.viewAllCustomAnnot) + self.addCustomModelVideoAction.triggered.connect( + self.showInstructionsCustomModel + ) + self.addCustomModelFrameAction.triggered.connect( + self.showInstructionsCustomModel + ) + self.addCustomModelFrameAction.callback = self.segmFrameCallback + self.addCustomModelVideoAction.callback = self.segmVideoCallback + + self.addCustomPromptModelAction.triggered.connect( + self.showInstructionsCustomPromptModel + ) + self.segmWithPromptableModelAction.triggered.connect( + self.segmWithPromptableModelActionTriggered + ) + + def gui_connectEditActions(self): + self.showInExplorerAction.setEnabled(True) + self.setEnabledFileToolbar(True) + self.loadFluoAction.setEnabled(True) + self.isEditActionsConnected = True + + self.preprocessImageAction.triggered.connect(self.preprocessAction.trigger) + self.combineChannelsAction.triggered.connect( + self.combineChannelsActionTriggered + ) + + self.overlayButton.toggled.connect(self.overlay_cb) + self.countObjsButton.toggled.connect(self.countObjectsCb) + self.togglePointsLayerAction.toggled.connect(self.pointsLayerToggled) + self.overlayLabelsButton.toggled.connect(self.overlayLabels_cb) + self.overlayButton.sigRightClick.connect(self.showOverlayContextMenu) + self.labelRoiButton.sigRightClick.connect(self.showLabelRoiContextMenu) + self.overlayLabelsButton.sigRightClick.connect( + self.showOverlayLabelsContextMenu + ) + self.rulerButton.toggled.connect(self.ruler_cb) + self.loadFluoAction.triggered.connect(self.loadFluo_cb) + self.loadPosAction.triggered.connect(self.loadPosTriggered) + # self.reloadAction.triggered.connect(self.reload_cb) + self.findIdAction.triggered.connect(self.findID) + self.zoomRectButton.toggled.connect(self.zoomRectActionToggled) + self.autoPilotButton.toggled.connect(self.autoPilotToggled) + self.skipToNewIdAction.triggered.connect(self.skipForwardToNewID) + self.slideshowButton.toggled.connect(self.launchSlideshow) + + self.copyLostObjButton.toggled.connect(self.copyLostObjContour_cb) + self.manualAnnotPastButton.toggled.connect(self.manualAnnotPast_cb) + + self.segmSingleFrameMenu.triggered.connect(self.segmFrameCallback) + self.segmVideoMenu.triggered.connect(self.segmVideoCallback) + + self.postProcessSegmAction.toggled.connect(self.postProcessSegm) + self.autoSegmAction.toggled.connect(self.autoSegm_cb) + self.realTimeTrackingToggle.clicked.connect(self.realTimeTrackingClicked) + self.repeatTrackingAction.triggered.connect(self.repeatTracking) + self.manualTrackingButton.toggled.connect(self.manualTracking_cb) + self.manualBackgroundButton.toggled.connect(self.manualBackground_cb) + self.repeatTrackingMenuAction.triggered.connect(self.repeatTracking) + self.repeatTrackingVideoAction.triggered.connect(self.repeatTrackingVideo) + for rtTrackerAction in self.trackingAlgosGroup.actions(): + rtTrackerAction.toggled.connect(self.rtTrackerActionToggled) + self.editRtTrackerParamsAction.triggered.connect(self.initRealTimeTracker) + self.delObjsOutSegmMaskAction.triggered.connect( + self.delObjsOutSegmMaskActionTriggered + ) + self.mergeIDsButton.toggled.connect(self.mergeObjs_cb) + self.brushButton.toggled.connect(self.Brush_cb) + self.eraserButton.toggled.connect(self.Eraser_cb) + self.curvToolButton.toggled.connect(self.curvTool_cb) + self.wandToolButton.toggled.connect(self.wand_cb) + self.labelRoiButton.toggled.connect(self.labelRoi_cb) + self.magicPromptsToolButton.toggled.connect(self.magicPrompts_cb) + self.drawClearRegionButton.toggled.connect(self.drawClearRegion_cb) + self.reInitCcaAction.triggered.connect(self.reInitCca) + self.moveLabelToolButton.toggled.connect(self.moveLabelButtonToggled) + self.editCcaToolAction.triggered.connect( + self.manualEditCcaToolbarActionTriggered + ) + self.assignBudMothAutoAction.triggered.connect(self.autoAssignBud_YeastMate) + self.keepIDsButton.toggled.connect(self.keepIDs_cb) + + self.whitelistIDsButton.toggled.connect(self.whitelistIDs_cb) + + self.whitelistIDsToolbar.sigWhitelistChanged.connect(self.whitelistIDsChanged) + + self.whitelistIDsToolbar.sigWhitelistAccepted.connect(self.whitelistIDsAccepted) + + self.whitelistIDsToolbar.sigViewOGIDs.connect(self.whitelistViewOGIDs) + + self.whitelistIDsToolbar.sigAddNewIDs.connect(self.whitelistAddNewIDsToggled) + + self.whitelistIDsToolbar.sigLoadOGLabs.connect(self.whitelistLoadOGLabs_cb) + + self.whitelistIDsToolbar.sigTrackOGagainstPreviousFrame.connect( + self.whitelistTrackOGagainstPreviousFrame_cb + ) + + self.expandLabelToolButton.toggled.connect(self.expandLabelCallback) + + self.reinitLastSegmFrameAction.triggered.connect(self.reInitLastSegmFrame) + + self.defaultRescaleIntensActionGroup.triggered.connect( + self.defaultRescaleIntensLutActionToggled + ) + + # self.repeatAutoCcaAction.triggered.connect(self.repeatAutoCca) + self.manuallyEditCcaAction.triggered.connect(self.manualEditCca) + self.addScaleBarAction.toggled.connect(self.addScaleBar) + self.addTimestampAction.toggled.connect(self.addTimestamp) + self.saveLabColormapAction.triggered.connect(self.saveLabelsColormap) + + self.enableSmartTrackAction.toggled.connect(self.enableSmartTrack) + # Brush/Eraser size action + self.brushSizeSpinbox.valueChanged.connect(self.brushSize_cb) + self.autoIDcheckbox.toggled.connect(self.autoIDtoggled) + # Mode + self.modeActionGroup.triggered.connect(self.changeModeFromMenu) + self.modeComboBox.sigTextChanged.connect(self.changeMode) + self.modeComboBox.activated.connect(self.clearComboBoxFocus) + self.equalizeHistPushButton.toggled.connect(self.equalizeHist) + + self.editOverlayColorAction.triggered.connect(self.toggleOverlayColorButton) + self.editTextIDsColorAction.triggered.connect(self.toggleTextIDsColorButton) + self.overlayColorButton.sigColorChanging.connect(self.changeOverlayColor) + self.overlayColorButton.sigColorChanged.connect(self.saveOverlayColor) + self.textIDsColorButton.sigColorChanging.connect(self.updateTextAnnotColor) + self.textIDsColorButton.sigColorChanged.connect(self.saveTextIDsColors) + + self.setMeasurementsAction.triggered.connect(self.showSetMeasurements) + self.addCustomMetricAction.triggered.connect(self.addCustomMetric) + self.addCombineMetricAction.triggered.connect(self.addCombineMetric) + + self.labelsGrad.colorButton.sigColorChanging.connect(self.updateBkgrColor) + self.labelsGrad.colorButton.sigColorChanged.connect(self.saveBkgrColor) + self.labelsGrad.sigGradientChangeFinished.connect(self.updateLabelsCmap) + self.labelsGrad.sigGradientChanged.connect(self.ticksCmapMoved) + self.labelsGrad.textColorButton.sigColorChanging.connect( + self.updateTextLabelsColor + ) + self.labelsGrad.textColorButton.sigColorChanged.connect( + self.saveTextLabelsColor + ) + # self.addFontSizeActions( + # self.labelsGrad.fontSizeMenu, self.setFontSizeActionChecked + # ) + + self.labelsGrad.shuffleCmapAction.triggered.connect(self.shuffle_cmap) + self.labelsGrad.greedyShuffleCmapAction.triggered.connect( + self.greedyShuffleCmap + ) + self.labelsGrad.permanentGreedyCmapAction.toggled.connect( + self.permanentGreedyCmapToggled + ) + self.shuffleCmapAction.triggered.connect(self.shuffle_cmap) + self.greedyShuffleCmapAction.triggered.connect(self.greedyShuffleCmap) + self.labelsGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) + self.labelsGrad.sigShowLabelsImgToggled.connect(self.showLabelImageItem) + self.labelsGrad.sigShowRightImgToggled.connect(self.showRightImageItem) + self.labelsGrad.sigShowNextFrameToggled.connect(self.showNextFrameImageItem) + + self.labelsGrad.defaultSettingsAction.triggered.connect( + self.restoreDefaultSettings + ) + + # self.addFontSizeActions( + # self.imgGrad.fontSizeMenu, self.setFontSizeActionChecked + # ) + self.imgGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) + self.imgGrad.textColorButton.disconnect() + self.imgGrad.textColorButton.clicked.connect( + self.editTextIDsColorAction.trigger + ) + self.imgGrad.labelsAlphaSlider.valueChanged.connect(self.updateLabelsAlpha) + self.imgGrad.defaultSettingsAction.triggered.connect( + self.restoreDefaultSettings + ) + + # Drawing mode + self.drawIDsContComboBox.currentIndexChanged.connect( + self.drawIDsContComboBox_cb + ) + self.drawIDsContComboBox.activated.connect(self.clearComboBoxFocus) + + self.annotateRightHowCombobox.currentIndexChanged.connect( + self.annotateRightHowCombobox_cb + ) + self.annotateRightHowCombobox.activated.connect(self.clearComboBoxFocus) + + self.showTreeInfoCheckbox.toggled.connect(self.setAnnotInfoMode) + + # Left + self.annotIDsCheckbox.clicked.connect(self.annotOptionClicked) + self.annotCcaInfoCheckbox.clicked.connect(self.annotOptionClicked) + self.annotContourCheckbox.clicked.connect(self.annotOptionClicked) + self.annotSegmMasksCheckbox.clicked.connect(self.annotOptionClicked) + self.drawMothBudLinesCheckbox.clicked.connect(self.annotOptionClicked) + self.drawNothingCheckbox.clicked.connect(self.annotOptionClicked) + self.annotNumZslicesCheckbox.clicked.connect(self.annotOptionClicked) + + # Right + self.annotIDsCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.annotCcaInfoCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.annotContourCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.annotSegmMasksCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.drawMothBudLinesCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.drawNothingCheckboxRight.clicked.connect(self.annotOptionClickedRight) + self.annotNumZslicesCheckboxRight.clicked.connect(self.annotOptionClickedRight) + + self.segmentToolAction.triggered.connect(self.segmentToolActionTriggered) + + self.addDelRoiAction.triggered.connect(self.addDelROI) + self.addDelPolyLineRoiButton.toggled.connect(self.addDelPolyLineRoi_cb) + self.delBorderObjAction.triggered.connect(self.delBorderObj) + self.delNewObjAction.triggered.connect(self.delNewObj) + + self.brushAutoFillCheckbox.toggled.connect(self.brushAutoFillToggled) + self.brushAutoHideCheckbox.toggled.connect(self.brushAutoHideToggled) + + self.imgGrad.sigAddScaleBar.connect(self.addScaleBarAction.setChecked) + self.imgGrad.sigAddTimestamp.connect(self.addTimestampAction.setChecked) + self.imgGrad.gradient.sigGradientChangeFinished.connect( + self.imgGradLUTfinished_cb + ) + + # self.normalizeQActionGroup.triggered.connect( + # self.normaliseIntensitiesActionTriggered + # ) + self.imgPropertiesAction.triggered.connect(self.editImgProperties) + + self.relabelSequentialAction.triggered.connect(self.relabelSequentialCallback) + + self.zoomToObjsAction.triggered.connect(self.zoomToObjsActionCallback) + self.zoomOutAction.triggered.connect(self.zoomOut) + self.preprocessAction.triggered.connect(self.preprocessActionTriggered) + self.combineChannelsAction.triggered.connect( + self.combineChannelsActionTriggered + ) + + self.viewCcaTableAction.triggered.connect(self.viewCcaTable) + + self.guiTabControl.propsQGBox.idSB.valueChanged.connect( + self.propsWidgetIDvalueChanged + ) + self.guiTabControl.highlightCheckbox.toggled.connect( + self.highlightIDonHoverCheckBoxToggled + ) + self.guiTabControl.highlightSearchedCheckbox.toggled.connect( + self.highlightSearchedIDcheckBoxToggled + ) + intensMeasurQGBox = self.guiTabControl.intensMeasurQGBox + intensMeasurQGBox.additionalMeasCombobox.currentTextChanged.connect( + self.updatePropsWidget + ) + intensMeasurQGBox.channelCombobox.currentTextChanged.connect( + self.updatePropsWidget + ) + + propsQGBox = self.guiTabControl.propsQGBox + propsQGBox.additionalPropsCombobox.currentTextChanged.connect( + self.updatePropsWidget + ) + + def gui_createActions(self): + # File actions + self.segmNdimIndicator = widgets.ToolButtonTextIcon(text="") + self.segmNdimIndicator.setCheckable(True) + self.segmNdimIndicator.setChecked(True) + # self.segmNdimIndicator.setDisabled(True) + + if self.debug: + self.createEmptyDataAction = QAction(self) + self.createEmptyDataAction.setText("DEBUG: Create empty data") + + self.newWindowAction = QAction("New Window", self) + + self.newAction = QAction(self) + self.newAction.setText("&New Segmentation File...") + self.newAction.setIcon(QIcon(":file-new.svg")) + self.openFolderAction = QAction( + QIcon(":folder-open.svg"), "&Load Folder...", self + ) + self.openFileAction = QAction( + QIcon(":image.svg"), "&Open Image/Video File...", self + ) + self.manageVersionsAction = QAction( + QIcon(":manage_versions.svg"), "Load Older Versions...", self + ) + self.manageVersionsAction.setDisabled(True) + self.saveAction = QAction(QIcon(":file-save.svg"), "Save", self) + self.saveAsAction = QAction("Save as...", self) + self.exportToVideoAction = QAction("&Video...", self) + self.exportToImageAction = QAction("&Image...", self) + self.quickSaveAction = QAction("Save Only Segmentation Masks", self) + self.loadFluoAction = QAction("Load Fluorescence Images...", self) + self.loadPosAction = QAction("Load Different Position...", self) + # self.reloadAction = QAction( + # QIcon(":reload.svg"), "Reload segmentation file", self + # ) + self.nextAction = QAction("Next", self) + self.prevAction = QAction("Previous", self) + self.showInExplorerAction = QAction( + QIcon(":drawer.svg"), f"&{self.openFolderText}", self + ) + self.exitAction = QAction("&Exit", self) + self.undoAction = QAction(QIcon(":undo.svg"), "Undo", self) + self.redoAction = QAction(QIcon(":redo.svg"), "Redo", self) + # String-based key sequences + self.newWindowAction.setShortcut("Ctrl+Shift+N") + self.newAction.setShortcut("Ctrl+N") + self.openFolderAction.setShortcut("Ctrl+O") + self.loadPosAction.setShortcut("Shift+P") + self.saveAsAction.setShortcut("Ctrl+Shift+S") + self.exportToVideoAction.setShortcut("Ctrl+Shift+V") + self.exportToImageAction.setShortcut("Ctrl+Shift+I") + self.saveAction.setShortcut("Ctrl+Alt+S") + self.quickSaveAction.setShortcut("Ctrl+S") + self.undoAction.setShortcut("Ctrl+Z") + self.redoAction.setShortcut("Ctrl+Y") + self.nextAction.setShortcut(Qt.Key_Right) + self.prevAction.setShortcut(Qt.Key_Left) + self.addAction(self.nextAction) + self.addAction(self.prevAction) + # Help tips + newTip = "Create a new segmentation file" + self.newAction.setStatusTip(newTip) + self.newAction.setWhatsThis("Create a new empty segmentation file") + + self.autoPilotButton = QAction(self) + self.autoPilotButton.setIcon(QIcon(":auto-pilot.svg")) + self.autoPilotButton.setCheckable(True) + self.autoPilotButton.setShortcut("Ctrl+Shift+A") + + self.findIdAction = QAction(self) + self.findIdAction.setIcon(QIcon(":find.svg")) + self.findIdAction.setShortcut("Ctrl+F") + + self.zoomRectButton = QToolButton(self) + self.zoomRectButton.setIcon(QIcon(":zoom_rect.svg")) + self.zoomRectButton.setCheckable(True) + self.zoomRectButton.setShortcut("Shift+Z") + self.LeftClickButtons.append(self.zoomRectButton) + self.checkableButtons.append(self.zoomRectButton) + self.checkableQButtonsGroup.addButton(self.zoomRectButton) + self.widgetsWithShortcut["Zoom to rectangular area"] = self.zoomRectButton + + self.skipToNewIdAction = QAction(self) + self.skipToNewIdAction.setIcon(QIcon(":skip_forward_new_ID.svg")) + self.skipToNewIdAction.setShortcut(widgets.KeySequenceFromText(Qt.Key_PageUp)) + + self.skipToNewIdAction.setDisabled(True) + + # Edit actions + models = utils.get_list_of_models() + models = [*models, "local_seg"] # Add local_seg for SegForLostIDsAction + self.segmActions = [] + self.modelNames = [] + self.acdcSegment_li = [] + self.models = [] + for model_name in models: + action = QAction(f"{model_name}...") + self.segmActions.append(action) + self.modelNames.append(model_name) + self.models.append(None) + self.acdcSegment_li.append(None) + action.setDisabled(True) + + self.addCustomModelFrameAction = QAction("Add custom model...", self) + self.addCustomModelVideoAction = QAction("Add custom model...", self) + + self.segmWithPromptableModelAction = QAction("Select promptable model...", self) + self.addCustomPromptModelAction = QAction( + "Add custom promptable model...", self + ) + + self.segmActionsVideo = [] + for model_name in models: + action = QAction(f"{model_name}...") + self.segmActionsVideo.append(action) + action.setDisabled(True) + + self.postProcessSegmAction = QAction("Segmentation post-processing...", self) + self.postProcessSegmAction.setDisabled(True) + self.postProcessSegmAction.setCheckable(True) + + self.EditSegForLostIDsSetSettings = QAction( + "Edit settings for Segmenting lost IDs...", self + ) + self.EditSegForLostIDsSetSettings.triggered.connect( + self.SegForLostIDsSetSettings + ) + + self.repeatTrackingAction = QAction( + QIcon(":repeat-tracking.svg"), "Repeat tracking", self + ) + self.repeatTrackingAction.setShortcut("Shift+T") + self.widgetsWithShortcut["Repeat Tracking"] = self.repeatTrackingAction + + self.editRtTrackerParamsAction = QAction( + "Edit real-time tracker parameters...", self + ) + + self.repeatTrackingMenuAction = QAction( + "Track current frame with real-time tracker...", self + ) + self.repeatTrackingMenuAction.setDisabled(True) + self.repeatTrackingMenuAction.setShortcut("Shift+T") + + self.repeatTrackingVideoAction = QAction( + "Select a tracker and track multiple frames...", self + ) + self.repeatTrackingVideoAction.setDisabled(True) + self.repeatTrackingVideoAction.setShortcut("Alt+Shift+T") + + self.trackingAlgosGroup = QActionGroup(self) + self.trackWithAcdcAction = QAction("Cell-ACDC", self) + self.trackWithAcdcAction.setCheckable(True) + self.trackingAlgosGroup.addAction(self.trackWithAcdcAction) + + self.trackWithYeazAction = QAction("YeaZ", self) + self.trackWithYeazAction.setCheckable(True) + self.trackingAlgosGroup.addAction(self.trackWithYeazAction) + + rt_trackers = utils.get_list_of_real_time_trackers() + for rt_tracker in rt_trackers: + rtTrackerAction = QAction(rt_tracker, self) + rtTrackerAction.setCheckable(True) + self.trackingAlgosGroup.addAction(rtTrackerAction) + + self.trackWithAcdcAction.setChecked(True) + aliases = utils.aliases_real_time_trackers() + + if "tracking_algorithm" in self.df_settings.index: + trackingAlgo = self.df_settings.at["tracking_algorithm", "value"] + if trackingAlgo in aliases: + trackingAlgo = aliases[trackingAlgo] + if trackingAlgo == "Cell-ACDC": + self.trackWithAcdcAction.setChecked(True) + elif trackingAlgo == "YeaZ": + self.trackWithYeazAction.setChecked(True) + else: + for rtTrackerAction in self.trackingAlgosGroup.actions(): + if rtTrackerAction.text() == trackingAlgo: + rtTrackerAction.setChecked(True) + break + + self.setMeasurementsAction = QAction("Set measurements...") + self.addCustomMetricAction = QAction("Add custom measurement...") + self.addCombineMetricAction = QAction("Add combined measurement...") + + # Standard key sequence + # self.copyAction.setShortcut(QKeySequence.StandardKey.Copy) + # self.pasteAction.setShortcut(QKeySequence.StandardKey.Paste) + # self.cutAction.setShortcut(QKeySequence.StandardKey.Cut) + # Help actions + self.tipsAction = QAction("Tips and tricks...", self) + self.UserManualAction = QAction("User Documentation...", self) + self.openLogFileAction = QAction("Open log file...", self) + self.showLogFilesAction = QAction("Show log files...", self) + self.aboutAction = QAction("About Cell-ACDC", self) + # self.aboutAction = QAction("&About...", self) + + # Assign mother to bud button + self.assignBudMothAutoAction = QAction(self) + self.assignBudMothAutoAction.setIcon(QIcon(":autoAssign.svg")) + self.assignBudMothAutoAction.setVisible(False) + + self.editCcaToolAction = QAction(self) + self.editCcaToolAction.setIcon(QIcon(":edit_cca.svg")) + # self.editCcaToolAction.setDisabled(True) + self.editCcaToolAction.setVisible(False) + + self.reInitCcaAction = QAction(self) + self.reInitCcaAction.setIcon(QIcon(":reinitCca.svg")) + self.reInitCcaAction.setVisible(False) + + self.toggleColorSchemeAction = QAction("Switch to light theme") + self.gui_updateSwitchColorSchemeActionText() + + self.pxModeAction = widgets.CheckableAction("Fixed size text annotations") + self.pxModeAction.setChecked(True) + pxModeTooltip = ( + "When the text annotations are with fixed size they scale relative " + "to the object when zooming in/out (fixed size in pixels).\n" + "This is typically faster to render, but it makes annotations " + "smaller/larger when zooming in/out, respectively.\n\n" + "Try activating it to speed up the annotation of many objects " + "in high resolution mode.\n\n" + "After activating it, you might need to increase the font size " + "from the menu on the top menubar `Edit --> Font size`." + ) + self.pxModeAction.setToolTip(pxModeTooltip) + + self.highLowResAction = widgets.CheckableAction( + "High resolution text annotations" + ) + highLowResTooltip = ( + "Resolution of the text annotations. High resolution results " + "in slower update of the annotations.\n" + "Not recommended with a number of segmented objects > 500.\n\n" + ) + self.highLowResAction.setToolTip(highLowResTooltip) + + self.editAutoSaveIntervalAction = QAction( + "Change autosave interval (minutes or frames)...", self + ) + + self.editShortcutsAction = QAction("Customize keyboard shortcuts...", self) + self.editShortcutsAction.setShortcut("Ctrl+K") + + self.showMirroredCursorAction = QAction("Show mirrored cursor on images", self) + self.showMirroredCursorAction.setCheckable(True) + if "showMirroredCursor" in self.df_settings.index: + checked = self.df_settings.at["showMirroredCursor", "value"] == "Yes" + self.showMirroredCursorAction.setChecked(checked) + else: + self.showMirroredCursorAction.setChecked(True) + self.showMirroredCursorAction.setShortcut("Ctrl+M") + + self.editTextIDsColorAction = QAction("Text annotation color...", self) + self.editTextIDsColorAction.setDisabled(True) + + self.editOverlayColorAction = QAction("Overlay color...", self) + self.editOverlayColorAction.setDisabled(True) + + self.manuallyEditCcaAction = QAction("Edit cell cycle annotations...", self) + self.manuallyEditCcaAction.setShortcut("Ctrl+Shift+P") + self.manuallyEditCcaAction.setDisabled(True) + + self.viewCcaTableAction = QAction("View cell cycle annotations...", self) + self.viewCcaTableAction.setDisabled(True) + self.viewCcaTableAction.setShortcut("Ctrl+P") + + self.addScaleBarAction = QAction("Add scale bar", self) + self.addScaleBarAction.setCheckable(True) + + self.addTimestampAction = QAction("Add timestamp", self) + self.addTimestampAction.setCheckable(True) + + self.invertBwAction = QAction("Invert black/white", self) + self.invertBwAction.setCheckable(True) + checked = self.df_settings.at["is_bw_inverted", "value"] == "Yes" + self.invertBwAction.setChecked(checked) + + self.shuffleCmapAction = QAction("Randomly shuffle colormap", self) + self.shuffleCmapAction.setShortcut("Shift+S") + + self.greedyShuffleCmapAction = QAction("Greedily shuffle colormap", self) + self.greedyShuffleCmapAction.setShortcut("Alt+Shift+S") + + self.saveLabColormapAction = QAction("Save labels colormap...", self) + + self.normalizeRawAction = QAction("Do not normalize. Display raw image", self) + self.normalizeToFloatAction = QAction( + "Convert to floating point format with values [0, 1]", self + ) + # self.normalizeToUbyteAction = QAction( + # 'Rescale to 8-bit unsigned integer format with values [0, 255]', self) + self.normalizeRescale0to1Action = QAction("Rescale to [0, 1]", self) + self.normalizeByMaxAction = QAction("Normalize by max value", self) + self.normalizeRawAction.setCheckable(True) + self.normalizeToFloatAction.setCheckable(True) + # self.normalizeToUbyteAction.setCheckable(True) + self.normalizeRescale0to1Action.setCheckable(True) + self.normalizeByMaxAction.setCheckable(True) + self.normalizeQActionGroup = QActionGroup(self) + self.normalizeQActionGroup.addAction(self.normalizeRawAction) + self.normalizeQActionGroup.addAction(self.normalizeToFloatAction) + # self.normalizeQActionGroup.addAction(self.normalizeToUbyteAction) + self.normalizeQActionGroup.addAction(self.normalizeRescale0to1Action) + self.normalizeQActionGroup.addAction(self.normalizeByMaxAction) + + self.preprocessAction = QAction("Pre-processing...", self) + self.preprocessAction.setShortcut("Alt+Shift+P") + + self.combineChannelsAction = QAction( + "Combine and manipulate channels and/or segmentation files...", self + ) + self.combineChannelsAction.setShortcut("Alt+Shift+C") + + self.zoomToObjsAction = QAction("Zoom to objects (Shortcut: H key)", self) + self.zoomOutAction = QAction("Zoom out (Shortcut: double press H key)", self) + + self.relabelSequentialAction = QAction("Relabel IDs sequentially...", self) + self.relabelSequentialAction.setShortcut("Ctrl+L") + self.relabelSequentialAction.setDisabled(True) + + self.setLastUserNormAction() + + self.autoSegmAction = QAction("Enable automatic segmentation", self) + self.autoSegmAction.setCheckable(True) + self.autoSegmAction.setDisabled(True) + + self.enableSmartTrackAction = QAction( + "Smart handling of enabling/disabling tracking", self + ) + self.enableSmartTrackAction.setCheckable(True) + self.enableSmartTrackAction.setChecked(True) + + self.enableAutoZoomToCellsAction = QAction( + 'Automatic zoom to all cells when pressing "Next/Previous"', self + ) + self.enableAutoZoomToCellsAction.setCheckable(True) + + self.imgPropertiesAction = QAction("Properties...", self) + self.imgPropertiesAction.setDisabled(True) + + self.addDelRoiAction = QAction(self) + self.addDelRoiAction.roiType = "rect" + self.addDelRoiAction.setIcon(QIcon(":addDelRoi.svg")) + + self.addDelPolyLineRoiButton = QToolButton(self) + self.addDelPolyLineRoiButton.setCheckable(True) + self.addDelPolyLineRoiButton.setIcon(QIcon(":addDelPolyLineRoi.svg")) + + self.checkableButtons.append(self.addDelPolyLineRoiButton) + self.LeftClickButtons.append(self.addDelPolyLineRoiButton) + + self.delBorderObjAction = QAction(self) + self.delBorderObjAction.setIcon(QIcon(":delBorderObj.svg")) + + self.delNewObjAction = QAction(self) + self.delNewObjAction.setIcon(QIcon(":delNewObj.svg")) + + self.loadCustomAnnotationsAction = QAction(self) + self.loadCustomAnnotationsAction.setIcon(QIcon(":load_annotation.svg")) + self.loadCustomAnnotationsAction.setToolTip( + "Load previously used custom annotations" + ) + + self.addCustomAnnotationAction = QAction(self) + self.addCustomAnnotationAction.setIcon(QIcon(":addCustomAnnotation.svg")) + self.addCustomAnnotationAction.setToolTip("Add custom annotation") + # self.functionsNotTested3D.append(self.addCustomAnnotationAction) + + self.viewAllCustomAnnotAction = QAction(self) + self.viewAllCustomAnnotAction.setCheckable(True) + self.viewAllCustomAnnotAction.setIcon(QIcon(":eye.svg")) + self.viewAllCustomAnnotAction.setToolTip("Show all custom annotations") + + def gui_updateSwitchColorSchemeActionText(self): + if self._colorScheme == "dark": + txt = "Switch to light theme" + else: + txt = "Switch to dark theme" + self.toggleColorSchemeAction.setText(txt) + + def initShortcuts(self): + from . import config + + cp = config.ConfigParser() + if os.path.exists(shortcut_filepath): + cp.read(shortcut_filepath) + + if "keyboard.shortcuts" not in cp: + cp["keyboard.shortcuts"] = {} + + if cp.has_option("keyboard.shortcuts", "Zoom out"): + zoomOutKeyValueStr = cp["keyboard.shortcuts"]["Zoom out"] + try: + self.zoomOutKeyValue = int(zoomOutKeyValueStr) + except Exception as err: + self.logger.warning( + f"{zoomOutKeyValueStr} is not a valid key " + 'zooming out action. Restoring default key "H".' + ) + + if "delete_object.action" not in cp: + self.delObjAction = None + else: + delObjKeySequenceText = cp["delete_object.action"]["Key sequence"] + delObjButtonText = cp["delete_object.action"]["Mouse button"] + delObjQtButton = ( + Qt.MouseButton.LeftButton + if delObjButtonText == "Left click" + else Qt.MouseButton.MiddleButton + ) + if not delObjKeySequenceText: + delObjKeySequence = None + else: + delObjKeySequence = widgets.KeySequenceFromText(delObjKeySequenceText) + self.delObjToolAction.setChecked(True) + self.delObjAction = delObjKeySequence, delObjQtButton + + shortcuts = {} + for name, widget in self.widgetsWithShortcut.items(): + if name not in cp.options("keyboard.shortcuts"): + if hasattr(widget, "keyPressShortcut"): + key = widget.keyPressShortcut + shortcut = widgets.KeySequenceFromText(key) + else: + shortcut = widget.shortcut() + shortcut_text = shortcut.toString() + cp["keyboard.shortcuts"][name] = shortcut_text + else: + shortcut_text = cp["keyboard.shortcuts"][name] + shortcut = widgets.KeySequenceFromText(shortcut_text) + + shortcuts[name] = (shortcut_text, shortcut) + self.setShortcuts(shortcuts, save=False) + with open(shortcut_filepath, "w") as ini: + cp.write(ini) + + def setShortcuts(self, shortcuts: dict, save=True): + for name, (text, shortcut) in shortcuts.items(): + widget = self.widgetsWithShortcut[name] + if shortcut is None: + shortcut = QKeySequence() + if hasattr(widget, "keyPressShortcut"): + widget.keyPressShortcut = shortcut + else: + widget.setShortcut(shortcut) + s = widget.toolTip() + toolTip = re.sub(r'Shortcut: "(.*)"', f'Shortcut: "{text}"', s) + widget.setToolTip(toolTip) + + if not save: + return + + from . import config + + cp = config.ConfigParser() + if os.path.exists(shortcut_filepath): + cp.read(shortcut_filepath) + + if "keyboard.shortcuts" not in cp: + cp["keyboard.shortcuts"] = {} + + for name, (text, shortcut) in shortcuts.items(): + cp["keyboard.shortcuts"][name] = text + + cp["keyboard.shortcuts"]["Zoom out"] = str(self.zoomOutKeyValue) + + if self.delObjAction is None: + with open(shortcut_filepath, "w") as ini: + cp.write(ini) + return + + delObjKeySequence, delObjQtButton = self.delObjAction + try: + if delObjKeySequence is None: + delObjKeySequenceText = "" + else: + delObjKeySequenceText = delObjKeySequence.toString() + + delObjKeySequenceText = delObjKeySequenceText.encode( + "ascii", "ignore" + ).decode("utf-8") + delObjButtonText = ( + "Left click" + if delObjQtButton == Qt.MouseButton.LeftButton + else "Middle click" + ) + cp["delete_object.action"] = { + "Key sequence": delObjKeySequenceText, + "Mouse button": delObjButtonText, + } + except Exception as err: + self.logger.warning( + f"{delObjKeySequence} is not a valid keys sequence for " + "deleting objects. Setting default action" + ) + self.delObjAction = None + cp.remove_section("delete_object.action") + + with open(shortcut_filepath, "w") as ini: + cp.write(ini) diff --git a/cellacdc/mixins/annotation_display.py b/cellacdc/mixins/annotation_display.py new file mode 100644 index 000000000..fefea09d1 --- /dev/null +++ b/cellacdc/mixins/annotation_display.py @@ -0,0 +1,1079 @@ +"""Qt view adapter for annotation display workflows.""" + +from __future__ import annotations + +import re + +from cellacdc import _palettes, apps, html_utils, widgets + +from typing import Mapping + +GREEN_HEX = _palettes.green() + +from .mode_controls import ModeControls + + +class AnnotationDisplay(ModeControls): + """Extracted from guiWin.""" + + def activateAnnotations(self): + if self.annotContourCheckbox.isChecked(): + return + if self.annotSegmMasksCheckbox.isChecked(): + return + + self.annotSegmMasksCheckbox.setChecked(True) + self.setDrawAnnotComboboxText() + + def annotGenNumTreeToggled(self, checked): + self.textAnnot[0].setGenNumTreeAnnotationsEnabled(checked) + + def annotLabelIDtreeToggled(self, checked): + self.textAnnot[0].setLabelTreeAnnotationsEnabled(checked) + + def annotOptionClicked(self, clicked=True, sender=None, saveSettings=True): + if sender is None: + sender = self.sender() + # First manually set exclusive with uncheckable + clickedIDs = sender == self.annotIDsCheckbox + clickedCca = sender == self.annotCcaInfoCheckbox + clickedMBline = sender == self.drawMothBudLinesCheckbox + if self.annotIDsCheckbox.isChecked() and clickedIDs: + if self.annotCcaInfoCheckbox.isChecked(): + self.annotCcaInfoCheckbox.setChecked(False) + if self.drawMothBudLinesCheckbox.isChecked(): + self.drawMothBudLinesCheckbox.setChecked(False) + + if self.annotCcaInfoCheckbox.isChecked() and clickedCca: + if self.annotIDsCheckbox.isChecked(): + self.annotIDsCheckbox.setChecked(False) + if self.drawMothBudLinesCheckbox.isChecked(): + self.drawMothBudLinesCheckbox.setChecked(False) + + if self.drawMothBudLinesCheckbox.isChecked() and clickedMBline: + if self.annotIDsCheckbox.isChecked(): + self.annotIDsCheckbox.setChecked(False) + if self.annotCcaInfoCheckbox.isChecked(): + self.annotCcaInfoCheckbox.setChecked(False) + + clickedCont = sender == self.annotContourCheckbox + clickedSegm = sender == self.annotSegmMasksCheckbox + if self.annotContourCheckbox.isChecked() and clickedCont: + if self.annotSegmMasksCheckbox.isChecked(): + self.annotSegmMasksCheckbox.setChecked(False) + + if self.annotSegmMasksCheckbox.isChecked() and clickedSegm: + if self.annotContourCheckbox.isChecked(): + self.annotContourCheckbox.setChecked(False) + + clickedDoNot = sender == self.drawNothingCheckbox + if clickedDoNot: + self.annotIDsCheckbox.setChecked(False) + self.annotCcaInfoCheckbox.setChecked(False) + self.annotContourCheckbox.setChecked(False) + self.annotSegmMasksCheckbox.setChecked(False) + self.drawMothBudLinesCheckbox.setChecked(False) + self.annotNumZslicesCheckbox.setChecked(False) + else: + self.drawNothingCheckbox.setChecked(False) + + if sender == self.annotNumZslicesCheckbox: + self.annotIDsCheckbox.setChecked(True) + self.drawNothingCheckbox.setChecked(False) + + self.setDrawAnnotComboboxText(saveSettings=saveSettings) + + def annotOptionClickedRight(self, clicked=True, sender=None, saveSettings=True): + if sender is None: + sender = self.sender() + # First manually set exclusive with uncheckable + clickedIDs = sender == self.annotIDsCheckboxRight + clickedCca = sender == self.annotCcaInfoCheckboxRight + clickedMBline = sender == self.drawMothBudLinesCheckboxRight + if self.annotIDsCheckboxRight.isChecked() and clickedIDs: + if self.annotCcaInfoCheckboxRight.isChecked(): + self.annotCcaInfoCheckboxRight.setChecked(False) + if self.drawMothBudLinesCheckboxRight.isChecked(): + self.drawMothBudLinesCheckboxRight.setChecked(False) + + if self.annotCcaInfoCheckboxRight.isChecked() and clickedCca: + if self.annotIDsCheckboxRight.isChecked(): + self.annotIDsCheckboxRight.setChecked(False) + if self.drawMothBudLinesCheckboxRight.isChecked(): + self.drawMothBudLinesCheckboxRight.setChecked(False) + + if self.drawMothBudLinesCheckboxRight.isChecked() and clickedMBline: + if self.annotIDsCheckboxRight.isChecked(): + self.annotIDsCheckboxRight.setChecked(False) + if self.annotCcaInfoCheckboxRight.isChecked(): + self.annotCcaInfoCheckboxRight.setChecked(False) + + clickedCont = sender == self.annotContourCheckboxRight + clickedSegm = sender == self.annotSegmMasksCheckboxRight + if self.annotContourCheckboxRight.isChecked() and clickedCont: + if self.annotSegmMasksCheckboxRight.isChecked(): + self.annotSegmMasksCheckboxRight.setChecked(False) + + if self.annotSegmMasksCheckboxRight.isChecked() and clickedSegm: + if self.annotContourCheckboxRight.isChecked(): + self.annotContourCheckboxRight.setChecked(False) + + clickedDoNot = sender == self.drawNothingCheckboxRight + if clickedDoNot: + self.annotIDsCheckboxRight.setChecked(False) + self.annotCcaInfoCheckboxRight.setChecked(False) + self.annotContourCheckboxRight.setChecked(False) + self.annotSegmMasksCheckboxRight.setChecked(False) + self.drawMothBudLinesCheckboxRight.setChecked(False) + self.annotNumZslicesCheckboxRight.setChecked(False) + else: + self.drawNothingCheckboxRight.setChecked(False) + + if sender == self.annotNumZslicesCheckboxRight: + self.annotIDsCheckboxRight.setChecked(True) + self.drawNothingCheckboxRight.setChecked(False) + + self.setDrawAnnotComboboxTextRight(saveSettings=saveSettings) + + def annotateRightHowCombobox_cb(self, idx): + how = self.annotateRightHowCombobox.currentText() + saveSettings = True + if hasattr(self.annotateRightHowCombobox, "saveSettings"): + saveSettings = self.annotateRightHowCombobox.saveSettings + + if saveSettings: + self.df_settings.at["how_draw_right_annotations", "value"] = how + self.df_settings.to_csv(self.settings_csv_path) + + mode = self.modeComboBox.currentText() + isCcaAnnot = ( + self.annotCcaInfoCheckboxRight.isChecked() + and mode != "Normal division: Lineage tree" + ) + isIDAnnot = self.annotIDsCheckboxRight.isChecked() or ( + self.annotCcaInfoCheckboxRight.isChecked() + and mode == "Normal division: Lineage tree" + ) + self.textAnnot[1].setCcaAnnot(isCcaAnnot) + + self.textAnnot[1].setLabelAnnot(isIDAnnot) + if not self.isDataLoading: + self.updateAllImages() + + def annotate_rip_and_bin_IDs(self, updateLabel=False): + depthAxes = self.switchPlaneCombobox.depthAxes() + if self.switchPlaneCombobox.isEnabled() and depthAxes != "z": + return + + posData = self.data[self.pos_i] + binnedIDs_xx = [] + binnedIDs_yy = [] + ripIDs_xx = [] + ripIDs_yy = [] + for obj in posData.rp: + obj.excluded = obj.label in posData.binnedIDs + obj.dead = obj.label in posData.ripIDs + if not self.isObjVisible(obj.bbox): + continue + + if obj.excluded: + y, x = self.getObjCentroid(obj.centroid) + binnedIDs_xx.append(x) + binnedIDs_yy.append(y) + if updateLabel: + self.getObjOptsSegmLabels(obj) + how = self.drawIDsContComboBox.currentText() + + if obj.dead: + y, x = self.getObjCentroid(obj.centroid) + ripIDs_xx.append(x) + ripIDs_yy.append(y) + if updateLabel: + self.getObjOptsSegmLabels(obj) + how = self.drawIDsContComboBox.currentText() + + self.ax2_binnedIDs_ScatterPlot.setData(binnedIDs_xx, binnedIDs_yy) + self.ax2_ripIDs_ScatterPlot.setData(ripIDs_xx, ripIDs_yy) + self.ax1_binnedIDs_ScatterPlot.setData(binnedIDs_xx, binnedIDs_yy) + self.ax1_ripIDs_ScatterPlot.setData(ripIDs_xx, ripIDs_yy) + + def applyToolNewFrameActionToggled(self, checked, toolName=None): + if toolName is None: + parentToolButton = self.sender().parent() + toolName = re.findall(r"Name: (.*)", parentToolButton.toolTip())[0] + toolName = toolName.strip() + button = self.applyToolNewFrameButtons[toolName] + toolName = toolName.replace(" ", "_") + settingName = f"{toolName}_applyNewFrame" + if checked: + self.df_settings.at[settingName, "value"] = "applyNewFrame" + button.setStyleSheet(f"background-color: {GREEN_HEX}") + else: + self.df_settings = self.df_settings.drop(index=settingName, errors="ignore") + button.setStyleSheet("background-color: none") + self.df_settings.to_csv(self.settings_csv_path) + + def areContoursRequested(self, ax): + if ax == 0 and self.annotContourCheckbox.isChecked(): + return True + + if ax == 1: + if not self.labelsGrad.showRightImgAction.isChecked(): + return False + + isRightDifferentAnnot = self.rightBottomGroupbox.isChecked() + areContRequestedRight = self.annotContourCheckboxRight.isChecked() + + if isRightDifferentAnnot and areContRequestedRight: + return True + + areContRequestedLeft = self.annotContourCheckbox.isChecked() + if not isRightDifferentAnnot and areContRequestedLeft: + return True + return False + + def areMothBudLinesRequested(self, ax): + if ax == 0: + if self.annotCcaInfoCheckbox.isChecked(): + return True + if self.drawMothBudLinesCheckbox.isChecked(): + return True + else: + if not self.labelsGrad.showRightImgAction.isChecked(): + return False + + isRightDifferentAnnot = self.rightBottomGroupbox.isChecked() + areLinesRequestedRight = ( + self.annotCcaInfoCheckboxRight.isChecked() + or self.drawMothBudLinesCheckboxRight.isChecked() + ) + + if isRightDifferentAnnot and areLinesRequestedRight: + return True + + areLinesRequestedLeft = ( + self.drawMothBudLinesCheckbox.isChecked() + or self.annotCcaInfoCheckbox.isChecked() + ) + if not isRightDifferentAnnot and areLinesRequestedLeft: + return True + return False + + def autoPilotToggled(self, checked): + self.autoPilotZoomToObjToolbar.setVisible(checked) + if checked: + self.autoPilotZoomToObjToggle.setChecked(False) + self.autoPilotZoomToObjToggle.toggle() + + def changeTextResolution(self): + mode = "high" if self.highLowResAction.isChecked() else "low" + self.logger.info(f"Switching to {mode} for the text annnotations...") + self.pxModeAction.setDisabled(not self.highLowResAction.isChecked()) + if not self.isDataLoaded: + return + + self.setAllIDs() + posData = self.data[self.pos_i] + allIDs = posData.allIDs + img_shape = self.img1.image.shape[:2] + self.textAnnot[0].changeResolution(mode, allIDs, self.ax1, img_shape) + self.textAnnot[1].changeResolution(mode, allIDs, self.ax2, img_shape) + self.updateAllImages() + + def clearAllCellToCellLines(self): + self.ax1_newMothBudLinesItem.setData([], []) + self.ax1_oldMothBudLinesItem.setData([], []) + self.ax2_newMothBudLinesItem.setData([], []) + self.ax2_oldMothBudLinesItem.setData([], []) + + def clearAnnotItems(self): + self.textAnnot[0].clear() + self.textAnnot[1].clear() + + def drawAllLineageTreeLines(self): + """ + Draw all lineage tree lines on the GUI. + + This method retrieves the lineage tree data and draws the lineage tree lines + connecting cells and their respective mothers when the mother has split. + """ + if self.lineage_tree is None: + return + + if len(self.lineage_tree.frames_for_dfs) < 2: + return + + self.clearAllCellToCellLines() + posData = self.data[self.pos_i] + frame_i = posData.frame_i + lin_tree_df = posData.allData_li[frame_i]["acdc_df"] + lin_tree_df_prev = posData.allData_li[frame_i - 1]["acdc_df"] + rp = posData.rp + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + + self.setTitleText() + + new_cells = lin_tree_df.index.difference( + lin_tree_df_prev.index + ) # I could use this for the if already but this is probably faster for frames where nothing changes + if new_cells.shape[0] == 0: + return + + for ax in (0, 1): + if not self.areMothBudLinesRequested(ax): + continue + + for ID in new_cells: + curr_obj = utils.get_obj_by_label(rp, ID) + lin_tree_df_ID = lin_tree_df.loc[ID] + + # lin_tree_df_mother_ID = lin_tree_df_prev.loc[lin_tree_df_ID["parent_ID_tree"]] + if ( + lin_tree_df_ID["parent_ID_tree"] == -1 + ): # make sure that new obj where the parents are not known get skipped + continue + + mother_obj = utils.get_obj_by_label( + prev_rp, lin_tree_df_ID["parent_ID_tree"] + ) + + emerg_frame_i = lin_tree_df_ID["emerg_frame_i"] + isNew = emerg_frame_i == frame_i + + self.drawObjLin_TreeMothBudLines(ax, curr_obj, mother_obj, isNew, ID=ID) + + def drawAllMothBudLines(self): + posData = self.data[self.pos_i] + for obj in posData.rp: + self.drawObjMothBudLines(obj, posData, ax=0) + self.drawObjMothBudLines(obj, posData, ax=1) + + def drawAnnotCombobox_to_options(self): + self.uncheckAnnotOptions() + + # Left + how = self.drawIDsContComboBox.currentText() + if how.find("IDs") != -1: + self.annotIDsCheckbox.setChecked(True) + if how.find("cell cycle info") != -1: + self.annotCcaInfoCheckbox.setChecked(True) + if how.find("contours") != -1: + self.annotContourCheckbox.setChecked(True) + if how.find("segm. masks") != -1: + self.annotSegmMasksCheckbox.setChecked(True) + if how.find("mother-bud lines") != -1: + self.drawMothBudLinesCheckbox.setChecked(True) + if how.find("nothing") != -1: + self.drawNothingCheckbox.setChecked(True) + + # Right + how = self.annotateRightHowCombobox.currentText() + if how.find("IDs") != -1: + self.annotIDsCheckboxRight.setChecked(True) + if how.find("cell cycle info") != -1: + self.annotCcaInfoCheckboxRight.setChecked(True) + if how.find("contours") != -1: + self.annotContourCheckboxRight.setChecked(True) + if how.find("segm. masks") != -1: + self.annotSegmMasksCheckboxRight.setChecked(True) + if how.find("mother-bud lines") != -1: + self.drawMothBudLinesCheckboxRight.setChecked(True) + if how.find("nothing") != -1: + self.drawNothingCheckboxRight.setChecked(True) + + def drawIDsContComboBox_cb(self, idx): + how = self.drawIDsContComboBox.currentText() + saveSettings = True + if hasattr(self.drawIDsContComboBox, "saveSettings"): + saveSettings = self.drawIDsContComboBox.saveSettings + + if saveSettings: + self.df_settings.at["how_draw_annotations", "value"] = how + self.df_settings.to_csv(self.settings_csv_path) + + mode = self.modeComboBox.currentText() + isCcaAnnot = ( + self.annotCcaInfoCheckbox.isChecked() + and mode != "Normal division: Lineage tree" + ) + isIDAnnot = self.annotIDsCheckbox.isChecked() or ( + self.annotCcaInfoCheckbox.isChecked() + and mode == "Normal division: Lineage tree" + ) + self.textAnnot[0].setCcaAnnot(isCcaAnnot) + + self.textAnnot[0].setLabelAnnot(isIDAnnot) + + if not self.isDataLoading: + self.updateAllImages() + + if self.eraserButton.isChecked(): + self.setTempImg1Eraser(None, init=True) + + def drawObjLin_TreeMothBudLines(self, ax, obj, mother_obj, isNew, ID=None): + """ + Draw moth-bud lines between an object and its mother object. + + Parameters + ---------- + ax : cellacdc.widgets.MainPlotItem + The Cell-ACDC GUI axes object to draw on. + obj : Object + The object for which to draw the moth-bud lines. + mother_obj : Object + The mother object to connect with. + isNew : bool + Indicates whether the object is new or not. + ID : int, optional + The ID of the object, by default None. + """ + if not self.areMothBudLinesRequested(ax): + return + + if not ID: + ID = obj.label + + isObjVisible = self.isObjVisible(obj.bbox) + + if not isObjVisible: + return + + scatterItem = self.getMothBudLineScatterItem(ax, isNew) + + y1, x1 = self.getObjCentroid(obj.centroid) + y2, x2 = self.getObjCentroid(mother_obj.centroid) + xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) + scatterItem.addPoints(xx, yy) + + def drawObjMothBudLines(self, obj, posData, ax=0): + areMothBudLinesRequested = self.areMothBudLinesRequested(ax) + if not areMothBudLinesRequested: + return + + if posData.cca_df is None: + return + + mode = str(self.modeComboBox.currentText()) + if mode == "Normal division: Lineage Tree": + return + + ID = obj.label + try: + cca_df_ID = posData.cca_df.loc[ID] + except KeyError: + return + + isObjVisible = self.isObjVisible(obj.bbox) + if not isObjVisible: + return + + ccs_ID = cca_df_ID["cell_cycle_stage"] + if ccs_ID == "G1": + return + + relationship = cca_df_ID["relationship"] + if relationship != "bud": + return + + emerg_frame_i = cca_df_ID["emerg_frame_i"] + isNew = emerg_frame_i == posData.frame_i + scatterItem = self.getMothBudLineScatterItem(ax, isNew) + relative_ID = cca_df_ID["relative_ID"] + + try: + relative_rp_idx = posData.IDs_idxs[relative_ID] + except KeyError: + return + + relative_ID_obj = posData.rp[relative_rp_idx] + y1, x1 = self.getObjCentroid(obj.centroid) + y2, x2 = self.getObjCentroid(relative_ID_obj.centroid) + xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) + scatterItem.addPoints(xx, yy) + + def getAnnotateHowRightImage(self): + if not self.labelsGrad.showRightImgAction.isChecked(): + return "nothing" + + if self.rightBottomGroupbox.isChecked(): + how = self.annotateRightHowCombobox.currentText() + else: + how = self.drawIDsContComboBox.currentText() + return how + + def getMothBudLineScatterItem(self, ax, new): + if ax == 0: + if new: + return self.ax1_newMothBudLinesItem + else: + return self.ax1_oldMothBudLinesItem + else: + if new: + return self.ax2_newMothBudLinesItem + else: + return self.ax2_oldMothBudLinesItem + + def getObjCentroid(self, obj_centroid): + if self.isSegm3D: + depthAxes = self.switchPlaneCombobox.depthAxes() + zc, yc, xc = obj_centroid + if depthAxes == "z": + return yc, xc + elif depthAxes == "y": + return zc, xc + else: + return zc, yc + else: + return obj_centroid + + def getObjOptsSegmLabels(self, obj): + if not self.labelsGrad.showLabelsImgAction.isChecked(): + return + + objOpts = self.getObjTextAnnotOpts(obj, "Draw only IDs", ax=1) + return objOpts + + def gui_raiseBottomLayoutContextMenu(self, event): + try: + # Convert QPointF to QPoint + self.bottomLayoutContextMenu.popup(event.globalPos().toPoint()) + except AttributeError: + self.bottomLayoutContextMenu.popup(event.globalPos()) + + def highLowResToggled(self, clicked=True): + self.changeTextResolution() + + def highlightZneighLabels_cb(self, checked): + if checked: + pass + else: + pass + + def keepAllToolsActiveActionToggled(self, checked): + for action in self.keepToolActiveActions.values(): + action.setChecked(checked) + + data_loaded = True + if not hasattr(self, "data"): + data_loaded = False + try: + self.labelRoiTrangeCheckbox.disconnect() + except TypeError: + pass + self.labelRoiTrangeCheckbox.setChecked( + checked + ) # why this is not wrapped in a QAction? + + if data_loaded: + self.labelRoiTrangeCheckbox.toggled.connect( + self.labelRoiTrangeCheckboxToggled + ) + + def keepToolActiveActionToggled(self, checked, toolName=None): + if toolName is None: + parentToolButton = self.sender().parent() + toolName = re.findall(r"Name: (.*)", parentToolButton.toolTip())[0] + + if checked: + self.df_settings.at[toolName, "value"] = "keepActive" + else: + self.df_settings = self.df_settings.drop(index=toolName, errors="ignore") + self.df_settings.to_csv(self.settings_csv_path) + + def labelRoiIsCircularRadioButtonToggled(self, checked): + if checked: + self.labelRoiCircularRadiusSpinbox.setDisabled(False) + else: + self.labelRoiCircularRadiusSpinbox.setDisabled(True) + + def onDoubleSpaceBar(self): + how = self.drawIDsContComboBox.currentText() + if how.find("nothing") == -1: + self.storeCurrentAnnotOptions_ax1() + self.drawNothingCheckbox.setChecked(True) + self.annotOptionClicked(sender=self.drawNothingCheckbox, saveSettings=False) + else: + self.restoreAnnotOptions_ax1() + + how = self.annotateRightHowCombobox.currentText() + if how.find("nothing") == -1: + self.storeCurrentAnnotOptions_ax2() + self.drawNothingCheckboxRight.setChecked(True) + self.annotOptionClickedRight( + sender=self.drawNothingCheckboxRight, saveSettings=False + ) + else: + self.restoreAnnotOptions_ax2() + + def pxModeActionToggled(self, checked): + self.df_settings.at["pxMode", "value"] = int(checked) + self.df_settings.to_csv(self.settings_csv_path) + + if not self.isDataLoaded: + return + + if self.highLowResAction.isChecked(): + for ax in range(2): + self.textAnnot[ax].setPxMode(checked) + + self.updateAllImages() + + def relabelSequentialCallback(self): + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer" or mode == "Cell cycle analysis": + self.startBlinkingModeCB() + return + + posData = self.data[self.pos_i] + selectedPos = (posData.pos_foldername,) + if len(self.data) > 1: + selectedPos = self.askSelectPos(action="to process") + if selectedPos is None: + self.logger.info("Re-labelling process stopped.") + return + + self.store_data() + # acdc_df_concat = self.getConcatAcdcDf() + # load.store_unsaved_acdc_df( + # posData, acdc_df_concat, + # log_func=self.logger.info + # ) + # if posData.SizeT > 1: + self.progressWin = apps.QDialogWorkerProgress( + title="Re-labelling sequential", + parent=self, + pbarDesc="Relabelling sequential...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + self.startRelabellingWorker(selectedPos) + + def restoreAnnotOptions_ax1(self, options=None): + if options is None and not hasattr(self, "annotOptionsToRestore"): + return + + if options is None: + options = self.annotOptionsToRestore + + if options is None: + return + + for option, state in options.items(): + checkbox = getattr(self, option) + checkbox.setChecked(state) + + self.setDrawAnnotComboboxText() + self.annotOptionsToRestore = None + + def restoreAnnotOptions_ax2(self): + if not hasattr(self, "annotOptionsToRestoreRight"): + return + + if self.annotOptionsToRestoreRight is None: + return + + for option, state in self.annotOptionsToRestoreRight.items(): + checkbox = getattr(self, option) + checkbox.setChecked(state) + + self.setDrawAnnotComboboxTextRight() + self.annotOptionsToRestoreRight = None + + def restoreAnnotationsOptions(self): + self.restoreAnnotOptions_ax1() + self.restoreAnnotOptions_ax2() + + def restoreSavedSettings(self): + if "how_draw_annotations" in self.df_settings.index: + how = self.df_settings.at["how_draw_annotations", "value"] + self.drawIDsContComboBox.setCurrentText(how) + else: + self.drawIDsContComboBox.setCurrentText("Draw IDs and contours") + + if "how_draw_right_annotations" in self.df_settings.index: + how = self.df_settings.at["how_draw_right_annotations", "value"] + self.annotateRightHowCombobox.setCurrentText(how) + else: + self.annotateRightHowCombobox.setCurrentText( + "Draw IDs and overlay segm. masks" + ) + + if "addNewIDsWhitelistToggle" in self.df_settings.index: + self.addNewIDsWhitelistToggle = ( + (self.df_settings.at["addNewIDsWhitelistToggle", "value"]) == "Yes" + ) + else: + self.addNewIDsWhitelistToggle = True + + self.drawAnnotCombobox_to_options() + self.drawIDsContComboBox_cb(0) + self.annotateRightHowCombobox_cb(0) + + def rtTrackerActionToggled(self, checked): + if not checked: + return + + aliases = utils.aliases_real_time_trackers(reverse=True) + if self.sender().text() in aliases: + trackingAlgo = aliases[self.sender().text()] + else: + trackingAlgo = self.sender().text() + self.df_settings.at["tracking_algorithm", "value"] = trackingAlgo + self.df_settings.to_csv(self.settings_csv_path) + + if self.sender().text() == "YeaZ": + msg = widgets.myMessageBox(wrapText=False) + info_txt = html_utils.paragraph(f""" + Note that YeaZ tracking algorithm tends to be sliglhtly more accurate + overall, but it is less capable of detecting segmentation + errors.

    + If you need to correct as many segmentation errors as possible + we recommend using Cell-ACDC tracking algorithm. + """) + msg.information(self, "Info about YeaZ", info_txt) + + self.isRealTimeTrackerInitialized = False + self.initRealTimeTracker() + + def setAllTextAnnotations(self, labelsToSkip=None): + delROIsIDs = self.setLostNewOldPrevIDs() + posData = self.data[self.pos_i] + self.textAnnot[0].setAnnotations( + posData=posData, + labelsToSkip=labelsToSkip, + isVisibleCheckFunc=self.isObjVisible, + highlightedID=self.highlightedID, + delROIsIDs=delROIsIDs, + annotateLost=self.annotLostObjsToggle.isChecked(), + getCurrentZfunc=self.z_lab, + getObjCentroidFunc=self.getObjCentroid, + ) + self.textAnnot[1].setAnnotations( + posData=posData, + labelsToSkip=labelsToSkip, + isVisibleCheckFunc=self.isObjVisible, + highlightedID=self.highlightedID, + delROIsIDs=delROIsIDs, + annotateLost=self.annotLostObjsToggle.isChecked(), + getCurrentZfunc=self.z_lab, + getObjCentroidFunc=self.getObjCentroid, + ) + self.textAnnot[0].update() + self.textAnnot[1].update() + return delROIsIDs + + def setAnnotInfoMode(self, checked): + if checked: + for action in self.annotSettingsIDmenu.actions(): + if action.text().find("tree") != -1: + self.textAnnot[0].setLabelTreeAnnotationsEnabled(True) + action.setChecked(True) + break + for action in self.annotSettingsGenNumMenu.actions(): + if action.text().find("tree") != -1: + self.textAnnot[0].setGenNumTreeAnnotationsEnabled(True) + action.setChecked(True) + break + else: + for action in self.annotSettingsIDmenu.actions(): + if action.text().find("tree") == -1: + action.setChecked(False) + self.textAnnot[0].setLabelTreeAnnotationsEnabled(False) + break + for action in self.annotSettingsGenNumMenu.actions(): + if action.text().find("tree") == -1: + action.setChecked(False) + self.textAnnot[0].setGenNumTreeAnnotationsEnabled(False) + break + self.setAllTextAnnotations() + + def setAnnotOptionsCcaMode(self): + self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1(return_value=True) + self.annotCcaInfoCheckbox.setChecked(True) + self.annotIDsCheckbox.setChecked(False) + self.drawMothBudLinesCheckbox.setChecked(False) + self.setDrawAnnotComboboxText() + + def setAnnotOptionsLin_treeMode(self): + # self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1( + # return_value=True + # ) + self.annotCcaInfoCheckbox.setChecked(True) + self.annotIDsCheckbox.setChecked(False) + self.drawMothBudLinesCheckbox.setChecked(False) + self.setDrawAnnotComboboxText() + self.showTreeInfoCheckbox.setChecked(True) + + def setDisabledAnnotCheckBoxesLeft(self, disabled): + self.annotIDsCheckbox.setDisabled(disabled) + self.annotCcaInfoCheckbox.setDisabled(disabled) + self.annotContourCheckbox.setDisabled(disabled) + self.annotSegmMasksCheckbox.setDisabled(disabled) + self.drawMothBudLinesCheckbox.setDisabled(disabled) + self.annotNumZslicesCheckbox.setDisabled(disabled) + self.drawNothingCheckbox.setDisabled(disabled) + + def setDisabledAnnotCheckBoxesRight(self, disabled): + self.annotIDsCheckboxRight.setDisabled(disabled) + self.annotCcaInfoCheckboxRight.setDisabled(disabled) + self.annotContourCheckboxRight.setDisabled(disabled) + self.annotSegmMasksCheckboxRight.setDisabled(disabled) + self.drawMothBudLinesCheckboxRight.setDisabled(disabled) + self.annotNumZslicesCheckboxRight.setDisabled(disabled) + self.drawNothingCheckboxRight.setDisabled(disabled) + + def setDisabledAnnotOptions(self, disabled): + # Left + self.annotIDsCheckbox.setDisabled(disabled) + self.annotCcaInfoCheckbox.setDisabled(disabled) + self.annotContourCheckbox.setDisabled(disabled) + # self.annotSegmMasksCheckbox.setDisabled(disabled) + self.drawMothBudLinesCheckbox.setDisabled(disabled) + # self.drawNothingCheckbox.setDisabled(disabled) + + # Right + self.annotIDsCheckboxRight.setDisabled(disabled) + self.annotCcaInfoCheckboxRight.setDisabled(disabled) + self.annotContourCheckboxRight.setDisabled(disabled) + # self.annotSegmMasksCheckboxRight.setDisabled(disabled) + self.drawMothBudLinesCheckboxRight.setDisabled(disabled) + + def setDrawAnnotComboboxText(self, saveSettings=True): + if self.annotIDsCheckbox.isChecked(): + if self.annotContourCheckbox.isChecked(): + t = "Draw IDs and contours" + elif self.annotSegmMasksCheckbox.isChecked(): + t = "Draw IDs and overlay segm. masks" + else: + t = "Draw only IDs" + + elif self.annotCcaInfoCheckbox.isChecked(): + if self.annotContourCheckbox.isChecked(): + t = "Draw cell cycle info and contours" + elif self.annotSegmMasksCheckbox.isChecked(): + t = "Draw cell cycle info and overlay segm. masks" + else: + t = "Draw only cell cycle info" + + elif self.annotSegmMasksCheckbox.isChecked(): + t = "Draw only overlay segm. masks" + + elif self.annotContourCheckbox.isChecked(): + t = "Draw only contours" + + elif self.drawMothBudLinesCheckbox.isChecked(): + t = "Draw only mother-bud lines" + + elif self.drawNothingCheckbox.isChecked(): + t = "Draw nothing" + else: + t = "Draw nothing" + + if t == self.drawIDsContComboBox.currentText(): + self.drawIDsContComboBox_cb(0) + + self.drawIDsContComboBox.saveSettings = saveSettings + self.drawIDsContComboBox.setCurrentText(t) + + def setDrawAnnotComboboxTextRight(self, saveSettings=True): + if self.annotIDsCheckboxRight.isChecked(): + if self.annotContourCheckboxRight.isChecked(): + t = "Draw IDs and contours" + elif self.annotSegmMasksCheckboxRight.isChecked(): + t = "Draw IDs and overlay segm. masks" + else: + t = "Draw only IDs" + + elif self.annotCcaInfoCheckboxRight.isChecked(): + if self.annotContourCheckboxRight.isChecked(): + t = "Draw cell cycle info and contours" + elif self.annotSegmMasksCheckboxRight.isChecked(): + t = "Draw cell cycle info and overlay segm. masks" + else: + t = "Draw only cell cycle info" + + elif self.annotSegmMasksCheckboxRight.isChecked(): + t = "Draw only overlay segm. masks" + + elif self.annotContourCheckboxRight.isChecked(): + t = "Draw only contours" + + elif self.drawMothBudLinesCheckboxRight.isChecked(): + t = "Draw only mother-bud lines" + + elif self.drawNothingCheckboxRight.isChecked(): + t = "Draw nothing" + else: + t = "Draw nothing" + + if t == self.annotateRightHowCombobox.currentText(): + self.annotateRightHowCombobox_cb(0) + + self.annotateRightHowCombobox.saveSettings = saveSettings + self.annotateRightHowCombobox.setCurrentText(t) + + def setDrawNothingAnnotations(self): + self.storeCurrentAnnotOptions_ax1() + self.storeCurrentAnnotOptions_ax2() + self.drawNothingCheckbox.setChecked(True) + self.annotOptionClicked(sender=self.drawNothingCheckbox, saveSettings=False) + self.drawNothingCheckboxRight.setChecked(True) + self.annotOptionClickedRight( + sender=self.drawNothingCheckboxRight, saveSettings=False + ) + + def setEnabledAnnotCheckBoxesLeftZdepthAxes(self): + if not self.isSegm3D: + return + + self.annotIDsCheckbox.setDisabled(False) + self.annotContourCheckbox.setDisabled(False) + self.annotIDsCheckbox.setChecked(True) + self.annotContourCheckbox.setChecked(True) + + self.annotOptionClicked(sender=self.annotIDsCheckbox, saveSettings=False) + + def setVisible3DsegmWidgets(self): + self.annotNumZslicesCheckbox.setVisible(self.isSegm3D) + self.annotNumZslicesCheckboxRight.setVisible(self.isSegm3D) + if not self.isSegm3D: + self.annotNumZslicesCheckbox.setChecked(False) + self.annotNumZslicesCheckboxRight.setChecked(False) + + def showHighlightZneighCheckbox(self): + if self.isSegm3D: + layout = self.bottomLeftLayout + # layout.addWidget(self.annotOptionsWidget, 0, 1, 1, 2) + # # layout.removeWidget(self.drawIDsContComboBox) + # # layout.addWidget(self.drawIDsContComboBox, 0, 1, 1, 1, + # # alignment=Qt.AlignCenter + # # ) + # layout.addWidget(self.highlightZneighObjCheckbox, 0, 2, 1, 2, + # alignment=Qt.AlignRight + # ) + self.highlightZneighObjCheckbox.show() + self.highlightZneighObjCheckbox.setChecked(True) + self.highlightZneighObjCheckbox.toggled.connect( + self.highlightZneighLabels_cb + ) + + def storeCurrentAnnotOptions_ax1(self, return_value=False): + if self.annotOptionsToRestore is not None: + return + + checkboxes = [ + "annotIDsCheckbox", + "annotCcaInfoCheckbox", + "annotContourCheckbox", + "annotSegmMasksCheckbox", + "drawMothBudLinesCheckbox", + "annotNumZslicesCheckbox", + "drawNothingCheckbox", + ] + annotOptions = {} + for checkboxName in checkboxes: + checkbox = getattr(self, checkboxName) + annotOptions[checkboxName] = checkbox.isChecked() + if return_value: + return annotOptions + self.annotOptionsToRestore = annotOptions + + def storeCurrentAnnotOptions_ax2(self): + if self.annotOptionsToRestoreRight is not None: + return + + checkboxes = [ + "annotIDsCheckboxRight", + "annotCcaInfoCheckboxRight", + "annotContourCheckboxRight", + "annotSegmMasksCheckboxRight", + "drawMothBudLinesCheckboxRight", + "annotNumZslicesCheckboxRight", + "drawNothingCheckboxRight", + ] + self.annotOptionsToRestoreRight = {} + for checkboxName in checkboxes: + checkbox = getattr(self, checkboxName) + self.annotOptionsToRestoreRight[checkboxName] = checkbox.isChecked() + + def uncheckAnnotOptions(self, left=True, right=True): + # Left + if left: + self.annotIDsCheckbox.setChecked(False) + self.annotCcaInfoCheckbox.setChecked(False) + self.annotContourCheckbox.setChecked(False) + self.annotSegmMasksCheckbox.setChecked(False) + self.drawMothBudLinesCheckbox.setChecked(False) + self.drawNothingCheckbox.setChecked(False) + + # Right + if right: + self.annotIDsCheckboxRight.setChecked(False) + self.annotCcaInfoCheckboxRight.setChecked(False) + self.annotContourCheckboxRight.setChecked(False) + self.annotSegmMasksCheckboxRight.setChecked(False) + self.drawMothBudLinesCheckboxRight.setChecked(False) + self.drawNothingCheckboxRight.setChecked(False) + + def updateAnnotatedIDs(self, oldIDs, newIDs, logger=print): + logger("Updating annotated IDs...") + posData = self.data[self.pos_i] + + mapper = dict(zip(oldIDs, newIDs)) + posData.ripIDs = set([mapper[ripID] for ripID in posData.ripIDs]) + posData.binnedIDs = set([mapper[binID] for binID in posData.binnedIDs]) + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + + customAnnotButtons = list(self.customAnnotDict.keys()) + for button in customAnnotButtons: + customAnnotValues = self.customAnnotDict[button] + annotatedIDs = customAnnotValues["annotatedIDs"][self.pos_i] + mappedAnnotIDs = {} + for frame_i, annotIDs_i in annotatedIDs.items(): + mappedIDs = [mapper[ID] for ID in annotIDs_i] + mappedAnnotIDs[frame_i] = mappedIDs + customAnnotValues["annotatedIDs"][self.pos_i] = mappedAnnotIDs + + def update_rp_metadata(self, draw=True): + posData = self.data[self.pos_i] + # Add to rp dynamic metadata (e.g. cells annotated as dead) + for i, obj in enumerate(posData.rp): + ID = obj.label + obj.excluded = ID in posData.binnedIDs + obj.dead = ID in posData.ripIDs + + def zoomRectActionToggled(self, checked): + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.sender()) + self.connectLeftClickButtons() + self.ax1.addItem(self.zoomRectItem) + else: + self.zoomRectItem.setPos((0, 0)) + self.zoomRectItem.setSize((0, 0)) + self.ax1.removeItem(self.zoomRectItem) + + def zoomRectCancelled(self): + self.isMouseDragImg1 = False + self.zoomRectItem.setPos((0, 0)) + self.zoomRectItem.setSize((0, 0)) + + def zoomRectDone(self): + xRange, yRange = self.ax1.viewRange() + self.zoomRectItem.storeLastRange(xRange, yRange) + + ymin, xmin, ymax, xmax = self.zoomRectItem.bbox() + + self.zoomRectItem.setPos((0, 0)) + self.zoomRectItem.setSize((0, 0)) + + self.ax1.setRange(xRange=(xmin, xmax), yRange=(ymin, ymax), padding=0) + + def showAllContoursToggled(self): + if not self.isDataLoaded: + return + + self.computeAllContours() + self.updateAllImages() diff --git a/cellacdc/mixins/app_shell.py b/cellacdc/mixins/app_shell.py new file mode 100644 index 000000000..73aeb6a10 --- /dev/null +++ b/cellacdc/mixins/app_shell.py @@ -0,0 +1,316 @@ +"""Qt view adapter for the application shell.""" + +from __future__ import annotations + +import os +import re +from datetime import timedelta + +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import QWidget + +from cellacdc import ( + _warnings, + base_cca_dict, + cca_df_colnames, + html_utils, + settings_csv_path, + widgets, +) + +from .actions import Actions +from .session import Session + + +class AppShell(Actions, Session): + """Extracted from guiWin.""" + + def about(self): + pass + + def cleanUpOnError(self): + self.onEscape() + caller = "Cell-ACDC" + if self.module.startswith("spotmax"): + caller = "spotMAX" + txt = f"WARNING: {caller} is in error state. Please, restart." + _hl = "*" * 100 + self.titleLabel.setText(txt, color="r") + self.logger.info(f"{_hl}\n{txt}\n{_hl}") + + def copyContent(self): + pass + + def cutContent(self): + pass + + def determineSlideshowWinPos(self): + screens = self.app.screens() + self.numScreens = len(screens) + winScreen = self.screen() + + # Center main window and determine location of slideshow window + # depending on number of screens available + if self.numScreens > 1: + for screen in screens: + if screen != winScreen: + winScreen = screen + break + + winScreenGeom = winScreen.geometry() + winScreenCenter = winScreenGeom.center() + winScreenCenterX = winScreenCenter.x() + winScreenCenterY = winScreenCenter.y() + winScreenLeft = winScreenGeom.left() + winScreenTop = winScreenGeom.top() + self.slideshowWinLeft = winScreenCenterX - int(850 / 2) + self.slideshowWinTop = winScreenCenterY - int(800 / 2) + + def initGlobalAttr(self): + self.setOverlayColors() + + self.initImgCmap() + + # Colormap + self.setLut() + + self.fluoDataChNameActions = [] + + self.splineHoverON = False + self.tempSegmentON = False + self.xyOnCtrlPressedFirstTime = None + self.typingEditID = False + self.prevAnnotOptions = None + self.ghostObject = None + self.autoContourHoverON = False + self.navigateScrollBarStartedMoving = True + self.zSliceScrollBarStartedMoving = True + self.labelRoiRunning = False + self.isRangeReset = True + self.lastManualSeparateState = None + self.editIDmergeIDs = True + self.doNotAskAgainExistingID = False + self.doubleRightClickTimeElapsed = False + self.isRealTimeTrackerInitialized = False + self.isWarningCcaIntegrity = False + self.isDoubleRightClick = False + self.isExportingVideo = False + self.pointsLayersNeverToggled = True + self.highlightedIDopts = None + self.timestampStartTimedelta = timedelta(seconds=0) + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + self._ZprojWidgersEnabledState = None + self.imgValueFormatter = "d" + self.rawValueFormatter = "d" + self.lastHoverID = -1 + self.annotOptionsToRestore = None + self.annotOptionsToRestoreRight = None + self.rescaleIntensChannelHowMapper = { + self.user_ch_name: "Rescale each 2D image" + } + self.timestampDialog = None + self.scaleBarDialog = None + self.countObjsWindow = None + self.initLabelRoiModelDialog = None + + # Second channel used by cellpose + self.secondChannelName = None + + self.ax1_viewRange = None + self.measurementsWin = None + + self.model_kwargs = None + self.segmModelName = None + self.labelRoiModel = None + self.autoSegmDoNotAskAgain = False + self.labelRoiGarbageWorkers = [] + self.labelRoiActiveWorkers = [] + + self.clickedOnBud = False + self.postProcessSegmWin = None + + self.UserEnforced_DisabledTracking = False + self.UserEnforced_Tracking = False + + self.ax1BrushHoverID = 0 + + self.disabled_cca_warnings = set() + + self.last_pos_i = -1 + self.last_frame_i = -1 + + # Plots items + self.isMouseDragImg2 = False + self.isMouseDragImg1 = False + self.isMovingLabel = False + self.isRightClickDragImg1 = False + self.clickObjYc, self.clickObjXc = None, None + + self.cca_df_colnames = cca_df_colnames + self.cca_df_dtypes = [str, int, int, str, int, int, bool, bool, int] + self.cca_df_default_values = list(base_cca_dict.values()) + self.cca_df_int_cols = [ + col for col in cca_df_colnames if type(base_cca_dict[col]) == int + ] + self.lin_tree_df_bool_col = [ + col for col in cca_df_colnames if isinstance(base_cca_dict[col], bool) + ] + + self.lin_tree_col_checks = [ + "generation_num", + ] + + # self.lin_tree_df_colnames = set(base_cca_df.keys()) | set(lineage_tree_cols) + # # self.lin_tree_df_dtypes = [ #dk if i need this, for now ignored + # # str, int, int, str, int, int, bool, bool, int + # # ] + # self.lin_tree_df_default_values = list(base_cca_df.values()) + lineage_tree_cols_std_val + self.lin_tree_df_int_cols = [ + "generation_num", + "relative_ID", + "emerg_frame_i", + "division_frame_i", + "corrected_on_frame_i", + ] + self.lin_tree_df_bool_col = [ + "is_history_known", + ] + + self.lin_tree_col_checks = [ + "generation_num", + ] + + self.lin_tree_df_colnames = ( + self.lin_tree_df_int_cols + + self.lin_tree_df_bool_col + + self.lin_tree_col_checks + ) + self.SegForLostIDsSettings = {} + + def initProfileModels(self): + self.logger.info("Initiliazing profilers...") + + from ._profile.spline_to_obj import model + + self.splineToObjModel = model.Model() + + self.splineToObjModel.fit() + + def onToggleColorScheme(self): + if self.toggleColorSchemeAction.text().find("light") != -1: + self._colorScheme = "light" + setDarkModeToggleChecked = False + else: + self._colorScheme = "dark" + setDarkModeToggleChecked = True + self.gui_updateSwitchColorSchemeActionText() + _warnings.warnRestartCellACDCcolorModeToggled( + self._colorScheme, app_name=self._appName, parent=self + ) + load.rename_qrc_resources_file(self._colorScheme) + self.statusBarLabel.setText( + html_utils.paragraph( + f"Restart {self._appName} for the change to take effect", + font_color="red", + ) + ) + self.df_settings.at["colorScheme", "value"] = self._colorScheme + self.df_settings.to_csv(settings_csv_path) + + def openLogFile(self): + self.logger.info(f'Opening log file "{self.log_path}"...') + utils.showInExplorer(self.log_path) + + def openNewWindow(self): + self.logger.info("Opening a new window...") + if self.launcherSlot is not None: + self.launcherSlot() + return + + winClass = self.__class__ + win = winClass( + self.app, parent=self, mainWin=self.mainWin, version=self._version + ) + win.run() + self.newWindows.append(win) + + def pasteContent(self): + pass + + def setDisabled( + self, disabled: bool, keepDisabled: bool = None, force: bool = False + ): + if force: + if disabled: + super().setDisabled(disabled) + return + else: + self.keepDisabled = False + super().setDisabled(disabled) + return + + if keepDisabled is not None: + self.keepDisabled = keepDisabled + + if self.keepDisabled: + if disabled: + super().setDisabled(disabled) + return + else: + return + else: + super().setDisabled(disabled) + + def setTooltips(self): + tooltips = load.get_tooltips_from_docs() + + for key, tooltip in tooltips.items(): + setShortcut = getattr(self, key).shortcut().toString() + if "Shortcut: " in tooltip: + tooltip = tooltip.replace("Shortcut: ", "\nShortcut: ") + elif setShortcut != "": + tooltip = re.sub( + r"Shortcut: \"(.*)\"", f'Shortcut: "{setShortcut}"', tooltip + ) + else: + tooltip = re.sub( + r"Shortcut: \"(.*)\"", f'Shortcut: "No shortcut"', tooltip + ) + + getattr(self, key).setToolTip(tooltip) + getattr(self, key)._tooltip = tooltip + + def setWindowIcon(self, icon=None): + if icon is None: + icon = QIcon(":icon.ico") + super().setWindowIcon(icon) + + def setWindowTitle(self, title=None): + if title is None: + title = f"Cell-ACDC v{self._acdc_version} - GUI" + super().setWindowTitle(title) + + def showAbout(self): + from cellacdc.help import about + + self.aboutWin = about.QDialogAbout(parent=self) + self.aboutWin.show() + + def showInExplorer_cb(self): + posData = self.data[self.pos_i] + path = posData.images_path + utils.showInExplorer(path) + + def showLogFiles(self): + log_files_path = os.path.dirname(self.log_path) + self.logger.info(f'Opening log files folder "{log_files_path}"...') + utils.showInExplorer(log_files_path) + + def showTipsAndTricks(self): + from cellacdc.help import welcome + + self.welcomeWin = welcome.welcomeWin() + self.welcomeWin.showAndSetSize() + self.welcomeWin.showPage(self.welcomeWin.quickStartItem) diff --git a/cellacdc/mixins/brush_tools.py b/cellacdc/mixins/brush_tools.py new file mode 100644 index 000000000..334e88018 --- /dev/null +++ b/cellacdc/mixins/brush_tools.py @@ -0,0 +1,599 @@ +"""Qt view adapter for brush and eraser tools.""" + +from __future__ import annotations + +import cv2 +import numpy as np +import skimage.measure +from qtpy.QtWidgets import QCheckBox + +from cellacdc import html_utils, settings_csv_path, widgets + +from .geometry import Geometry +from .tool_activation import ToolActivation + + +class BrushTools(Geometry, ToolActivation): + """Extracted from guiWin.""" + + def Brush_cb(self, checked): + if checked: + self.typingEditID = False + self.setDiskMask() + self.setHoverToolSymbolData( + [], + [], + ( + self.ax1_EraserCircle, + self.ax2_EraserCircle, + self.ax1_EraserX, + self.ax2_EraserX, + ), + ) + self.updateBrushCursor(self.xHoverImg, self.yHoverImg) + self.setBrushID() + + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.sender()) + c = self.defaultToolBarButtonColor + self.eraserButton.setStyleSheet(f"background-color: {c}") + self.connectLeftClickButtons() + self.setFocusGraphics() + else: + self.ax1_lostObjScatterItem.setVisible(True) + self.ax2_lostObjScatterItem.setVisible(True) + self.ax1_lostTrackedScatterItem.setVisible(True) + self.ax2_lostTrackedScatterItem.setVisible(True) + + self.setHoverToolSymbolData( + [], + [], + (self.ax2_BrushCircle, self.ax1_BrushCircle), + ) + self.resetCursors() + + self.showEditIDwidgets(checked) + self.enableSizeSpinbox(checked) + + def Eraser_cb(self, checked): + if checked: + self.setDiskMask() + self.setHoverToolSymbolData( + [], + [], + (self.ax2_BrushCircle, self.ax1_BrushCircle), + ) + self.updateEraserCursor(self.xHoverImg, self.yHoverImg) + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.sender()) + c = self.defaultToolBarButtonColor + self.brushButton.setStyleSheet(f"background-color: {c}") + self.connectLeftClickButtons() + else: + self.setHoverToolSymbolData( + [], + [], + ( + self.ax1_EraserCircle, + self.ax2_EraserCircle, + self.ax1_EraserX, + self.ax2_EraserX, + ), + ) + self.resetCursors() + self.updateAllImages() + + self.showEditIDwidgets(checked) + self.enableSizeSpinbox(checked) + + def _setTempImageBrushContour(self): + pass + + def applyBrushMask(self, mask, ID, toLocalSlice=None): + posData = self.data[self.pos_i] + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + if toLocalSlice is not None: + toLocalSlice = (self.z_lab(), *toLocalSlice) + posData.lab[toLocalSlice][mask] = ID + else: + posData.lab[self.z_lab()][mask] = ID + else: + if toLocalSlice is not None: + for z in range(len(posData.lab)): + _slice = (z, *toLocalSlice) + posData.lab[_slice][mask] = ID + else: + posData.lab[:, mask] = ID + else: + if toLocalSlice is not None: + posData.lab[toLocalSlice][mask] = ID + else: + posData.lab[mask] = ID + + def applyEraserMask(self, mask): + posData = self.data[self.pos_i] + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + posData.lab[self.z_lab(), mask] = 0 + else: + posData.lab[:, mask] = 0 + else: + posData.lab[mask] = 0 + + def autoIDtoggled(self, checked): + self.editIDspinboxAction.setDisabled(checked) + self.editIDLabelAction.setDisabled(checked) + if not checked and self.editIDspinbox.value() == 0: + newID = self.setBrushID(return_val=True) + self.editIDspinbox.setValue(newID) + + def brushAutoFillToggled(self, checked): + val = "Yes" if checked else "No" + self.df_settings.at["brushAutoFill", "value"] = val + self.df_settings.to_csv(self.settings_csv_path) + + def brushAutoHideToggled(self, checked): + val = "Yes" if checked else "No" + self.df_settings.at["brushAutoHide", "value"] = val + self.df_settings.to_csv(self.settings_csv_path) + + def brushReleased(self): + posData = self.data[self.pos_i] + self.fillHolesID(posData.brushID, sender="brush") + + # Update data (rp, etc) + self.update_rp( + update_IDs=self.isNewID, + ) + + # Repeat tracking + if self.autoIDcheckbox.isChecked(): + self.trackManuallyAddedObject(posData.brushID, self.isNewID) + else: + self.update_rp(update_IDs=posData.brushID not in posData.IDs_idxs) + + # Update images + if self.isNewID: + editTxt = "Add new ID with brush tool" + if self.isSnapshot: + self.fixCcaDfAfterEdit(editTxt) + self.updateAllImages() + else: + self.warnEditingWithCca_df(editTxt) + else: + self.updateAllImages() + + self.isNewID = False + + def brushSize_cb(self, value): + self.ax2_EraserCircle.setSize(value * 2) + self.ax1_BrushCircle.setSize(value * 2) + self.ax2_BrushCircle.setSize(value * 2) + self.ax1_EraserCircle.setSize(value * 2) + self.ax2_EraserX.setSize(value) + self.ax1_EraserX.setSize(value) + self.setDiskMask() + + def changeBrushID(self): + """Function called when pressing or releasing shift""" + if not self.isSegm3D: + # Changing brush ID with shift is only for 3D segm + return + + if not self.brushButton.isChecked(): + # Brush if not active + return + + if not self.isMouseDragImg2 and not self.isMouseDragImg1: + # Mouse is not brushing at the moment + return + + posData = self.data[self.pos_i] + forceNewObj = not self.isNewID + + if forceNewObj: + # Shift is down --> force new object with brush + # e.g., 24 --> 28: + # 24 is hovering ID that we store as self.prevBrushID + # 24 object becomes 28 that is the new posData.brushID + self.isNewID = True + self.changedID = posData.brushID + self.restoreBrushID = posData.brushID + # Set a new ID + self.setBrushID() + else: + # Shift released or hovering on ID in z+-1 + # --> restore brush ID from before shift was pressed or from + # when we started brushing from outside an object + # but we hovered on ID in z+-1 while dragging. + # We change the entire 28 object to 24 so before changing the + # brush ID back to 24 we builg the mask with 28 to change it to 24 + self.isNewID = False + self.changedID = posData.brushID + # Restore ID + posData.brushID = self.restoreBrushID + + brushID = posData.brushID + brushIDmask = self.get_2Dlab(posData.lab) == self.changedID + self.applyBrushMask(brushIDmask, brushID) + if self.isMouseDragImg1: + self.brushColor = self.lut[posData.brushID] / 255 + self.setTempImg1Brush(True, brushIDmask, posData.brushID) + + def checkWarnDeletedIDwithEraser(self): + posData = self.data[self.pos_i] + + for ID in self.erasedIDs: + if ID == 0: + continue + if ID in posData.IDs_idxs: + continue + + self.instructHowDeleteID() + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Delete ID with eraser") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Delete ID with eraser") + + return True + + return False + + def clearObjFromMask(self, image, mask, toLocalSlice=None): + if mask is None: + return image + + if toLocalSlice is None: + image[mask] = 0 + else: + image[toLocalSlice][mask] = 0 + + return image + + def fillHolesID(self, ID, sender="brush"): + posData = self.data[self.pos_i] + if sender == "brush": + if not self.brushAutoFillCheckbox.isChecked(): + return False + + lab2D = self.get_2Dlab(posData.lab) + mask = lab2D == ID + filledMask = scipy.ndimage.binary_fill_holes(mask) + lab2D[filledMask] = ID + + self.set_2Dlab(lab2D) + return True + return False + + def getDiskMask(self, xdata, ydata): + Y, X = self.currentLab2D.shape[-2:] + + brushSize = self.brushSizeSpinbox.value() + yBottom, xLeft = ydata - brushSize, xdata - brushSize + yTop, xRight = ydata + brushSize + 1, xdata + brushSize + 1 + + if xLeft < 0: + if yBottom < 0: + # Disk mask out of bounds top-left + diskMask = self.diskMask.copy() + diskMask = diskMask[-yBottom:, -xLeft:] + yBottom = 0 + elif yTop > Y: + # Disk mask out of bounds bottom-left + diskMask = self.diskMask.copy() + diskMask = diskMask[0 : Y - yBottom, -xLeft:] + yTop = Y + else: + # Disk mask out of bounds on the left + diskMask = self.diskMask.copy() + diskMask = diskMask[:, -xLeft:] + xLeft = 0 + + elif xRight > X: + if yBottom < 0: + # Disk mask out of bounds top-right + diskMask = self.diskMask.copy() + diskMask = diskMask[-yBottom:, 0 : X - xLeft] + yBottom = 0 + elif yTop > Y: + # Disk mask out of bounds bottom-right + diskMask = self.diskMask.copy() + diskMask = diskMask[0 : Y - yBottom, 0 : X - xLeft] + yTop = Y + else: + # Disk mask out of bounds on the right + diskMask = self.diskMask.copy() + diskMask = diskMask[:, 0 : X - xLeft] + xRight = X + + elif yBottom < 0: + # Disk mask out of bounds on top + diskMask = self.diskMask.copy() + diskMask = diskMask[-yBottom:] + yBottom = 0 + + elif yTop > Y: + # Disk mask out of bounds on bottom + diskMask = self.diskMask.copy() + diskMask = diskMask[0 : Y - yBottom] + yTop = Y + + else: + # Disk mask fully inside the image + diskMask = self.diskMask + + return yBottom, xLeft, yTop, xRight, diskMask + + def getLabelsLayerImage(self, ax=0): + if ax == 0: + return self.labelsLayerImg1.image + else: + return self.labelsLayerRightImg.image + + def getMagicWandFloodTolerance(self): + tol_perc = self.wandControlsToolbar.toleranceSpinbox.value() + if tol_perc == 0: + return + + posData = self.data[self.pos_i] + _min, _max = posData.img_data_min_max + tol_fraction = tol_perc / 100 + tol = (_max - _min) * tol_fraction + + return tol + + def initFloodMaskImage(self): + posData = self.data[self.pos_i] + self.flood_img = posData.img_data[posData.frame_i] + if not self.isSegm3D and posData.SizeZ > 1: + self.flood_img = self.get_2Dimg_from_3D(self.flood_img) + return + + def initTempLayerBrush(self, ID, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + self.hideItemsHoverBrush(ID=ID, force=True) + Y, X = self.img1.image.shape[:2] + tempImage = np.zeros((Y, X), dtype=np.uint32) + if how.find("contours") != -1: + tempImage[self.currentLab2D == ID] = ID + self.brushImage = tempImage.copy() + self.brushContourImage = np.zeros((Y, X, 4), dtype=np.uint8) + color = self.imgGrad.contoursColorButton.color() + self.brushContoursRgba = color.getRgb() + opacity = 1.0 + else: + opacity = self.imgGrad.labelsAlphaSlider.value() + color = self.lut[ID] + lut = np.zeros((2, 4), dtype=np.uint8) + lut[1, -1] = 255 + lut[1, :-1] = color + self.tempLayerImg1.setLookupTable(lut) + self.tempLayerImg1.setOpacity(opacity) + self.tempLayerImg1.setImage(tempImage, force_set_linked=True) + + def instructHowDeleteID(self): + if "showInfoDeleteObject" not in self.df_settings.index: + self.df_settings.at["showInfoDeleteObject", "value"] = "Yes" + + showInfoDeleteObject = ( + self.df_settings.at["showInfoDeleteObject", "value"] == "Yes" + ) + if not showInfoDeleteObject: + return + + actionText = self.middleClickText() + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "You have deleted an object using the eraser tool.

    " + 'Did you know that you can use the "Delete object" action
    ' + "to delete an object with a single click?

    " + f"To do so, use the following action: {actionText}

    " + "Note: You can also set a custom shortcut by going to the menu
    " + "Settings --> Customise keyboard shortcuts...." + ) + doNotShowAgainCheckbox = QCheckBox("Do not show again") + msg.information( + self, + "Delete objects with single click", + txt, + widgets=doNotShowAgainCheckbox, + ) + + showInfoDeleteObjectValue = ( + "No" if doNotShowAgainCheckbox.isChecked() else "Yes" + ) + self.df_settings.at["showInfoDeleteObject", "value"] = showInfoDeleteObjectValue + self.df_settings.to_csv(settings_csv_path) + + def resetCursors(self): + self.ax1_cursor.setData([], []) + self.ax2_cursor.setData([], []) + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + def setBrushID(self, useCurrentLab=True, return_val=False): + # Make sure that the brushed ID is always a new one based on + # already visited frames + posData = self.data[self.pos_i] + wl_init = posData.whitelist and posData.whitelist.whitelistIDs + if useCurrentLab: + IDs_tot = set(posData.IDs) + if wl_init: + try: + IDs_tot.update(posData.whitelist.originalLabsIDs[posData.frame_i]) + except: + pass + try: + if posData.whitelist.whitelistIDs[posData.frame_i]: + IDs_tot.update(posData.whitelist.whitelistIDs[posData.frame_i]) + except: + pass + newID = max(IDs_tot, default=0) + else: + newID = 0 + for frame_i, storedData in enumerate(posData.allData_li): + if frame_i == posData.frame_i: + continue + lab = storedData["labels"] + if lab is not None: + rp = storedData["regionprops"] + IDs_tot = {obj.label for obj in rp} + if wl_init: + if self.whitelistCheckOriginalLabels( + warning=False, frame_i=frame_i + ): + IDs_tot.update(posData.whitelist.originalLabsIDs[frame_i]) + if posData.whitelist.whitelistIDs[frame_i]: + IDs_tot.update(posData.whitelist.whitelistIDs[frame_i]) + _max = max(IDs_tot, default=0) + if _max > newID: + newID = _max + else: + break + + for y, x, manual_ID in posData.editID_info: + if manual_ID > newID: + newID = manual_ID + posData.brushID = newID + 1 + if return_val: + return posData.brushID + + def setDiskMask(self): + brushSize = self.brushSizeSpinbox.value() + # diam = brushSize*2 + # center = (brushSize, brushSize) + # diskShape = (diam+1, diam+1) + # diskMask = np.zeros(diskShape, bool) + # rr, cc = skimage.draw.disk(center, brushSize+1, shape=diskShape) + # diskMask[rr, cc] = True + self.diskMask = skimage.morphology.disk(brushSize, dtype=bool) + + def setTempBrushMaskFromWand(self, flood_mask, init=False): + if not np.any(flood_mask): + return + + posData = self.data[self.pos_i] + mask = np.logical_or(flood_mask, posData.lab == posData.brushID) + if mask.ndim == 3: + z_slice = self.zSliceScrollBar.sliderPosition() + mask = mask[z_slice] + + self.setTempImg1Brush(init, mask, posData.brushID) + + def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0): + if init: + self.initTempLayerBrush(ID, ax=ax) + + if self.annotContourCheckbox.isChecked(): + brushImage = self.brushImage + else: + brushImage = self.tempLayerImg1.image + + if toLocalSlice is None: + brushImage[mask] = ID + else: + brushImage[toLocalSlice][mask] = ID + + if self.annotContourCheckbox.isChecked(): + try: + obj = skimage.measure.regionprops(brushImage)[0] + except IndexError: + return + objContour = [self.getObjContours(obj)] + # objContour = core.get_obj_contours( + # obj_image=(brushImage>0).astype(np.uint8), local=True + # ) + self.brushContourImage[:] = 0 + img = self.brushContourImage + color = self.brushContoursRgba + cv2.drawContours(img, objContour, -1, color, 1) + self.tempLayerImg1.setImage(img, force_set_linked=True) + else: + self.tempLayerImg1.setImage(brushImage, force_set_linked=True) + + def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0): + if init: + self.erasedLab = np.zeros_like(self.currentLab2D) + + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + if ax == 1 and not self.labelsGrad.showRightImgAction.isChecked(): + return + + if how.find("contours") != -1: + self.clearObjFromMask(self.contoursImage, mask, toLocalSlice=toLocalSlice) + erasedRp = skimage.measure.regionprops(self.erasedLab) + for obj in erasedRp: + self.addObjContourToContoursImage(obj=obj, ax=ax) + elif how.find("overlay segm. masks") != -1: + labelsImage = self.getLabelsLayerImage(ax=ax) + self.clearObjFromMask(labelsImage, mask, toLocalSlice=toLocalSlice) + if ax == 0: + self.labelsLayerImg1.setImage( + self.labelsLayerImg1.image, autoLevels=False + ) + else: + self.labelsLayerRightImg.setImage( + self.labelsLayerRightImg.image, autoLevels=False + ) + + def showEditIDwidgets(self, visible): + self.editIDLabelAction.setVisible(visible) + self.editIDspinboxAction.setVisible(visible) + self.autoIDcheckboxAction.setVisible(visible) + showToolbar = ( + visible + or self.brushSizeAction.isVisible() + or self.brushAutoFillAction.isVisible() + or self.brushAutoHideAction.isVisible() + ) + self.brushEraserToolBar.setVisible(showToolbar) + + def updateEraserCursor(self, x, y, xyLocked=None, isHoverImg1=True): + if x is None: + return + + xdata, ydata = int(x), int(y) + _img = self.currentLab2D + Y, X = _img.shape + + if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): + return + + size = self.brushSizeSpinbox.value() * 2 + self.setHoverToolSymbolData( + [x], [y], self.activeEraserCircleCursors(isHoverImg1), size=size + ) + self.setHoverToolSymbolData( + [x], [y], self.activeEraserXCursors(isHoverImg1), size=int(size / 2) + ) + + isMouseDrag = self.isMouseDragImg1 or self.isMouseDragImg2 + if isMouseDrag: + return + + if xyLocked is not None: + xdata, ydata = xyLocked + + self.setHoverToolSymbolColor( + xdata, + ydata, + self.eraserCirclePen, + self.activeEraserCircleCursors(isHoverImg1), + self.eraserButton, + hoverRGB=None, + ) diff --git a/cellacdc/mixins/canvas_context_menu.py b/cellacdc/mixins/canvas_context_menu.py new file mode 100644 index 000000000..14aa2cbf9 --- /dev/null +++ b/cellacdc/mixins/canvas_context_menu.py @@ -0,0 +1,129 @@ +"""View adapter for canvas context menus and deleted-ROI clicks.""" + +from __future__ import annotations + +import pyqtgraph as pg +from qtpy.QtCore import QPoint +from qtpy.QtWidgets import QAction, QMenu + +from .image_display import ImageDisplay + + +class CanvasContextMenu(ImageDisplay): + """Extracted from guiWin.""" + + def gui_clickedDelRoi(self, event, left_click, right_click): + posData = self.data[self.pos_i] + x, y = event.pos().x(), event.pos().y() + + # Check if right click on ROI + delROIs = posData.allData_li[posData.frame_i]["delROIs_info"]["rois"].copy() + for r, roi in enumerate(delROIs): + ROImask = self.getDelRoiMask(roi) + if self.isSegm3D: + clickedOnROI = ROImask[self.z_lab(), int(y), int(x)] + else: + clickedOnROI = ROImask[int(y), int(x)] + raiseContextMenuRoi = right_click and clickedOnROI + dragRoi = left_click and clickedOnROI + if raiseContextMenuRoi: + self.roi_to_del = roi + self.roiContextMenu = QMenu(self) + separator = QAction(self) + separator.setSeparator(True) + self.roiContextMenu.addAction(separator) + action = QAction("Remove ROI") + action.triggered.connect(self.removeDelROI) + self.roiContextMenu.addAction(action) + try: + # Convert QPointF to QPoint + self.roiContextMenu.exec_(event.screenPos().toPoint()) + except AttributeError: + self.roiContextMenu.exec_(event.screenPos()) + return True + elif dragRoi: + event.ignore() + return True + return False + + def checkHighlightScaleBar(self, x, y, activeToolButton): + if not hasattr(self, "scaleBar"): + return + + if not self.addScaleBarAction.isChecked(): + return + + if activeToolButton is not None: + return + + ymin, xmin, ymax, xmax = self.scaleBar.bbox() + if x < xmin: + self.scaleBar.setHighlighted(False) + return + + if x > xmax: + self.scaleBar.setHighlighted(False) + return + + if y < ymin: + self.scaleBar.setHighlighted(False) + return + + if y > ymax: + self.scaleBar.setHighlighted(False) + return + + self.scaleBar.setHighlighted(True) + + def checkHighlightTimestamp(self, x, y, activeToolButton): + if not hasattr(self, "timestamp"): + return + + if not self.addTimestampAction.isChecked(): + return + + if activeToolButton is not None: + return + + if hasattr(self, "scaleBar"): + if self.scaleBar.isHighlighted(): + return + + ymin, xmin, ymax, xmax = self.timestamp.bbox() + if x < xmin: + self.timestamp.setHighlighted(False) + return + + if x > xmax: + self.timestamp.setHighlighted(False) + return + + if y < ymin: + self.timestamp.setHighlighted(False) + return + + if y > ymax: + self.timestamp.setHighlighted(False) + return + + self.timestamp.setHighlighted(True) + + def gui_imgGradShowContextMenu(self, x, y): + if hasattr(self, "scaleBar"): + if self.scaleBar.isHighlighted(): + self.scaleBar.showContextMenu(x, y) + return + + if hasattr(self, "timestamp"): + if self.timestamp.isHighlighted(): + self.timestamp.showContextMenu(x, y) + return + + self.imgGrad.gradient.menu.popup(QPoint(int(x), int(y))) + + def gui_rightImageShowContextMenu(self, event): + try: + # Convert QPointF to QPoint + self.imgGradRight.gradient.menu.popup(event.screenPos().toPoint()) + except AttributeError: + self.imgGradRight.gradient.menu.popup(event.screenPos()) diff --git a/cellacdc/mixins/canvas_drawing.py b/cellacdc/mixins/canvas_drawing.py new file mode 100644 index 000000000..748ae3a0a --- /dev/null +++ b/cellacdc/mixins/canvas_drawing.py @@ -0,0 +1,624 @@ +"""Qt view adapter for canvas drawing interactions.""" + +from __future__ import annotations + +import numpy as np +import skimage.segmentation + +from qtpy.QtCore import Qt +from qtpy.QtGui import QGuiApplication +from qtpy.QtWidgets import QAction, QMessageBox + +from cellacdc import apps, exception_handler, html_utils, widgets + +from .canvas_selection import CanvasSelection +from .label_editing import LabelEditing + + +class CanvasDrawing(CanvasSelection, LabelEditing): + """Extracted from guiWin.""" + + def gui_addCreatedAxesItems(self): + self.ax1.addItem(self.ax1_contoursImageItem) + self.ax1.addItem(self.ax1_lostObjImageItem) + self.ax1.addItem(self.ax1_lostTrackedObjImageItem) + self.ax1.addItem(self.ax1_oldMothBudLinesItem) + self.ax1.addItem(self.ax1_newMothBudLinesItem) + self.ax1.addItem(self.ax1_lostObjScatterItem) + self.ax1.addItem(self.ax1_lostTrackedScatterItem) + self.ax1.addItem(self.ccaFailedScatterItem) + self.ax1.addItem(self.yellowContourScatterItem) + + self.ax2.addItem(self.ax2_contoursImageItem) + self.ax2.addItem(self.ax2_lostObjImageItem) + self.ax2.addItem(self.ax2_lostTrackedObjImageItem) + self.ax2.addItem(self.ax2_oldMothBudLinesItem) + self.ax2.addItem(self.ax2_newMothBudLinesItem) + self.ax2.addItem(self.ax2_lostObjScatterItem) + + self.textAnnot[0].addToPlotItem(self.ax1) + self.textAnnot[1].addToPlotItem(self.ax2) + + self.ax1.addItem(self.exportMaskImageItem) + self.ax1.exportMaskImageItem = self.exportMaskImageItem + + def gui_mouseDragEventImg1(self, event): + x, y = event.pos().x(), event.pos().y() + + if hasattr(self, "scaleBar"): + if self.scaleBarDialog is not None: + self.scaleBarDialog.locCombobox.setCurrentText("Custom") + if self.scaleBar.isHighlighted() and self.scaleBar.clicked: + self.scaleBar.setLocationProperty("custom") + self.scaleBar.move(x, y) + return + + if hasattr(self, "timestamp"): + if self.timestampDialog is not None: + self.timestampDialog.locCombobox.setCurrentText("Custom") + if self.timestamp.isHighlighted() and self.timestamp.clicked: + self.timestamp.setLocationProperty("custom") + self.timestamp.move(x, y) + return + + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + return + + posData = self.data[self.pos_i] + Y, X = self.get_2Dlab(posData.lab).shape + xdata, ydata = int(x), int(y) + if not utils.is_in_bounds(xdata, ydata, X, Y): + return + + if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): + self.drawAutoContour(y, x) + + # Brush dragging mouse --> keep brushing + elif self.isMouseDragImg1 and self.brushButton.isChecked(): + lab_2D = self.get_2Dlab(posData.lab) + + # t1 = time.perf_counter() + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) + + # t2 = time.perf_counter() + + diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) + + # Build brush mask + mask = np.zeros(lab_2D.shape, bool) + mask[diskSlice][diskMask] = True + mask[rrPoly, ccPoly] = True + + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + + # t3 = time.perf_counter() + if not self.isPowerBrush() and not ctrl: + mask[lab_2D != 0] = False + self.setHoverToolSymbolColor( + xdata, + ydata, + self.ax2_BrushCirclePen, + (self.ax2_BrushCircle, self.ax1_BrushCircle), + self.brushButton, + brush=self.ax2_BrushCircleBrush, + ) + + # t4 = time.perf_counter() + + # Apply brush mask + self.applyBrushMask(mask, posData.brushID) + + self.setImageImg2(updateLookuptable=False) + + # t5 = time.perf_counter() + + lab2D = self.get_2Dlab(posData.lab) + brushMask = np.logical_and(lab2D[diskSlice] == posData.brushID, diskMask) + self.setTempImg1Brush( + False, brushMask, posData.brushID, toLocalSlice=diskSlice + ) + + # t6 = time.perf_counter() + + # printl( + # 'Brush exec times =\n' + # f' * {(t1-t0)*1000 = :.4f} ms\n' + # f' * {(t2-t1)*1000 = :.4f} ms\n' + # f' * {(t3-t2)*1000 = :.4f} ms\n' + # f' * {(t4-t3)*1000 = :.4f} ms\n' + # f' * {(t5-t4)*1000 = :.4f} ms\n' + # f' * {(t6-t5)*1000 = :.4f} ms\n' + # f' * {(t6-t0)*1000 = :.4f} ms' + # ) + + # Eraser dragging mouse --> keep erasing + elif self.isMouseDragImg1 and self.eraserButton.isChecked(): + posData = self.data[self.pos_i] + lab_2D = self.get_2Dlab(posData.lab) + rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + + diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) + + # Build eraser mask + mask = np.zeros(lab_2D.shape, bool) + mask[ymin:ymax, xmin:xmax][diskMask] = True + mask[rrPoly, ccPoly] = True + + if self.eraseOnlyOneID: + mask[lab_2D != self.erasedID] = False + self.setHoverToolSymbolColor( + xdata, + ydata, + self.eraserCirclePen, + (self.ax2_EraserCircle, self.ax1_EraserCircle), + self.eraserButton, + hoverRGB=self.img2.lut[self.erasedID], + ID=self.erasedID, + ) + + self.erasedIDs.update(lab_2D[mask]) + self.applyEraserMask(mask) + + self.setImageImg2() + + for erasedID in self.erasedIDs: + if erasedID == 0: + continue + self.erasedLab[lab_2D == erasedID] = erasedID + self.erasedLab[mask] = 0 + + eraserMask = mask[diskSlice] + self.setTempImg1Eraser(eraserMask, toLocalSlice=diskSlice) + self.setTempImg1Eraser(eraserMask, toLocalSlice=diskSlice, ax=1) + + # Move label dragging mouse --> keep moving + elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + self.moveLabel(x, y) + + # Wand dragging mouse --> keep doing the magic + elif self.isMouseDragImg1 and self.wandToolButton.isChecked(): + tol = self.getMagicWandFloodTolerance() + if self.isSegm3D: + z_slice = self.zSliceScrollBar.sliderPosition() + seed = (z_slice, ydata, xdata) + else: + seed = (ydata, xdata) + + flood_mask = skimage.segmentation.flood(self.flood_img, seed, tolerance=tol) + drawUnderMask = np.logical_or( + posData.lab == 0, posData.lab == posData.brushID + ) + flood_mask = np.logical_and(flood_mask, drawUnderMask) + + self.flood_mask[flood_mask] = True + + if self.wandControlsToolbar.autoFillHolesCheckbox.isChecked(): + self.flood_mask = core.binary_fill_holes(self.flood_mask) + + if self.wandControlsToolbar.useConvexHullCheckbox.isChecked(): + self.flood_mask = core.convex_hull_mask(self.flood_mask) + + self.setTempBrushMaskFromWand(self.flood_mask) + + # Label ROI dragging mouse --> draw ROI + elif self.isMouseDragImg1 and self.labelRoiButton.isChecked(): + if self.labelRoiIsRectRadioButton.isChecked(): + x0, y0 = self.labelRoiItem.pos() + w, h = (xdata - x0), (ydata - y0) + self.labelRoiItem.setSize((w, h)) + elif self.labelRoiIsFreeHandRadioButton.isChecked(): + self.freeRoiItem.addPoint(xdata, ydata) + + # Draw freehand clear region --> draw region + elif self.isMouseDragImg1 and self.drawClearRegionButton.isChecked(): + self.freeRoiItem.addPoint(xdata, ydata) + + # Label ROI dragging mouse --> draw ROI + elif self.isMouseDragImg1 and self.zoomRectButton.isChecked(): + x0, y0 = self.zoomRectItem.pos() + w, h = (xdata - x0), (ydata - y0) + self.zoomRectItem.setSize((w, h)) + + def gui_mouseDragEventImg2(self, event): + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + return + + Y, X = self.get_2Dlab(posData.lab).shape + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + if not utils.is_in_bounds(xdata, ydata, X, Y): + return + + # Eraser dragging mouse --> keep erasing + if self.isMouseDragImg2 and self.eraserButton.isChecked(): + posData = self.data[self.pos_i] + lab_2D = self.get_2Dlab(posData.lab) + Y, X = lab_2D.shape + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + brushSize = self.brushSizeSpinbox.value() + rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + + # Build eraser mask + mask = np.zeros(lab_2D.shape, bool) + mask[ymin:ymax, xmin:xmax][diskMask] = True + mask[rrPoly, ccPoly] = True + + if self.eraseOnlyOneID: + mask[lab_2D != self.erasedID] = False + self.setHoverToolSymbolColor( + xdata, + ydata, + self.eraserCirclePen, + (self.ax2_EraserCircle, self.ax1_EraserCircle), + self.eraserButton, + hoverRGB=self.img2.lut[self.erasedID], + ID=self.erasedID, + ) + + self.erasedIDs.update(lab_2D[mask]) + + self.applyEraserMask(mask) + self.setImageImg2(updateLookuptable=False) + + # Brush paint dragging mouse --> keep painting + if self.isMouseDragImg2 and self.brushButton.isChecked(): + posData = self.data[self.pos_i] + lab_2D = self.get_2Dlab(posData.lab) + Y, X = lab_2D.shape + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) + + # Build brush mask + mask = np.zeros(lab_2D.shape, bool) + mask[ymin:ymax, xmin:xmax][diskMask] = True + mask[rrPoly, ccPoly] = True + + # If user double-pressed 'b' then draw over the labels + color = self.brushButton.palette().button().color().name() + if color != self.doublePressKeyButtonColor: + mask[lab_2D != 0] = False + self.setHoverToolSymbolColor( + xdata, + ydata, + self.ax2_BrushCirclePen, + (self.ax2_BrushCircle, self.ax1_BrushCircle), + self.eraserButton, + brush=self.ax2_BrushCircleBrush, + ) + + # Apply brush mask + self.applyBrushMask(mask, self.ax2BrushID) + + self.setImageImg2() + + # Move label dragging mouse --> keep moving + elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + self.moveLabel(x, y) + + def gui_mouseReleaseEventImg1(self, event): + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + alt = modifiers == Qt.AltModifier + right_click = event.button() == Qt.MouseButton.RightButton and not alt + + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + return + + Y, X = self.get_2Dlab(posData.lab).shape + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + if not utils.is_in_bounds(xdata, ydata, X, Y): + self.isMouseDragImg2 = False + self.updateAllImages() + return + + if hasattr(self, "scaleBar"): + if self.scaleBar.isHighlighted() and self.scaleBar.clicked: + self.scaleBar.clicked = False + return + + if hasattr(self, "timestamp"): + if self.timestamp.isHighlighted() and self.timestamp.clicked: + self.timestamp.clicked = False + return + + sendRightClickImg2 = ( + mode == "Segmentation and Tracking" or self.isSnapshot + ) and right_click + if sendRightClickImg2: + # Allow right-click actions on both images + self.gui_mouseReleaseEventImg2(event) + + # Right-click curvature tool mouse release + if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): + self.isRightClickDragImg1 = False + try: + self.curvToolSplineToObj(isRightClick=True) + self.update_rp() + if self.autoIDcheckbox.isChecked(): + self.trackManuallyAddedObject(posData.brushID, True) + if self.isSnapshot: + self.fixCcaDfAfterEdit("Add new ID with curvature tool") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Add new ID with curvature tool") + self.clearCurvItems() + self.curvTool_cb(True) + except ValueError: + self.clearCurvItems() + self.curvTool_cb(True) + pass + + # Eraser mouse release --> update IDs and contours + elif self.isMouseDragImg1 and self.eraserButton.isChecked(): + self.isMouseDragImg1 = False + + self.clearTempBrushImage() + + # Update data (rp, etc) + self.update_rp() + + doUpdateImages = self.checkWarnDeletedIDwithEraser() + + if doUpdateImages: + self.updateAllImages() + + # Brush button mouse release + elif self.isMouseDragImg1 and self.brushButton.isChecked(): + self.isMouseDragImg1 = False + + self.clearTempBrushImage() + + self.brushReleased() + + # Wand tool release, add new object + elif self.isMouseDragImg1 and self.wandToolButton.isChecked(): + self.isMouseDragImg1 = False + + self.clearTempBrushImage() + + posData = self.data[self.pos_i] + posData.lab[self.flood_mask] = posData.brushID + + # Update data (rp, etc) + self.update_rp() + + # Repeat tracking + self.trackManuallyAddedObject(posData.brushID, self.isNewID) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Add new ID with magic-wand") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Add new ID with magic-wand") + + # Label ROI mouse release --> label the ROI with labelRoiWorker + elif self.isMouseDragImg1 and self.labelRoiButton.isChecked(): + self.labelRoiRunning = True + self.app.setOverrideCursor(Qt.WaitCursor) + self.isMouseDragImg1 = False + + if self.labelRoiIsFreeHandRadioButton.isChecked(): + self.freeRoiItem.closeCurve() + + proceed = self.labelRoiCheckStartStopFrame() + if not proceed: + self.labelRoiCancelled() + return + + roiImg, self.labelRoiSlice = self.getLabelRoiImage() + + if roiImg.size == 0: + self.labelRoiCancelled() + return + + if self.labelRoiModel is None: + cancel = self.initLabelRoiModel() + if cancel: + self.labelRoiCancelled() + return + + # Restore state of button because it was maybe unchecked by + # using other tools that are allowed --> see "elif" case in + # labelRoi_cb + self.labelRoiButton.blockSignals(True) + self.labelRoiButton.setChecked(True) + self.labelRoiToolbar.setVisible(True) + self.labelRoiButton.blockSignals(False) + + roiSecondChannel = None + if self.secondChannelName is not None: + secondChannelData = self.getSecondChannelData() + roiSecondChannel = secondChannelData[self.labelRoiSlice] + + isTimelapse = self.labelRoiTrangeCheckbox.isChecked() + if isTimelapse: + start_n = self.labelRoiStartFrameNoSpinbox.value() + stop_n = self.labelRoiStopFrameNoSpinbox.value() + self.progressWin = apps.QDialogWorkerProgress( + title="ROI segmentation", + parent=self, + pbarDesc=f"Segmenting frames n. {start_n} to {stop_n}...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(stop_n - start_n) + + self.app.restoreOverrideCursor() + labelRoiWorker = self.labelRoiActiveWorkers[-1] + labelRoiWorker.start( + roiImg, + posData, + roiSecondChannel=roiSecondChannel, + isTimelapse=isTimelapse, + ) + self.app.setOverrideCursor(Qt.WaitCursor) + self.logger.info( + f"Magic labeller started on image ROI = {self.labelRoiSlice}..." + ) + self.titleLabel.setText("Magic labeller is doing its magic...") + self.setDisabled(True) + + # Move label mouse released, update move + elif self.isMovingLabel and self.moveLabelToolButton.isChecked(): + self.isMovingLabel = False + + # Update data (rp, etc) + self.update_rp() + + # Repeat tracking + self.tracking(enforce=True, assign_unique_new_IDs=False) + + if not self.moveLabelToolButton.findChild(QAction).isChecked(): + self.moveLabelToolButton.setChecked(False) + else: + self.updateAllImages() + + # Assign mother to bud + elif self.assignBudMothButton.isChecked() and self.clickedOnBud: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud]: + return + + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + mothID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to annotate as mother cell", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + mothID_prompt.exec_() + if mothID_prompt.cancel: + return + else: + ID = mothID_prompt.EntryID + obj_idx = posData.IDs.index(ID) + y, x = posData.rp[obj_idx].centroid + xdata, ydata = int(x), int(y) + + if self.isSnapshot: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + relationship = posData.cca_df.at[ID, "relationship"] + ccs = posData.cca_df.at[ID, "cell_cycle_stage"] + is_history_known = posData.cca_df.at[ID, "is_history_known"] + # We allow assiging a cell in G1 as mother only on first frame + # OR if the history is unknown + if relationship == "bud" and posData.frame_i > 0 and is_history_known: + self.assignBudMothButton.setChecked(False) + txt = html_utils.paragraph( + f"You clicked on ID {ID} which is a BUD.

    " + "To assign a bud start by clicking on the bud " + "and release on a cell in G1" + ) + msg = widgets.myMessageBox() + msg.critical(self, "Released on a bud", txt) + self.assignBudMothButton.setChecked(True) + return + + elif posData.frame_i == 0: + # Check that clicked bud actually is smaller that mother + # otherwise warn the user that he might have clicked first + # on a mother + budID = self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud] + new_mothID = self.get_2Dlab(posData.lab)[ydata, xdata] + bud_obj_idx = posData.IDs.index(budID) + new_moth_obj_idx = posData.IDs.index(new_mothID) + rp_budID = posData.rp[bud_obj_idx] + rp_new_mothID = posData.rp[new_moth_obj_idx] + if rp_budID.area >= rp_new_mothID.area: + self.assignBudMothButton.setChecked(False) + msg = widgets.myMessageBox() + txt = ( + f"You clicked FIRST on ID {budID} and then on {new_mothID}.
    " + f"For me this means that you want ID {budID} to be the " + f"BUD of ID {new_mothID}.
    " + f"However ID {budID} is bigger than {new_mothID} " + f"so maybe you should have clicked FIRST on {new_mothID}?

    " + "What do you want me to do?" + ) + txt = html_utils.paragraph(txt) + swapButton, keepButton = msg.warning( + self, + "Which one is bud?", + txt, + buttonsTexts=( + f"Assign ID {new_mothID} as the bud of ID {budID}", + f"Keep ID {budID} as the bud of ID {new_mothID}", + ), + ) + if msg.clickedButton == swapButton: + (xdata, ydata, self.xClickBud, self.yClickBud) = ( + self.xClickBud, + self.yClickBud, + xdata, + ydata, + ) + self.assignBudMothButton.setChecked(True) + + elif is_history_known and not self.clickedOnHistoryKnown: + self.assignBudMothButton.setChecked(False) + budID = self.get_2Dlab(posData.lab)[ydata, xdata] + # Allow assigning an unknown cell ONLY to another unknown cell + txt = ( + f"You started by clicking on ID {budID} which has " + "UNKNOWN history, but you then clicked/released on " + f"ID {ID} which has KNOWN history.\n\n" + "Only two cells with UNKNOWN history can be assigned as " + "relative of each other." + ) + msg = QMessageBox() + msg.critical(self, "Released on a cell with KNOWN history", txt, msg.Ok) + self.assignBudMothButton.setChecked(True) + return + + self.clickedOnHistoryKnown = is_history_known + self.xClickMoth, self.yClickMoth = xdata, ydata + + if ccs != "G1" and posData.frame_i > 0: + self.assignBudMothButton.setChecked(False) + self.onMotherNotInG1(ID) + self.assignBudMothButton.setChecked(True) + else: + self.annotateBudToDifferentMother() + + if not self.assignBudMothButton.findChild(QAction).isChecked(): + self.assignBudMothButton.setChecked(False) + + self.clickedOnBud = False + self.BudMothTempLine.setData([], []) + + # Draw clear region mouse release + elif self.isMouseDragImg1 and self.drawClearRegionButton.isChecked(): + self.isMouseDragImg1 = False + self.freeRoiItem.closeCurve() + self.clearObjsFreehandRegion() + + # Zoom rect mouse release + elif self.isMouseDragImg1 and self.zoomRectButton.isChecked(): + self.isMouseDragImg1 = False + self.zoomRectDone() diff --git a/cellacdc/mixins/canvas_events.py b/cellacdc/mixins/canvas_events.py new file mode 100644 index 000000000..53dcc3945 --- /dev/null +++ b/cellacdc/mixins/canvas_events.py @@ -0,0 +1,1047 @@ +"""Qt view adapter for canvas mouse events.""" + +from __future__ import annotations + +import numpy as np +import pyqtgraph as pg +import skimage.segmentation + +from qtpy.QtCore import Qt, QTimer +from qtpy.QtGui import QGuiApplication, QMouseEvent +from qtpy.QtWidgets import QAction, QMessageBox + +from cellacdc import apps, exception_handler + +from .canvas_context_menu import CanvasContextMenu +from .canvas_selection import CanvasSelection +from .label_editing import LabelEditing + + +class CanvasEvents(CanvasContextMenu, CanvasSelection, LabelEditing): + """Extracted from guiWin.""" + + def gui_mousePressEventImg1(self, event: QMouseEvent): + self.typingEditID = False + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + alt = modifiers == Qt.AltModifier + isMod = alt + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + isCcaMode = mode == "Cell cycle analysis" + isCustomAnnotMode = mode == "Custom annotations" + left_click = event.button() == Qt.MouseButton.LeftButton and not isMod + middle_click = self.isMiddleClick(event, modifiers) + right_click = event.button() == Qt.MouseButton.RightButton + isPanImageClick = self.isPanImageClick(event, modifiers) + brushON = self.brushButton.isChecked() + curvToolON = self.curvToolButton.isChecked() + histON = self.setIsHistoryKnownButton.isChecked() + eraserON = self.eraserButton.isChecked() + rulerON = self.rulerButton.isChecked() + wandON = self.wandToolButton.isChecked() and not isPanImageClick + polyLineRoiON = self.addDelPolyLineRoiButton.isChecked() + labelRoiON = self.labelRoiButton.isChecked() + keepObjON = self.keepIDsButton.isChecked() + whitelistIDsON = self.whitelistIDsButton.isChecked() + separateON = self.separateBudButton.isChecked() + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + manualBackgroundON = self.manualBackgroundButton.isChecked() + magicPromptsON = self.magicPromptsToolButton.isChecked() + pointsLayerON = self.togglePointsLayerAction.isChecked() + copyContourON = ( + self.copyLostObjButton.isChecked() + and self.ax1_lostObjScatterItem.hoverLostID > 0 + ) + findNextMotherButtonON = self.findNextMotherButton.isChecked() + unknownLineageButtonON = self.unknownLineageButton.isChecked() + drawClearRegionON = self.drawClearRegionButton.isChecked() + zoomRectON = self.zoomRectButton.isChecked() + + # Check if right-click on segment of polyline roi to add segment + segments = self.gui_getHoveredSegmentsPolyLineRoi() + if len(segments) == 1 and right_click: + seg = segments[0] + seg.roi.segmentClicked(seg, event) + return + + # Check if right-click on handle of polyline roi to remove it + handles = self.gui_getHoveredHandlesPolyLineRoi() + if len(handles) == 1 and right_click: + handle = handles[0] + handle.roi.removeHandle(handle) + return + + # Check if click on ROI + isClickOnDelRoi = self.gui_clickedDelRoi(event, left_click, right_click) + if isClickOnDelRoi: + return + + dragImgLeft = ( + left_click + and not brushON + and not histON + and not curvToolON + and not eraserON + and not rulerON + and not wandON + and not polyLineRoiON + and not labelRoiON + and not middle_click + and not keepObjON + and not separateON + and not manualBackgroundON + and not drawClearRegionON + and addPointsByClickingButton is None + and not whitelistIDsON + and not zoomRectON + ) + if isPanImageClick: + dragImgLeft = True + + is_right_click_custom_ON = any( + [b.isChecked() for b in self.customAnnotDict.keys()] + ) + + canAnnotateDivision = ( + not self.assignBudMothButton.isChecked() + and not self.setIsHistoryKnownButton.isChecked() + and not self.curvToolButton.isChecked() + and not is_right_click_custom_ON + and not labelRoiON + and not separateON + ) + + # In timelapse mode division can be annotated if isCcaMode and right-click + # while in snapshot mode with Ctrl+right-click + isAnnotateDivision = (right_click and isCcaMode and canAnnotateDivision) or ( + right_click and ctrl and self.isSnapshot + ) + + isCustomAnnot = ( + (right_click or dragImgLeft) + and (isCustomAnnotMode or self.isSnapshot) + and self.customAnnotButton is not None + ) + + is_right_click_action_ON = any( + [b.isChecked() for b in self.checkableQButtonsGroup.buttons()] + ) + + isOnlyRightClick = ( + right_click + and canAnnotateDivision + and not isAnnotateDivision + and not isMod + and not is_right_click_action_ON + and not is_right_click_custom_ON + and not copyContourON + and not findNextMotherButtonON + and not unknownLineageButtonON + and not middle_click + ) + + if isOnlyRightClick: + # Start timer or check if it is a double-right-click + if self.countRightClicks == 0: + self.isDoubleRightClick = False + self.countRightClicks = 1 + self.doubleRightClickTimeElapsed = False + screenPos = event.screenPos() + self._img1_click_xy = (screenPos.x(), screenPos.y()) + QTimer.singleShot(400, self.doubleRightClickTimerCallBack) + return + elif self.countRightClicks == 1 and not self.doubleRightClickTimeElapsed: + self.isDoubleRightClick = True + self.countRightClicks = 0 + self.editIDbutton.setChecked(True) + + # Left click actions + canCurv = ( + curvToolON + and not self.assignBudMothButton.isChecked() + and not brushON + and not dragImgLeft + and not eraserON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canBrush = ( + brushON + and not curvToolON + and not rulerON + and not dragImgLeft + and not eraserON + and not wandON + and not labelRoiON + and not manualBackgroundON + and addPointsByClickingButton is None + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canErase = ( + eraserON + and not curvToolON + and not rulerON + and not dragImgLeft + and not brushON + and not wandON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canRuler = ( + rulerON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not wandON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canWand = ( + wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canPolyLine = ( + polyLineRoiON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not labelRoiON + and not manualBackgroundON + and addPointsByClickingButton is None + and not drawClearRegionON + and not magicPromptsON + and not zoomRectON + ) + canLabelRoi = ( + labelRoiON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not keepObjON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not whitelistIDsON + and not magicPromptsON + and not zoomRectON + ) + canKeep = ( + keepObjON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not whitelistIDsON + and not magicPromptsON + and not zoomRectON + ) + canWhitelistIDs = ( + whitelistIDsON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not keepObjON + and not magicPromptsON + and not zoomRectON + ) + canAddPoint = ( + (pointsLayerON or magicPromptsON) + and addPointsByClickingButton is not None + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and not keepObjON + and not manualBackgroundON + and not drawClearRegionON + and not zoomRectON + ) + canAddManualBackgroundObj = ( + manualBackgroundON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not keepObjON + and not drawClearRegionON + and not magicPromptsON + and not whitelistIDsON + and not zoomRectON + ) + canDrawClearRegion = ( + drawClearRegionON + and not wandON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not labelRoiON + and not manualBackgroundON + and addPointsByClickingButton is None + and not polyLineRoiON + and not magicPromptsON + and not whitelistIDsON + and not zoomRectON + ) + canZoomRect = ( + zoomRectON + and not curvToolON + and not brushON + and not dragImgLeft + and not brushON + and not rulerON + and not polyLineRoiON + and not labelRoiON + and addPointsByClickingButton is None + and not manualBackgroundON + and not drawClearRegionON + and not wandON + and not whitelistIDsON + and not magicPromptsON + ) + + # Enable dragging of the image window or the scalebar + if dragImgLeft and not isCustomAnnot: + x, y = event.pos().x(), event.pos().y() + if hasattr(self, "scaleBar"): + if self.scaleBar.isHighlighted(): + self.scaleBar.mousePressed(x, y) + return + if hasattr(self, "timestamp"): + if self.timestamp.isHighlighted(): + self.timestamp.mousePressed(x, y) + return + pg.ImageItem.mousePressEvent(self.img1, event) + event.ignore() + return + + isAllowedActionViewer = canAddPoint or canRuler + + if mode == "Viewer" and not isAllowedActionViewer: + self.startBlinkingModeCB() + event.ignore() + return + + # Allow right-click or middle-click actions on both images + eventOnImg2 = ( + ( + right_click or (middle_click and not canAddPoint) + # or (left_click and separateON) + ) + and (mode == "Segmentation and Tracking" or self.isSnapshot) + and not isAnnotateDivision + and not manualBackgroundON + ) + if eventOnImg2: + event.isImg1Sender = True + self.gui_mousePressEventImg2(event) + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + Y, X = self.get_2Dlab(posData.lab).shape + if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + else: + return + + # Paint new IDs with brush and left click on the left image + if left_click and canBrush: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + lab_2D = self.get_2Dlab(posData.lab) + Y, X = lab_2D.shape + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False, storeOnlyZoom=True) + + ID = self.getHoverID(xdata, ydata) + + if ID > 0: + posData.brushID = ID + self.isNewID = False + else: + # Update brush ID. Take care of disappearing cells to remember + # to not use their IDs anymore in the future + self.isNewID = True + self.setBrushID() + self.updateLookuptable(lenNewLut=posData.brushID + 1) + + self.brushColor = self.lut[posData.brushID] / 255 + + self.yPressAx2, self.xPressAx2 = y, x + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + diskSlice = (slice(ymin, ymax), slice(xmin, xmax)) + + self.isMouseDragImg1 = True + + # Draw new objects + localLab = lab_2D[diskSlice] + mask = diskMask.copy() + if not self.isPowerBrush() and not ctrl: + mask[localLab != 0] = False + + self.applyBrushMask(mask, posData.brushID, toLocalSlice=diskSlice) + + self.setImageImg2(updateLookuptable=False) + + how = self.drawIDsContComboBox.currentText() + lab2D = self.get_2Dlab(posData.lab) + self.globalBrushMask = np.zeros(lab2D.shape, dtype=bool) + brushMask = localLab == posData.brushID + brushMask = np.logical_and(brushMask, diskMask) + self.setTempImg1Brush( + True, brushMask, posData.brushID, toLocalSlice=diskSlice + ) + + self.lastHoverID = -1 + + elif left_click and canErase: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + lab_2D = self.get_2Dlab(posData.lab) + Y, X = lab_2D.shape + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False, storeOnlyZoom=True) + + self.yPressAx2, self.xPressAx2 = y, x + # Keep a list of erased IDs got erased + self.erasedIDs = set() + + if self.xyOnCtrlPressedFirstTime is not None: + self.erasedID = self.getHoverID(*self.xyOnCtrlPressedFirstTime) + else: + self.erasedID = self.getHoverID(xdata, ydata) + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + + # Build eraser mask + mask = np.zeros(lab_2D.shape, bool) + mask[ymin:ymax, xmin:xmax][diskMask] = True + + # If user double-pressed 'b' then erase over ALL labels + color = self.eraserButton.palette().button().color().name() + eraseOnlyOneID = ( + color != self.doublePressKeyButtonColor and self.erasedID != 0 + ) + + self.eraseOnlyOneID = eraseOnlyOneID + + if eraseOnlyOneID: + mask[lab_2D != self.erasedID] = False + + self.setTempImg1Eraser(mask, init=True) + self.applyEraserMask(mask) + + self.erasedIDs.update(lab_2D[mask]) + + for erasedID in self.erasedIDs: + if erasedID == 0: + continue + self.erasedLab[lab_2D == erasedID] = erasedID + + self.isMouseDragImg1 = True + + elif canAddPoint: + action = addPointsByClickingButton.action + self.storeUndoAddPoint(action) + x, y = event.pos().x(), event.pos().y() + hoveredPoints = action.scatterItem.pointsAt(event.pos()) + if len(hoveredPoints) > 0: + removed_ids = self.removeClickedPoints(action, hoveredPoints) + if not magicPromptsON: + removed_id = min(removed_ids) + addPointsByClickingButton.pointIdSpinbox.setValue(removed_id) + addPointsByClickingButton.pointIdSpinbox.removedId = removed_id + else: + self.restorePrevPointIdRightClick(addPointsByClickingButton) + self.drawPointsLayers(computePointsLayers=False) + else: + point_id = self.getAddedPointId( + magicPromptsON, + addPointsByClickingButton, + right_click, + left_click, + middle_click, + ) + if point_id is None: + return + + self.addClickedPoint(action, x, y, point_id) + self.drawPointsLayers(computePointsLayers=False) + + point_id = self.getClickedPointNewId( + action, + point_id, + addPointsByClickingButton.pointIdSpinbox, + isMagicPrompts=magicPromptsON, + ) + addPointsByClickingButton.pointIdSpinbox.setValue( + point_id, setLinkedWidget=False + ) + + elif left_click and canDrawClearRegion: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + self.freeRoiItem.addPoint(xdata, ydata) + + self.isMouseDragImg1 = True + + elif left_click and canRuler or canPolyLine: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + closePolyLine = len(self.startPointPolyLineItem.pointsAt(event.pos())) > 0 + if not self.tempSegmentON or canPolyLine: + # Keep adding anchor points for polyline + self.ax1_rulerAnchorsItem.setData([xdata], [ydata]) + self.tempSegmentON = True + else: + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + self.tempSegmentON = False + xxRA, yyRA = self.ax1_rulerAnchorsItem.getData() + x0, y0 = xxRA[0], yyRA[0] + if ctrl: + x1, y1 = transformation.snap_xy_to_closest_angle( + x0, y0, xdata, ydata + ) + else: + x1, y1 = xdata, ydata + lengthText = self.getRulerLengthText() + self.ax1_rulerPlotItem.setData( + [x0, x1], [y0, y1], lengthText=lengthText + ) + self.ax1_rulerAnchorsItem.setData([x0, x1], [y0, y1]) + + xxPolyLine = self.startPointPolyLineItem.getData()[0] + if canPolyLine and len(xxPolyLine) == 0: + # Create and add roi item + self.createDelPolyLineRoi() + # Add start point of polyline roi + self.startPointPolyLineItem.setData([xdata], [ydata]) + self.polyLineRoi.points.append((xdata, ydata)) + elif canPolyLine: + # Add points to polyline roi and eventually close it + if not closePolyLine: + self.polyLineRoi.points.append((xdata, ydata)) + self.addPointsPolyLineRoi(closed=closePolyLine) + if closePolyLine: + # Close polyline ROI + if len(self.polyLineRoi.getLocalHandlePositions()) == 2: + self.polyLineRoi = self.replacePolyLineRoiWithLineRoi( + self.polyLineRoi + ) + self.tempSegmentON = False + self.ax1_rulerAnchorsItem.setData([], []) + self.ax1_rulerPlotItem.setData([], []) + self.startPointPolyLineItem.setData([], []) + self.addRoiToDelRoiInfo(self.polyLineRoi) + # Call roi moving on closing ROI + self.delROImoving(self.polyLineRoi) + self.delROImovingFinished(self.polyLineRoi) + + elif left_click and canKeep: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + keepID_win = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to keep", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + keepID_win.exec_() + if keepID_win.cancel: + return + else: + ID = keepID_win.EntryID + + if ID in self.keptObjectsIDs: + self.keptObjectsIDs.remove(ID) + self.clearHighlightedText() + else: + self.keptObjectsIDs.append(ID) + self.highlightLabelID(ID) + + self.updateTempLayerKeepIDs() + + elif left_click and canWhitelistIDs: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + keepID_win = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to select", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + keepID_win.exec_() + if keepID_win.cancel: + return + else: + ID = keepID_win.EntryID + + posData = self.data[self.pos_i] + + if not posData.whitelist: + wl_init = False + if not hasattr(self, "tempWhitelistIDs"): + self.tempWhitelistIDs = ( + set() + ) # not updated, only use in this context + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = self.tempWhitelistIDs + else: + wl_init = True + current_whitelist = posData.whitelist.get(posData.frame_i) + + if ID in current_whitelist: + current_whitelist.remove(ID) + self.removeHighlightLabelID(IDs=[ID]) + else: + current_whitelist.add(ID) + self.highlightLabelID(ID) + + self.whitelistIDsToolbar.whitelistLineEdit.setText(current_whitelist) + + if wl_init: + posData.whitelist[posData.frame_i] = current_whitelist + else: + self.tempWhitelistIDs = current_whitelist + + self.whitelistUpdateTempLayer() + + elif right_click and copyContourON: + hoverLostID = self.ax1_lostObjScatterItem.hoverLostID + self.copyLostObjectMask(hoverLostID) + self.update_rp() + self.updateAllImages() + self.store_data() + + elif right_click and canCurv: + # Draw manually assisted auto contour + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + Y, X = self.get_2Dlab(posData.lab).shape + + self.autoCont_x0 = xdata + self.autoCont_y0 = ydata + self.xxA_autoCont, self.yyA_autoCont = [], [] + self.curvAnchors.addPoints([x], [y]) + img = self.getDisplayedImg1() + self.autoContObjMask = np.zeros(img.shape, np.uint8) + self.isRightClickDragImg1 = True + + elif left_click and canCurv: + # Draw manual spline + x, y = event.pos().x(), event.pos().y() + Y, X = self.get_2Dlab(posData.lab).shape + + # Check if user clicked on starting anchor again --> close spline + closeSpline = False + clickedAnchors = self.curvAnchors.pointsAt(event.pos()) + xxA, yyA = self.curvAnchors.getData() + if len(xxA) > 0: + if len(xxA) == 1: + self.splineHoverON = True + x0, y0 = xxA[0], yyA[0] + if len(clickedAnchors) > 0: + xA_clicked, yA_clicked = clickedAnchors[0].pos() + if x0 == xA_clicked and y0 == yA_clicked: + x = x0 + y = y0 + closeSpline = True + + # Add anchors + self.curvAnchors.addPoints([x], [y]) + try: + xx, yy = self.curvHoverPlotItem.getData() + self.curvPlotItem.setData(xx, yy) + except Exception as e: + # traceback.print_exc() + pass + + if closeSpline: + self.splineHoverON = False + self.curvToolSplineToObj() + self.update_rp() + if self.autoIDcheckbox.isChecked(): + self.trackManuallyAddedObject(posData.brushID, True) + if self.isSnapshot: + self.fixCcaDfAfterEdit("Add new ID with curvature tool") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Add new ID with curvature tool") + self.clearCurvItems() + self.curvTool_cb(True) + + elif left_click and canWand: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + Y, X = self.get_2Dlab(posData.lab).shape + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + self.isNewID = False + posData.brushID = self.get_2Dlab(posData.lab)[ydata, xdata] + if posData.brushID == 0: + self.setBrushID() + self.updateLookuptable(lenNewLut=posData.brushID + 1) + self.isNewID = True + self.brushColor = self.img2.lut[posData.brushID] / 255 + + # NOTE: flood is on mousedrag or release + tol = self.getMagicWandFloodTolerance() + self.initFloodMaskImage() + if self.isSegm3D: + z_slice = self.zSliceScrollBar.sliderPosition() + seed = (z_slice, ydata, xdata) + else: + seed = (ydata, xdata) + + flood_mask = skimage.segmentation.flood(self.flood_img, seed, tolerance=tol) + + drawUnderMask = np.logical_or( + posData.lab == 0, posData.lab == posData.brushID + ) + self.flood_mask = np.logical_and(flood_mask, drawUnderMask) + + if self.wandControlsToolbar.autoFillHolesCheckbox.isChecked(): + self.flood_mask = core.binary_fill_holes(self.flood_mask) + + if self.wandControlsToolbar.useConvexHullCheckbox.isChecked(): + self.flood_mask = core.convex_hull_mask(self.flood_mask) + + self.setTempBrushMaskFromWand(self.flood_mask, init=True) + self.isMouseDragImg1 = True + + elif right_click and self.manualTrackingButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + manualTrackID = self.manualTrackingToolbar.spinboxID.value() + clickedID = self.getClickedID( + xdata, ydata, text=f"that you want to assign to {manualTrackID}" + ) + if clickedID is None: + return + + if clickedID == manualTrackID: + self.manualTrackingToolbar.showWarning( + f"The clicked object already has ID = {manualTrackID}" + ) + return + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + posData = self.data[self.pos_i] + currentIDs = posData.IDs.copy() + if manualTrackID in currentIDs: + tempID = max(currentIDs) + 1 + posData.lab[posData.lab == clickedID] = tempID + posData.lab[posData.lab == manualTrackID] = clickedID + posData.lab[posData.lab == tempID] = manualTrackID + self.manualTrackingToolbar.showWarning( + f"The ID {manualTrackID} already exists --> " + f"ID {manualTrackID} has been swapped with {clickedID}" + ) + else: + posData.lab[posData.lab == clickedID] = manualTrackID + self.manualTrackingToolbar.showInfo( + f"ID {clickedID} changed to {manualTrackID}." + ) + + self.update_rp() + self.updateAllImages() + + elif right_click and manualBackgroundON: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + + delID = posData.manualBackgroundLab[ydata, xdata] + if delID == 0: + return + + self.clearManualBackgroundObject(delID) + textItem = self.manualBackgroundTextItems.pop(delID) + self.ax1.removeItem(textItem) + self.setManualBackgroundImage() + + elif left_click and canAddManualBackgroundObj: + x, y = event.pos().x(), event.pos().y() + + self.addManualBackgroundObject(x, y) + self.setManualBackgroundImage() + self.setManualBackgrounNextID() + + # Label ROI mouse press + elif (left_click or right_click) and canLabelRoi: + if right_click: + # Force model initialization on mouse release + self.labelRoiModel = None + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + + if self.labelRoiIsRectRadioButton.isChecked(): + self.labelRoiItem.setPos((xdata, ydata)) + elif self.labelRoiIsFreeHandRadioButton.isChecked(): + self.freeRoiItem.addPoint(xdata, ydata) + + self.isMouseDragImg1 = True + + # Annotate cell cycle division + elif isAnnotateDivision: + if posData.cca_df is None: + return + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + divID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to annotate as divided", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + divID_prompt.exec_() + if divID_prompt.cancel: + return + else: + ID = divID_prompt.EntryID + obj_idx = posData.IDs.index(ID) + y, x = posData.rp[obj_idx].centroid + xdata, ydata = int(x), int(y) + + if not self.isSnapshot: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + # Annotate or undo division + self.manualCellCycleAnnotation(ID) + else: + self.undoBudMothAssignment(ID) + + # Assign bud to mother (mouse down on bud) + elif right_click and self.assignBudMothButton.isChecked(): + if self.clickedOnBud: + # NOTE: self.clickedOnBud is set to False when assigning a mother + # is successfull in mouse release event + # We still have to click on a mother + return + + if posData.cca_df is None: + return + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + budID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID of a bud you want to correct mother assignment", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + budID_prompt.exec_() + if budID_prompt.cancel: + return + else: + ID = budID_prompt.EntryID + + obj_idx = posData.IDs.index(ID) + y, x = posData.rp[obj_idx].centroid + xdata, ydata = int(x), int(y) + + relationship = posData.cca_df.at[ID, "relationship"] + is_history_known = posData.cca_df.at[ID, "is_history_known"] + self.clickedOnHistoryKnown = is_history_known + # We allow assiging a cell in G1 as bud only on first frame + # OR if the history is unknown + if relationship != "bud" and posData.frame_i > 0 and is_history_known: + txt = ( + f"You clicked on ID {ID} which is NOT a bud.\n" + "To assign a bud to a cell start by clicking on a bud " + "and release on a cell in G1" + ) + msg = QMessageBox() + msg.critical(self, "Not a bud", txt, msg.Ok) + return + + self.clickedOnBud = True + self.xClickBud, self.yClickBud = xdata, ydata + + # Annotate (or undo) that cell has unknown history + elif right_click and self.setIsHistoryKnownButton.isChecked(): + if posData.cca_df is None: + return + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + unknownID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to annotate as " + '"history UNKNOWN/KNOWN"', + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + unknownID_prompt.exec_() + if unknownID_prompt.cancel: + return + else: + ID = unknownID_prompt.EntryID + obj_idx = posData.IDs.index(ID) + y, x = posData.rp[obj_idx].centroid + xdata, ydata = int(x), int(y) + + self.annotateIsHistoryKnown(ID) + if not self.setIsHistoryKnownButton.findChild(QAction).isChecked(): + self.setIsHistoryKnownButton.setChecked(False) + + elif isCustomAnnot: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + clickedBkgrDialog = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to annotate as divided", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + clickedBkgrDialog.exec_() + if clickedBkgrDialog.cancel: + return + else: + ID = clickedBkgrDialog.EntryID + obj_idx = posData.IDs.index(ID) + y, x = posData.rp[obj_idx].centroid + xdata, ydata = int(x), int(y) + + button = self.doCustomAnnotation(ID) + if button is None: + return + + keepActive = self.customAnnotDict[button]["state"]["keepActive"] + if not keepActive: + button.setChecked(False) + + elif right_click and findNextMotherButtonON: + if posData.frame_i == 0: + return + + self.find_mother_action(posData, event, ydata, xdata) + + elif right_click and unknownLineageButtonON: + if posData.frame_i == 0: + return + + self.annotate_unknown_lineage_action(posData, event, ydata, xdata) + + elif (left_click or right_click) and canZoomRect: + if left_click: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + + self.zoomRectItem.setPos((xdata, ydata)) + + self.isMouseDragImg1 = True + else: + try: + xRange, yRange = self.zoomRectItem.getLastRange() + self.ax1.setRange(xRange=xRange, yRange=yRange, padding=0) + except Exception as err: + QTimer.singleShot(100, self.autoRange) diff --git a/cellacdc/mixins/canvas_hover.py b/cellacdc/mixins/canvas_hover.py new file mode 100644 index 000000000..13679525a --- /dev/null +++ b/cellacdc/mixins/canvas_hover.py @@ -0,0 +1,605 @@ +"""Qt view adapter for canvas hover and cursor interactions.""" + +from __future__ import annotations + +import pyqtgraph as pg +from typing import Any +from qtpy.QtCore import Qt +from qtpy.QtGui import QGuiApplication + +from cellacdc import html_utils, widgets + +from .canvas_events import CanvasEvents + + +class CanvasHover(CanvasEvents): + """Extracted from guiWin.""" + + def drawTempMergeObjsLine(self, event, posData, modifiers): + if self.clickObjYc is None: + return + modifier = modifiers == Qt.ShiftModifier + x, y = event.pos() + y2, x2 = y, x + xdata, ydata = int(x), int(y) + y1, x1 = self.clickObjYc, self.clickObjXc + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID != 0: + obj_idx = posData.IDs_idxs[ID] + obj = posData.rp[obj_idx] + y2, x2 = self.getObjCentroid(obj.centroid) + + if modifier and ID > 0: + self.mergeObjsTempLine.addPoint(x2, y2) + elif not modifier: + self.mergeObjsTempLine.setData([x1, x2], [y1, y2]) + + def drawTempMothBudLine(self, event, posData): + x, y = event.pos() + y2, x2 = y, x + xdata, ydata = int(x), int(y) + y1, x1 = self.yClickBud, self.xClickBud + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + self.BudMothTempLine.setData([x1, x2], [y1, y2]) + else: + obj_idx = posData.IDs_idxs[ID] + obj = posData.rp[obj_idx] + y2, x2 = self.getObjCentroid(obj.centroid) + self.BudMothTempLine.setData([x1, x2], [y1, y2]) + + def drawTempRulerLine(self, event): + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + x, y = event.pos() + x1, y1 = int(x), int(y) + xxRA, yyRA = self.ax1_rulerAnchorsItem.getData() + x0, y0 = xxRA[0], yyRA[0] + if ctrl: + x1, y1 = transformation.snap_xy_to_closest_angle(x0, y0, x1, y1) + self.ax1_rulerPlotItem.setData([x0, x1], [y0, y1]) + + def gui_add_ax_cursors(self): + try: + self.ax1.removeItem(self.ax1_cursor) + self.ax2.removeItem(self.ax2_cursor) + except Exception as e: + pass + + self.ax2_cursor = pg.ScatterPlotItem( + symbol="+", + pxMode=True, + pen=pg.mkPen("k", width=1), + brush=pg.mkBrush("w"), + size=16, + tip=None, + ) + self.ax2.addItem(self.ax2_cursor) + + self.ax1_cursor = pg.ScatterPlotItem( + symbol="+", + pxMode=True, + pen=pg.mkPen("k", width=1), + brush=pg.mkBrush("w"), + size=16, + tip=None, + ) + self.ax1.addItem(self.ax1_cursor) + + def gui_hoverEventImg1(self, event, isHoverImg1=True): + try: + posData = self.data[self.pos_i] + except AttributeError: + return + + # Update x, y, value label bottom right + if not event.isExit(): + self.xHoverImg, self.yHoverImg = event.pos() + else: + self.xHoverImg, self.yHoverImg = None, None + + if event.isExit(): + self.resetCursor() + + if not event.isExit() and self.slideshowWin is not None: + self.slideshowWin.setMirroredCursorPos(*event.pos()) + + # Alt key was released --> restore cursor + modifiers = QGuiApplication.keyboardModifiers() + cursorsInfo = self.gui_setCursor(modifiers, event) + self.highlightHoverLostObj(modifiers, event) + + drawRulerLine = ( + (self.rulerButton.isChecked() or self.addDelPolyLineRoiButton.isChecked()) + and self.tempSegmentON + and not event.isExit() + ) + if drawRulerLine: + self.drawTempRulerLine(event) + + if not event.isExit(): + x, y = event.pos() + xdata, ydata = int(x), int(y) + _img = self.img1.image + Y, X = _img.shape[:2] + if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: + ID = self.currentLab2D[ydata, xdata] + self.updatePropsWidget(ID, fromHover=True) + activeToolButton = self.getActiveToolButton() + hoverText = self.hoverValuesFormatted( + xdata, ydata, activeToolButton, isHoverImg1 + ) + self.checkHighlightScaleBar(x, y, activeToolButton) + self.checkHighlightTimestamp(x, y, activeToolButton) + self.wcLabel.setText(hoverText) + else: + self.clickedOnBud = False + self.BudMothTempLine.setData([], []) + self.wcLabel.setText("") + + if cursorsInfo["setKeepObjCursor"]: + x, y = event.pos() + self.highlightHoverIDsKeptObj(x, y) + + if cursorsInfo["setManualTrackingCursor"]: + x, y = event.pos() + # self.highlightHoverID(x, y) + self.drawManualTrackingGhost(x, y) + + if cursorsInfo["setManualBackgroundCursor"]: + x, y = event.pos() + # self.highlightHoverID(x, y) + self.drawManualBackgroundObj(x, y) + + if ( + not cursorsInfo["setManualTrackingCursor"] + and not cursorsInfo["setManualBackgroundCursor"] + ): + self.clearGhost() + + setMoveLabelCursor = cursorsInfo["setMoveLabelCursor"] + setExpandLabelCursor = cursorsInfo["setExpandLabelCursor"] + if setMoveLabelCursor or setExpandLabelCursor: + x, y = event.pos() + self.updateHoverLabelCursor(x, y) + + # Draw eraser circle + if cursorsInfo["setEraserCursor"]: + x, y = event.pos() + self.updateEraserCursor(x, y, isHoverImg1=isHoverImg1) + self.hideItemsHoverBrush(xy=(x, y)) + elif self.eraserButton.isChecked() and not event.isExit(): + if self.xyOnCtrlPressedFirstTime is not None: + self.updateEraserCursor( + x, + y, + xyLocked=self.xyOnCtrlPressedFirstTime, + isHoverImg1=isHoverImg1, + ) + self.hideItemsHoverBrush(xy=(x, y)) + else: + eraserCursors = ( + self.ax1_EraserCircle, + self.ax2_EraserCircle, + self.ax1_EraserX, + self.ax2_EraserX, + ) + self.setHoverToolSymbolData([], [], eraserCursors) + + # Draw Brush circle + if cursorsInfo["setBrushCursor"]: + x, y = event.pos() + self.updateBrushCursor(x, y, isHoverImg1=isHoverImg1) + self.hideItemsHoverBrush(xy=(x, y)) + elif cursorsInfo["setAddPointCursor"]: + x, y = event.pos() + self.setHoverCircleAddPoint(x, y) + else: + self.setHoverToolSymbolData( + [], + [], + (self.ax2_BrushCircle, self.ax1_BrushCircle), + ) + + # Draw label ROi circular cursor + setLabelRoiCircCursor = cursorsInfo["setLabelRoiCircCursor"] + if setLabelRoiCircCursor: + x, y = event.pos() + else: + x, y = None, None + self.updateLabelRoiCircularCursor(x, y, setLabelRoiCircCursor) + + drawMothBudLine = ( + self.assignBudMothButton.isChecked() + and self.clickedOnBud + and not event.isExit() + ) + if drawMothBudLine: + self.drawTempMothBudLine(event, posData) + + drawMergeObjsLine = self.mergeIDsButton.isChecked() and not event.isExit() + if drawMergeObjsLine: + self.drawTempMergeObjsLine(event, posData, modifiers) + + # Temporarily draw spline curve + # see https://stackoverflow.com/questions/33962717/interpolating-a-closed-curve-using-scipy + drawSpline = ( + self.curvToolButton.isChecked() + and self.splineHoverON + and not event.isExit() + ) + if drawSpline: + self.hoverEventDrawSpline(event) + + setMirroredCursor = ( + self.app.overrideCursor() is None + and not event.isExit() + and isHoverImg1 + and self.showMirroredCursorAction.isChecked() + ) + if setMirroredCursor: + x, y = event.pos() + self.ax2_cursor.setData([x], [y]) + else: + self.ax2_cursor.setData([], []) + + return cursorsInfo + + def gui_hoverEventImg2(self, event): + try: + posData = self.data[self.pos_i] + except AttributeError: + return + + if not event.isExit(): + self.xHoverImg, self.yHoverImg = event.pos() + else: + self.xHoverImg, self.yHoverImg = None, None + + # Cursor left image --> restore cursor + if event.isExit() and self.app.overrideCursor() is not None: + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + # Alt key was released --> restore cursor + modifiers = QGuiApplication.keyboardModifiers() + noModifier = modifiers == Qt.NoModifier + shift = modifiers == Qt.ShiftModifier + ctrl = modifiers == Qt.ControlModifier + if self.app.overrideCursor() == Qt.SizeAllCursor and noModifier: + self.app.restoreOverrideCursor() + + setBrushCursor = ( + self.brushButton.isChecked() + and not event.isExit() + and (noModifier or shift or ctrl) + ) + setEraserCursor = ( + self.eraserButton.isChecked() and not event.isExit() and noModifier + ) + setLabelRoiCircCursor = ( + self.labelRoiButton.isChecked() + and not event.isExit() + and (noModifier or shift or ctrl) + and self.labelRoiIsCircularRadioButton.isChecked() + ) + if setBrushCursor or setEraserCursor or setLabelRoiCircCursor: + self.app.setOverrideCursor(Qt.CrossCursor) + + setMoveLabelCursor = ( + self.moveLabelToolButton.isChecked() and not event.isExit() and noModifier + ) + + setExpandLabelCursor = ( + self.expandLabelToolButton.isChecked() and not event.isExit() and noModifier + ) + + # Cursor is moving on image while Alt key is pressed --> pan cursor + alt = QGuiApplication.keyboardModifiers() == Qt.AltModifier + setPanImageCursor = alt and not event.isExit() + if setPanImageCursor and self.app.overrideCursor() is None: + self.app.setOverrideCursor(Qt.SizeAllCursor) + + setKeepObjCursor = ( + self.keepIDsButton.isChecked() and not event.isExit() and noModifier + ) + if setKeepObjCursor and self.app.overrideCursor() is None: + self.app.setOverrideCursor(Qt.PointingHandCursor) + + # Update x, y, value label bottom right + if not event.isExit(): + x, y = event.pos() + xdata, ydata = int(x), int(y) + _img = self.currentLab2D + Y, X = _img.shape + # hoverText = self.hoverValuesFormatted(xdata, ydata) + # self.wcLabel.setText(hoverText) + else: + if self.eraserButton.isChecked() or self.brushButton.isChecked(): + self.gui_mouseReleaseEventImg2(event) + self.wcLabel.setText(f"") + + if setMoveLabelCursor or setExpandLabelCursor: + x, y = event.pos() + self.updateHoverLabelCursor(x, y) + + if setKeepObjCursor: + x, y = event.pos() + self.highlightHoverIDsKeptObj(x, y) + + # Draw eraser circle + if setEraserCursor: + x, y = event.pos() + self.updateEraserCursor(x, y, isHoverImg1=False) + else: + self.setHoverToolSymbolData( + [], + [], + ( + self.ax1_EraserCircle, + self.ax2_EraserCircle, + self.ax1_EraserX, + self.ax2_EraserX, + ), + ) + + # Draw Brush circle + if setBrushCursor: + x, y = event.pos() + self.updateBrushCursor(x, y, isHoverImg1=False) + else: + self.setHoverToolSymbolData( + [], + [], + (self.ax2_BrushCircle, self.ax1_BrushCircle), + ) + + # Draw label ROi circular cursor + if setLabelRoiCircCursor: + x, y = event.pos() + else: + x, y = None, None + self.updateLabelRoiCircularCursor(x, y, setLabelRoiCircCursor) + + def gui_hoverEventRightImage(self, event): + try: + posData = self.data[self.pos_i] + except AttributeError: + return + + if event.isExit(): + self.resetCursors() + + self.gui_hoverEventImg1(event, isHoverImg1=False) + setMirroredCursor = ( + self.app.overrideCursor() is None + and not event.isExit() + and self.showMirroredCursorAction.isChecked() + ) + if setMirroredCursor: + x, y = event.pos() + self.ax1_cursor.setData([x], [y]) + + def gui_setCursor(self, modifiers, event): + noModifier = modifiers == Qt.NoModifier + shift = modifiers == Qt.ShiftModifier + ctrl = modifiers == Qt.ControlModifier + alt = modifiers == Qt.AltModifier + + # Alt key was released --> restore cursor + if self.app.overrideCursor() == Qt.SizeAllCursor and noModifier: + self.app.restoreOverrideCursor() + + setBrushCursor = ( + self.brushButton.isChecked() + and not event.isExit() + and (noModifier or shift or ctrl) + ) + setEraserCursor = ( + self.eraserButton.isChecked() and not event.isExit() and noModifier + ) + setAddDelPolyLineCursor = ( + self.addDelPolyLineRoiButton.isChecked() + and not event.isExit() + and noModifier + ) + setLabelRoiCircCursor = ( + self.labelRoiButton.isChecked() + and not event.isExit() + and (noModifier or shift or ctrl) + and self.labelRoiIsCircularRadioButton.isChecked() + ) + setWandCursor = ( + self.wandToolButton.isChecked() and not event.isExit() and noModifier + ) + setLabelRoiCursor = ( + self.labelRoiButton.isChecked() and not event.isExit() and noModifier + ) + setMoveLabelCursor = ( + self.moveLabelToolButton.isChecked() and not event.isExit() and noModifier + ) + setExpandLabelCursor = ( + self.expandLabelToolButton.isChecked() and not event.isExit() and noModifier + ) + setCurvCursor = ( + self.curvToolButton.isChecked() and not event.isExit() and noModifier + ) + setKeepObjCursor = ( + self.keepIDsButton.isChecked() and not event.isExit() and noModifier + ) + setCustomAnnotCursor = ( + self.customAnnotButton is not None and not event.isExit() and noModifier + ) + setManualTrackingCursor = ( + self.manualTrackingButton.isChecked() and not event.isExit() and noModifier + ) + setManualBackgroundCursor = ( + self.manualBackgroundButton.isChecked() + and not event.isExit() + and noModifier + ) + setZoomRectCursor = ( + self.zoomRectButton.isChecked() and not event.isExit() and noModifier + ) + setEditIDCursor = self.editIDbutton.isChecked() and not event.isExit() + magicPromptsON = self.magicPromptsToolButton.isChecked() + pointsLayerON = self.togglePointsLayerAction.isChecked() + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + setAddPointCursor = ( + (pointsLayerON or magicPromptsON) + and addPointsByClickingButton is not None + and not event.isExit() + and noModifier + ) + overrideCursor = self.app.overrideCursor() + setPanImageCursor = alt and not event.isExit() + if setPanImageCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.SizeAllCursor) + elif setBrushCursor or setEraserCursor or setLabelRoiCircCursor: + self.app.setOverrideCursor(Qt.CrossCursor) + elif setWandCursor and overrideCursor is None: + self.app.setOverrideCursor(self.wandCursor) + elif setLabelRoiCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.CrossCursor) + elif setCurvCursor and overrideCursor is None: + self.app.setOverrideCursor(self.curvCursor) + elif setCustomAnnotCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.PointingHandCursor) + elif setAddDelPolyLineCursor: + self.app.setOverrideCursor(self.polyLineRoiCursor) + elif setCustomAnnotCursor: + x, y = event.pos() + self.highlightHoverID(x, y) + elif setKeepObjCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.PointingHandCursor) + elif setManualTrackingCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.PointingHandCursor) + elif setManualBackgroundCursor and overrideCursor is None: + self.app.setOverrideCursor(Qt.PointingHandCursor) + elif setAddPointCursor: + self.app.setOverrideCursor(self.addPointsCursor) + elif setZoomRectCursor: + self.app.setOverrideCursor(Qt.CrossCursor) + elif setEditIDCursor and overrideCursor is None: + if shift: + self.app.setOverrideCursor(Qt.CrossCursor) + else: + self.app.restoreOverrideCursor() + + return { + "setBrushCursor": setBrushCursor, + "setEraserCursor": setEraserCursor, + "setAddDelPolyLineCursor": setAddDelPolyLineCursor, + "setLabelRoiCircCursor": setLabelRoiCircCursor, + "setWandCursor": setWandCursor, + "setLabelRoiCursor": setLabelRoiCursor, + "setMoveLabelCursor": setMoveLabelCursor, + "setExpandLabelCursor": setExpandLabelCursor, + "setCurvCursor": setCurvCursor, + "setKeepObjCursor": setKeepObjCursor, + "setCustomAnnotCursor": setCustomAnnotCursor, + "setManualTrackingCursor": setManualTrackingCursor, + "setManualBackgroundCursor": setManualBackgroundCursor, + "setAddPointCursor": setAddPointCursor, + "setZoomRectCursor": setZoomRectCursor, + "setEditIDCursor": setEditIDCursor, + } + + def onCtrlPressedFirstTime(self): + x, y = self.xHoverImg, self.yHoverImg + if x is None: + self.xyOnCtrlPressedFirstTime = None + return + + xdata, ydata = int(x), int(y) + Y, X = self.currentLab2D.shape + + if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): + self.xyOnCtrlPressedFirstTime = None + return + + ID = self.currentLab2D[ydata, xdata] + if ID == 0: + self.xyOnCtrlPressedFirstTime = None + return + + self.xyOnCtrlPressedFirstTime = (xdata, ydata) + + def onCtrlReleased(self): + self.xyOnCtrlPressedFirstTime = None + + def updateHoverLabelCursor(self, x, y): + if x is None: + self.hoverLabelID = 0 + return + + xdata, ydata = int(x), int(y) + Y, X = self.currentLab2D.shape + if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): + return + + ID = self.currentLab2D[ydata, xdata] + self.hoverLabelID = ID + + if ID == 0: + if self.highlightedID != 0: + self.updateAllImages() + self.highlightedID = 0 + return + + if self.app.overrideCursor() != Qt.SizeAllCursor: + self.app.setOverrideCursor(Qt.SizeAllCursor) + + if not self.isMovingLabel: + self.highlightSearchedID(ID) + + def warnAddingPointWithExistingId(self, point_id, table_endname=""): + posData = self.data[self.pos_i] + if not point_id in posData.IDs_idxs: + return True + + msg = widgets.myMessageBox(wrapText=False) + txt = f""" + Cell ID {point_id} already exists!

    + Are you sure you want to add this point? + """ + if table_endname: + txt = f""" + The loaded table {table_endname} has point id + {point_id}. +

    However, {txt} + """ + txt = html_utils.paragraph(txt) + _, _, yesButton = msg.warning( + self, + f"Cell ID {point_id} already exist", + txt, + buttonsTexts=("Cancel", "No, do not add", f"Yes, add point id {point_id}"), + ) + return msg.clickedButton == yesButton + + def gui_getHoveredSegmentsPolyLineRoi(self): + posData = self.data[self.pos_i] + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + segments = [] + for roi in delROIs_info["rois"]: + if not isinstance(roi, pg.PolyLineROI): + continue + for seg in roi.segments: + if seg.currentPen == seg.hoverPen: + seg.roi = roi + segments.append(seg) + return segments + + def gui_getHoveredHandlesPolyLineRoi(self): + posData = self.data[self.pos_i] + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + handles = [] + for roi in delROIs_info["rois"]: + if not isinstance(roi, pg.PolyLineROI): + continue + for handle in roi.getHandles(): + if handle.currentPen == handle.hoverPen: + handle.roi = roi + handles.append(handle) + return handles diff --git a/cellacdc/mixins/canvas_right_image.py b/cellacdc/mixins/canvas_right_image.py new file mode 100644 index 000000000..f9f5710f1 --- /dev/null +++ b/cellacdc/mixins/canvas_right_image.py @@ -0,0 +1,51 @@ +"""View adapter for duplicated right-image interactions.""" + +from __future__ import annotations + +from qtpy.QtCore import Qt +from qtpy.QtGui import QGuiApplication + +from cellacdc import exception_handler + +from .canvas_drawing import CanvasDrawing +from .canvas_events import CanvasEvents + + +class CanvasRightImage(CanvasDrawing, CanvasEvents): + """Extracted from guiWin.""" + + def getMouseDataCoordsRightImage(self): + text = self.wcLabel.text() + if not text: + return + + ax_idx = int(re.findall(r"\(ax(\d)\)", text)[0]) + if ax_idx == 0: + return + + coords = re.findall(r"x=(\d+), y=(\d+) \|", text)[0] + + return tuple([int(val) for val in coords]) + + def gui_mousePressRightImage(self, event): + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + alt = modifiers == Qt.AltModifier + isMod = alt + right_click = event.button() == Qt.MouseButton.RightButton and not isMod + is_right_click_action_ON = any( + [b.isChecked() for b in self.checkableQButtonsGroup.buttons()] + ) + self.typingEditID = False + showLabelsGradMenu = right_click and not is_right_click_action_ON + if showLabelsGradMenu: + self.gui_rightImageShowContextMenu(event) + event.ignore() + else: + self.gui_mousePressEventImg1(event) + + def gui_mouseDragRightImage(self, event): + self.gui_mouseDragEventImg1(event) + + def gui_mouseReleaseRightImage(self, event): + self.gui_mouseReleaseEventImg1(event) diff --git a/cellacdc/mixins/canvas_selection.py b/cellacdc/mixins/canvas_selection.py new file mode 100644 index 000000000..cafa84b15 --- /dev/null +++ b/cellacdc/mixins/canvas_selection.py @@ -0,0 +1,801 @@ +"""Qt view adapter for canvas selection interactions.""" + +from __future__ import annotations + +import time + +import pyqtgraph as pg +import scipy.ndimage +import skimage.morphology + +from qtpy.QtCore import Qt +from qtpy.QtGui import QGuiApplication +from qtpy.QtWidgets import QAction, QGraphicsSceneMouseEvent + +from cellacdc import apps, exception_handler + +from .canvas_tool import CanvasTool +from .brush_tools import BrushTools + + +class CanvasSelection(CanvasTool, BrushTools): + """Extracted from guiWin.""" + + def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): + modifiers = QGuiApplication.keyboardModifiers() + alt = modifiers == Qt.AltModifier + shift = modifiers == Qt.ShiftModifier + shift_regardless = bool(modifiers & Qt.ShiftModifier) + isMod = alt + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + left_click = event.button() == Qt.MouseButton.LeftButton and not alt + middle_click = self.isMiddleClick(event, modifiers) + right_click = event.button() == Qt.MouseButton.RightButton and not alt + isPanImageClick = self.isPanImageClick(event, modifiers) + eraserON = self.eraserButton.isChecked() + brushON = self.brushButton.isChecked() + separateON = self.separateBudButton.isChecked() + self.typingEditID = False + + # Drag image if neither brush or eraser are On pressed + dragImg = left_click and not eraserON and not brushON and not middle_click + if isPanImageClick: + dragImg = True + + # Enable dragging of the image window like pyqtgraph original code + if dragImg: + pg.ImageItem.mousePressEvent(self.img2, event) + event.ignore() + return + + if mode == "Viewer" and middle_click: + self.startBlinkingModeCB() + event.ignore() + return + + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + Y, X = self.get_2Dlab(posData.lab).shape + if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + else: + return + + # Check if right click on ROI + isClickOnDelRoi = self.gui_clickedDelRoi(event, left_click, right_click) + if isClickOnDelRoi: + return + + # show gradient widget menu if none of the right-click actions are ON + # and event is not coming from image 1 + is_right_click_action_ON = any( + [b.isChecked() for b in self.checkableQButtonsGroup.buttons()] + ) + is_right_click_custom_ON = any( + [b.isChecked() for b in self.customAnnotDict.keys()] + ) + is_event_from_img1 = False + if hasattr(event, "isImg1Sender"): + is_event_from_img1 = event.isImg1Sender + + is_only_right_click = ( + right_click and not is_right_click_action_ON and not middle_click + ) + + showLabelsGradMenu = is_only_right_click and not is_event_from_img1 + + if showLabelsGradMenu: + self.labelsGrad.showMenu(event) + event.ignore() + return + + editInViewerMode = ( + (is_right_click_action_ON or is_right_click_custom_ON) + and (right_click or middle_click) + and mode == "Viewer" + ) + + if editInViewerMode: + self.startBlinkingModeCB() + event.ignore() + return + + # Left-click is used for brush, eraser, separate bud, curvature tool + # and magic labeller + # Brush and eraser are mutually exclusive but we want to keep the eraser + # or brush ON and disable them temporarily to allow left-click with + # separate ON + canDelete = mode == "Segmentation and Tracking" or self.isSnapshot + + # Delete ID (set to 0) + if middle_click and canDelete: + t0 = time.perf_counter() + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + delID = self.get_2Dlab(posData.lab)[ydata, xdata] + if delID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + delID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.
    " + "Enter here ID(s) that you want to delete

    " + "You can enter multiple IDs separated by comma", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + allowList=True, + isInteger=True, + ) + delID_prompt.exec_() + if delID_prompt.cancel: + return + delIDs = delID_prompt.EntryID + else: + delIDs = [delID] + + # Ask to propagate change to all future visited frames + key = "Delete ID" + askAction = self.askHowFutureFramesActions[key] + doNotShow = not askAction.isChecked() + (UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain) = ( + self.propagateChange( + delIDs, + key, + doNotShow, + posData.UndoFutFrames_DelID, + posData.applyFutFrames_DelID, + ) + ) + + if UndoFutFrames is None: + return + + # Store undo state before modifying stuff + self.storeUndoRedoStates(UndoFutFrames) + posData.doNotShowAgain_DelID = doNotShowAgain + posData.UndoFutFrames_DelID = UndoFutFrames + posData.applyFutFrames_DelID = applyFutFrames + includeUnvisited = posData.includeUnvisitedInfo["Delete ID"] + + delID_mask = self.deleteIDmiddleClick( + delIDs, applyFutFrames, includeUnvisited, shift=shift_regardless + ) + if delID_mask.ndim == 3: + delID_mask = delID_mask[self.z_lab()] + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Delete ID") + else: + self.warnEditingWithCca_df("Delete ID", update_images=False) + + self.setImageImg2() + delROIsIDs = self.setAllTextAnnotations() + self.setAllContoursImages(delROIsIDs=delROIsIDs, compute=False) + + how = self.drawIDsContComboBox.currentText() + if how.find("overlay segm. masks") != -1: + self.labelsLayerImg1.image[delID_mask] = 0 + self.labelsLayerImg1.setImage(self.labelsLayerImg1.image) + + how_ax2 = self.getAnnotateHowRightImage() + if how_ax2.find("overlay segm. masks") != -1: + self.labelsLayerRightImg.image[delID_mask] = 0 + self.labelsLayerRightImg.setImage(self.labelsLayerRightImg.image) + + self.highlightLostNew() + + # Separate bud or objects with same ID + elif right_click and separateON: + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + sepID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here ID that you want to split", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + sepID_prompt.exec_() + if sepID_prompt.cancel: + return + else: + ID = sepID_prompt.EntryID + y, x = posData.rp[posData.IDs_idxs[ID]].centroid[-2:] + xdata, ydata = int(x), int(y) + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + max_ID = max(posData.IDs, default=1) + + if self.isSegm3D and not shift: + z = self.zSliceScrollBar.sliderPosition() + posData.lab, splittedIDs = measure.separate_with_label( + posData.lab, + posData.rp, + [ID], + max_ID, + click_coords_list=[(z, ydata, xdata)], + ) + success = True + # self.set_2Dlab(lab2D) + elif not shift: + result = core.split_along_convexity_defects( + ID, self.get_2Dlab(posData.lab), max_ID + ) + lab2D, success, splittedIDs = result + self.set_2Dlab(lab2D) + else: + success = False + + # If automatic bud separation was not successfull call manual one + if not success: + posData.disableAutoActivateViewerWindow = True + img = self.getDisplayedImg1() + col = "manual_separate_draw_mode" + drawMode = self.df_settings.at[col, "value"] + manualSep = apps.manualSeparateGui( + self.get_2Dlab(posData.lab), + ID, + img, + fontSize=self.fontSize, + IDcolor=self.lut[ID], + parent=self, + drawMode=drawMode, + ) + manualSep.setState(self.lastManualSeparateState) + manualSep.show() + manualSep.centerWindow() + manualSep.show(block=True) + if manualSep.cancel: + posData.disableAutoActivateViewerWindow = False + if not self.separateBudButton.findChild(QAction).isChecked(): + self.separateBudButton.setChecked(False) + return + self.lastManualSeparateState = manualSep.state() + lab2D = self.get_2Dlab(posData.lab) + lab2D[manualSep.lab != 0] = manualSep.lab[manualSep.lab != 0] + self.set_2Dlab(lab2D) + splittedIDs = [obj.label for obj in manualSep.rp] + posData.disableAutoActivateViewerWindow = False + self.storeManualSeparateDrawMode(manualSep.drawMode) + + # Update data (rp, etc) + self.update_rp() + + # Repeat tracking + self.trackSubsetIDs(splittedIDs) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Separate IDs") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Separate IDs") + + self.store_data() + + if not self.separateBudButton.findChild(QAction).isChecked(): + self.separateBudButton.setChecked(False) + + # Fill holes + elif right_click and self.fillHolesToolButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + clickedBkgrID = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here the ID that you want to " + "fill the holes of", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + clickedBkgrID.exec_() + if clickedBkgrID.cancel: + return + else: + ID = clickedBkgrID.EntryID + + if ID in posData.lab: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + obj_idx = posData.IDs.index(ID) + obj = posData.rp[obj_idx] + objMask = self.getObjImage(obj.image, obj.bbox) + localFill = scipy.ndimage.binary_fill_holes(objMask) + posData.lab[self.getObjSlice(obj.slice)][localFill] = ID + + self.update_rp() + self.updateAllImages() + + if not self.fillHolesToolButton.findChild(QAction).isChecked(): + self.fillHolesToolButton.setChecked(False) + + # Hull contour + elif right_click and self.hullContToolButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + mergeID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here the ID that you want to " + "replace with Hull contour", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + mergeID_prompt.exec_() + if mergeID_prompt.cancel: + return + else: + ID = mergeID_prompt.EntryID + + if ID in posData.lab: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + obj_idx = posData.IDs.index(ID) + obj = posData.rp[obj_idx] + objMask = self.getObjImage(obj.image, obj.bbox) + localHull = skimage.morphology.convex_hull_image(objMask) + posData.lab[self.getObjSlice(obj.slice)][localHull] = ID + + self.update_rp() + self.updateAllImages() + + if not self.hullContToolButton.findChild(QAction).isChecked(): + self.hullContToolButton.setChecked(False) + + # Move label + elif right_click and self.moveLabelToolButton.isChecked(): + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + x, y = event.pos().x(), event.pos().y() + self.startMovingLabel(x, y) + + # Fill holes + elif right_click and self.fillHolesToolButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + clickedBkgrID = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here the ID that you want to " + "fill the holes of", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + clickedBkgrID.exec_() + if clickedBkgrID.cancel: + return + else: + ID = clickedBkgrID.EntryID + + # Merge IDs + elif right_click and self.mergeIDsButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + mergeID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here first ID that you want to merge", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + mergeID_prompt.exec_() + if mergeID_prompt.cancel: + self.mergeObjsTempLine.setData([], []) + return + else: + ID = mergeID_prompt.EntryID + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + self.firstID = ID + + obj_idx = posData.IDs_idxs[ID] + obj = posData.rp[obj_idx] + yc, xc = self.getObjCentroid(obj.centroid) + self.clickObjYc, self.clickObjXc = int(yc), int(xc) + + # Edit ID + elif right_click and self.editIDbutton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + editID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter here ID that you want to replace with a new one", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + editID_prompt.show(block=True) + + if editID_prompt.cancel: + return + else: + ID = editID_prompt.EntryID + + obj_idx = posData.IDs_idxs[ID] + y, x = posData.rp[obj_idx].centroid[-2:] + xdata, ydata = int(x), int(y) + + posData.disableAutoActivateViewerWindow = True + currentIDs = posData.IDs.copy() + self.setAllIDs(onlyVisited=True) + addPropagateCheckbox = ( + not self.isSnapshot + and posData.frame_i == self.navigateScrollBar.maximum() - 1 + and posData.frame_i < posData.SizeT - 1 + ) + editID = apps.EditIDDialog( + ID, + posData.IDs, + doNotShowAgain=self.doNotAskAgainExistingID, + parent=self, + entryID=self.getNearestLostObjID(y, x), + nextUniqueID=self.setBrushID(return_val=True), + allIDs=posData.allIDs, + addPropagateCheckbox=addPropagateCheckbox, + ) + editID.show(block=True) + if editID.cancel: + posData.disableAutoActivateViewerWindow = False + if not self.editIDbutton.findChild(QAction).isChecked(): + self.editIDbutton.setChecked(False) + return + + if editID.assignNewID: + self.assignNewIDfromClickedID(ID, event) + return + + if not self.doNotAskAgainExistingID: + self.editIDmergeIDs = editID.mergeWithExistingID + self.doNotAskAgainExistingID = editID.doNotAskAgainExistingID + + self.applyEditID( + ID, + currentIDs, + editID.how, + x, + y, + shift=shift, + doPropagateUnvisited=editID.doPropagateFutureFrames, + ) + + elif (right_click or left_click) and self.keepIDsButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + keepID_win = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to keep", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + keepID_win.exec_() + if keepID_win.cancel: + return + else: + ID = keepID_win.EntryID + + if ID in self.keptObjectsIDs: + self.keptObjectsIDs.remove(ID) + self.clearHighlightedText() + else: + self.keptObjectsIDs.append(ID) + self.highlightLabelID(ID) + + self.updateTempLayerKeepIDs() + + # Annotate cell as removed from the analysis + elif right_click and self.binCellButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + binID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to remove from the analysis", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + binID_prompt.exec_() + if binID_prompt.cancel: + return + else: + ID = binID_prompt.EntryID + + # Ask to propagate change to all future visited frames + key = "Exclude cell from analysis" + askAction = self.askHowFutureFramesActions[key] + doNotShow = not askAction.isChecked() + (UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain) = ( + self.propagateChange( + ID, + key, + doNotShow, + posData.UndoFutFrames_BinID, + posData.applyFutFrames_BinID, + ) + ) + + if UndoFutFrames is None: + # User cancelled the process + return + + posData.doNotShowAgain_BinID = doNotShowAgain + posData.UndoFutFrames_BinID = UndoFutFrames + posData.applyFutFrames_BinID = applyFutFrames + + self.current_frame_i = posData.frame_i + + # Apply Exclude cell from analysis to future frames if requested + if applyFutFrames: + # Store current data before going to future frames + self.store_data() + for i in range(posData.frame_i + 1, endFrame_i + 1): + posData.frame_i = i + self.get_data() + if ID in posData.binnedIDs: + posData.binnedIDs.remove(ID) + else: + posData.binnedIDs.add(ID) + self.update_rp_metadata(draw=False) + self.store_data(autosave=i == endFrame_i) + + self.app.restoreOverrideCursor() + + # Back to current frame + if applyFutFrames: + posData.frame_i = self.current_frame_i + self.get_data() + + # Store undo state before modifying stuff + self.storeUndoRedoStates(UndoFutFrames) + + if ID in posData.binnedIDs: + posData.binnedIDs.remove(ID) + else: + posData.binnedIDs.add(ID) + + self.annotate_rip_and_bin_IDs(updateLabel=True) + + # Gray out ore restore binned ID + self.updateLookuptable() + + if not self.binCellButton.findChild(QAction).isChecked(): + self.binCellButton.setChecked(False) + + # Annotate cell as dead + elif right_click and self.ripCellButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(self.get_2Dlab(posData.lab), y, x) + ripID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to annotate as dead", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + ripID_prompt.exec_() + if ripID_prompt.cancel: + return + else: + ID = ripID_prompt.EntryID + + # Ask to propagate change to all future visited frames + key = "Annotate cell as dead" + askAction = self.askHowFutureFramesActions[key] + doNotShow = not askAction.isChecked() + (UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain) = ( + self.propagateChange( + ID, + key, + doNotShow, + posData.UndoFutFrames_RipID, + posData.applyFutFrames_RipID, + ) + ) + + if UndoFutFrames is None: + return + + posData.doNotShowAgain_RipID = doNotShowAgain + posData.UndoFutFrames_RipID = UndoFutFrames + posData.applyFutFrames_RipID = applyFutFrames + + self.current_frame_i = posData.frame_i + + # Apply Edit ID to future frames if requested + if applyFutFrames: + # Store current data before going to future frames + self.store_data() + for i in range(posData.frame_i + 1, endFrame_i + 1): + posData.frame_i = i + self.get_data() + if ID in posData.ripIDs: + posData.ripIDs.remove(ID) + else: + posData.ripIDs.add(ID) + self.update_rp_metadata(draw=False) + self.store_data(autosave=i == endFrame_i) + self.app.restoreOverrideCursor() + + # Back to current frame + if applyFutFrames: + posData.frame_i = self.current_frame_i + self.get_data() + + # Store undo state before modifying stuff + self.storeUndoRedoStates(UndoFutFrames) + + if ID in posData.ripIDs: + posData.ripIDs.remove(ID) + else: + posData.ripIDs.add(ID) + + self.annotate_rip_and_bin_IDs(updateLabel=True) + + # Gray out dead ID + self.updateLookuptable() + self.store_data() + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Annotate ID as dead") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Annotate ID as dead") + + if not self.ripCellButton.findChild(QAction).isChecked(): + self.ripCellButton.setChecked(False) + + def gui_mouseReleaseEventImg2(self, event): + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + return + + Y, X = self.get_2Dlab(posData.lab).shape + try: + x, y = event.pos().x(), event.pos().y() + except Exception as e: + return + + xdata, ydata = int(x), int(y) + if not utils.is_in_bounds(xdata, ydata, X, Y): + self.isMouseDragImg2 = False + self.updateAllImages() + return + + # Move label mouse released, update move + if self.isMovingLabel and self.moveLabelToolButton.isChecked(): + self.isMovingLabel = False + + # Update data (rp, etc) + self.update_rp() + + # Repeat tracking + self.tracking(enforce=True, assign_unique_new_IDs=False) + + self.updateAllImages() + + if not self.moveLabelToolButton.findChild(QAction).isChecked(): + self.moveLabelToolButton.setChecked(False) + + # Merge IDs + elif self.mergeIDsButton.isChecked(): + x, y = event.pos().x(), event.pos().y() + xdata, ydata = int(x), int(y) + lab2D = self.get_2Dlab(posData.lab) + ID = lab2D[ydata, xdata] + if ID == 0: + nearest_ID = core.nearest_nonzero_2D(lab2D, y, x) + mergeID_prompt = apps.QLineEditDialog( + title="Clicked on background", + msg="You clicked on the background.\n" + "Enter ID that you want to merge with ID " + f"{self.firstID}", + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + mergeID_prompt.exec_() + if mergeID_prompt.cancel: + return + else: + ID = mergeID_prompt.EntryID + obj_idx = posData.IDs_idxs[ID] + obj = posData.rp[obj_idx] + y2, x2 = self.getObjCentroid(obj.centroid) + self.mergeObjsTempLine.addPoint(x2, y2) + + xx, yy = self.mergeObjsTempLine.getData() + IDs_to_merge = lab2D[yy.astype(int), xx.astype(int)] + for ID in IDs_to_merge: + if ID == 0: + continue + posData.lab[posData.lab == ID] = self.firstID + + self.mergeObjsTempLine.setData([], []) + self.clickObjYc, self.clickObjXc = None, None + + # Update data (rp, etc) + self.update_rp() + + ask_back_prop = True + + if posData.frame_i == 0: + ask_back_prop = False + prev_IDs = [] + else: + prev_IDs = posData.allData_li[posData.frame_i - 1]["IDs"] + + if all(ID not in prev_IDs for ID in IDs_to_merge): + ask_back_prop = False + + if not self.isFrameCcaAnnotated() and ask_back_prop: + proceed = self.askPropagateChangePast(f"Merge IDs {IDs_to_merge}") + if proceed: + self.propagateMergeObjsPast(IDs_to_merge) + self.whitelistPropagateIDs( + only_future_frames=False, update_lab=True + ) # in the update_rp() call, this should also be done + + # Repeat tracking + self.tracking( + enforce=True, assign_unique_new_IDs=False, separateByLabel=False + ) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Merge IDs") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Merge IDs") + + if not self.mergeIDsButton.findChild(QAction).isChecked(): + self.mergeIDsButton.setChecked(False) + self.store_data() diff --git a/cellacdc/mixins/canvas_tool.py b/cellacdc/mixins/canvas_tool.py new file mode 100644 index 000000000..2a48c7f6e --- /dev/null +++ b/cellacdc/mixins/canvas_tool.py @@ -0,0 +1,11 @@ +"""View adapter for canvas tool interaction decisions.""" + +from __future__ import annotations + + +class CanvasTool: + """Extracted from guiWin.""" + + def storeManualSeparateDrawMode(self, mode): + self.df_settings.at["manual_separate_draw_mode", "value"] = mode + self.df_settings.to_csv(self.settings_csv_path) diff --git a/cellacdc/mixins/cell_cycle.py b/cellacdc/mixins/cell_cycle.py new file mode 100644 index 000000000..148486d55 --- /dev/null +++ b/cellacdc/mixins/cell_cycle.py @@ -0,0 +1,2979 @@ +"""Qt view adapter for cell-cycle annotation workflows.""" + +from __future__ import annotations + +import traceback +import uuid + +from tqdm import tqdm +import pandas as pd +from qtpy.QtCore import QMutex, QThread, QTimer, QWaitCondition +from qtpy.QtWidgets import QCheckBox, QMessageBox, QPushButton + +from cellacdc import ( + apps, + _warnings, + base_cca_dict, + disableWindow, + exception_handler, + html_utils, +) +from cellacdc import widgets, workers + +from .undo_redo import UndoRedo + + +class CellCycle(UndoRedo): + """Extracted from guiWin.""" + + def _getCcaCostMatrix(self, numCellsG1, numNewCells, IDsCellsG1, newIDs_contours): + posData = self.data[self.pos_i] + dataDict = posData.allData_li[posData.frame_i] + dist_matrix_df = dataDict.get("obj_to_obj_dist_cost_matrix_df") + if dist_matrix_df is None: + cost = np.full((numCellsG1, numNewCells), np.inf) + for obj in posData.rp: + ID = obj.label + try: + i = IDsCellsG1.index(ID) + except ValueError: + continue + + cont = self.getObjContours(obj) + i = IDsCellsG1.index(ID) + + # Get distance from cell in G1 and all other new cells + for j, newID_cont in enumerate(newIDs_contours): + min_dist, nearest_xy = self.nearest_point_2Dyx(cont, newID_cont) + cost[i, j] = min_dist + + return cost + + cost = dist_matrix_df.loc[IDsCellsG1, posData.new_IDs].values + + return cost + + def addIDBaseCca_df(self, posData, ID): + if ID <= 0: + # When calling update_cca_df_deletedIDs we add relative IDs + # but they could be -1 for cells in G1 + return + + _zip = zip( + self.cca_df_colnames, + self.cca_df_default_values, + ) + if posData.cca_df.empty: + posData.cca_df = pd.DataFrame({col: val for col, val in _zip}, index=[ID]) + else: + for col, val in _zip: + posData.cca_df.at[ID, col] = val + self.store_cca_df() + + def addMissingIDs_cca_df(self, posData): + base_cca_df = self.getBaseCca_df() + if posData.cca_df is None: + posData.cca_df = base_cca_df + return + + posData.cca_df = posData.cca_df.combine_first(base_cca_df) + + def annotateBudToDifferentMother(self): + """ + This function is used for correcting automatic mother-bud assignment. + + It can be called at any frame of the bud life. + + There are three cells involved: bud, current mother, new mother. + + Eligibility: + - User clicked first on a bud (checked at click time) + - User released mouse button on a cell in G1 (checked at release time) + - The new mother MUST be in G1 for all the frames of the bud life + --> if not warn + - The new mother MUST have appeared in current frame OR be already + in G1 in previous frame, otherwise there would be no G1 cycle + + Result: + - The bud only changes relative ID to the new mother + - The new mother changes relative ID and stage to 'S' + - The old mother changes its entire status to the status it had + before being assigned to the clicked bud + """ + posData = self.data[self.pos_i] + lab2D = self.get_2Dlab(posData.lab) + budID = lab2D[self.yClickBud, self.xClickBud] + new_mothID = lab2D[self.yClickMoth, self.xClickMoth] + + if budID == new_mothID: + return + + if not self.isSnapshot: + eligible = self.checkMothEligibility(budID, new_mothID) + if not eligible: + return + + budEligible = self.checkChangeMotherBudEligible(budID, posData.frame_i) + if not budEligible: + return + + # Allow partial initialization of cca_df with mouse + if posData.frame_i == 0: + newMothCcs = posData.cca_df.at[new_mothID, "cell_cycle_stage"] + if not newMothCcs == "G1": + err_msg = "You are assigning the bud to a cell that is not in G1!" + msg = QMessageBox() + msg.critical(self, "New mother not in G1!", err_msg, msg.Ok) + return + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(0, posData.cca_df, undoId) + currentRelID = posData.cca_df.at[budID, "relative_ID"] + if currentRelID in posData.cca_df.index: + posData.cca_df.at[currentRelID, "relative_ID"] = -1 + posData.cca_df.at[currentRelID, "generation_num"] = 2 + posData.cca_df.at[currentRelID, "cell_cycle_stage"] = "G1" + posData.cca_df.at[budID, "relationship"] = "bud" + posData.cca_df.at[budID, "generation_num"] = 0 + posData.cca_df.at[budID, "relative_ID"] = new_mothID + posData.cca_df.at[budID, "cell_cycle_stage"] = "S" + posData.cca_df.at[new_mothID, "relative_ID"] = budID + posData.cca_df.at[new_mothID, "generation_num"] = 2 + posData.cca_df.at[new_mothID, "cell_cycle_stage"] = "S" + self.updateAllImages() + self.store_cca_df() + return + + curr_mothID = posData.cca_df.at[budID, "relative_ID"] + if curr_mothID in posData.cca_df.index: + curr_moth_cca = self.getStatus_RelID_BeforeEmergence(budID, curr_mothID) + + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + # Correct current frames and update LabelItems + posData.cca_df.at[budID, "relative_ID"] = new_mothID + posData.cca_df.at[budID, "generation_num"] = 0 + posData.cca_df.at[budID, "relative_ID"] = new_mothID + posData.cca_df.at[budID, "relationship"] = "bud" + posData.cca_df.at[budID, "corrected_on_frame_i"] = posData.frame_i + posData.cca_df.at[budID, "cell_cycle_stage"] = "S" + + posData.cca_df.at[new_mothID, "relative_ID"] = budID + posData.cca_df.at[new_mothID, "cell_cycle_stage"] = "S" + posData.cca_df.at[new_mothID, "relationship"] = "mother" + + if curr_mothID in posData.cca_df.index: + # Cells with UNKNOWN history has relative's ID = -1 + # which is not an existing cell + posData.cca_df.loc[curr_mothID] = curr_moth_cca + + self.updateAllImages() + + # self.checkMultiBudMoth(draw=True) + self.store_cca_df() + proceed = self.checkMothersExcludedOrDead() + if not proceed: + # User clicked on cancel in the message box + self.UndoCca() + return + + if self.ccaTableWin is not None: + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + # Correct future frames + for i in range(posData.frame_i + 1, posData.SizeT): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + IDs = cca_df_i.index + if budID not in IDs or new_mothID not in IDs: + # For some reason ID disappeared from this frame + continue + + self.storeUndoRedoCca(i, cca_df_i, undoId) + bud_relationship = cca_df_i.at[budID, "relationship"] + bud_ccs = cca_df_i.at[budID, "cell_cycle_stage"] + + if bud_relationship == "mother" and bud_ccs == "S": + # The bud at the ith frame budded itself --> stop + break + + cca_df_i.at[budID, "relative_ID"] = new_mothID + cca_df_i.at[budID, "generation_num"] = 0 + cca_df_i.at[budID, "relative_ID"] = new_mothID + cca_df_i.at[budID, "relationship"] = "bud" + cca_df_i.at[budID, "cell_cycle_stage"] = "S" + + newMoth_bud_ccs = cca_df_i.at[new_mothID, "cell_cycle_stage"] + if newMoth_bud_ccs == "G1": + # Assign bud to new mother only if the new mother is in G1 + # This can happen if the bud already has a G1 annotated + cca_df_i.at[new_mothID, "relative_ID"] = budID + cca_df_i.at[new_mothID, "cell_cycle_stage"] = "S" + cca_df_i.at[new_mothID, "relationship"] = "mother" + + if curr_mothID in cca_df_i.index: + # Cells with UNKNOWN history has relative's ID = -1 + # which is not an existing cell + cca_df_i.loc[curr_mothID] = curr_moth_cca + + self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) + + # Correct past frames + for i in range(posData.frame_i - 1, -1, -1): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + + is_bud_existing = budID in cca_df_i.index + if not is_bud_existing: + # Bud was not emerged yet + break + + self.storeUndoRedoCca(i, cca_df_i, undoId) + cca_df_i.at[budID, "relative_ID"] = new_mothID + cca_df_i.at[budID, "generation_num"] = 0 + cca_df_i.at[budID, "relative_ID"] = new_mothID + cca_df_i.at[budID, "relationship"] = "bud" + cca_df_i.at[budID, "cell_cycle_stage"] = "S" + + cca_df_i.at[new_mothID, "relative_ID"] = budID + cca_df_i.at[new_mothID, "cell_cycle_stage"] = "S" + cca_df_i.at[new_mothID, "relationship"] = "mother" + + if curr_mothID in cca_df_i.index: + # Cells with UNKNOWN history has relative's ID = -1 + # which is not an existing cell + cca_df_i.loc[curr_mothID] = curr_moth_cca + + self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) + + self.enqAutosave() + + def annotateDivision(self, cca_df, ID, relID, frame_i=None): + # Correct as follows: + # For frame_i > 0 --> assign to G1 and +1 on generation number + # For frame == 0 --> reinitialize to unknown cells + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + self.annotateWillDivide(ID, relID) + + store = False + cca_df.at[ID, "cell_cycle_stage"] = "G1" + cca_df.at[relID, "cell_cycle_stage"] = "G1" + + if frame_i > 0: + gen_num_clickedID = cca_df.at[ID, "generation_num"] + cca_df.at[ID, "generation_num"] += 1 + cca_df.at[ID, "division_frame_i"] = frame_i + gen_num_relID = cca_df.at[relID, "generation_num"] + cca_df.at[relID, "generation_num"] = gen_num_relID + 1 + cca_df.at[relID, "division_frame_i"] = frame_i + if gen_num_clickedID < gen_num_relID: + cca_df.at[ID, "relationship"] = "mother" + else: + cca_df.at[relID, "relationship"] = "mother" + else: + cca_df.at[ID, "generation_num"] = 2 + cca_df.at[relID, "generation_num"] = 2 + + cca_df.at[ID, "division_frame_i"] = -1 + cca_df.at[relID, "division_frame_i"] = -1 + + cca_df.at[ID, "relationship"] = "mother" + cca_df.at[relID, "relationship"] = "mother" + + store = True + return store + + def annotateIsHistoryKnown(self, ID): + """ + This function is used for annotating that a cell has unknown or known + history. Cells with unknown history are for example the cells already + present in the first frame or cells that appear in the frame from + outside of the field of view. + + With this function we simply set 'is_history_known' to False. + When the users saves instead we update the entire staus of the cell + with unknown history with the function "updateIsHistoryKnown()" + """ + posData = self.data[self.pos_i] + is_history_known = posData.cca_df.at[ID, "is_history_known"] + relID = posData.cca_df.at[ID, "relative_ID"] + if relID in posData.cca_df.index: + relID_cca = self.getStatus_RelID_BeforeEmergence(ID, relID) + + if is_history_known: + # Save status of ID when emerged to allow undoing + statusID_whenEmerged = self.getStatusKnownHistoryBud(ID) + if statusID_whenEmerged is None: + return + posData.ccaStatus_whenEmerged[ID] = statusID_whenEmerged + + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + if ID not in posData.ccaStatus_whenEmerged: + self.warnSettingHistoryKnownCellsFirstFrame(ID) + return + + self.setHistoryKnowledge(ID, posData.cca_df) + + if relID in posData.cca_df.index: + # If the cell with unknown history has a relative ID assigned to it + # we set the cca of it to the status it had BEFORE the assignment + posData.cca_df.loc[relID] = relID_cca + + # Update cell cycle info LabelItems + obj_idx = posData.IDs.index(ID) + rp_ID = posData.rp[obj_idx] + + if relID in posData.IDs: + relObj_idx = posData.IDs.index(relID) + rp_relID = posData.rp[relObj_idx] + + self.setAllTextAnnotations() + self.drawAllMothBudLines() + + self.store_cca_df() + + if self.ccaTableWin is not None: + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + # Correct future frames + for i in range(posData.frame_i + 1, posData.SizeT): + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(i, cca_df_i, undoId) + IDs = cca_df_i.index + if ID not in IDs: + # For some reason ID disappeared from this frame + continue + else: + self.setHistoryKnowledge(ID, cca_df_i) + if relID in IDs: + cca_df_i.loc[relID] = relID_cca + self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) + + # Correct past frames + for i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(i, cca_df_i, undoId) + IDs = cca_df_i.index + if ID not in IDs: + # we reached frame where ID was not existing yet + break + else: + relID = cca_df_i.at[ID, "relative_ID"] + self.setHistoryKnowledge(ID, cca_df_i) + if relID in IDs: + cca_df_i.loc[relID] = relID_cca + self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) + + self.enqAutosave() + + def annotateWillDivide(self, ID, relID, frame_i=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + # Store in the past frames that division has been annotated + for past_frame_i in range(frame_i - 1, -1, -1): + past_cca_df = self.get_cca_df(frame_i=past_frame_i, return_df=True) + if past_cca_df is None: + return + + if ID not in past_cca_df.index: + # ID is a bud and is not emerged yet here + return + + if frame_i - 1 == past_frame_i: + # Get generation number at first iteration + gen_num = past_cca_df.at[ID, "generation_num"] + + if past_cca_df.at[ID, "generation_num"] != gen_num: + # ID is a mother and the cell cycle is finished here + return + + past_cca_df.at[ID, "will_divide"] = 1 + past_cca_df.at[relID, "will_divide"] = 1 + + self.store_cca_df(cca_df=past_cca_df, frame_i=past_frame_i, autosave=False) + + def applyManualCcaChangesFutureFrames(self, changes, stop_frame_i): + self.store_data(autosave=False) + posData = self.data[self.pos_i] + undoId = uuid.uuid4() + for i in range(posData.frame_i, stop_frame_i): + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(i, cca_df_i, undoId) + + for ID, changes_ID in changes.items(): + if ID not in cca_df_i.index: + continue + for col, (oldValue, newValue) in changes_ID.items(): + cca_df_i.at[ID, col] = newValue + self.store_cca_df(frame_i=i, cca_df=cca_df_i, autosave=False) + self.get_data() + self.updateAllImages() + + def attempt_auto_cca(self, enforceAll=False): + mode = str(self.modeComboBox.currentText()) + posData = self.data[self.pos_i] + + if mode == "Cell cycle analysis": + notEnoughG1Cells, proceed = self.autoCca_df(enforceAll=enforceAll) + if not proceed: + return notEnoughG1Cells, proceed + + # mode = str(self.modeComboBox.currentText()) + if posData.cca_df is None: # ??? + notEnoughG1Cells = False + proceed = True + return notEnoughG1Cells, proceed + if posData.cca_df.isna().any(axis=None): + raise ValueError("Cell cycle analysis table contains NaNs") + # self.checkMultiBudMoth() + proceed = self.checkMothersExcludedOrDead() + return notEnoughG1Cells, proceed + + elif mode == "Normal division: Lineage tree": + self.autoLinTree_df() + notEnoughG1Cells = False + proceed = True + return notEnoughG1Cells, proceed + + else: + notEnoughG1Cells = False + proceed = True + return notEnoughG1Cells, proceed + + def autoAssignBud_YeastMate(self): + if not self.is_win: + txt = ( + "YeastMate is available only on Windows OS." + "We are working on expading support also on macOS and Linux.\n\n" + "Thank you for your patience!" + ) + msg = QMessageBox() + msg.critical(self, "Supported only on Windows", txt, msg.Ok) + return + + model_name = "YeastMate" + idx = self.modelNames.index(model_name) + + self.titleLabel.setText( + f"{model_name} is thinking... (check progress in terminal/console)", + color=self.titleColor, + ) + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + posData = self.data[self.pos_i] + # Check if model needs to be imported + acdcSegment = self.acdcSegment_li[idx] + if acdcSegment is None: + acdcSegment = utils.import_segment_module(model_name) + self.acdcSegment_li[idx] = acdcSegment + + # Read all models parameters + init_params, segment_params = utils.getModelArgSpec(acdcSegment) + # Prompt user to enter the model parameters + try: + url = acdcSegment.url_help() + except AttributeError: + url = None + + _SizeZ = None + if self.isSegm3D: + _SizeZ = posData.SizeZ + win = apps.QDialogModelParams( + init_params, + segment_params, + model_name, + url=url, + posData=posData, + df_metadata=posData.metadata_df, + ) + win.exec_() + if win.cancel: + self.titleLabel.setText("Segmentation aborted.") + return + + use_gpu = win.init_kwargs.get("gpu", False) + proceed = utils.check_gpu_available(model_name, use_gpu, qparent=self) + if not proceed: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + + self.model_kwargs = win.model_kwargs + model = utils.init_segm_model(acdcSegment, posData, win.init_kwargs) + if model is None: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + try: + model.setupLogger(self.logger) + except Exception as e: + pass + + self.models[idx] = model + + img = self.getDisplayedImg1() + + posData.cca_df = model.predictCcaState(img, posData.lab) + self.store_data() + self.updateAllImages() + + self.titleLabel.setText("Budding event prediction done.", color="g") + + def autoCca_df(self, enforceAll=False): + """ + Assign each bud to a mother with scipy linear sum assignment + (Hungarian or Munkres algorithm). First we build a cost matrix where + each (i, j) element is the minimum distance between bud i and mother j. + Then we minimize the cost of assigning each bud to a mother, and finally + we write the assignment info into cca_df + """ + proceed = True + notEnoughG1Cells = False + ScellsGone = False + + posData = self.data[self.pos_i] + + # Skip cca if not the right mode + mode = str(self.modeComboBox.currentText()) + if mode.find("Cell cycle") == -1: + return notEnoughG1Cells, proceed + + # Make sure that this is a visited frame in segmentation tracking mode + if posData.allData_li[posData.frame_i]["labels"] is None: + proceed = self.warnFrameNeverVisitedSegmMode() + return notEnoughG1Cells, proceed + + # Determine if this is the last visited frame for repeating + # bud assignment on non manually correct (corrected_on_frame_i>0) buds. + # The idea is that the user could have assigned division on a cell + # by going previous and we want to check if this cell could be a + # "better" mother for those non manually corrected buds + curr_df = posData.allData_li[posData.frame_i]["acdc_df"] + isLastVisitedAgain = self.isLastVisitedAgainCca(curr_df, enforceAll=enforceAll) + + frameAlreadyAnnotated = ( + posData.cca_df is not None and not enforceAll and not isLastVisitedAgain + ) + # Use stored cca_df and do not modify it with automatic stuff + if frameAlreadyAnnotated: + return notEnoughG1Cells, proceed + + # Keep only correctedAssignIDs if requested + # For the last visited frame we perform assignment again only on + # IDs where we didn't manually correct assignment + correctedAssignIDs = set() + if isLastVisitedAgain and not enforceAll: + try: + correctedAssignIDs = curr_df[curr_df["corrected_on_frame_i"] > 0].index + except Exception as e: + correctedAssignIDs = [] + posData.new_IDs = [ + ID for ID in posData.new_IDs if ID not in correctedAssignIDs + ] + + # Check if new IDs exist some time in the past + found_cca_df_IDs = self.checkCcaPastFramesNewIDs() + + # Check if there are some S cells that disappeared + abort, automaticallyDividedIDs = self.checkScellsGone() + if abort: + notEnoughG1Cells = False + proceed = False + return notEnoughG1Cells, proceed + + # Get previous dataframe + acdc_df = posData.allData_li[posData.frame_i - 1]["acdc_df"] + prev_cca_df = acdc_df[self.cca_df_colnames].copy() + + if posData.cca_df is None: + posData.cca_df = prev_cca_df.copy() + else: + posData.cca_df = curr_df[self.cca_df_colnames].copy() + + # concatenate new IDs found in past frames (before frame_i-1) + if found_cca_df_IDs is not None: + cca_df = pd.concat([posData.cca_df, *found_cca_df_IDs]) + unique_idx = ~cca_df.index.duplicated(keep="first") + posData.cca_df = cca_df[unique_idx] + + # If there are no new IDs we are done + if not posData.new_IDs: + proceed = True + self.store_cca_df() + return notEnoughG1Cells, proceed + + # Get cells in G1 (exclude dead) and check if there are enough cells in G1 + try: + prev_df_G1 = prev_cca_df[prev_cca_df["cell_cycle_stage"] == "G1"] + prev_df_G1 = prev_df_G1[~acdc_df.loc[prev_df_G1.index]["is_cell_dead"]] + IDsCellsG1 = set(prev_df_G1.index) + except Exception as err: + IDsCellsG1 = set() + + if isLastVisitedAgain or enforceAll: + # If we are repeating auto cca for last visited frame + # then we also add the cells in G1 that appears in current frame + # and we remove the ones that are already in S in current frame + # if they were manually corrected (i.e., they cannot be mother). + # Note that potential mother cells must be either appearing in + # current frame or in G1 also at previous frame. + # If we would consider cells that are in G1 at current frame + # but not in previous frame, assigning a bud to it would + # result in no G1 at all for the mother cell. + df_G1 = posData.cca_df[posData.cca_df["cell_cycle_stage"] == "G1"] + current_G1_IDs = df_G1.index + new_cell_G1 = [ID for ID in current_G1_IDs if ID not in prev_cca_df.index] + IDsCellsG1.update(new_cell_G1) + cells_S_current = posData.cca_df[ + (posData.cca_df["cell_cycle_stage"] == "S") + & (posData.cca_df["corrected_on_frame_i"] == posData.frame_i) + ].index + IDsCellsG1 = IDsCellsG1 - set(cells_S_current) + + # Remove cells that disappeared + IDsCellsG1 = [ID for ID in IDsCellsG1 if ID in posData.IDs] + + numCellsG1 = len(IDsCellsG1) + numNewCells = len(posData.new_IDs) + if numCellsG1 < numNewCells: + notEnoughG1Cells, proceed = self.handleNoCellsInG1(numCellsG1, numNewCells) + return notEnoughG1Cells, proceed + + # Compute new IDs contours + newIDs_contours = [] + for obj in posData.rp: + ID = obj.label + if ID in posData.new_IDs: + cont = self.getObjContours(obj) + newIDs_contours.append(cont) + + # Compute cost matrix + cost = self._getCcaCostMatrix( + numCellsG1, numNewCells, IDsCellsG1, newIDs_contours + ) + + # Run hungarian (munkres) assignment algorithm + row_idx, col_idx = scipy.optimize.linear_sum_assignment(cost) + + # New mother cells + newMothIDs = {IDsCellsG1[i] for i in row_idx} + + # Assign buds to mothers + for i, j in zip(row_idx, col_idx): + mothID = IDsCellsG1[i] + budID = posData.new_IDs[j] + + relID = None + # If we are repeating assignment for the bud then we also have to + # correct the possibily wrong mother --> it goes back to + # G1 if it's not a mother that we assign now + if budID in posData.cca_df.index: + relID = posData.cca_df.at[budID, "relative_ID"] + if relID in prev_cca_df.index and relID not in newMothIDs: + posData.cca_df.loc[relID] = prev_cca_df.loc[relID] + + posData.cca_df.at[mothID, "relative_ID"] = budID + posData.cca_df.at[mothID, "cell_cycle_stage"] = "S" + + bud_cca_dict = base_cca_dict.copy() + bud_cca_dict["cell_cycle_stage"] = "S" + bud_cca_dict["generation_num"] = 0 + bud_cca_dict["relative_ID"] = mothID + bud_cca_dict["relationship"] = "bud" + bud_cca_dict["emerg_frame_i"] = posData.frame_i + bud_cca_dict["is_history_known"] = True + bud_cca_dict["corrected_on_frame_i"] = -1 + posData.cca_df.loc[budID] = pd.Series(bud_cca_dict) + + # Keep only existing IDs + posData.cca_df = posData.cca_df.loc[posData.IDs] + + self.store_cca_df() + proceed = True + return notEnoughG1Cells, proceed + + def blinkPairingItem(self): + if self.blinkPairingItemTimer.flag: + opacity = 0.3 + self.blinkPairingItemTimer.flag = False + else: + opacity = 1.0 + self.blinkPairingItemTimer.flag = True + self.warnPairingItem.setOpacity(opacity) + + def ccaCheckerStopChecking(self): + if not self.ccaCheckerRunning: + return + + self.ccaIntegrityCheckerWorker.clearQueue() + + if self.ccaIntegrityCheckerWorker.isChecking: + self.ccaIntegrityCheckerWorker.abortChecking = True + + def ccaCheckerWorkerClosed(self, worker): + self.logger.info("Cell cycle annotations integrity checker stopped.") + self.ccaCheckerRunning = False + + def ccaCheckerWorkerDone(self): + self.setStatusBarLabel(log=False) + + def ccaIntegrCheckerToggled(self, checked): + self.df_settings.at["is_cca_integrity_checker_activated", "value"] = int( + checked + ) + self.df_settings.to_csv(self.settings_csv_path) + mode = self.modeComboBox.currentText() + if mode != "Cell cycle analysis": + return + + if checked: + self.startCcaIntegrityCheckerWorker() + else: + self.disableCcaIntegrityChecker() + + def checkCcaPastFramesNewIDs(self): + posData = self.data[self.pos_i] + if not posData.new_IDs: + return + + found_cca_df_IDs = [] + for frame_i in range(posData.frame_i - 2, -1, -1): + acdc_df = posData.allData_li[frame_i]["acdc_df"] + cca_df_i = acdc_df[self.cca_df_colnames] + intersect_idx = cca_df_i.index.intersection(posData.new_IDs) + cca_df_i = cca_df_i.loc[intersect_idx] + if cca_df_i.empty: + continue + found_cca_df_IDs.append(cca_df_i) + + # Remove IDs found in past frames from new_IDs list + newIDs = np.array(posData.new_IDs, dtype=np.uint32) + mask_index = np.in1d(newIDs, cca_df_i.index) + posData.new_IDs = list(newIDs[~mask_index]) + if not posData.new_IDs: + return found_cca_df_IDs + return found_cca_df_IDs + + def checkChangeMotherBudEligible(self, budID, frame_i): + result = self._checkBudFutureNoDivision(budID, frame_i) + if result is None: + return True + + self.warnBudAnnotatedDividedInFuture( + budID, *result, action="change mother cell" + ) + return False + + def checkDivisionCanBeUndone(self, ID, relID): + """Check that division annotation can be undone (see Notes section) + + Parameters + ---------- + ID : int + Cell ID of the clicked cell in G1 + relID : _type_ + Relative ID of the cell that was clicked + + Notes + ----- + Division annotation can be undone only if `relID` is also in G1 for the + entire duration of the correction + """ + posData = self.data[self.pos_i] + + ccs_relID = posData.cca_df.at[relID, "cell_cycle_stage"] + if ccs_relID == "S": + return posData.frame_i + + # Check future frames + for future_i in range(posData.frame_i + 1, posData.SizeT): + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + ccs_relID = cca_df_i.at[relID, "cell_cycle_stage"] + if ccs_relID == "S": + return future_i + + # Check past frames + for past_i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + if ID not in cca_df_i.index or relID not in cca_df_i.index: + # Bud did not exist at frame_i = i + break + + ccs = cca_df_i.at[ID, "cell_cycle_stage"] + if ccs == "S": + break + + ccs_relID = cca_df_i.at[relID, "cell_cycle_stage"] + if ccs_relID == "S": + return future_i + + def checkMothEligibility(self, budID, new_mothID): + """ + Check that the new mother is in G1 for the entire life of the bud + and that the G1 duration is > than 1 frame + """ + last_cca_frame_i = self.navigateScrollBar.maximum() - 1 + posData = self.data[self.pos_i] + eligible = True + + # Check future frames + G1_duration_future = 0 + for future_i in range(posData.frame_i, posData.SizeT): + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + + if cca_df_i is None: + # ith frame was not visited yet + break + + if budID not in cca_df_i.index: + # Bud disappeared + break + + is_still_bud = cca_df_i.at[budID, "relationship"] == "bud" + if not is_still_bud: + break + + ccs = cca_df_i.at[new_mothID, "cell_cycle_stage"] + if ccs != "G1": + cancel, apply = self.warnMotherNotEligible( + new_mothID, budID, future_i, "not_G1_in_the_future" + ) + if apply: + self.resetCcaFuture(future_i) + break + isG1singleFrame = G1_duration_future == 1 + isFutureFrameNotLastAnnot = future_i != last_cca_frame_i + if cancel or (isG1singleFrame and isFutureFrameNotLastAnnot): + eligible = False + return eligible + + G1_duration_future += 1 + + # Check past frames + for past_i in range(posData.frame_i - 1, -1, -1): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + + is_bud_existing = budID in cca_df_i.index + is_moth_existing = new_mothID in cca_df_i.index + + if not is_moth_existing: + # Mother not existing because it appeared from outside FOV + break + + ccs = cca_df_i.at[new_mothID, "cell_cycle_stage"] + if ccs != "G1" and is_bud_existing: + # Requested mother not in G1 in the past + # during the life of the bud (is_bud_existing = True) + self.warnMotherNotEligible( + new_mothID, budID, past_i, "not_G1_in_the_past" + ) + eligible = False + return eligible + + if not is_bud_existing: + # Bud stop existing --> check that mother is still in G1 + if ccs != "G1": + eligible = False + self.warnMotherNotEligible( + new_mothID, budID, past_i, "single_frame_G1_duration" + ) + break + + return eligible + + def checkMothersExcludedOrDead(self): + try: + posData = self.data[self.pos_i] + buds_df = posData.cca_df[ + (posData.cca_df.relationship == "bud") + & (posData.cca_df.emerg_frame_i == posData.frame_i) + ] + acdc_df_i = posData.allData_li[posData.frame_i]["acdc_df"] + moth_df = acdc_df_i.loc[buds_df.relative_ID.to_list()] + excluded_df = moth_df[ + (moth_df.is_cell_dead > 0) | (moth_df.is_cell_excluded > 0) + ] + excludedMothIDs = excluded_df.index.to_list() + if not excludedMothIDs: + self.stopBlinkingPairItem() + return True + budIDsOfExcludedMoth = excluded_df.relative_ID.to_list() + proceed = self.warnDeadOrExcludedMothers( + budIDsOfExcludedMoth, excludedMothIDs + ) + return proceed + except Exception as e: + self.logger.info(traceback.format_exc()) + print("-" * 100) + self.logger.warning("Checking if mother cell is excluded or dead failed.") + print("^" * 100) + return False + + def checkScellsGone(self): + """Check if there are cells in S phase whose relative disappear in + current frame. Allow user to choose between automatically assign + division to these cells or cancel and not visit the frame. + + Returns + ------- + bool + False if there are no cells disappeared or the user decided + to accept automatic division. + """ + automaticallyDividedIDs = [] + + mode = str(self.modeComboBox.currentText()) + if mode.find("Cell cycle") == -1: + # No cell cycle analysis mode --> do nothing + return False, automaticallyDividedIDs + + posData = self.data[self.pos_i] + + if posData.allData_li[posData.frame_i]["labels"] is None: + # Frame never visited/checked in segm mode --> autoCca_df will raise + # a critical message + return False, automaticallyDividedIDs + + # Check if there are S cells that either only mother or only + # bud disappeared and automatically assign division to it + # or abort visiting this frame + prev_acdc_df = posData.allData_li[posData.frame_i - 1]["acdc_df"] + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + prev_cca_df = prev_acdc_df[self.cca_df_colnames].copy() + + ScellsIDsGone = [] + for ccSeries in prev_cca_df.itertuples(): + ID = ccSeries.Index + ccs = ccSeries.cell_cycle_stage + if ccs != "S": + continue + + relID = ccSeries.relative_ID + if relID == -1: + continue + + # Check is relID is gone while ID stays + if relID not in posData.IDs and ID in posData.IDs: + ScellsIDsGone.append(relID) + + if not ScellsIDsGone: + # No cells in S that disappears --> do nothing + return False, automaticallyDividedIDs + + self.highlightNewIDs_ccaFailed(ScellsIDsGone, rp=prev_rp) + proceed = self.warnScellsGone(ScellsIDsGone, posData.frame_i) + self.clearLostObjContoursItems() + + if not proceed: + return True, automaticallyDividedIDs + + for IDgone in ScellsIDsGone: + relID = prev_cca_df.at[IDgone, "relative_ID"] + self.annotateDisappearedBeforeDivision(relID, IDgone, prev_cca_df) + self.annotateDivision( + prev_cca_df, IDgone, relID, frame_i=posData.frame_i - 1 + ) + self.annotateDivisionCurrentFrameRelativeIDgone(relID) + automaticallyDividedIDs.append(relID) + + self.store_cca_df(frame_i=posData.frame_i - 1, cca_df=prev_cca_df) + + return False, automaticallyDividedIDs + + def checkSwapMothersEligibility(self): + posData = self.data[self.pos_i] + + lab2D = self.get_2Dlab(posData.lab) + budID = lab2D[self.yClickBud, self.xClickBud] + otherMothID = lab2D[self.yClickMoth, self.xClickMoth] + mothID = posData.cca_df.at[budID, "relative_ID"] + otherBudID = posData.cca_df.at[otherMothID, "relative_ID"] + + for _budID in (budID, otherBudID): + result = self._checkBudFutureNoDivision(_budID, posData.frame_i) + if result is None: + continue + + self.warnBudAnnotatedDividedInFuture(_budID, *result) + return + + correct_pairings = {otherBudID: mothID, budID: otherMothID} + wrong_pairings = {mothID: budID, otherMothID: otherBudID} + for correctBudID, correctMothID in correct_pairings.items(): + wrongBudID = wrong_pairings[correctMothID] + frame_no_G1 = self._checkMothInG1beforeBudEmergence( + correctMothID, correctBudID, wrongBudID, posData.frame_i + ) + if frame_no_G1 is None: + continue + + self.warnMotherNotAtLeastOneFrameG1( + correctBudID, correctMothID, frame_no_G1 + ) + return + + return budID, otherBudID, otherMothID, mothID + + def disableCcaIntegrityChecker(self): + self.stopCcaIntegrityCheckerWorker() + + def enqCcaIntegrityChecker(self): + if not self.ccaCheckerRunning: + return + posData = self.data[self.pos_i] + self.ccaIntegrityCheckerWorker.enqueue(posData) + + def fixCcaDfAfterEdit(self, editTxt): + posData = self.data[self.pos_i] + if posData.cca_df is not None: + # For snapshot mode we fix or reinit cca_df depending on the edit + self.update_cca_df_snapshots(editTxt, posData) + self.store_data() + + def fixWillDivide(self, warning_txt, IDs_will_divide_wrong): + self.logger.info(warning_txt) + self.logger.info("Fixing `will_divide` information...") + + global_cca_df = self.getConcatCcaDf() + global_cca_df = global_cca_df.reset_index().set_index( + ["Cell_ID", "generation_num"] + ) + global_cca_df.loc[IDs_will_divide_wrong, "will_divide"] = 0 + global_cca_df = global_cca_df.reset_index().set_index(["frame_i", "Cell_ID"]) + self.storeFromConcatCcaDf(global_cca_df) + + def getBaseCca_df(self, with_tree_cols=False): + posData = self.data[self.pos_i] + IDs = [obj.label for obj in posData.rp] + cca_df = core.getBaseCca_df(IDs, with_tree_cols=with_tree_cols) + return cca_df + + def getConcatCcaDf(self): + posData = self.data[self.pos_i] + cca_dfs = [] + keys = [] + for frame_i in range(0, posData.SizeT): + cca_df = self.get_cca_df(frame_i=frame_i, return_df=True) + if cca_df is None: + break + + cca_dfs.append(cca_df) + keys.append(frame_i) + + if not cca_dfs: + return + + global_cca_df = pd.concat(cca_dfs, keys=keys, names=["frame_i"]) + return global_cca_df + + def get_cca_df(self, frame_i=None, return_df=False, debug=False): + # cca_df is None unless the metadata contains cell cycle annotations + # NOTE: cell cycle annotations are either from the current session + # or loaded from HDD in "initPosAttr" with a .question to the user + posData = self.data[self.pos_i] + cca_df = None + i = posData.frame_i if frame_i is None else frame_i + df = posData.allData_li[i]["acdc_df"] + if df is not None: + if "cell_cycle_stage" in df.columns: + cca_df = df[self.cca_df_colnames].copy() + + if cca_df is None and self.isSnapshot: + cca_df = self.getBaseCca_df() + posData.cca_df = cca_df + + if cca_df is not None: + cca_df = cca_df.dropna() + + if return_df: + return cca_df + else: + posData.cca_df = cca_df + + def get_last_cca_frame_i(self): + posData = self.data[self.pos_i] + + i = 0 + # Determine last annotated frame index + for i, dict_frame_i in enumerate(posData.allData_li): + df = dict_frame_i["acdc_df"] + if df is None: + break + elif "cell_cycle_stage" not in df.columns: + break + + last_cca_frame_i = i if i == 0 or i + 1 == len(posData.allData_li) else i - 1 + + return last_cca_frame_i + + def goToFrameNumber(self, frame_n): + posData = self.data[self.pos_i] + posData.frame_i = frame_n - 1 + self.get_data() + self.updateAllImages() + self.updateScrollbars() + + def handleNoCellsInG1(self, numCellsG1, numNewCells): + posData = self.data[self.pos_i] + self.highlightNewCellNotEnoughG1cells(posData.new_IDs) + continueAnyway = _warnings.warnNotEnoughG1Cells( + numCellsG1, posData.frame_i, numNewCells, qparent=self + ) + if continueAnyway: + notEnoughG1Cells = False + proceed = True + # Annotate the new IDs with unknown history + for ID in posData.new_IDs: + posData.cca_df.loc[ID] = pd.Series(base_cca_dict) + cca_df_ID = self.getStatusKnownHistoryBud(ID) + posData.ccaStatus_whenEmerged[ID] = cca_df_ID + else: + notEnoughG1Cells = True + proceed = False + + # Clear new cells annotations + self.ccaFailedScatterItem.setData([], []) + return notEnoughG1Cells, proceed + + def highlightIDs(self, IDs, pen): + pass + + def highlightNewCellNotEnoughG1cells(self, IDsCellsG1): + posData = self.data[self.pos_i] + for obj in posData.rp: + if obj.label not in IDsCellsG1: + continue + objContours = self.getObjContours(obj) + if objContours is not None: + xx = objContours[:, 0] + 0.5 + yy = objContours[:, 1] + 0.5 + self.ccaFailedScatterItem.addPoints(xx, yy) + self.textAnnot[0].addObjAnnotation(obj, "green", f"{obj.label}?", False) + + def highlightNewIDs_ccaFailed(self, IDsWithIssue, rp=None): + if rp is None: + posData = self.data[self.pos_i] + rp = posData.rp + for obj in rp: + if obj.label not in IDsWithIssue: + continue + self.setCcaIssueContour(obj) + + def initCca(self): + posData = self.data[self.pos_i] + last_tracked_i = self.get_last_tracked_i() + defaultMode = "Viewer" + if last_tracked_i == 0: + txt = html_utils.paragraph( + "On this dataset either you never checked that the segmentation " + "and tracking are correct or you did not save yet.

    " + 'If you already visited some frames with "Segmentation and Tracking" ' + 'mode save data before switching to "Cell cycle analysis mode".

    ' + "Otherwise you first have to check (and eventually correct) some frames " + 'in "Segmentation and Tracking" mode before proceeding ' + "with cell cycle analysis." + ) + msg = widgets.myMessageBox() + msg.critical(self, "Tracking was never checked", txt) + self.modeComboBox.setCurrentText(defaultMode) + return + + proceed = True + + last_cca_frame_i = self.get_last_cca_frame_i() + if last_cca_frame_i == 0: + # Remove undoable actions from segmentation mode + posData.UndoRedoStates[0] = [] + self.undoAction.setEnabled(False) + self.redoAction.setEnabled(False) + + if posData.frame_i > last_cca_frame_i: + # Prompt user to go to last annotated frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + The last annotated frame is frame {last_cca_frame_i + 1}.

    + Do you want to restart cell cycle analysis from frame + {last_cca_frame_i + 1}?
    + """) + _, goToFrameButton, stayButton = msg.warning( + self, + "Go to last annotated frame?", + txt, + buttonsTexts=( + "Cancel", + f"Yes, go to frame {last_cca_frame_i + 1}", + "No, stay on current frame", + ), + ) + if goToFrameButton == msg.clickedButton: + self.addMissingIDs_cca_df(posData) + self.store_cca_df() + msg = "Looking good!" + self.last_cca_frame_i = last_cca_frame_i + posData.frame_i = last_cca_frame_i + self.titleLabel.setText(msg, color=self.titleColor) + self.get_data() + self.addMissingIDs_cca_df(posData) + self.store_cca_df() + self.updateAllImages() + self.updateScrollbars() + elif stayButton == msg.clickedButton: + self.addMissingIDs_cca_df(posData) + self.store_cca_df() + self.initMissingFramesCca(last_cca_frame_i, posData.frame_i) + last_cca_frame_i = posData.frame_i + msg = "Cell cycle analysis initialised!" + self.titleLabel.setText(msg, color="g") + elif msg.cancel: + msg = "Cell cycle analysis aborted." + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + self.modeComboBox.setCurrentText(defaultMode) + proceed = False + return + elif posData.frame_i < last_cca_frame_i: + # Prompt user to go to last annotated frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + The last annotated frame is frame {last_cca_frame_i + 1}.

    + Do you want to restart cell cycle analysis from frame + {last_cca_frame_i + 1}?
    + """) + yesButton, noButton, _ = msg.question( + self, + "Go to last annotated frame?", + txt, + buttonsTexts=("Yes", "No", "Cancel"), + ) + if msg.cancel: + msg = "Cell cycle analysis aborted." + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + self.modeComboBox.setCurrentText(defaultMode) + proceed = False + return + + self.addMissingIDs_cca_df(posData) + if msg.clickedButton == yesButton: + self.addMissingIDs_cca_df(posData) + msg = "Looking good!" + self.titleLabel.setText(msg, color=self.titleColor) + self.last_cca_frame_i = last_cca_frame_i + posData.frame_i = last_cca_frame_i + self.get_data() + self.addMissingIDs_cca_df(posData) + self.store_cca_df() + self.updateAllImages() + self.updateScrollbars() + else: + self.get_data() + self.addMissingIDs_cca_df(posData) + self.store_cca_df() + + self.last_cca_frame_i = last_cca_frame_i + + self.navigateScrollBar.setMaximum(last_cca_frame_i + 1) + self.navSpinBox.setMaximum(last_cca_frame_i + 1) + self.lastTrackedFrameLabel.setText( + f"Last cc annot. frame n. = {last_cca_frame_i + 1}" + ) + + if posData.cca_df is None: + posData.cca_df = self.getBaseCca_df() + self.store_cca_df() + msg = "Cell cycle analysis initialized!" + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + else: + self.get_cca_df() + + self.enqCcaIntegrityChecker() + + return proceed + + def initCcaIntegrityChecker(self): + posData = self.data[self.pos_i] + for frame_i, data_frame_i in enumerate(posData.allData_li): + lab = data_frame_i["labels"] + if lab is None: + break + + cca_df = self.get_cca_df(frame_i, return_df=True) + self.store_cca_df_checker(posData, frame_i, cca_df) + + self.enqCcaIntegrityChecker() + + def initMissingFramesCca(self, last_cca_frame_i, current_frame_i): + self.logger.info( + "Initialising cell cycle annotations of missing past frames..." + ) + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + + annotated_cca_dfs = [] + for frame_i in range(last_cca_frame_i + 1): + acdc_df = posData.allData_li[frame_i]["acdc_df"] + if "cell_cycle_stage" in acdc_df.columns: + continue + + acdc_df[self.cca_df_colnames] = "" + + annotated_cca_dfs = [ + posData.allData_li[i]["acdc_df"][self.cca_df_colnames] + for i in range(last_cca_frame_i + 1) + ] + keys = range(last_cca_frame_i + 1) + names = ["frame_i", "Cell_ID"] + annotated_cca_df = ( + pd.concat(annotated_cca_dfs, keys=keys, names=names) + .reset_index() + .set_index(["Cell_ID", "frame_i"]) + .sort_index() + ) + + last_annotated_cca_df = annotated_cca_df.groupby(level=0).last() + cca_df_colnames = self.cca_df_colnames + pbar = tqdm(total=current_frame_i - last_cca_frame_i + 1, ncols=100) + for frame_i in range(last_cca_frame_i, current_frame_i + 1): + posData.frame_i = frame_i + self.get_data() + cca_df = self.getBaseCca_df() + + idx = last_annotated_cca_df.index.intersection(cca_df.index) + cca_df.loc[idx, cca_df_colnames] = last_annotated_cca_df.loc[idx] + + self.store_cca_df(cca_df=cca_df, frame_i=frame_i, autosave=False) + pbar.update() + pbar.close() + + posData.frame_i = current_frame_i + self.get_data() + + def isCcaCheckerChecking(self): + if not self.ccaCheckerRunning: + return False + + return self.ccaIntegrityCheckerWorker.isChecking + + def isCurrentFrameCcaVisited(self): + posData = self.data[self.pos_i] + curr_df = posData.allData_li[posData.frame_i]["acdc_df"] + return curr_df is not None and "cell_cycle_stage" in curr_df.columns + + def isFrameCcaAnnotated(self): + posData = self.data[self.pos_i] + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + if acdc_df is None: + return False + + return "cell_cycle_stage" in acdc_df.columns + + def isLastVisitedAgainCca(self, curr_df, enforceAll=False): + # Determine if this is the last visited frame for repeating + # bud assignment on non manually corrected_on_frame_i buds. + # The idea is that the user could have assigned division on a cell + # by going previous and we want to check if this cell could be a + # "better" mother for those non manually corrected buds + posData = self.data[self.pos_i] + if curr_df is None: + return False + + if "cell_cycle_stage" not in curr_df.columns: + return False + + if enforceAll: + return False + + lastVisited = False + posData.new_IDs = [ + ID + for ID in posData.new_IDs + if curr_df.at[ID, "is_history_known"] + and curr_df.at[ID, "cell_cycle_stage"] == "S" + ] + if posData.frame_i + 1 < posData.SizeT: + next_df = posData.allData_li[posData.frame_i + 1]["acdc_df"] + if next_df is None: + lastVisited = True + else: + if "cell_cycle_stage" not in next_df.columns: + lastVisited = True + else: + lastVisited = True + + return lastVisited + + def manualCellCycleAnnotation(self, ID): + """ + This function is used for both annotating division or undoing the + annotation. It can be called on any frame. + + If we annotate division (right click on a cell in S) then it will + check if there are future frames to correct. + Frames to correct are those frames where both the mother and the bud + are annotated as S phase cells. + In this case we assign all those frames to G1, relationship to mother, + and +1 generation number + + If we undo the annotation (right click on a cell in G1) then it will + correct both past and future annotated frames (if present). + Frames to correct are those frames where both the mother and the bud + are annotated as G1 phase cells. + In this case we assign all those frames to G1, relationship back to + bud, and -1 generation number + """ + posData = self.data[self.pos_i] + + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + # Correct current frame + clicked_ccs = posData.cca_df.at[ID, "cell_cycle_stage"] + relID = posData.cca_df.at[ID, "relative_ID"] + + if relID not in posData.IDs: + return + + if clicked_ccs == "G1" and posData.frame_i == 0: + # We do not allow undoing division annotation on first frame + return + + if clicked_ccs == "G1": + issue_frame_i = self.checkDivisionCanBeUndone(ID, relID) + if issue_frame_i is not None: + _warnings.warnDivisionAnnotationCannotBeUndone( + ID, relID, issue_frame_i, qparent=self + ) + return + + if clicked_ccs == "S": + self.annotateDivision(posData.cca_df, ID, relID) + self.store_cca_df() + else: + self.undoDivisionAnnotation(posData.cca_df, ID, relID) + self.store_cca_df() + + # Update cell cycle info LabelItems + self.ax1_newMothBudLinesItem.setData([], []) + self.ax1_oldMothBudLinesItem.setData([], []) + self.ax2_newMothBudLinesItem.setData([], []) + self.ax2_oldMothBudLinesItem.setData([], []) + self.drawAllMothBudLines() + self.setAllTextAnnotations() + + if self.ccaTableWin is not None: + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + # Correct future frames + for future_i in range(posData.frame_i + 1, posData.SizeT): + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(future_i, cca_df_i, undoId) + IDs = cca_df_i.index + if ID not in IDs: + # For some reason ID disappeared from this frame + continue + + ccs = cca_df_i.at[ID, "cell_cycle_stage"] + relID = cca_df_i.at[ID, "relative_ID"] + if clicked_ccs == "S": + if ccs == "G1": + # Cell is in G1 in the future again so stop annotating + break + self.annotateDivision(cca_df_i, ID, relID) + self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) + elif ccs == "S": + # Cell is in S in the future again so stop undoing (break) + # also leave a 1 frame duration G1 to avoid a continuous + # S phase + self.annotateDivision(cca_df_i, ID, relID) + self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) + break + else: + self.undoDivisionAnnotation(cca_df_i, ID, relID) + self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) + + # Correct past frames + for past_i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + if ID not in cca_df_i.index or relID not in cca_df_i.index: + # Bud did not exist at frame_i = i + break + + self.storeUndoRedoCca(past_i, cca_df_i, undoId) + ccs = cca_df_i.at[ID, "cell_cycle_stage"] + relID = cca_df_i.at[ID, "relative_ID"] + if ccs == "S": + # We correct only those frames in which the ID was in 'G1' + break + else: + store = self.undoDivisionAnnotation(cca_df_i, ID, relID) + self.store_cca_df(frame_i=past_i, cca_df=cca_df_i, autosave=False) + + self.enqAutosave() + + def manualEditCca(self, checked=True): + posData = self.data[self.pos_i] + editCcaWidget = apps.editCcaTableWidget( + posData.cca_df, posData.SizeT, current_frame_i=posData.frame_i, parent=self + ) + editCcaWidget.sigApplyChangesFutureFrames.connect( + self.applyManualCcaChangesFutureFrames + ) + editCcaWidget.exec_() + if editCcaWidget.cancel: + return + posData.cca_df = editCcaWidget.cca_df + self.store_cca_df() + # self.checkMultiBudMoth() + self.updateAllImages() + + def manualEditCcaToolbarActionTriggered(self): + self.manualEditCca() + + def nearest_point_2Dyx(self, points, all_others): + """ + Given 2D array of [y, x] coordinates points and all_others return the + [y, x] coordinates of the two points (one from points and one from all_others) + that have the absolute minimum distance + """ + # Compute 3D array where each ith row of each kth page is the element-wise + # difference between kth row of points and ith row in all_others array. + # (i.e. diff[k,i] = points[k] - all_others[i]) + diff = points[:, np.newaxis] - all_others + # Compute 2D array of distances where + # dist[i, j] = euclidean dist (points[i],all_others[j]) + dist = np.linalg.norm(diff, axis=2) + # Compute i, j indexes of the absolute minimum distance + i, j = np.unravel_index(dist.argmin(), dist.shape) + nearest_point = all_others[j] + point = points[i] + min_dist = np.min(dist) + return min_dist, nearest_point + + def onMotherNotInG1(self, mothID): + txt = html_utils.paragraph( + f"You clicked on ID={mothID} which is NOT in G1

    " + "Do you want to proceed with swapping the mother cells?

    " + "NOTE: To assign a bud start by clicking on the bud " + "and release on a cell in G1" + ) + msg = widgets.myMessageBox() + swapMothersButton = widgets.reloadPushButton("Swap mother cells") + _, swapMothersButton = msg.warning( + self, + "Released on a cell NOT in G1", + txt, + buttonsTexts=("Cancel", swapMothersButton), + ) + if msg.cancel: + return + + pairings = self.checkSwapMothersEligibility() + if pairings is None: + self.logger.info("Swapping mothers is not possible.") + return + + self.swapMothers(*pairings) + + def reInitCca(self): + if not self.isSnapshot: + txt = html_utils.paragraph( + "If you decide to continue ALL cell cycle annotations from " + "this frame to the end will be erased from current session " + "(saved data is not touched of course).

    " + "To annotate future frames again you will have to revisit them.

    " + "Do you want to continue?" + ) + msg = widgets.myMessageBox() + msg.warning( + self, "Re-initialize annnotations?", txt, buttonsTexts=("Cancel", "Yes") + ) + posData = self.data[self.pos_i] + if msg.cancel: + return + + # Reset all future frames + self.resetCcaFuture(posData.frame_i + 1) + if posData.frame_i == 0: + # Reset everything since we are on first frame + posData.cca_df = self.getBaseCca_df() + self.store_data() + self.updateAllImages() + self.navigateScrollBar.setMaximum(posData.frame_i + 1) + self.navSpinBox.setMaximum(posData.frame_i + 1) + else: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + posData = self.data[self.pos_i] + posData.cca_df = self.getBaseCca_df() + self.store_data() + self.updateAllImages() + + def removeCcaAnnotationsCurrentFrame(self): + posData = self.data[self.pos_i] + posData.cca_df = None + + posData.allData_li[posData.frame_i].pop("cca_df", None) + posData.allData_li[posData.frame_i].pop("cca_df_checker", None) + + df = posData.allData_li[posData.frame_i]["acdc_df"] + if df is None: + # No more saved info to delete + return False + + if "cell_cycle_stage" not in df.columns: + # No cell cycle info present + return False + + df = df.drop(columns=self.cca_df_colnames) + posData.allData_li[posData.frame_i]["acdc_df"] = df + + return True + + def repeatAutoCca(self): + # Do not allow automatic bud assignment if there are future + # frames that already contain anotations + posData = self.data[self.pos_i] + next_df = posData.allData_li[posData.frame_i + 1]["acdc_df"] + if next_df is not None: + if "cell_cycle_stage" in next_df.columns: + msg = QMessageBox() + warn_cca = msg.critical( + self, + "Future visited frames detected!", + "Automatic bud assignment CANNOT be performed becasue " + "there are future frames that already contain cell cycle " + "annotations. The behaviour in this case cannot be predicted.\n\n" + "We suggest assigning the bud manually OR use the " + '"Re-initialize cell cycle annotations" button which properly ' + "re-initialize future frames.", + msg.Ok, + ) + return + + correctedAssignIDs = posData.cca_df[ + posData.cca_df["corrected_on_frame_i"] >= 0 + ].index + NeverCorrectedAssignIDs = [ + ID for ID in posData.new_IDs if ID not in correctedAssignIDs + ] + + # Store cca_df temporarily if attempt_auto_cca fails + posData.cca_df_beforeRepeat = posData.cca_df.copy() + + if not all(NeverCorrectedAssignIDs): + notEnoughG1Cells, proceed = self.attempt_auto_cca() + if notEnoughG1Cells or not proceed: + posData.cca_df = posData.cca_df_beforeRepeat + else: + self.updateAllImages() + return + + msg = QMessageBox() + msg.setIcon(msg.Question) + msg.setText( + "Do you want to automatically assign buds to mother cells for " + "ALL the new cells in this frame (excluding cells with unknown history) " + "OR only the cells where you never clicked on?" + ) + msg.setDetailedText( + f"New cells that you never touched:\n\n{NeverCorrectedAssignIDs}" + ) + enforceAllButton = QPushButton("ALL new cells") + b = QPushButton("Only cells that I never corrected assignment") + msg.addButton(b, msg.YesRole) + msg.addButton(enforceAllButton, msg.NoRole) + msg.exec_() + if msg.clickedButton() == enforceAllButton: + notEnoughG1Cells, proceed = self.attempt_auto_cca(enforceAll=True) + else: + notEnoughG1Cells, proceed = self.attempt_auto_cca() + if notEnoughG1Cells or not proceed: + posData.cca_df = posData.cca_df_beforeRepeat + else: + self.updateAllImages() + + def resetCcaFuture(self, from_frame_i): + posData = self.data[self.pos_i] + self.last_cca_frame_i = from_frame_i - 1 + self.ccaCheckerStopChecking() + + self.setNavigateScrollBarMaximum() + for i in range(from_frame_i, posData.SizeT): + posData.allData_li[i].pop("cca_df", None) + posData.allData_li[i].pop("cca_df_checker", None) + + df = posData.allData_li[i]["acdc_df"] + if df is None: + # No more saved info to delete + break + + if "cell_cycle_stage" not in df.columns: + # No cell cycle info present + continue + + df = df.drop(columns=self.cca_df_colnames) + posData.allData_li[i]["acdc_df"] = df + + if posData.acdc_df is not None: + frames = posData.acdc_df.index.get_level_values(0) + if from_frame_i in frames: + posData.acdc_df = posData.acdc_df.loc[:from_frame_i] + + self.resetWillDivideInfo() + + def resetFutureCcaColCurrentFrame(self): + posData = self.data[self.pos_i] + + cca_df_S_mask = posData.cca_df.cell_cycle_stage == "S" + posData.cca_df.loc[cca_df_S_mask, "will_divide"] = 0 + + mothers_mask = (posData.cca_df.relationship == "mother") & cca_df_S_mask + bud_mask = posData.cca_df.relationship == "bud" + + posData.cca_df.loc[mothers_mask, "daughter_disappears_before_division"] = 0 + posData.cca_df.loc[bud_mask, "disappears_before_division"] = 0 + + cca_df = self.get_cca_df(frame_i=posData.frame_i, return_df=True) + if cca_df is not None: + cca_df_S_mask = cca_df.cell_cycle_stage == "S" + cca_df.loc[cca_df_S_mask, "will_divide"] = 0 + + mothers_mask = (cca_df.relationship == "mother") & cca_df_S_mask + bud_mask = cca_df.relationship == "bud" + + cca_df.loc[mothers_mask, "daughter_disappears_before_division"] = 0 + cca_df.loc[bud_mask, "disappears_before_division"] = 0 + + self.store_data() + + def resetWillDivideInfo(self): + global_cca_df = self.getConcatCcaDf() + if global_cca_df is None: + return + + global_cca_df = load._fix_will_divide(global_cca_df) + self.storeFromConcatCcaDf(global_cca_df) + + def setCcaIssueContour(self, obj): + objContours = self.getObjContours(obj, all_external=True) + for cont in objContours: + xx = cont[:, 0] + 0.5 + yy = cont[:, 1] + 0.5 + self.ax1_lostObjScatterItem.addPoints(xx, yy) + self.textAnnot[0].addObjAnnotation(obj, "lost_object", f"{obj.label}?", False) + + def startBlinkingPairingItem(self, budIDs, mothIDs): + self.ax1_newMothBudLinesItem.setOpacity(0.2) + self.ax1_oldMothBudLinesItem.setOpacity(0.2) + + posData = self.data[self.pos_i] + acdc_df_i = posData.allData_li[posData.frame_i]["acdc_df"] + + # Blink one pairing at the time (the first found) + xc_b = acdc_df_i.loc[budIDs[0], "x_centroid"] + yc_b = acdc_df_i.loc[budIDs[0], "y_centroid"] + + xc_m = acdc_df_i.loc[mothIDs[0], "x_centroid"] + yc_m = acdc_df_i.loc[mothIDs[0], "y_centroid"] + + self.warnPairingItem.setData([xc_b, xc_m], [yc_b, yc_m]) + + self.blinkPairingItemTimer = QTimer() + self.blinkPairingItemTimer.flag = True + self.blinkPairingItemTimer.timeout.connect(self.blinkPairingItem) + self.blinkPairingItemTimer.start(300) + + def startCcaIntegrityCheckerWorker(self): + if not hasattr(self, "data"): + return + + if not self.isDataLoaded: + return + + if not self.ccaIntegrCheckerToggle.isChecked(): + return + + ccaCheckerThread = QThread() + self.ccaCheckerMutex = QMutex() + self.ccaCheckerWaitCond = QWaitCondition() + + worker = workers.CcaIntegrityCheckerWorker( + self.ccaCheckerMutex, self.ccaCheckerWaitCond + ) + self.ccaIntegrityCheckerWorker = worker + self.ccaCheckerThread = ccaCheckerThread + + worker.moveToThread(ccaCheckerThread) + worker.finished.connect(ccaCheckerThread.quit) + worker.finished.connect(worker.deleteLater) + ccaCheckerThread.finished.connect(ccaCheckerThread.deleteLater) + + worker.sigDone.connect(self.ccaCheckerWorkerDone) + worker.progress.connect(self.workerProgress) + worker.critical.connect(self.ccaIntegrityWorkerCritical) + worker.finished.connect(self.ccaCheckerWorkerClosed) + worker.sigWarning.connect(self.warnCcaIntegrity) + worker.sigFixWillDivide.connect(self.fixWillDivide) + + ccaCheckerThread.started.connect(worker.run) + ccaCheckerThread.start() + + self.ccaCheckerRunning = True + + self.initCcaIntegrityChecker() + + self.logger.info("Cell cycle annotations integrity checker started.") + + def stopBlinkingPairItem(self): + self.ax1_newMothBudLinesItem.setOpacity(1.0) + self.ax1_oldMothBudLinesItem.setOpacity(1.0) + + self.warnPairingItem.setData([], []) + try: + self.blinkPairingItemTimer.stop() + except Exception as e: + pass + + def stopCcaIntegrityCheckerWorker(self): + try: + self.ccaIntegrityCheckerWorker._stop() + except Exception as err: + pass + + def storeFromConcatCcaDf(self, global_cca_df): + posData = self.data[self.pos_i] + for frame_i in range(0, posData.SizeT): + try: + cca_df = global_cca_df.loc[frame_i] + except KeyError as err: + break + + self.store_cca_df(frame_i=frame_i, cca_df=cca_df, autosave=False) + + self.get_cca_df() + + def store_cca_df( + self, + pos_i=None, + frame_i=None, + cca_df=None, + mainThread=True, + autosave=True, + store_cca_df_copy=False, + ): + pos_i = self.pos_i if pos_i is None else pos_i + posData = self.data[pos_i] + i = posData.frame_i if frame_i is None else frame_i + if cca_df is None: + cca_df = posData.cca_df + if self.ccaTableWin is not None and mainThread: + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + acdc_df = posData.allData_li[i]["acdc_df"] + if acdc_df is None: + current_frame_i = None + if frame_i is not None and frame_i != posData.frame_i: + current_frame_i = posData.frame_i + posData.frame_i = frame_i + self.get_data() + self.store_data() + acdc_df = posData.allData_li[i]["acdc_df"] + if current_frame_i is not None: + # Back to current frame + posData.frame_i = current_frame_i + self.get_data(debug=False) + + if "cell_cycle_stage" in acdc_df.columns: + # Cell cycle info already present --> overwrite with new + acdc_df[self.cca_df_colnames] = cca_df[self.cca_df_colnames] + posData.allData_li[i]["acdc_df"] = acdc_df + elif cca_df is not None: + df = acdc_df.drop(cca_df.columns, axis=1, errors="ignore") + df = df.join(cca_df, how="left") + posData.allData_li[i]["acdc_df"] = df + + # Store copy for cca integrity worker + self.store_cca_df_checker(posData, i, cca_df) + + if store_cca_df_copy and cca_df is not None: + posData.allData_li[i]["cca_df"] = cca_df.copy() + + if autosave: + self.enqAutosave() + self.enqCcaIntegrityChecker() + + def store_cca_df_checker(self, posData, frame_i, cca_df): + if not self.ccaCheckerRunning: + return + + if cca_df is None: + return + + posData.allData_li[frame_i]["cca_df_checker"] = cca_df.copy() + + def swapMothers(self, budID, otherBudID, otherMothID, mothID): + posData = self.data[self.pos_i] + + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + self.logger.info( + f"Swapping assignments (requested at frame n. {posData.frame_i + 1}):\n" + f" * Bud ID {budID} --> mother ID {otherMothID}\n" + f" * Bud ID {otherBudID} --> mother ID {mothID}" + ) + + correct_pairings = {otherBudID: mothID, budID: otherMothID} + + for correct_budID, correct_mothID in correct_pairings.items(): + posData.cca_df.at[correct_budID, "relative_ID"] = correct_mothID + posData.cca_df.at[correct_mothID, "relative_ID"] = correct_budID + posData.cca_df.at[correct_budID, "corrected_on_frame_i"] = posData.frame_i + posData.cca_df.at[correct_mothID, "corrected_on_frame_i"] = posData.frame_i + self.store_cca_df() + + # Correct past frames + corrected_budIDs_past = set() + for past_i in range(posData.frame_i - 1, -1, -1): + if len(corrected_budIDs_past) == 2: + break + + for correct_budID, correct_mothID in correct_pairings.items(): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + + if correct_budID in corrected_budIDs_past: + continue + + if correct_budID not in cca_df_i.index: + # Bud does not exist anymore in the past + corrected_budIDs_past.add(correct_budID) + + if len(corrected_budIDs_past) < 2: + self.restoreMotherToBeforeWrongBudWasAssignedToIt( + correct_mothID, cca_df_i, past_i + ) + continue + + cca_df_i.at[correct_budID, "relative_ID"] = correct_mothID + cca_df_i.at[correct_mothID, "relative_ID"] = correct_budID + cca_df_i.at[correct_budID, "corrected_on_frame_i"] = posData.frame_i + cca_df_i.at[correct_mothID, "corrected_on_frame_i"] = posData.frame_i + + # Set mother cell cycle stage to S in case it is not + if cca_df_i.at[correct_mothID, "cell_cycle_stage"] == "G1": + cca_df_i.at[correct_mothID, "cell_cycle_stage"] = "S" + # cca_df_i.at[correct_mothID, 'generation_num'] -= 1 + + self.store_cca_df(frame_i=past_i, cca_df=cca_df_i, autosave=False) + + # Correct future frames + corrected_budIDs_future = set() + for future_i in range(posData.frame_i + 1, posData.SizeT): + if len(corrected_budIDs_future) == 2: + break + + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + for correct_budID, correct_mothID in correct_pairings.items(): + if correct_budID in corrected_budIDs_future: + # Bud already corrected in the future + continue + + if correct_budID not in cca_df_i.index: + # Bud disappeared in the future + corrected_budIDs_future.add(correct_budID) + continue + + ccs_bud = cca_df_i.at[correct_budID, "cell_cycle_stage"] + if ccs_bud == "G1": + # Bud divided in the future, annotate division between + # correct mother and wrong bud and then stop correcting + if correct_budID not in corrected_budIDs_future: + corrected_budIDs_future.add(correct_budID) + + if len(corrected_budIDs_future) < 2: + self.annotateDivisionFutureFramesSwapMothers( + cca_df_i, correct_mothID, future_i + ) + continue + + cca_df_i.at[correct_budID, "relative_ID"] = correct_mothID + cca_df_i.at[correct_mothID, "relative_ID"] = correct_budID + cca_df_i.at[correct_budID, "corrected_on_frame_i"] = posData.frame_i + cca_df_i.at[correct_mothID, "corrected_on_frame_i"] = posData.frame_i + + # Set mother cell cycle stage to S in case it is not + if cca_df_i.at[correct_mothID, "cell_cycle_stage"] == "G1": + cca_df_i.at[correct_mothID, "cell_cycle_stage"] = "S" + # cca_df_i.at[correct_mothID, 'generation_num'] -= 1 + + self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) + + self.updateAllImages() + + def undoBudMothAssignment(self, ID): + posData = self.data[self.pos_i] + relID = posData.cca_df.at[ID, "relative_ID"] + ccs = posData.cca_df.at[ID, "cell_cycle_stage"] + if ccs == "G1": + return + posData.cca_df.at[ID, "relative_ID"] = -1 + posData.cca_df.at[ID, "generation_num"] = 2 + posData.cca_df.at[ID, "cell_cycle_stage"] = "G1" + posData.cca_df.at[ID, "relationship"] = "mother" + if relID in posData.cca_df.index: + posData.cca_df.at[relID, "relative_ID"] = -1 + posData.cca_df.at[relID, "generation_num"] = 2 + posData.cca_df.at[relID, "cell_cycle_stage"] = "G1" + posData.cca_df.at[relID, "relationship"] = "mother" + + obj_idx = posData.IDs.index(ID) + relObj_idx = posData.IDs.index(relID) + rp_ID = posData.rp[obj_idx] + rp_relID = posData.rp[relObj_idx] + + self.store_cca_df() + + # Update cell cycle info LabelItems + self.setAllTextAnnotations() + + if self.ccaTableWin is not None: + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + def undoDivisionAnnotation(self, cca_df, ID, relID): + # Correct as follows: + # If G1 then correct to S and -1 on generation number + store = False + cca_df.at[ID, "cell_cycle_stage"] = "S" + gen_num_clickedID = cca_df.at[ID, "generation_num"] + cca_df.at[ID, "generation_num"] -= 1 + cca_df.at[ID, "division_frame_i"] = -1 + cca_df.at[relID, "cell_cycle_stage"] = "S" + gen_num_relID = cca_df.at[relID, "generation_num"] + cca_df.at[relID, "generation_num"] -= 1 + cca_df.at[relID, "division_frame_i"] = -1 + if gen_num_clickedID < gen_num_relID: + cca_df.at[ID, "relationship"] = "bud" + else: + cca_df.at[relID, "relationship"] = "bud" + cca_df.at[ID, "will_divide"] = 0 + cca_df.at[relID, "will_divide"] = 0 + store = True + return store + + def unstore_cca_df(self): + posData = self.data[self.pos_i] + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + for col in self.cca_df_colnames: + if col not in acdc_df.columns: + continue + acdc_df.drop(col, axis=1, inplace=True) + + def updateCcaDfDeletedIDsTimelapse( + self, posData, relIDsOfDelIDs, deletedIDs, undoId, dropInPast, dropInFuture + ): + # Get status of the relIDs (of deleted IDs) to restore + relIDsCcaStatus = {} + for relID in relIDsOfDelIDs: + try: + ccs = posData.cca_df.at[relID, "cell_cycle_stage"] + relationship = posData.cca_df.at[relID, "relationship"] + except Exception as err: + continue + + ccaStatus = core.getBaseCca_df([relID]).loc[relID] + if relationship == "mother" and ccs == "S": + for past_frame_i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_frame_i, return_df=True) + ccs_past = cca_df_i.at[relID, "cell_cycle_stage"] + if ccs_past == "G1": + ccaStatus = cca_df_i.loc[relID] + break + + posData.cca_df.loc[relID] = ccaStatus + self.store_data(autosave=False) + relIDsCcaStatus[relID] = ccaStatus + + for fut_frame_i in range(posData.frame_i + 1, posData.SizeT): + cca_df_i = self.get_cca_df(frame_i=fut_frame_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(fut_frame_i, cca_df_i, undoId) + + if dropInFuture: + cca_df_i = cca_df_i.drop(deletedIDs, errors="ignore") + else: + for delID in deletedIDs: + dataDict = posData.allData_li[fut_frame_i] + delIDexists = dataDict["IDs_idxs"].get(delID, False) + if not delIDexists: + continue + + cca_df_i.loc[delID] = core.getBaseCca_df([delID]).loc[delID] + + areRelIDsPresent = False + for relID in relIDsOfDelIDs: + try: + ccs = cca_df_i.at[relID, "cell_cycle_stage"] + relationship = cca_df_i.at[relID, "relationship"] + ccaStatus = relIDsCcaStatus[relID] + cca_df_i.loc[relID] = ccaStatus + areRelIDsPresent = True + except Exception as err: + continue + + if not areRelIDsPresent: + break + + self.store_cca_df(frame_i=fut_frame_i, cca_df=cca_df_i, autosave=False) + + # Correct past frames + for past_frame_i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_frame_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + self.storeUndoRedoCca(past_frame_i, cca_df_i, undoId) + if dropInPast: + cca_df_i = cca_df_i.drop(deletedIDs, errors="ignore") + else: + for delID in deletedIDs: + dataDict = posData.allData_li[past_frame_i] + delIDexists = dataDict["IDs_idxs"].get(delID, False) + if not delIDexists: + continue + + cca_df_i.loc[delID] = core.getBaseCca_df([delID]).loc[delID] + + areRelIDsPresent = False + for relID in relIDsOfDelIDs: + try: + ccs = cca_df_i.at[relID, "cell_cycle_stage"] + relationship = cca_df_i.at[relID, "relationship"] + ccaStatus = relIDsCcaStatus[relID] + cca_df_i.loc[relID] = ccaStatus + areRelIDsPresent = True + except Exception as err: + continue + + if not areRelIDsPresent: + break + + self.store_cca_df(frame_i=past_frame_i, cca_df=cca_df_i, autosave=False) + + def updateIsHistoryKnown(): + """ + This function is called every time the user saves and it is used + for updating the status of cells where we don't know the history + + There are three possibilities: + + 1. The cell with unknown history is a BUD + --> we don't know when that bud emerged --> 'emerg_frame_i' = -1 + 2. The cell with unknown history is a MOTHER cell + --> we don't know emerging frame --> 'emerg_frame_i' = -1 + AND generation number --> we start from 'generation_num' = 2 + 3. The cell with unknown history is a CELL in G1 + --> we don't know emerging frame --> 'emerg_frame_i' = -1 + AND generation number --> we start from 'generation_num' = 2 + AND relative's ID in the previous cell cycle --> 'relative_ID' = -1 + """ + pass + + def update_cca_df_deletedIDs( + self, posData, deletedIDs, dropInPast=True, dropInFuture=True + ): + if posData.cca_df is None: + return + + # Store cca_df for undo action + undoId = uuid.uuid4() + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + try: + relIDs = posData.cca_df.reindex(deletedIDs, fill_value=-1)["relative_ID"] + except KeyError as err: + return + + posData.cca_df = posData.cca_df.drop(deletedIDs, errors="ignore") + if self.isSnapshot: + self.update_cca_df_newIDs(posData, relIDs) + else: + self.updateCcaDfDeletedIDsTimelapse( + posData, relIDs, deletedIDs, undoId, dropInPast, dropInFuture + ) + + def update_cca_df_newIDs(self, posData, new_IDs): + for newID in new_IDs: + self.addIDBaseCca_df(posData, newID) + + def update_cca_df_relabelling(self, posData, oldIDs, newIDs): + relIDs = posData.cca_df["relative_ID"] + posData.cca_df["relative_ID"] = relIDs.replace(oldIDs, newIDs) + mapper = dict(zip(oldIDs, newIDs)) + posData.cca_df = posData.cca_df.rename(index=mapper) + + def update_cca_df_snapshots(self, editTxt, posData): + cca_df = posData.cca_df + cca_df_IDs = cca_df.index + if editTxt == "Delete ID": + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Separate IDs": + new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] + self.update_cca_df_newIDs(posData, new_IDs) + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Edit ID": + new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] + self.update_cca_df_newIDs(posData, new_IDs) + old_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, old_IDs) + + elif editTxt == "Annotate ID as dead": + return + + elif editTxt == "Deleted non-selected objects": + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Delete ID with eraser": + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Add new ID with brush tool": + new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] + self.update_cca_df_newIDs(posData, new_IDs) + + elif editTxt == "Merge IDs": + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Add new ID with curvature tool": + new_IDs = [ID for ID in posData.IDs if ID not in cca_df_IDs] + self.update_cca_df_newIDs(posData, new_IDs) + + elif editTxt == "Delete IDs using ROI": + deleted_IDs = [ID for ID in cca_df_IDs if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, deleted_IDs) + + elif editTxt == "Repeat segmentation": + posData.cca_df = self.getBaseCca_df() + + def viewCcaTable(self): + posData = self.data[self.pos_i] + zoomIDs = self.getZoomIDs() + + df = posData.allData_li[posData.frame_i]["acdc_df"] + current_cca_df = posData.cca_df + if zoomIDs is not None: + df = df.loc[zoomIDs] + current_cca_df = current_cca_df.loc[zoomIDs] + + for column in current_cca_df.columns: + header = ( + "================================================\n" + f"CURRENT vs STORED `{column}` column" + f"for frame number {posData.frame_i + 1}:\n" + ) + df_compare = current_cca_df[[column]].copy() + df_compare[f"STORED_{column}"] = df[column] + text = f"{header}{df_compare}" + self.logger.info(text) + + if "cell_cycle_stage" in df.columns: + cca_df = df[self.cca_df_colnames] + cca_df = cca_df.merge( + current_cca_df, + how="outer", + left_index=True, + right_index=True, + suffixes=("_STORED", "_CURRENT"), + ) + cca_df = cca_df.reindex(sorted(cca_df.columns), axis=1) + num_cols = len(cca_df.columns) + for j in range(0, num_cols, 2): + df_j_x = cca_df.iloc[:, j] + df_j_y = cca_df.iloc[:, j + 1] + if any(df_j_x != df_j_y): + self.logger.info("------------------------") + self.logger.info("DIFFERENCES:") + diff_df = cca_df.iloc[:, j : j + 2] + diff_mask = diff_df.iloc[:, 0] != diff_df.iloc[:, 1] + self.logger.info(diff_df[diff_mask]) + else: + cca_df = None + self.logger.info(cca_df) + self.logger.info("========================") + if current_cca_df is None: + return + if current_cca_df.empty: + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Cell cycle annotations' table is empty.
    " + ) + msg.warning(self, "Table empty", txt) + return + + df = posData.add_tree_cols_to_cca_df(current_cca_df, frame_i=posData.frame_i) + if self.ccaTableWin is None: + self.ccaTableWin = apps.ViewCcaTableWindow(df, parent=self) + self.ccaTableWin.show() + self.ccaTableWin.setGeometryWindow() + self.ccaTableWin.sigUpdateCcaTable.connect(self.onSigUpdateCcaTableWindow) + else: + self.ccaTableWin.setFocus() + self.ccaTableWin.activateWindow() + self.ccaTableWin.updateTable(current_cca_df) + + def warnBudAnnotatedDividedInFuture( + self, budID, motherID, future_division_frame_i, action="swap mother cells" + ): + posData = self.data[self.pos_i] + + txt = html_utils.paragraph(f""" + Bud ID {budID} is annotated as divided from mother ID {motherID} + at frame n. {future_division_frame_i + 1},
    + therefore it is not possible to {action}.

    + We recommend reinitializing cell cycle annotations on any + frame
    between frames number {posData.frame_i + 1} and + {future_division_frame_i} before attempting to {action}.

    + Thank you for your patience! + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, f"{action} not possible".title(), txt) + return + + def warnCcaIntegrity(self, txt, category): + self.logger.warning(f"{html_utils.to_plain_text(txt)}") + + if "disable_all" in self.disabled_cca_warnings: + return + + if category in self.disabled_cca_warnings: + return + + if txt in self.disabled_cca_warnings: + return + + if self.isWarningCcaIntegrity: + # Some other warning is still open --> avoid opening another one + return + + self.isWarningCcaIntegrity = True + disabled_warning = _warnings.warn_cca_integrity( + txt, category, self, go_to_frame_callback=self.goToFrameNumber + ) + if disabled_warning: + self.disabled_cca_warnings.add(disabled_warning) + + self.isWarningCcaIntegrity = False + + def warnDeadOrExcludedMothers(self, budIDs, mothIDs): + self.startBlinkingPairingItem(budIDs, mothIDs) + msg = widgets.myMessageBox(wrapText=False) + pairings = [ + f"Mother ID {mID} --> bud ID {bID}" for mID, bID in zip(mothIDs, budIDs) + ] + txt = html_utils.paragraph(f""" + The mother cell in the following mother-bud pairings + (blinking line on the image) is
    + excluded from the analysis or dead: + {html_utils.to_list(pairings)} + """) + msg.warning( + self, "Mother cell is excluded or dead", txt, buttonsTexts=("Cancel", "Ok") + ) + return not msg.cancel + + def warnEditingWithCca_df( + self, + editTxt, + return_answer=False, + get_answer=False, + get_cancelled=False, + update_images=True, + ): + # Function used to warn that the user is editing in "Segmentation and + # Tracking" mode a frame that contains cca annotations. + # Ask whether to remove annotations from all future frames + if self.isSnapshot: + return True + + posData = self.data[self.pos_i] + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + + if acdc_df is None and self.lineage_tree is None: + if update_images: + self.updateAllImages() + return True + + cell_cycle_stage_present = ( + acdc_df is not None and "cell_cycle_stage" in acdc_df.columns + ) + lineage_tree_present = ( + self.lineage_tree is not None or "parent_ID_tree" in acdc_df.columns + ) + if not cell_cycle_stage_present and not lineage_tree_present: + if update_images: + self.updateAllImages() + return True + + action = self.warnEditingWithAnnotActions.get(editTxt, None) + if action is not None and not action.isChecked(): + # user has checked that he does not want to be asked again AND he doesnt want to delete + if update_images: + self.updateAllImages() + return True + + msg = widgets.myMessageBox() + warn_type = ( + "cell cycle annotations" + if cell_cycle_stage_present + else "lineage tree annotations" + ) + txt = html_utils.paragraph( + f"You modified a frame that has {warn_type}.

    " + f'The change "{editTxt}" most likely makes the ' + "annotations wrong.

    " + "If you really want to apply this change we reccommend to remove" + f"ALL {warn_type}
    " + "from current frame to the end.

    " + "What do you want to do?" + ) + if action is not None: + checkBox = QCheckBox("Remember my choice and do not ask again") + else: + checkBox = None + + dropDelIDsNoteText = ( + "" if editTxt.find("Delete") == -1 else " (drop removed IDs)" + ) + _, removeAnnotButton, _ = msg.warning( + self, + "Edited segmentation with annotations!", + txt, + buttonsTexts=( + "Cancel", + "Remove annotations from future frames (RECOMMENDED)", + f"Do not remove annotations{dropDelIDsNoteText}", + ), + widgets=checkBox, + ) + if msg.cancel: + if get_cancelled: + return "cancelled" + removeAnnotations = False + return removeAnnotations + + if action is not None: + action.setChecked(not checkBox.isChecked()) + action.removeAnnot = msg.clickedButton == removeAnnotButton + + if return_answer: + return msg.clickedButton == removeAnnotButton + + if (msg.clickedButton == removeAnnotButton) and cell_cycle_stage_present: + self.resetFutureCcaColCurrentFrame() + self.resetCcaFuture(posData.frame_i + 1) + self.updateAllImages() + elif (msg.clickedButton == removeAnnotButton) and lineage_tree_present: + self.resetLin_tree_future() + self.updateAllImages() + else: + if dropDelIDsNoteText and posData.cca_df is not None: + delIDs = [ID for ID in posData.cca_df.index if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, delIDs, dropInPast=False) + self.addMissingIDs_cca_df(posData) + self.updateAllImages() + self.store_data() + # if action is not None: + # if action.removeAnnot: + # self.store_data() + # posData.frame_i -= 1 + # self.get_data() + # if lineage_tree_present: + # self.resetLin_tree_future() + # self.resetCcaFuture(posData.frame_i) + # self.next_frame() + + if get_answer: + return msg.clickedButton == removeAnnotButton + else: + return True + + def warnFrameNeverVisitedSegmMode(self): + msg = widgets.myMessageBox() + warn_cca = msg.critical( + self, + "Next frame NEVER visited", + 'Next frame was never visited in "Segmentation and Tracking"' + "mode.\n You cannot perform cell cycle analysis on frames" + "where segmentation and/or tracking errors were not" + "checked/corrected.\n\n" + 'Switch to "Segmentation and Tracking" mode ' + "and check/correct next frame,\n" + "before attempting cell cycle analysis again", + ) + return False + + def warnMotherNotAtLeastOneFrameG1(self, budID, motherID, frame_no_G1): + posData = self.data[self.pos_i] + + txt = html_utils.paragraph(f""" + Assigning bud ID {budID} to cell ID {motherID} cannot be + done because cell ID {motherID} is not in G1 at frame n. + {frame_no_G1}.

    + This would result in no G1 phase between previous cell cycle of + cell ID {motherID} and current one. + This is unfortunately not allowed.

    + One possible solution is to annotate division on cell ID + {motherID} on any frame before frame n. {frame_no_G1}.

    + Thank you for your patience! + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Swap mothers not possible", txt) + return + + def warnMotherNotEligible(self, new_mothID, budID, i, why): + if why == "not_G1_in_the_future": + err_msg = html_utils.paragraph(f""" + The requested cell in G1 (ID={new_mothID}) + at future frame {i + 1} has a bud assigned to it, + therefore it cannot be assigned as the mother + of bud ID {budID}.

    + You can assign a cell as the mother of bud ID {budID} + only if this cell is in G1 for the + entire life of the bud.

    + One possible solution is to click on "cancel", go to + frame {i + 1} and assign the bud of cell {new_mothID} + to another cell.\n' + A second solution is to assign bud ID {budID} to cell + {new_mothID} anyway by clicking "Apply".

    + However to ensure correctness of + future assignments Cell-ACDC will delete any cell cycle + information from frame {i + 1} to the end. Therefore, you + will have to visit those frames again.

    + The deletion of cell cycle information + CANNOT BE UNDONE! + Saved data is not changed of course.

    + Apply assignment or cancel process? + """) + applyButton = widgets.okPushButton(isDefault=False) + applyButton.setText("Apply and remove future annotations") + msg = widgets.myMessageBox() + _, applyButton = msg.warning( + self, "Cell not eligible", err_msg, buttonsTexts=("Cancel", applyButton) + ) + cancel = msg.cancel + apply = msg.clickedButton == applyButton + elif why == "not_G1_in_the_past": + err_msg = html_utils.paragraph(f""" + The requested cell in G1 + (ID={new_mothID}) at past frame {i + 1} + has a bud assigned to it, therefore it cannot be + assigned as mother of bud ID {budID}.
    + You can assign a cell as the mother of bud ID {budID} + only if this cell is in G1 for the entire life of the bud.
    + One possible solution is to first go to frame {i + 1} and + assign the bud of cell {new_mothID} to another cell. + """) + msg = widgets.myMessageBox() + msg.warning(self, "Cell not eligible", err_msg) + cancel = msg.cancel + apply = False + elif why == "single_frame_G1_duration": + err_msg = html_utils.paragraph(f""" + Assigning bud ID {budID} to cell ID {new_mothID} would result + in no G1 phase at all between previous cell cycle and + current cell cycle (see frame n. {i + 1}).

    + + The solution is to annotate division on cell ID {new_mothID} + on any frame before the frame number {i + 1}, and then + proceed to correcting the bud assignment.

    + + This will gurantee a G1 duration for the cell {new_mothID} + of at least 1 frame.

    + Thank you for your patience! + """) + msg = widgets.myMessageBox() + msg.warning(self, "Cell not eligible", err_msg) + cancel = msg.cancel + apply = False + return cancel, apply + + def warnScellsGone(self, ScellsIDsGone, frame_i): + msg = widgets.myMessageBox() + text = html_utils.paragraph(f""" + In the next frame the followning cells' IDs in S/G2/M + (highlighted with a yellow contour) will disappear:

    + {ScellsIDsGone}

    + If the cell does not exist you might have deleted it at some point. + If that's the case, then try to go to some previous frames and reset + the cell cycle annotations there (button on the top toolbar).

    + These cells are either buds or mother whose related IDs will not + disappear. This is likely due to cell division happening in + previous frame and the divided bud or mother will be + washed away.

    + If you decide to continue these cells will be automatically + annotated as divided at frame number {frame_i}.

    + Do you want to continue? + """) + _, yesButton, noButton = msg.warning( + self, + 'Cells in "S/G2/M" disappeared!', + text, + buttonsTexts=("Cancel", "Yes", "No"), + ) + return msg.clickedButton == yesButton + + def warnSettingHistoryKnownCellsFirstFrame(self, ID): + txt = html_utils.paragraph(f""" + Cell ID {ID} is a cell that is present since the first + frame.

    + These cells already have history UNKNOWN assigned and the + history status cannot be changed. + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "First frame cells", txt) + + def getStatusKnownHistoryBud(self, ID): + posData = self.data[self.pos_i] + cca_df_ID = None + for i in range(posData.frame_i - 1, -1, -1): + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + is_cell_existing = is_bud_existing = ID in cca_df_i.index + if not is_cell_existing: + bud_cca_dict = base_cca_dict.copy() + bud_cca_dict["cell_cycle_stage"] = "S" + bud_cca_dict["generation_num"] = 0 + bud_cca_dict["relationship"] = "bud" + bud_cca_dict["emerg_frame_i"] = i + 1 + bud_cca_dict["is_history_known"] = True + cca_df_ID = pd.Series(bud_cca_dict) + return cca_df_ID + + def setHistoryKnowledge(self, ID, cca_df): + posData = self.data[self.pos_i] + is_history_known = cca_df.at[ID, "is_history_known"] + if is_history_known: + cca_df.at[ID, "is_history_known"] = False + cca_df.at[ID, "cell_cycle_stage"] = "G1" + cca_df.at[ID, "generation_num"] += 2 + cca_df.at[ID, "emerg_frame_i"] = -1 + cca_df.at[ID, "relative_ID"] = -1 + cca_df.at[ID, "relationship"] = "mother" + else: + cca_df.loc[ID] = posData.ccaStatus_whenEmerged[ID] + + def annotateDivisionFutureFramesSwapMothers( + self, cca_df_at_future_division, mothIDofDisappearedBud, frame_i + ): + """This method is called as part of `guiWin.swapMothers`. + + It annotates cell division and propagates that to future frames to the + mother cell that stops having the correct bud because division between + wrong bud and other wrong mother was annotated in the future. + + Parameters + ---------- + cca_df_at_future_division : pd.DataFrame + _description_ + mothIDofDisappearedBud : int + Mother ID of the disappeared bud + frame_i : int + Frame since when the mother ID stops having the correct bud because + the correct bud was assigned as divided from the wrong mother + """ + posData = self.data[self.pos_i] + + relativeIDofMothID = cca_df_at_future_division.at[ + mothIDofDisappearedBud, "relative_ID" + ] + if relativeIDofMothID not in cca_df_at_future_division.index: + # Also wrong bud ID disappeared + return + + relativeIDofMothIDrelationship = cca_df_at_future_division.at[ + relativeIDofMothID, "relationship" + ] + if relativeIDofMothIDrelationship != "bud": + # The wrong bud ID is a cell in G1 from future cycle --> + # the actual wrong bud ID disappeared too. + return + + wrongBudID = relativeIDofMothID + + self.annotateDivision( + cca_df_at_future_division, + mothIDofDisappearedBud, + wrongBudID, + frame_i=frame_i, + ) + cca_df_at_future_division.at[mothIDofDisappearedBud, "corrected_on_frame_i"] = ( + frame_i + ) + self.store_cca_df( + frame_i=frame_i, cca_df=cca_df_at_future_division, autosave=False + ) + + ccaStatusToRestore = cca_df_at_future_division.loc[mothIDofDisappearedBud] + for future_i in range(frame_i + 1, posData.SizeT): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + break + + ccs = cca_df_i.at[mothIDofDisappearedBud, "cell_cycle_stage"] + if ccs == "G1": + # Mother cell in G1 again, stop correcting + break + + cca_df_i.loc[mothIDofDisappearedBud] = ccaStatusToRestore + cca_df_i.at[mothIDofDisappearedBud, "corrected_on_frame_i"] = frame_i + + self.store_cca_df(frame_i=future_i, cca_df=cca_df_i, autosave=False) + + def getStatus_RelID_BeforeEmergence(self, budID, curr_mothID): + posData = self.data[self.pos_i] + # Get status of the current mother before it had budID assigned to it + cca_status_before_bud_emerg = None + for i in range(posData.frame_i - 1, -1, -1): + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=i, return_df=True) + + is_bud_existing = budID in cca_df_i.index + if not is_bud_existing: + # Bud was not emerged yet + if curr_mothID in cca_df_i.index: + cca_status_before_bud_emerg = cca_df_i.loc[curr_mothID] + return cca_status_before_bud_emerg + else: + # The bud emerged together with the mother because + # they appeared together from outside of the fov + # and they were trated as new IDs bud in S0 + bud_cca_dict = base_cca_dict.copy() + bud_cca_dict["cell_cycle_stage"] = "S" + bud_cca_dict["generation_num"] = 0 + bud_cca_dict["relationship"] = "bud" + bud_cca_dict["emerg_frame_i"] = i + 1 + bud_cca_dict["is_history_known"] = True + cca_status_before_bud_emerg = pd.Series(bud_cca_dict) + return cca_status_before_bud_emerg + + # Mother did not have a status before bud emergence because it was + # already paired with bud at first frame --> reinit to default + cca_status_before_bud_emerg = core.getBaseCca_df([curr_mothID]).loc[curr_mothID] + return cca_status_before_bud_emerg + + def _checkBudFutureNoDivision(self, budID, start_frame_i): + posData = self.data[self.pos_i] + + future_i = start_frame_i + for future_i in range(start_frame_i, posData.SizeT): + if future_i == 0: + continue + + # Get cca_df for ith frame from allData_li + cca_df_i = self.get_cca_df(frame_i=future_i, return_df=True) + if cca_df_i is None: + # ith frame was not visited yet + return + + if budID not in cca_df_i.index: + # Bud disappears in the future --> fine + return + + ccs = cca_df_i.at[budID, "cell_cycle_stage"] + if ccs == "G1": + return future_i, cca_df_i.at[budID, "relative_ID"] + + def _checkMothInG1beforeBudEmergence( + self, motherID, budID, wrongBudID, start_frame_i + ): + """Check that mother is in G1 on the frame before bud emergence + + Parameters + ---------- + motherID : int + ID of mother cell + budID : int + ID of bud + start_frame_i : int + Frame index from which to start checking in the past + """ + for past_i in range(start_frame_i, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + if budID not in cca_df_i.index: + if cca_df_i.at[motherID, "cell_cycle_stage"] == "G1": + return + + budID_prev_cycle = cca_df_i.at[motherID, "relative_ID"] + if budID_prev_cycle != wrongBudID: + return past_i + 1 + + break + + def restoreMotherToBeforeWrongBudWasAssignedToIt( + self, mothIDofDisappearedBud, cca_df_at_correct_bud_ID_disappearance, frame_i + ): + """This method is called as part of `guiWin.swapMothers`. + + Parameters + ---------- + mothIDofDisappearedBud : int + Mother ID of the disappeared bud + cca_df_at_correct_bud_ID_disappearance : pd.DataFrame + Cell cycle annotations DataFrame when the correct bud ID stopped + existing (before emergence) + frame_i : int + Frame index when the correct bud ID stopped existing + (before emergence) + + Note + ---- + It restores the mother cell cycle annotations to the status it had + before the wrong bud was assigned to it. + + We need to do it only if the swapMothers past frames loop is still + iterating to correct the other bud. + + We also need to do this only if the wrong bud ID is actually a bud. + + When we swap mothers in the past frames it can be that the correct bud + ID stops existing (before emergence). In this case the correct mother + still has the wrong bud assigned to ID so we need to restore the status + it had before the wrong bud was assigned to it. + + To determine the status we go back until the wrong bud disappear. That + is the frame before the wrong bud was assigned to the mother we want to + correct. This is the status we want to restore. + + When we go back in time it could be that the wrong bud never disappears + becuase it is already emerged at frame 0. In this case the status we + want to restore at is the default G1 status at frame 0. + """ + relativeIDofMothID = cca_df_at_correct_bud_ID_disappearance.at[ + mothIDofDisappearedBud, "relative_ID" + ] + if relativeIDofMothID not in cca_df_at_correct_bud_ID_disappearance.index: + # Also wrong bud ID disappeared + return + + relativeIDofMothIDrelationship = cca_df_at_correct_bud_ID_disappearance.at[ + relativeIDofMothID, "relationship" + ] + if relativeIDofMothIDrelationship != "bud": + # The wrong bud ID is a cell in G1 from previous cycle --> + # the actual wrong bud ID disappeared too. + return + + wrongBudID = relativeIDofMothID + + mothCcaBeforeWrongBudID = base_cca_dict + # Search in the past for status of mother before wrong bud emerged + for past_i in range(frame_i, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + if wrongBudID not in cca_df_i.index: + mothCcaBeforeWrongBudID = cca_df_i.loc[mothIDofDisappearedBud] + break + + # Restore in past frames the correct mother status + for past_i in range(frame_i, -1, -1): + cca_df_i = self.get_cca_df(frame_i=past_i, return_df=True) + if wrongBudID in cca_df_i.index: + cca_df_i.loc[mothIDofDisappearedBud] = mothCcaBeforeWrongBudID + cca_df_i.at[mothIDofDisappearedBud, "corrected_on_frame_i"] = frame_i + self.store_cca_df(frame_i=past_i, cca_df=cca_df_i, autosave=False) + else: + break + + def annotateDivisionCurrentFrameRelativeIDgone(self, IDwhoseRelativeIsGone): + posData = self.data[self.pos_i] + if posData.cca_df is None: + return + ID = IDwhoseRelativeIsGone + posData.cca_df.at[ID, "generation_num"] += 1 + posData.cca_df.at[ID, "division_frame_i"] = posData.frame_i - 1 + posData.cca_df.at[ID, "relationship"] = "mother" + + def annotateDisappearedBeforeDivision(self, relID, IDgone, cca_df, frame_i=None): + posData = self.data[self.pos_i] + gen_num = cca_df.at[relID, "generation_num"] + if frame_i is None: + frame_i = posData.frame_i + + for past_frame_i in range(frame_i - 1, -1, -1): + past_cca_df = self.get_cca_df(frame_i=past_frame_i, return_df=True) + if past_cca_df is None: + return + + try: + if past_cca_df.at[relID, "generation_num"] != gen_num: + # ID is a mother and the cell cycle is finished here + return + except Exception as err: + # Bud stops existing --> stop process + return + + past_cca_df.at[IDgone, "disappears_before_division"] = 1 + past_cca_df.at[relID, "daughter_disappears_before_division"] = 1 + + self.store_cca_df(cca_df=past_cca_df, frame_i=past_frame_i, autosave=False) diff --git a/cellacdc/gui_combine.py b/cellacdc/mixins/combine.py similarity index 69% rename from cellacdc/gui_combine.py rename to cellacdc/mixins/combine.py index 17303c837..b3d8459d1 100644 --- a/cellacdc/gui_combine.py +++ b/cellacdc/mixins/combine.py @@ -1,58 +1,75 @@ -from typing import List, Dict, Any, Tuple -from . import core, workers, widgets, html_utils, apps, preprocess, myutils, printl -from qtpy.QtCore import QThread, QTimer, QMutex, QWaitCondition -from natsort import natsorted +"""Combine channels GUI mixin extracted from gui_combine.py.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + import numpy as np -# from gui import guiWin +from natsort import natsorted +from qtpy.QtCore import QMutex, QThread, QTimer, QWaitCondition + +from cellacdc import ( + apps, + core, + html_utils, + utils, + preprocess, + printl, + widgets, + workers, +) + -class CombineGuiElements: +from .graphics import Graphics +from .preprocessing import Preprocessing + + +class CombineGui: def _setup_vars_combine(self): self.combineWorker = None self.combineDialog = None self.combineSegmViewToggle = None - + def combineDialogSaveCombinedData(self, dialog): # here check if all data has been processed? posData = self.data[self.pos_i] - + try: posData.combinedChannelsDataArray() except TypeError as e: - if 'Not all frames have been processed.' in str(e): + if "Not all frames have been processed." in str(e): msg = widgets.myMessageBox() txt = html_utils.paragraph( - 'Not all frames have been processed.
    ' - 'Please process all frames before saving.' + "Not all frames have been processed.
    " + "Please process all frames before saving." ) - msg.warning(self, 'Process all data before saving', txt) + msg.warning(self, "Process all data before saving", txt) return - helpText = ( - """ + helpText = """ The segm/img file will be saved with a different file name.

    Insert a name to append to the end of the new file name. The rest of the name will be the same as the original file base. """ - ) - hintText = 'Insert a name for the combined channels file:' + hintText = "Insert a name for the combined channels file:" basename = posData.basename if self.combineDialog.saveAsSegm(): - ext = '.npz' - hintText = hintText.replace('channels', 'segmentation') - helpText = helpText.replace('channels', 'segmentation') - basename = f'{basename}segm' + ext = ".npz" + hintText = hintText.replace("channels", "segmentation") + helpText = helpText.replace("channels", "segmentation") + basename = f"{basename}segm" else: - ext = '.tif' - + ext = ".tif" + win = apps.filenameDialog( basename=basename, ext=ext, hintText=hintText, - defaultEntry='combined', - helpText=helpText, + defaultEntry="combined", + helpText=helpText, allowEmpty=False, - parent=dialog + parent=dialog, ) win.exec_() if win.cancel: @@ -60,24 +77,25 @@ def combineDialogSaveCombinedData(self, dialog): appendedText = win.entryText if appendedText: - filename = f'{basename}_{appendedText}{ext}' + filename = f"{basename}_{appendedText}{ext}" else: - filename = f'{basename}{ext}' - + filename = f"{basename}{ext}" + self.progressWin = apps.QDialogWorkerProgress( - title='Saving combined channels(s)', + title="Saving combined channels(s)", parent=self, - pbarDesc='Saving combined channels(s)' + pbarDesc="Saving combined channels(s)", ) self.progressWin.show(self.app) self.progressWin.mainPbar.setMaximum(0) - - self.statusBarLabel.setText('Saving combined channels...') - + + self.statusBarLabel.setText("Saving combined channels...") + self.saveCombinedChannelsWorker = workers.SaveCombinedChannelsWorker( - self.data, filename, + self.data, + filename, ) - + self.saveCombinedChannelsThread = QThread() self.saveCombinedChannelsWorker.moveToThread(self.saveCombinedChannelsThread) self.saveCombinedChannelsWorker.signals.finished.connect( @@ -89,23 +107,19 @@ def combineDialogSaveCombinedData(self, dialog): self.saveCombinedChannelsThread.finished.connect( self.saveCombinedChannelsThread.deleteLater ) - - self.saveCombinedChannelsWorker.signals.critical.connect( - self.workerCritical - ) + + self.saveCombinedChannelsWorker.signals.critical.connect(self.workerCritical) self.saveCombinedChannelsWorker.signals.initProgressBar.connect( self.workerInitProgressbar ) self.saveCombinedChannelsWorker.signals.progressBar.connect( self.workerUpdateProgressbar ) - self.saveCombinedChannelsWorker.signals.progress.connect( - self.workerProgress - ) + self.saveCombinedChannelsWorker.signals.progress.connect(self.workerProgress) self.saveCombinedChannelsWorker.signals.finished.connect( self.saveCombinedChannelsWorkerFinished ) - + self.saveCombinedChannelsThread.started.connect( self.saveCombinedChannelsWorker.run ) @@ -113,155 +127,139 @@ def combineDialogSaveCombinedData(self, dialog): self.saveCombinedChannelsWorker.sigDebugShowImg.connect(self.debugShowImg) self.saveCombinedChannelsThread.start() - + def combineDialogStepsChanged(self): - steps, keep_input_data_type, formula = self.combineDialog.steps(return_keepInputDataType=True) + steps, keep_input_data_type, formula = self.combineDialog.steps( + return_keepInputDataType=True + ) if steps is None: - self.logger.warning('Combine channels recipe not initialized yet.') + self.logger.warning("Combine channels recipe not initialized yet.") return - - self.updateCombineChannelsPreview(steps=steps, keep_input_data_type=keep_input_data_type, formula=formula) + + self.updateCombineChannelsPreview( + steps=steps, keep_input_data_type=keep_input_data_type, formula=formula + ) def updateCombineChannelsPreview(self, *args, **kwargs): - force = kwargs.get('force', False) - + force = kwargs.get("force", False) + if self.combineDialog is None: return - + if not self.combineDialog.isVisible() and not force: return - + if not self.combineDialog.previewCheckbox.isChecked() and not force: return - - if kwargs.get('steps') is None: - steps, keep_input_data_type, formula = self.combineDialog.steps(return_keepInputDataType=True) + + if kwargs.get("steps") is None: + steps, keep_input_data_type, formula = self.combineDialog.steps( + return_keepInputDataType=True + ) else: - steps = kwargs.get('steps') - keep_input_data_type = kwargs.get('keep_input_data_type') - formula = kwargs.get('formula') + steps = kwargs.get("steps") + keep_input_data_type = kwargs.get("keep_input_data_type") + formula = kwargs.get("formula") if steps is None: - self.logger.warning('Combine channels recipe not initialized yet.') + self.logger.warning("Combine channels recipe not initialized yet.") return - - txt = 'Combining...' + + txt = "Combining..." self.logger.info(txt) self.statusBarLabel.setText(txt) - + self.combineEnqueueCurrentImage(steps, keep_input_data_type, formula) - + def viewCombineChannelDataToggled(self, checked): self.img1.setUseCombined(checked) - + if checked: self.combineViewAsSegmSetup() - else: # setimage1 is already called in combineViewAsSegmSetup + else: # setimage1 is already called in combineViewAsSegmSetup self.setImageImg1() if self.viewPreprocDataToggle.isChecked(): - self.viewPreprocDataToggle.toggled.disconnect() + self.viewPreprocDataToggle.toggled.disconnect() self.viewPreprocDataToggle.setChecked(False) - self.viewPreprocDataToggle.toggled.connect( - self.viewPreprocDataToggled - ) - + self.viewPreprocDataToggle.toggled.connect(self.viewPreprocDataToggled) + def setupCombiningChannels(self): posData = self.data[self.pos_i] if self.combineDialog is not None: self.combineDialog.close() - + ordered_channels = [ch for ch in posData.chNames if ch != self.user_ch_name] ordered_channels = natsorted(ordered_channels) ordered_channels = [self.user_ch_name] + ordered_channels - segmentations = [segm for segm in self.existingSegmEndNames] segmentations = natsorted(segmentations) - segmentations = ['current segm.'] + segmentations + segmentations = ["current segm."] + segmentations # also add segm ordered_channels.extend(segmentations) - + self.combineDialog = apps.CombineChannelsSetupDialogGUI( ordered_channels, - isTimelapse=posData.SizeT>1, - isZstack=posData.SizeZ>1, - isMultiPos=len(self.data)>1, + isTimelapse=posData.SizeT > 1, + isZstack=posData.SizeZ > 1, + isMultiPos=len(self.data) > 1, df_metadata=posData.metadata_df, - hideOnClosing=True, + hideOnClosing=True, # addApplyButton=True, - parent=self - ) - self.doPreviewPreprocImage = False #to do - self.combineDialog.sigApplyImage.connect( - self.combineCurrentImage - ) - self.combineDialog.sigApplyZstack.connect( - self.combineZStack - ) - self.combineDialog.sigApplyAllFrames.connect( - self.combineAllFrames - ) - self.combineDialog.sigApplyAllPos.connect( - self.combineAllPos - ) - self.combineDialog.sigPreviewToggled.connect( - self.combinePreviewToggled + parent=self, ) + self.doPreviewPreprocImage = False # to do + self.combineDialog.sigApplyImage.connect(self.combineCurrentImage) + self.combineDialog.sigApplyZstack.connect(self.combineZStack) + self.combineDialog.sigApplyAllFrames.connect(self.combineAllFrames) + self.combineDialog.sigApplyAllPos.connect(self.combineAllPos) + self.combineDialog.sigPreviewToggled.connect(self.combinePreviewToggled) self.combineDialog.sigSaveAsSegmCheckboxToggled.connect( self.combinePreviewViewAsSegmToggled ) - self.combineDialog.sigValuesChanged.connect( - self.combineDialogStepsChanged - ) + self.combineDialog.sigValuesChanged.connect(self.combineDialogStepsChanged) self.combineDialog.sigSavePreprocData.connect( self.combineDialogSaveCombinedData ) - self.combineDialog.sigClose.connect( - self.combineDialogClosed - ) + self.combineDialog.sigClose.connect(self.combineDialogClosed) if self.combineWorker is not None: return - + self.combineThread = QThread() self.combineMutex = QMutex() self.combineWaitCond = QWaitCondition() - + self.combineWorker = workers.CombineChannelsWorkerGUI( - self.combineMutex, self.combineWaitCond, + self.combineMutex, + self.combineWaitCond, logger_func=self.logger.info, # signals=self.signals # what are the singals for gui??? ) - + self.combineWorker.moveToThread(self.combineThread) self.combineWorker.signals.finished.connect(self.combineThread.quit) - self.combineWorker.signals.finished.connect( - self.combineWorker.deleteLater - ) + self.combineWorker.signals.finished.connect(self.combineWorker.deleteLater) self.combineThread.finished.connect(self.combineWorker.deleteLater) self.combineWorker.sigDone.connect(self.combineWorkerDone) - self.combineWorker.sigIsQueueEmpty.connect( - self.combineWorkerIsQueueEmpty - ) + self.combineWorker.sigIsQueueEmpty.connect(self.combineWorkerIsQueueEmpty) self.combineWorker.sigPreviewDone.connect(self.combineWorkerPreviewDone) self.combineWorker.signals.progress.connect(self.workerProgress) self.combineWorker.signals.critical.connect(self.workerCritical) self.combineWorker.signals.finished.connect(self.combineWorkerClosed) - self.combineWorker.sigAskLoadChannels.connect( - self.combineWorkerAskLoadChannels - ) - + self.combineWorker.sigAskLoadChannels.connect(self.combineWorkerAskLoadChannels) + self.combineThread.started.connect(self.combineWorker.run) self.combineThread.start() - - self.logger.info('Combine channels worker started.') - + + self.logger.info("Combine channels worker started.") + def combineDialogClosed(self, window): QTimer.singleShot(200, self._combineDialogClosed) - + def _combineDialogClosed(self): self.combineDialog = None @@ -274,17 +272,19 @@ def combineViewAsSegmSetup(self): if self.combineSegmViewToggle.isChecked(): self.combineSegmViewToggle.setChecked(False) self.combineSegmViewToggle.setCheckable(False) - + if not self.overlayLabelsButton.isChecked() and combineViewAsSegm: self.overlayLabelsButton.blockSignals(True) self.overlayLabelsButton.setChecked(True) - self.overlayLabels_cb(checked=True, selectedLabelsEndnames=['combined segm.']) + self.overlayLabels_cb( + checked=True, selectedLabelsEndnames=["combined segm."] + ) self.overlayLabelsButton.blockSignals(False) - + if combineViewAsSegm: if not self.combineSegmViewToggle.isChecked(): self.combineSegmViewToggle.setCheckable(True) - + # reset view to update the overlay labels self.combineSegmViewToggle.setChecked(False) self.combineSegmViewToggle.setChecked(True) @@ -295,22 +295,23 @@ def combineViewAsSegmSetup(self): def combineChannelsActionTriggered(self): if self.zProjComboBox is not None: curr_proj = self.zProjComboBox.currentText() - if curr_proj != 'single z-slice': - self.zProjComboBox.setCurrentText('single z-slice') - + if curr_proj != "single z-slice": + self.zProjComboBox.setCurrentText("single z-slice") + if self.switchPlaneCombobox is not None: depthAxes = self.switchPlaneCombobox.depthAxes() - if depthAxes != 'z': - self.switchPlaneCombobox.setCurrentText('xy') - + if depthAxes != "z": + self.switchPlaneCombobox.setCurrentText("xy") + if self.combineDialog is None: self.setupCombiningChannels() self.combineDialog.show() self.combineDialog.raise_() self.combineDialog.activateWindow() self.combineDialog.emitSigPreviewToggled() - -class CombineGUIWorker: + + +class CombineWorker(CombineGui, Graphics, Preprocessing): def combineEnqueueCurrentImage(self, steps, keep_input_data_type, formula): posData = self.data[self.pos_i] @@ -318,43 +319,43 @@ def combineEnqueueCurrentImage(self, steps, keep_input_data_type, formula): z_slice = self.z_slice_index() else: z_slice = 0 - + key = (self.pos_i, posData.frame_i, z_slice) self.combineWorker.enqueue( self.data, - steps, + steps, key, keep_input_data_type, output_as_segm=self.combineDialog.saveAsSegm(), formula=formula, ) - + def combinePreviewToggled(self, checked): self.viewCombineChannelDataToggle.setChecked(checked) self.updateCombineChannelsPreview() - + def combinePreviewViewAsSegmToggled(self, checked): self.updateCombineChannelsPreview() self.combineViewAsSegmSetup() - + def combineCurrentImage( - self, - steps: List[Dict[str, Any]]=None, - keep_input_data_type:bool=None, - formula: str=None, - ): + self, + steps: List[Dict[str, Any]] = None, + keep_input_data_type: bool = None, + formula: str = None, + ): if steps and keep_input_data_type is None: - raise ValueError('keep_input_data_type must be set if steps is set') - + raise ValueError("keep_input_data_type must be set if steps is set") + if steps is None: steps, keep_input_data_type, formula = self.combineDialog.steps( return_keepInputDataType=True ) - txt = 'Combining current image...' + txt = "Combining current image..." self.logger.info(txt) self.statusBarLabel.setText(txt) - + selected_channel = core.get_selected_channels(steps) self.getChData(requ_ch=selected_channel) @@ -364,46 +365,46 @@ def combineCurrentImage( key = (pos_i, self.data[pos_i].frame_i, z_slice) self.combineWorker.setupJob( - self.data, - steps, + self.data, + steps, keep_input_data_type, key, output_as_segm=self.combineDialog.saveAsSegm(), formula=formula, ) - + self.combineWorker.wakeUp() - + def combineZStack( - self, - steps: List[Dict[str, Any]]=None, - keep_input_data_type:bool=None, - formula: str=None, - ): + self, + steps: List[Dict[str, Any]] = None, + keep_input_data_type: bool = None, + formula: str = None, + ): if self.combineDialog is not None: keep_input_data_type = ( self.combineDialog.keepInputDataTypeToggle.isChecked() ) - + if steps and keep_input_data_type is None: - raise ValueError('keep_input_data_type must be set if steps is set') - + raise ValueError("keep_input_data_type must be set if steps is set") + if steps is None: steps, keep_input_data_type, formula = self.combineDialog.steps( return_keepInputDataType=True ) - txt = 'Combining z-stack...' + txt = "Combining z-stack..." self.statusBarLabel.setText(txt) self.logger.info(txt) - + selected_channel = core.get_selected_channels(steps) self.getChData(requ_ch=selected_channel) posData = self.data[self.pos_i] key = (self.pos_i, posData.frame_i, None) self.combineWorker.setupJob( - self.data, - steps, + self.data, + steps, keep_input_data_type, key, output_as_segm=self.combineDialog.saveAsSegm(), @@ -411,28 +412,31 @@ def combineZStack( ) self.combineWorker.wakeUp() - - def combineAllFrames(self, - steps: List[Dict[str, Any]]=None, - keep_input_data_type:bool=None, - formula: str=None, - ): + + def combineAllFrames( + self, + steps: List[Dict[str, Any]] = None, + keep_input_data_type: bool = None, + formula: str = None, + ): if steps and not keep_input_data_type: - raise ValueError('keep_input_data_type must be set if steps is set') - + raise ValueError("keep_input_data_type must be set if steps is set") + if steps is None: - steps, keep_input_data_type, formula = self.combineDialog.steps(return_keepInputDataType=True) - txt = 'Combining all frames...' + steps, keep_input_data_type, formula = self.combineDialog.steps( + return_keepInputDataType=True + ) + txt = "Combining all frames..." self.logger.info(txt) self.statusBarLabel.setText(txt) - + selected_channel = core.get_selected_channels(steps) self.getChData(requ_ch=selected_channel) key = (self.pos_i, None, None) self.combineWorker.setupJob( - self.data, - steps, + self.data, + steps, keep_input_data_type, key, output_as_segm=self.combineDialog.saveAsSegm(), @@ -440,31 +444,33 @@ def combineAllFrames(self, ) self.combineWorker.wakeUp() - - def combineAllPos(self, - steps: List[Dict[str, Any]]=None, - keep_input_data_type:bool=None, - formula: str=None, - ): + + def combineAllPos( + self, + steps: List[Dict[str, Any]] = None, + keep_input_data_type: bool = None, + formula: str = None, + ): if steps and not keep_input_data_type: - raise ValueError('keep_input_data_type must be set if steps is set') - + raise ValueError("keep_input_data_type must be set if steps is set") + if steps is None: - steps, keep_input_data_type, formula = self.combineDialog.steps(return_keepInputDataType=True) - txt = 'Combining all Positions...' + steps, keep_input_data_type, formula = self.combineDialog.steps( + return_keepInputDataType=True + ) + txt = "Combining all Positions..." self.logger.info(txt) self.statusBarLabel.setText(txt) - + selected_channel = core.get_selected_channels(steps) - + for pos_i in range(len(self.data)): self.getChData(requ_ch=selected_channel, pos_i=pos_i) - key = (None, None, None) self.combineWorker.setupJob( - self.data, - steps, + self.data, + steps, keep_input_data_type, key, output_as_segm=self.combineDialog.saveAsSegm(), @@ -472,14 +478,14 @@ def combineAllPos(self, ) self.combineWorker.wakeUp() - + def stopCombineWorker(self): - self.logger.info('Closing combine worker...') + self.logger.info("Closing combine worker...") try: self.combineWorker.stop() except Exception as err: pass - + def combineWorkerCritical(self, error): self.combineDialog.appliedFinished() self.workerCritical(error) @@ -490,15 +496,13 @@ def combineWorkerIsQueueEmpty(self, isEmpty: bool): else: self.combineDialog.setDisabled(True) self.combineDialog.infoLabel.setText( - 'Computing preview...
    ' - '(Feel free to use Cell-ACDC while waiting)' + "Computing preview...
    " + "(Feel free to use Cell-ACDC while waiting)" ) def combineWorkerPreviewDone( - self, - processed_data: List[np.ndarray], - keys: List[Tuple[int, int, int]] - ): + self, processed_data: List[np.ndarray], keys: List[Tuple[int, int, int]] + ): unique_pos = {key[0] for key in keys} per_pos_data = {pos_i: [] for pos_i in unique_pos} @@ -506,9 +510,9 @@ def combineWorkerPreviewDone( pos_i, frame_i, z_slice = key per_pos_data[pos_i].append((key, img)) - for pos_i in unique_pos: + for pos_i in unique_pos: posData = self.data[pos_i] - if not hasattr(posData, 'combine_img_data'): + if not hasattr(posData, "combine_img_data"): posData.combine_img_data = preprocess.PreprocessedData( image_data=np.zeros(posData.img_data.shape) ) @@ -535,22 +539,26 @@ def combineWorkerPreviewDone( self.data, pos_i, frame_i, z_slice ) else: - raise ValueError('Invalid number of dimensions in img_data.') - + raise ValueError("Invalid number of dimensions in img_data.") + posData = self.data[self.pos_i] curr_pos_i, curr_frame_i, curr_z_slice = ( - self.pos_i,self.data[self.pos_i].frame_i, self.z_slice_index() + self.pos_i, + self.data[self.pos_i].frame_i, + self.z_slice_index(), ) if not self.combineDialog.saveAsSegm(): self.img1.updateMinMaxValuesCombinedData( self.data, curr_pos_i, curr_frame_i, curr_z_slice ) - + self.combineViewAsSegmSetup() - + def combineWorkerAskLoadChannels(self, requ_channels, pos_i): # spit channels and segm to load - segms_to_load, channels_to_load, current_segm = myutils.separate_fluo_segment_channels(requ_channels) + segms_to_load, channels_to_load, current_segm = ( + utils.separate_fluo_segment_channels(requ_channels) + ) if pos_i is None: pos_i = list(range(len(self.data))) elif not isinstance(pos_i, list): @@ -562,12 +570,10 @@ def combineWorkerAskLoadChannels(self, requ_channels, pos_i): for segm in segms_to_load: self.loadOverlayLabelsData(segm, pos_i=i) self.combineWorker.wake_waitCondLoadFluoChannels() - + def combineWorkerDone( - self, - processed_data: List[np.ndarray], - keys: List[Tuple[int, int, int]] - ): + self, processed_data: List[np.ndarray], keys: List[Tuple[int, int, int]] + ): self.setStatusBarLabel(log=False) self.combineDialog.appliedFinished() @@ -578,9 +584,9 @@ def combineWorkerDone( pos_i, frame_i, z_slice = key per_pos_data[pos_i].append((key, img)) - for pos_i in unique_pos: + for pos_i in unique_pos: posData = self.data[pos_i] - if not hasattr(posData, 'combine_img_data'): + if not hasattr(posData, "combine_img_data"): posData.combine_img_data = preprocess.PreprocessedData( image_data=np.zeros(posData.img_data.shape) ) @@ -593,9 +599,9 @@ def combineWorkerDone( posData.combine_img_data[frame_i][z_slice] = processed_data if not self.combineDialog.saveAsSegm(): self.img1.updateMinMaxValuesCombinedData( - self.data, pos_i, frame_i, z_slice - ) - if not self.combineDialog.saveAsSegm(): + self.data, pos_i, frame_i, z_slice + ) + if not self.combineDialog.saveAsSegm(): self.img1.updateMinMaxValuesCombinedDataProjections( self.data, pos_i, frame_i ) @@ -607,31 +613,31 @@ def combineWorkerDone( self.img1.updateMinMaxValuesCombinedData( self.data, pos_i, frame_i, z_slice ) - + if not self.viewCombineChannelDataToggle.isChecked(): self.viewCombineChannelDataToggle.setChecked(True) else: self.setImageImg1() def combineWorkerClosed(self, worker): - self.logger.info('Combine worker stopped.') - + self.logger.info("Combine worker stopped.") + def saveCombinedChannelsWorkerFinished(self): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() self.progressWin = None - + self.setStatusBarLabel() - self.logger.info('Combined channels data saved!') - self.titleLabel.setText('Combined channels data saved!', color='w') + self.logger.info("Combined channels data saved!") + self.titleLabel.setText("Combined channels data saved!", color="w") def saveCombineWorkerFinished(self): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() self.progressWin = None - + self.setStatusBarLabel() - self.logger.info('Combined channels saved!') - self.titleLabel.setText('Combined channels saved!', color='w') \ No newline at end of file + self.logger.info("Combined channels saved!") + self.titleLabel.setText("Combined channels saved!", color="w") diff --git a/cellacdc/mixins/curvature_tools.py b/cellacdc/mixins/curvature_tools.py new file mode 100644 index 000000000..458546109 --- /dev/null +++ b/cellacdc/mixins/curvature_tools.py @@ -0,0 +1,309 @@ +"""Qt view adapter for curvature and spline tools.""" + +from __future__ import annotations + +import numpy as np +import pyqtgraph as pg +import skimage.draw +import skimage.measure + +from .brush_tools import BrushTools +from .undo_redo import UndoRedo + + +class CurvatureTools(BrushTools, UndoRedo): + """Extracted from guiWin.""" + + def clearCurvItems(self, removeItems=True): + try: + posData = self.data[self.pos_i] + curvItems = zip( + posData.curvPlotItems, posData.curvAnchorsItems, posData.curvHoverItems + ) + for plotItem, curvAnchors, hoverItem in curvItems: + plotItem.setData([], []) + curvAnchors.setData([], []) + hoverItem.setData([], []) + if removeItems: + self.ax1.removeItem(plotItem) + self.ax1.removeItem(curvAnchors) + self.ax1.removeItem(hoverItem) + + if removeItems: + posData.curvPlotItems = [] + posData.curvAnchorsItems = [] + posData.curvHoverItems = [] + except AttributeError: + # traceback.print_exc() + pass + + def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): + posData = self.data[self.pos_i] + # Store undo state before modifying stuff + self.storeUndoRedoStates(False, storeOnlyZoom=True) + + if isRightClick: + xxS, yyS = self.curvPlotItem.getData() + if xxS is None: + self.setUncheckedAllButtons() + return + self.smoothAutoContWithSpline() + + xxS, yyS = self.getClosedSplineCoords() + + if self.autoIDcheckbox.isChecked(): + self.setBrushID() + curvToolID = posData.brushID + else: + curvToolID = self.editIDspinbox.value() + posData.brushID = curvToolID + + if curvToolID <= 0: + self.setBrushID() + curvToolID = posData.brushID + + lab2D = self.get_2Dlab(posData.lab).copy() + newIDMask = np.zeros(lab2D.shape, bool) + rr, cc = skimage.draw.polygon(yyS, xxS, shape=lab2D.shape) + newIDMask[rr, cc] = True + newIDMask[lab2D != 0] = False + lab2D[newIDMask] = curvToolID + self.set_2Dlab(lab2D) + self.currentLab2D = lab2D + + def curvTool_cb(self, checked): + posData = self.data[self.pos_i] + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.curvToolButton) + self.connectLeftClickButtons() + self.hoverLinSpace = np.linspace(0, 1, 1000) + self.curvPlotItem = pg.PlotDataItem(pen=self.newIDs_cpen) + self.curvHoverPlotItem = pg.PlotDataItem(pen=self.oldIDs_cpen) + self.curvAnchors = pg.ScatterPlotItem( + symbol="o", + size=9, + brush=pg.mkBrush((255, 0, 0, 50)), + pen=pg.mkPen((255, 0, 0), width=2), + hoverable=True, + hoverPen=pg.mkPen((255, 0, 0), width=3), + hoverBrush=pg.mkBrush((255, 0, 0)), + tip=None, + ) + self.ax1.addItem(self.curvAnchors) + self.ax1.addItem(self.curvPlotItem) + self.ax1.addItem(self.curvHoverPlotItem) + self.splineHoverON = True + posData.curvPlotItems.append(self.curvPlotItem) + posData.curvAnchorsItems.append(self.curvAnchors) + posData.curvHoverItems.append(self.curvHoverPlotItem) + else: + self.splineHoverON = False + self.isRightClickDragImg1 = False + self.clearCurvItems() + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + self.showEditIDwidgets(checked) + + def drawAutoContour(self, y2, x2): + y1, x1 = self.autoCont_y0, self.autoCont_x0 + Dy = abs(y2 - y1) + Dx = abs(x2 - x1) + edge = self.getDisplayedImg1() + if Dy != 0 or Dx != 0: + # NOTE: numIter takes care of any lag in mouseMoveEvent + numIter = int(round(max((Dy, Dx)))) + alfa = np.arctan2(y1 - y2, x2 - x1) + base = np.pi / 4 + alfa_dir = round((base * round(alfa / base)) * 180 / np.pi) + for _ in range(numIter): + y1, x1 = self.autoCont_y0, self.autoCont_x0 + yy, xx = self.get_dir_coords(alfa_dir, y1, x1, edge.shape) + a_dir = edge[yy, xx] + min_int = np.max(a_dir) + min_i = list(a_dir).index(min_int) + y, x = yy[min_i], xx[min_i] + try: + xx, yy = self.curvHoverPlotItem.getData() + except TypeError: + xx, yy = [], [] + + if xx is None or yy is None or len(xx) == 0 or len(yy) == 0: + xx, yy = [], [] + elif x == xx[-1] and y == yy[-1]: + # Do not append point equal to last point + return + + xx = np.r_[xx, x] + yy = np.r_[yy, y] + try: + self.curvHoverPlotItem.setData(xx, yy) + self.curvPlotItem.setData(xx, yy) + except TypeError: + pass + self.autoCont_y0, self.autoCont_x0 = y, x + + def getClosedSplineCoords(self): + xxS, yyS = self.curvPlotItem.getData() + bbox_area = (xxS.max() - xxS.min()) * (yyS.max() - yyS.min()) + if bbox_area < 26_000: + # Using 1000 is fast enough according to profiling + return xxS, yyS + + optimalSpaceSize = self.splineToObjModel.predict(bbox_area, max_exec_time=150) + if optimalSpaceSize >= 1000: + # Using 1000 is fast enough according to model + return xxS, yyS + + if optimalSpaceSize < 100: + # Do not allow a rough spline + optimalSpaceSize = 100 + + # Get spline with optimal space size so that exec time + # or skimage.draw.polygon is less than 150 ms + xx, yy = self.curvAnchors.getData() + resolutionSpace = np.linspace(0, 1, int(optimalSpaceSize)) + xxS, yyS = self.getSpline(xx, yy, resolutionSpace=resolutionSpace, per=True) + return xxS, yyS + + def getPolygonBrush(self, yxc2, Y, X): + # see https://en.wikipedia.org/wiki/Tangent_lines_to_circles + y1, x1 = self.yPressAx2, self.xPressAx2 + y2, x2 = yxc2 + R = self.brushSizeSpinbox.value() + r = R + + arcsin_den = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + arctan_den = x2 - x1 + if arcsin_den != 0 and arctan_den != 0: + beta = np.arcsin((R - r) / arcsin_den) + gamma = -np.arctan((y2 - y1) / arctan_den) + alpha = gamma - beta + x3 = x1 + r * np.sin(alpha) + y3 = y1 + r * np.cos(alpha) + x4 = x2 + R * np.sin(alpha) + y4 = y2 + R * np.cos(alpha) + + alpha = gamma + beta + x5 = x1 - r * np.sin(alpha) + y5 = y1 - r * np.cos(alpha) + x6 = x2 - R * np.sin(alpha) + y6 = y2 - R * np.cos(alpha) + + rr_poly, cc_poly = skimage.draw.polygon( + [y3, y4, y6, y5], [x3, x4, x6, x5], shape=(Y, X) + ) + else: + rr_poly, cc_poly = [], [] + + self.yPressAx2, self.xPressAx2 = y2, x2 + return rr_poly, cc_poly + + def getSpline(self, xx, yy, resolutionSpace=None, per=False, appendFirst=False): + # Remove duplicates + valid = np.where(np.abs(np.diff(xx)) + np.abs(np.diff(yy)) > 0) + xx = np.r_[xx[valid], xx[-1]] + yy = np.r_[yy[valid], yy[-1]] + if appendFirst: + xx = np.r_[xx, xx[0]] + yy = np.r_[yy, yy[0]] + per = True + + # Interpolate splice + if resolutionSpace is None: + resolutionSpace = self.hoverLinSpace + k = 2 if len(xx) == 3 else 3 + + try: + tck, u = scipy.interpolate.splprep([xx, yy], s=0, k=k, per=per) + xi, yi = scipy.interpolate.splev(resolutionSpace, tck) + return xi, yi + except (ValueError, TypeError): + # Catch errors where we know why splprep fails + return [], [] + + def get_dir_coords(self, alfa_dir, yd, xd, shape, connectivity=1): + h, w = shape + y_above = yd + 1 if yd + 1 < h else yd + y_below = yd - 1 if yd > 0 else yd + x_right = xd + 1 if xd + 1 < w else xd + x_left = xd - 1 if xd > 0 else xd + if alfa_dir == 0: + yy = [y_below, y_below, yd, y_above, y_above] + xx = [xd, x_right, x_right, x_right, xd] + elif alfa_dir == 45: + yy = [y_below, y_below, y_below, yd, y_above] + xx = [x_left, xd, x_right, x_right, x_right] + elif alfa_dir == 90: + yy = [yd, y_below, y_below, y_below, yd] + xx = [x_left, x_left, xd, x_right, x_right] + elif alfa_dir == 135: + yy = [y_above, yd, y_below, y_below, y_below] + xx = [x_left, x_left, x_left, xd, x_right] + elif alfa_dir == -180 or alfa_dir == 180: + yy = [y_above, y_above, yd, y_below, y_below] + xx = [xd, x_left, x_left, x_left, xd] + elif alfa_dir == -135: + yy = [y_below, yd, y_above, y_above, y_above] + xx = [x_left, x_left, x_left, xd, x_right] + elif alfa_dir == -90: + yy = [yd, y_above, y_above, y_above, yd] + xx = [x_left, x_left, xd, x_right, x_right] + else: + yy = [y_above, y_above, y_above, yd, y_below] + xx = [x_left, xd, x_right, x_right, x_right] + if connectivity == 1: + return yy[1:4], xx[1:4] + else: + return yy, xx + + def hoverEventDrawSpline(self, event): + x, y = event.pos() + xx, yy = self.curvAnchors.getData() + hoverAnchors = self.curvAnchors.pointsAt(event.pos()) + per = False + # If we are hovering the starting point we generate + # a closed spline + if len(xx) < 2: + return + + if len(hoverAnchors) > 0: + xA_hover, yA_hover = hoverAnchors[0].pos() + if xx[0] == xA_hover and yy[0] == yA_hover: + per = True + if per: + # Append start coords and close spline + xx = np.r_[xx, xx[0]] + yy = np.r_[yy, yy[0]] + xi, yi = self.getSpline(xx, yy, per=per) + # self.curvPlotItem.setData([], []) + else: + # Append mouse coords + xx = np.r_[xx, x] + yy = np.r_[yy, y] + xi, yi = self.getSpline(xx, yy, per=per) + self.curvHoverPlotItem.setData(xi, yi) + + def smoothAutoContWithSpline(self, n=3): + try: + xx, yy = self.curvHoverPlotItem.getData() + if xx is None or yy is None: + return + # Downsample by taking every nth coord + xxA, yyA = xx[::n], yy[::n] + rr, cc = skimage.draw.polygon(yyA, xxA) + self.autoContObjMask[rr, cc] = 1 + rp = skimage.measure.regionprops(self.autoContObjMask) + if not rp: + return + obj = rp[0] + cont = self.getObjContours(obj) + xxC, yyC = cont[:, 0], cont[:, 1] + xxA, yyA = xxC[::n], yyC[::n] + self.xxA_autoCont, self.yyA_autoCont = xxA, yyA + xxS, yyS = self.getSpline(xxA, yyA, per=True, appendFirst=True) + if len(xxS) > 0: + self.curvPlotItem.setData(xxS, yyS) + except (TypeError, ValueError): + pass diff --git a/cellacdc/mixins/custom_annotations.py b/cellacdc/mixins/custom_annotations.py new file mode 100644 index 000000000..652015042 --- /dev/null +++ b/cellacdc/mixins/custom_annotations.py @@ -0,0 +1,640 @@ +"""Qt view adapter for custom annotations.""" + +from __future__ import annotations + +import json +import os +import re +import traceback +from collections import defaultdict + +import pyqtgraph as pg +import pandas as pd +from qtpy.QtGui import QColor + +from cellacdc import apps, html_utils, settings_folderpath, widgets + + +custom_annot_path = os.path.join(settings_folderpath, "custom_annotations.json") + +from .annotation_display import AnnotationDisplay +from .object_properties import ObjectProperties + + +class CustomAnnotations(AnnotationDisplay, ObjectProperties): + """Extracted from guiWin.""" + + def addCustomAnnnotScatterPlot(self, symbolColor, symbol, toolButton): + # Add scatter plot item + symbolColorBrush = [0, 0, 0, 50] + symbolColorBrush[:3] = symbolColor.getRgb()[:3] + scatterPlotItem = widgets.CustomAnnotationScatterPlotItem() + scatterPlotItem.setData( + [], + [], + symbol=symbol, + pxMode=False, + brush=pg.mkBrush(symbolColorBrush), + size=15, + pen=pg.mkPen(width=3, color=symbolColor), + hoverable=True, + hoverBrush=pg.mkBrush(symbolColor), + tip=None, + ) + scatterPlotItem.sigHovered.connect(self.customAnnotHovered) + scatterPlotItem.button = toolButton + self.customAnnotDict[toolButton]["scatterPlotItem"] = scatterPlotItem + self.ax1.addItem(scatterPlotItem) + + def addCustomAnnotButtonAllLoadedPos(self): + allPosCustomAnnot = {} + for pos_i, posData in enumerate(self.data): + self.addCustomAnnotationSavedPos(pos_i=pos_i) + allPosCustomAnnot = {**allPosCustomAnnot, **posData.customAnnot} + for posData in self.data: + posData.customAnnot = allPosCustomAnnot + + def addCustomAnnotation(self): + self.readSavedCustomAnnot() + + self.addAnnotWin = apps.customAnnotationDialog( + self.savedCustomAnnot, parent=self + ) + self.addAnnotWin.sigDeleteSelecAnnot.connect(self.deleteSelectedAnnot) + self.addAnnotWin.exec_() + if self.addAnnotWin.cancel: + self.logger.info("Custom annotation process cancelled.") + return + + symbol = self.addAnnotWin.symbol + symbolColor = self.addAnnotWin.state["symbolColor"] + keySequence = self.addAnnotWin.shortcutWidget.widget.keySequence + toolTip = self.addAnnotWin.toolTip + name = self.addAnnotWin.state["name"] + keepActive = self.addAnnotWin.state.get("keepActive", True) + isHideChecked = self.addAnnotWin.state.get("isHideChecked", True) + + proceed = self.checkNameExists(name) + if not proceed: + self.logger.info("Custom annotation process cancelled.") + return + + self.addCustomAnnotationItems( + symbol, + symbolColor, + keySequence, + toolTip, + name, + keepActive, + isHideChecked, + self.addAnnotWin.state, + ) + self.saveCustomAnnot() + self.doCustomAnnotation(0) + + def addCustomAnnotationButton( + self, + symbol, + symbolColor, + keySequence, + toolTip, + annotName, + keepActive, + isHideChecked, + ): + toolButton = widgets.customAnnotToolButton( + symbol, + symbolColor, + parent=self, + keepToolActive=keepActive, + isHideChecked=isHideChecked, + ) + toolButton.setCheckable(True) + self.checkableQButtonsGroup.addButton(toolButton) + if keySequence is not None: + toolButton.setShortcut(keySequence) + toolButton.setToolTip(toolTip) + toolButton.name = annotName + toolButton.toggled.connect(self.customAnnotButtonToggled) + toolButton.sigRemoveAction.connect(self.removeCustomAnnotButton) + toolButton.sigKeepActiveAction.connect(self.customAnnotKeepActive) + toolButton.sigHideAction.connect(self.customAnnotHide) + toolButton.sigModifyAction.connect(self.customAnnotModify) + action = self.annotateToolbar.addWidget(toolButton) + return toolButton, action + + def addCustomAnnotationItems( + self, + symbol, + symbolColor, + keySequence, + toolTip, + name, + keepActive, + isHideChecked, + state, + ): + toolButton, action = self.addCustomAnnotationButton( + symbol, symbolColor, keySequence, toolTip, name, keepActive, isHideChecked + ) + + self.customAnnotDict[toolButton] = { + "action": action, + "state": state, + "annotatedIDs": [defaultdict(list) for _ in range(len(self.data))], + } + + # Save custom annotation to cellacdc/temp/custom_annotations.json + state_to_save = state.copy() + state_to_save["symbolColor"] = tuple(symbolColor.getRgb()) + self.savedCustomAnnot[name] = state_to_save + for posData in self.data: + posData.customAnnot[name] = state_to_save + + # Add scatter plot item + self.addCustomAnnnotScatterPlot(symbolColor, symbol, toolButton) + + customAnnotButton = self.customAnnotDict[toolButton] + allPosAnnotatedIDs = customAnnotButton["annotatedIDs"] + # Add 0s column to acdc_df + for pos_i, posData in enumerate(self.data): + for frame_i, data_dict in enumerate(posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + if name not in acdc_df.columns: + acdc_df[name] = 0 + else: + acdc_df[name] = acdc_df[name].astype(int) + acdc_df_annot = acdc_df[acdc_df[name] == 1].reset_index() + annot_IDs = acdc_df_annot["Cell_ID"].to_list() + allPosAnnotatedIDs[pos_i][frame_i].extend(annot_IDs) + + if posData.acdc_df is not None: + if name not in posData.acdc_df.columns: + posData.acdc_df[name] = 0 + else: + posData.acdc_df[name] = posData.acdc_df[name].astype(int) + acdc_df_annot = posData.acdc_df[ + posData.acdc_df[name] == 1 + ].reset_index() + annot_IDs = acdc_df_annot["Cell_ID"].to_list() + allPosAnnotatedIDs[pos_i][frame_i].extend(annot_IDs) + + def addCustomAnnotationSavedPos(self, pos_i=None): + if pos_i is None: + pos_i = self.pos_i + + posData = self.data[pos_i] + for name, annotState in posData.customAnnot.items(): + # Check if button is already present and update only annotated IDs + buttons = [b for b in self.customAnnotDict.keys() if b.name == name] + if buttons: + toolButton = buttons[0] + allAnnotedIDs = self.customAnnotDict[toolButton]["annotatedIDs"] + allAnnotedIDs[pos_i] = posData.customAnnotIDs.get(name, {}) + continue + + try: + symbol = re.findall(r"\'(.+)\'", annotState["symbol"])[0] + except Exception as e: + self.logger.info(traceback.format_exc()) + symbol = "o" + + symbolColor = QColor(*annotState["symbolColor"]) + shortcut = annotState["shortcut"] + if shortcut is not None: + keySequence = widgets.macShortcutToWindows(shortcut) + keySequence = widgets.KeySequenceFromText(keySequence) + else: + keySequence = None + toolTip = utils.getCustomAnnotTooltip(annotState) + keepActive = annotState.get("keepActive", True) + isHideChecked = annotState.get("isHideChecked", True) + + toolButton, action = self.addCustomAnnotationButton( + symbol, + symbolColor, + keySequence, + toolTip, + name, + keepActive, + isHideChecked, + ) + allPosAnnotIDs = [ + pos.customAnnotIDs.get(name, defaultdict(list)) for pos in self.data + ] + self.customAnnotDict[toolButton] = { + "action": action, + "state": annotState, + "annotatedIDs": allPosAnnotIDs, + } + + self.addCustomAnnnotScatterPlot(symbolColor, symbol, toolButton) + + def askCustomAnnotationNameExists(self, name): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + The annotationa called {name} already exists in the + acdc_output CSV file.

    + If you continue, this column will be used to initialize + pre-annotated objects.

    + Do you want to continue? + """) + noButton, yesButton = msg.question( + self, + "Custom annotation name already exists", + txt, + buttonsTexts=("No, stop process", "Yes, use existing column"), + ) + return msg.clickedButton == yesButton + + def checkNameExists(self, name): + posData = self.data[self.pos_i] + for frame_i, data_dict in enumerate(posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + if name in acdc_df.columns: + return self.askCustomAnnotationNameExists(name) + + if posData.acdc_df is not None and name in posData.acdc_df.columns: + return self.askCustomAnnotationNameExists(name) + + return True + + def clearCustomAnnot(self): + for button in self.customAnnotDict.keys(): + scatterPlotItem = self.customAnnotDict[button]["scatterPlotItem"] + scatterPlotItem.setData([], []) + + def clearScatterPlotCustomAnnotButton(self, button): + scatterPlotItem = self.customAnnotDict[button]["scatterPlotItem"] + scatterPlotItem.setData([], []) + + def customAnnotButtonToggled(self, checked): + if checked: + self.customAnnotButton = self.sender() + # Uncheck the other buttons + for button in self.customAnnotDict.keys(): + if button == self.sender(): + continue + + button.toggled.disconnect() + self.clearScatterPlotCustomAnnotButton(button) + button.setChecked(False) + button.toggled.connect(self.customAnnotButtonToggled) + self.doCustomAnnotation(0) + else: + self.customAnnotButton = None + button = self.sender() + clearAnnotation = ( + button.isHideChecked or not self.viewAllCustomAnnotAction.isChecked() + ) + if clearAnnotation: + self.clearScatterPlotCustomAnnotButton(button) + self.setHighlightID(False) + self.resetCursor() + + def customAnnotHide(self, button): + self.customAnnotDict[button]["state"]["isHideChecked"] = button.isHideChecked + clearAnnot = ( + not button.isChecked() + and button.isHideChecked + and not self.viewAllCustomAnnotAction.isChecked() + ) + if clearAnnot: + # User checked hide annot with the button not active --> clear + self.clearScatterPlotCustomAnnotButton(button) + elif not button.isChecked(): + # User uncheked hide annot with the button not active --> show + self.doCustomAnnotation(0) + + def customAnnotHovered(self, scatterPlotItem, points, event): + # Show tool tip when hovering an annotation with annotation name and ID + vb = scatterPlotItem.getViewBox() + if vb is None: + return + if len(points) > 0: + posData = self.data[self.pos_i] + point = points[0] + x, y = point.pos().x(), point.pos().y() + xdata, ydata = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + vb.setToolTip(f"Annotation name: {scatterPlotItem.button.name}\nID = {ID}") + else: + vb.setToolTip("") + + def customAnnotKeepActive(self, button): + self.customAnnotDict[button]["state"]["keepActive"] = button.keepToolActive + + def customAnnotModify(self, button): + state = self.customAnnotDict[button]["state"] + self.addAnnotWin = apps.customAnnotationDialog( + self.savedCustomAnnot, state=state + ) + self.addAnnotWin.sigDeleteSelecAnnot.connect(self.deleteSelectedAnnot) + self.addAnnotWin.exec_() + if self.addAnnotWin.cancel: + return + + # Rename column if existing + posData = self.data[self.pos_i] + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + if acdc_df is not None: + old_name = self.customAnnotDict[button]["state"]["name"] + new_name = self.addAnnotWin.state["name"] + acdc_df = acdc_df.rename(columns={old_name: new_name}) + posData.allData_li[posData.frame_i]["acdc_df"] = acdc_df + + self.customAnnotDict[button]["state"] = self.addAnnotWin.state + + name = self.addAnnotWin.state["name"] + state_to_save = self.addAnnotWin.state.copy() + symbolColor = self.addAnnotWin.state["symbolColor"] + state_to_save["symbolColor"] = tuple(symbolColor.getRgb()) + self.savedCustomAnnot[name] = self.addAnnotWin.state + self.saveCustomAnnot() + + symbol = self.addAnnotWin.symbol + symbolColor = self.customAnnotDict[button]["state"]["symbolColor"] + button.setColor(symbolColor) + button.update() + symbolColorBrush = [0, 0, 0, 50] + symbolColorBrush[:3] = symbolColor.getRgb()[:3] + scatterPlotItem = self.customAnnotDict[button]["scatterPlotItem"] + xx, yy = scatterPlotItem.getData() + if xx is None: + xx, yy = [], [] + scatterPlotItem.setData( + xx, + yy, + symbol=symbol, + pxMode=False, + brush=pg.mkBrush(symbolColorBrush), + size=15, + pen=pg.mkPen(width=3, color=symbolColor), + ) + + def deleteSavedAnnotation(self): + for item in self.selectAnnotWin.listBox.selectedItems(): + name = item.text() + self.savedCustomAnnot.pop(name) + self.deleteSelectedAnnot(self.selectAnnotWin.listBox.selectedItems()) + items = list(self.savedCustomAnnot.keys()) + self.selectAnnotWin.listBox.clear() + self.selectAnnotWin.listBox.addItems(items) + + def deleteSelectedAnnot(self, itemsToDelete): + self.saveCustomAnnot(only_temp=True) + + def doCustomAnnotation(self, ID): + mode = self.modeComboBox.currentText() + if not self.isSnapshot and mode != "Custom annotations": + # Do not show annotations if timelapse and mode not annotations + return + + if self.switchPlaneCombobox.depthAxes() != "z": + return + + # NOTE: pass 0 for ID to not add + posData = self.data[self.pos_i] + if self.viewAllCustomAnnotAction.isChecked(): + # User requested to show all annotations --> iterate all buttons + # Unless it actively clicked to annotate --> avoid annotating object + # with all the annotations present + buttons = list(self.customAnnotDict.keys()) + else: + # Annotate if the button is active or isHideChecked is False + buttons = [ + b + for b in self.customAnnotDict.keys() + if (b.isChecked() or not b.isHideChecked) + ] + if not buttons: + return + + for button in buttons: + annotatedIDs = self.customAnnotDict[button]["annotatedIDs"][self.pos_i] + annotIDs_frame_i = annotatedIDs.get(posData.frame_i, []) + state = self.customAnnotDict[button]["state"] + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + + if button.isChecked() and ID > 0: + # Annotate only if existing ID and the button is checked + if ID in annotIDs_frame_i: + annotIDs_frame_i.remove(ID) + acdc_df.at[ID, state["name"]] = 0 + elif ID != 0: + annotIDs_frame_i.append(ID) + + annotPerButton = self.customAnnotDict[button] + allAnnotedIDs = annotPerButton["annotatedIDs"] + posAnnotedIDs = allAnnotedIDs[self.pos_i] + posAnnotedIDs[posData.frame_i] = annotIDs_frame_i + + if acdc_df is None: + self.store_data(autosave=False) + acdc_df = posData.allData_li[posData.frame_i]["acdc_df"] + + xx, yy = [], [] + for annotID in annotIDs_frame_i: + if annotID not in posData.IDs_idxs: + continue + + obj_idx = posData.IDs_idxs[annotID] + obj = posData.rp[obj_idx] + acdc_df.at[annotID, state["name"]] = 1 + if not self.isObjVisible(obj.bbox): + continue + y, x = self.getObjCentroid(obj.centroid) + xx.append(x) + yy.append(y) + + scatterPlotItem = self.customAnnotDict[button]["scatterPlotItem"] + scatterPlotItem.setData(xx, yy) + + posData.allData_li[posData.frame_i]["acdc_df"] = acdc_df + + # if self.highlightedID != 0: + # self.highlightedID = 0 + # self.setHighlightID(False) + + if buttons: + return buttons[0] + + def loadCustomAnnotations(self): + items = list(self.savedCustomAnnot.keys()) + if len(items) == 0: + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + There are no custom annotations saved.

    + Click on "Add custom annotation" button to start adding new + annotations. + """) + msg.warning(self, "No annotations saved", txt) + return + + self.selectAnnotWin = widgets.QDialogListbox( + "Load previously used custom annotation(s)", + "Select annotations to load:", + items, + additionalButtons=("Delete selected annnotations",), + parent=self, + multiSelection=True, + ) + for button in self.selectAnnotWin._additionalButtons: + button.disconnect() + button.clicked.connect(self.deleteSavedAnnotation) + self.selectAnnotWin.exec_() + if self.selectAnnotWin.cancel: + return + + for selectedAnnotName in self.selectAnnotWin.selectedItemsText: + selectedAnnot = self.savedCustomAnnot[selectedAnnotName] + + symbol = selectedAnnot["symbol"] + symbol = re.findall(r"\'(.+)\'", symbol)[0] + symbolColor = selectedAnnot["symbolColor"] + symbolColor = pg.mkColor(symbolColor) + keySequence = widgets.KeySequenceFromText(selectedAnnot["shortcut"]) + Type = selectedAnnot["type"] + toolTip = ( + f"Name: {selectedAnnotName}\n\n" + f"Type: {Type}\n\n" + f"Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n" + f"Description: {selectedAnnot['description']}\n\n" + f'Shortcut: "{keySequence}"' + ) + keepActive = selectedAnnot["keepActive"] + isHideChecked = selectedAnnot["isHideChecked"] + state = { + "type": Type, + "name": selectedAnnotName, + "symbol": selectedAnnot["symbol"], + "shortcut": selectedAnnot["shortcut"], + "description": selectedAnnot["description"], + "keepActive": keepActive, + "isHideChecked": isHideChecked, + "symbolColor": symbolColor, + } + self.addCustomAnnotationItems( + symbol, + symbolColor, + keySequence, + toolTip, + selectedAnnotName, + keepActive, + isHideChecked, + state, + ) + for pos_i, posData in enumerate(self.data): + posData.customAnnot[selectedAnnotName] = selectedAnnot + + self.saveCustomAnnot() + + def readSavedCustomAnnot(self): + tempAnnot = {} + if os.path.exists(custom_annot_path): + self.logger.info("Loading saved custom annotations...") + tempAnnot = load.read_json(custom_annot_path, logger_func=self.logger.info) + + posData = self.data[self.pos_i] + self.savedCustomAnnot = tempAnnot + for pos_i, posData in enumerate(self.data): + self.savedCustomAnnot = {**self.savedCustomAnnot, **posData.customAnnot} + + def reinitCustomAnnot(self): + buttons = list(self.customAnnotDict.keys()) + for button in buttons: + self.clearScatterPlotCustomAnnotButton(button) + action = self.customAnnotDict[button]["action"] + self.annotateToolbar.removeAction(action) + self.checkableQButtonsGroup.removeButton(button) + self.customAnnotDict.pop(button) + # self.savedCustomAnnot.pop(name) + + self.saveCustomAnnot(only_temp=True) + + def removeCustomAnnotButton(self, button, askHow=True, save=True): + if askHow: + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + Do you want to remove also the column with annotations or + only the annotation button?
    + """) + _, removeOnlyButton, removeColButton = msg.question( + self, + "Remove only button?", + txt, + buttonsTexts=( + "Cancel", + "Remove only button", + " Remove also column with annotations ", + ), + ) + if msg.cancel: + return + removeOnlyButton = msg.clickedButton == removeOnlyButton + else: + removeOnlyButton = True + + name = self.customAnnotDict[button]["state"]["name"] + # remove annotation from position + for posData in self.data: + try: + posData.customAnnot.pop(name) + posData.saveCustomAnnotationParams() + except KeyError as e: + # Current pos doesn't have any annotation button. Continue + continue + + if posData.acdc_df is None: + continue + + if removeOnlyButton: + continue + + posData.acdc_df = posData.acdc_df.drop(columns=name, errors="ignore") + for frame_i, data_dict in enumerate(posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + acdc_df = acdc_df.drop(columns=name, errors="ignore") + posData.allData_li[frame_i]["acdc_df"] = acdc_df + + self.clearScatterPlotCustomAnnotButton(button) + + action = self.customAnnotDict[button]["action"] + self.annotateToolbar.removeAction(action) + self.checkableQButtonsGroup.removeButton(button) + self.customAnnotDict.pop(button) + # self.savedCustomAnnot.pop(name) + + self.saveCustomAnnot(only_temp=True) + + def saveCustomAnnot(self, only_temp=False): + if not hasattr(self, "savedCustomAnnot"): + return + + if not self.savedCustomAnnot: + return + + # Save to cell acdc temp path + with open(custom_annot_path, mode="w") as file: + json.dump(self.savedCustomAnnot, file, indent=2) + + if only_temp: + return + + self.logger.info("Saving custom annotations parameters...") + # Save to pos path + for _posData in self.data: + _posData.saveCustomAnnotationParams() + + def viewAllCustomAnnot(self, checked): + if not checked: + # Clear all annotations before showing only checked + for button in self.customAnnotDict.keys(): + self.clearScatterPlotCustomAnnotButton(button) + self.doCustomAnnotation(0) diff --git a/cellacdc/mixins/data_loading.py b/cellacdc/mixins/data_loading.py new file mode 100644 index 000000000..0a4f3b6f1 --- /dev/null +++ b/cellacdc/mixins/data_loading.py @@ -0,0 +1,1653 @@ +"""Qt view adapter for data loading and recovery workflows.""" + +from __future__ import annotations + +import os +import shutil +import zipfile +from functools import partial + +import numpy as np +import pandas as pd +import psutil +import skimage +from datetime import datetime +import cv2 +import skimage.color +import skimage.io +from natsort import natsorted +from qtpy.QtCore import QEventLoop, QMutex, Qt, QThread, QTimer, QWaitCondition +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import QFileDialog, QPushButton + +from cellacdc import ( + _palettes, + apps, + autopilot, + dataPrep, + data_structure_docs_url, + exception_handler, + html_utils, + load, + utils, + prompts, + user_manual_url, + widgets, + workers, +) + +GREEN_HEX = _palettes.green() + +from .layout_controls import LayoutControls + + +class DataLoading(LayoutControls): + """Extracted from guiWin.""" + + def _createEmptyData(self): + self.MostRecentPath = self.getMostRecentPath() + exp_path = QFileDialog.getExistingDirectory( + self, + "Select experiment folder where to create empty data", + self.MostRecentPath, + ) + if not exp_path: + return + + pos_path = os.path.join(exp_path, "Position_1") + images_path = os.path.join(pos_path, "Images") + if os.path.exists(images_path): + raise FileExistsError(f'The following path already exists "{images_path}"') + + os.makedirs(images_path, exist_ok=True) + + basename = "test_empty_" + tif_filename = f"{basename}channel_1.tif" + tif_filepath = os.path.join(images_path, tif_filename) + empty_img = np.zeros((256, 256), dtype=np.uint8) + empty_img[0, 0] = 255 + skimage.io.imsave(tif_filepath, empty_img) + + metadata_filename = f"{basename}metadata.csv" + metadata_filepath = os.path.join(images_path, metadata_filename) + df_metadata = pd.DataFrame({"Description": ["basename"], "values": [basename]}) + df_metadata.to_csv(metadata_filepath, index=False) + + self.isNewFile = True + self._openFolder(exp_path=images_path) + + def _loadFromExperimentFolder(self, exp_path): + select_folder = load.select_exp_folder() + values = select_folder.get_values_segmGUI(exp_path) + if not values: + self.criticalInvalidPosFolder(exp_path) + self.openFolderAction.setEnabled(True) + return [] + + if len(values) > 1: + select_folder.QtPrompt(self, values, allow_cancel=False) + if select_folder.cancel: + return [] + else: + select_folder.cancel = False + select_folder.selected_pos = select_folder.pos_foldernames + + images_paths = [] + for pos in select_folder.selected_pos: + images_paths.append(os.path.join(exp_path, pos, "Images")) + return images_paths + + def _openFile(self, file_path=None): + """ + Function used for loading an image file directly. + """ + if file_path is None: + self.MostRecentPath = self.getMostRecentPath() + file_path = QFileDialog.getOpenFileName( + self, + "Select image file", + self.MostRecentPath, + "Image/Video Files (*.png *.tif *.tiff *.jpg *.jpeg *.mov *.avi *.mp4)" + ";;All Files (*)", + )[0] + if not file_path: + return + + filename, ext = os.path.splitext(os.path.basename(file_path)) + ext = ext.lower() + dirpath = os.path.dirname(file_path) + dirname = os.path.basename(dirpath) + filename = filename.rstrip("_") + channel_name = None + do_copy = True + if dirname != "Images": + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + acdc_folder = f"{timestamp}_acdc" + exp_path = os.path.join(dirpath, acdc_folder, "Images") + proceed, do_copy = self.warnUserCreationImagesFolder(exp_path, ext) + if not proceed: + self.logger.info("Loading image file cancelled.") + return + + proceed, channel_name = self.askUserChannelName(filename, ".tif") + if not proceed: + self.logger.info("Loading image file cancelled.") + return + + os.makedirs(exp_path, exist_ok=True) + else: + exp_path = dirpath + + if channel_name is not None: + # Check if user wants to use the existing channel name + underscore_splits = filename.split("_") + if len(underscore_splits) > 1: + default_ch_name = underscore_splits[-1] + if channel_name == default_ch_name: + filename = "_".join(underscore_splits[:-1]) + + basename = f"{filename}_" + new_filename = f"{filename}_{channel_name}{ext}" + df_metadata = pd.DataFrame( + {"Description": ["basename"], "values": [basename]} + ) + metadata_csv_filename = f"{basename}metadata.csv" + metadata_csv_filepath = os.path.join(exp_path, metadata_csv_filename) + df_metadata.to_csv(metadata_csv_filepath, index=False) + else: + new_filename = f"{filename}{ext}" + + if do_copy: + action_text = "Copying" + else: + action_text = "Moving" + + if ext == ".tif" or ext == ".npz": + new_filepath = os.path.join(exp_path, new_filename) + if not os.path.exists(new_filepath): + self.logger.info(f"{action_text} file to Images folder...") + if do_copy: + shutil.copy2(file_path, new_filepath) + else: + shutil.move(file_path, new_filepath) + self._openFolder(exp_path=exp_path, imageFilePath=new_filepath) + else: + self.logger.info(f"{action_text} file to .tif format...") + data = load.loadData(file_path, "", log_func=self.logger.info) + data.loadImgData() + img = data.img_data + if img.ndim == 3 and (img.shape[-1] == 3 or img.shape[-1] == 4): + self.logger.info("Converting RGB image to grayscale...") + if img.shape[-1] == 3: + data.img_data = skimage.color.rgb2gray(data.img_data) + else: + data.img_data = cv2.cvtColor(data.img_data, cv2.COLOR_RGBA2GRAY) + data.img_data = skimage.img_as_ubyte(data.img_data) + new_filename_no_ext, ext = os.path.splitext(new_filename) + tif_filename = f"{new_filename_no_ext}.tif" + tif_path = os.path.join(exp_path, tif_filename) + if data.img_data.ndim == 3: + SizeT = data.img_data.shape[0] + SizeZ = 1 + elif data.img_data.ndim == 4: + SizeT = data.img_data.shape[0] + SizeZ = data.img_data.shape[1] + else: + SizeT = 1 + SizeZ = 1 + is_imageJ_dtype = ( + data.img_data.dtype == np.uint8 + or data.img_data.dtype == np.uint32 + or data.img_data.dtype == np.uint32 + or data.img_data.dtype == np.float32 + ) + if not is_imageJ_dtype: + data.img_data = skimage.img_as_ubyte(data.img_data) + + utils.to_tiff(tif_path, data.img_data) + self._openFolder(exp_path=exp_path, imageFilePath=tif_path) + + def _openFolder(self, checked=False, exp_path=None, imageFilePath=""): + """Main function to load data. + + Parameters + ---------- + checked : bool + kwarg needed because openFolder can be called by openFolderAction. + exp_path : string or None + Path selected by the user either directly, through openFile, + or drag and drop image file. + imageFilePath : string + Path of the image file that was either drag and dropped or opened + from File --> Open image/video file (openFileAction). + + Returns + ------- + None + """ + + if exp_path is None: + self.MostRecentPath = self.getMostRecentPath() + exp_path = QFileDialog.getExistingDirectory( + self, + "Select experiment folder containing Position_n folders " + "or specific Position_n folder", + self.MostRecentPath, + ) + + if not exp_path: + self.openFolderAction.setEnabled(True) + return + + proceed = self.reInitGui() + if not proceed: + self.openFolderAction.setEnabled(True) + return + + self.openFolderAction.setEnabled(False) + + if self.slideshowWin is not None: + self.slideshowWin.close() + + if self.ccaTableWin is not None: + self.ccaTableWin.close() + + self.exp_path = exp_path + self.logger.info(f"Loading from {self.exp_path}") + self.addToRecentPaths(exp_path, logger=self.logger) + self.addPathToOpenRecentMenu(exp_path) + + folder_type = utils.determine_folder_type(exp_path) + is_pos_folder, is_images_folder, exp_path = folder_type + + self.titleLabel.setText("Loading data...", color=self.titleColor) + + skip_channels = [] + ch_name_selector = prompts.select_channel_name( + which_channel="segm", allow_abort=False + ) + user_ch_name = None + if not is_pos_folder and not is_images_folder and not imageFilePath: + images_paths = self._loadFromExperimentFolder(exp_path) + if not images_paths: + self.loadingDataAborted() + return + + elif is_pos_folder and not imageFilePath: + pos_foldername = os.path.basename(exp_path) + exp_path = os.path.dirname(exp_path) + images_paths = [os.path.join(exp_path, pos_foldername, "Images")] + + elif is_images_folder and not imageFilePath: + images_paths = [exp_path] + pos_path = os.path.dirname(exp_path) + exp_path = os.path.dirname(pos_path) + + elif imageFilePath: + # images_path = exp_path because called by openFile func + filenames = utils.listdir(exp_path) + ch_names, basenameNotFound = ch_name_selector.get_available_channels( + filenames, exp_path + ) + filename = os.path.basename(imageFilePath) + self.ch_names = ch_names + user_ch_name = [ + chName for chName in ch_names if filename.find(chName) != -1 + ][0] + images_paths = [exp_path] + pos_path = os.path.dirname(exp_path) + exp_path = os.path.dirname(pos_path) + + self.images_paths = images_paths + + # Get info from first position selected + images_path = self.images_paths[0] + filenames = utils.listdir(images_path) + if ch_name_selector.is_first_call and user_ch_name is None: + ch_names, _ = ch_name_selector.get_available_channels( + filenames, images_path + ) + self.ch_names = ch_names + if not ch_names: + self.openFolderAction.setEnabled(True) + self.criticalNoTifFound(images_path) + return + if len(ch_names) > 1: + CbLabel = "Select channel name to load: " + ch_name_selector.QtPrompt(self, ch_names, CbLabel=CbLabel) + if ch_name_selector.was_aborted: + self.openFolderAction.setEnabled(True) + return + skip_channels.extend( + [ch for ch in ch_names if ch != ch_name_selector.channel_name] + ) + else: + ch_name_selector.channel_name = ch_names[0] + ch_name_selector.setUserChannelName() + user_ch_name = ch_name_selector.user_ch_name + else: + # File opened directly with self.openFile + ch_name_selector.channel_name = user_ch_name + + user_ch_file_paths = [] + not_allowed_ends = ["btrack_tracks.h5"] + for images_path in self.images_paths: + channel_file_path = load.get_filename_from_channel( + images_path, + user_ch_name, + skip_channels=skip_channels, + not_allowed_ends=not_allowed_ends, + logger=self.logger.info, + ) + if not channel_file_path: + self.criticalImgPathNotFound(images_path) + return + user_ch_file_paths.append(channel_file_path) + + ch_name_selector.setUserChannelName() + self.user_ch_name = user_ch_name + self.img1.channelName = user_ch_name + + self.AutoPilotProfile.storeSelectedChannel(self.user_ch_name) + + self.initGlobalAttr() + self.createOverlayContextMenu() + self.createUserChannelNameAction() + self.gui_createOverlayColors() + self.gui_createOverlayItems() + lastRow = self.bottomLeftLayout.rowCount() + self.bottomLeftLayout.setRowStretch(lastRow + 1, 1) + + self.num_pos = len(user_ch_file_paths) + proceed = self.loadSelectedData(user_ch_file_paths, user_ch_name) + if not proceed: + self.openFolderAction.setEnabled(True) + return + + def _workerDebug(self, stuff_to_debug): + pass + + def addToRecentPaths(self, path, logger=None): + utils.addToRecentPaths(path, logger=self.logger) + + def askMismatchSegmDataShape(self, posData): + msg = widgets.myMessageBox(wrapText=False) + title = "Segm. data shape mismatch" + f = "3D" if self.isSegm3D else "2D" + f = f"{f} over time" if posData.SizeT > 1 else f + r = "2D" if self.isSegm3D else "3D" + r = f"{r} over time" if posData.SizeT > 1 else r + text = html_utils.paragraph(f""" + The segmentation masks of the first Position that you loaded is + {f},
    + while {posData.pos_foldername} is {r}.

    + The loaded segmentation masks must be either all 3D + or all 2D.

    + Do you want to skip loading this position or cancel the process? + """) + _, skipPosButton = msg.warning( + self, title, text, buttonsTexts=("Cancel", "Skip this Position") + ) + if skipPosButton == msg.clickedButton: + self.loadDataWorker.skipPos = True + self.loadDataWorker.waitCond.wakeAll() + + def askRecoverNotSavedData(self, posData): + last_modified_time_unsaved = "NEVER" + if os.path.exists(posData.segm_npz_temp_path): + recovered_file_path = posData.segm_npz_temp_path + if os.path.exists(posData.segm_npz_path): + last_modified_time_unsaved = datetime.fromtimestamp( + os.path.getmtime(posData.segm_npz_path) + ).strftime("%a %d. %b. %y - %H:%M:%S") + else: + posData.setTempPaths() + if os.path.exists(posData.unsaved_acdc_df_autosave_path): + zip_path = posData.unsaved_acdc_df_autosave_path + with zipfile.ZipFile(zip_path, mode="r") as zip: + csv_names = natsorted(set(zip.namelist())) + iso_key = csv_names[-1][:-4] + most_recent_unsaved_acdc_df_datetime = datetime.strptime( + iso_key, load.ISO_TIMESTAMP_FORMAT + ) + last_modified_time_unsaved = ( + most_recent_unsaved_acdc_df_datetime + ).strftime("%a %d. %b. %y - %H:%M:%S") + + if os.path.exists(posData.acdc_output_csv_path): + acdc_df_mtime = os.path.getmtime(posData.acdc_output_csv_path) + timestamp = datetime.fromtimestamp(acdc_df_mtime) + last_modified_time_saved = timestamp.strftime("%a %d. %b. %y - %H:%M:%S") + else: + last_modified_time_saved = "Null" + + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(""" + Cell-ACDC detected unsaved data.

    + Do you want to load and recover the unsaved data or + load the data that was last saved by the user? + """) + details = f""" + The unsaved data was created on {last_modified_time_unsaved}\n\n + The user saved the data last time on {last_modified_time_saved} + """ + msg.setDetailedText(details) + loadUnsavedButton = widgets.reloadPushButton("Recover unsaved data") + loadSavedButton = widgets.savePushButton("Load saved data") + infoButton = widgets.infoPushButton("More info...") + loadSafeNpzButton = "" + if posData.isSafeNpzOverwritePresent(): + loadSafeNpzButton = widgets.reloadPushButton( + "Load .safe.npz file from crash" + ) + buttons = ( + loadSavedButton, + loadUnsavedButton, + loadSafeNpzButton, + infoButton, + ) + else: + buttons = (loadSavedButton, loadUnsavedButton, infoButton) + msg.question( + self.progressWin, + "Recover unsaved data?", + txt, + buttonsTexts=("Cancel", *buttons), + showDialog=False, + ) + infoButton.disconnect() + infoButton.clicked.connect(partial(self.showInfoAutosave, posData)) + msg.exec_() + if msg.cancel: + self.loadDataWorker.abort = True + elif msg.clickedButton == loadUnsavedButton: + self.loadDataWorker.loadUnsaved = True + elif msg.clickedButton == loadSafeNpzButton: + self.loadDataWorker.loadSafeOverwriteNpz = True + + self.loadDataWorker.waitCond.wakeAll() + + def askUserChannelName(self, filename_no_ext, ext): + help_txt = html_utils.paragraph(f""" + Cell-ACDC requires that every image file has a basename and some + additional text, typically the channel name.

    + The basename will be common to all created files, while the additional text is used to identify the image files. + """) + + basename = filename_no_ext + underscore_splits = filename_no_ext.split("_") + if len(underscore_splits) > 1: + channel_name = underscore_splits[-1] + basename = "_".join(underscore_splits[:-1]) + else: + channel_name = "channel_1" + + txt = html_utils.paragraph(f""" + Provide some text (e.g., the channel name) to append at the end of the image file. + """) + win = apps.filenameDialog( + basename=basename, + ext=ext, + hintText=txt, + defaultEntry=channel_name, + helpText=help_txt, + allowEmpty=False, + parent=self, + title="Provide channel name for image file", + ) + win.exec_() + if win.cancel: + return False, "" + + return True, win.entryText + + def checkManageVersions(self): + posData = self.data[self.pos_i] + posData.setTempPaths(createFolder=False) + loaded_acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) + + if os.path.exists(posData.recoveryFolderpath()): + self.manageVersionsAction.setDisabled(False) + self.manageVersionsAction.setToolTip( + f"Load an older version of the `{loaded_acdc_df_filename}` file " + "(table with annotations and measurements)." + ) + else: + self.manageVersionsAction.setDisabled(True) + + def checkMemoryRequirements(self, required_ram): + memory = psutil.virtual_memory() + total_ram = memory.total + available_ram = memory.available + if required_ram / available_ram > 0.3: + proceed = self.warnMemoryNotSufficient( + total_ram, available_ram, required_ram + ) + return proceed + else: + return True + + def criticalFluoChannelNotFound(self, fluo_ch, posData): + msg = widgets.myMessageBox(showCentered=False) + ls = "\n".join(utils.listdir(posData.images_path)) + msg.setDetailedText(f"Files present in the {posData.relPath} folder:\n{ls}") + title = "Requested channel data not found!" + txt = html_utils.paragraph( + f"The folder {posData.pos_path} " + "does not contain " + "either one of the following files:

    " + f"{posData.basename}{fluo_ch}.tif
    " + f"{posData.basename}{fluo_ch}_aligned.npz

    " + "Data loading aborted." + ) + msg.addShowInFileManagerButton(posData.images_path) + okButton = msg.warning(self, title, txt, buttonsTexts=("Ok")) + + def criticalImgPathNotFound(self, images_path): + self.logger.info( + "The following folder does not contain valid image files: " + f'"{images_path}"\n\n' + "Check that all the positions loaded contain the same channel name. " + "Make sure to double check for spelling mistakes or types in the " + "channel names." + ) + msg = widgets.myMessageBox() + msg.addShowInFileManagerButton(images_path) + err_msg = html_utils.paragraph(f""" + The folder

    + {images_path}

    + does not contain any valid image file!

    + Valid file formats are .h5, .tif, _aligned.h5, _aligned.npz. + """) + okButton = msg.critical( + self, "No valid files found!", err_msg, buttonsTexts=("Ok",) + ) + + def criticalInvalidPosFolder(self, exp_path): + href = html_utils.href_tag("here", data_structure_docs_url) + txt = html_utils.paragraph(f""" + The selected folder:

    + + {exp_path}

    + + is not a valid folder.

    + + Select a folder that contains the Position_n folders, + or a specific Position.

    + + If you are trying to load a single image file go to + File --> Open image/video file....

    + + To load a folder containing multiple .tif files the folder must + be called either Position_n
    + (with n being an integer) or Images.

    + + For more information about the correct folder structure see {href}. + """) + msg = widgets.myMessageBox(wrapText=False) + helpButton = widgets.helpPushButton("Help...") + msg.addButton(helpButton) + helpButton.clicked.disconnect() + helpButton.clicked.connect(partial(utils.browse_url, data_structure_docs_url)) + msg.addShowInFileManagerButton(exp_path) + msg.critical(self, "Incompatible folder", txt) + + def criticalNoTifFound(self, images_path): + err_title = "No .tif files found in folder." + err_msg = html_utils.paragraph( + "The following folder

    " + f"{images_path}

    " + "does not contain .tif or .h5 files.

    " + 'Only .tif or .h5 files can be loaded with "Open Folder" button.

    ' + "Try with File --> Open image/video file... " + "and directly select the file you want to load." + ) + msg = widgets.myMessageBox() + msg.addShowInFileManagerButton(images_path) + msg.critical(self, err_title, err_msg) + + def getFileExtensions(self, images_path): + alignedFound = any( + [f.find("_aligned.np") != -1 for f in utils.listdir(images_path)] + ) + if alignedFound: + extensions = ( + "Aligned channels (*npz *npy);; Tif channels(*tiff *tif);;All Files (*)" + ) + else: + extensions = "Tif channels(*tiff *tif);; All Files (*)" + return extensions + + def getMostRecentPath(self): + return utils.getMostRecentPath() + + def getPathFromChName(self, chName, posData): + ls = utils.listdir(posData.images_path) + endnames = {f[len(posData.basename) :]: f for f in ls} + validEnds = ["_aligned.npz", "_aligned.h5", ".h5", ".tif", ".npz"] + for end in validEnds: + files = [ + filename + for endname, filename in endnames.items() + if endname == f"{chName}{end}" + ] + if files: + filename = files[0] + break + else: + self.criticalFluoChannelNotFound(chName, posData) + self.app.restoreOverrideCursor() + return None, None + + fluo_path = os.path.join(posData.images_path, filename) + filename, _ = os.path.splitext(filename) + return fluo_path, filename + + def helpNewFile(self): + msg = widgets.myMessageBox(showCentered=False) + href = f'user manual' + txt = html_utils.paragraph(f""" + Cell-ACDC can open both a single image file or files structured + into Position folders.

    + If you are just testing out you can load a single image file, but + in general we reccommend structuring your data into Position + folders.

    + More info about Position folders in the {href} at the section + called "Create required data structure from microscopy file(s)". + """) + msg.information(self, "Help on Position folders", txt) + + def initFluoData(self): + if len(self.ch_names) <= 1: + return + + if "ask_load_fluo_at_init" in self.df_settings.index: + if self.df_settings.at["ask_load_fluo_at_init", "value"] == "No": + return + msg = widgets.myMessageBox(allowClose=False) + txt = ( + "Do you also want to load fluorescence images?
    " + "You can load as many channels as you want.

    " + "If you load fluorescence images then the software will " + "calculate metrics for each loaded fluorescence channel " + "such as min, max, mean, quantiles, etc. " + "of each segmented object.

    " + "NOTE: You can always load them later from the menu " + "File --> Load fluorescence images... or when you set " + "measurements from the menu " + "Measurements --> Set measurements..." + ) + msg.addDoNotShowAgainCheckbox(text="Don't ask again") + no, yes = msg.question( + self, + "Load fluorescence images?", + html_utils.paragraph(txt), + buttonsTexts=("No", "Yes"), + ) + if msg.doNotShowAgainCheckbox.isChecked(): + self.df_settings.at["ask_load_fluo_at_init", "value"] = "No" + self.df_settings.to_csv(self.settings_csv_path) + if msg.clickedButton == yes: + self.loadFluo_cb(None) + self.AutoPilotProfile.storeClickMessageBox( + "Load fluorescence images?", msg.clickedButton.text() + ) + + def loadDataWorkerDataIntegrityCritical(self): + errTitle = "All loaded positions contains frames over time!" + self.titleLabel.setText(errTitle, color="r") + + msg = widgets.myMessageBox(parent=self) + + err_msg = html_utils.paragraph(f""" + {errTitle}.

    + To load data that contains frames over time you have to select + only ONE position. + """) + msg.setIcon(iconName="SP_MessageBoxCritical") + msg.setWindowTitle("Loaded multiple positions with frames!") + msg.addText(err_msg) + msg.addButton("Ok") + msg.show(block=True) + + def loadDataWorkerDataIntegrityWarning(self, pos_foldername): + err_msg = ( + 'WARNING: Segmentation mask file ("..._segm.npz") not found. ' + "You could run segmentation module first." + ) + self.workerProgress(err_msg, "INFO") + self.titleLabel.setText(err_msg, color="r") + abort = False + msg = widgets.myMessageBox(parent=self) + warn_msg = html_utils.paragraph(f""" + The folder {pos_foldername} does not contain a + pre-computed segmentation mask.

    + You can continue with a blank mask or cancel and + pre-compute the mask with the segmentation module.

    + Do you want to continue? + """) + msg.setIcon(iconName="SP_MessageBoxWarning") + msg.setWindowTitle("Segmentation file not found") + msg.addText(warn_msg) + msg.addButton("Ok") + continueWithBlankSegm = msg.addButton(" Cancel ") + msg.show(block=True) + if continueWithBlankSegm == msg.clickedButton: + abort = True + self.loadDataWorker.abort = abort + self.loadDataWaitCond.wakeAll() + + def loadDataWorkerFinished(self, data): + self.funcDescription = "loading data worker finished" + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + if data is None or data == "abort": + self.loadingDataAborted() + return + + if data[0].onlyEditMetadata: + self.loadingDataAborted() + return + + self.pos_i = 0 + self.data = data + self.gui_createGraphicsItems() + return True + + def loadFromArrays(self, image, labels=None, **kwargs): + """Load in-memory arrays into the GUI without filesystem dialogs.""" + from cellacdc.data_source import ExperimentData + + data = ExperimentData.from_arrays(image, labels, **kwargs) + self.loadFromExperimentData(data) + + def loadFromExperimentData(self, data): + """Load a materialized :class:`ExperimentData` instance into the GUI.""" + if not data.is_materialized: + raise ValueError("ExperimentData must be materialized before loading.") + + posData = data.positions[0] + self.user_ch_name = posData.user_ch_name + self.ch_names = posData.chNames + self.user_ch_file_paths = [posData.imgPath] + self.num_pos = len(data.positions) + self.exp_path = posData.exp_path + self.isNewFile = not posData.segmFound + self.newSegmEndName = "" + self.selectedSegmEndName = "" + self.labelBoolSegm = posData.labelBoolSegm + self.isSegm3D = posData.isSegm3D + self.SizeT = posData.SizeT + self.SizeZ = posData.SizeZ + self.TimeIncrement = posData.TimeIncrement + self.PhysicalSizeZ = posData.PhysicalSizeZ + self.PhysicalSizeY = posData.PhysicalSizeY + self.PhysicalSizeX = posData.PhysicalSizeX + self.loadSizeS = posData.loadSizeS + self.loadSizeT = posData.loadSizeT + self.loadSizeZ = posData.loadSizeZ + self.isSnapshot = posData.SizeT == 1 + self.isH5chunk = False + self.existingSegmEndNames = set() + self.createOverlayLabelsContextMenu(self.existingSegmEndNames) + self.createOverlayLabelsItems(self.existingSegmEndNames) + self.overlayLabelsButtonAction.setVisible(True) + self.disableNonFunctionalButtons() + self.overlayLabelsItems = {} + self.drawModeOverlayLabelsChannels = {} + self.loadDataWorkerFinished(data.positions) + + def loadFluo_cb(self, checked=True, fluo_channels=None): + if fluo_channels is None: + posData = self.data[self.pos_i] + ch_names = [ + ch + for ch in self.ch_names + if ch != self.user_ch_name and ch not in posData.loadedFluoChannels + ] + if not ch_names: + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "You already loaded ALL channels.

    " + "To change the overlaid channel " + "right-click on the overlay button." + ) + msg.information(self, "All channels are loaded", txt) + return False + selectFluo = widgets.QDialogListbox( + "Select channel to load", + "Select channel names to load:\n", + ch_names, + multiSelection=True, + parent=self, + ) + selectFluo.exec_() + + if selectFluo.cancel: + return False + + fluo_channels = selectFluo.selectedItemsText + self.AutoPilotProfile.storeLoadedFluoChannels(fluo_channels) + + for p, posData in enumerate(self.data): + # posData.ol_data = None + for fluo_ch in fluo_channels: + fluo_path, filename = self.getPathFromChName(fluo_ch, posData) + if fluo_path is None: + self.criticalFluoChannelNotFound(fluo_ch, posData) + return False + fluo_data, bkgrData = self.load_fluo_data(fluo_path) + if fluo_data is None: + return False + posData.loadedFluoChannels.add(fluo_ch) + + if posData.SizeT == 1: + fluo_data = fluo_data[np.newaxis] + + posData.fluo_data_dict[filename] = fluo_data + posData.fluo_bkgrData_dict[filename] = bkgrData + posData.ol_data_dict[filename] = fluo_data.copy() + + self.overlayButton.setStyleSheet(f"background-color: {GREEN_HEX}") + self.guiTabControl.addChannels( + [posData.user_ch_name, *posData.loadedFluoChannels] + ) + return True + + def loadNonAlignedFluoChannel(self, fluo_path): + posData = self.data[self.pos_i] + if posData.filename.find("aligned") != -1: + filename, _ = os.path.splitext(os.path.basename(fluo_path)) + path = f".../{posData.pos_foldername}/Images/{filename}_aligned.npz" + msg = widgets.myMessageBox() + msg.critical( + self, + "Aligned fluo channel not found!", + "Aligned data for fluorescence channel not found!\n\n" + f"You loaded aligned data for the cells channel, therefore " + "loading NON-aligned fluorescence data is not allowed.\n\n" + 'Run the script "dataPrep.py" to create the following file:\n\n' + f"{path}", + ) + return None + fluo_data = np.squeeze(skimage.io.imread(fluo_path)) + return fluo_data + + def loadPosTriggered(self): + if not self.isDataLoaded: + return + + self.startAutomaticLoadingPos() + + def loadSelectedData(self, user_ch_file_paths, user_ch_name): + data = [] + numPos = len(user_ch_file_paths) + self.user_ch_file_paths = user_ch_file_paths + + self.logger.info(f"Reading {user_ch_name} channel metadata...") + # Get information from first loaded position + posData = load.loadData( + user_ch_file_paths[0], user_ch_name, log_func=self.logger.info + ) + posData.getBasenameAndChNames(qparent=self) + posData.buildPaths() + + if posData.ext != ".h5": + self.lazyLoader.salute = False + self.lazyLoader.exit = True + self.lazyLoaderWaitCond.wakeAll() + self.waitReadH5cond.wakeAll() + + # Get end name of every existing segmentation file + existingSegmEndNames = set() + for filePath in user_ch_file_paths: + _posData = load.loadData(filePath, user_ch_name, log_func=self.logger.info) + _posData.getBasenameAndChNames(qparent=self) + segm_files = load.get_segm_files(_posData.images_path) + _existingEndnames = load.get_endnames(_posData.basename, segm_files) + existingSegmEndNames.update(_existingEndnames) + + selectedSegmEndName = "" + self.newSegmEndName = "" + if self.isNewFile or not existingSegmEndNames: + self.isNewFile = True + # Remove the 'segm_' part to allow filenameDialog to check if + # a new file is existing (since we only ask for the part after + # 'segm_') + existingEndNames = [ + n.replace("segm", "", 1).replace("_", "", 1) + for n in existingSegmEndNames + ] + if posData.basename.endswith("_"): + basename = f"{posData.basename}segm" + else: + basename = f"{posData.basename}_segm" + win = apps.filenameDialog( + basename=basename, + hintText="Insert a filename for the segmentation file:", + existingNames=existingEndNames, + ) + win.exec_() + if win.cancel: + self.loadingDataAborted() + return + self.newSegmEndName = win.entryText + else: + if len(existingSegmEndNames) > 0: + win = apps.SelectSegmFileDialog( + existingSegmEndNames, + self.exp_path, + parent=self, + addNewFileButton=True, + basename=posData.basename, + ) + win.exec_() + if win.cancel: + self.loadingDataAborted() + return + if win.newSegmEndName is None: + selectedSegmEndName = win.selectedItemText + self.AutoPilotProfile.storeSelectedSegmFile(selectedSegmEndName) + else: + self.newSegmEndName = win.newSegmEndName + self.isNewFile = True + elif len(existingSegmEndNames) == 1: + selectedSegmEndName = list(existingSegmEndNames)[0] + + posData.loadImgData() + + required_ram = posData.getBytesImageData() + if required_ram >= 5e8: + # Disable autosave for data > 500MB + self.autoSaveToggle.setChecked(False) + + proceed = self.checkMemoryRequirements(required_ram) + if not proceed: + self.loadingDataAborted() + return + + posData.loadOtherFiles( + load_segm_data=True, + load_metadata=True, + create_new_segm=self.isNewFile, + new_endname=self.newSegmEndName, + end_filename_segm=selectedSegmEndName, + ) + self.selectedSegmEndName = selectedSegmEndName + self.labelBoolSegm = posData.labelBoolSegm + posData.labelSegmData() + + print("") + self.logger.info(f"Segmentation filename: {posData.segm_npz_path}") + + proceed = posData.askInputMetadata( + self.num_pos, + ask_SizeT=self.num_pos == 1, + ask_TimeIncrement=True, + ask_PhysicalSizes=True, + singlePos=False, + save=True, + warnMultiPos=True, + ) + if not proceed: + self.loadingDataAborted() + return + + self.AutoPilotProfile.storeOkAskInputMetadata() + + if posData.isSegm3D is None: + self.isSegm3D = False + else: + self.isSegm3D = posData.isSegm3D + self.SizeT = posData.SizeT + self.SizeZ = posData.SizeZ + self.TimeIncrement = posData.TimeIncrement + self.PhysicalSizeZ = posData.PhysicalSizeZ + self.PhysicalSizeY = posData.PhysicalSizeY + self.PhysicalSizeX = posData.PhysicalSizeX + self.loadSizeS = posData.loadSizeS + self.loadSizeT = posData.loadSizeT + self.loadSizeZ = posData.loadSizeZ + + self.overlayLabelsItems = {} + self.drawModeOverlayLabelsChannels = {} + + self.existingSegmEndNames = existingSegmEndNames + self.createOverlayLabelsContextMenu(existingSegmEndNames) + self.overlayLabelsButtonAction.setVisible(True) + self.createOverlayLabelsItems(existingSegmEndNames) + self.disableNonFunctionalButtons() + + self.isH5chunk = posData.ext == ".h5" and ( + self.loadSizeT != self.SizeT or self.loadSizeZ != self.SizeZ + ) + + required_ram = posData.checkH5memoryFootprint() * self.loadSizeS + if required_ram > 0: + proceed = self.checkMemoryRequirements(required_ram) + if not proceed: + self.loadingDataAborted() + return + + if posData.SizeT == 1: + self.isSnapshot = True + else: + self.isSnapshot = False + + self.progressWin = apps.QDialogWorkerProgress( + title="Loading data...", + parent=self, + pbarDesc=f'Loading "{user_ch_file_paths[0]}"...', + ) + self.progressWin.show(self.app) + + func = partial( + self.startLoadDataWorker, user_ch_file_paths, user_ch_name, posData + ) + + QTimer.singleShot(150, func) + + def load_fluo_data(self, fluo_path, isGuiThread=True): + self.logger.info(f'Loading fluorescence image data from "{fluo_path}"...') + bkgrData = None + posData = self.data[self.pos_i] + # Load overlay frames and align if needed + filename = os.path.basename(fluo_path) + filename_noEXT, ext = os.path.splitext(filename) + if ext == ".npy" or ext == ".npz": + fluo_data = np.load(fluo_path) + try: + fluo_data = np.squeeze(fluo_data["arr_0"]) + except Exception as e: + fluo_data = np.squeeze(fluo_data) + + # Load background data + bkgrData_path = os.path.join( + posData.images_path, f"{filename_noEXT}_bkgrRoiData.npz" + ) + if os.path.exists(bkgrData_path): + bkgrData = np.load(bkgrData_path) + elif ext == ".tif" or ext == ".tiff": + aligned_filename = f"{filename_noEXT}_aligned.npz" + aligned_path = os.path.join(posData.images_path, aligned_filename) + if os.path.exists(aligned_path): + fluo_data = np.load(aligned_path)["arr_0"] + + # Load background data + bkgrData_path = os.path.join( + posData.images_path, f"{aligned_filename}_bkgrRoiData.npz" + ) + if os.path.exists(bkgrData_path): + bkgrData = np.load(bkgrData_path) + else: + fluo_data = self.loadNonAlignedFluoChannel(fluo_path) + if fluo_data is None: + return None, None + + # Load background data + bkgrData_path = os.path.join( + posData.images_path, f"{filename_noEXT}_bkgrRoiData.npz" + ) + if os.path.exists(bkgrData_path): + bkgrData = np.load(bkgrData_path) + elif isGuiThread: + txt = html_utils.paragraph( + f"File format {ext} is not supported!\n" + "Choose either .tif or .npz files." + ) + msg = widgets.myMessageBox() + msg.critical(self, "File not supported", txt) + return None, None + + return fluo_data, bkgrData + + def loadingDataAborted(self): + self.openFolderAction.setEnabled(True) + self.titleLabel.setText("Loading data aborted.") + + def loadingDataCompleted(self): + self.isDataLoading = True + posData = self.data[self.pos_i] + + files_format = "\n".join( + [f" - {file}" for file in posData.images_folder_files] + ) + sep = "-" * 100 + self.logger.info( + f"{sep}\nFiles present in the first Position folder loaded:\n\n" + f"{files_format}\n{sep}" + ) + self.logger.info(f"Basename of the first Position: {posData.basename}") + self.secondLevelToolbar.setVisible(True) + self.updateImageValueFormatter() + self.checkManageVersions() + self.initManualBackgroundImage() + self.initPixelSizePropsDockWidget() + + self.setWindowTitle( + f'Cell-ACDC v{self._acdc_version} - GUI - "{posData.exp_path}"' + ) + + self.setupPreprocessing() + self.setupCombiningChannels() + + if self.isSegm3D: + self.segmNdimIndicator.setText("3D") + else: + self.segmNdimIndicator.setText("2D") + + self.segmNdimIndicatorAction.setVisible(True) + + self.guiTabControl.addChannels([posData.user_ch_name]) + self.showPropsDockButton.setDisabled(False) + + self.bottomScrollArea.show() + self.gui_createStoreStateWorker() + self.init_segmInfo_df() + self.connectScrollbars() + self.initPosAttr() + + self.logger.info("Pre-computing min and max values of the images...") + self.img1.preComputedMinMaxValues(self.data) + self.img2.minMaxValuesMapper = self.img1.minMaxValuesMapper + + self.initMetrics() + self.initFluoData() + self.createChannelNamesActions() + self.addActionsLutItemContextMenu(self.imgGrad) + + # Scrollbar for opacity of img1 (when overlaying) + self.img1.alphaScrollbar = self.addAlphaScrollbar(self.user_ch_name, self.img1) + + self.navigateScrollBar.setSliderPosition(posData.frame_i + 1) + + # Connect events at the end of loading data process + self.gui_connectGraphicsEvents() + if not self.isEditActionsConnected: + self.gui_connectEditActions() + self.normalizeToFloatAction.setChecked(True) + + self.navSpinBox.connectValueChanged(self.navigateSpinboxValueChanged) + + self.setFramesSnapshotMode() + if self.isSnapshot: + self.navSizeLabel.setText(f"/{len(self.data)}") + else: + self.navSizeLabel.setText(f"/{posData.SizeT}") + + self.enableZstackWidgets(posData.SizeZ > 1) + # self.showHighlightZneighCheckbox() + + self.exportToVideoAction.setDisabled(posData.SizeZ == 1 and posData.SizeT == 1) + + self.img1BottomGroupbox.show() + + isLabVisible = self.df_settings.at["isLabelsVisible", "value"] == "Yes" + isRightImgVisible = self.df_settings.at["isRightImageVisible", "value"] == "Yes" + isNextFrameVisible = self.df_settings.at["isNextFrameVisible", "value"] == "Yes" + isNextFrameActive = ( + isNextFrameVisible and self.labelsGrad.showNextFrameAction.isEnabled() + ) + self.updateScrollbars() + self.openFolderAction.setEnabled(True) + self.editTextIDsColorAction.setDisabled(False) + self.imgPropertiesAction.setEnabled(True) + self.navigateToolBar.setVisible(True) + self.labelsGrad.showLabelsImgAction.setChecked(isLabVisible) + self.labelsGrad.showRightImgAction.setChecked(isRightImgVisible) + self.labelsGrad.showNextFrameAction.setChecked(isNextFrameActive) + if isRightImgVisible or isNextFrameActive: + self.rightBottomGroupbox.setChecked(True) + + isTwoImagesLayout = isRightImgVisible or isLabVisible or isNextFrameActive + self.setTwoImagesLayout(isTwoImagesLayout) + + self.setBottomLayoutStretch() + + if isNextFrameActive: + self.rightBottomGroupbox.show() + self.rightBottomGroupbox.setChecked(True) + self.drawNothingCheckboxRight.click() + + self.readSavedCustomAnnot() + self.addCustomAnnotButtonAllLoadedPos() + self.setStatusBarLabel() + + self.initLookupTableLab() + if self.invertBwAction.isChecked() and not self.invertBwAlreadyCalledOnce: + self.invertBw(True) + self.restoreSavedSettings() + + self.initContoursImage() + self.initTextAnnot() + self.initDelRoiLab() + + self.update_rp() + self.updateAllImages() + if posData.SizeT > 1: + self.rightImageFramesScrollbar.setValueNoSignal(posData.frame_i + 2) + self.setMetricsFunc() + + self.gui_createLabelRoiItem() + self.gui_createZoomRectItem() + + self.titleLabel.setText("Data successfully loaded.", color=self.titleColor) + + self.disableNonFunctionalButtons() + self.setVisible3DsegmWidgets() + + if len(self.data) == 1 and posData.SizeZ > 1 and posData.SizeT == 1: + self.zSliceCheckbox.setChecked(True) + else: + self.zSliceCheckbox.setChecked(False) + + self.labelRoiCircItemLeft.setImageShape(self.currentLab2D.shape) + self.labelRoiCircItemRight.setImageShape(self.currentLab2D.shape) + + self.retainSpaceSlidersToggled(self.retainSpaceSlidersAction.isChecked()) + + self.stopAutomaticLoadingPos() + self.viewAllCustomAnnotAction.setChecked(True) + + self.updateImageValueFormatter() + + posData.loadWhitelist() + + self.setFocusGraphics() + self.setFocusMain() + + # Overwrite axes viewbox context menu + self.ax1.vb.menu = self.imgGrad.gradient.menu + self.ax2.vb.menu = self.labelsGrad.menu + + QTimer.singleShot(200, self.resizeGui) + + self.isDataLoaded = True + self.isDataLoading = False + + self.initImgGradRescaleIntensitiesHowPreference() + + self.rescaleIntensitiesLut(setImage=False) + + self.gui_createAutoSaveWorker() + + def newFile(self): + self.newSegmEndName = "" + self.isNewFile = True + msg = widgets.myMessageBox(parent=self, showCentered=False) + msg.setWindowTitle("File or folder?") + msg.addText( + html_utils.paragraph(f""" + Do you want to load an image file or Position + folder(s)? + """) + ) + loadPosButton = QPushButton("Load Position folder", msg) + loadPosButton.setIcon(QIcon(":folder-open.svg")) + loadFileButton = QPushButton("Load image file", msg) + loadFileButton.setIcon(QIcon(":image.svg")) + helpButton = widgets.helpPushButton("Help...") + msg.addButton(helpButton) + helpButton.disconnect() + helpButton.clicked.connect(self.helpNewFile) + msg.addCancelButton(connect=True) + msg.addButton(loadFileButton) + msg.addButton(loadPosButton) + loadPosButton.setDefault(True) + msg.exec_() + if msg.cancel: + return + + if msg.clickedButton == loadPosButton: + self._openFolder() + else: + self._openFile() + + def openFile(self, checked=False, file_path=None): + self.logger.info(f'Opening FILE "{file_path}"') + + self.isNewFile = False + self._openFile(file_path=file_path) + + def openFolder(self, checked=False, exp_path=None, imageFilePath=""): + if exp_path is None: + self.logger.info("Asking to select a folder path...") + else: + self.logger.info(f'Opening FOLDER "{exp_path}"...') + + self.isNewFile = False + if hasattr(self, "data") and self.titleLabel.text != "Saved!": + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Do you want to save before loading another dataset?" + ) + _, no, yes = msg.question( + self, "Save?", txt, buttonsTexts=("Cancel", "No", "Yes") + ) + if msg.clickedButton == yes: + func = partial(self._openFolder, exp_path, imageFilePath) + cancel = self.saveData(finishedCallback=func) + return + elif msg.cancel: + self.store_data() + return + else: + self.store_data(autosave=False) + + self._openFolder(exp_path=exp_path, imageFilePath=imageFilePath) + + def openRecentFile(self, path): + self.logger.info(f"Opening recent folder: {path}") + self.addToRecentPaths(path, logger=self.logger) + self.openFolder(exp_path=path) + + def reload_cb(self): + posData = self.data[self.pos_i] + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + labData = np.load(posData.segm_npz_path) + # Keep compatibility with .npy and .npz files + try: + lab = labData["arr_0"][posData.frame_i] + except Exception as e: + lab = labData[posData.frame_i] + posData.segm_data[posData.frame_i] = lab.copy() + self.get_data() + self.tracking() + self.updateAllImages() + + def showInfoAutosave(self, posData): + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = f""" + Cell-ACDC either detected unsaved data in a previous session and it + stored it because the Autosave
    + function was active, or it crashed during saving.

    + You can toggle Autosave ON and OFF from the menu on the top menubar + File --> Autosave. + """ + txt = f""" + {txt}

    + If Cell-ACDC crashed during saving, the segmentation file ending + with .new.npz
    + is present and you might be able to recover the data from there. + """ + + txt = f""" + {txt}

    + You can find additional recovered data in the following folder: + """ + txt = html_utils.paragraph(txt) + msg.information( + self, + "Autosave info", + txt, + path_to_browse=posData.recoveryFolderPath, + commands=(posData.recoveryFolderPath,), + ) + + def startAutomaticLoadingPos(self): + self.AutoPilot = autopilot.AutoPilot(self) + self.AutoPilot.execLoadPos() + + def startLoadDataWorker(self, user_ch_file_paths, user_ch_name, firstPosData): + self.funcDescription = "loading data" + + self.guiTabControl.propsQGBox.idSB.setValue(0) + + self.thread = QThread() + self.loadDataMutex = QMutex() + self.loadDataWaitCond = QWaitCondition() + + self.loadDataWorker = workers.loadDataWorker( + self, user_ch_file_paths, user_ch_name, firstPosData + ) + + self.loadDataWorker.moveToThread(self.thread) + self.loadDataWorker.signals.finished.connect(self.thread.quit) + self.loadDataWorker.signals.finished.connect(self.loadDataWorker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + self.loadDataWorker.signals.finished.connect(self.loadDataWorkerFinished) + self.loadDataWorker.signals.progress.connect(self.workerProgress) + self.loadDataWorker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.loadDataWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.loadDataWorker.signals.critical.connect(self.workerCritical) + self.loadDataWorker.signals.dataIntegrityCritical.connect( + self.loadDataWorkerDataIntegrityCritical + ) + self.loadDataWorker.signals.dataIntegrityWarning.connect( + self.loadDataWorkerDataIntegrityWarning + ) + self.loadDataWorker.signals.sigPermissionError.connect( + self.workerPermissionError + ) + self.loadDataWorker.signals.sigWarnMismatchSegmDataShape.connect( + self.askMismatchSegmDataShape + ) + self.loadDataWorker.signals.sigRecovery.connect(self.askRecoverNotSavedData) + + self.thread.started.connect(self.loadDataWorker.run) + self.thread.start() + + def stopAutomaticLoadingPos(self): + if self.AutoPilot is None: + return + + if self.AutoPilot.timer.isActive(): + self.AutoPilot.timer.stop() + self.AutoPilot = None + + def warnMemoryNotSufficient(self, total_ram, available_ram, required_ram): + total_ram = utils._bytes_to_GB(total_ram) + available_ram = utils._bytes_to_GB(available_ram) + required_ram = utils._bytes_to_GB(required_ram) + required_perc = round(100 * required_ram / available_ram) + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + The total amount of data that you requested to load is about + {required_ram:.2f} GB ({required_perc}% of the available memory) + but there are only {available_ram:.2f} GB available.

    + For optimal operation, we recommend loading maximum 30% + of the available memory. To do so, try to close open apps to + free up some memory. Another option is to crop the images + using the data prep module.

    + If you choose to continue, the system might freeze + or your OS could simply kill the process.

    + What do you want to do? + """) + cancelButton, continueButton = msg.warning( + self, + "Memory not sufficient", + txt, + buttonsTexts=("Cancel", "Continue anyway"), + ) + if msg.clickedButton == continueButton: + # Disable autosaving since it would keep a copy of the data and + # we cannot afford it with low memory + self.autoSaveToggle.setChecked(False) + return True + else: + return False + + def warnUserCreationImagesFolder(self, images_path, ext): + msg = widgets.myMessageBox(wrapText=False) + txt = f""" + Cell-ACDC requires a specific folder structure to load the data.

    + Specifically, it requires the image(s) to be located in a + folder called Images.

    + The file format of the images must be TIFF or NPZ + (.tif or .npz extension).

    + You can choose to let Cell-ACDC create the required data structure + from your file,
    + or you can stop the + process and manually place the image(s) into a folder called + Images.

    + If you choose to proceed, Cell-ACDC will create the following + folder: + {images_path} +
    + """ + + if ext == ".tif" or ext == ".npz": + txt = f"{txt}How do you want to proceed?" + else: + txt = f"{txt}Do you want to proceed?" + txt = html_utils.paragraph(txt) + + if ext == ".tif" or ext == ".npz": + copyButton = widgets.copyPushButton("Copy the image into the new folder") + moveButton = widgets.movePushButton("Move the image into the new folder") + _, copyButton, moveButton = msg.information( + self, + "Creating Images folder", + txt, + buttonsTexts=("Cancel", copyButton, moveButton), + ) + if msg.cancel: + return False, None + + if msg.clickedButton == copyButton: + return True, True + elif msg.clickedButton == moveButton: + return True, False + + else: + msg.information( + self, + "Creating Images folder", + txt, + buttonsTexts=("Cancel", "Yes, proceed"), + ) + if msg.cancel: + return False, None + + return True, True + + def workerPermissionError(self, txt, waitCond): + msg = widgets.myMessageBox(parent=self) + msg.setIcon(iconName="SP_MessageBoxCritical") + msg.setWindowTitle("Permission denied") + msg.addText(txt) + msg.addButton(" Ok ") + msg.exec_() + waitCond.wakeAll() + + def zSliceAbsent(self, filename, posData): + self.app.restoreOverrideCursor() + SizeZ = posData.SizeZ + chNames = posData.chNames + filenamesPresent = posData.segmInfo_df.index.get_level_values(0).unique() + chNamesPresent = [ + ch + for ch in chNames + for file in filenamesPresent + if file.endswith(ch) or file.endswith(f"{ch}_aligned") + ] + win = apps.QDialogZsliceAbsent(filename, SizeZ, chNamesPresent) + win.exec_() + if win.cancel: + self.worker.abort = True + self.waitCond.wakeAll() + return + if win.useMiddleSlice: + user_ch_name = filename[len(posData.basename) :] + for _posData in self.data: + if _posData is None: + continue + _, filename = self.getPathFromChName(user_ch_name, _posData) + df = utils.getDefault_SegmInfo_df(_posData, filename) + _posData.segmInfo_df = pd.concat([df, _posData.segmInfo_df]) + unique_idx = ~_posData.segmInfo_df.index.duplicated() + _posData.segmInfo_df = _posData.segmInfo_df[unique_idx] + _posData.segmInfo_df.to_csv(_posData.segmInfo_df_csv_path) + elif win.useSameAsCh: + user_ch_name = filename[len(posData.basename) :] + for _posData in self.data: + if _posData is None: + continue + _, srcFilename = self.getPathFromChName(win.selectedChannel, _posData) + cellacdc_df = _posData.segmInfo_df.loc[srcFilename].copy() + _, dstFilename = self.getPathFromChName(user_ch_name, _posData) + if dstFilename is None: + self.worker.abort = True + self.waitCond.wakeAll() + return + dst_df = utils.getDefault_SegmInfo_df(_posData, dstFilename) + for z_info in cellacdc_df.itertuples(): + frame_i = z_info.Index + zProjHow = z_info.which_z_proj + if zProjHow == "single z-slice": + src_idx = (srcFilename, frame_i) + if _posData.segmInfo_df.at[src_idx, "resegmented_in_gui"]: + col = "z_slice_used_gui" + else: + col = "z_slice_used_dataPrep" + z_slice = _posData.segmInfo_df.at[src_idx, col] + dst_idx = (dstFilename, frame_i) + dst_df.at[dst_idx, "z_slice_used_dataPrep"] = z_slice + dst_df.at[dst_idx, "z_slice_used_gui"] = z_slice + _posData.segmInfo_df = pd.concat([dst_df, _posData.segmInfo_df]) + unique_idx = ~_posData.segmInfo_df.index.duplicated() + _posData.segmInfo_df = _posData.segmInfo_df[unique_idx] + _posData.segmInfo_df.to_csv(_posData.segmInfo_df_csv_path) + elif win.runDataPrep: + user_ch_file_paths = [] + user_ch_name = filename[len(self.data[self.pos_i].basename) :] + for _posData in self.data: + if _posData is None: + continue + user_ch_path = load.get_filename_from_channel( + _posData.images_path, user_ch_name + ) + if user_ch_path is None: + self.worker.abort = True + self.waitCond.wakeAll() + return + user_ch_file_paths.append(user_ch_path) + exp_path = os.path.dirname(_posData.pos_path) + + dataPrepWin = dataPrep.dataPrepWin() + dataPrepWin.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + dataPrepWin.titleText = """ + Select z-slice (or projection) for each frame/position.
    + Once happy, close the window. + """ + dataPrepWin.show() + dataPrepWin.initLoading() + dataPrepWin.SizeT = self.data[0].SizeT + dataPrepWin.SizeZ = self.data[0].SizeZ + dataPrepWin.metadataAlreadyAsked = True + self.logger.info(f"Loading channel {user_ch_name} data...") + dataPrepWin.loadFiles(exp_path, user_ch_file_paths, user_ch_name) + dataPrepWin.startAction.setDisabled(True) + dataPrepWin.onlySelectingZslice = True + + loop = QEventLoop(self) + dataPrepWin.loop = loop + loop.exec_() + + self.waitCond.wakeAll() + + def getConcatAcdcDf(self): + acdc_dfs = [] + keys = [] + posData = self.data[self.pos_i] + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + break + + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + break + + acdc_dfs.append(acdc_df) + keys.append(frame_i) + + if not acdc_dfs: + return + + return pd.concat(acdc_dfs, keys=keys, names=["frame_i"]) diff --git a/cellacdc/mixins/deleted_rois.py b/cellacdc/mixins/deleted_rois.py new file mode 100644 index 000000000..b3a8be256 --- /dev/null +++ b/cellacdc/mixins/deleted_rois.py @@ -0,0 +1,583 @@ +"""Qt view adapter for deleted-ROI workflows.""" + +from __future__ import annotations + +from functools import partial +import uuid + +import numpy as np +import pyqtgraph as pg +from collections.abc import Iterable +import skimage.measure +from qtpy.QtCore import QRect, QRectF, QTimer + +from cellacdc import widgets + +from .cell_cycle import CellCycle + + +class DeletedRois(CellCycle): + """Extracted from guiWin.""" + + def addDelPolyLineRoi_cb(self, checked): + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.addDelPolyLineRoiButton) + self.connectLeftClickButtons() + if self.isSnapshot: + self.fixCcaDfAfterEdit("Delete IDs using ROI") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Delete IDs using ROI") + else: + self.tempSegmentON = False + self.ax1_rulerPlotItem.setData([], []) + self.ax1_rulerAnchorsItem.setData([], []) + self.startPointPolyLineItem.setData([], []) + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + def addDelROI(self, event): + roi, key = self.createDelROI() + self.addRoiToDelRoiInfo(roi) + if not self.labelsGrad.showLabelsImgAction.isChecked(): + self.ax1.addDelRoiItem(roi, key) + else: + self.ax2.addDelRoiItem(roi, key) + self.applyDelROIimg1(roi, init=True) + self.applyDelROIimg1(roi, init=True, ax=1) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Delete IDs using ROI") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Delete IDs using ROI", get_cancelled=True) + + def addExistingDelROIs(self): + posData = self.data[self.pos_i] + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + isAx2hidden = not self.labelsGrad.showLabelsImgAction.isChecked() + + for r, roi in enumerate(delROIs_info["rois"]): + if isinstance(roi, pg.PolyLineROI) or isAx2hidden: + # PolyLine ROIs are only on ax1 + self.ax1.addDelRoiItem(roi, roi.key) + else: + # Rect ROI is on ax2 because ax2 is visible + self.ax2.addDelRoiItem(roi, roi.key) + + self.setDelRoiState(roi, delROIs_info["state"][r]) + + def addPointsPolyLineRoi(self, closed=False): + self.polyLineRoi.setPoints(self.polyLineRoi.points, closed=closed) + if not closed: + return + + # Connect closed ROI + self.polyLineRoi.sigRegionChanged.connect(self.delROImoving) + self.polyLineRoi.sigRegionChangeFinished.connect(self.delROImovingFinished) + + def addRoiToDelRoiInfo(self, roi: pg.ROI): + posData = self.data[self.pos_i] + for i in range(posData.frame_i, posData.SizeT): + delROIs_info = posData.allData_li[i]["delROIs_info"] + delROIs_info["rois"].append(roi) + delROIs_info["state"].append(roi.getState()) + delROIs_info["delMasks"].append(np.zeros_like(self.currentLab2D)) + delROIs_info["delIDsROI"].append(set()) + + def applyDelROIimg1(self, roi, init=False, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + if ax == 1 and not self.labelsGrad.showRightImgAction.isChecked(): + return + + if init and how.find("contours") == -1: + self.setOverlaySegmMasks(force=True) + return + + posData = self.data[self.pos_i] + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + try: + idx = delROIs_info["rois"].index(roi) + except Exception as err: + try: + ax.removeDelRoiItem(roi) + except Exception as err: + pass + return + delIDs = delROIs_info["delIDsROI"][idx] + delMask = delROIs_info["delMasks"][idx] + if how.find("nothing") != -1: + return + elif how.find("contours") != -1: + self.updateContoursImage(ax=ax) + + if not delIDs: + return + + if how.find("overlay segm. masks") != -1: + lab = self.currentLab2D.copy() + lab[delMask > 0] = 0 + if ax == 0: + self.labelsLayerImg1.setImage(lab, autoLevels=False) + else: + self.labelsLayerRightImg.setImage(lab, autoLevels=False) + + self.setAllTextAnnotations(labelsToSkip={ID: True for ID in delIDs}) + + def applyDelROIs(self): + self.logger.info("Applying deletion ROIs (if present)...") + + for posData in self.data: + self.current_frame_i = posData.frame_i + for frame_i in range(posData.SizeT): + lab = posData.allData_li[frame_i]["labels"] + if lab is None: + break + delROIs_info = posData.allData_li[frame_i]["delROIs_info"] + delIDs_rois = delROIs_info["delIDsROI"] + if not delIDs_rois: + continue + for delIDs in delIDs_rois: + for delID in delIDs: + lab[lab == delID] = 0 + posData.allData_li[frame_i]["labels"] = lab + # Get the rest of the metadata and store data based on the new lab + posData.frame_i = frame_i + self.get_data() + self.store_data(autosave=False) + + # Back to current frame + posData.frame_i = self.current_frame_i + self.get_data() + + def clearLostObjContoursItems(self): + self.ax1_lostObjScatterItem.setData([], []) + self.ax2_lostObjScatterItem.setData([], []) + + self.ax1_lostTrackedScatterItem.setData([], []) + self.ax2_lostTrackedScatterItem.setData([], []) + + self.ax2_lostObjImageItem.clear() + self.ax2_lostTrackedObjImageItem.clear() + + self.ax1_lostObjImageItem.clear() + self.ax1_lostTrackedObjImageItem.clear() + + def createDelPolyLineRoi(self): + Y, X = self.currentLab2D.shape + self.polyLineRoi = pg.PolyLineROI( + [], rotatable=False, removable=True, pen=pg.mkPen(color="r") + ) + self.polyLineRoi.handleSize = 7 + self.polyLineRoi.points = [] + key = uuid.uuid4() + self.ax1.addDelRoiItem(self.polyLineRoi, key) + + def createDelROI(self, xl=None, yb=None, w=32, h=32, anchors=None): + posData = self.data[self.pos_i] + if xl is None: + xRange, yRange = self.ax1.viewRange() + xl = 0 if xRange[0] < 0 else xRange[0] + yb = 0 if yRange[0] < 0 else yRange[0] + Y, X = self.currentLab2D.shape + if anchors is None: + roi = widgets.DelROI( + [xl, yb], + [w, h], + rotatable=False, + removable=True, + pen=pg.mkPen(color="r"), + maxBounds=QRectF(QRect(0, 0, X, Y)), + ) + ## handles scaling horizontally around center + roi.addScaleHandle([1, 0.5], [0, 0.5]) + roi.addScaleHandle([0, 0.5], [1, 0.5]) + + ## handles scaling vertically from opposite edge + roi.addScaleHandle([0.5, 0], [0.5, 1]) + roi.addScaleHandle([0.5, 1], [0.5, 0]) + + ## handles scaling both vertically and horizontally + roi.addScaleHandle([1, 1], [0, 0]) + roi.addScaleHandle([0, 0], [1, 1]) + roi.addScaleHandle([0, 1], [1, 0]) + roi.addScaleHandle([1, 0], [0, 1]) + + roi.handleSize = 7 + roi.sigRegionChanged.connect(self.delROImoving) + roi.sigRegionChanged.connect(self.delROIstartedMoving) + roi.sigRegionChangeFinished.connect(self.delROImovingFinished) + + key = uuid.uuid4() + + return roi, key + + def delROImoving(self, roi): + roi.setPen(color=(255, 255, 0)) + # First bring back IDs if the ROI moved away + self.restoreAnnotDelROI(roi) + self.setImageImg2() + self.applyDelROIimg1(roi) + self.applyDelROIimg1(roi, ax=1) + + def delROImovingFinished(self, roi: pg.ROI): + roi.setPen(color="r") + self.update_rp() + self.updateAllImages() + QTimer.singleShot(300, partial(self.updateDelROIinFutureFrames, roi)) + + def delROIstartedMoving(self, roi): + self.clearLostObjContoursItems() + + def getDelROIlab(self, input_lab_2D=None): + posData = self.data[self.pos_i] + if self.delRoiLab is None: + self.initDelRoiLab() + + out_lab = self.delRoiLab + if input_lab_2D is None: + out_lab[:] = self.get_2Dlab(posData.lab, force_z=False) + else: + out_lab[:] = input_lab_2D + + allDelIDs = set() + # Iterate rois and delete IDs + for roi in posData.allData_li[posData.frame_i]["delROIs_info"]["rois"]: + if not self.ax1.isDelRoiItemPresent( + roi + ) and not self.ax2.isDelRoiItemPresent(roi): + continue + ROImask = self.getDelRoiMask(roi) + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + idx = delROIs_info["rois"].index(roi) + delObjROImask = delROIs_info["delMasks"][idx] + delIDsROI = delROIs_info["delIDsROI"][idx] + delROIlabRp = skimage.measure.regionprops(out_lab) + for delObj in delROIlabRp: + isDelObj = np.any(ROImask[delObj.slice][delObj.image]) + if not isDelObj: + continue + + delObjROImask[delObj.slice][delObj.image] = delObj.label + out_lab[delObj.slice][delObj.image] = 0 + + delIDsROI.add(delObj.label) + allDelIDs.add(delObj.label) + + # Keep a mask of deleted IDs to bring them back when roi moves + delROIs_info["delMasks"][idx] = delObjROImask + delROIs_info["delIDsROI"][idx] = delIDsROI + + # printl( + # f't1-t0: {(t1-t0)*1000:.3f} ms,', + # f't2-t1: {(t2-t1)*1000:.3f} ms,', + # f't3-t2: {(t3-t2)*1000:.3f} ms,', + # # f't4-t3: {(t4-t3)*1000:.3f} ms,', + # # f't5-t4: {(t5-t4)*1000:.3f} ms,', + # # f't6-t5: {(t6-t5)*1000:.3f} ms', + # sep='\n' + # ) + + return allDelIDs, out_lab + + def getDelRoiMask(self, roi, posData=None, z_slice=None): + if posData is None: + posData = self.data[self.pos_i] + if z_slice is None: + z_slice = self.z_lab() + ROImask = np.zeros(posData.lab.shape, bool) + if isinstance(roi, pg.PolyLineROI): + r, c = [], [] + x0, y0 = roi.pos().x(), roi.pos().y() + for _, point in roi.getLocalHandlePositions(): + xr, yr = point.x(), point.y() + r.append(int(yr + y0)) + c.append(int(xr + x0)) + if not r or not c: + return ROImask + + if len(r) == 2: + rr, cc, val = skimage.draw.line_aa(r[0], c[0], r[1], c[1]) + else: + rr, cc = skimage.draw.polygon(r, c, shape=self.currentLab2D.shape) + + Y, X = self.currentLab2D.shape + rr = rr[(rr >= 0) & (rr < Y)] + cc = cc[(cc >= 0) & (cc < X)] + + if self.isSegm3D: + ROImask[z_slice, rr, cc] = True + else: + ROImask[rr, cc] = True + elif isinstance(roi, pg.LineROI): + (_, point1), (_, point2) = roi.getSceneHandlePositions() + point1 = self.ax1.vb.mapSceneToView(point1) + point2 = self.ax1.vb.mapSceneToView(point2) + x1, y1 = int(point1.x()), int(point1.y()) + x2, y2 = int(point2.x()), int(point2.y()) + rr, cc, val = skimage.draw.line_aa(y1, x1, y2, x2) + if self.isSegm3D: + ROImask[z_slice, rr, cc] = True + else: + ROImask[rr, cc] = True + else: + x0, y0 = [int(c) for c in roi.pos()] + w, h = [int(c) for c in roi.size()] + if self.isSegm3D: + ROImask[z_slice, y0 : y0 + h, x0 : x0 + w] = True + else: + ROImask[y0 : y0 + h, x0 : x0 + w] = True + return ROImask + + def getDelRoisIDs(self): + posData = self.data[self.pos_i] + if posData.frame_i > 0: + prev_lab = posData.allData_li[posData.frame_i - 1]["labels"] + allDelIDs = set() + for roi in posData.allData_li[posData.frame_i]["delROIs_info"]["rois"]: + if not self.ax1.isDelRoiItemPresent( + roi + ) and not self.ax2.isDelRoiItemPresent(roi): + continue + + ROImask = self.getDelRoiMask(roi) + delIDs = posData.lab[ROImask] + allDelIDs.update(delIDs) + if posData.frame_i > 0: + delIDsPrevFrame = prev_lab[ROImask] + allDelIDs.update(delIDsPrevFrame) + return allDelIDs + + def getStoredDelRoiIDs(self, frame_i=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + allDelIDs = set() + delROIs_info = posData.allData_li[frame_i]["delROIs_info"] + delIDs_rois = delROIs_info["delIDsROI"] + for delIDs in delIDs_rois: + allDelIDs.update(delIDs) + return allDelIDs + + def initDelRoiLab(self): + posData = self.data[self.pos_i] + z_slice = self.z_lab() + img = posData.img_data[posData.frame_i] + Y, X = img[z_slice].shape[-2:] + + self.delRoiLab = np.zeros((Y, X), dtype=np.uint32) + + def moveDelRoisToLeft(self): + # Move del ROIs to the left image + for posData in self.data: + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + for roi in delROIs_info["rois"]: + if not self.ax2.isDelRoiItemPresent(roi): + continue + + self.ax1.addDelRoiItem(roi, roi.key) + self.ax2.removeDelRoiItem(roi) + + def removeAlldelROIsCurrentFrame(self): + posData = self.data[self.pos_i] + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + rois = delROIs_info["rois"].copy() + for roi in rois: + self.ax2.removeDelRoiItem(roi) + + for item in self.ax2.items: + if isinstance(item, pg.ROI): + self.ax2.removeDelRoiItem(item) + + for item in self.ax1.items: + if isinstance(item, pg.ROI) and item != self.labelRoiItem: + self.ax1.removeDelRoiItem(item) + + def removeDelROI(self, event): + posData = self.data[self.pos_i] + + for ax in (self.ax1, self.ax2): + try: + self.ax1.removeDelRoiItem(self.roi_to_del) + except Exception as err: + pass + + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + idx = delROIs_info["rois"].index(self.roi_to_del) + delROIs_info["rois"].pop(idx) + delROIs_info["delMasks"].pop(idx) + delROIs_info["delIDsROI"].pop(idx) + delROIs_info["state"].pop(idx) + + self.removeDelROIFromFutureFrames(self.roi_to_del) + self.updateAllImages() + + def removeDelROIFromFutureFrames(self, roi_to_del): + posData = self.data[self.pos_i] + + # Restore deleted IDs from already visited future frames + current_frame_i = posData.frame_i + for i in range(posData.frame_i + 1, posData.SizeT): + if posData.allData_li[i]["labels"] is None: + break + + delROIs_info = posData.allData_li[i]["delROIs_info"] + try: + idx = delROIs_info["rois"].index(roi_to_del) + except IndexError: + continue + + posData.frame_i = i + idx = delROIs_info["rois"].index(roi_to_del) + if delROIs_info["delIDsROI"][idx]: + posData.lab = posData.allData_li[i]["labels"] + self.restoreAnnotDelROI(roi_to_del, enforce=True, draw=False) + posData.allData_li[i]["labels"] = posData.lab + self.get_data() + self.store_data(autosave=False) + delROIs_info["rois"].pop(idx) + delROIs_info["delMasks"].pop(idx) + delROIs_info["delIDsROI"].pop(idx) + delROIs_info["state"].pop(idx) + + if isinstance(self.roi_to_del, pg.PolyLineROI): + # PolyLine ROIs are only on ax1 + self.ax1.removeItem(self.roi_to_del) + elif not self.labelsGrad.showLabelsImgAction.isChecked(): + # Rect ROI is on ax1 because ax2 is hidden + self.ax1.removeItem(self.roi_to_del) + else: + # Rect ROI is on ax2 because ax2 is visible + self.ax2.removeItem(self.roi_to_del) + + # Back to current frame + posData.frame_i = current_frame_i + posData.lab = posData.allData_li[posData.frame_i]["labels"] + self.get_data() + self.store_data() + + def replacePolyLineRoiWithLineRoi(self, roi): + x0, y0 = roi.pos().x(), roi.pos().y() + (_, point1), (_, point2) = roi.getLocalHandlePositions() + xr1, yr1 = point1.x(), point1.y() + xr2, yr2 = point2.x(), point2.y() + x1, y1 = xr1 + x0, yr1 + y0 + x2, y2 = xr2 + x0, yr2 + x0 + lineRoi = pg.LineROI((x1, y1), (x2, y2), width=0.5) + lineRoi.handleSize = 7 + self.ax1.removeItem(self.polyLineRoi) + self.ax1.addItem(lineRoi) + lineRoi.removeHandle(2) + # Connect closed ROI + lineRoi.sigRegionChanged.connect(self.delROImoving) + lineRoi.sigRegionChangeFinished.connect(self.delROImovingFinished) + return lineRoi + + def restoreAnnotDelROI(self, roi, enforce=True, draw=True): + posData = self.data[self.pos_i] + ROImask = self.getDelRoiMask(roi) + delROIs_info = posData.allData_li[posData.frame_i]["delROIs_info"] + try: + idx = delROIs_info["rois"].index(roi) + except Exception as err: + return + + delMask = delROIs_info["delMasks"][idx] + delIDs = delROIs_info["delIDsROI"][idx] + overlapROIdelIDs = np.unique(delMask[ROImask]) + lab2D = self.get_2Dlab(posData.lab) + restoredIDs = set() + for ID in delIDs: + if ID in overlapROIdelIDs and not enforce: + continue + + restoredIDs.add(ID) + + delMaskID = delMask == ID + self.currentLab2D[delMaskID] = ID + lab2D[delMaskID] = ID + + if draw: + self.restoreDelROIimg1(delMaskID, ID, ax=0) + self.restoreDelROIimg1(delMaskID, ID, ax=1) + + delMask[delMaskID] = 0 + + delROIs_info["delIDsROI"][idx] = delIDs - restoredIDs + self.set_2Dlab(lab2D) + self.update_rp() + + def restoreDelROIimg1(self, delMaskID, delID, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + if how.find("nothing") != -1: + return + + if how.find("contours") != -1: + rp_delmask = skimage.measure.regionprops(delMaskID.astype(np.uint8)) + if len(rp_delmask) > 0: + obj = rp_delmask[0] + self.addObjContourToContoursImage(obj=obj, ax=ax) + elif how.find("overlay segm. masks") != -1: + if ax == 0: + self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) + else: + self.labelsLayerRightImg.setImage(self.currentLab2D, autoLevels=False) + + def setDelRoiState(self, roi: pg.ROI, state): + roi.sigRegionChanged.disconnect() + roi.sigRegionChangeFinished.disconnect() + roi.setState(state) + roi.sigRegionChanged.connect(self.delROImoving) + roi.sigRegionChangeFinished.connect(self.delROImovingFinished) + + def updateDelROIinFutureFrames(self, roi: pg.ROI): + posData = self.data[self.pos_i] + restore_current_frame = False + + roiState = roi.getState() + # Restore deleted IDs from already visited future frames + current_frame_i = posData.frame_i + delROIs_info = posData.allData_li[current_frame_i]["delROIs_info"] + try: + idx = delROIs_info["rois"].index(roi) + delROIs_info["state"][idx] = roiState + except Exception as err: + pass + + self.store_data() + + for i in range(posData.frame_i + 1, posData.SizeT): + delROIs_info = posData.allData_li[i]["delROIs_info"] + try: + idx = delROIs_info["rois"].index(roi) + except Exception as err: + continue + delROIs_info["state"][idx] = roiState + if posData.allData_li[i]["labels"] is None: + continue + + posData.frame_i = i + posData.lab = posData.allData_li[i]["labels"] + self.restoreAnnotDelROI(roi, enforce=False, draw=False) + posData.allData_li[i]["labels"] = posData.lab + self.get_data() + self.store_data(autosave=False) + restore_current_frame = True + + if not restore_current_frame: + return + + # Back to current frame + posData.frame_i = current_frame_i + posData.lab = posData.allData_li[posData.frame_i]["labels"] + self.get_data() + self.store_data() diff --git a/cellacdc/mixins/display_decorations.py b/cellacdc/mixins/display_decorations.py new file mode 100644 index 000000000..81e9daa1d --- /dev/null +++ b/cellacdc/mixins/display_decorations.py @@ -0,0 +1,152 @@ +"""View adapter for timestamp, scale-bar, and view-range decorations.""" + +from __future__ import annotations + +import numpy as np + +from cellacdc import apps, widgets + + +class DisplayDecorations: + """Extracted from guiWin.""" + + def addScaleBar(self, checked): + if checked: + posData = self.data[self.pos_i] + Y, X = self.img1.image.shape[:2] + viewRange = self.ax1ViewRange() + self.scaleBarDialog = apps.ScaleBarPropertiesDialog( + X, Y, posData.PhysicalSizeX, parent=self + ) + self.scaleBarDialog.show() + self.scaleBar = widgets.ScaleBar((Y, X), viewRange, parent=self.ax1) + self.scaleBar.sigEditProperties.connect(self.editScaleBarProperties) + self.scaleBar.sigRemove.connect(self.editScaleBarRemove) + self.scaleBar.addToAxis(self.ax1) + self.scaleBar.draw(**self.scaleBarDialog.kwargs()) + self.scaleBarDialog.sigValueChanged.connect(self.updateScaleBar) + self.scaleBarDialog.exec_() + if self.scaleBarDialog.cancel: + self.addScaleBarAction.setChecked(False) + return + else: + self.scaleBar.removeFromAxis(self.ax1) + + self.scaleBarDialog = None + self.imgGrad.addScaleBarAction.setChecked(checked) + + def addTimestamp(self, checked): + if checked: + posData = self.data[self.pos_i] + Y, X = self.img1.image.shape[:2] + viewRange = self.ax1ViewRange() + self.timestampDialog = apps.TimestampPropertiesDialog(parent=self) + self.timestampDialog.show() + self.timestamp = widgets.TimestampItem( + Y, + X, + viewRange, + secondsPerFrame=posData.TimeIncrement, + start_timedelta=self.timestampStartTimedelta, + ) + self.timestamp.sigEditProperties.connect(self.editTimestampProperties) + self.timestamp.sigRemove.connect(self.editTimestampRemove) + self.timestamp.addToAxis(self.ax1) + self.timestamp.draw(posData.frame_i, **self.timestampDialog.kwargs()) + self.timestampDialog.sigValueChanged.connect(self.updateTimestamp) + self.timestampDialog.exec_() + else: + self.timestamp.removeFromAxis(self.ax1) + + self.timestampDialog = None + self.imgGrad.addTimestampAction.setChecked(checked) + + def ax1ViewRange(self, integers=False): + if self.exportToImageWindow is None: + viewRange = self.ax1.viewRange() + else: + exportMask = np.all(self.exportMaskImage == [0, 0, 0, 0], axis=-1) + if np.all(exportMask): + viewRange = self.ax1.viewRange() + else: + viewRange = self.ax1.viewRange(exportMask) + + if not integers: + return viewRange + + xRange, yRange = viewRange + xmin = round(xRange[0]) + ymin = round(yRange[0]) + xmax = round(xRange[1]) + ymax = round(yRange[1]) + return [xmin, xmax], [ymin, ymax] + + def getViewRange(self): + Y, X = self.img1.image.shape[:2] + xRange, yRange = self.ax1.viewRange() + xmin = 0 if xRange[0] < 0 else xRange[0] + ymin = 0 if yRange[0] < 0 else yRange[0] + + xmax = X if xRange[1] >= X else xRange[1] + ymax = Y if yRange[1] >= Y else yRange[1] + return int(ymin), int(ymax), int(xmin), int(xmax) + + def editScaleBarProperties(self, properties): + Y, X = self.img1.image.shape[:2] + posData = self.data[self.pos_i] + self.scaleBarDialog = apps.ScaleBarPropertiesDialog( + X, Y, posData.PhysicalSizeX, parent=self, **properties + ) + self.scaleBarDialog.sigValueChanged.connect(self.updateScaleBar) + self.scaleBarDialog.exec_() + + def editScaleBarRemove(self, timestamp): + self.addScaleBarAction.setChecked(False) + + def editTimestampProperties(self, properties): + self.timestampDialog = apps.TimestampPropertiesDialog(parent=self, **properties) + self.timestampDialog.sigValueChanged.connect(self.updateTimestamp) + self.timestampDialog.show() + + def editTimestampRemove(self, timestamp): + self.addTimestampAction.setChecked(False) + + def viewRangeChanged(self, viewBox, viewRange, updateExportImageMask=True): + # self.updateViewRangeExportToImage(viewRange) + self.updateValuesStatusBar() + + if hasattr(self, "scaleBar"): + isScaleBarMoveWithZoom = self.scaleBar.properties()["move_with_zoom"] + else: + isScaleBarMoveWithZoom = False + doMoveScaleBar = self.scaleBarDialog is not None or isScaleBarMoveWithZoom + if doMoveScaleBar: + self.scaleBar.updatePosViewRangeChanged(viewRange) + + if hasattr(self, "timestamp"): + isTimestampMoveWithZoom = self.timestamp.properties()["move_with_zoom"] + else: + isTimestampMoveWithZoom = False + + doMoveTimestamp = self.timestampDialog is not None or isTimestampMoveWithZoom + if doMoveTimestamp: + self.timestamp.updatePosViewRangeChanged(viewRange) + + self._viewRange = viewRange + + def updateScaleBar(self, scaleBarKwargs): + self.scaleBar.draw(**scaleBarKwargs) + + def updateTimestamp(self, timeStampKwargs): + posData = self.data[self.pos_i] + self.timestamp.draw(posData.frame_i, **timeStampKwargs) + + def updateTimestampFrame(self): + if not hasattr(self, "timestamp"): + return + + if not self.addTimestampAction.isChecked(): + return + + posData = self.data[self.pos_i] + self.timestamp.setText(posData.frame_i) diff --git a/cellacdc/mixins/draw_clear_region.py b/cellacdc/mixins/draw_clear_region.py new file mode 100644 index 000000000..a4118cf5f --- /dev/null +++ b/cellacdc/mixins/draw_clear_region.py @@ -0,0 +1,90 @@ +"""View adapter for draw-clear-region workflows.""" + +from __future__ import annotations + +from .undo_redo import UndoRedo + + +class DrawClearRegion(UndoRedo): + """Extracted from guiWin.""" + + def drawClearRegion_cb(self, checked): + posData = self.data[self.pos_i] + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.drawClearRegionButton) + self.connectLeftClickButtons() + + self.drawClearRegionToolbar.setVisible(checked) + + if not self.isSegm3D: + self.drawClearRegionToolbar.setZslicesControlEnabled(False) + return + + if not checked: + return + + self.drawClearRegionToolbar.setZslicesControlEnabled(True, SizeZ=posData.SizeZ) + + def clearObjsFreehandRegion(self): + self.logger.info("Clearing objects inside freehand region...") + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False, storeImage=False, storeOnlyZoom=True) + + posData = self.data[self.pos_i] + zRange = None + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + z_slice = self.z_lab() + zRange = self.drawClearRegionToolbar.zRange(z_slice, posData.SizeZ) + else: + zRange = (0, posData.SizeZ) + + regionSlice = self.freeRoiItem.slice(zRange=zRange) + mask = self.freeRoiItem.mask() + + regionLab = posData.lab[(...,) + regionSlice].copy() + + clearBorders = ( + self.drawClearRegionToolbar.clearOnlyEnclosedObjsRadioButton.isChecked() + ) + if clearBorders: + if regionLab.ndim == 2: + regionLab = transformation.clear_objects_not_in_mask(regionLab, mask) + regionRp = skimage.measure.regionprops(regionLab) + for obj in regionRp: + if np.all(mask[obj.slice][obj.image]): + continue + + regionLab[obj.slice][obj.image] = 0 + else: + for z, regionLab_z in enumerate(regionLab): + regionLab[z] = transformation.clear_objects_not_in_mask( + regionLab_z, mask + ) + else: + regionLab[..., ~mask] = 0 + + regionRp = skimage.measure.regionprops(regionLab) + clearIDs = [obj.label for obj in regionRp] + + if not clearIDs: + if clearBorders: + self.logger.warning( + "None of the objects in the freehand region are fully enclosed" + ) + else: + self.logger.warning( + "None of the objects are touching the freehand region" + ) + return + + self.deleteIDmiddleClick(clearIDs, False, False) + self.update_cca_df_deletedIDs(posData, clearIDs) + + self.freeRoiItem.clear() + + self.updateAllImages() diff --git a/cellacdc/mixins/exporting.py b/cellacdc/mixins/exporting.py new file mode 100644 index 000000000..ea66f1535 --- /dev/null +++ b/cellacdc/mixins/exporting.py @@ -0,0 +1,414 @@ +"""Qt view adapter for image and video export workflows.""" + +from __future__ import annotations + +import os +import shutil +import traceback +from functools import partial + +from datetime import datetime +import numpy as np +import skimage.measure +import skimage.segmentation +from qtpy.QtCore import QTimer + +from cellacdc import _warnings, apps, disableWindow, exception_handler +from cellacdc import exporters, html_utils, prompts, widgets + +from .app_shell import AppShell +from .frame_navigation import FrameNavigation + + +class Exporting(AppShell, FrameNavigation): + """Extracted from guiWin.""" + + def askTimelapseOrZslicesVideo(self): + txt = html_utils.paragraph(""" + Do you want to record a video of scrolling through the z-slices or + a Timelapse video? + """) + msg = widgets.myMessageBox(wrapText=False) + _, timelapseButton = msg.question( + self, + "Z-slices or Timelapse video?", + txt, + buttonsTexts=("Z-slices", "Timelapse"), + ) + if msg.cancel: + return + + return msg.clickedButton == timelapseButton + + def exportAddScaleBar(self, checked): + self.addScaleBarAction.setChecked(checked) + + def exportFrame(self): + nd = self.exportToVideoPreferences["num_digits"] + idx = str(self.exportToVideoCurrentNavVarIdx).zfill(nd) + filename = self.exportToVideoPreferences["filename"] + png_filename = f"{idx}_{filename}.png" + pngs_folderpath = self.exportToVideoPreferences["pngs_folderpath"] + + png_filepath = os.path.join(pngs_folderpath, png_filename) + img_bgr = self.exportToVideoImageExporter.export(png_filepath) + self.exportToVideoExporter.add_frame(img_bgr) + return True + + def exportToImage(self, preferences): + filepath = preferences["filepath"] + self.logger.info(f'Saving image to "{filepath}"...') + + if filepath.endswith(".svg"): + exporter = exporters.SVGExporter(self.ax1) + else: + exporter = exporters.ImageExporter(self.ax1, dpi=preferences["dpi"]) + exporter.export(filepath) + self.logger.info(f"Image saved.") + + self.setDisabled(False) + self.exportMaskImage[:] = 0 + self.exportMaskImageItem.setImage(self.exportMaskImage) + prompts.exportToImageFinished(filepath, qparent=self) + + def exportToImageTriggered(self): + posData = self.data[self.pos_i] + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}_acdc_exported_image" + win = apps.ExportToImageParametersDialog( + parent=self, + startFolderpath=posData.pos_path, + startFilename=filename, + startViewRange=self.ax1.viewRange(), + isScaleBarPresent=self.addScaleBarAction.isChecked(), + ) + win.sigAddScaleBar.connect(self.exportAddScaleBar) + win.sigRangeChanged.connect( + partial(self.setViewRangeFromExportToImageDialog, win=win) + ) + # self.ax1.vb.sigRangeChanged.connect( + # win.updateViewRangeExportToImageDialog + # ) + self.setExportMaskImage(self.ax1.viewRange()) + self.exportToImageWindow = win + win.exec_() + # self.ax1.vb.sigRangeChanged.disconnect() + if win.cancel: + self.exportMaskImage[:] = 0 + self.exportMaskImageItem.setImage(self.exportMaskImage) + self.exportToImageWindow = None + self.logger.info("Export to image process cancelled") + return + + isTransparent = self.overlayToolbar.isTransparent() + if not isTransparent: + # SVG export works only with RGBA not with setOpacity + # --> only true transparency mode can be used + self.overlayToolbar.setTransparent(True) + + self.exportToImage(win.selected_preferences) + self.exportToImageWindow = None + + if not isTransparent: + self.overlayToolbar.setTransparent(False) + + def exportToVideoAddTimestamp(self, checked): + self.addTimestampAction.setChecked(checked) + + def exportToVideoFinished(self, conversion_to_mp4_successful): + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + # Back to current frame + if self.exportToVideoPreferences["is_timelapse"]: + posData = self.data[self.pos_i] + posData.frame_i = self.exportToVideoNavVarIdxToRestore + self.get_data() + self.store_data() + self.updateAllImages() + self.navigateScrollBar.setSliderPosition(posData.frame_i + 1) + self.navSpinBox.setValue(posData.frame_i + 1) + else: + self.update_z_slice(self.exportToVideoNavVarIdxToRestore) + + self.setDisabled(False) + self.isExportingVideo = False + + if not self.isTransparent: + # True transparency mode was activated programmatically + # --> restore what the user had before starting to export + self.overlayToolbar.setTransparent(False) + + prompts.exportToVideoFinished( + self.exportToVideoPreferences, conversion_to_mp4_successful, qparent=self + ) + + def exportToVideoTriggered(self): + posData = self.data[self.pos_i] + + doTimelapseVideo = posData.SizeT > 1 + if posData.SizeT > 1 and posData.SizeZ > 1: + doTimelapseVideo = self.askTimelapseOrZslicesVideo() + + if doTimelapseVideo is None: + self.logger.info("Export to video process cancelled") + return + + channels = [self.user_ch_name, *self.checkedOverlayChannels] + mode = "timelapse" if doTimelapseVideo else "z_slices" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}_acdc_exported_{mode}_video" + win = apps.ExportToVideoParametersDialog( + channels, + parent=self, + startFolderpath=posData.pos_path, + startFilename=filename, + startFrameNum=posData.frame_i + 1, + SizeT=posData.SizeT, + SizeZ=posData.SizeZ, + isTimelapseVideo=doTimelapseVideo, + isScaleBarPresent=self.addScaleBarAction.isChecked(), + isTimestampPresent=self.addTimestampAction.isChecked(), + rescaleIntensChannelHowMapper=self.rescaleIntensChannelHowMapper, + ) + win.sigAddScaleBar.connect(self.exportAddScaleBar) + win.sigAddTimestamp.connect(self.exportToVideoAddTimestamp) + win.sigRescaleIntensLut.connect(self.rescaleIntensExportToVideoDialog) + win.exec_() + if win.cancel: + self.logger.info("Export to video process cancelled") + return + + cancel = _warnings.warnExportToVideo(qparent=self) + if cancel: + self.logger.info("Export to video process cancelled") + return + + self.startExportToVideoWorker(win.selected_preferences) + + def exportingFramesFinished(self): + if not self.exportToVideoPreferences["save_pngs"]: + self.logger.info("Removing PNGs...") + try: + shutil.rmtree(self.exportToVideoPreferences["pngs_folderpath"]) + except Exception as err: + pass + + self.logger.info("Saving video...") + + self.exportToVideoExporter.release() + + # Run ffmpeg new process + conversion_to_mp4_successful = True + if self.exportToVideoPreferences["filepath"].endswith(".mp4"): + try: + self.exportToVideoExporter.avi_to_mp4() + try: + os.remove(self.exportToVideoPreferences["avi_filepath"]) + except Exception as err: + pass + except Exception as err: + self.logger.exception(traceback.format_exc()) + self.logger.info("Conversion to MP4 failed. See traceback above.") + conversion_to_mp4_successful = False + self.exportToVideoPreferences["filepath"] = ( + self.exportToVideoExporter._avi_filepath + ) + + self.exportToVideoFinished(conversion_to_mp4_successful) + + def exportingVideoCritical(self): + self.setDisabled(False) + + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.logger.info("Exporting video process failed.") + + def getZoomIDs(self, viewRange=None): + if viewRange is None: + viewRange = self.ax1.viewRange() + + lab = self.currentLab2D + Y, X = lab.shape + ((xmin, xmax), (ymin, ymax)) = viewRange + if xmin <= 0 and ymin <= 0 and xmax >= X and ymax >= Y: + posData = self.data[self.pos_i] + return None + + xmin = xmin if xmin >= 0 else 0 + ymin = ymin if ymin >= 0 else 0 + xmax = xmax if xmax < X else X + ymax = ymax if ymax < Y else Y + + zoomSlice = ( + slice(round(ymin), round(ymax)), + slice(round(xmin), round(xmax)), + ) + + zoomLab = skimage.segmentation.clear_border(lab[zoomSlice]) + zoomRp = skimage.measure.regionprops(zoomLab) + zoomIDs = [obj.label for obj in zoomRp] + return zoomIDs + + def initExportMaskImage(self): + posData = self.data[self.pos_i] + z_slice = self.z_lab() + img = posData.img_data[posData.frame_i] + Y, X = img[z_slice].shape[-2:] + + self.exportMaskImage = np.zeros((Y, X, 4), dtype=np.uint8) + + def onSigUpdateCcaTableWindow(self, *args): + if not self.isDataLoaded: + return + + if self.ccaTableWin is None: + return + + viewRange = self.ax1.viewRange() + posData = self.data[self.pos_i] + zoomIDs = self.getZoomIDs(viewRange=viewRange) + + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + def setExportMaskImage(self, viewRange): + if not hasattr(self, "exportMaskImage"): + self.initExportMaskImage() + else: + self.exportMaskImage[:] = 0 + + xRange, yRange = viewRange + x0, x1 = map(round, xRange) + y0, y1 = map(round, yRange) + + if self.invertBwAction.isChecked(): + self.exportMaskImage[:, :, :3] = 255 + + if x0 > 0: + self.exportMaskImage[:, :x0, 3] = 255 + if x1 < self.exportMaskImage.shape[1]: + self.exportMaskImage[:, x1:, 3] = 255 + if y0 > 0: + self.exportMaskImage[:y0, :, 3] = 255 + if y1 < self.exportMaskImage.shape[0]: + self.exportMaskImage[y1:, :, 3] = 255 + + self.exportMaskImageItem.setImage(self.exportMaskImage) + + def setViewRangeFromExportToImageDialog(self, viewRange, win=None): + xRange, yRange = viewRange + # self.ax1.sigRangeChanged.disconnect(self.viewRangeChanged) + self.ax1.setRange(xRange=xRange, yRange=yRange) + # self.ax1.sigRangeChanged.connect(self.viewRangeChanged) + # self.viewRangeChanged( + # self.ax1.vb, viewRange, updateExportMaskImage=False + # ) + self.setExportMaskImage(viewRange) + + def startExportToVideoWorker(self, preferences): + self.isExportingVideo = True + self.isTransparent = self.overlayToolbar.isTransparent() + if not self.isTransparent: + # SVG export works only with RGBA not with setOpacity + # --> only true transparency mode can be used + self.overlayToolbar.setTransparent(True) + + self.setDisabled(True) + + self.progressWin = apps.QDialogWorkerProgress( + title="Exporting to video", + parent=self.mainWin, + pbarDesc="Exporting to video...", + ) + self.progressWin.show(self.app) + self.exportToVideoStopNavVarNum = preferences["stop_nav_var_num"] + self.numFramesExported = 0 + self.progressWin.mainPbar.setMaximum( + preferences["stop_nav_var_num"] - preferences["start_nav_var_num"] + 1 + ) + self.exportToVideoPreferences = preferences + + self.store_data() + posData = self.data[self.pos_i] + if self.exportToVideoPreferences["is_timelapse"]: + # Go to requested start frame + posData.frame_i = preferences["start_nav_var_num"] - 1 + self.get_data() + self.updateAllImages() + self.exportToVideoNavVarIdxToRestore = posData.frame_i + else: + self.update_z_slice(preferences["start_nav_var_num"] - 1) + self.exportToVideoNavVarIdxToRestore = self.zSliceScrollBar.sliderPosition() + self.exportToVideoCurrentNavVarIdx = preferences["start_nav_var_num"] - 1 + + self.exportToVideoImageExporter = exporters.ImageExporter( + self.ax1, save_pngs=preferences["save_pngs"], dpi=preferences["dpi"] + ) + self.exportToVideoExporter = exporters.VideoExporter( + preferences["avi_filepath"], preferences["fps"] + ) + + QTimer.singleShot(200, self.updateAndExportFrame) + + def updateAndExportFrame(self): + didVideoExporterFinish = ( + self.exportToVideoCurrentNavVarIdx == self.exportToVideoStopNavVarNum + ) + if didVideoExporterFinish: + self.progressWin.mainPbar.setMaximum(0) + self.progressWin.mainPbar.setValue(0) + QTimer.singleShot(50, self.exportingFramesFinished) + return + + posData = self.data[self.pos_i] + if self.exportToVideoPreferences["is_timelapse"]: + self.goToFrameNumber(self.exportToVideoCurrentNavVarIdx + 1) + else: + self.update_z_slice(self.exportToVideoCurrentNavVarIdx) + + success = self.exportFrame() + if success is None: + self.exportingVideoCritical() + return + + self.exportToVideoCurrentNavVarIdx += 1 + self.progressWin.mainPbar.update(1) + + QTimer.singleShot(50, self.updateAndExportFrame) + + def updateViewRangeExportToImage(self, viewRange): + if self.exportToImageWindow is None: + return + + # prevViewRange = self.exportToImageWindow.viewRange() + prevViewRange = self._viewRange + prevXRange = prevViewRange[0] + prevYRange = prevViewRange[1] + currXRange = viewRange[0] + currYRange = viewRange[1] + + prevX0, prevX1 = prevXRange + currX0, currX1 = currXRange + prevY0, prevY1 = prevYRange + currY0, currY1 = currYRange + + deltaX = currX0 - prevX0 + deltaY = currY0 - prevY0 + + winViewRange = self.exportToImageWindow.viewRange() + winXRange = winViewRange[0] + winYRange = winViewRange[1] + winX0, winX1 = winXRange + winY0, winY1 = winYRange + + newX0 = winX0 + deltaX + newX1 = winX1 + deltaX + newY0 = winY0 + deltaY + newY1 = winY1 + deltaY + + self.exportToImageWindow.setViewRange( + (newX0, newX1), (newY0, newY1), emitSignal=False + ) diff --git a/cellacdc/mixins/frame_navigation.py b/cellacdc/mixins/frame_navigation.py new file mode 100644 index 000000000..3be3e2005 --- /dev/null +++ b/cellacdc/mixins/frame_navigation.py @@ -0,0 +1,1139 @@ +"""Qt view adapter for frame and position navigation.""" + +from __future__ import annotations + +from collections import Counter +from functools import partial + +import numpy as np +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import QAbstractSlider, QCheckBox + +from cellacdc import QtScoped, apps, exception_handler, html_utils, printl, widgets + + +SliderSingleStepAdd = QtScoped.SliderSingleStepAdd() +SliderSingleStepSub = QtScoped.SliderSingleStepSub() +SliderPageStepAdd = QtScoped.SliderPageStepAdd() +SliderPageStepSub = QtScoped.SliderPageStepSub() +SliderMove = QtScoped.SliderMove() + +from .graphics import Graphics +from .label_editing import LabelEditing + + +class FrameNavigation(Graphics, LabelEditing): + """Extracted from guiWin.""" + + def PosScrollBarAction(self, action): + if action == SliderSingleStepAdd: + self.next_cb() + elif action == SliderSingleStepSub: + self.prev_cb() + elif action == SliderPageStepAdd: + self.PosScrollBarReleased() + elif action == SliderPageStepSub: + self.PosScrollBarReleased() + + def PosScrollBarMoved(self, pos_n): + if self.navigateScrollBarStartedMoving: + self.store_data() + + self.pos_i = pos_n - 1 + self.updateFramePosLabel() + proceed_cca, never_visited = self.get_data() + self.updateAllImages() + self.setStatusBarLabel() + self.navigateScrollBarStartedMoving = False + + def PosScrollBarReleased(self): + self.navigateScrollBarStartedMoving = True + if self.pos_i == self.navigateScrollBar.sliderPosition() - 1: + # Slider released without changing value --> do nothing + return + + self.pos_i = self.navigateScrollBar.sliderPosition() - 1 + self.updateFramePosLabel() + self.updatePos() + + def _setViewRangeSwitchPlane(self, previousPlane): + posData = self.data[self.pos_i] + SizeZ = posData.SizeZ + SizeY, SizeX = self.img1.image.shape[:2] + currentPlane = self.switchPlaneCombobox.plane() + if previousPlane == "xy": + if currentPlane == "zy": + self.ax1.setRange(xRange=self.yRangePrev) + unusedRange = np.clip(self.xRangePrev, 0, SizeX) + elif currentPlane == "zx": + self.ax1.setRange(xRange=self.xRangePrev) + unusedRange = np.clip(self.yRangePrev, 0, SizeY) + elif previousPlane == "zy": + if currentPlane == "xy": + self.ax1.setRange(yRange=self.xRangePrev) + unusedRange = np.clip(self.yRangePrev, 0, SizeZ) + elif currentPlane == "zx": + self.ax1.setRange(yRange=self.yRangePrev) + unusedRange = np.clip(self.xRangePrev, 0, SizeY) + elif previousPlane == "zx": + if currentPlane == "xy": + self.ax1.setRange(xRange=self.xRangePrev) + unusedRange = np.clip(self.yRangePrev, 0, SizeZ) + elif currentPlane == "zy": + self.ax1.setRange(yRange=self.yRangePrev) + unusedRange = np.clip(self.xRangePrev, 0, SizeX) + + sliceValue = round((unusedRange[0] + unusedRange[1]) / 2) + self.zSliceScrollBar.setSliderPosition(sliceValue) + self.update_z_slice(self.zSliceScrollBar.sliderPosition()) + + def apply_tools_on_new_frame(self): + mode = str(self.modeComboBox.currentText()) + if mode != "Segmentation and Tracking": + return + posData = self.data[self.pos_i] + if ( + not (posData.last_tracked_i <= posData.frame_i) + or posData.frame_i == self.lastFrameRanOnFirstVisitTools + ): + return + + self.lastFrameRanOnFirstVisitTools = posData.frame_i + for name, checkbox in self.applyToolNewFrameActions.items(): + if not checkbox.isChecked(): + continue + + tool_button = self.applyToolNewFrameButtons[name] + try: + if hasattr(tool_button, "click"): + tool_button.click() + elif hasattr(tool_button, "trigger"): + tool_button.trigger() + else: + printl(f"Warning: {name} has no click or trigger method") + except Exception as e: + self.logger.info(f"Error applying tool {name}: {e}") + + def askInitCcaFirstFrame(self): + mode = str(self.modeComboBox.currentText()) + if mode != "Cell cycle analysis": + return True + + posData = self.data[self.pos_i] + if posData.frame_i != 0: + return True + + editCcaWidget = apps.editCcaTableWidget( + posData.cca_df, + posData.SizeT, + parent=self, + title="Initialize cell cycle annotations", + ) + editCcaWidget.sigApplyChangesFutureFrames.connect( + self.applyManualCcaChangesFutureFrames + ) + editCcaWidget.exec_() + if editCcaWidget.cancel: + self.resetNavigateFramesScrollbar() + return False + + if posData.cca_df is not None: + is_cca_same_as_stored = (posData.cca_df == editCcaWidget.cca_df).all( + axis=None + ) + if not is_cca_same_as_stored: + reinit_cca = self.warnEditingWithCca_df( + "Re-initialize cell cyle annotations first frame", + return_answer=True, + ) + if reinit_cca: + self.resetCcaFuture(0) + + posData.cca_df = editCcaWidget.cca_df + self.store_cca_df() + + return True + + def askInitLinTreeFirstFrame(self): + mode = str(self.modeComboBox.currentText()) + if mode != "Normal division: Lineage tree": + return True + + posData = self.data[self.pos_i] + if posData.frame_i != 0: + return True + + if self.lineage_tree is None: + self.initLinTree() + + return True + + def checkIfFutureFrameManualAnnotPastFrames(self): + if not self.manualAnnotPastButton.isChecked(): + return True + + posData = self.data[self.pos_i] + frame_to_restore = self.manualAnnotState.get("frame_i_to_restore") + if posData.frame_i <= frame_to_restore: + return True + + warn_txt = ( + "WARNING: Cannot navigate to future frames while in manual annotation mode." + ) + self.logger.info(warn_txt) + self.statusBarLabel.setText(f'

    {warn_txt}

    ') + + return False + + def connectScrollbars(self): + self.t_label.show() + self.navigateScrollBar.show() + self.navigateScrollBar.setDisabled(False) + + if self.data[0].SizeZ > 1: + self.enableZstackWidgets(True) + self.zSliceScrollBar.setMaximum(self.data[0].SizeZ - 1) + self.zSliceSpinbox.setMaximum(self.data[0].SizeZ) + self.SizeZlabel.setText(f"/{self.data[0].SizeZ}") + try: + self.zSliceScrollBar.actionTriggered.disconnect() + self.zSliceScrollBar.sliderReleased.disconnect() + self.zProjComboBox.currentTextChanged.disconnect() + self.zProjComboBox.activated.disconnect() + self.switchPlaneCombobox.sigPlaneChanged.disconnect() + self.zProjLockViewButton.toggled.disconnect() + except Exception as e: + pass + self.zSliceScrollBar.actionTriggered.connect( + self.zSliceScrollBarActionTriggered + ) + self.zSliceScrollBar.sliderReleased.connect(self.zSliceScrollBarReleased) + self.zProjComboBox.currentTextChanged.connect(self.updateZproj) + self.zProjComboBox.activated.connect(self.clearComboBoxFocus) + self.switchPlaneCombobox.sigPlaneChanged.connect(self.switchViewedPlane) + self.zProjLockViewButton.toggled.connect(self.zProjLockViewToggled) + + posData = self.data[self.pos_i] + if posData.SizeT == 1: + self.t_label.setText("Position n.") + self.navigateScrollBar.setMinimum(1) + self.navigateScrollBar.setMaximum(len(self.data)) + self.navigateScrollBar.setAbsoluteMaximum(len(self.data)) + self.navSpinBox.setMaximum(len(self.data)) + self.navigateScrollBar.connectEvents( + { + "sliderMoved": self.PosScrollBarMoved, + "sliderReleased": self.PosScrollBarReleased, + "actionTriggered": self.PosScrollBarAction, + } + ) + else: + self.navigateScrollBar.setMinimum(1) + self.navigateScrollBar.setAbsoluteMaximum(posData.SizeT) + self.rightImageFramesScrollbar.setMinimum(1) + self.rightImageFramesScrollbar.setMaximum(posData.SizeT) + if posData.last_tracked_i is not None: + self.navigateScrollBar.setMaximum(posData.last_tracked_i + 1) + self.navSpinBox.setMaximum(posData.last_tracked_i + 1) + self.t_label.setText("Frame n.") + self.navigateScrollBar.connectEvents( + { + "sliderMoved": self.framesScrollBarMoved, + "sliderReleased": self.framesScrollBarReleased, + "actionTriggered": self.framesScrollBarActionTriggered, + } + ) + self.rightImageFramesScrollbar.connectValueChanged( + self.rightImageFramesScrollbarValueChanged + ) + + def extendSegmDataIfNeeded(self, stopFrameNum): + posData = self.data[self.pos_i] + segmSizeT = len(posData.segm_data) + if stopFrameNum <= segmSizeT: + return + numFramesToAdd = stopFrameNum - segmSizeT + posData.allData_li.extend( + [utils.get_empty_stored_data_dict() for i in range(numFramesToAdd)] + ) + lab_shape = posData.segm_data[0].shape + shapeToAdd = (numFramesToAdd, *lab_shape) + additionalSegmData = np.zeros(shapeToAdd, dtype=posData.segm_data.dtype) + extendedSegmData = np.concatenate((posData.segm_data, additionalSegmData)) + posData.segm_data = extendedSegmData + + def framesScrollBarActionTriggered(self, action): + if action == SliderSingleStepAdd: + # Clicking on dialogs triggered by next_cb might trigger + # pressEvent of navigateQScrollBar, avoid that + self.navigateScrollBar.disableCustomPressEvent() + self.next_cb() + QTimer.singleShot(100, self.navigateScrollBar.enableCustomPressEvent) + elif action == SliderSingleStepSub: + self.prev_cb() + elif action == SliderPageStepAdd: + self.framesScrollBarReleased(do_store_data=True) + elif action == SliderPageStepSub: + self.framesScrollBarReleased(do_store_data=True) + + def framesScrollBarMoved(self, frame_n): + if self.navigateScrollBarStartedMoving: + mode = str(self.modeComboBox.currentText()) + if mode != "Viewer": + self.store_data(debug=False) + + posData = self.data[self.pos_i] + posData.frame_i = frame_n - 1 + if posData.allData_li[posData.frame_i]["labels"] is None: + if posData.frame_i < len(posData.segm_data): + posData.lab = posData.segm_data[posData.frame_i] + else: + posData.lab = np.zeros_like(posData.segm_data[0]) + else: + posData.lab = posData.allData_li[posData.frame_i]["labels"] + + self.setImageImg1() + if self.overlayButton.isChecked(): + self.setOverlayImages() + + if self.navigateScrollBarStartedMoving: + self.clearAllItems() + + self.navSpinBox.setValueNoEmit(posData.frame_i + 1) + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.img2.setImage(posData.lab, z=self.z_lab(), autoLevels=False) + self.updateLookuptable() + self.updateFramePosLabel() + self.updateViewerWindow() + self.updateTimestampFrame() + self.updateHighlightedAxis() + self.navigateScrollBarStartedMoving = False + + def framesScrollBarReleased(self, do_store_data=False): + posData = self.data[self.pos_i] + if posData.frame_i == self.navigateScrollBar.sliderPosition() - 1: + # Slider released without changing value --> do nothing + return + + mode = str(self.modeComboBox.currentText()) + if mode != "Viewer" and do_store_data: + self.store_data(debug=False) + + self.navigateScrollBarStartedMoving = True + posData.frame_i = self.navigateScrollBar.sliderPosition() - 1 + self.updateFramePosLabel() + proceed_cca, never_visited = self.get_data() + self.updateAllImages() + + def goToZsliceSearchedID(self, obj): + if not self.isSegm3D: + return + + current_z = self.z_lab() + nearest_nonzero_z = core.nearest_nonzero_z_idx_from_z_centroid( + obj, current_z=current_z + ) + if nearest_nonzero_z == current_z: + self.drawPointsLayers(computePointsLayers=True) + return + + self.zSliceScrollBar.setSliderPosition(nearest_nonzero_z) + self.update_z_slice(nearest_nonzero_z) + + def isNavigateActionOnNextFrame(self): + posData = self.data[self.pos_i] + if posData.SizeT == 1: + return False + + ax1_coords = self.getMouseDataCoordsRightImage() + if ax1_coords is None: + return False + + if not self.labelsGrad.showNextFrameAction.isEnabled(): + return False + + if not self.labelsGrad.showNextFrameAction.isChecked(): + return + + # Mouse is on right image and next frame action is checked + return True + + def manualAnnotRestoreLastTrackedFrame(self, last_tracked_i_to_restore): + if self.navigateScrollBar.maximum() - 1 <= last_tracked_i_to_restore: + return + + posData = self.data[self.pos_i] + for frame_i in range(last_tracked_i_to_restore + 1, posData.SizeT): + data_frame_i = utils.get_empty_stored_data_dict() + + data_frame_i["manually_edited_lab"] = posData.allData_li[frame_i][ + "manually_edited_lab" + ] + + posData.allData_li[frame_i] = data_frame_i + + self.navigateScrollBar.setMaximum(last_tracked_i_to_restore + 1) + self.navSpinBox.setMaximum(last_tracked_i_to_restore + 1) + + def navigateSpinboxEditingFinished(self): + if self.isSnapshot: + self.PosScrollBarReleased() + else: + self.framesScrollBarReleased() + + def navigateSpinboxValueChanged(self, value): + self.navigateScrollBar.setSliderPosition(value) + if self.isSnapshot: + self.PosScrollBarMoved(value) + else: + self.navigateScrollBarStartedMoving = True + self.framesScrollBarMoved(value) + + def nextActionTriggered(self): + if self.isNavigateActionOnNextFrame(): + self.rightImageFramesScrollbar.setValue( + self.rightImageFramesScrollbar.value() + 1 + ) + return + + stepAddAction = QAbstractSlider.SliderAction.SliderSingleStepAdd + if self.zKeptDown or self.zSliceCheckbox.isChecked(): + self.zSliceScrollBar.triggerAction(stepAddAction) + else: + self.navigateScrollBar.triggerAction(stepAddAction) + + def nextFrameImage(self, current_frame_i=None): + if not self.labelsGrad.showNextFrameAction.isEnabled(): + return + + if not self.labelsGrad.showNextFrameAction.isChecked(): + return + + posData = self.data[self.pos_i] + if current_frame_i is None: + current_frame_i = posData.frame_i + + next_frame_i = current_frame_i + 1 + if next_frame_i >= len(posData.img_data): + img = posData.img_data[-1] + else: + img = posData.img_data[next_frame_i] + + if posData.SizeZ > 1: + img = self.get_2Dimg_from_3D(img, isLayer0=True) + + # img = self.normalizeIntensities(img) + + return img + + def next_cb(self): + if self.isSnapshot: + self.next_pos() + else: + self.next_frame() + if self.curvToolButton.isChecked(): + self.curvTool_cb(True) + + self.updatePropsWidget("") + + def next_frame(self, warn=True): + proceed = self.checkIfFutureFrameManualAnnotPastFrames() + if not proceed: + return + + proceed = self.askInitCcaFirstFrame() + if not proceed: + return + + proceed = self.askInitLinTreeFirstFrame() + if not proceed: + return + + mode = str(self.modeComboBox.currentText()) + posData = self.data[self.pos_i] + + if posData.frame_i >= posData.SizeT - 1: + # Store data for current frame + if mode != "Viewer": + self.store_data(debug=False) + msg = "You reached the last segmented frame!" + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + return + + proceed = self.warnLostObjects() + if not proceed: + self.resetNavigateScrollbar() + return + + # Store data for current frame + if mode != "Viewer": + self.store_data(debug=False) + + self.askLineageTreeChanges() + posData.frame_i += 1 + self.removeAlldelROIsCurrentFrame() + proceed_cca, never_visited = self.get_data() + if not proceed_cca: + posData.frame_i -= 1 + self.get_data() + self.logger.info("No data for current frame. ") + return + + if mode == "Segmentation and Tracking" or self.isSnapshot: + self.addExistingDelROIs() + + self.updatePreprocessPreview() + self.updateCombineChannelsPreview() + self.postProcessing() + self.tracking(storeUndo=True, wl_update=False) + notEnoughG1Cells, proceed = self.attempt_auto_cca() + if notEnoughG1Cells or not proceed: + posData.frame_i -= 1 + self.get_data() + self.setAllTextAnnotations() + self.logger.info("Not enough G1 cells to compute cell cycle annotations.") + return + + self.store_zslices_rp() + self.resetExpandLabel() + self.updateAllImages() + self.updateHighlightedAxis() + self.updateViewerWindow() + self.updateLastVisitedFrame(last_visited_frame_i=posData.frame_i - 1) + self.setNavigateScrollBarMaximum() + self.updateScrollbars() + self.computeSegm() + self.initGhostObject() + self.whitelistPropagateIDs() + self.zoomToCells() + self.updateItemsMousePos() + self.updateObjectCounts() + + self.apply_tools_on_new_frame() + + def next_pos(self): + self.store_data(debug=True, autosave=False) + prev_pos_i = self.pos_i + if self.pos_i < self.num_pos - 1: + self.pos_i += 1 + self.updateSegmDataAutoSaveWorker() + else: + self.logger.info("You reached last position.") + self.pos_i = 0 + self.updatePos() + + def onZsliceSpinboxValueChange(self, value): + self.zSliceScrollBar.setSliderPosition(value - 1) + + def prevActionTriggered(self): + if self.isNavigateActionOnNextFrame(): + self.rightImageFramesScrollbar.setValue( + self.rightImageFramesScrollbar.value() - 1 + ) + return + + stepSubAction = QAbstractSlider.SliderAction.SliderSingleStepSub + if self.zKeptDown or self.zSliceCheckbox.isChecked(): + self.zSliceScrollBar.triggerAction(stepSubAction) + else: + self.navigateScrollBar.triggerAction(stepSubAction) + + def prev_cb(self): + if self.isSnapshot: + self.prev_pos() + else: + self.prev_frame() + if self.curvToolButton.isChecked(): + self.curvTool_cb(True) + + self.updatePropsWidget("") + + def prev_frame(self): + posData = self.data[self.pos_i] + if posData.frame_i <= 0: + msg = "You reached the first frame!" + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + return + + # Store data for current frame + mode = str(self.modeComboBox.currentText()) + if mode != "Viewer": + self.store_data(debug=False) + + self.removeAlldelROIsCurrentFrame() + self.askLineageTreeChanges() + posData.frame_i -= 1 + _, never_visited = self.get_data() + + if mode == "Segmentation and Tracking" or self.isSnapshot: + self.addExistingDelROIs() + + self.resetExpandLabel() + self.updatePreprocessPreview() + self.updateCombineChannelsPreview() + self.postProcessing() + self.tracking() + self.whitelistPropagateIDs(update_lab=True) + self.updateAllImages() + self.updateScrollbars() + self.updateHighlightedAxis() + self.zoomToCells() + self.initGhostObject() + self.updateViewerWindow() + self.updateItemsMousePos() + self.updateObjectCounts() + + def prev_pos(self): + self.store_data(debug=False, autosave=False) + prev_pos_i = self.pos_i + if self.pos_i > 0: + self.pos_i -= 1 + self.updateSegmDataAutoSaveWorker() + else: + self.logger.info("You reached first position.") + self.pos_i = self.num_pos - 1 + self.updatePos() + + def reInitLastSegmFrame( + self, checked=True, from_frame_i=None, updateImages=True, force=False + ): + if not force: + cancel = self.warnReinitLastSegmFrame() + if cancel: + self.logger.info("Re-initialization of last validated frame cancelled.") + return + + posData = self.data[self.pos_i] + if from_frame_i is None: + from_frame_i = posData.frame_i + + self.lastFrameRanOnFirstVisitTools = posData.frame_i + + self.updateLastCheckedFrameWidgets(from_frame_i) + posData.last_tracked_i = from_frame_i + self.navigateScrollBar.setMaximum(from_frame_i + 1) + self.navSpinBox.setMaximum(from_frame_i + 1) + # self.navigateScrollBar.setMinimum(1) + + # posData.tracked_lost_centroids[from_frame_i-1] = set() + for i in range(from_frame_i, posData.SizeT): + if posData.allData_li[i]["labels"] is None: + break + + posData.segm_data[i] = posData.allData_li[i]["labels"] + posData.allData_li[i] = utils.get_empty_stored_data_dict() + + posData.tracked_lost_centroids[i] = set() + posData.acdcTracker2stepsAnnotInfo.pop(i, None) + + if posData.acdc_df is not None: + frames = posData.acdc_df.index.get_level_values(0) + if from_frame_i in frames: + posData.acdc_df = posData.acdc_df.loc[:from_frame_i] + + self.removeAlldelROIsCurrentFrame() + + if not updateImages: + return + + self.updateAllImages() + + def resetAcceptedLostIDs(self, from_frame_i=None): + posData = self.data[self.pos_i] + if from_frame_i is None: + from_frame_i = posData.frame_i + + posData.tracked_lost_centroids[from_frame_i - 1] = set() + for i in range(from_frame_i, posData.SizeT): + posData.tracked_lost_centroids[i] = set() + + def resetNavigateFramesScrollbar(self, frame_i=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + self.navigateScrollBar.setValueNoSignal(frame_i + 1) + + def resetNavigateScrollbar(self): + try: + self.navigateScrollBar.blockSignals(True) + self.navigateScrollBar.actionTriggered.disconnect() + self.navigateScrollBar.sliderReleased.disconnect() + self.navigateScrollBar.sliderMoved.disconnect() + # self.navigateScrollBar.valueChanged.disconnect() + self.navigateScrollBar.setSliderPosition(self.navSpinBox.value()) + except Exception as e: + if "disconnect()" not in str(e): + printl(e) + pass + + self.navigateScrollBar.blockSignals(False) + self.navigateScrollBar.actionTriggered.connect( + self.framesScrollBarActionTriggered + ) + self.navigateScrollBar.sliderReleased.connect(self.framesScrollBarReleased) + self.navigateScrollBar.sliderMoved.connect(self.framesScrollBarMoved) + + def rightImageFramesScrollbarValueChanged(self, value): + img = self.nextFrameImage(current_frame_i=value - 2) + self.img1.linkedImageItem.frame_i = value + self.img1.linkedImageItem.setImage(img) + + def setFrameNavigationDisabled(self, disable: bool, why: str): + """Disables the frame navigation buttons and scrollbar. + This is used when the user is not allowed to navigate through frames + Call again to unlock it again. Also sets tooltips to inform the user + + Parameters + ---------- + disable : bool + if the navigation should be disabled + why : str + the reason for disabeling the navigation. + """ + + if disable: + self.whyNavigateDisabled.add(why) + else: + try: + self.whyNavigateDisabled.remove(why) + except KeyError: + pass + + if len(self.whyNavigateDisabled) == 0: + disable = False + else: + disable = True + + # Apply the disable/enable state + self.prevAction.setDisabled(disable) + self.nextAction.setDisabled(disable) + self.navigateScrollBar.setDisabled(disable) + + # Set appropriate tooltip + if not disable: + self.navigateScrollBar.setToolTip( + "NOTE: The maximum frame number that can be visualized with this " + "scrollbar\n" + "is the last visited frame with the selected mode\n" + '(see "Mode" selector on the top-right).\n\n' + "If the scrollbar does not move it means that you never visited\n" + "any frame with current mode.\n\n" + 'Note that the "Viewer" mode allows you to scroll ALL frames.' + ) + return + + txt = f"Frame navigation disabled: {self.whyNavigateDisabled}" + self.logger.info(txt) + self.navigateScrollBar.setToolTip(txt) + + def setNavigateScrollBarMaximum(self): + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Segmentation and Tracking": + if posData.last_tracked_i is not None: + if posData.frame_i > posData.last_tracked_i: + self.navigateScrollBar.setMaximum(posData.frame_i + 1) + self.navSpinBox.setMaximum(posData.frame_i + 1) + else: + self.navigateScrollBar.setMaximum(posData.last_tracked_i + 1) + self.navSpinBox.setMaximum(posData.last_tracked_i + 1) + else: + self.navigateScrollBar.setMaximum(posData.frame_i + 1) + self.navSpinBox.setMaximum(posData.frame_i + 1) + + self.updateLastCheckedFrameWidgets(self.navSpinBox.maximum() - 1) + elif mode == "Cell cycle analysis": + if posData.frame_i > self.last_cca_frame_i: + self.navigateScrollBar.setMaximum(posData.frame_i + 1) + self.navSpinBox.setMaximum(posData.frame_i + 1) + else: + self.navigateScrollBar.setMaximum(self.last_cca_frame_i + 1) + self.navSpinBox.setMaximum(self.last_cca_frame_i + 1) + self.lastTrackedFrameLabel.setText( + f"Last cc annot. frame n. = {self.navSpinBox.maximum()}" + ) + elif mode == "Normal division: Lineage tree": + if self.lineage_tree is None: + self.navigateScrollBar.setMaximum(posData.frame_i + 1) + self.navSpinBox.setMaximum(posData.frame_i + 1) + else: + if self.lineage_tree.frames_for_dfs: + i = max(self.lineage_tree.frames_for_dfs) + else: + i = 0 + self.navigateScrollBar.setMaximum(i + 1) + self.navSpinBox.setMaximum(i + 1) + + def setSwitchViewedPlaneDisabled(self, disabled): + posData = self.data[self.pos_i] + if posData.SizeZ == 1: + return + + self.switchPlaneCombobox.setDisabled(disabled) + if disabled: + self.switchPlaneCombobox.setCurrentIndex(0) + + def setViewRangeSwitchPlane(self, previousPlane): + self.autoRange() + QTimer.singleShot(100, partial(self._setViewRangeSwitchPlane, previousPlane)) + + def setZprojDisabled(self, disabled, storePrevState=False): + self.combineChannelsAction.setDisabled(disabled) + for action in self.editToolBar.actions(): + button = self.editToolBar.widgetForAction(action) + if button == self.eraserButton: + continue + + if button in self.toolsActiveInProj3Dsegm: + continue + + try: + tooltip = button.toolTip() + prefix = "WARNING: Disabled due to projection mode\n\n" + if disabled: + if not tooltip.startswith(prefix): + button.setToolTip(prefix + tooltip) + else: + if tooltip.startswith(prefix): + button.setToolTip(tooltip[len(prefix) :]) + except: + pass + action.setDisabled(disabled) + try: + button.setChecked(False) + except Exception as err: + pass + + def switchViewedPlane(self, previousPlane, currentPlane): + posData = self.data[self.pos_i] + self.xRangePrev, self.yRangePrev = self.ax1.viewRange() + self.zSlicePrev = self.zSliceScrollBar.sliderPosition() + + self.zProjComboBox.setCurrentText("single z-slice") + depthAxes = self.switchPlaneCombobox.depthAxes() + self.onEscape() + self.initDelRoiLab() + if depthAxes != "z": + # Disable projections on plane that is not xy + self.zProjComboBox.setCurrentText("single z-slice") + self.zProjComboBox.setDisabled(True) + + # Clear annotations + self.clearAllItems() + self.setHighlightID(False) + + # Disable annotations on a plane that is not yz + self.setDrawNothingAnnotations() + self.setDisabledAnnotCheckBoxesLeft(True) + self.setDisabledAnnotCheckBoxesRight(True) + self.setEnabledAnnotCheckBoxesLeftZdepthAxes() + self.overlayButtonPrevState = self.overlayButton.isChecked() + self.overlayButton.setChecked(False) + self.overlayButton.setDisabled(True) + else: + self.zProjComboBox.setDisabled(False) + self.restoreAnnotationsOptions() + self.setDisabledAnnotCheckBoxesLeft(False) + self.setDisabledAnnotCheckBoxesRight(False) + self.overlayButton.setDisabled(False) + if self.overlayButtonPrevState: + self.overlayButton.setChecked(self.overlayButtonPrevState) + self.updateZsliceScrollbar(posData.frame_i) + + SizeY, SizeX = posData.img_data[posData.frame_i].shape[-2:] + + if depthAxes != "z" and self.isSnapshot: + # Disable editing when the plane is not xy + self.disableEditingViewPlaneNotXY() + elif self.isSnapshot: + # Re-enable editing in snapshot mode when the plane is xy + self.setEnabledSnapshotMode() + + if depthAxes == "z": + maxSliceNum = posData.SizeZ + elif depthAxes == "y": + maxSliceNum = SizeY + else: + maxSliceNum = SizeX + + maxSliceText = f"/{maxSliceNum}" + self.SizeZlabel.setText(maxSliceText) + self.zSliceCheckbox.setText(f"{depthAxes}-slice") + self.zSliceScrollBar.setMaximum(maxSliceNum - 1) + self.zSliceSpinbox.setMaximum(maxSliceNum) + + self.initContoursImage() + self.updateAllImages() + QTimer.singleShot(200, partial(self.setViewRangeSwitchPlane, previousPlane)) + + def updateFramePosLabel(self): + if self.isSnapshot: + posData = self.data[self.pos_i] + self.navSpinBox.setValueNoEmit(self.pos_i + 1) + else: + posData = self.data[0] + self.navSpinBox.setValueNoEmit(posData.frame_i + 1) + + def updateItemsMousePos(self): + if self.brushButton.isChecked(): + self.updateBrushCursor(self.xHoverImg, self.yHoverImg) + + if self.eraserButton.isChecked(): + self.updateEraserCursor(self.xHoverImg, self.yHoverImg) + + def updateOverlayZproj(self, how): + if how.find("max") != -1 or how == "same as above": + self.overlay_z_label.setDisabled(True) + self.zSliceOverlay_SB.setDisabled(True) + else: + self.overlay_z_label.setDisabled(False) + self.zSliceOverlay_SB.setDisabled(False) + self.setOverlayImages() + + def updateOverlayZslice(self, z): + self.setOverlayImages() + + def updatePos(self): + self.clearUndoQueue() + self.setStatusBarLabel() + self.checkManageVersions() + self.removeAlldelROIsCurrentFrame() + self.resetManualBackgroundItems() + proceed_cca, never_visited = self.get_data(debug=False) + self.pointsLayerLoadedDfsToData() + self.flushDirtyPointsLayersAutosave() + self.initContoursImage() + self.initDelRoiLab() + self.initTextAnnot() + self.postProcessing() + self.updateScrollbars() + self.updatePreprocessPreview() + self.updateCombineChannelsPreview() + self.updateAllImages() + self.computeSegm() + self.zoomOut() + self.restartZoomAutoPilot() + self.initManualBackgroundObject() + self.updateObjectCounts() + self.updateItemsMousePos() + + def updateScrollbars(self): + self.updateItemsMousePos() + self.updateFramePosLabel() + posData = self.data[self.pos_i] + navPos = self.pos_i + 1 if self.isSnapshot else posData.frame_i + 1 + self.navigateScrollBar.setSliderPosition(navPos) + if posData.SizeZ > 1: + self.updateZsliceScrollbar(posData.frame_i) + idx = (posData.filename, posData.frame_i) + self.zSliceScrollBar.setMaximum(posData.SizeZ - 1) + self.zSliceSpinbox.setMaximum(posData.SizeZ) + self.SizeZlabel.setText(f"/{posData.SizeZ}") + + def updateViewerWindow(self): + if self.slideshowWin is None: + return + + if self.slideshowWin.linkWindow is None: + return + + if not self.slideshowWin.linkWindowCheckbox.isChecked(): + return + + posData = self.data[self.pos_i] + self.slideshowWin.frame_i = posData.frame_i + self.slideshowWin.update_img() + + def updateZproj(self, how): + for p, posData in enumerate(self.data[self.pos_i :]): + if self.zProjLockViewButton.isChecked(): + idx = [(posData.filename, frame_i) for frame_i in range(posData.SizeT)] + else: + idx = [(posData.filename, posData.frame_i)] + posData.segmInfo_df.loc[idx, "which_z_proj_gui"] = how + posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) + + posData = self.data[self.pos_i] + if how == "single z-slice": + self.zSliceScrollBar.setDisabled(False) + self.zSliceSpinbox.setDisabled(False) + self.zSliceCheckbox.setDisabled(False) + self.setZprojDisabled(False) + self.update_z_slice(self.zSliceScrollBar.sliderPosition()) + else: + self.zSliceScrollBar.setDisabled(True) + self.zSliceSpinbox.setDisabled(True) + self.zSliceCheckbox.setDisabled(True) + self.setZprojDisabled(self.isSegm3D) + self.updateAllImages() + + def update_z_slice(self, z): + posData = self.data[self.pos_i] + if self.switchPlaneCombobox.depthAxes() == "z": + if self.zProjLockViewButton.isChecked(): + idx = [(posData.filename, frame_i) for frame_i in range(posData.SizeT)] + else: + idx = [ + (posData.filename, frame_i) + for frame_i in range(posData.frame_i, posData.SizeT) + ] + posData.segmInfo_df.loc[idx, "z_slice_used_gui"] = z + + self.updatePreprocessPreview() + self.updateCombineChannelsPreview() + self.highlightedID = self.getHighlightedID() + self.updateAllImages( + computePointsLayers=False, computeContours=False, updateLookuptable=True + ) + self.updateItemsMousePos() + if self.isSegm3D: + self.updateObjectCounts() + + def warnLostObjects(self, do_warn=True): + if not do_warn: + return True + + if not self.warnLostCellsAction.isChecked(): + return True + + mode = str(self.modeComboBox.currentText()) + if not mode == "Segmentation and Tracking": + return True + + posData = self.data[self.pos_i] + if not posData.lost_IDs: + return True + + frame_i = posData.frame_i + try: + accepted_lost_IDs = posData.accepted_lost_IDs.get(frame_i, []) + already_accepted_lost = Counter(accepted_lost_IDs) == Counter( + posData.lost_IDs + ) + except AttributeError as err: + already_accepted_lost = False + + if already_accepted_lost: + return True + + self.nextAction.setDisabled(True) + self.prevAction.setDisabled(True) + self.navigateScrollBar.setDisabled(True) + + msg = widgets.myMessageBox() + warn_msg = html_utils.paragraph( + "Current frame (compared to previous frame) " + "has lost the following cells:

    " + f"{posData.lost_IDs}

    " + "Are you sure you want to continue?
    " + ) + checkBox = QCheckBox("Do not show again") + noButton, yesButton = msg.warning( + self, "Lost cells!", warn_msg, buttonsTexts=("No", "Yes"), widgets=checkBox + ) + doNotWarnLostCells = not checkBox.isChecked() + self.warnLostCellsAction.setChecked(doNotWarnLostCells) + if msg.clickedButton == noButton: + self.nextAction.setDisabled(False) + self.prevAction.setDisabled(False) + self.navigateScrollBar.setDisabled(False) + return False + + self.nextAction.setDisabled(False) + self.prevAction.setDisabled(False) + self.navigateScrollBar.setDisabled(False) + if not hasattr(posData, "accepted_lost_IDs"): + posData.accepted_lost_IDs = {} + if frame_i not in posData.accepted_lost_IDs: + posData.accepted_lost_IDs[frame_i] = [] + + posData.accepted_lost_IDs[frame_i].extend(posData.lost_IDs) + # This section is adding the lost cells to tracked_lost_centroids... TBH I dont know why this wasnt done in the first place + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[posData.frame_i - 1]["IDs_idxs"] + accepted_lost_centroids = { + tuple(int(val) for val in prev_rp[prev_IDs_idxs[ID]].centroid) + for ID in posData.lost_IDs + } + try: + posData.tracked_lost_centroids[frame_i] = posData.tracked_lost_centroids[ + frame_i + ] | (accepted_lost_centroids) + except KeyError: + posData.tracked_lost_centroids[frame_i] = accepted_lost_centroids + return True + + def warnReinitLastSegmFrame(self): + current_frame_n = self.navigateScrollBar.value() + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + Are you sure you want to re-initialize the last visited and + validated frame to number {current_frame_n}?

    + WARNING: If you save, all annotations after frame number + {current_frame_n} will be lost! + """) + msg.warning( + self, + "WARNING: Potential loss of data", + txt, + buttonsTexts=("Cancel", "Yes, I am sure"), + ) + return msg.cancel + + def zSliceScrollBarActionTriggered(self, action): + singleMove = ( + action == SliderSingleStepAdd + or action == SliderSingleStepSub + or action == SliderPageStepAdd + or action == SliderPageStepSub + ) + if singleMove: + self.update_z_slice(self.zSliceScrollBar.sliderPosition()) + elif action == SliderMove: + if self.zSliceScrollBarStartedMoving and self.isSegm3D: + self.clearAx1Items(onlyHideText=True) + self.clearAx2Items(onlyHideText=True) + posData = self.data[self.pos_i] + idx = (posData.filename, posData.frame_i) + z = self.zSliceScrollBar.sliderPosition() + if self.switchPlaneCombobox.depthAxes() == "z": + posData.segmInfo_df.at[idx, "z_slice_used_gui"] = z + self.zSliceSpinbox.setValueNoEmit(z + 1) + img = self._getImageupdateAllImages(None) + self.img1.setCurrentZsliceIndex(z) + self.img1.setImage( + img, + next_frame_image=self.nextFrameImage(), + scrollbar_value=posData.frame_i + 2, + ) + try: + self.setOverlayImages() + except Exception as err: + pass + + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.img2.setImage(posData.lab, z=z, autoLevels=False) + self.updateViewerWindow() + self.setTextAnnotZsliceScrolling() + self.setGraphicalAnnotZsliceScrolling() + self.setOverlayLabelsItems() + self.drawPointsLayers(computePointsLayers=False) + self.zSliceScrollBarStartedMoving = False + self.highlightSearchedID(self.highlightedID, force=True) + + def zSliceScrollBarReleased(self): + self.clearTempBrushImage() + self.zSliceScrollBarStartedMoving = True + self.update_z_slice(self.zSliceScrollBar.sliderPosition()) + + def storeViewRange(self): + if not hasattr(self, "isRangeReset"): + return + + if not self.isRangeReset: + return + self.ax1_viewRange = self.ax1.viewRange() + self.isRangeReset = False diff --git a/cellacdc/mixins/geometry.py b/cellacdc/mixins/geometry.py new file mode 100644 index 000000000..a413f5ca0 --- /dev/null +++ b/cellacdc/mixins/geometry.py @@ -0,0 +1,65 @@ +"""Mouse and interaction geometry helpers.""" + +from __future__ import annotations + +from qtpy.QtCore import Qt + +from cellacdc import is_mac + + +class Geometry: + """Extracted from guiWin.""" + + def isDefaultMiddleClick(self, mouseEvent, modifiers): + if is_mac: + middle_click = ( + mouseEvent.button() == Qt.MouseButton.LeftButton + and modifiers == Qt.ControlModifier + and not self.brushButton.isChecked() + ) + else: + middle_click = mouseEvent.button() == Qt.MouseButton.MiddleButton + return middle_click + + def isMiddleClick(self, mouseEvent, modifiers): + if self.delObjAction is None: + return self.isDefaultMiddleClick(mouseEvent, modifiers) + + delObjKeySequence, delObjQtButton = self.delObjAction + if delObjKeySequence is None: + # Setting only middle click on mac is allowed, however the + # delObjKeySequence is None and the tool button is never checked + isDelObjectActive = True + else: + isDelObjectActive = self.delObjToolAction.isChecked() + + mouseEventButton = self.changeRightClickToLeftOnMac(mouseEvent) + + middle_click = mouseEventButton == delObjQtButton and isDelObjectActive + + return middle_click + + def isPanImageClick(self, mouseEvent, modifiers): + left_click = mouseEvent.button() == Qt.MouseButton.LeftButton + return modifiers == Qt.AltModifier and left_click + + def middleClickText(self): + if self.delObjAction is None and is_mac: + return "Command + Left Click" + + if self.delObjAction is None: + return "Middle Click" + + delObjKeySequence, delObjQtButton = self.delObjAction + + if delObjQtButton == Qt.MouseButton.LeftButton: + buttonName = "Left click" + elif delObjQtButton == Qt.MouseButton.RightButton: + buttonName = "Right click" + else: + buttonName = "Middle click" + + if delObjKeySequence is None: + return buttonName + + return f"{delObjKeySequence.toString()} + {buttonName}" diff --git a/cellacdc/mixins/graphics.py b/cellacdc/mixins/graphics.py new file mode 100644 index 000000000..fbd374bab --- /dev/null +++ b/cellacdc/mixins/graphics.py @@ -0,0 +1,2789 @@ +"""Qt view adapter for graphics item construction workflows.""" + +from __future__ import annotations + +import traceback +from functools import partial + +import cv2 +import matplotlib +import numpy as np +import pyqtgraph as pg +from collections.abc import Iterable, Mapping +import skimage.exposure +import skimage.measure +from natsort import natsorted +from qtpy.QtCore import QEventLoop, QRect, QRectF, Qt, QThread, QTimer +from qtpy.QtGui import QColor, QCursor, QFont +from qtpy.QtWidgets import QAction, QActionGroup, QLabel, QMenu +from qtpy.QtWidgets import QGraphicsProxyWidget, QPushButton + +from cellacdc import ( + _warnings, + annotate, + apps, + colors, + html_utils, + utils, + widgets, + workers, +) + +_font = QFont() +_font.setPixelSize(11) + +from .points_layers import PointsLayers + + +class Graphics(PointsLayers): + """Extracted from guiWin.""" + + def _computeAllContours2D(self, dataDict, obj, z, obj_bbox, include_internal=False): + obj_image = self.getObjImage(obj.image, obj.bbox, z_slice=z) + if obj_image is None: + return + + all_external = False + local = False + contours = core.get_obj_contours( + obj_image=obj_image, + obj_bbox=obj_bbox, + local=local, + all_external=all_external, + ) + key = (obj.label, str(z), all_external, local) + dataDict["contours"][key] = contours + + all_external = True + local = False + contours = core.get_obj_contours( + obj_image=obj_image, + obj_bbox=obj_bbox, + local=local, + all_external=all_external, + all=include_internal, + ) + key = (obj.label, str(z), all_external, local) + dataDict["contours"][key] = contours + + return dataDict + + def _computeAllObjToObjCostPairs(self, posData): + self.computeAllObjCostPairsWorker.signals.initProgressBar.emit( + len(posData.allData_li) + ) + for frame_i, dataDict in enumerate(posData.allData_li): + if frame_i == 0: + continue + + rp = dataDict["regionprops"] + if rp is None: + break + + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + dist_matrix = core._compute_all_obj_to_obj_contour_dist_pairs( + dataDict["contours"], rp, prev_rp=prev_rp, restrict_search=True + ) + dataDict["obj_to_obj_dist_cost_matrix_df"] = dist_matrix + self.computeAllObjCostPairsWorker.signals.progressBar.emit(1) + self.computeAllObjCostPairsWorker.signals.initProgressBar.emit(0) + + def _gui_createGraphicsItems(self): + for _posData in self.data: + _posData.allData_li = [None] * _posData.SizeT + + posData = self.data[self.pos_i] + + allIDs, posData = core.count_objects(posData, self.logger.info) + + self.highLowResAction.setChecked(True) + numItems = len(allIDs) + if numItems > 1500: + cancel, switchToLowRes = _warnings.warnTooManyItems( + self, numItems, self.progressWin + ) + if cancel: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.loadingDataAborted() + return + if switchToLowRes: + self.highLowResAction.setChecked(False) + else: + # Many items requires pxMode active to be fast enough + self.pxModeAction.setChecked(True) + + self.logger.info(f"Creating graphical items...") + + self.ax1_contoursImageItem = pg.ImageItem() + + self.ax1_lostObjImageItem = pg.ImageItem() + self.ax2_lostObjImageItem = pg.ImageItem() + + self.ax1_lostTrackedObjImageItem = pg.ImageItem() + self.ax2_lostTrackedObjImageItem = pg.ImageItem() + + self.ax1_oldMothBudLinesItem = pg.ScatterPlotItem( + symbol="s", + pxMode=False, + brush=self.oldMothBudLineBrush, + size=self.mothBudLineWeight, + pen=None, + ) + self.ax1_newMothBudLinesItem = pg.ScatterPlotItem( + symbol="s", + pxMode=False, + brush=self.newMothBudLineBrush, + size=self.mothBudLineWeight, + pen=None, + ) + self.ax1_lostObjScatterItem = self.gui_getLostObjScatterItem() + self.yellowContourScatterItem = self.gui_getLostObjScatterItem() + + self.ax1_lostTrackedScatterItem = self.gui_getTrackedLostObjScatterItem() + self.greenContourScatterItem = self.gui_getTrackedLostObjScatterItem() + + brush = pg.mkBrush((0, 255, 0, 200)) + pen = pg.mkPen("g", width=1) + self.ccaFailedScatterItem = pg.ScatterPlotItem( + size=self.contLineWeight + 1, pen=pen, brush=brush, pxMode=False, symbol="s" + ) + + self.ax2_contoursImageItem = pg.ImageItem() + self.ax2_oldMothBudLinesItem = pg.ScatterPlotItem( + symbol="s", + pxMode=False, + brush=self.oldMothBudLineBrush, + size=self.mothBudLineWeight, + pen=None, + ) + self.ax2_newMothBudLinesItem = pg.ScatterPlotItem( + symbol="s", + pxMode=False, + brush=self.newMothBudLineBrush, + size=self.mothBudLineWeight, + pen=None, + ) + self.ax2_lostObjScatterItem = self.gui_getLostObjScatterItem() + self.ax2_lostTrackedScatterItem = self.gui_getTrackedLostObjScatterItem() + + self.gui_createTextAnnotItems(allIDs) # here + self.gui_setTextAnnotColors() # here + + self.setDisabledAnnotOptions(False) + + self.progressWin.mainPbar.setMaximum(0) + self.gui_addOverlayLayerItems() + self.gui_addTopLayerItems() + + self.gui_addCreatedAxesItems() + self.gui_add_ax_cursors() + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.loadingDataCompleted() + + def _updateContColour(self, color): + self.gui_createContourPens() + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.contoursColorButton.setColor(color) + + def _updateContLineThickness(self): + self.gui_createContourPens() + for act in self.imgGrad.contLineWightActionGroup.actions(): + if act == self.sender(): + act.setChecked(True) + act.toggled.connect(self.contLineWeightToggled) + + def _updateMothBudLineColour(self, color): + self.gui_createMothBudLinePens() + self.ax1_newMothBudLinesItem.setBrush(self.newMothBudLineBrush) + self.ax1_oldMothBudLinesItem.setBrush(self.oldMothBudLineBrush) + self.ax2_newMothBudLinesItem.setBrush(self.newMothBudLineBrush) + self.ax2_oldMothBudLinesItem.setBrush(self.oldMothBudLineBrush) + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.mothBudLineColorButton.setColor(color) + + def _updateMothBudLineSize(self, size): + self.gui_createMothBudLinePens() + + for act in self.imgGrad.mothBudLineWightActionGroup.actions(): + if act == self.sender(): + act.setChecked(True) + act.toggled.connect(self.mothBudLineWeightToggled) + + self.ax1_oldMothBudLinesItem.setSize(size) + self.ax1_newMothBudLinesItem.setSize(size) + self.ax2_oldMothBudLinesItem.setSize(size) + self.ax2_newMothBudLinesItem.setSize(size) + + def addActionsLutItemContextMenu(self, lutItem): + lutItem.gradient.menu.addSection("Visible channels: ") + for action in self.overlayContextMenu.actions(): + if action.isSeparator(): + continue + lutItem.gradient.menu.addAction(action) + lutItem.gradient.menu.addSeparator() + + annotationMenu = lutItem.gradient.menu.addMenu("Annotations settings") + ID_menu = annotationMenu.addMenu("IDs") + self.annotSettingsIDmenu = QActionGroup(annotationMenu) + labID_action = QAction("Show label's ID") + labID_action.setCheckable(True) + labID_action.setChecked(True) + labID_action.toggled.connect(self.annotLabelIDtreeToggled) + treeID_action = QAction("Show tree's ID") + treeID_action.setCheckable(True) + treeID_action.toggled.connect(self.annotLabelIDtreeToggled) + self.annotSettingsIDmenu.addAction(labID_action) + self.annotSettingsIDmenu.addAction(treeID_action) + ID_menu.addAction(labID_action) + ID_menu.addAction(treeID_action) + + ID_menu = annotationMenu.addMenu("Generation number") + self.annotSettingsGenNumMenu = QActionGroup(annotationMenu) + gen_num_action = QAction("Show default generation number") + gen_num_action.setCheckable(True) + gen_num_action.setChecked(True) + gen_num_action.toggled.connect(self.annotGenNumTreeToggled) + tree_gen_num_action = QAction("Show tree generation number") + tree_gen_num_action.setCheckable(True) + tree_gen_num_action.toggled.connect(self.annotGenNumTreeToggled) + self.annotSettingsGenNumMenu.addAction(gen_num_action) + self.annotSettingsGenNumMenu.addAction(tree_gen_num_action) + ID_menu.addAction(gen_num_action) + ID_menu.addAction(tree_gen_num_action) + + def addAlphaScrollbar(self, channelName, imageItem): + alphaScrollBar = widgets.ScrollBar(Qt.Horizontal) + imageItem.alphaScrollBar = alphaScrollBar + alphaScrollBar.channelName = channelName + + label = QLabel(f"Alpha {channelName}") + label.setFont(_font) + label.hide() + alphaScrollBar.imageItem = imageItem + alphaScrollBar.label = label + alphaScrollBar.setFixedHeight(self.h) + alphaScrollBar.hide() + alphaScrollBar.setMinimum(0) + alphaScrollBar.setMaximum(40) + alphaScrollBar.setValue(20) + alphaScrollBar.setToolTip( + f"Control the alpha value of the overlaid channel {channelName}.\n" + "alpha=0 results in NO overlay,\n" + "alpha=1 results in only fluorescence data visible" + ) + self.bottomLeftLayout.addWidget( + alphaScrollBar.label, self.alphaScrollbarRow, 0, alignment=Qt.AlignRight + ) + self.bottomLeftLayout.addWidget(alphaScrollBar, self.alphaScrollbarRow, 1, 1, 2) + + alphaScrollBar.valueChanged.connect( + partial(self.setOpacityOverlayLayersItems, scrollbar=alphaScrollBar) + ) + + self.alphaScrollbarRow += 1 + return alphaScrollBar + + def addFluoChNameContextMenuAction(self, ch_name): + posData = self.data[self.pos_i] + allTexts = [action.text() for action in self.chNamesQActionGroup.actions()] + if ch_name not in allTexts: + action = QAction(self) + action.setText(ch_name) + action.setCheckable(True) + self.chNamesQActionGroup.addAction(action) + action.setChecked(True) + self.fluoDataChNameActions.append(action) + + def addObjContourToContoursImage( + self, ID=0, obj=None, ax=0, thickness=None, color=None, force=False + ): + imageItem = self.getContoursImageItem(ax, force=force) + if imageItem is None: + return + + if obj is None: + obj = self.getObjFromID(ID) + if obj is None: + return + + contours = self.getObjContours(obj, all_external=True) + if thickness is None: + thickness = self.contLineWeight + if color is None: + color = self.contLineColor + + self.setContoursImage(imageItem, contours, thickness, color) + + def addOverlayLabelsToggled(self, checked, name=None): + if name is None: + name = self.sender().text() + if checked: + gradItem = self.overlayLabelsItems[name][-1] + drawMode = gradItem.drawModeActionGroup.checkedAction().text() + self.drawModeOverlayLabelsChannels[name] = drawMode + else: + self.drawModeOverlayLabelsChannels.pop(name) + self.hideOverlayLabelsItems(specific=[name]) + self.setOverlayLabelsItems() + + def askLabelsToOverlay(self): + selectOverlayLabels = widgets.QDialogListbox( + "Select segmentation to overlay", + "Select segmentation file to overlay:\n", + natsorted(self.existingSegmEndNames), + multiSelection=True, + parent=self, + ) + selectOverlayLabels.exec_() + if selectOverlayLabels.cancel: + return + + return selectOverlayLabels.selectedItemsText + + def askSelectOverlayChannel(self): + ch_names = [ch for ch in self.ch_names if ch != self.user_ch_name] + selectFluo = widgets.QDialogListbox( + "Select channel", + "Select channel names to overlay:\n", + ch_names, + multiSelection=True, + parent=self, + ) + selectFluo.exec_() + if selectFluo.cancel: + return + + return selectFluo.selectedItemsText + + def changeOverlayColor(self, button): + rgb = button.color().getRgb()[:3] + lutItem = self.overlayLayersItems[button.channel][1] + self.initColormapOverlayLayerItem(rgb, lutItem) + lutItem.overlayColorButton.setColor(rgb) + + def clearAllItems(self): + self.clearAx1Items() + self.clearAx2Items() + + def clearAx1Items(self, onlyHideText=False): + self.ax1_binnedIDs_ScatterPlot.clear() + self.ax1_ripIDs_ScatterPlot.clear() + self.labelsLayerImg1.clear() + self.labelsLayerRightImg.clear() + self.keepIDsTempLayerLeft.clear() + self.keepIDsTempLayerRight.clear() + self.highLightIDLayerImg1.clear() + self.highLightIDLayerRightImage.clear() + self.searchedIDitemLeft.clear() + self.searchedIDitemRight.clear() + self.ax1_contoursImageItem.clear() + self.ax1_lostObjImageItem.clear() + self.ax1_lostTrackedObjImageItem.clear() + self.textAnnot[0].clear() + self.ax1_newMothBudLinesItem.setData([], []) + self.ax1_oldMothBudLinesItem.setData([], []) + self.ax1_lostObjScatterItem.setData([], []) + self.ax1_lostTrackedScatterItem.setData([], []) + self.ccaFailedScatterItem.setData([], []) + self.yellowContourScatterItem.setData([], []) + + self.clearPointsLayers() + + self.clearOverlayLabelsItems() + self.clearManualBackgroundAnnotations() + self.clearCustomAnnot() + + def clearAx2Items(self, onlyHideText=False): + self.ax2_binnedIDs_ScatterPlot.clear() + self.ax2_ripIDs_ScatterPlot.clear() + self.ax2_contoursImageItem.clear() + self.ax2_lostObjImageItem.clear() + self.ax2_lostTrackedObjImageItem.clear() + self.textAnnot[1].clear() + self.ax2_newMothBudLinesItem.setData([], []) + self.ax2_oldMothBudLinesItem.setData([], []) + self.ax2_lostObjScatterItem.setData([], []) + + def clearComputedContours(self): + for posData in self.data: + for frame_i, dataDict in enumerate(posData.allData_li): + dataDict["contours"] = {} + + def clearObjContour(self, ID=0, obj=None, ax=0, debug=False, updateImage=True): + imageItem = self.getContoursImageItem(ax) + if imageItem is None: + return + + if ID > 0: + self.contoursImage[self.currentLab2D == ID] = [0, 0, 0, 0] + else: + obj_slice = self.getObjSlice(obj.slice) + obj_image = self.getObjImage(obj.image, obj.bbox) + self.contoursImage[obj_slice][obj_image] = [0, 0, 0, 0] + + if not updateImage: + return + + imageItem.setImage(self.contoursImage) + + def clearOverlayImageItems(self): + for items in self.overlayLayersItems.values(): + imageItem = items[0] + imageItem.clear() + + self.rgbaImg1.clear() + + def clearOverlayLabelsItems(self): + for segmEndname, drawMode in self.drawModeOverlayLabelsChannels.items(): + items = self.overlayLabelsItems[segmEndname] + imageItem, contoursItem, gradItem = items + imageItem.clear() + contoursItem.clear() + + def computeAllContours(self): + self.logger.info("Computing all contours...") + posData = self.data[self.pos_i] + zz = [None] + if self.isSegm3D: + zz.extend(range(posData.SizeZ)) + + include_internal = self.showAllContoursToggle.isChecked() + for frame_i, dataDict in enumerate(posData.allData_li): + lab = dataDict["labels"] + if lab is None: + break + + rp = dataDict["regionprops"] + if rp is None: + rp = skimage.measure.regionprops(lab) + + dataDict["contours"] = {} + for obj in rp: + obj_bbox = self.getObjBbox(obj.bbox) + for z in zz: + if not self.isObjVisible(obj.bbox, z_slice=z): + continue + + try: + self._computeAllContours2D( + dataDict, + obj, + z, + obj_bbox, + include_internal=include_internal, + ) + except Exception as err: + # Contours computation fails on weird objects + pass + + def computeAllObjCostPairsWorkerCritical(self, error): + self.computeAllObjCostPairsWorkerLoop.exit() + self.workerCritical(error) + + def computeAllObjCostPairsWorkerFinished(self, output): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.computeAllObjCostPairsWorkerLoop.exit() + + def computeAllObjToObjCostPairs(self): + desc = "Computing all object-to-object cost matrices..." + self.logger.info(desc) + posData = self.data[self.pos_i] + + self.progressWin = apps.QDialogWorkerProgress( + title=desc, parent=self, pbarDesc=desc + ) + self.progressWin.mainPbar.setMaximum(0) + self.progressWin.show(self.app) + + self.computeAllObjCostPairsThread = QThread() + self.computeAllObjCostPairsWorker = workers.SimpleWorker( + posData, self._computeAllObjToObjCostPairs + ) + + self.computeAllObjCostPairsWorker.moveToThread( + self.computeAllObjCostPairsThread + ) + + self.computeAllObjCostPairsWorker.signals.finished.connect( + self.computeAllObjCostPairsThread.quit + ) + self.computeAllObjCostPairsWorker.signals.finished.connect( + self.computeAllObjCostPairsWorker.deleteLater + ) + self.computeAllObjCostPairsThread.finished.connect( + self.computeAllObjCostPairsThread.deleteLater + ) + + self.computeAllObjCostPairsWorker.signals.critical.connect( + self.computeAllObjCostPairsWorkerCritical + ) + self.computeAllObjCostPairsWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.computeAllObjCostPairsWorker.signals.progressBar.connect( + self.workerUpdateProgressbar + ) + self.computeAllObjCostPairsWorker.signals.progress.connect(self.workerProgress) + self.computeAllObjCostPairsWorker.signals.finished.connect( + self.computeAllObjCostPairsWorkerFinished + ) + + self.computeAllObjCostPairsThread.started.connect( + self.computeAllObjCostPairsWorker.run + ) + self.computeAllObjCostPairsThread.start() + + self.computeAllObjCostPairsWorkerLoop = QEventLoop() + self.computeAllObjCostPairsWorkerLoop.exec_() + + def contLineWeightToggled(self, checked=True): + if not checked: + return + self.imgGrad.uncheckContLineWeightActions() + w = self.sender().lineWeight + self.df_settings.at["contLineWeight", "value"] = w + self.df_settings.to_csv(self.settings_csv_path) + self._updateContLineThickness() + self.updateAllImages() + + def createChannelNamesActions(self): + # LUT histogram channel name context menu actions + self.chNamesQActionGroup = QActionGroup(self) + self.chNamesQActionGroup.addAction(self.userChNameAction) + posData = self.data[self.pos_i] + for action in self.fluoDataChNameActions: + self.chNamesQActionGroup.addAction(action) + action.setChecked(False) + + self.userChNameAction.setChecked(True) + + for action in self.overlayContextMenu.actions(): + action.setChecked(False) + + def createOverlayContextMenu(self): + ch_names = [ch for ch in self.ch_names if ch != self.user_ch_name] + self.overlayContextMenu = QMenu() + self.overlayContextMenu.addSeparator() + self.checkedOverlayChannels = set() + for chName in ch_names: + action = QAction(chName, self.overlayContextMenu) + action.setCheckable(True) + action.toggled.connect(self.overlayChannelToggled) + self.overlayContextMenu.addAction(action) + + def createOverlayLabelsContextMenu(self, segmEndnames): + self.overlayLabelsContextMenu = QMenu() + self.overlayLabelsContextMenu.addSeparator() + self.drawModeOverlayLabelsChannels = {} + segmEndnames_extended = list(segmEndnames.copy()) + segmEndnames_extended = ["combined segm."] + segmEndnames_extended + for segmEndname in segmEndnames_extended: + action = QAction(segmEndname, self.overlayLabelsContextMenu) + if segmEndname == "combined segm.": + action.setCheckable(False) + self.combineSegmViewToggle = action + else: + action.setCheckable(True) + action.toggled.connect(self.addOverlayLabelsToggled) + self.overlayLabelsContextMenu.addAction(action) + + self.overlayLabelsContextMenu.addSeparator() + action = QAction("Edit appearance...", self.overlayLabelsContextMenu) + action.triggered.connect(self.editOverlayLabelsAppearance) + self.overlayLabelsContextMenu.addAction(action) + + def createOverlayLabelsItems(self, segmEndnames): + selectActionGroup = QActionGroup(self) + segmEndnames_extended = list(segmEndnames.copy()) + segmEndnames_extended = ["combined segm."] + segmEndnames_extended + for segmEndname in segmEndnames_extended: + action = QAction(segmEndname) + if segmEndname == "combined segm.": + action.setCheckable(False) + else: + action.setCheckable(True) + action.toggled.connect(self.setOverlayLabelsItemsVisible) + selectActionGroup.addAction(action) + self.selectOverlayLabelsActionGroup = selectActionGroup + + self.overlayLabelsItems = {} + for segmEndname in segmEndnames_extended: + imageItem = pg.ImageItem() + + gradItem = widgets.overlayLabelsGradientWidget( + imageItem, selectActionGroup, segmEndname + ) + gradItem.hide() + gradItem.drawModeActionGroup.triggered.connect( + self.overlayLabelsDrawModeToggled + ) + self.mainLayout.addWidget(gradItem, 0, 0) + + contoursItem = pg.ScatterPlotItem() + color = colors.get_complementary_color(self.contLineColor) + r, g, b, a = colors.rgba_str_to_values(color) + qcolor = QColor(r, g, b, a) + contoursItem.setData( + [], + [], + symbol="s", + pxMode=False, + size=self.contLineWeight * 2, + brush=pg.mkBrush(color=qcolor), + pen=pg.mkPen(width=3, color=qcolor), + tip=None, + ) + + items = (imageItem, contoursItem, gradItem) + self.overlayLabelsItems[segmEndname] = items + + def createUserChannelNameAction(self): + self.userChNameAction = QAction(self) + self.userChNameAction.setCheckable(True) + self.userChNameAction.setText(self.user_ch_name) + + def defaultRescaleIntensLutActionToggled(self, action): + how = action.text() + for rescaleIntensAction in self.imgGrad.rescaleActionGroup.actions(): + if how == rescaleIntensAction.text(): + rescaleIntensAction.setChecked(True) + rescaleIntensAction.trigger() + break + + for channel, items in self.overlayLayersItems.items(): + lutItem = items[1] + for rescaleIntensAction in lutItem.rescaleActionGroup.actions(): + if how == rescaleIntensAction.text(): + rescaleIntensAction.setChecked(True) + rescaleIntensAction.trigger() + break + + self.df_settings.at["default_rescale_intens_how", "value"] = how + self.df_settings.to_csv(self.settings_csv_path) + + def drawLostObjContoursImage( + self, + imageItem, + contours, + thickness=1, + color=(255, 165, 0, 255), # orange + ): + img = self.lostObjContoursImage + cv2.drawContours(img, contours, -1, color, thickness) + imageItem.setImage(img) + + def drawLostTrackedObjContoursImage(self, imageItem, contours): + thickness = 1 + color = (0, 255, 0, 255) # green + img = self.lostTrackedObjContoursImage + cv2.drawContours(img, contours, -1, color, thickness) + imageItem.setImage(img) + + def editOverlayLabelsAppearance(self, *args): + segmEndname = list(self.overlayLabelsItems.keys())[0] + contoursItem = self.overlayLabelsItems[segmEndname][1] + win = apps.OverlayLabelsAppearanceDialog( + scatterPlotItem=contoursItem, parent=self + ) + win.exec_() + if win.cancel: + return + + brush = win.properties["brush"] + pen = win.properties["pen"] + for items in self.overlayLabelsItems.values(): + imageItem, contoursItem, gradItem = items + contoursItem.setBrush(brush, update=False) + contoursItem.setPen(pen) + + def enableOverlayWidgets(self, enabled): + posData = self.data[self.pos_i] + if enabled: + self.overlayColorButton.setDisabled(False) + self.editOverlayColorAction.setDisabled(False) + + if posData.SizeZ == 1: + return + + self.zSliceOverlay_SB.setMaximum(posData.SizeZ - 1) + if self.zProjOverlay_CB.currentText().find("max") != -1: + self.overlay_z_label.setDisabled(True) + self.zSliceOverlay_SB.setDisabled(True) + else: + z = self.zSliceOverlay_SB.sliderPosition() + self.overlay_z_label.setText( + f"Overlay z-slice {z + 1:02}/{posData.SizeZ}" + ) + self.zSliceOverlay_SB.setDisabled(False) + self.overlay_z_label.setDisabled(False) + self.zSliceOverlay_SB.show() + self.overlay_z_label.show() + self.zProjOverlay_CB.show() + self.zSliceOverlay_SB.valueChanged.connect(self.updateOverlayZslice) + self.zProjOverlay_CB.currentTextChanged.connect(self.updateOverlayZproj) + self.zProjOverlay_CB.activated.connect(self.clearComboBoxFocus) + else: + self.zSliceOverlay_SB.setDisabled(True) + self.zSliceOverlay_SB.hide() + self.overlay_z_label.hide() + self.zProjOverlay_CB.hide() + self.overlayColorButton.setDisabled(True) + self.editOverlayColorAction.setDisabled(True) + + if posData.SizeZ == 1: + return + + self.zSliceOverlay_SB.valueChanged.disconnect() + self.zProjOverlay_CB.currentTextChanged.disconnect() + self.zProjOverlay_CB.activated.disconnect() + + def extendLabelsLUT(self, lenNewLut): + posData = self.data[self.pos_i] + # Build a new lut to include IDs > than original len of lut + if lenNewLut > len(self.lut): + numNewColors = lenNewLut - len(self.lut) + # Index original lut + _lut = np.zeros((lenNewLut, 3), np.uint8) + _lut[: len(self.lut)] = self.lut + # Pick random colors and append them at the end to recycle them + randomIdx = np.random.randint(0, len(self.lut), size=numNewColors) + for i, idx in enumerate(randomIdx): + rgb = self.lut[idx] + _lut[len(self.lut) + i] = rgb + self.lut = _lut + self.initLabelsImageItems() + return True + return False + + def getLabelsImageLut(self): + lut = np.zeros((len(self.lut), 4), dtype=np.uint8) + lut[:, -1] = 255 + lut[:, :-1] = self.lut + lut[0] = [0, 0, 0, 0] + return lut + + def getNearestLostObjID(self, y, x): + if not self.annotLostObjsToggle.isChecked(): + return + + posData = self.data[self.pos_i] + if not posData.lost_IDs: + return + + prev_lab = posData.allData_li[posData.frame_i - 1]["labels"] + if prev_lab is None: + return + + # if not hasattr(self, 'lostObjContoursImage'): + # self.store_data() + # posData.frame_i -= 1 + # self.get_data() + # self.store_data() + # posData.frame_i += 1 + # self.get_data() + # self.updateLostNewCurrentIDs() + # self.updateLostContoursImage(ax=0) + # self.updateLostContoursImage(ax=1) + # self.updateLostNewCurrentIDs() + + yy, xx, _ = np.nonzero(self.lostObjContoursImage) + lostObjsContourMask = np.zeros(self.currentLab2D.shape, dtype=bool) + lostObjsContourMask[yy.astype(int), xx.astype(int)] = True + + # Add accepted lost IDs + try: + yy, xx, _ = np.nonzero(self.lostTrackedObjContoursImage) + lostObjsContourMask[yy.astype(int), xx.astype(int)] = True + except Exception as err: + pass + + _, y_nearest, x_nearest = core.nearest_nonzero_2D( + lostObjsContourMask, y, x, return_coords=True + ) + nearest_ID = self.get_2Dlab(prev_lab)[y_nearest, x_nearest] + + if nearest_ID == 0: + return + + return nearest_ID + + def getObjContours( + self, + obj, + all_external=False, + local=False, + force_calc=True, + include_internal=False, + ): + posData = self.data[self.pos_i] + dataDict = posData.allData_li[posData.frame_i] + allContours = dataDict.get("contours") + if allContours is not None and not force_calc: + z = self.z_lab() + key = (obj.label, str(z), all_external, local) + contours = allContours.get(key) + if contours is not None: + return contours + + obj_image = self.getObjImage(obj.image, obj.bbox).astype(np.uint8) + obj_bbox = self.getObjBbox(obj.bbox) + try: + contours = core.get_obj_contours( + obj_image=obj_image, + obj_bbox=obj_bbox, + local=local, + all_external=all_external, + ) + except Exception as e: + if all_external: + contours = [] + else: + contours = None + self.logger.warning( + f"Object ID {obj.label} contours drawing failed. " + f"(bounding box = {obj.bbox})" + ) + return contours + + def getObjFromID(self, ID): + posData = self.data[self.pos_i] + try: + idx = posData.IDs_idxs[ID] + except KeyError as e: + # Object already cleared + return + + obj = posData.rp[idx] + return obj + + def getOlImg(self, key, frame_i=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + img = posData.ol_data[key][frame_i] + if posData.SizeZ > 1: + zProjHow = self.zProjOverlay_CB.currentText() + z = self.zSliceOverlay_SB.sliderPosition() + if zProjHow == "same as above": + zProjHow = self.zProjComboBox.currentText() + z = self.zSliceScrollBar.sliderPosition() + reconnect = False + try: + self.zSliceOverlay_SB.valueChanged.disconnect() + reconnect = True + except TypeError: + pass + self.zSliceOverlay_SB.setSliderPosition(z) + if reconnect: + self.zSliceOverlay_SB.valueChanged.connect(self.updateOverlayZslice) + if zProjHow == "single z-slice": + self.overlay_z_label.setText( + f"Overlay z-slice {z + 1:02}/{posData.SizeZ}" + ) + ol_img = img[z].copy() + elif zProjHow == "max z-projection": + ol_img = img.max(axis=0) + elif zProjHow == "mean z-projection": + ol_img = img.mean(axis=0) + elif zProjHow == "median z-proj.": + ol_img = np.median(img, axis=0) + else: + ol_img = img.copy() + + return ol_img + + def getOpacitiesFromAlphaScrollbarValues(self): + alpha_values = [] + activeOverlayImageItems = [] + for items in self.overlayLayersItems.values(): + imgItem, lutItem, alphaSB = items[:3] + _toolbutton = alphaSB.toolbutton + if not _toolbutton.isChecked() or not _toolbutton.isVisible(): + continue + + alpha_values.append(alphaSB.value() / alphaSB.maximum()) + activeOverlayImageItems.append(imgItem) + + opacities = colors.hierarchical_weights(alpha_values)[::-1] + channel_opacity_mapper = {} + for i, imgItem in enumerate(activeOverlayImageItems): + channel_opacity_mapper[imgItem.channelName] = opacities[i + 1] + + channel_opacity_mapper[self.user_ch_name] = opacities[0] + + return channel_opacity_mapper + + def getOverlayItems(self, channelName, index): + imageItem = widgets.OverlayImageItem() + imageItem.setOpacity(0.5) + imageItem.channelName = channelName + + lutItem = widgets.myHistogramLUTitem( + parent=self, name="image", axisLabel=channelName + ) + imageItem.lutItem = lutItem + for action in lutItem.rescaleActionGroup.actions(): + if action.text() == self.defaultRescaleIntensHow: + action.setChecked(True) + break + + lutItem.removeAddScaleBarAction() + lutItem.removeAddTimestampAction() + lutItem.restoreState(self.df_settings) + lutItem.setImageItem(imageItem) + lutItem.vb.raiseContextMenu = lambda x: None + initColor = self.overlayColors[channelName] + self.initColormapOverlayLayerItem(initColor, lutItem) + lutItem.addOverlayColorButton(initColor, channelName) + lutItem.initColor = initColor + lutItem.hide() + + lutItem.overlayColorButton.sigColorChanging.connect(self.changeOverlayColor) + lutItem.overlayColorButton.sigColorChanged.connect(self.saveOverlayColor) + + lutItem.invertBwAction.toggled.connect(self.setCheckedInvertBW) + + lutItem.contoursColorButton.disconnect() + lutItem.contoursColorButton.clicked.connect( + self.imgGrad.contoursColorButton.click + ) + for act in lutItem.contLineWightActionGroup.actions(): + act.toggled.connect(self.contLineWeightToggled) + + lutItem.mothBudLineColorButton.disconnect() + lutItem.mothBudLineColorButton.clicked.connect( + self.imgGrad.mothBudLineColorButton.click + ) + for act in lutItem.mothBudLineWightActionGroup.actions(): + act.toggled.connect(self.mothBudLineWeightToggled) + + lutItem.textColorButton.disconnect() + lutItem.textColorButton.clicked.connect(self.editTextIDsColorAction.trigger) + + lutItem.defaultSettingsAction.triggered.connect(self.restoreDefaultSettings) + lutItem.labelsAlphaSlider.valueChanged.connect(self.setValueLabelsAlphaSlider) + lutItem.sigRescaleIntes.connect( + partial(self.rescaleIntensitiesLut, imageItem=imageItem) + ) + if f"how_rescale_intensities_{channelName}" in self.df_settings.index: + how = self.df_settings.at[f"how_rescale_intensities_{channelName}", "value"] + lutItem.setRescaleIntensitiesHow(how) + + self.rescaleIntensChannelHowMapper[channelName] = "Rescale each 2D image" + + self.addActionsLutItemContextMenu(lutItem) + + alphaScrollBar = self.addAlphaScrollbar(channelName, imageItem) + + toolbutton = widgets.OverlayChannelToolButton( + channelName, lutItem, shortcut=str(index) + ) + toolbutton.action = self.overlayToolbar.addWidget(toolbutton) + toolbutton.setVisible(False) + + toolbutton.clicked.connect(self.overlayChannelToolbuttonClicked) + + alphaScrollBar.toolbutton = toolbutton + + return imageItem, lutItem, alphaScrollBar, toolbutton + + def getOverlayLabelsData(self, segmEndname): + posData = self.data[self.pos_i] + + if posData.ol_labels_data is None: + self.loadOverlayLabelsData(segmEndname) + elif segmEndname not in posData.ol_labels_data: + self.loadOverlayLabelsData(segmEndname) + + comb_seg = False + if "combined segm." == segmEndname: + comb_seg = True + if not self.isSegm3D: + zStackImg = self.data[0].SizeZ > 1 + if zStackImg: + selected_z_stack = self.zSliceScrollBar.sliderPosition() + else: + selected_z_stack = 0 + out = posData.ol_labels_data["combined segm."][posData.frame_i][ + selected_z_stack + ] + return out.astype(np.uint32) + + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + z = self.zSliceScrollBar.sliderPosition() + ol_lab = posData.ol_labels_data[segmEndname][posData.frame_i][z] + if comb_seg: + ol_lab = ol_lab.astype(np.uint32) + return ol_lab + else: + ol_lab = posData.ol_labels_data[segmEndname][posData.frame_i].max( + axis=0 + ) + if comb_seg: + ol_lab = ol_lab.astype(np.uint32) + return ol_lab + else: + return posData.ol_labels_data[segmEndname][posData.frame_i] + + def greedyShuffleCmap(self, updateImages=True): + lut = self.labelsGrad.item.colorMap().getLookupTable(0, 1, 255) + greedy_lut = colors.get_greedy_lut(self.currentLab2D, lut) + self.lut = greedy_lut + self.initLabelsImageItems() + if updateImages: + self.updateAllImages() + + def gui_addGraphicsItems(self): + # Auto image adjustment button + proxy = QGraphicsProxyWidget() + equalizeHistPushButton = QPushButton("Enhance contrast") + widthHint = equalizeHistPushButton.sizeHint().width() + equalizeHistPushButton.setMaximumWidth(widthHint) + equalizeHistPushButton.setCheckable(True) + if not self.invertBwAction.isChecked(): + equalizeHistPushButton.setStyleSheet( + "QPushButton {background-color: #282828; color: #F0F0F0;}" + ) + self.equalizeHistPushButton = equalizeHistPushButton + proxy.setWidget(equalizeHistPushButton) + self.graphLayout.addItem(proxy, row=0, col=0) + self.equalizeHistPushButton = equalizeHistPushButton + + # Left image histogram + self.imgGrad = widgets.myHistogramLUTitem(parent=self, name="image") + self.imgGrad.restoreState(self.df_settings) + self.lutItemsLayout.addItem(self.imgGrad, row=0, col=0) + for action in self.imgGrad.rescaleActionGroup.actions(): + if action.text() == self.defaultRescaleIntensHow: + action.setChecked(True) + self.rescaleIntensMenu.addAction(action) + + # Colormap gradient widget + self.labelsGrad = widgets.labelsGradientWidget(parent=self) + try: + stateFound = self.labelsGrad.restoreState(self.df_settings) + except Exception as e: + self.logger.exception(traceback.format_exc()) + print("======================================") + self.logger.info( + "Failed to restore previously used colormap. " + 'Using default colormap "viridis"' + ) + self.labelsGrad.item.loadPreset("viridis") + + # Add actions to imgGrad gradient item + self.imgGrad.gradient.menu.addAction(self.labelsGrad.showLabelsImgAction) + self.imgGrad.gradient.menu.addAction(self.labelsGrad.showRightImgAction) + self.imgGrad.gradient.menu.addAction(self.labelsGrad.showNextFrameAction) + + self.imgGrad.gradient.menu.addSeparator() + + self.imgGrad.gradient.menu.addMenu(self.exportMenu) + + # Add actions to view menu + self.viewMenu.addAction(self.labelsGrad.showLabelsImgAction) + self.viewMenu.addAction(self.labelsGrad.showRightImgAction) + + # Right image histogram + self.imgGradRight = widgets.baseHistogramLUTitem( + name="image", parent=self, gradientPosition="left" + ) + self.imgGradRight.gradient.menu.addAction(self.labelsGrad.showLabelsImgAction) + self.imgGradRight.gradient.menu.addAction(self.labelsGrad.showRightImgAction) + self.imgGradRight.gradient.menu.addAction(self.labelsGrad.showNextFrameAction) + + self.imgGrad.setChildLutItem(self.imgGradRight) + + # Title + self.titleLabel = pg.LabelItem( + justify="center", color=self.titleColor, size="14pt" + ) + self.graphLayout.addItem(self.titleLabel, row=0, col=1, colspan=2) + + def gui_addOverlayLayerItems(self): + for items in self.overlayLabelsItems.values(): + imageItem, contoursItem, gradItem = items + self.ax1.addItem(imageItem) + self.ax1.addItem(contoursItem) + + def gui_addTopLayerItems(self): + for item in self.topLayerItems: + self.ax1.addItem(item) + + for item in self.topLayerItemsRight: + self.ax2.addItem(item) + + def gui_connectGraphicsEvents(self): + self.img1.hoverEvent = self.gui_hoverEventImg1 + self.img2.hoverEvent = self.gui_hoverEventImg2 + self.img1.mousePressEvent = self.gui_mousePressEventImg1 + self.img1.mouseMoveEvent = self.gui_mouseDragEventImg1 + self.img1.mouseReleaseEvent = self.gui_mouseReleaseEventImg1 + self.img2.mousePressEvent = self.gui_mousePressEventImg2 + self.img2.mouseMoveEvent = self.gui_mouseDragEventImg2 + self.img2.mouseReleaseEvent = self.gui_mouseReleaseEventImg2 + self.rightImageItem.mousePressEvent = self.gui_mousePressRightImage + self.rightImageItem.mouseMoveEvent = self.gui_mouseDragRightImage + self.rightImageItem.mouseReleaseEvent = self.gui_mouseReleaseRightImage + self.rightImageItem.hoverEvent = self.gui_hoverEventRightImage + # self.imgGrad.gradient.showMenu = self.gui_gradientContextMenuEvent + self.imgGradRight.gradient.showMenu = self.gui_rightImageShowContextMenu + # self.imgGrad.vb.contextMenuEvent = self.gui_gradientContextMenuEvent + self.ax1.sigRangeChanged.connect(self.viewRangeChanged) + + def gui_createContourPens(self): + if "contLineWeight" in self.df_settings.index: + val = self.df_settings.at["contLineWeight", "value"] + self.contLineWeight = int(val) + else: + self.contLineWeight = 1 + if "contLineColor" in self.df_settings.index: + val = self.df_settings.at["contLineColor", "value"] + rgba = colors.rgba_str_to_values(val) + self.contLineColor = rgba + self.newIDlineColor = [min(255, v + 50) for v in self.contLineColor] + else: + self.contLineColor = (255, 0, 0, 200) + self.newIDlineColor = (255, 0, 0, 255) + + try: + self.imgGrad.contoursColorButton.sigColorChanging.disconnect() + self.imgGrad.contoursColorButton.sigColorChanged.disconnect() + except Exception as e: + pass + try: + for act in self.imgGrad.contLineWightActionGroup.actions(): + act.toggled.disconnect() + except Exception as e: + pass + for act in self.imgGrad.contLineWightActionGroup.actions(): + if act.lineWeight == self.contLineWeight: + act.setChecked(True) + self.imgGrad.contoursColorButton.setColor(self.contLineColor[:3]) + + self.imgGrad.contoursColorButton.sigColorChanging.connect(self.updateContColour) + self.imgGrad.contoursColorButton.sigColorChanged.connect(self.saveContColour) + for act in self.imgGrad.contLineWightActionGroup.actions(): + act.toggled.connect(self.contLineWeightToggled) + + # Contours pens + self.oldIDs_cpen = pg.mkPen(color=self.contLineColor, width=self.contLineWeight) + self.newIDs_cpen = pg.mkPen( + color=self.newIDlineColor, width=self.contLineWeight + 1 + ) + self.tempNewIDs_cpen = pg.mkPen(color="g", width=self.contLineWeight + 1) + + def gui_createGraphicsItems(self): + # Create enough PlotDataItems and LabelItems to draw contours and IDs. + self.progressWin = apps.QDialogWorkerProgress( + title="Creating axes items", + parent=self, + pbarDesc="Creating axes items (see progress in the terminal)...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + QTimer.singleShot(50, self._gui_createGraphicsItems) + + def gui_createLabelRoiItem(self): + Y, X = self.currentLab2D.shape + # Label ROI rectangle + pen = pg.mkPen("r", width=3) + self.labelRoiItem = widgets.ROI( + (0, 0), + (0, 0), + maxBounds=QRectF(QRect(0, 0, X, Y)), + scaleSnap=True, + translateSnap=True, + pen=pen, + hoverPen=pen, + ) + + posData = self.data[self.pos_i] + if self.labelRoiZdepthSpinbox.value() == 0: + self.labelRoiZdepthSpinbox.setValue(posData.SizeZ) + self.labelRoiZdepthSpinbox.setMaximum(posData.SizeZ + 1) + + def gui_createMothBudLinePens(self): + if "mothBudLineSize" in self.df_settings.index: + val = self.df_settings.at["mothBudLineSize", "value"] + self.mothBudLineWeight = int(val) + else: + self.mothBudLineWeight = 2 + + self.newMothBudlineColor = (255, 0, 0) + if "mothBudLineColor" in self.df_settings.index: + val = self.df_settings.at["mothBudLineColor", "value"] + rgba = colors.rgba_str_to_values(val) + self.mothBudLineColor = rgba[0:3] + else: + self.mothBudLineColor = (255, 165, 0) + + try: + self.imgGrad.mothBudLineColorButton.sigColorChanging.disconnect() + self.imgGrad.mothBudLineColorButton.sigColorChanged.disconnect() + except Exception as e: + pass + try: + for act in self.imgGrad.mothBudLineWightActionGroup.actions(): + act.toggled.disconnect() + except Exception as e: + pass + for act in self.imgGrad.mothBudLineWightActionGroup.actions(): + if act.lineWeight == self.mothBudLineWeight: + act.setChecked(True) + else: + act.setChecked(False) + self.imgGrad.mothBudLineColorButton.setColor(self.mothBudLineColor[:3]) + + self.imgGrad.mothBudLineColorButton.sigColorChanging.connect( + self.updateMothBudLineColour + ) + self.imgGrad.mothBudLineColorButton.sigColorChanged.connect( + self.saveMothBudLineColour + ) + for act in self.imgGrad.mothBudLineWightActionGroup.actions(): + act.toggled.connect(self.mothBudLineWeightToggled) + + # MOther-bud lines brushes + self.NewBudMoth_Pen = pg.mkPen( + color=self.newMothBudlineColor, + width=self.mothBudLineWeight + 1, + style=Qt.DashLine, + ) + self.OldBudMoth_Pen = pg.mkPen( + color=self.mothBudLineColor, width=self.mothBudLineWeight, style=Qt.DashLine + ) + + self.redDashLinePen = pg.mkPen(color="r", width=2, style=Qt.DashLine) + + self.oldMothBudLineBrush = pg.mkBrush(self.mothBudLineColor) + self.newMothBudLineBrush = pg.mkBrush(self.newMothBudlineColor) + + def gui_createOverlayColors(self): + fluoChannels = [ch for ch in self.ch_names if ch != self.user_ch_name] + self.logger.info(f"Number of TIFF files detected: {len(fluoChannels)}") + self.overlayColors = {} + for c, ch in enumerate(fluoChannels): + if f"{ch}_rgb" in self.df_settings.index: + rgb_text = self.df_settings.at[f"{ch}_rgb", "value"] + rgb = tuple([int(val) for val in rgb_text.split("_")]) + self.overlayColors[ch] = rgb + else: + if c >= len(self.overlayRGBs) - 1: + i = c / len(fluoChannels) + additional_color_num = c - len(self.overlayRGBs) + 1 + rgbs = [ + tuple([round(c * 255) for c in self.overlayCmap(i)][:3]) + for _ in range(additional_color_num) + ] + self.overlayRGBs.extend(rgbs) + rgb = colors.FLUO_CHANNELS_COLORS.get(ch, self.overlayRGBs[c]) + self.overlayColors[ch] = rgb + + def gui_createOverlayItems(self): + self.imgGrad.setAxisLabel(self.user_ch_name) + self.baseLayerToolbutton = widgets.OverlayChannelToolButton( + self.user_ch_name, self.imgGrad + ) + self.baseLayerToolbutton.setChecked(True) + self.baseLayerToolbutton.clicked.connect(self.overlayChannelToolbuttonClicked) + self.allOverlayToolbuttons = {self.user_ch_name: self.baseLayerToolbutton} + self.allOverlayToolbuttonsByIdx = {0: self.baseLayerToolbutton} + self.baseLayerToolbutton.action = self.overlayToolbar.addWidget( + self.baseLayerToolbutton + ) + self.overlayLayersItems = {} + self.overlayToolbarAreChannelsChecked = {} + fluoChannels = [ch for ch in self.ch_names if ch != self.user_ch_name] + for c, ch in enumerate(fluoChannels): + overlayItems = self.getOverlayItems(ch, c + 1) + self.overlayLayersItems[ch] = overlayItems + imageItem, lutItem = overlayItems[:2] + self.ax1.addItem(imageItem) + self.lutItemsLayout.addItem(lutItem, row=0, col=c + 1) + toolbutton = overlayItems[3] + self.allOverlayToolbuttons[ch] = toolbutton + self.allOverlayToolbuttonsByIdx[c + 1] = toolbutton + + self.overlayToolbuttonsSep = self.overlayToolbar.addSeparator() + self.plotsCol = len(self.ch_names) + + self.ax1.addImageItem(self.rgbaImg1) + + def gui_createPlotItems(self): + if "textIDsColor" in self.df_settings.index: + rgbString = self.df_settings.at["textIDsColor", "value"] + r, g, b = colors.rgb_str_to_values(rgbString) + self.gui_createTextAnnotColors(r, g, b, custom=True) + self.textIDsColorButton.setColor((r, g, b)) + else: + self.gui_createTextAnnotColors(0, 0, 0, custom=False) + + if "labels_text_color" in self.df_settings.index: + rgbString = self.df_settings.at["labels_text_color", "value"] + r, g, b = colors.rgb_str_to_values(rgbString) + self.ax2_textColor = (r, g, b) + else: + self.ax2_textColor = (255, 0, 0) + + self.emptyLab = np.zeros((2, 2), dtype=np.uint8) + + # Right image item linked to left + self.rightImageItem = widgets.ChildImageItem( + linkedScrollbar=self.rightImageFramesScrollbar + ) + self.imgGradRight.setImageItem(self.rightImageItem) + self.ax2.addItem(self.rightImageItem) + + # Left image + self.img1 = widgets.ParentImageItem( + linkedImageItem=self.rightImageItem, + activatingActions=( + self.labelsGrad.showRightImgAction, + self.labelsGrad.showNextFrameAction, + ), + ) + self.imgGrad.setImageItem(self.img1) + self.img1.lutItem = self.imgGrad + self.imgGrad.sigRescaleIntes.connect(self.rescaleIntensitiesLut) + self.ax1.addBaseImageItem(self.img1) + + # RGBA image for true transparency mode + self.rgbaImg1 = pg.ImageItem() + + # self.rgbaImg1.setImage(self.emptyLab) + + # Right image + self.img2 = widgets.labImageItem() + self.ax2.addItem(self.img2) + + self.topLayerItems = [] + self.topLayerItemsRight = [] + + self.gui_createContourPens() + self.gui_createMothBudLinePens() + + self.eraserCirclePen = pg.mkPen(width=1.5, color="r") + + # Temporary line item connecting bud to new mother + self.BudMothTempLine = pg.PlotDataItem(pen=self.NewBudMoth_Pen) + self.topLayerItems.append(self.BudMothTempLine) + + # Temporary line item connecting objects to merge + self.mergeObjsTempLine = widgets.PlotCurveItem(pen=self.redDashLinePen) + self.topLayerItems.append(self.mergeObjsTempLine) + + # Overlay segm. masks item + self.labelsLayerImg1 = widgets.BaseLabelsImageItem() + self.ax1.addItem(self.labelsLayerImg1) + + self.labelsLayerRightImg = widgets.BaseLabelsImageItem() + self.ax2.addItem(self.labelsLayerRightImg) + + # Red/green border rect item + self.GreenLinePen = pg.mkPen(color="g", width=2) + self.RedLinePen = pg.mkPen(color="r", width=2) + self.ax1BorderLine = pg.PlotDataItem() + self.topLayerItems.append(self.ax1BorderLine) + self.ax2BorderLine = pg.PlotDataItem(pen=pg.mkPen(color="r", width=2)) + self.topLayerItems.append(self.ax2BorderLine) + + # Brush/Eraser/Wand.. layer item + self.tempLayerRightImage = pg.ImageItem() + self.tempLayerImg1 = widgets.ParentImageItem( + linkedImageItem=self.tempLayerRightImage, + activatingAction=(self.labelsGrad.showRightImgAction,), + ) + self.topLayerItems.append(self.tempLayerImg1) + self.topLayerItemsRight.append(self.tempLayerRightImage) + + # Highlighted ID layer items + self.highLightIDLayerImg1 = pg.ImageItem() + self.topLayerItems.append(self.highLightIDLayerImg1) + + # Highlighted ID layer items + self.highLightIDLayerRightImage = pg.ImageItem() + self.topLayerItemsRight.append(self.highLightIDLayerRightImage) + + # Keep IDs temp layers + self.keepIDsTempLayerRight = pg.ImageItem() + self.keepIDsTempLayerLeft = widgets.ParentImageItem( + linkedImageItem=self.keepIDsTempLayerRight, + activatingAction=self.labelsGrad.showRightImgAction, + ) + self.topLayerItems.append(self.keepIDsTempLayerLeft) + self.topLayerItemsRight.append(self.keepIDsTempLayerRight) + + # Searched ID contour + self.searchedIDitemRight = pg.ScatterPlotItem() + self.searchedIDitemRight.setData( + [], + [], + symbol="s", + pxMode=False, + size=1, + brush=pg.mkBrush(color=(255, 0, 0, 150)), + pen=pg.mkPen(width=2, color="r"), + tip=None, + ) + self.searchedIDitemLeft = pg.ScatterPlotItem() + self.searchedIDitemLeft.setData( + [], + [], + symbol="s", + pxMode=False, + size=1, + brush=pg.mkBrush(color=(255, 0, 0, 150)), + pen=pg.mkPen(width=2, color="r"), + tip=None, + ) + self.topLayerItems.append(self.searchedIDitemLeft) + self.topLayerItemsRight.append(self.searchedIDitemRight) + + # Brush circle img1 + self.ax1_BrushCircle = pg.ScatterPlotItem() + self.ax1_BrushCircle.setData( + [], + [], + symbol="o", + pxMode=False, + brush=pg.mkBrush((255, 255, 255, 50)), + pen=pg.mkPen(width=2), + tip=None, + ) + self.topLayerItems.append(self.ax1_BrushCircle) + + # Eraser circle img1 + self.ax1_EraserCircle = pg.ScatterPlotItem() + self.ax1_EraserCircle.setData( + [], + [], + symbol="o", + pxMode=False, + brush=None, + pen=self.eraserCirclePen, + tip=None, + ) + self.topLayerItems.append(self.ax1_EraserCircle) + + self.ax1_EraserX = pg.ScatterPlotItem() + self.ax1_EraserX.setData( + [], + [], + symbol="x", + pxMode=False, + size=3, + brush=pg.mkBrush(color=(255, 0, 0, 50)), + pen=pg.mkPen(width=1, color="r"), + tip=None, + ) + self.topLayerItems.append(self.ax1_EraserX) + + # Brush circle img1 + self.labelRoiCircItemLeft = widgets.LabelRoiCircularItem() + self.labelRoiCircItemLeft.cleared = False + self.labelRoiCircItemLeft.setData( + [], + [], + symbol="o", + pxMode=False, + brush=pg.mkBrush(color=(255, 0, 0, 0)), + pen=pg.mkPen(color="r", width=2), + tip=None, + ) + self.labelRoiCircItemRight = widgets.LabelRoiCircularItem() + self.labelRoiCircItemRight.cleared = False + self.labelRoiCircItemRight.setData( + [], + [], + symbol="o", + pxMode=False, + brush=pg.mkBrush(color=(255, 0, 0, 0)), + pen=pg.mkPen(color="r", width=2), + tip=None, + ) + self.topLayerItems.append(self.labelRoiCircItemLeft) + self.topLayerItemsRight.append(self.labelRoiCircItemRight) + + self.ax1_binnedIDs_ScatterPlot = widgets.BaseScatterPlotItem() + self.ax1_binnedIDs_ScatterPlot.setData( + [], + [], + symbol="t", + pxMode=False, + brush=pg.mkBrush((255, 0, 0, 50)), + size=15, + pen=pg.mkPen(width=3, color="r"), + tip=None, + ) + self.topLayerItems.append(self.ax1_binnedIDs_ScatterPlot) + + self.ax1_ripIDs_ScatterPlot = widgets.BaseScatterPlotItem() + self.ax1_ripIDs_ScatterPlot.setData( + [], + [], + symbol="x", + pxMode=False, + brush=pg.mkBrush((255, 0, 0, 50)), + size=15, + pen=pg.mkPen(width=2, color="r"), + tip=None, + ) + self.topLayerItems.append(self.ax1_ripIDs_ScatterPlot) + + # Ruler plotItem and scatterItem + rulerPen = pg.mkPen(color="r", style=Qt.DashLine, width=2) + self.ax1_rulerPlotItem = widgets.RulerPlotItem(pen=rulerPen) + self.ax1_rulerAnchorsItem = pg.ScatterPlotItem( + symbol="o", + size=9, + brush=pg.mkBrush((255, 0, 0, 50)), + pen=pg.mkPen((255, 0, 0), width=2), + tip=None, + ) + self.topLayerItems.append(self.ax1_rulerPlotItem) + self.topLayerItems.append(self.ax1_rulerPlotItem.labelItem) + self.topLayerItems.append(self.ax1_rulerAnchorsItem) + + # Start point of polyline roi + self.ax1_point_ScatterPlot = pg.ScatterPlotItem() + self.ax1_point_ScatterPlot.setData( + [], + [], + symbol="o", + pxMode=False, + size=3, + pen=pg.mkPen(width=2, color="r"), + brush=pg.mkBrush((255, 0, 0, 50)), + tip=None, + ) + self.topLayerItems.append(self.ax1_point_ScatterPlot) + + # Experimental: scatter plot to add a point marker + self.startPointPolyLineItem = pg.ScatterPlotItem() + self.startPointPolyLineItem.setData( + [], + [], + symbol="o", + size=9, + pen=pg.mkPen(width=2, color="r"), + brush=pg.mkBrush((255, 0, 0, 50)), + hoverable=True, + hoverBrush=pg.mkBrush((255, 0, 0, 255)), + tip=None, + ) + self.topLayerItems.append(self.startPointPolyLineItem) + + # Eraser circle img2 + self.ax2_EraserCircle = pg.ScatterPlotItem() + self.ax2_EraserCircle.setData( + [], + [], + symbol="o", + pxMode=False, + brush=None, + pen=self.eraserCirclePen, + tip=None, + ) + self.ax2.addItem(self.ax2_EraserCircle) + self.ax2_EraserX = pg.ScatterPlotItem() + self.ax2_EraserX.setData( + [], + [], + symbol="x", + pxMode=False, + size=3, + brush=pg.mkBrush(color=(255, 0, 0, 50)), + pen=pg.mkPen(width=1.5, color="r"), + ) + self.ax2.addItem(self.ax2_EraserX) + + # Brush circle img2 + self.ax2_BrushCirclePen = pg.mkPen(width=2) + self.ax2_BrushCircleBrush = pg.mkBrush((255, 255, 255, 50)) + self.ax2_BrushCircle = pg.ScatterPlotItem() + self.ax2_BrushCircle.setData( + [], + [], + symbol="o", + pxMode=False, + brush=self.ax2_BrushCircleBrush, + pen=self.ax2_BrushCirclePen, + tip=None, + ) + self.ax2.addItem(self.ax2_BrushCircle) + + # Annotated metadata markers (ScatterPlotItem) + self.ax2_binnedIDs_ScatterPlot = widgets.BaseScatterPlotItem() + self.ax2_binnedIDs_ScatterPlot.setData( + [], + [], + symbol="t", + pxMode=False, + brush=pg.mkBrush((255, 0, 0, 50)), + size=15, + pen=pg.mkPen(width=3, color="r"), + tip=None, + ) + self.ax2.addItem(self.ax2_binnedIDs_ScatterPlot) + + self.ax2_ripIDs_ScatterPlot = widgets.BaseScatterPlotItem() + self.ax2_ripIDs_ScatterPlot.setData( + [], + [], + symbol="x", + pxMode=False, + brush=pg.mkBrush((255, 0, 0, 50)), + size=15, + pen=pg.mkPen(width=2, color="r"), + tip=None, + ) + self.ax2.addItem(self.ax2_ripIDs_ScatterPlot) + + self.freeRoiItem = widgets.PlotCurveItem(pen=pg.mkPen(color="r", width=2)) + self.topLayerItems.append(self.freeRoiItem) + + self.warnPairingItem = widgets.PlotCurveItem( + pen=pg.mkPen(color="r", width=5, style=Qt.DashLine), pxMode=False + ) + self.topLayerItems.append(self.warnPairingItem) + + self.exportMaskImageItem = pg.ImageItem() + + self.ghostContourItemLeft = widgets.GhostContourItem(self.ax1) + self.ghostContourItemRight = widgets.GhostContourItem(self.ax2) + + self.ghostMaskItemLeft = widgets.GhostMaskItem(self.ax1) + self.ghostMaskItemRight = widgets.GhostMaskItem(self.ax2) + + self.manualBackgroundObjItem = widgets.GhostContourItem( + self.ax1, penColor="r", textColor="r" + ) + self.manualBackgroundImageItem = pg.ImageItem() + + def gui_createTextAnnotColors(self, r, g, b, custom=False): + if custom: + self.objLabelAnnotRgb = (int(r), int(g), int(b)) + self.SphaseAnnotRgb = (int(r * 0.9), int(r * 0.9), int(b * 0.9)) + self.G1phaseAnnotRgba = (int(r * 0.8), int(g * 0.8), int(b * 0.8), 220) + else: + self.objLabelAnnotRgb = (255, 255, 255) # white + self.SphaseAnnotRgb = (229, 229, 229) + self.G1phaseAnnotRgba = (204, 204, 204, 220) + self.dividedAnnotRgb = (245, 188, 1) # orange + + self.emptyBrush = pg.mkBrush((0, 0, 0, 0)) + self.emptyPen = pg.mkPen((0, 0, 0, 0)) + + def gui_createTextAnnotItems(self, allIDs): + self.textAnnot = {} + isHighResolution = self.highLowResAction.isChecked() + pxMode = self.pxModeAction.isChecked() + for ax in range(2): + ax_textAnnot = annotate.TextAnnotations() + ax_textAnnot.initFonts(self.fontSize) + ax_textAnnot.createItems(isHighResolution, allIDs, pxMode=pxMode) + self.textAnnot[ax] = ax_textAnnot + + def gui_createZoomRectItem(self): + Y, X = self.currentLab2D.shape + # Label ROI rectangle + pen = pg.mkPen("r", width=3, style=Qt.DashLine) + self.zoomRectItem = widgets.ZoomROI( + (0, 0), + (0, 0), + maxBounds=QRectF(QRect(0, 0, X, Y)), + scaleSnap=True, + translateSnap=True, + pen=pen, + hoverPen=pen, + ) + + def gui_getLostObjScatterItem(self): + self.objLostAnnotRgb = (245, 184, 0) + brush = pg.mkBrush((*self.objLostAnnotRgb, 150)) + pen = pg.mkPen(self.objLostAnnotRgb, width=1) + lostObjScatterItem = pg.ScatterPlotItem( + size=self.contLineWeight + 1, pen=pen, brush=brush, pxMode=False, symbol="s" + ) + return lostObjScatterItem + + def gui_getTrackedLostObjScatterItem(self): + self.objLostTrackedAnnotRgb = (0, 255, 0) + brush = pg.mkBrush((*self.objLostTrackedAnnotRgb, 150)) + pen = pg.mkPen(self.objLostTrackedAnnotRgb, width=1) + lostObjScatterItem = pg.ScatterPlotItem( + size=self.contLineWeight + 1, pen=pen, brush=brush, pxMode=False, symbol="s" + ) + return lostObjScatterItem + + def gui_initImg1BottomWidgets(self): + self.zSliceScrollBar.hide() + self.zProjComboBox.hide() + self.zProjLockViewButton.hide() + self.zSliceOverlay_SB.hide() + self.zProjOverlay_CB.hide() + self.overlay_z_label.hide() + self.zSliceCheckbox.hide() + self.zSliceSpinbox.hide() + self.SizeZlabel.hide() + + def gui_setTextAnnotColors(self): + self.textAnnot[0].setColors( + self.objLabelAnnotRgb, + self.dividedAnnotRgb, + self.SphaseAnnotRgb, + self.G1phaseAnnotRgba, + self.objLostAnnotRgb, + self.objLostTrackedAnnotRgb, + ) + + self.textAnnot[1].setColors( + self.objLabelAnnotRgb, + self.dividedAnnotRgb, + self.SphaseAnnotRgb, + self.G1phaseAnnotRgba, + self.objLostAnnotRgb, + self.objLostTrackedAnnotRgb, + ) + + def hideOverlayLabelsItems(self, specific=None): + if specific is None: + specific = self.overlayLabelsItems.keys() + for segmEndname in specific: + imageItem, contoursItem, gradItem = self.overlayLabelsItems[segmEndname] + imageItem.setVisible(False) + contoursItem.setVisible(False) + gradItem.setVisible(False) + + def imgGradLUTfinished_cb(self): + posData = self.data[self.pos_i] + ticks = self.imgGrad.gradient.listTicks() + + self.img1ChannelGradients[self.user_ch_name] = { + "ticks": [(x, t.color.getRgb()) for t, x in ticks], + "mode": "rgb", + } + + self.df_settings = self.imgGrad.saveState(self.df_settings) + self.df_settings.to_csv(self.settings_csv_path) + + def initColormapOverlayLayerItem(self, foregrColor, lutItem): + if self.invertBwAction.isChecked(): + bkgrColor = (255, 255, 255, 255) + else: + bkgrColor = (0, 0, 0, 255) + gradient = colors.get_pg_gradient((bkgrColor, foregrColor)) + lutItem.setGradient(gradient) + + def initLabelsImageItems(self): + lut = self.getLabelsImageLut() + self.labelsLayerImg1.setLevels([0, len(lut)]) + self.labelsLayerRightImg.setLevels([0, len(lut)]) + self.labelsLayerImg1.setLookupTable(lut) + self.labelsLayerRightImg.setLookupTable(lut) + alpha = self.imgGrad.labelsAlphaSlider.value() + self.labelsLayerImg1.setOpacity(alpha) + self.labelsLayerRightImg.setOpacity(alpha) + + def initLookupTableLab(self): + self.img2.setLookupTable(self.lut) + self.img2.setLevels([0, len(self.lut)]) + self.initLabelsImageItems() + + def loadOverlayData(self, ol_channels, addToExisting=False): + posData = self.data[self.pos_i] + for ol_ch in ol_channels: + if ol_ch not in list(posData.loadedFluoChannels): + # Requested channel was never loaded --> load it at first + # iter i == 0 + success = self.loadFluo_cb(fluo_channels=[ol_ch]) + if not success: + return False + + lastChannelName = ol_channels[-1] + for action in self.fluoDataChNameActions: + if action.text() == lastChannelName: + action.setChecked(True) + + for p, posData in enumerate(self.data): + if addToExisting: + ol_data = posData.ol_data + else: + ol_data = {} + for i, ol_ch in enumerate(ol_channels): + _, filename = self.getPathFromChName(ol_ch, posData) + ol_data[filename] = posData.ol_data_dict[filename].copy() + self.addFluoChNameContextMenuAction(ol_ch) + posData.ol_data = ol_data + + return True + + def loadOverlayLabelsData(self, segmEndname, pos_i=None): + if pos_i is None: + pos_i = self.pos_i + posData = self.data[pos_i] + + if posData.ol_labels_data is None: + posData.ol_labels_data = {} + if segmEndname == "combined segm.": + posData.ol_labels_data["combined segm."] = posData.combine_img_data + return + filePath, filename = load.get_path_from_endname( + segmEndname, posData.images_path + ) + self.logger.info(f'Loading "{segmEndname}.npz"...') + labelsData = np.load(filePath)["arr_0"] + if posData.SizeT == 1: + labelsData = labelsData[np.newaxis] + if self.isSegm3D and labelsData.ndim == 3: + # 2D segm --> stack to 3D + T, Y, X = labelsData.shape + repeat = [labelsData] * posData.SizeZ + labelsData = np.stack(repeat, axis=1) + + posData.ol_labels_data[segmEndname] = labelsData + + def mothBudLineWeightToggled(self, checked=True): + if not checked: + return + self.imgGrad.uncheckContLineWeightActions() + w = self.sender().lineWeight + self.df_settings.at["mothBudLineSize", "value"] = w + self.df_settings.to_csv(self.settings_csv_path) + self._updateMothBudLineSize(w) + self.updateAllImages() + + def mousePressColorButton(self, event): + posData = self.data[self.pos_i] + items = list(self.checkedOverlayChannels) + if len(items) > 1: + selectFluo = widgets.QDialogListbox( + "Select image", + "Select which fluorescence image you want to update the color of\n", + items, + multiSelection=False, + parent=self, + ) + selectFluo.exec_() + keys = selectFluo.selectedItemsText + if selectFluo.cancel or not keys: + return + else: + self.overlayColorButton.channel = keys[0] + else: + self.overlayColorButton.channel = items[0] + self.overlayColorButton.selectColor() + + def overlayChannelToggled(self, checked): + # Action toggled from overlayButton context menu + channelName = self.sender().text() + posData = self.data[self.pos_i] + if checked: + if channelName not in posData.loadedFluoChannels: + self.loadOverlayData([channelName], addToExisting=True) + else: + _, filename = self.getPathFromChName(channelName, posData) + posData.ol_data[filename] = posData.ol_data_dict[filename].copy() + + self.checkedOverlayChannels.add(channelName) + else: + self.checkedOverlayChannels.remove(channelName) + imageItem = self.overlayLayersItems[channelName][0] + imageItem.clear() + + self.setOverlayChannelsToolbuttonsChecked() + self.setOverlayItemsVisible() + self.setRetainSizePolicyLutItems() + self.updateAllImages() + + def overlayChannelToolbuttonClicked(self, checked=False, toolbutton=None): + if toolbutton is None: + toolbutton = self.sender() + + n_checked_buttons = sum( + [b.isChecked() for b in self.allOverlayToolbuttons.values()] + ) + + channelName = toolbutton.channelName() + + if n_checked_buttons == 0 or self.overlayToolbar.isSingleChannel(): + # At least one button must be checked + toolbutton.setChecked(True) + + if self.overlayToolbar.isSingleChannel(): + # Exclusive buttons + for channel, otherToolbutton in self.allOverlayToolbuttons.items(): + if channel == channelName: + continue + + otherToolbutton.setChecked(False) + + if self.overlayToolbar.isTransparent(): + self.setOverlayImages() + return + + self.setOverlayItemsOpacities() + + def overlayLabelsDrawModeToggled(self, action): + segmEndname = action.segmEndname + drawMode = action.text() + if segmEndname in self.drawModeOverlayLabelsChannels: + self.drawModeOverlayLabelsChannels[segmEndname] = drawMode + self.setOverlayLabelsItems() + + def overlayLabels_cb(self, checked, selectedLabelsEndnames=None): + if checked: + if not self.drawModeOverlayLabelsChannels: + if selectedLabelsEndnames is None: + selectedLabelsEndnames = self.askLabelsToOverlay() + if selectedLabelsEndnames is None: + self.logger.info("Overlay labels cancelled.") + self.overlayLabelsButton.setChecked(False) + return + for selectedEndname in selectedLabelsEndnames: + self.loadOverlayLabelsData(selectedEndname) + for action in self.overlayLabelsContextMenu.actions(): + if not action.isCheckable(): + continue + if action.text() == selectedEndname: + action.setChecked(True) + lastSelectedName = selectedLabelsEndnames[-1] + for action in self.selectOverlayLabelsActionGroup.actions(): + if action.text() == lastSelectedName: + action.setChecked(True) + self.updateAllImages() + + def overlay_cb(self, checked): + self.overlayToolbar.setVisible(checked) + + self.UserNormAction, _, _ = self.getCheckNormAction() + posData = self.data[self.pos_i] + if checked: + if posData.ol_data is None: + selectedChannels = self.askSelectOverlayChannel() + if selectedChannels is None: + self.overlayButton.toggled.disconnect() + self.overlayButton.setChecked(False) + self.overlayButton.toggled.connect(self.overlay_cb) + return + + success = self.loadOverlayData(selectedChannels) + if not success: + return False + lastChannel = selectedChannels[-1] + self.setCheckedOverlayContextMenusActions(selectedChannels) + imageItem = self.overlayLayersItems[lastChannel][0] + self.setOpacityOverlayLayersItems(None, imageItem=imageItem) + self.setOverlayChannelsToolbuttonsChecked() + + self.setRetainSizePolicyLutItems() + self.normalizeRescale0to1Action.setChecked(True) + + self.updateAllImages() + self.updateImageValueFormatter() + self.enableOverlayWidgets(True) + else: + self.img1.setOpacity(1.0) + self.updateAllImages() + self.updateImageValueFormatter() + self.enableOverlayWidgets(False) + self.clearOverlayImageItems() + + self.setOverlayItemsVisible() + + def permanentGreedyCmapToggled(self, checked): + if checked: + settings_value = "yes" + else: + self.setLut() + self.updateLookuptable() + self.initLabelsImageItems() + settings_value = "no" + + self.updateAllImages() + + if self.isSnapshot: + option_name = "permanent_greedy_lut_snapshots" + else: + option_name = "permanent_greedy_lut_timelapse" + + self.df_settings.at[option_name, "value"] = settings_value + self.df_settings.to_csv(self.settings_csv_path) + + def removeAllItems(self): + self.ax1.clear() + self.ax2.clear() + try: + self.chNamesQActionGroup.removeAction(self.userChNameAction) + except Exception as e: + pass + try: + posData = self.data[self.pos_i] + for action in self.fluoDataChNameActions: + self.chNamesQActionGroup.removeAction(action) + except Exception as e: + pass + try: + self.overlayButton.setChecked(False) + except Exception as e: + pass + + if hasattr(self, "contoursImage"): + self.initContoursImage() + + def removeOverlayItems(self): + self.lutItemsLayout.clear() + + try: + for toolbutton in self.allOverlayToolbuttonsByIdx.values(): + self.overlayToolbar.removeAction(toolbutton.action) + + self.overlayToolbuttonsSep.removeFromToolbar() + except Exception as err: + pass + + def restoreDefaultColors(self): + try: + color = self.defaultToolBarButtonColor + self.overlayButton.setStyleSheet(f"background-color: {color}") + except AttributeError: + # traceback.print_exc() + pass + + def restoreDefaultSettings(self): + df = self.df_settings + df.at["contLineWeight", "value"] = 1 + df.at["mothBudLineSize", "value"] = 1 + df.at["mothBudLineColor", "value"] = (255, 165, 0, 255) + df.at["contLineColor", "value"] = (205, 0, 0, 220) + + self._updateContColour((205, 0, 0, 220)) + self._updateMothBudLineColour((255, 165, 0, 255)) + self._updateMothBudLineSize(1) + self._updateContLineThickness() + + df.at["overlaySegmMasksAlpha", "value"] = 0.3 + df.at["img_cmap", "value"] = "grey" + self.imgCmap = self.imgGrad.cmaps["grey"] + self.imgCmapName = "grey" + self.labelsGrad.item.loadPreset("viridis") + df.at["labels_bkgrColor", "value"] = (25, 25, 25) + + if df.at["is_bw_inverted", "value"] == "Yes": + self.invertBw(update=False) + + df = df[~df.index.str.contains("lab_cmap")] + df.to_csv(self.settings_csv_path) + self.imgGrad.restoreState(df) + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.restoreState(df) + + self.labelsGrad.saveState(df) + self.labelsGrad.restoreState(df, loadCmap=False) + + self.df_settings.to_csv(self.settings_csv_path) + self.updateAllImages() + + def saveBkgrColor(self, button): + color = button.color().getRgb()[:3] + self.df_settings.at["labels_bkgrColor", "value"] = color + self.df_settings.to_csv(self.settings_csv_path) + self.updateAllImages() + + def saveContColour(self, colorButton): + self.df_settings.to_csv(self.settings_csv_path) + + def saveMothBudLineColour(self, colorButton): + self.df_settings.to_csv(self.settings_csv_path) + + def saveOverlayColor(self, button): + rgb = button.color().getRgb()[:3] + rgb_text = "_".join([str(val) for val in rgb]) + self.df_settings.at[f"{button.channel}_rgb", "value"] = rgb_text + self.df_settings.to_csv(self.settings_csv_path) + + def saveTextIDsColors(self, button): + self.df_settings.at["textIDsColor", "value"] = self.objLabelAnnotRgb + self.df_settings.to_csv(self.settings_csv_path) + + def saveTextLabelsColor(self, button): + color = button.color().getRgb()[:3] + self.df_settings.at["labels_text_color", "value"] = color + self.df_settings.to_csv(self.settings_csv_path) + + def segmNdimIndicatorClicked(self): + ndimText = self.segmNdimIndicator.text() + if ndimText == "2D": + alternativeNdimText = "3D" + toggleText = "activate" + else: + alternativeNdimText = "2D" + toggleText = "de-activate" + msg = widgets.myMessageBox(wrapText=False) + important_txt = """ + The toggle to activate 3D segmentation is visible only when + the Number of z-slices is greater than 1. + """ + txt = html_utils.paragraph(f""" + This indicator shows that you are working with {ndimText} + segmentation masks.

    + + If instead, you want to work with {alternativeNdimText} segmentation, + you need to initialize a new segmentation file.

    + + To do so, go the menu on the top menubar File --> + New Segmentation File... and,
    + at the dialog where you insert the metadata (Number of z-slices, + pixel size, etc.),
    + {toggleText} the parameter called Work with 3D + segmentation masks (z-stack)
    + as indicated in the screenshot below
    . + {html_utils.to_admonition(important_txt, admonition_type="note")} +
    + """) + msg.information( + self, + "Segmentation nmber of dimensions info", + txt, + image_paths=":toggle_3D_screenshot.png", + ) + self.segmNdimIndicator.setChecked(True) + + def setAllContoursImages(self, delROIsIDs=None, compute=True): + if compute: + self.computeAllContours() + self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) + self.updateContoursImage(ax=1, delROIsIDs=delROIsIDs, compute=compute) + + def setAllLostObjContoursImage(self, delROIsIDs=None): + self.updateLostContoursImage(ax=0, delROIsIDs=None) + self.updateLostContoursImage(ax=1, delROIsIDs=None) + + def setAllLostTrackedObjContoursImage(self, delROIsIDs=None): + self.updateLostTrackedContoursImage(ax=0, delROIsIDs=None) + self.updateLostTrackedContoursImage(ax=1, delROIsIDs=None) + + def setCheckedOverlayContextMenusActions(self, channelNames): + for action in self.overlayContextMenu.actions(): + if action.text() in channelNames: + action.setChecked(True) + self.checkedOverlayChannels.add(action.text()) + + def setContoursImage(self, imageItem, contours, thickness, color): + cv2.drawContours(self.contoursImage, contours, -1, color, thickness) + imageItem.setImage(self.contoursImage) + + def setLostObjectContour(self, obj): + allContours = self.getObjContours(obj, all_external=True) + for objContours in allContours: + xx = objContours[:, 0] + 0.5 + yy = objContours[:, 1] + 0.5 + data = [obj.label] * len(xx) + self.ax1_lostObjScatterItem.addPoints(xx, yy, data=data) + self.ax2_lostObjScatterItem.addPoints(xx, yy) + + def setLut(self, shuffle=True): + self.lut = self.labelsGrad.item.colorMap().getLookupTable(0, 1, 255) + if shuffle: + np.random.shuffle(self.lut) + + # Insert background color + if "labels_bkgrColor" in self.df_settings.index: + rgbString = self.df_settings.at["labels_bkgrColor", "value"] + try: + r, g, b = rgbString + except Exception as e: + r, g, b = colors.rgb_str_to_values(rgbString) + else: + r, g, b = 25, 25, 25 + self.df_settings.at["labels_bkgrColor", "value"] = (r, g, b) + + self.lut = np.insert(self.lut, 0, [r, g, b], axis=0) + + def setOpacityOverlayLayersItems(self, value, imageItem=None, scrollbar=None): + if scrollbar is None: + scrollbar = imageItem.alphaScrollBar + + channel = scrollbar.channelName + toolbutton = self.allOverlayToolbuttons[channel] + if not toolbutton.isChecked() or not toolbutton.isVisible(): + return + + if value is None: + value = scrollbar.value() + + if imageItem is None: + imageItem = scrollbar.imageItem + alpha = value / scrollbar.maximum() + elif value > 1: + alpha = value / scrollbar.maximum() + else: + alpha = value + + alpha_values = [] + activeOverlayImageItems = [] + for items in self.overlayLayersItems.values(): + imgItem, lutItem, alphaSB = items[:3] + _toolbutton = alphaSB.toolbutton + if alphaSB.channelName == channel: + alpha_values.append(alpha) + elif not _toolbutton.isChecked() or not _toolbutton.isVisible(): + continue + else: + alpha_values.append(alphaSB.value() / alphaSB.maximum()) + + activeOverlayImageItems.append(imgItem) + + opacities = colors.hierarchical_weights(alpha_values)[::-1] + + for i, imgItem in enumerate(activeOverlayImageItems): + imgItem.setOpacity(opacities[i + 1]) + + self.img1.setOpacity(opacities[0], applyToLinked=False) + + def setOverlayChannelsToolbuttonsChecked(self): + for channel, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB, toolbutton = items[:4] + toolbutton.setChecked( + not self.overlayToolbar.isSingleChannel() + and channel in self.checkedOverlayChannels + ) + + def setOverlayColors(self): + self.overlayRGBs = [ + (255, 255, 0), + (252, 72, 254), + (49, 222, 134), + (22, 108, 27), + ] + self.overlayCmap = matplotlib.colormaps["hsv"] + self.overlayRGBs.extend( + [ + tuple([round(c * 255) for c in self.overlayCmap(i)][:3]) + for i in np.linspace(0, 1, 8) + ] + ) + + def setOverlayImages(self, frame_i=None): + if not self.overlayButton.isChecked(): + return + + posData = self.data[self.pos_i] + if posData.ol_data is None: + return + + rgba_imgs_info = {} + for filename in posData.ol_data: + chName = utils.get_chname_from_basename( + filename, posData.basename, remove_ext=False + ) + if chName not in self.checkedOverlayChannels: + continue + + items = self.overlayLayersItems[chName] + imageItem, lutItem, alphaSB = items[:3] + + ol_img = self.getOlImg(filename, frame_i=frame_i) + + if self.overlayToolbar.isTransparent(): + toolbutton = items[3] + if not toolbutton.isChecked(): + continue + alpha_val = alphaSB.value() / alphaSB.maximum() + ol_img = skimage.exposure.rescale_intensity( + ol_img, out_range=(0.0, 1.0) + ) + out_range_min, out_range_max = lutItem.getLevels() + rgba_imgs_info[chName] = (ol_img, alpha_val, lutItem) + else: + self.rescaleIntensitiesLut(setImage=False, imageItem=imageItem) + imageItem.setImage(ol_img) + + if not self.overlayToolbar.isTransparent(): + return + + alpha_values = [] + images = [] + luts = [] + for channel, info in rgba_imgs_info.items(): + ol_img, alpha_val, lutItem = info + alpha_values.append(alpha_val) + images.append(ol_img) + luts.append(lutItem.gradient.getLookupTable(256, alpha=255) / 255) + + weights = colors.hierarchical_weights(alpha_values) + + if self.baseLayerToolbutton.isChecked(): + image1 = self._getImageupdateAllImages() + image1 = skimage.exposure.rescale_intensity(image1, out_range=(0.0, 1.0)) + images.append(image1) + baseLut = self.imgGrad.gradient.getLookupTable(256, alpha=255) / 255 + luts.append(baseLut) + + images_rgba = [] + for img, lut in zip(images, luts): + rgba = colors.grayscale_apply_lut(img, lut) + images_rgba.append(rgba) + + rgba_merge = colors.hierarchical_blend(images_rgba, weights) + self.rgbaImg1.setImage(rgba_merge) + + def setOverlayItemsOpacities(self): + n_checked_buttons = sum( + [b.isChecked() for b in self.allOverlayToolbuttons.values()] + ) + + isSingleChannel = ( + self.overlayToolbar.isSingleChannel() or n_checked_buttons == 1 + ) + + channel_opacity_mapper = self.getOpacitiesFromAlphaScrollbarValues() + + # Set opacity of every layer accordingly + for channel, otherToolbutton in self.allOverlayToolbuttons.items(): + if channel == self.user_ch_name: + otherImageItem = self.img1 + alphaScrollbar = None + # alpha_value = channel_opacity_mapper[channel] + else: + otherItems = self.overlayLayersItems[channel] + otherImageItem = otherItems[0] + alphaScrollbar = otherItems[2] + # alpha_value = alphaScrollbar.value()/alphaScrollbar.maximum() + + if otherToolbutton.isChecked() and isSingleChannel: + op_val = 1.0 + elif otherToolbutton.isChecked(): + op_val = channel_opacity_mapper[channel] + else: + op_val = 0.0 + + if op_val == 0: + op_val = 0.01 + + op_val = op_val if op_val < 1.0 else 0.999 + + otherImageItem.setOpacity(op_val, applyToLinked=False) + + if alphaScrollbar is None: + continue + + alphaScrollbar.setDisabled(bool(op_val == 0)) + + def setOverlayItemsVisible(self): + for channel, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB, toolbutton = items[:4] + lutItem.hide() + alphaSB.hide() + alphaSB.label.hide() + toolbutton.setVisible(False) + + if not self.overlayButton.isChecked(): + return + + for channel, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB, toolbutton = items[:4] + if channel in self.checkedOverlayChannels: + lutItem.show() + alphaSB.show() + alphaSB.label.show() + toolbutton.setVisible(True) + + def setOverlayLabelsItems(self, specific=None): + if not self.overlayLabelsButton.isChecked(): + self.hideOverlayLabelsItems(specific=specific) + return + + if specific is None: + specific = self.drawModeOverlayLabelsChannels.keys() + + for segmEndname in specific: + drawMode = self.drawModeOverlayLabelsChannels[segmEndname] + ol_lab = self.getOverlayLabelsData(segmEndname) + items = self.overlayLabelsItems[segmEndname] + imageItem, contoursItem, gradItem = items + contoursItem.clear() + if drawMode == "Draw contours": + for obj in skimage.measure.regionprops(ol_lab): + contours = self.getObjContours(obj, all_external=True) + for cont in contours: + contoursItem.addPoints(cont[:, 0] + 0.5, cont[:, 1] + 0.5) + elif drawMode == "Overlay labels": + imageItem.setImage(ol_lab, autoLevels=False) + self.showOverlayLabelsItems(specific=specific) + + def setOverlayLabelsItemsVisible(self, checked): + for _segmEndname, drawMode in self.drawModeOverlayLabelsChannels.items(): + items = self.overlayLabelsItems[_segmEndname] + gradItem = items[-1] + gradItem.hide() + + if checked: + segmEndname = self.sender().text() + gradItem = self.overlayLabelsItems[segmEndname][-1] + gradItem.show() + + def setOverlaySegmMasks(self, force=False, forceIfNotActive=False): + if not hasattr(self, "currentLab2D"): + return + + how = self.drawIDsContComboBox.currentText() + isOverlaySegmLeftActive = how.find("overlay segm. masks") != -1 + + how_ax2 = self.getAnnotateHowRightImage() + isOverlaySegmRightActive = ( + how_ax2.find("overlay segm. masks") != -1 + and self.labelsGrad.showRightImgAction.isChecked() + ) + + isOverlaySegmActive = ( + isOverlaySegmLeftActive or isOverlaySegmRightActive or force + ) + if not isOverlaySegmActive and not forceIfNotActive: + return + + alpha = self.imgGrad.labelsAlphaSlider.value() + if alpha == 0: + return + + posData = self.data[self.pos_i] + maxID = max(posData.IDs, default=0) + + if maxID >= len(self.lut): + self.extendLabelsLUT(maxID + 10) + + currentLab2D = self.currentLab2D + if isOverlaySegmLeftActive: + self.labelsLayerImg1.setImage(currentLab2D, autoLevels=False) + + if isOverlaySegmRightActive: + self.labelsLayerRightImg.setImage(currentLab2D, autoLevels=False) + + def setOverlaySingleChannel(self, *args, **kwargs): + if self.overlayToolbar.isSingleChannel(): + self.overlayToolbarAreChannelsChecked = { + channel: toolbutton.isChecked() + for channel, toolbutton in self.allOverlayToolbuttons.items() + } + firstActiveToolbutton = [ + toolbutton + for toolbutton in self.allOverlayToolbuttons.values() + if toolbutton.isChecked() + ][0] + firstActiveToolbutton.click() + else: + for ch, checked in self.overlayToolbarAreChannelsChecked.items(): + toolbutton = self.allOverlayToolbuttons[ch] + toolbutton.setChecked(checked) + + self.setOverlayItemsOpacities() + + def setOverlayTransparency(self, transparent: bool): + opacity = float(transparent) + opacity = opacity if opacity < 1.0 else 0.999 + self.rgbaImg1.setOpacity(opacity) + + if transparent: + self.img1.setOpacity(0.001, applyToLinked=False) + self.imgGrad.sigLookupTableChanged.connect( + self.updateTransparentOverlayRgba + ) + self.imgGrad.sigLevelsChanged.connect(self.updateTransparentOverlayRgba) + + for channel, items in self.overlayLayersItems.items(): + imageItem, lutItem, alphaSB = items[:3] + if transparent: + alphaSB.valueChanged.disconnect() + alphaSB.valueChanged.connect(self.updateTransparentOverlayRgba) + lutItem.sigLookupTableChanged.connect(self.updateTransparentOverlayRgba) + lutItem.sigLevelsChanged.connect(self.updateTransparentOverlayRgba) + imageItem.setOpacity(0) + + if not transparent: + self.setOverlayItemsOpacities() + + self.setOverlayImages() + + def setPermanentGreedyCmapPreferences(self): + if self.isSnapshot: + option_name = "permanent_greedy_lut_snapshots" + else: + option_name = "permanent_greedy_lut_timelapse" + + if option_name not in self.df_settings.index: + return + + checked = self.df_settings.at[option_name, "value"] == "yes" + self.labelsGrad.permanentGreedyCmapAction.setChecked(checked) + + def setRetainSizePolicyLutItems(self): + if not self.retainSizeLutItems: + return + for channel, items in self.overlayLayersItems.items(): + _, lutItem, alphaSB = items[:3] + utils.setRetainSizePolicy(lutItem, retain=True) + QTimer.singleShot(300, self.autoRange) + + def setTrackedLostObjectContour(self, obj): + if self.isExportingVideo: + return + + allContours = self.getObjContours(obj, all_external=True) + for objContours in allContours: + xx = objContours[:, 0] + 0.5 + yy = objContours[:, 1] + 0.5 + data = [obj.label] * len(xx) + self.ax1_lostTrackedScatterItem.addPoints(xx, yy, data=data) + self.ax2_lostTrackedScatterItem.addPoints(xx, yy) + + def setValueLabelsAlphaSlider(self, value): + self.imgGrad.labelsAlphaSlider.setValue(value) + self.updateLabelsAlpha(value) + + def showOverlayContextMenu(self, event): + if not self.overlayButton.isChecked(): + return + + self.overlayContextMenu.exec_(QCursor.pos()) + + def showOverlayLabelsContextMenu(self, event): + if not self.overlayLabelsButton.isChecked(): + return + + self.overlayLabelsContextMenu.exec_(QCursor.pos()) + + def showOverlayLabelsItems(self, specific=None): + if specific is None: + specific = self.overlayLabelsItems.keys() + for segmEndname in specific: + imageItem, contoursItem, gradItem = self.overlayLabelsItems[segmEndname] + drawMode = self.drawModeOverlayLabelsChannels[segmEndname] + if drawMode == "Draw contours": + contoursItem.setVisible(True) + elif drawMode == "Overlay labels": + imageItem.setVisible(True) + gradItem.setVisible(True) + + def shuffle_cmap(self): + np.random.shuffle(self.lut[1:]) + self.initLabelsImageItems() + self.updateAllImages() + + def ticksCmapMoved(self, gradient): + pass + + def toggleOverlayColorButton(self, checked=True): + self.mousePressColorButton(None) + + def toggleTextIDsColorButton(self, checked=True): + self.textIDsColorButton.selectColor() + + def updateBkgrColor(self, button): + color = button.color().getRgb()[:3] + self.lut[0] = color + self.updateLookuptable() + + def updateContColour(self, colorButton): + color = colorButton.color().getRgb() + self.df_settings.at["contLineColor", "value"] = str(color) + self._updateContColour(color) + self.updateAllImages() + + def updateContoursImage(self, ax, delROIsIDs=None, compute=True): + imageItem = self.getContoursImageItem(ax) + if imageItem is None: + return + + if not hasattr(self, "contoursImage"): + self.initContoursImage() + else: + self.contoursImage[:] = 0 + + contours = [] + for obj in skimage.measure.regionprops(self.currentLab2D): + obj_contours = self.getObjContours( + obj, + all_external=True, + force_calc=compute, + include_internal=self.showAllContoursToggle.isChecked(), + ) + contours.extend(obj_contours) + + thickness = self.contLineWeight + color = self.contLineColor + self.setContoursImage(imageItem, contours, thickness, color) + + def updateLabelsCmap(self, gradient): + self.setLut() + self.updateLookuptable() + self.initLabelsImageItems() + + self.df_settings = self.labelsGrad.saveState(self.df_settings) + self.df_settings.to_csv(self.settings_csv_path) + + self.updateAllImages() + + def updateLookuptable(self, lenNewLut=None, delIDs=None): + posData = self.data[self.pos_i] + if lenNewLut is None: + try: + if delIDs is None: + IDs = posData.IDs + else: + # Remove IDs removed with ROI from LUT + IDs = [ID for ID in posData.IDs if ID not in delIDs] + lenNewLut = max(IDs, default=0) + 1 + except ValueError: + # Empty segmentation mask + lenNewLut = 1 + # Build a new lut to include IDs > than original len of lut + updateLevels = self.extendLabelsLUT(lenNewLut) + lut = self.lut.copy() + + try: + # lut = self.lut[:lenNewLut].copy() + for ID in posData.binnedIDs: + lut[ID] = lut[ID] * 0.2 + + for ID in posData.ripIDs: + lut[ID] = lut[ID] * 0.2 + except Exception as e: + err_str = traceback.format_exc() + print("=" * 30) + self.logger.info(err_str) + print("=" * 30) + + if updateLevels: + self.img2.setLevels([0, len(lut)]) + + if self.keepIDsButton.isChecked(): + lut = np.round(lut * 0.3).astype(np.uint8) + keptLut = np.round(lut[self.keptObjectsIDs] / 0.3).astype(np.uint8) + lut[self.keptObjectsIDs] = keptLut + + self.img2.setLookupTable(lut) + + def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): + if draw: + imageItem = self.getLostObjImageItem(ax) + if imageItem is None: + return + + if not hasattr(self, "lostObjContoursImage"): + self.initLostObjContoursImage() + else: + self.lostObjContoursImage[:] = 0 + + if delROIsIDs is None: + delROIsIDs = set() + + posData = self.data[self.pos_i] + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[posData.frame_i - 1]["IDs_idxs"] + if posData.whitelist is not None and posData.whitelist.whitelistIDs is not None: + whitelist = posData.whitelist.whitelistIDs[posData.frame_i - 1] + else: + whitelist = None + + contours = [] + for lostID in posData.lost_IDs: + if lostID in delROIsIDs or ( + whitelist is not None and lostID not in whitelist + ): + continue + + obj = prev_rp[prev_IDs_idxs[lostID]] + if not self.isObjVisible(obj.bbox): + continue + + obj_contours = self.getObjContours(obj, all_external=True) + + if ax == 0: + self.addLostObjsToLostObjImage(obj, lostID) + + contours.extend(obj_contours) + + if not draw: + return + + self.drawLostObjContoursImage(imageItem, contours) + + def updateLostTrackedContoursImage( + self, ax, delROIsIDs=None, tracked_lost_IDs=None + ): + imageItem = self.getLostTrackedObjImageItem(ax) + if imageItem is None: + return + + if not hasattr(self, "lostTrackedObjContoursImage"): + self.initLostTrackedObjContoursImage() + else: + self.lostTrackedObjContoursImage[:] = 0 + + if delROIsIDs is None: + delROIsIDs = set() + + posData = self.data[self.pos_i] + if tracked_lost_IDs is None: + tracked_lost_IDs = self.getTrackedLostIDs() + + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[posData.frame_i - 1]["IDs_idxs"] + contours = [] + for tracked_lost_ID in tracked_lost_IDs: + if tracked_lost_ID in delROIsIDs: + continue + + obj = prev_rp[prev_IDs_idxs[tracked_lost_ID]] + if not self.isObjVisible(obj.bbox): + continue + + obj_contours = self.getObjContours(obj, all_external=True) + contours.extend(obj_contours) + + self.drawLostTrackedObjContoursImage(imageItem, contours) + + def updateMothBudLineColour(self, colorButton): + color = colorButton.color().getRgb() + self.df_settings.at["mothBudLineColor", "value"] = str(color) + self._updateMothBudLineColour(color) + self.updateAllImages() + + def updateTextAnnotColor(self, button): + r, g, b = np.array(self.textIDsColorButton.color().getRgb()[:3]) + self.imgGrad.textColorButton.setColor((r, g, b)) + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.textColorButton.setColor((r, g, b)) + self.gui_createTextAnnotColors(r, g, b, custom=True) + self.gui_setTextAnnotColors() + self.updateAllImages() + + def updateTextLabelsColor(self, button): + self.ax2_textColor = button.color().getRgb()[:3] + posData = self.data[self.pos_i] + if posData.rp is None: + return + + for obj in posData.rp: + self.getObjOptsSegmLabels(obj) + + def updateTransparentOverlayRgba(self, *args, **kwargs): + self.setOverlayImages() diff --git a/cellacdc/mixins/image_controls.py b/cellacdc/mixins/image_controls.py new file mode 100644 index 000000000..8f9b98c14 --- /dev/null +++ b/cellacdc/mixins/image_controls.py @@ -0,0 +1,447 @@ +"""Qt view adapter for image controls and bottom layout.""" + +from __future__ import annotations + +from qtpy.QtCore import Qt +from qtpy.QtGui import QFont +from qtpy.QtWidgets import ( + QAction, + QActionGroup, + QCheckBox, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QMenu, + QVBoxLayout, + QWidget, +) + +from cellacdc import widgets + +_font = QFont() +_font.setPixelSize(11) + +from .frame_navigation import FrameNavigation + + +class ImageControls(FrameNavigation): + """Extracted from guiWin.""" + + def gui_createBottomWidgetsToBottomLayout(self): + # self.bottomDockWidget = QDockWidget(self) + bottomScrollArea = widgets.ScrollArea(resizeVerticalOnShow=True) + bottomScrollArea.sigLeaveEvent.connect(self.setFocusMain) + bottomWidget = QWidget() + bottomScrollAreaLayout = QVBoxLayout() + self.bottomLayout = QHBoxLayout() + self.bottomLayout.addLayout(self.quickSettingsLayout) + self.bottomLayout.addStretch(1) + self.bottomLayout.addWidget(self.img1BottomGroupbox) + self.bottomLayout.addStretch(1) + self.bottomLayout.addWidget(self.rightBottomGroupbox) + self.bottomLayout.addStretch(1) + + bottomScrollAreaLayout.addLayout(self.bottomLayout) + bottomScrollAreaLayout.addStretch(1) + + bottomWidget.setLayout(bottomScrollAreaLayout) + bottomScrollArea.setWidgetResizable(True) + bottomScrollArea.setWidget(bottomWidget) + self.bottomScrollArea = bottomScrollArea + + if "bottom_sliders_zoom_perc" in self.df_settings.index: + val = int(self.df_settings.at["bottom_sliders_zoom_perc", "value"]) + zoom_perc = val + else: + zoom_perc = 100 + self.bottomLayoutContextMenu = QMenu("Bottom layout", self) + zoomMenu = self.bottomLayoutContextMenu.addMenu("Zoom") + actions = [] + self.bottomLayoutContextMenu.zoomActionGroup = QActionGroup(zoomMenu) + for perc in np.arange(50, 151, 10): + action = QAction(f"{perc}%", zoomMenu) + action.setCheckable(True) + if perc == zoom_perc: + action.setChecked(True) + action.toggled.connect(self.zoomBottomLayoutActionTriggered) + actions.append(action) + self.bottomLayoutContextMenu.zoomActionGroup.addAction(action) + zoomMenu.addActions(actions) + resetAction = self.bottomLayoutContextMenu.addAction("Reset default height") + resetAction.triggered.connect(self.resizeGui) + retainSpaceAction = self.bottomLayoutContextMenu.addAction( + "Retain space of hidden sliders" + ) + retainSpaceAction.setCheckable(True) + if "retain_space_hidden_sliders" in self.df_settings.index: + retainSpaceChecked = ( + self.df_settings.at["retain_space_hidden_sliders", "value"] == "Yes" + ) + else: + retainSpaceChecked = True + retainSpaceAction.setChecked(retainSpaceChecked) + retainSpaceAction.toggled.connect(self.retainSpaceSlidersToggled) + self.retainSpaceSlidersAction = retainSpaceAction + self.setBottomLayoutStretch() + + def gui_createGraphicsPlots(self): + self.graphLayout = pg.GraphicsLayoutWidget() + if self.invertBwAction.isChecked(): + self.graphLayout.setBackground(graphLayoutBkgrColor) + self.titleColor = "black" + else: + self.graphLayout.setBackground(darkBkgrColor) + self.titleColor = "white" + + self.lutItemsLayout = self.graphLayout.addLayout(row=1, col=0) + # self.lutItemsLayout.setBorder('w') + + # Left plot + self.ax1 = widgets.MainPlotItem(showWelcomeText=True) + self.ax1.invertY(True) + self.ax1.setAspectLocked(True) + self.ax1.hideAxis("bottom") + self.ax1.hideAxis("left") + self.plotsCol = 1 + self.graphLayout.addItem(self.ax1, row=1, col=1) + + # Right plot + self.ax2 = widgets.MainPlotItem() + self.ax2.setAspectLocked(True) + self.ax2.invertY(True) + self.ax2.hideAxis("bottom") + self.ax2.hideAxis("left") + # self.currentFrameLabelItem = pg.LabelItem( + # color=self.titleColor, size='13px' + # ) + self.graphLayout.addItem(self.ax2, row=1, col=2) + + def gui_createImg1Widgets(self): + # Toggle contours/ID combobox + self.drawIDsContComboBoxSegmItems = [ + "Draw IDs and contours", + "Draw IDs and overlay segm. masks", + "Draw only cell cycle info", + "Draw cell cycle info and contours", + "Draw cell cycle info and overlay segm. masks", + "Draw only mother-bud lines", + "Draw only IDs", + "Draw only contours", + "Draw only overlay segm. masks", + "Draw nothing", + ] + self.drawIDsContComboBox = widgets.ComboBox() + self.drawIDsContComboBox.setFont(_font) + self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) + self.drawIDsContComboBox.setVisible(False) + + self.annotIDsCheckbox = widgets.CheckBox( + "IDs", keyPressCallback=self.resetFocus + ) + self.annotCcaInfoCheckbox = widgets.CheckBox( + "Cell cycle info", keyPressCallback=self.resetFocus + ) + self.annotNumZslicesCheckbox = widgets.CheckBox( + "No. z-slices/object", keyPressCallback=self.resetFocus + ) + + self.annotContourCheckbox = widgets.CheckBox( + "Contours", keyPressCallback=self.resetFocus + ) + self.annotSegmMasksCheckbox = widgets.CheckBox( + "Segm. masks", keyPressCallback=self.resetFocus + ) + + self.drawMothBudLinesCheckbox = widgets.CheckBox( + "Only mother-daughter line", keyPressCallback=self.resetFocus + ) + + self.drawNothingCheckbox = widgets.CheckBox( + "Do not annotate", keyPressCallback=self.resetFocus + ) + + self.annotOptionsWidget = QWidget() + annotOptionsLayout = QHBoxLayout() + + # Show tree info checkbox + self.showTreeInfoCheckbox = widgets.CheckBox( + "Show tree info", keyPressCallback=self.resetFocus + ) + self.showTreeInfoCheckbox.setFont(_font) + sp = self.showTreeInfoCheckbox.sizePolicy() + sp.setRetainSizeWhenHidden(True) + self.showTreeInfoCheckbox.setSizePolicy(sp) + self.showTreeInfoCheckbox.hide() + + annotOptionsLayout.addWidget(self.showTreeInfoCheckbox) + annotOptionsLayout.addWidget(QLabel(" | ")) + annotOptionsLayout.addWidget(self.annotIDsCheckbox) + annotOptionsLayout.addWidget(self.annotCcaInfoCheckbox) + annotOptionsLayout.addWidget(self.drawMothBudLinesCheckbox) + annotOptionsLayout.addWidget(self.annotNumZslicesCheckbox) + annotOptionsLayout.addWidget(QLabel(" | ")) + annotOptionsLayout.addWidget(self.annotContourCheckbox) + annotOptionsLayout.addWidget(self.annotSegmMasksCheckbox) + annotOptionsLayout.addWidget(QLabel(" | ")) + annotOptionsLayout.addWidget(self.drawNothingCheckbox) + annotOptionsLayout.addWidget(self.drawIDsContComboBox) + self.annotOptionsLayout = annotOptionsLayout + + # Toggle highlight z+-1 objects combobox + self.highlightZneighObjCheckbox = widgets.CheckBox( + "Highlight objects in neighbouring z-slices", + keyPressCallback=self.resetFocus, + ) + self.highlightZneighObjCheckbox.setFont(_font) + self.highlightZneighObjCheckbox.hide() + + annotOptionsLayout.addWidget(self.highlightZneighObjCheckbox) + self.annotOptionsWidget.setLayout(annotOptionsLayout) + + # Annotations options right image + self.annotIDsCheckboxRight = widgets.CheckBox( + "IDs", keyPressCallback=self.resetFocus + ) + self.annotCcaInfoCheckboxRight = widgets.CheckBox( + "Cell cycle info", keyPressCallback=self.resetFocus + ) + self.annotNumZslicesCheckboxRight = widgets.CheckBox( + "No. z-slices/object", keyPressCallback=self.resetFocus + ) + + self.annotContourCheckboxRight = widgets.CheckBox( + "Contours", keyPressCallback=self.resetFocus + ) + self.annotSegmMasksCheckboxRight = widgets.CheckBox( + "Segm. masks", keyPressCallback=self.resetFocus + ) + + self.drawMothBudLinesCheckboxRight = widgets.CheckBox( + "Only mother-daughter line", keyPressCallback=self.resetFocus + ) + + self.drawNothingCheckboxRight = widgets.CheckBox( + "Do not annotate", keyPressCallback=self.resetFocus + ) + + self.annotOptionsWidgetRight = QWidget() + annotOptionsLayoutRight = QHBoxLayout() + + annotOptionsLayoutRight.addWidget(QLabel(" ")) + annotOptionsLayoutRight.addWidget(QLabel(" | ")) + annotOptionsLayoutRight.addWidget(self.annotIDsCheckboxRight) + annotOptionsLayoutRight.addWidget(self.annotCcaInfoCheckboxRight) + annotOptionsLayoutRight.addWidget(self.drawMothBudLinesCheckboxRight) + annotOptionsLayoutRight.addWidget(self.annotNumZslicesCheckboxRight) + annotOptionsLayoutRight.addWidget(QLabel(" | ")) + annotOptionsLayoutRight.addWidget(self.annotContourCheckboxRight) + annotOptionsLayoutRight.addWidget(self.annotSegmMasksCheckboxRight) + annotOptionsLayoutRight.addWidget(QLabel(" | ")) + annotOptionsLayoutRight.addWidget(self.drawNothingCheckboxRight) + self.annotOptionsLayoutRight = annotOptionsLayoutRight + + self.annotOptionsWidgetRight.setLayout(annotOptionsLayoutRight) + + # Frames scrollbar + self.navigateScrollBar = widgets.navigateQScrollBar(Qt.Horizontal) + self.navigateScrollBar.setDisabled(True) + self.navigateScrollBar.setMinimum(1) + self.navigateScrollBar.setMaximum(1) + self.navigateScrollBar.setToolTip( + "NOTE: The maximum frame number that can be visualized with this " + "scrollbar\n" + "is the last visited frame with the selected mode\n" + '(see "Mode" selector on the top-right).\n\n' + "If the scrollbar does not move it means that you never visited\n" + "any frame with current mode.\n\n" + 'Note that the "Viewer" mode allows you to scroll ALL frames.' + ) + t_label = QLabel("frame n. ") + t_label.setFont(_font) + self.t_label = t_label + + # z-slice scrollbars + self.zSliceScrollBar = widgets.linkedQScrollbar(Qt.Horizontal) + + self.zProjComboBox = widgets.ComboBox() + self.zProjComboBox.setFont(_font) + self.zProjComboBox.addItems( + [ + "single z-slice", + "max z-projection", + "mean z-projection", + "median z-proj.", + ] + ) + self.zProjLockViewButton = widgets.LockPushButton() + self.zProjLockViewButton.setCheckable(True) + self.zProjLockViewButton.setToolTip( + "If active, the selected z-slice view is applied to all frames" + ) + self.zProjLockViewButton.hide() + + self.switchPlaneCombobox = widgets.SwitchPlaneCombobox() + self.switchPlaneCombobox.setToolTip("Switch viewed plane") + + self.zSliceOverlay_SB = widgets.ScrollBar(Qt.Horizontal) + _z_label = QLabel("Overlay z-slice ") + _z_label.setFont(_font) + _z_label.setDisabled(True) + self.overlay_z_label = _z_label + + self.zProjOverlay_CB = widgets.ComboBox() + self.zProjOverlay_CB.setFont(_font) + self.zProjOverlay_CB.addItems( + [ + "single z-slice", + "max z-projection", + "mean z-projection", + "median z-proj.", + "same as above", + ] + ) + self.zProjOverlay_CB.setCurrentIndex(4) + self.zSliceOverlay_SB.setDisabled(True) + + self.img1BottomGroupbox = self.gui_getImg1BottomWidgets() + + def gui_createLabWidgets(self): + bottomRightLayout = QVBoxLayout() + self.rightBottomGroupbox = widgets.GroupBox( + "Annotate right image independent of left image", + keyPressCallback=self.resetFocus, + ) + self.rightBottomGroupbox.setCheckable(True) + self.rightBottomGroupbox.setChecked(False) + self.rightBottomGroupbox.hide() + + self.annotateRightHowCombobox = widgets.ComboBox() + self.annotateRightHowCombobox.setFont(_font) + self.annotateRightHowCombobox.addItems(self.drawIDsContComboBoxSegmItems) + self.annotateRightHowCombobox.setCurrentIndex( + self.drawIDsContComboBox.currentIndex() + ) + self.annotateRightHowCombobox.setVisible(False) + + self.annotOptionsLayoutRight.addWidget(self.annotateRightHowCombobox) + + self.rightImageFramesScrollbar = widgets.ScrollBarWithNumericControl( + labelText="Frame n. " + ) + self.rightImageFramesScrollbar.setVisible(False) + + bottomRightLayout.addWidget(self.annotOptionsWidgetRight) + bottomRightLayout.addWidget(self.rightImageFramesScrollbar) + bottomRightLayout.addStretch(1) + + self.rightBottomGroupbox.setLayout(bottomRightLayout) + + self.rightBottomGroupbox.toggled.connect(self.rightImageControlsToggled) + + def gui_getImg1BottomWidgets(self): + bottomLeftLayout = QGridLayout() + self.bottomLeftLayout = bottomLeftLayout + container = QGroupBox("Navigate and annotate left image") + + row = 0 + bottomLeftLayout.addWidget(self.annotOptionsWidget, row, 0, 1, 4) + # bottomLeftLayout.addWidget( + # self.drawIDsContComboBox, row, 1, 1, 2, + # alignment=Qt.AlignCenter + # ) + + # bottomLeftLayout.addWidget( + # self.showTreeInfoCheckbox, row, 0, 1, 1, + # alignment=Qt.AlignCenter + # ) + + row += 1 + navWidgetsLayout = QHBoxLayout() + self.navSpinBox = widgets.SpinBox(disableKeyPress=True) + self.navSpinBox.setMinimum(1) + self.navSpinBox.setMaximum(100) + self.navSizeLabel = QLabel("/ND") + navWidgetsLayout.addWidget(self.t_label) + navWidgetsLayout.addWidget(self.navSpinBox) + navWidgetsLayout.addWidget(self.navSizeLabel) + bottomLeftLayout.addLayout(navWidgetsLayout, row, 0, alignment=Qt.AlignRight) + bottomLeftLayout.addWidget(self.navigateScrollBar, row, 1, 1, 2) + sp = self.navigateScrollBar.sizePolicy() + sp.setRetainSizeWhenHidden(True) + self.navigateScrollBar.setSizePolicy(sp) + self.navSpinBox.connectValueChanged(self.navigateSpinboxValueChanged) + self.navSpinBox.editingFinished.connect(self.navigateSpinboxEditingFinished) + self.navSpinBox.sigUpClicked.connect(self.navigateSpinboxEditingFinished) + self.navSpinBox.sigDownClicked.connect(self.navigateSpinboxEditingFinished) + + self.lastTrackedFrameLabel = QLabel() + self.lastTrackedFrameLabel.setFont(_font) + bottomLeftLayout.addWidget(self.lastTrackedFrameLabel, row, 3) + + row += 1 + zSliceCheckboxLayout = QHBoxLayout() + self.zSliceCheckbox = QCheckBox("z-slice") + self.zSliceSpinbox = widgets.SpinBox(disableKeyPress=True) + self.zSliceSpinbox.setMinimum(1) + self.SizeZlabel = QLabel("/ND") + self.zSliceCheckbox.setToolTip( + "Activate/deactivate control of the z-slices with keyboard arrows.\n\n" + 'SHORTCUT to toggle ON/OFF: "Z" key' + ) + zSliceCheckboxLayout.addWidget(self.zSliceCheckbox) + zSliceCheckboxLayout.addWidget(self.zSliceSpinbox) + zSliceCheckboxLayout.addWidget(self.SizeZlabel) + bottomLeftLayout.addLayout( + zSliceCheckboxLayout, row, 0, alignment=Qt.AlignRight + ) + bottomLeftLayout.addWidget(self.zSliceScrollBar, row, 1, 1, 2) + bottomLeftLayout.addWidget(self.zProjComboBox, row, 3) + bottomLeftLayout.addWidget(self.zProjLockViewButton, row, 4) + bottomLeftLayout.addWidget(self.switchPlaneCombobox, row, 5) + self.zSliceSpinbox.connectValueChanged(self.onZsliceSpinboxValueChange) + self.zSliceSpinbox.editingFinished.connect(self.zSliceScrollBarReleased) + + row += 1 + bottomLeftLayout.addWidget( + self.overlay_z_label, row, 0, alignment=Qt.AlignRight + ) + bottomLeftLayout.addWidget(self.zSliceOverlay_SB, row, 1, 1, 2) + + bottomLeftLayout.addWidget(self.zProjOverlay_CB, row, 3) + + row += 1 + self.alphaScrollbarRow = row + + bottomLeftLayout.setColumnStretch(0, 0) + bottomLeftLayout.setColumnStretch(1, 3) + bottomLeftLayout.setColumnStretch(2, 0) + + container.setLayout(bottomLeftLayout) + return container + + def gui_resetBottomLayoutHeight(self): + self.h = self.defaultWidgetHeightBottomLayout + self.checkBoxesHeight = 14 + self.fontPixelSize = 11 + self.resizeSlidersArea() + + def resetFocus(self): + self.setFocusGraphics() + self.setFocusMain() + + def rightImageControlsToggled(self, checked): + if self.isDataLoading: + return + if checked: + self.annotateRightHowCombobox.setCurrentText( + self.drawIDsContComboBox.currentText() + ) + self.updateAllImages() + + def setFocusGraphics(self): + self.graphLayout.setFocus() + + def setFocusMain(self): + # on macOS with Qt6 setFocus causes crashes. Disabled for now. + return diff --git a/cellacdc/mixins/image_display.py b/cellacdc/mixins/image_display.py new file mode 100644 index 000000000..96b6407df --- /dev/null +++ b/cellacdc/mixins/image_display.py @@ -0,0 +1,1323 @@ +"""Qt view adapter for image display, LUT, and cursor workflows.""" + +from __future__ import annotations + +from functools import partial + +import numpy as np +import pyqtgraph as pg +import skimage.exposure +import skimage.measure +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import QAction, QActionGroup + +from cellacdc import ( + apps, + darkBkgrColor, + disableWindow, + exception_handler, + graphLayoutBkgrColor, + utils, + settings_csv_path, +) + +from .display_decorations import DisplayDecorations + + +class ImageDisplay(DisplayDecorations): + """Extracted from guiWin.""" + + def RGBtoGray(self, R, G, B): + # see https://stackoverflow.com/questions/17615963/standard-rgb-to-grayscale-conversion + C_linear = (0.2126 * R + 0.7152 * G + 0.0722 * B) / 255 + if C_linear <= 0.0031309: + gray = 12.92 * C_linear + else: + gray = 1.055 * (C_linear) ** (1 / 2.4) - 0.055 + return gray + + def _getImageupdateAllImages(self, image=None): + if image is not None: + return image + + img = self.getImage() + return img + + def activeBrushCircleCursors(self, isHoverImg1): + if self.showMirroredCursorAction.isChecked(): + return self.ax1_BrushCircle, self.ax2_BrushCircle + + if isHoverImg1: + return (self.ax1_BrushCircle,) + else: + return (self.ax2_BrushCircle,) + + def activeEraserCircleCursors(self, isHoverImg1): + if self.showMirroredCursorAction.isChecked(): + return self.ax1_EraserCircle, self.ax2_EraserCircle + + if isHoverImg1: + return (self.ax1_EraserCircle,) + else: + return (self.ax2_EraserCircle,) + + def activeEraserXCursors(self, isHoverImg1): + if self.showMirroredCursorAction.isChecked(): + return self.ax1_EraserX, self.ax2_EraserX + + if isHoverImg1: + return (self.ax1_EraserX,) + else: + return (self.ax2_EraserX,) + + def addFontSizeActions(self, menu, slot): + fontActionGroup = QActionGroup(self) + fontActionGroup.setExclusive(True) + for fontSize in range(4, 27): + action = QAction(self) + action.setText(str(fontSize)) + action.setCheckable(True) + if fontSize == self.fontSize: + action.setChecked(True) + fontActionGroup.addAction(action) + menu.addAction(action) + action.triggered.connect(slot) + return fontActionGroup + + def autoRange(self): + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.ax2.autoRange() + self.ax1.autoRange() + + def changeFontSize(self): + fontSize = self.fontSizeSpinBox.value() + if fontSize == self.fontSize: + return + + self.fontSize = fontSize + + self.df_settings.at["fontSize", "value"] = self.fontSize + self.df_settings.to_csv(self.settings_csv_path) + + self.setAllIDs() + posData = self.data[self.pos_i] + for ax in range(2): + self.textAnnot[ax].changeFontSize(self.fontSize) + if self.highLowResAction.isChecked(): + self.setAllTextAnnotations() + else: + self.updateAllImages() + + def clearCursors(self): + self.ax1_cursor.setData([], []) + self.ax2_cursor.setData([], []) + self.setHoverToolSymbolData( + [], + [], + (self.ax2_BrushCircle, self.ax1_BrushCircle), + ) + eraserCursors = ( + self.ax1_EraserCircle, + self.ax2_EraserCircle, + self.ax1_EraserX, + self.ax2_EraserX, + ) + self.setHoverToolSymbolData([], [], eraserCursors) + + def customLevelsLutChanged(self, levels, imageItem=None): + imageItem.setLevels(levels) + + def editImgProperties(self, checked=True): + posData = self.data[self.pos_i] + posData.askInputMetadata( + len(self.data), + ask_SizeT=True, + ask_TimeIncrement=True, + ask_PhysicalSizes=True, + save=True, + singlePos=True, + askSegm3D=False, + ) + if hasattr(self, "timestamp"): + self.timestamp.setSecondsPerFrame(posData.TimeIncrement) + self.updateTimestampFrame() + + if hasattr(self, "scaleBar"): + self.scaleBar.updatePhysicalLength(posData.PhysicalSizeX) + + def enableZstackWidgets(self, enabled): + if enabled: + utils.setRetainSizePolicy(self.zSliceScrollBar) + utils.setRetainSizePolicy(self.zProjComboBox) + utils.setRetainSizePolicy(self.zSliceOverlay_SB) + utils.setRetainSizePolicy(self.zProjOverlay_CB) + utils.setRetainSizePolicy(self.overlay_z_label) + self.zSliceScrollBar.setDisabled(False) + self.zProjComboBox.show() + if self.data[self.pos_i].SizeT > 1: + self.zProjLockViewButton.show() + self.zSliceScrollBar.show() + self.zSliceCheckbox.show() + self.zSliceSpinbox.show() + self.switchPlaneCombobox.show() + self.switchPlaneCombobox.setDisabled(False) + self.SizeZlabel.show() + else: + utils.setRetainSizePolicy(self.zSliceScrollBar, retain=False) + utils.setRetainSizePolicy(self.zProjComboBox, retain=False) + utils.setRetainSizePolicy(self.zSliceOverlay_SB, retain=False) + utils.setRetainSizePolicy(self.zProjOverlay_CB, retain=False) + utils.setRetainSizePolicy(self.overlay_z_label, retain=False) + self.zSliceScrollBar.setDisabled(True) + self.zProjComboBox.hide() + self.zProjComboBox.hide() + self.zSliceScrollBar.hide() + self.zSliceCheckbox.hide() + self.zSliceSpinbox.hide() + self.SizeZlabel.hide() + self.switchPlaneCombobox.hide() + self.switchPlaneCombobox.setDisabled(True) + + self.imgGrad.rescaleAcrossZstackAction.setDisabled(not enabled) + for ch, overlayItems in self.overlayLayersItems.items(): + lutItem = overlayItems[1] + lutItem.rescaleAcrossZstackAction.setDisabled(not enabled) + + def equalizeHist(self, checked=True): + self.img1.useEqualized = checked + + if not checked: + self.updateAllImages() + return + + self.logger.info("Equalizing image histogram...") + for pos_i, _posData in enumerate(self.data): + n_dim_img = _posData.img_data.ndim + _posData.equalized_img_data = preprocess.PreprocessedData() + for frame_i, img_frame in enumerate(_posData.img_data): + if n_dim_img == 4: + for z, img_z in enumerate(img_frame): + eq_img = skimage.exposure.equalize_adapthist(img_z) + _posData.equalized_img_data[frame_i][z] = eq_img + self.img1.updateMinMaxValuesEqualizedData( + self.data, pos_i, frame_i, z + ) + self.img1.updateMinMaxValuesEqualizedDataProjections( + self.data, pos_i, frame_i + ) + else: + eq_img = skimage.exposure.equalize_adapthist(img_frame) + _posData.equalized_img_data[frame_i] = eq_img + self.img1.updateMinMaxValuesEqualizedData( + self.data, pos_i, frame_i, None + ) + + self.updateAllImages() + + def getCheckNormAction(self): + normalize = False + how = "" + for action in self.normalizeQActionGroup.actions(): + if action.isChecked(): + how = action.text() + normalize = True + break + return action, normalize, how + + def getContoursImageItem(self, ax, force=False): + if not self.areContoursRequested(ax) and not force: + return + + if ax == 0: + return self.ax1_contoursImageItem + else: + return self.ax2_contoursImageItem + + def getDisplayedImg1(self): + return self.img1.image + + def getDisplayedZstack(self): + posData = self.data[self.pos_i] + return posData.img_data[posData.frame_i] + + def getDistantGray(self, desiredGray, bkgrGray): + isDesiredSimilarToBkgr = abs(desiredGray - bkgrGray) < 0.3 + if isDesiredSimilarToBkgr: + return 1 - desiredGray + else: + return desiredGray + + def getImage(self, frame_i=None, raw=False): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + if raw: + return self.getRawImageLayer0(frame_i) + + if self.viewPreprocDataToggle.isChecked(): + try: + img = posData.preproc_img_data[frame_i] + if posData.SizeZ == 1: + return np.array(img) + + self.updateZsliceScrollbar(frame_i) + z_slice = self.z_slice_index() + img = img[z_slice] + return img + except Exception as err: + # self.logger.warning( + # 'Pre-processed image not existing --> returning raw image' + # ) + return self.getRawImageLayer0(frame_i) + + viewCombinedImageData = ( + self.viewCombineChannelDataToggle.isChecked() + and self.combineDialog is not None + and not self.combineDialog.saveAsSegm() + ) + + if viewCombinedImageData: + try: + img = posData.combine_img_data[frame_i] + if posData.SizeZ == 1: + return np.array(img) + + self.updateZsliceScrollbar(frame_i) + z_slice = self.z_slice_index() + img = img[z_slice] + return img + except Exception as err: + # self.logger.warning( + # 'combined image not existing --> returning raw image' + # ) + return self.getRawImageLayer0(frame_i) + + if self.equalizeHistPushButton.isChecked(): + img = posData.equalized_img_data[frame_i] + if posData.SizeZ == 1: + return np.array(img) + + self.updateZsliceScrollbar(frame_i) + z_slice = self.z_slice_index() + img = img[z_slice] + return img + + return self.getRawImageLayer0(frame_i) + + def getImageDataFromFilename(self, filename): + posData = self.data[self.pos_i] + if filename == posData.filename: + return posData.img_data[posData.frame_i] + else: + return posData.ol_data_dict.get(filename) + + def getLostObjImageItem(self, ax): + if ax == 0: + return self.ax1_lostObjImageItem + else: + return self.ax1_lostTrackedObjImageItem + + def getLostTrackedObjImageItem(self, ax): + if ax == 0: + return self.ax1_lostTrackedObjImageItem + else: + return self.ax2_lostTrackedObjImageItem + + def getObjBbox(self, obj_bbox): + if self.isSegm3D and len(obj_bbox) == 6: + obj_bbox = (obj_bbox[1], obj_bbox[2], obj_bbox[4], obj_bbox[5]) + return obj_bbox + else: + return obj_bbox + + def getObjImage(self, obj_image, obj_bbox, z_slice=None): + if self.isSegm3D and len(obj_bbox) == 6: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if not isZslice: + # required a projection + return obj_image.max(axis=0) + + min_z = obj_bbox[0] + if z_slice is None: + z_slice = self.z_lab() + if isinstance(z_slice, tuple): + z_slice = z_slice[-1] + + local_z = z_slice - min_z + try: + obi_image_2d = obj_image[local_z] + except Exception as err: + obi_image_2d = None + return obi_image_2d + else: + return obj_image + + def getObjSlice(self, obj_slice): + if self.isSegm3D: + return obj_slice[1:3] + else: + return obj_slice + + def getObject2DimageFromZ(self, z, obj): + posData = self.data[self.pos_i] + z_min = obj.bbox[0] + local_z = z - z_min + if local_z >= posData.SizeZ or local_z < 0: + return + return obj.image[local_z] + + def getObject2DsliceFromZ(self, z, obj): + posData = self.data[self.pos_i] + z_min = obj.bbox[0] + local_z = z - z_min + if local_z >= posData.SizeZ or local_z < 0: + return + return obj.image[local_z] + + def getPreComputedMinMaxZstack(self, channel: str): + if channel != self.user_ch_name: + return None + + posData = self.data[self.pos_i] + zstack_min, zstack_max = np.inf, 0 + for z in range(posData.SizeZ): + key = (self.pos_i, posData.frame_i, z) + levels = self.img1.minMaxValuesMapper.get(key) + if levels is None: + return + + img_min, img_max = levels + if img_min < zstack_min: + zstack_min = img_min + + if img_max > zstack_max: + zstack_max = img_max + + return (zstack_min, zstack_max) + + def getRawImage(self, frame_i=None, filename=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + if filename is None: + rawImgData = posData.img_data[frame_i] + isLayer0 = True + else: + rawImgData = posData.ol_data[filename][frame_i] + isLayer0 = False + if posData.SizeZ > 1: + rawImg = self.get_2Dimg_from_3D(rawImgData, isLayer0=isLayer0) + else: + rawImg = rawImgData + return rawImg + + def getRawImageLayer0(self, frame_i): + posData = self.data[self.pos_i] + + if posData.SizeZ > 1: + img = posData.img_data[frame_i] + self.updateZsliceScrollbar(frame_i) + img = self.get_2Dimg_from_3D(img) + else: + img = posData.img_data[frame_i].copy() + + if img.ndim == 2: + return img + if img.ndim == 3 and img.shape[-1] in (3, 4): + return img + + raise ValueError( + "Raw image for display must be 2D (Y, X) or RGB/A (Y, X, 3 or 4); " + f"got shape={getattr(img, 'shape', None)}, ndim={getattr(img, 'ndim', None)} " + f"for frame_i={frame_i} (metadata SizeT={posData.SizeT}, SizeZ={posData.SizeZ}). " + "Check that metadata SizeT/SizeZ matches the loaded array (e.g. squeezed TIFF vs CSV)." + ) + + def get_2Dimg_from_3D(self, imgData, isLayer0=True, frame_i=None): + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + if frame_i < 0: + frame_i = 0 + frame_i = posData.frame_i = 0 + + axis_slice = self.zSliceScrollBar.sliderPosition() + if self.switchPlaneCombobox.depthAxes() == "x": + return imgData[:, :, axis_slice].copy() + elif self.switchPlaneCombobox.depthAxes() == "y": + return imgData[:, axis_slice].copy() + + idx = (posData.filename, frame_i) + zProjHow_L0 = self.zProjComboBox.currentText() + if isLayer0: + try: + z = posData.segmInfo_df.at[idx, "z_slice_used_gui"] + except ValueError as e: + z = posData.segmInfo_df.loc[idx, "z_slice_used_gui"].iloc[0] + zProjHow = zProjHow_L0 + else: + z = self.zSliceOverlay_SB.sliderPosition() + zProjHow_L1 = self.zProjOverlay_CB.currentText() + if zProjHow_L1 == "same as above": + zProjHow = zProjHow_L0 + else: + zProjHow = zProjHow_L1 + + if zProjHow == "single z-slice": + img = imgData[z] # .copy() + elif zProjHow == "max z-projection": + img = imgData.max(axis=0) + elif zProjHow == "mean z-projection": + img = imgData.mean(axis=0) + elif zProjHow == "median z-proj.": + img = np.median(imgData, axis=0) + return img + + def get_2Dlab(self, lab, force_z=True): + if self.isSegm3D: + if force_z: + return lab[self.z_lab()] + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + return lab[self.z_lab()] + else: + return lab.max(axis=0) + else: + return lab + + def get_2Drp(self, lab=None): + if self.isSegm3D: + if lab is None: + # self.currentLab2D is defined at self.setImageImg2() + lab = self.currentLab2D + lab = self.get_2Dlab(lab) + rp = skimage.measure.regionprops(lab) + return rp + else: + return self.data[self.pos_i].rp + + def initContoursImage(self): + posData = self.data[self.pos_i] + z_slice = self.z_lab() + img = posData.img_data[posData.frame_i] + Y, X = img[z_slice].shape[-2:] + + self.contoursImage = np.zeros((Y, X, 4), dtype=np.uint8) + + def initImgCmap(self): + if not "img_cmap" in self.df_settings.index: + self.df_settings.at["img_cmap", "value"] = "grey" + self.imgCmapName = self.df_settings.at["img_cmap", "value"] + self.imgCmap = self.imgGrad.cmaps[self.imgCmapName] + if self.imgCmapName != "grey": + # To ensure mapping to colors we need to normalize image + self.normalizeByMaxAction.setChecked(True) + + def initImgGradRescaleIntensitiesHowPreference(self): + posData = self.data[self.pos_i] + channelName = posData.user_ch_name + if f"how_rescale_intensities_{channelName}" not in self.df_settings.index: + return + + how = self.df_settings.at[f"how_rescale_intensities_{channelName}", "value"] + self.imgGrad.setRescaleIntensitiesHow(how) + + def initLostObjContoursImage(self): + posData = self.data[self.pos_i] + z_slice = self.z_lab() + img = posData.img_data[posData.frame_i] + Y, X = img[z_slice].shape[-2:] + + self.lostObjContoursImage = np.zeros((Y, X, 4), dtype=np.uint8) + + def initLostTrackedObjContoursImage(self): + posData = self.data[self.pos_i] + z_slice = self.z_lab() + img = posData.img_data[posData.frame_i] + Y, X = img[z_slice].shape[-2:] + + self.lostTrackedObjContoursImage = np.zeros((Y, X, 4), dtype=np.uint8) + + def initManualBackgroundImage(self): + posData = self.data[self.pos_i] + if hasattr(posData, "lab"): + Y, X = posData.lab.shape[-2:] + else: + Y, X = posData.img_data.shape[-2:] + if not hasattr(self, "manualBackgroundTextItems"): + self.manualBackgroundTextItems = {} + posData.manualBackgroundImage = np.zeros((Y, X, 4), dtype=np.uint8) + if posData.manualBackgroundLab is None: + posData.manualBackgroundLab = np.zeros((Y, X), dtype=np.uint32) + + def initTextAnnot(self, force=False): + posData = self.data[self.pos_i] + if hasattr(posData, "lab"): + Y, X = posData.lab.shape[-2:] + else: + Y, X = posData.img_data.shape[-2:] + self.textAnnot[0].initItem((Y, X)) + self.textAnnot[1].initItem((Y, X)) + + def invertBw(self, checked, update=True): + self.invertBwAlreadyCalledOnce = True + + try: + self.labelsGrad.invertBwAction.toggled.disconnect() + except Exception as err: + pass + + self.labelsGrad.invertBwAction.setChecked(checked) + self.labelsGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) + + try: + self.imgGrad.invertBwAction.toggled.disconnect() + except Exception as err: + pass + self.imgGrad.invertBwAction.setChecked(checked) + self.imgGrad.invertBwAction.toggled.connect(self.setCheckedInvertBW) + + self.imgGrad.setInvertedColorMaps(checked) + self.imgGrad.invertCurrentColormap(checked) + + self.imgGradRight.setInvertedColorMaps(checked) + self.imgGradRight.invertCurrentColormap(checked) + + if hasattr(self, "overlayLayersItems"): + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.invertBwAction.toggled.disconnect() + lutItem.invertBwAction.setChecked(checked) + lutItem.invertBwAction.toggled.connect(self.setCheckedInvertBW) + lutItem.setInvertedColorMaps(checked) + + if self.slideshowWin is not None: + self.slideshowWin.is_bw_inverted = checked + self.slideshowWin.update_img() + self.df_settings.at["is_bw_inverted", "value"] = "Yes" if checked else "No" + self.df_settings.to_csv(self.settings_csv_path) + if checked: + # Light mode + self.equalizeHistPushButton.setStyleSheet("") + self.graphLayout.setBackground(graphLayoutBkgrColor) + self.ax2_BrushCirclePen = pg.mkPen((150, 150, 150), width=2) + self.ax2_BrushCircleBrush = pg.mkBrush((200, 200, 200, 150)) + self.titleColor = "black" + else: + # Dark mode + self.equalizeHistPushButton.setStyleSheet( + "QPushButton {background-color: #282828; color: #F0F0F0;}" + ) + self.graphLayout.setBackground(darkBkgrColor) + self.ax2_BrushCirclePen = pg.mkPen(width=2) + self.ax2_BrushCircleBrush = pg.mkBrush((255, 255, 255, 50)) + self.titleColor = "white" + + if not hasattr(self, "textAnnot"): + return + + self.textAnnot[0].invertBlackAndWhite() + self.textAnnot[1].invertBlackAndWhite() + + self.objLabelAnnotRgb = tuple(self.textAnnot[0].item.colors()["label"][:3]) + self.textIDsColorButton.setColor(self.objLabelAnnotRgb) + self.imgGrad.textColorButton.setColor(self.objLabelAnnotRgb) + for items in self.overlayLayersItems.values(): + lutItem = items[1] + lutItem.textColorButton.setColor(self.objLabelAnnotRgb) + + if update: + self.updateAllImages() + + def isObjVisible(self, obj_bbox, debug=False, z_slice=None): + if z_slice is None: + z_slice = self.z_lab() + + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if not isZslice: + # required a projection --> all obj are visible + return True + + depthAxes = self.switchPlaneCombobox.depthAxes() + + min_z, min_y, min_x, max_z, max_y, max_x = obj_bbox + if depthAxes == "z": + min_val, max_val = min_z, max_z + val = z_slice + elif depthAxes == "y": + min_val, max_val = min_y, max_y + val = z_slice[-1] + else: + min_val, max_val = min_x, max_x + val = z_slice[-1] + + if val >= min_val and val < max_val: + return True + else: + return False + else: + return True + + def launchSlideshow(self): + posData = self.data[self.pos_i] + self.determineSlideshowWinPos() + if self.slideshowButton.isChecked(): + self.slideshowWin = apps.imageViewer( + parent=self, + button_toUncheck=self.slideshowButton, + linkWindow=posData.SizeT > 1, + enableOverlay=True, + enableMirroredCursor=True, + ) + self.slideshowWin.img.minMaxValuesMapper = self.img1.minMaxValuesMapper + self.slideshowWin.img.setCurrentPosIndex(self.pos_i) + h = self.drawIDsContComboBox.size().height() + self.slideshowWin.framesScrollBar.setFixedHeight(h) + self.slideshowWin.overlayButton.setChecked(self.overlayButton.isChecked()) + self.slideshowWin.sigHoveringImage.connect( + self.setMirroredCursorFromSecondWindow + ) + if posData.SizeZ > 1: + z_slice = self.zSliceScrollBar.sliderPosition() + self.slideshowWin.img.setCurrentZsliceIndex(z_slice) + self.slideshowWin.zSliceScrollBar.setSliderPosition(z_slice) + self.slideshowWin.z_label.setText( + f"z-slice {z_slice + 1:02}/{posData.SizeZ}" + ) + self.slideshowWin.update_img() + self.slideshowWin.show(left=self.slideshowWinLeft, top=self.slideshowWinTop) + else: + self.slideshowWin.close() + self.slideshowWin = None + + def normaliseIntensitiesActionTriggered(self, action): + how = action.text() + self.df_settings.at["how_normIntensities", "value"] = how + self.df_settings.to_csv(self.settings_csv_path) + self.updateAllImages() + self.updateImageValueFormatter() + + def normalizeIntensities(self, img): + action, normalize, how = self.getCheckNormAction() + if not normalize: + return img + + if how == "Do not normalize. Display raw image": + img = img + elif how == "Convert to floating point format with values [0, 1]": + img = utils.img_to_float(img) + # elif how == 'Rescale to 8-bit unsigned integer format with values [0, 255]': + # img = skimage.img_as_float(img) + # img = (img*255).astype(np.uint8) + # return img + elif how == "Rescale to [0, 1]": + img = skimage.img_as_float(img) + img = skimage.exposure.rescale_intensity(img) + elif how == "Normalize by max value": + img = img / np.max(img) + return img + + def removeAxLimits(self): + self.ax1.vb.state["limits"]["xLimits"] = [-1e307, +1e307] + self.ax1.vb.state["limits"]["yLimits"] = [-1e307, +1e307] + + def rescaleIntensExportToVideoDialog(self, how, channel, setImage=True): + if channel == self.user_ch_name: + lutItem = self.imgGrad + else: + lutItem = self.overlayLayersItems[channel][1] + + for action in lutItem.rescaleActionGroup.actions(): + if action.text() == how: + action.trigger() + # self.rescaleIntensitiesLut(setImage=setImage) + break + + def rescaleIntensitiesLut( + self, action: QAction = None, setImage: bool = True, imageItem=None + ): + if not self.isDataLoaded: + self.logger.info( + "WARNING: Data is not loaded. Intensities will be rescaled later." + ) + return + + posData = self.data[self.pos_i] + if imageItem is None: + imageItem = self.img1 + channel = self.user_ch_name + image_data = posData.img_data + else: + channel = imageItem.channelName + _, filename = self.getPathFromChName(channel, posData) + image_data = posData.fluo_data_dict[filename] + + triggeredByUser = True + if action is None: + triggeredByUser = False + action = imageItem.lutItem.rescaleActionGroup.checkedAction() + + how = action.text() + + self.df_settings.at[f"how_rescale_intensities_{channel}", "value"] = how + self.df_settings.to_csv(self.settings_csv_path) + + if how == "Rescale each 2D image": + if how == self.rescaleIntensChannelHowMapper[channel]: + # No need to update since we have autoscale + return + + imageItem.setEnableAutoLevels(True) + if setImage: + imageItem.setImage(imageItem.image) + return + + lutLevelsCh = posData.lutLevels[channel] + + if how == "Rescale across z-stack": + imageItem.setEnableAutoLevels(False) + levels_key = (how, posData.frame_i) + levels = lutLevelsCh.get(levels_key) + if levels is None: + levels = self.getPreComputedMinMaxZstack(channel) + + if levels is None: + image_zstack = image_data[posData.frame_i] + levels = (image_zstack.min(), image_zstack.max()) + lutLevelsCh[levels_key] = levels + imageItem.setLevels(levels) + elif how == "Rescale across time frames": + imageItem.setEnableAutoLevels(False) + levels_key = (how, None) + levels = lutLevelsCh.get(levels_key) + if levels is None: + levels = (image_data.min(), image_data.max()) + + lutLevelsCh[levels_key] = levels + imageItem.setLevels(levels) + elif how == "Choose custom levels...": + autoLevelsEnabledBefore = imageItem.autoLevelsEnabled + imageItem.setEnableAutoLevels(False) + if triggeredByUser: + current_min, current_max = imageItem.getLevels() + dtype_max = np.iinfo(image_data.dtype).max + max_value = image_data.max() + min_value = image_data.min() + win = apps.SetCustomLevelsLut( + init_min_value=current_min, + init_max_value=current_max, + maximum_max_value=max_value, + minimum_min_value=min_value, + parent=self, + ) + win.sigLevelsChanged.connect( + partial(self.customLevelsLutChanged, imageItem=imageItem) + ) + win.exec_() + if win.cancel: + imageItem.setEnableAutoLevels(autoLevelsEnabledBefore) + self.logger.info("Custom LUT levels setting cancelled.") + self.updateAllImages() + return + selectedLevels = win.selectedLevels + else: + selectedLevels = imageItem.getLevels() + imageItem.setLevels(selectedLevels) + elif how == "Do no rescale, display raw image": + imageItem.setEnableAutoLevels(False) + levels_key = (how, None) + levels = lutLevelsCh.get(levels_key) + if levels is None: + dtype_max = np.iinfo(image_data.dtype).max + levels = (0, dtype_max) + lutLevelsCh[levels_key] = levels + imageItem.setLevels(levels) + + self.rescaleIntensChannelHowMapper[channel] = how + + if setImage: + imageItem.setImage(imageItem.image) + + def resetRange(self): + if self.ax1_viewRange is None: + return + xRange, yRange = self.ax1_viewRange + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.ax2.vb.setRange(xRange=xRange, yRange=yRange) + self.ax1.vb.setRange(xRange=xRange, yRange=yRange) + self.ax1_viewRange = None + self.isRangeReset = True + + def resizeGui(self): + self.ax1.vb.state["limits"]["xRange"] = [None, None] + self.ax1.vb.state["limits"]["yRange"] = [None, None] + self.autoRange() + if self.ax1.getViewBox().state["limits"]["xRange"][0] is not None: + self.bottomScrollArea._resizeVertical() + return + (xmin, xmax), (ymin, ymax) = self.ax1.viewRange() + maxYRange = int((ymax - ymin) * 1.5) + maxXRange = int((xmax - xmin) * 1.5) + self.ax1.setLimits(maxYRange=maxYRange, maxXRange=maxXRange) + self.bottomScrollArea._resizeVertical() + QTimer.singleShot(200, self.autoRange) + + def ruler_cb(self, checked): + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.sender()) + self.connectLeftClickButtons() + else: + self.tempSegmentON = False + self.ax1_rulerPlotItem.setData([], []) + self.ax1_rulerAnchorsItem.setData([], []) + + def saveLabelsColormap(self): + self.labelsGrad.saveColormap() + + def setAnnotOptionsRightImageLabelsDisabled(self, disabled): + self.annotContourCheckboxRight.setDisabled(disabled) + self.annotSegmMasksCheckboxRight.setDisabled(disabled) + if disabled: + self.annotSegmMasksCheckboxRight.setChecked(False) + self.annotSegmMasksCheckboxRight.setChecked(False) + self.annotIDsCheckboxRight.setChecked(True) + + def setBottomLayoutStretch(self): + if ( + self.labelsGrad.showRightImgAction.isChecked() + or self.labelsGrad.showNextFrameAction.isChecked() + ): + # Equally share space between the two control groupboxes + self.bottomLayout.setStretch(1, 1) + self.bottomLayout.setStretch(2, 5) + self.bottomLayout.setStretch(3, 1) + self.bottomLayout.setStretch(4, 5) + self.bottomLayout.setStretch(5, 1) + elif self.labelsGrad.showLabelsImgAction.isChecked(): + # Left control takes only left space + self.bottomLayout.setStretch(1, 1) + self.bottomLayout.setStretch(2, 5) + self.bottomLayout.setStretch(3, 5) + self.bottomLayout.setStretch(4, 1) + self.bottomLayout.setStretch(5, 1) + else: + # Left control takes all the space + self.bottomLayout.setStretch(1, 3) + self.bottomLayout.setStretch(2, 10) + self.bottomLayout.setStretch(3, 1) + self.bottomLayout.setStretch(4, 1) + self.bottomLayout.setStretch(5, 1) + + def setCheckedInvertBW(self, checked): + self.invertBwAction.setChecked(checked) + + def setGraphicalAnnotZsliceScrolling(self): + posData = self.data[self.pos_i] + if self.isSegm3D: + self.currentLab2D = posData.lab[self.z_lab()] + self.setOverlaySegmMasks() + self.doCustomAnnotation(0) + self.update_rp_metadata() + else: + self.currentLab2D = posData.lab + self.setOverlaySegmMasks() + self.updateContoursImage(0) + self.updateContoursImage(1) + + def setHoverToolSymbolData(self, xx, yy, ScatterItems, size=None): + if not xx: + self.ax1_lostObjScatterItem.setVisible(True) + self.ax2_lostObjScatterItem.setVisible(True) + + self.ax1_lostTrackedScatterItem.setVisible(True) + self.ax2_lostTrackedScatterItem.setVisible(True) + + for item in ScatterItems: + if size is None: + item.setData(xx, yy) + else: + item.setData(xx, yy, size=size) + + def setImageImg1(self, image=None): + img = self._getImageupdateAllImages(image=image) + posData = self.data[self.pos_i] + self.img1.setCurrentPosIndex(self.pos_i) + self.img1.setCurrentFrameIndex(posData.frame_i) + if posData.SizeZ > 1: + zProjHow = self.zProjComboBox.currentText() + if zProjHow == "single z-slice": + z = self.zSliceScrollBar.sliderPosition() + else: + z = zProjHow + + self.img1.setCurrentZsliceIndex(z) + + self.img1.setImage( + img, + next_frame_image=self.nextFrameImage(), + scrollbar_value=posData.frame_i + 2, + ) + + def setImageImg2(self, updateLookuptable=True, set_image=True): + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Segmentation and Tracking" or self.isSnapshot: + # self.addExistingDelROIs() + allDelIDs, lab2D = self.getDelROIlab() + else: + lab2D = self.get_2Dlab(posData.lab, force_z=False) + allDelIDs = set() + + self.currentLab2D = lab2D + if self.labelsGrad.permanentGreedyCmapAction.isChecked() and updateLookuptable: + self.greedyShuffleCmap(updateImages=False) + + if self.labelsGrad.showLabelsImgAction.isChecked() and set_image: + self.img2.setImage(lab2D, z=self.z_lab(), autoLevels=False) + + if updateLookuptable: + self.updateLookuptable(delIDs=allDelIDs) + + def setLastUserNormAction(self): + how = self.df_settings.at["how_normIntensities", "value"] + for action in self.normalizeQActionGroup.actions(): + if action.text() == how: + action.setChecked(True) + break + + def setMirroredCursorFromSecondWindow(self, x, y): + if x is None: + xx, yy = [], [] + else: + xx, yy = [x], [y] + self.ax1_cursor.setData(xx, yy) + if not self.isTwoImageLayout: + return + self.ax2_cursor.setData(xx, yy) + + def setTextAnnotZsliceScrolling(self): + pass + + def setTwoImagesLayout(self, isTwoImages): + self.isTwoImageLayout = isTwoImages + if isTwoImages: + self.graphLayout.removeItem(self.titleLabel) + self.graphLayout.addItem(self.titleLabel, row=0, col=1, colspan=2) + # self.mainLayout.setAlignment(self.bottomLayout, Qt.AlignLeft) + self.ax2.show() + self.ax2.vb.setYLink(self.ax1.vb) + self.ax2.vb.setXLink(self.ax1.vb) + else: + self.graphLayout.removeItem(self.titleLabel) + self.graphLayout.addItem(self.titleLabel, row=0, col=1) + # self.mainLayout.setAlignment(self.bottomLayout, Qt.AlignCenter) + self.ax2.hide() + oldLink = self.ax2.vb.linkedView(self.ax1.vb.YAxis) + try: + oldLink.sigYRangeChanged.disconnect() + oldLink.sigXRangeChanged.disconnect() + except TypeError: + pass + + def set_2Dlab(self, lab2D, lab3D=None): + posData = self.data[self.pos_i] + + if lab3D is None: + lab3D = posData.lab + + if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" + if isZslice: + lab3D[self.z_lab()] = lab2D + else: + lab3D[:] = lab2D + else: + if lab3D.shape == lab2D.shape: + lab3D[...] = lab2D + else: + posData.lab = lab2D + + def showLabelImageItem(self, checked): + self.rightImageFramesScrollbar.setVisible(not checked) + self.rightImageFramesScrollbar.setDisabled(checked) + self.setTwoImagesLayout(checked) + self.setAnnotOptionsRightImageLabelsDisabled(checked) + if checked: + self.df_settings.at["isLabelsVisible", "value"] = "Yes" + self.df_settings.at["isNextFrameVisible", "value"] = "No" + self.df_settings.at["isRightImageVisible", "value"] = "No" + self.rightBottomGroupbox.show() + self.rightBottomGroupbox.setChecked(True) + if not self.isDataLoading: + self.updateAllImages() + else: + self.clearAx2Items() + self.img2.clear() + self.df_settings.at["isLabelsVisible", "value"] = "No" + self.rightBottomGroupbox.hide() + self.moveDelRoisToLeft() + + self.df_settings.to_csv(self.settings_csv_path) + QTimer.singleShot(200, self.resizeGui) + + self.setBottomLayoutStretch() + + def showMirroredCursorToggled(self, checked): + value = "Yes" if checked else "No" + self.df_settings.at["showMirroredCursor", "value"] = value + self.df_settings.to_csv(settings_csv_path) + + if not checked: + self.clearCursors() + + def showNextFrameImageItem(self, checked): + self.rightImageFramesScrollbar.setVisible(checked) + self.rightImageFramesScrollbar.setDisabled(not checked) + self.setTwoImagesLayout(checked) + if checked: + self.df_settings.at["isNextFrameVisible", "value"] = "Yes" + self.df_settings.at["isRightImageVisible", "value"] = "No" + self.df_settings.at["isLabelsVisible", "value"] = "No" + self.graphLayout.addItem(self.imgGradRight, row=1, col=self.plotsCol + 2) + self.rightBottomGroupbox.show() + self.rightBottomGroupbox.setChecked(True) + self.drawNothingCheckboxRight.click() + if not self.isDataLoading: + self.updateAllImages() + else: + self.clearAx2Items() + self.rightBottomGroupbox.hide() + self.df_settings.at["isNextFrameVisible", "value"] = "No" + try: + self.graphLayout.removeItem(self.imgGradRight) + except Exception: + return + self.rightImageItem.clear() + + self.df_settings.to_csv(self.settings_csv_path) + + QTimer.singleShot(300, self.resizeGui) + + self.setBottomLayoutStretch() + + def showRightImageItem(self, checked): + self.rightImageFramesScrollbar.setVisible(not checked) + self.rightImageFramesScrollbar.setDisabled(checked) + self.setTwoImagesLayout(checked) + if checked: + self.df_settings.at["isRightImageVisible", "value"] = "Yes" + self.df_settings.at["isNextFrameVisible", "value"] = "No" + self.df_settings.at["isLabelsVisible", "value"] = "No" + self.graphLayout.addItem(self.imgGradRight, row=1, col=self.plotsCol + 2) + self.rightBottomGroupbox.show() + if not self.isDataLoading: + self.updateAllImages() + else: + self.clearAx2Items() + self.rightBottomGroupbox.hide() + self.df_settings.at["isRightImageVisible", "value"] = "No" + try: + self.graphLayout.removeItem(self.imgGradRight) + except Exception: + return + self.rightImageItem.clear() + + self.df_settings.to_csv(self.settings_csv_path) + + QTimer.singleShot(300, self.resizeGui) + + self.setBottomLayoutStretch() + + def updateAllImages( + self, + image=None, + computePointsLayers=True, + computeContours=True, + updateLookuptable=True, + ): + self.clearAllItems() + + posData = self.data[self.pos_i] + + self.last_pos_i = self.pos_i + self.last_frame_i = posData.frame_i + + self.rescaleIntensitiesLut(setImage=False) + + self.setImageImg1(image=image) + self.setImageImg2(updateLookuptable=updateLookuptable) + + self.setOverlayImages() + + self.setOverlayLabelsItems() + self.setOverlaySegmMasks() + + if self.slideshowWin is not None: + self.slideshowWin.frame_i = posData.frame_i + self.slideshowWin.update_img() + + # self.update_rp() + + # Annotate ID and draw contours + delROIsIDs = self.setAllTextAnnotations() + self.setAllContoursImages(delROIsIDs=delROIsIDs, compute=False) + + mode = self.modeComboBox.currentText() + self.drawAllMothBudLines() + if mode == "Normal division: Lineage tree": + self.drawAllLineageTreeLines() + + self.highlightLostNew() + + if self.ccaTableWin is not None: # need to add for lin tree, later + zoomIDs = self.getZoomIDs() + self.ccaTableWin.updateTable(posData.cca_df, IDs=zoomIDs) + + self.doCustomAnnotation(0) + + self.annotate_rip_and_bin_IDs() + self.updateTempLayerKeepIDs() + self.whitelistUpdateTempLayer() + self.drawPointsLayers(computePointsLayers=computePointsLayers) + self.setManualBackgroundImage() + self.annotateAssignedObjsAcdcTrackerSecondStep() + + self.highlightSearchedID(self.highlightedID, force=True) + self.updateTimestampFrame() + + posData.visited = True + + def updateImageValueFormatter(self): + if self.img1.image is not None: + dtype = self.img1.image.dtype + n_digits = len(str(int(self.img1.image.max()))) + self.imgValueFormatter = utils.get_number_fstring_formatter( + dtype, precision=abs(n_digits - 5) + ) + + rawImgData = self.data[self.pos_i].img_data + dtype = rawImgData.dtype + n_digits = len(str(int(rawImgData.max()))) + self.rawValueFormatter = utils.get_number_fstring_formatter( + dtype, precision=abs(n_digits - 5) + ) + + def updateLabelsAlpha(self, value): + self.df_settings.at["overlaySegmMasksAlpha", "value"] = value + self.df_settings.to_csv(self.settings_csv_path) + if self.keepIDsButton.isChecked(): + value = value / 3 + self.labelsLayerImg1.setOpacity(value) + self.labelsLayerRightImg.setOpacity(value) + + def updateZsliceScrollbar(self, frame_i): + posData = self.data[self.pos_i] + if self.switchPlaneCombobox.depthAxes() != "z": + return + + idx = (posData.filename, frame_i) + try: + z = posData.segmInfo_df.at[idx, "z_slice_used_gui"] + except ValueError as e: + z = posData.segmInfo_df.loc[idx, "z_slice_used_gui"].iloc[0] + try: + zProjHow = posData.segmInfo_df.at[idx, "which_z_proj_gui"] + except ValueError as e: + zProjHow = posData.segmInfo_df.loc[idx, "which_z_proj_gui"].iloc[0] + + self.zProjComboBox.setCurrentText(zProjHow) + + reconnect = False + try: + self.zSliceScrollBar.actionTriggered.disconnect() + self.zSliceScrollBar.sliderReleased.disconnect() + reconnect = True + except TypeError: + pass + self.zSliceScrollBar.setSliderPosition(z) + if reconnect: + self.zSliceScrollBar.actionTriggered.connect( + self.zSliceScrollBarActionTriggered + ) + self.zSliceScrollBar.sliderReleased.connect(self.zSliceScrollBarReleased) + self.zSliceSpinbox.setValueNoEmit(z + 1) + + def zProjLockViewToggled(self, checked): + self.updateZproj(self.zProjComboBox.currentText()) + + def z_lab(self, checkIfProj=False): + if checkIfProj and self.zProjComboBox.currentText() != "single z-slice": + return + + if not self.isSegm3D: + return + + posData = self.data[self.pos_i] + + idx = self.zSliceScrollBar.sliderPosition() + + # ensure idx doesnt exceed the number of z-slices of the position + idx_z = min(idx, posData.SizeZ - 1) + + if not self.switchPlaneCombobox.isEnabled(): + return idx_z + + depthAxes = self.switchPlaneCombobox.depthAxes() + if depthAxes == "z": + return idx_z + elif depthAxes == "y": + idx_y = min(idx, posData.SizeY - 1) + return (slice(None), idx_y) + else: + idx_x = min(idx, posData.SizeX - 1) + return (slice(None), slice(None), idx_x) + + def z_slice_index(self): + posData = self.data[self.pos_i] + if posData.SizeZ == 1: + return None + zProjHow = self.zProjComboBox.currentText() + if zProjHow != "single z-slice": + return zProjHow + + axis_slice = self.zSliceScrollBar.sliderPosition() + if self.switchPlaneCombobox.depthAxes() == "x": + z_slice = (slice(None, None, None), slice(None, None, None), axis_slice) + elif self.switchPlaneCombobox.depthAxes() == "y": + z_slice = (slice(None, None, None), axis_slice) + else: + z_slice = axis_slice + + return z_slice + + def zoomOut(self): + self.ax1.autoRange() + + def zoomToCells(self, enforce=False): + if not self.enableAutoZoomToCellsAction.isChecked() and not enforce: + return + + posData = self.data[self.pos_i] + lab_mask = (self.currentLab2D > 0).astype(np.uint8) + rp = skimage.measure.regionprops(lab_mask) + if not rp: + Y, X = lab_mask.shape + xRange = -0.5, X + 0.5 + yRange = -0.5, Y + 0.5 + else: + obj = rp[0] + min_row, min_col, max_row, max_col = self.getObjBbox(obj.bbox) + xRange = min_col - 10, max_col + 10 + yRange = max_row + 10, min_row - 10 + + self.ax1.setRange(xRange=xRange, yRange=yRange) + + def zoomToObjsActionCallback(self): + self.zoomToCells(enforce=True) diff --git a/cellacdc/mixins/label_editing.py b/cellacdc/mixins/label_editing.py new file mode 100644 index 000000000..ac5ab9fec --- /dev/null +++ b/cellacdc/mixins/label_editing.py @@ -0,0 +1,747 @@ +"""Qt view adapter for label-editing workflows.""" + +from __future__ import annotations + +import math + +import numpy as np +import skimage.measure +from qtpy.QtCore import Qt +from qtpy.QtGui import QGuiApplication +from qtpy.QtWidgets import QAction + +from cellacdc import apps, disableWindow, exception_handler + +from .tool_activation import ToolActivation + + +class LabelEditing(ToolActivation): + """Extracted from guiWin.""" + + def _get_editID_info(self, df): + if "was_manually_edited" not in df.columns: + return [] + + if "y_centroid" not in df.columns or "x_centroid" not in df.columns: + df = self.addYXcentroidToDf(df) + + manually_edited_df = df[df["was_manually_edited"] > 0] + editID_info = [ + (row.y_centroid, row.x_centroid, row.Index) + for row in manually_edited_df.itertuples() + ] + return editID_info + + def _update_zslices_rp(self): + if not self.isSegm3D: + return + + posData = self.data[self.pos_i] + posData.zSlicesRp = {} + for z, lab2d in enumerate(posData.lab): + lab2d_rp = skimage.measure.regionprops(lab2d) + posData.zSlicesRp[z] = {obj.label: obj for obj in lab2d_rp} + + def addYXcentroidToDf(self, df): + posData = self.data[self.pos_i] + for obj in posData.rp: + y_centroid = int(self.getObjCentroid(obj.centroid)[0]) + x_centroid = int(self.getObjCentroid(obj.centroid)[1]) + df.at[obj.label, "y_centroid"] = y_centroid + df.at[obj.label, "x_centroid"] = x_centroid + return df + + def applyEditID( + self, + clickedID, + currentIDs, + oldIDnewIDMapper, + clicked_x, + clicked_y, + shift=False, + doPropagateUnvisited=False, + ): + posData = self.data[self.pos_i] + + # Ask to propagate change to all future visited frames + key = "Edit ID" + askAction = self.askHowFutureFramesActions[key] + doNotShow = not askAction.isChecked() + (UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain) = ( + self.propagateChange( + clickedID, + key, + doNotShow, + posData.UndoFutFrames_EditID, + posData.applyFutFrames_EditID, + applyTrackingB=True, + ) + ) + + if UndoFutFrames is None: + return + + if shift and self.isSegm3D: + lab = self.get_2Dlab(posData.lab) + else: + lab = posData.lab + + # Store undo state before modifying stuff + self.storeUndoRedoStates(UndoFutFrames) + maxID = max(posData.IDs, default=0) + for old_ID, new_ID in oldIDnewIDMapper: + if new_ID in currentIDs and not self.editIDmergeIDs: + tempID = maxID + 1 + lab[lab == old_ID] = maxID + 1 + lab[lab == new_ID] = old_ID + lab[lab == tempID] = new_ID + maxID += 1 + + old_ID_idx = currentIDs.index(old_ID) + new_ID_idx = currentIDs.index(new_ID) + + # Append information for replicating the edit in tracking + # List of tuples (y, x, replacing ID) + objo = posData.rp[old_ID_idx] + yo, xo = self.getObjCentroid(objo.centroid) + objn = posData.rp[new_ID_idx] + yn, xn = self.getObjCentroid(objn.centroid) + if not math.isnan(yo) and not math.isnan(yn): + yn, xn = int(yn), int(xn) + posData.editID_info.append((yn, xn, new_ID)) + yo, xo = int(clicked_y), int(clicked_x) + posData.editID_info.append((yo, xo, old_ID)) + else: + lab[lab == old_ID] = new_ID + if new_ID > maxID: + maxID = new_ID + old_ID_idx = posData.IDs.index(old_ID) + + # Append information for replicating the edit in tracking + # List of tuples (y, x, replacing ID) + obj = posData.rp[old_ID_idx] + y, x = self.getObjCentroid(obj.centroid) + if not math.isnan(y) and not math.isnan(y): + y, x = int(y), int(x) + posData.editID_info.append((y, x, new_ID)) + + self.updateAssignedObjsAcdcTrackerSecondStep(new_ID) + + if shift and self.isSegm3D: + self.set_2Dlab(lab) + + # Update rps + self.update_rp() + + # Since we manually changed an ID we don't want to repeat tracking + self.setAllTextAnnotations() + self.highlightLostNew() + # self.checkIDsMultiContour() + + # Update colors for the edited IDs + self.updateLookuptable() + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Edit ID") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Edit ID", update_images=False) + + if not self.editIDbutton.findChild(QAction).isChecked(): + self.editIDbutton.setChecked(False) + + posData.disableAutoActivateViewerWindow = True + + # Perform desired action on future frames + posData.doNotShowAgain_EditID = doNotShowAgain + posData.UndoFutFrames_EditID = UndoFutFrames + posData.applyFutFrames_EditID = applyFutFrames + includeUnvisited = ( + posData.includeUnvisitedInfo["Edit ID"] or doPropagateUnvisited + ) + + if not applyFutFrames and not doPropagateUnvisited: + return + + self.changeIDfutureFrames( + endFrame_i, oldIDnewIDMapper, includeUnvisited, shift=shift + ) + + def apply_manual_edits_to_lab_if_needed(self, lab): + posData = self.data[self.pos_i] + data_frame_i = posData.allData_li[posData.frame_i] + edited_lab_dict = data_frame_i["manually_edited_lab"]["lab"] + if not edited_lab_dict: + return lab + + # zoom_slice = data_frame_i['manually_edited_lab']['zoom_slice'] + for z, lab_edited in edited_lab_dict.items(): + if not self.isSegm3D: + # lab[zoom_slice] = lab_edited + lab = lab_edited + break + + lab[z] = lab_edited + + # lab[z, zoom_slice[0], zoom_slice[1]] = zoom_lab + + return lab + + def assignNewIDfromClickedID(self, clickedID: int, event: QGraphicsSceneMouseEvent): + posData = self.data[self.pos_i] + x, y = event.pos().x(), event.pos().y() + newID = self.setBrushID(return_val=True) + mapper = [(clickedID, newID)] + self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y) + + def changeIDfutureFrames( + self, endFrame_i, oldIDnewIDMapper, includeUnvisited, shift=False + ): + posData = self.data[self.pos_i] + self.current_frame_i = posData.frame_i + + # Store data for current frame + self.store_data() + if endFrame_i is None: + self.app.restoreOverrideCursor() + return + + segmSizeT = len(posData.segm_data) + for i in range(posData.frame_i + 1, segmSizeT): + lab = posData.allData_li[i]["labels"] + if lab is None and not includeUnvisited: + self.enqAutosave() + break + + if lab is not None: + # Visited frame + posData.frame_i = i + self.get_data(lin_tree_init=False) + if shift and self.isSegm3D: + lab = self.get_2Dlab(posData.lab) + else: + lab = posData.lab + + if self.onlyTracking: + self.tracking(enforce=True) + elif not posData.IDs: + continue + else: + maxID = max(posData.IDs, default=0) + 1 + for old_ID, new_ID in oldIDnewIDMapper: + if new_ID in lab: + tempID = maxID + 1 # lab.max() + 1 + lab[lab == old_ID] = tempID + lab[lab == new_ID] = old_ID + lab[lab == tempID] = new_ID + maxID += 1 + else: + lab[lab == old_ID] = new_ID + + if shift and self.isSegm3D: + self.set_2Dlab(lab) + + self.update_rp(draw=False) + self.store_data(autosave=i == endFrame_i) + elif includeUnvisited: + # Unvisited frame (includeUnvisited = True) + lab = posData.segm_data[i] + if shift and self.isSegm3D: + lab = self.get_2Dlab(lab) + else: + lab = lab + + for old_ID, new_ID in oldIDnewIDMapper: + if new_ID in lab: + tempID = lab.max() + 1 + lab[lab == old_ID] = tempID + lab[lab == new_ID] = old_ID + lab[lab == tempID] = new_ID + else: + lab[lab == old_ID] = new_ID + + if shift and self.isSegm3D: + posData.segm_data[i][self.z_lab()] = lab + + # Back to current frame + posData.frame_i = self.current_frame_i + self.get_data() + self.app.restoreOverrideCursor() + + def delBorderObj(self, checked): + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + posData = self.data[self.pos_i] + posData.lab = skimage.segmentation.clear_border(posData.lab, buffer_size=1) + oldIDs = posData.IDs.copy() + self.update_rp() + removedIDs = [ID for ID in oldIDs if ID not in posData.IDs] + if posData.cca_df is not None: + posData.cca_df = posData.cca_df.drop(index=removedIDs) + self.store_data() + self.updateAllImages() + + def delNewObj(self, checked): + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + posData = self.data[self.pos_i] + frame_i = posData.frame_i + + if frame_i == 0: + return + + prev_IDs = posData.allData_li[frame_i - 1]["IDs"] + curr_IDs = posData.IDs + new_IDs = list(set(curr_IDs) - set(prev_IDs)) + + lab = posData.lab + del_mask = np.isin(lab, new_IDs) + lab[del_mask] = 0 + posData.lab = lab + + self.update_rp() + + if posData.cca_df is not None: + posData.cca_df = posData.cca_df.drop(index=new_IDs) + self.store_data() + self.updateAllImages() + + def deleteIDFromLab(self, lab, delID, frame_i=None, delMask=None, shift=False): + posData = self.data[self.pos_i] + frame_i = posData.frame_i if frame_i is None else frame_i + + if shift and self.isSegm3D: + lab3D = lab + delMask3D = delMask + lab = self.get_2Dlab(lab) + if delMask is not None: + delMask = self.get_2Dlab(delMask) + rp = skimage.measure.regionprops(lab) + IDs_idxs = {obj.label: idx for idx, obj in enumerate(rp)} + else: + if frame_i == posData.frame_i: + rp = posData.rp + IDs_idxs = posData.IDs_idxs + else: + rp = posData.allData_li[frame_i]["regionprops"] + IDs_idxs = posData.allData_li[frame_i]["IDs_idxs"] + + if isinstance(delID, int): + delID = [delID] + + is_any_id_present = False + for _delID in delID: + if _delID in IDs_idxs: + is_any_id_present = True + break + + if not is_any_id_present: + return lab, delMask + + if delMask is None: + delMask = np.zeros(lab.shape, dtype=bool) + else: + delMask[:] = False + + for _delID in delID: + idx = IDs_idxs.get(_delID, None) + if idx is None: + continue + obj = rp[idx] + delMask[obj.slice][obj.image] = True + lab[delMask] = 0 + + if shift and self.isSegm3D: + self.set_2Dlab(lab, lab3D=lab3D) + lab = lab3D + if delMask3D is not None: + self.set_2Dlab(delMask, lab3D=delMask3D) + delMask = delMask3D + + return lab, delMask + + def deleteIDmiddleClick( + self, delIDs: Iterable, applyFutFrames, includeUnvisited, shift=False + ): + self.clearHighlightedID() + + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + + # Apply Delete ID to future frames if requested + if applyFutFrames: + delMask = np.zeros(posData.lab.shape, dtype=bool) + # Store current data before going to future frames + self.store_data() + segmSizeT = len(posData.segm_data) + for i in range(posData.frame_i + 1, segmSizeT): + lab = posData.allData_li[i]["labels"] + if lab is None and not includeUnvisited: + self.enqAutosave() + break + + if lab is not None: + # Visited frame + lab, _ = self.deleteIDFromLab( + lab, delIDs, frame_i=i, delMask=delMask, shift=shift + ) + + # Store change + posData.allData_li[i]["labels"] = lab + # Get the rest of the stored metadata based on the new lab + posData.frame_i = i + self.get_data() + self.store_data(autosave=False) + elif includeUnvisited: + # Unvisited frame (includeUnvisited = True) + lab = posData.segm_data[i] + lab, _ = self.deleteIDFromLab( + lab, delIDs, frame_i=i, delMask=delMask, shift=shift + ) + + # Back to current frame + if applyFutFrames: + posData.frame_i = current_frame_i + self.get_data() + + z_slice = None + if shift and self.isSegm3D: + z_slice = self.z_lab() + + posData.lab, delID_mask = self.deleteIDFromLab(posData.lab, delIDs, shift=shift) + for _delID in delIDs: + self.clearObjContour(ID=_delID, ax=0) + self.clearObjContour(ID=_delID, ax=1) + if z_slice is None: + self.removeObjectFromRp(_delID) + self.removeStoredContours(_delID, z_slice=z_slice) + + if shift and self.isSegm3D: + self.update_rp() + + self.store_data(autosave=False) + self.whitelistPropagateIDs( + IDs_to_remove=delIDs, curr_frame_only=(not applyFutFrames) + ) + return delID_mask + + def getClickedID(self, xdata, ydata, text=""): + posData = self.data[self.pos_i] + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + if ID == 0: + msg = f"You clicked on the background.\nEnter here the ID {text}" + nearest_ID = core.nearest_nonzero_2D( + self.get_2Dlab(posData.lab), xdata, ydata + ) + clickedBkgrID = apps.QLineEditDialog( + title="Clicked on background", + msg=msg, + parent=self, + allowedValues=posData.IDs, + defaultTxt=str(nearest_ID), + isInteger=True, + ) + clickedBkgrID.exec_() + if clickedBkgrID.cancel: + return + else: + ID = clickedBkgrID.EntryID + return ID + + def getHoverID(self, xdata, ydata, byPassShiftCheck=False): + if not hasattr(self, "diskMask"): + return 0 + + modifiers = QGuiApplication.keyboardModifiers() + ctrl = modifiers == Qt.ControlModifier + if byPassShiftCheck: + shift = False + else: + shift = modifiers == Qt.ShiftModifier + + if self.isPowerBrush() and not ctrl: + return 0 + + if not self.autoIDcheckbox.isChecked(): + return self.editIDspinbox.value() + + ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) + posData = self.data[self.pos_i] + lab_2D = self.get_2Dlab(posData.lab) + ID = lab_2D[ydata, xdata] + self.isHoverZneighID = False + if self.isSegm3D: + z = self.z_lab() + SizeZ = posData.lab.shape[0] + doNotLinkThroughZ = self.brushButton.isChecked() and shift + if doNotLinkThroughZ: + if self.brushHoverCenterModeAction.isChecked() or ID > 0: + hoverID = ID + else: + masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask] + hoverID = np.bincount(masked_lab).argmax() + else: + if z > 0: + ID_z_under = posData.lab[z - 1, ydata, xdata] + if self.brushHoverCenterModeAction.isChecked() or ID_z_under > 0: + hoverIDa = ID_z_under + else: + lab = posData.lab + masked_lab_a = lab[z - 1, ymin:ymax, xmin:xmax][diskMask] + hoverIDa = np.bincount(masked_lab_a).argmax() + else: + hoverIDa = 0 + + if self.brushHoverCenterModeAction.isChecked() or ID > 0: + hoverIDb = lab_2D[ydata, xdata] + else: + masked_lab_b = lab_2D[ymin:ymax, xmin:xmax][diskMask] + hoverIDb = np.bincount(masked_lab_b).argmax() + + if z < SizeZ - 1: + ID_z_above = posData.lab[z + 1, ydata, xdata] + if self.brushHoverCenterModeAction.isChecked() or ID_z_above > 0: + hoverIDc = ID_z_above + else: + lab = posData.lab + masked_lab_c = lab[z + 1, ymin:ymax, xmin:xmax][diskMask] + hoverIDc = np.bincount(masked_lab_c).argmax() + else: + hoverIDc = 0 + + if hoverIDa > 0: + hoverID = hoverIDa + self.isHoverZneighID = True + elif hoverIDb > 0: + hoverID = hoverIDb + elif hoverIDc > 0: + hoverID = hoverIDc + self.isHoverZneighID = True + else: + hoverID = 0 + else: + if self.brushButton.isChecked() and shift: + # Force new ID with brush and Shift + hoverID = 0 + elif self.brushHoverCenterModeAction.isChecked() or ID > 0: + hoverID = ID + else: + masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask] + hoverID = np.bincount(masked_lab).argmax() + + self.editIDspinbox.setValue(hoverID) + + return hoverID + + def getLastHoveredID(self): + if self.xHoverImg is None: + return 0 + + xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) + ID = self.currentLab2D[ydata, xdata] + return ID + + def get_zslices_rp(self): + if not self.isSegm3D: + return + + posData = self.data[self.pos_i] + self.store_zslices_rp() + posData.zSlicesRp = posData.allData_li[posData.frame_i]["z_slices_rp"] + + def isPowerBrush(self): + color = self.brushButton.palette().button().color().name() + return color == self.doublePressKeyButtonColor + + def isPowerButton(self, button): + color = button.palette().button().color().name() + return color == self.doublePressKeyButtonColor + + def isPowerEraser(self): + color = self.eraserButton.palette().button().color().name() + return color == self.doublePressKeyButtonColor + + def mergeObjs_cb(self, checked): + if not checked: + self.mergeObjsTempLine.setData([], []) + + def removeObjectFromRp(self, delID): + posData = self.data[self.pos_i] + rp = [] + IDs = [] + IDs_idxs = {} + idx = 0 + for obj in posData.rp: + if obj.label == delID: + continue + rp.append(obj) + IDs.append(obj.label) + IDs_idxs[obj.label] = idx + idx += 1 + + posData.rp = rp + posData.IDs = IDs + posData.IDs_idxs = IDs_idxs + + if not self.isSegm3D: + return + + zSlicesRp = {} + for z, zSliceRp in posData.zSlicesRp.items(): + if delID in zSliceRp: + continue + + zSlicesRp[z] = zSlicesRp + + posData.zSlicesRp = zSlicesRp + self.store_zslices_rp(force_update=True) + + def removeStoredContours(self, delID, frame_i=None, z_slice=None): + posData = self.data[self.pos_i] + + if frame_i is None: + frame_i = posData.frame_i + + dataDict = posData.allData_li[posData.frame_i] + try: + newContours = {} + for key, contours in dataDict["contours"].items(): + ID = key[0] + if ID == delID: + continue + + if z_slice is not None: + z_slice_i = key[1] + if z_slice_i != z_slice: + continue + + newContours[key] = contours + + dataDict["contours"] = newContours + except KeyError as err: + pass + + def setHoverToolSymbolColor( + self, + xdata, + ydata, + pen, + ScatterItems, + button, + brush=None, + hoverRGB=None, + ID=None, + byPassShiftCheck=False, + ): + modifiers = QGuiApplication.keyboardModifiers() + if byPassShiftCheck: + shift = False + else: + shift = modifiers == Qt.ShiftModifier + + posData = self.data[self.pos_i] + Y, X = self.get_2Dlab(posData.lab).shape + if not utils.is_in_bounds(xdata, ydata, X, Y): + return + + self.isHoverZneighID = False + if ID is None: + hoverID = self.getHoverID(xdata, ydata, byPassShiftCheck=byPassShiftCheck) + else: + hoverID = ID + + if hoverID == 0: + for item in ScatterItems: + item.setPen(pen) + item.setBrush(brush) + else: + try: + rgb = self.lut[hoverID] + rgb = rgb if hoverRGB is None else hoverRGB + rgbPen = np.clip(rgb * 1.1, 0, 255) + for item in ScatterItems: + item.setPen(*rgbPen, width=2) + item.setBrush(*rgb, 100) + except IndexError: + pass + + checkChangeID = ( + self.isHoverZneighID and not shift and self.lastHoverID != hoverID + ) + if checkChangeID: + # We are hovering an ID in z+1 or z-1 + self.restoreBrushID = hoverID + # self.changeBrushID() + + self.lastHoverID = hoverID + + def store_zslices_rp(self, force_update=False): + if not self.isSegm3D: + return + + posData = self.data[self.pos_i] + are_zslices_rp_stored = ( + posData.allData_li[posData.frame_i].get("z_slices_rp") is not None + ) + if force_update or not are_zslices_rp_stored: + self._update_zslices_rp() + + posData.allData_li[posData.frame_i]["z_slices_rp"] = posData.zSlicesRp + + def update_rp( + self, + draw=True, + debug=False, + update_IDs=True, + wl_update=True, + wl_track_og_curr=False, + wl_update_lab=False, + ): + + posData = self.data[self.pos_i] + # Update rp for current posData.lab (e.g. after any change) + + if wl_update: + if self.whitelistOriginalIDs is None: + old_IDs = posData.allData_li[posData.frame_i][ + "IDs" + ].copy() # for whitelist stuff + else: + old_IDs = self.whitelistOriginalIDs.copy() + self.whitelistOriginalIDs = None + elif self.whitelistOriginalIDs is None: + self.whitelist_old_IDs = posData.allData_li[posData.frame_i]["IDs"].copy() + + posData.rp = skimage.measure.regionprops(posData.lab) + if update_IDs: + IDs = [] + IDs_idxs = {} + for idx, obj in enumerate(posData.rp): + IDs.append(obj.label) + IDs_idxs[obj.label] = idx + posData.IDs = IDs + posData.IDs_idxs = IDs_idxs + self.update_rp_metadata(draw=draw) + self.store_zslices_rp(force_update=True) + + if not wl_update: + return + + # Update tracking whitelist + accepted_lost_centroids = self.getTrackedLostIDs() + new_IDs = posData.IDs + added_IDs = set(new_IDs) - set(old_IDs) + removed_IDs = set(old_IDs) - set(new_IDs) - set(accepted_lost_centroids) + + self.whitelistPropagateIDs( + IDs_to_add=added_IDs, + IDs_to_remove=removed_IDs, + curr_frame_only=True, + IDs_curr=new_IDs, + track_og_curr=wl_track_og_curr, + curr_lab=posData.lab, + curr_rp=posData.rp, + update_lab=wl_update_lab, + ) diff --git a/cellacdc/mixins/label_roi.py b/cellacdc/mixins/label_roi.py new file mode 100644 index 000000000..8a4d9cdaa --- /dev/null +++ b/cellacdc/mixins/label_roi.py @@ -0,0 +1,503 @@ +"""Qt view adapter for label-ROI workflows.""" + +from __future__ import annotations + +import numpy as np +import os +from qtpy.QtCore import QMutex, Qt, QThread, QWaitCondition +from qtpy.QtGui import QCursor +from qtpy.QtWidgets import QAction, QMenu + +from cellacdc import ( + apps, + exception_handler, + html_utils, + qutils, + settings_folderpath, + widgets, + workers, +) + +from .brush_tools import BrushTools + + +class LabelRoi(BrushTools): + """Extracted from guiWin.""" + + def getLabelRoiImage(self): + posData = self.data[self.pos_i] + + if self.labelRoiTrangeCheckbox.isChecked(): + start_frame_i = self.labelRoiStartFrameNoSpinbox.value() - 1 + stop_frame_n = self.labelRoiStopFrameNoSpinbox.value() + tRangeLen = stop_frame_n - start_frame_i + else: + tRangeLen = 1 + + if tRangeLen > 1: + tRange = (start_frame_i, stop_frame_n) + else: + tRange = None + + if self.isSegm3D: + if tRangeLen > 1: + imgData = posData.img_data + else: + # Filtered data not existing + imgData = posData.img_data[posData.frame_i] + + roi_zdepth = self.labelRoiZdepthSpinbox.value() + if roi_zdepth == posData.SizeZ: + z0 = 0 + z1 = posData.SizeZ + elif roi_zdepth == 1: + z0 = self.zSliceScrollBar.sliderPosition() + z1 = z0 + 1 + else: + if roi_zdepth % 2 != 0: + roi_zdepth += 1 + half_zdepth = int(roi_zdepth / 2) + zc = self.zSliceScrollBar.sliderPosition() + 1 + z0 = zc - half_zdepth + z0 = z0 if z0 >= 0 else 0 + z1 = zc + half_zdepth + z1 = z1 if z1 < posData.SizeZ else posData.SizeZ + + if self.labelRoiIsRectRadioButton.isChecked(): + labelRoiSlice = self.labelRoiItem.slice(zRange=(z0, z1), tRange=tRange) + elif self.labelRoiIsFreeHandRadioButton.isChecked(): + labelRoiSlice = self.freeRoiItem.slice(zRange=(z0, z1), tRange=tRange) + elif self.labelRoiIsCircularRadioButton.isChecked(): + labelRoiSlice = self.labelRoiCircItemLeft.slice( + zRange=(z0, z1), tRange=tRange + ) + else: + if self.labelRoiIsRectRadioButton.isChecked(): + labelRoiSlice = self.labelRoiItem.slice(tRange=tRange) + elif self.labelRoiIsFreeHandRadioButton.isChecked(): + labelRoiSlice = self.freeRoiItem.slice(tRange=tRange) + elif self.labelRoiIsCircularRadioButton.isChecked(): + labelRoiSlice = self.labelRoiCircItemLeft.slice(tRange=tRange) + if tRangeLen > 1: + imgData = posData.img_data + else: + imgData = self.img1.image + + roiImg = imgData[labelRoiSlice] + if self.labelRoiIsFreeHandRadioButton.isChecked(): + mask = self.freeRoiItem.mask() + elif self.labelRoiIsCircularRadioButton.isChecked(): + mask = self.labelRoiCircItemLeft.mask() + else: + mask = None + + if mask is not None: + # Copy roiImg otherwise we are replacing minimum inside original image + roiImg = roiImg.copy() + # Fill outside of freehand roi with minimum of the ROI image + if tRangeLen > 1: + for i in range(tRangeLen): + ith_roiImg = roiImg[i] + if self.isSegm3D: + roiImg[i, :, ~mask] = ith_roiImg.min() + else: + roiImg[i, ~mask] = ith_roiImg.min() + else: + if self.isSegm3D: + roiImg[:, ~mask] = roiImg.min() + else: + roiImg[~mask] = roiImg.min() + + return roiImg, labelRoiSlice + + def getSecondChannelData(self): + if self.secondChannelName is None: + return + + posData = self.data[self.pos_i] + + fluo_ch = self.secondChannelName + fluo_path, filename = self.getPathFromChName(fluo_ch, posData) + if filename in posData.fluo_data_dict: + fluo_data = posData.fluo_data_dict[filename] + else: + fluo_data, bkgrData = self.load_fluo_data(fluo_path) + posData.fluo_data_dict[filename] = fluo_data + posData.fluo_bkgrData_dict[filename] = bkgrData + + if self.labelRoiTrangeCheckbox.isChecked(): + start_frame_i = self.labelRoiStartFrameNoSpinbox.value() - 1 + stop_frame_n = self.labelRoiStopFrameNoSpinbox.value() + tRangeLen = stop_frame_n - start_frame_i + else: + tRangeLen = 1 + + if tRangeLen > 1: + # fluo_img_data = fluo_data[start_frame_i:stop_frame_n] + if self.isSegm3D or posData.SizeZ == 1: + return fluo_data + else: + T, Z, Y, X = fluo_data.shape + secondChannelData = np.zeros((T, Y, X), dtype=fluo_data.dtype) + for frame_i, fluo_img in enumerate(fluo_data): + secondChannelData[frame_i] = self.get_2Dimg_from_3D( + fluo_data, frame_i=frame_i + ) + return secondChannelData + else: + if posData.SizeT > 1: + fluo_img_data = fluo_data[posData.frame_i] + else: + fluo_img_data = fluo_data + + if self.isSegm3D or posData.SizeZ == 1: + return fluo_img_data + else: + return self.get_2Dimg_from_3D(fluo_img_data) + + def indexRoiLab(self, roiLab, roiLabSlice, lab, brushID): + # Delete only objects touching borders in X and Y not in Z + if self.labelRoiAutoClearBorderCheckbox.isChecked(): + mask = np.zeros(roiLab.shape, dtype=bool) + mask[..., 1:-1, 1:-1] = True + roiLab = skimage.segmentation.clear_border(roiLab, mask=mask) + + roiLabMask = roiLab > 0 + roiLab[roiLabMask] += brushID - 1 + if self.labelRoiReplaceExistingObjectsCheckbox.isChecked(): + IDs_touched_by_new_objects = np.unique(lab[roiLabSlice][roiLabMask]) + for ID in IDs_touched_by_new_objects: + lab[lab == ID] = 0 + + lab[roiLabSlice][roiLabMask] = roiLab[roiLabMask] + return lab + + def initLabelRoiModel(self): + self.app.restoreOverrideCursor() + # Ask which model + self.initLabelRoiModelDialog = apps.QDialogSelectModel(parent=self) + self.initLabelRoiModelDialog.exec_() + if self.initLabelRoiModelDialog.cancel: + self.logger.info("Magic labeller aborted.") + self.initLabelRoiModelDialog = None + return True + self.app.setOverrideCursor(Qt.WaitCursor) + model_name = self.initLabelRoiModelDialog.selectedModel + self.labelRoiModel = self.repeatSegm( + model_name=model_name, askSegmParams=True, is_label_roi=True + ) + if self.labelRoiModel is None: + self.initLabelRoiModelDialog = None + return True + self.labelRoiViewCurrentModelAction.setDisabled(False) + self.initLabelRoiModelDialog = None + return False + + def labelRoiCancelled(self): + self.labelRoiRunning = False + self.app.restoreOverrideCursor() + self.labelRoiItem.setPos((0, 0)) + self.labelRoiItem.setSize((0, 0)) + self.freeRoiItem.clear() + self.logger.info("Magic labeller process cancelled.") + + def labelRoiCheckStartStopFrame(self): + if not self.labelRoiTrangeCheckbox.isChecked(): + return True + + start_n = self.labelRoiStartFrameNoSpinbox.value() + stop_n = self.labelRoiStopFrameNoSpinbox.value() + if start_n <= stop_n: + return True + + self.blinker = qutils.QControlBlink( + self.labelRoiStopFrameNoSpinbox, qparent=self + ) + self.blinker.start() + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + Stop frame number is less than start frame number!

    + What do you want to do? + """) + msg.warning( + self, + "Stop frame number lower than start", + txt, + buttonsTexts=("Cancel", "Segment only current frame"), + ) + if msg.cancel: + return False + + posData = self.data[self.pos_i] + self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i + 1) + self.labelRoiStopFrameNoSpinbox.setValue(posData.frame_i + 1) + + def labelRoiDone(self, roiSegmData, isTimeLapse): + self.setDisabled(False) + + posData = self.data[self.pos_i] + self.setBrushID() + + if isTimeLapse: + self.progressWin.mainPbar.setMaximum(0) + self.progressWin.mainPbar.setValue(0) + current_frame_i = posData.frame_i + start_frame_i = self.labelRoiStartFrameNoSpinbox.value() - 1 + for i, roiLab in enumerate(roiSegmData): + frame_i = start_frame_i + i + lab = posData.allData_li[frame_i]["labels"] + store = True + if lab is None: + if frame_i >= len(posData.segm_data): + lab = np.zeros_like(posData.segm_data[0]) + posData.segm_data = np.append( + posData.segm_data, lab[np.newaxis], axis=0 + ) + else: + lab = posData.segm_data[frame_i] + store = False + roiLabSlice = self.labelRoiSlice[1:] + lab = self.indexRoiLab(roiLab, roiLabSlice, lab, posData.brushID) + if store: + posData.frame_i = frame_i + posData.allData_li[frame_i]["labels"] = lab.copy() + self.get_data() + self.store_data(autosave=False) + + # Back to current frame + posData.frame_i = current_frame_i + self.get_data() + else: + roiLab = roiSegmData + posData.lab = self.indexRoiLab( + roiLab, self.labelRoiSlice, posData.lab, posData.brushID + ) + + self.update_rp() + + # Repeat tracking + if self.autoIDcheckbox.isChecked(): + self.tracking(enforce=True, assign_unique_new_IDs=False) + + self.store_data() + self.updateAllImages() + + self.labelRoiItem.setPos((0, 0)) + self.labelRoiItem.setSize((0, 0)) + self.freeRoiItem.clear() + self.logger.info("Magic labeller done!") + self.app.restoreOverrideCursor() + + self.labelRoiRunning = False + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + uncheckLabelRoiTRange = ( + self.labelRoiTrangeCheckbox.isChecked() + and not self.labelRoiTrangeCheckbox.findChild(QAction).isChecked() + ) + if uncheckLabelRoiTRange: + self.labelRoiTrangeCheckbox.setChecked(False) + + def labelRoiFromCurrentFrameTriggered(self): + posData = self.data[self.pos_i] + self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i + 1) + + def labelRoiToEndFramesTriggered(self): + posData = self.data[self.pos_i] + self.labelRoiStopFrameNoSpinbox.setValue(posData.SizeT) + + def labelRoiTrangeCheckboxToggled(self, checked): + disabled = not checked + self.labelRoiStartFrameNoSpinbox.setDisabled(disabled) + self.labelRoiStopFrameNoSpinbox.setDisabled(disabled) + self.labelRoiStartFrameNoSpinbox.label.setDisabled(disabled) + self.labelRoiStopFrameNoSpinbox.label.setDisabled(disabled) + self.labelRoiToEndFramesAction.setDisabled(disabled) + self.labelRoiFromCurrentFrameAction.setDisabled(disabled) + + if disabled: + return + + posData = self.data[self.pos_i] + + self.labelRoiStartFrameNoSpinbox.setValue(posData.frame_i + 1) + self.labelRoiStopFrameNoSpinbox.setValue(posData.SizeT) + + def labelRoiViewCurrentModel(self): + from . import config + + ini_path = os.path.join(settings_folderpath, "last_params_segm_models.ini") + configPars = config.ConfigParser() + configPars.read(ini_path) + model_name = self.labelRoiModel.model_name + txt = f"Model: {model_name}" + SECTION = f"{model_name}.init" + txt = f"{txt}

    [Initialization parameters]
    " + for option in configPars.options(SECTION): + value = configPars[SECTION][option] + param_txt = f"{option} = {value}
    " + txt = f"{txt}{param_txt}" + + SECTION = f"{model_name}.segment" + txt = f"{txt}
    [Segmentation parameters]
    " + for option in configPars.options(SECTION): + value = configPars[SECTION][option] + param_txt = f"{option} = {value}
    " + txt = f"{txt}{param_txt}" + + win = apps.ViewTextDialog(txt, parent=self) + win.exec_() + + def labelRoiWorkerFinished(self): + self.logger.info("Magic labeller closed.") + worker = self.labelRoiActiveWorkers.pop(-1) + + def labelRoi_cb(self, checked): + posData = self.data[self.pos_i] + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.labelRoiButton) + self.connectLeftClickButtons() + + self.labelRoiStartFrameNoSpinbox.setMaximum(posData.SizeT) + self.labelRoiStopFrameNoSpinbox.setMaximum(posData.SizeT) + + if self.labelRoiActiveWorkers: + lastActiveWorker = self.labelRoiActiveWorkers[-1] + self.labelRoiGarbageWorkers.append(lastActiveWorker) + lastActiveWorker.finished.emit() + self.logger.info("Collected garbage w5orker (magic labeller).") + + self.labelRoiToolbar.setVisible(True) + if self.isSegm3D: + self.labelRoiZdepthSpinbox.setDisabled(False) + else: + self.labelRoiZdepthSpinbox.setDisabled(True) + + # Start thread and pause it + self.labelRoiThread = QThread() + self.labelRoiMutex = QMutex() + self.labelRoiWaitCond = QWaitCondition() + + labelRoiWorker = workers.LabelRoiWorker(self) + + labelRoiWorker.moveToThread(self.labelRoiThread) + labelRoiWorker.finished.connect(self.labelRoiThread.quit) + labelRoiWorker.finished.connect(labelRoiWorker.deleteLater) + self.labelRoiThread.finished.connect(self.labelRoiThread.deleteLater) + + labelRoiWorker.finished.connect(self.labelRoiWorkerFinished) + labelRoiWorker.sigLabellingDone.connect(self.labelRoiDone) + labelRoiWorker.sigProgressBar.connect(self.workerUpdateProgressbar) + + labelRoiWorker.progress.connect(self.workerProgress) + labelRoiWorker.critical.connect(self.workerCritical) + + self.labelRoiActiveWorkers.append(labelRoiWorker) + + self.labelRoiThread.started.connect(labelRoiWorker.run) + self.labelRoiThread.start() + + # Add the rectROI to ax1 + self.ax1.addItem(self.labelRoiItem) + elif self.initLabelRoiModelDialog is not None: + # User is using other tools while the dialog is still open + # --> we allow this because it's useful to be able to use + # the ruler or check things --> do nothing + pass + else: + self.labelRoiToolbar.setVisible(False) + + for worker in self.labelRoiActiveWorkers: + worker._stop() + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + self.labelRoiItem.setPos((0, 0)) + self.labelRoiItem.setSize((0, 0)) + self.freeRoiItem.clear() + self.ax1.removeItem(self.labelRoiItem) + self.updateLabelRoiCircularCursor(None, None, False) + + def loadLabelRoiLastParams(self): + idx = "labelRoi_checkedRoiType" + if idx in self.df_settings.index: + checkedRoiType = self.df_settings.at[idx, "value"] + for button in self.labelRoiTypesGroup.buttons(): + if button.text() == checkedRoiType: + button.setChecked(True) + break + + idx = "labelRoi_circRoiRadius" + if idx in self.df_settings.index: + circRoiRadius = self.df_settings.at[idx, "value"] + self.labelRoiCircularRadiusSpinbox.setValue(int(circRoiRadius)) + + idx = "labelRoi_roiZdepth" + if idx in self.df_settings.index: + roiZdepth = self.df_settings.at[idx, "value"] + self.labelRoiZdepthSpinbox.setValue(int(roiZdepth)) + + idx = "labelRoi_autoClearBorder" + if idx in self.df_settings.index: + clearBorder = self.df_settings.at[idx, "value"] + checked = clearBorder == "Yes" + self.labelRoiAutoClearBorderCheckbox.setChecked(checked) + + idx = "labelRoi_replaceExistingObjects" + if idx in self.df_settings.index: + val = self.df_settings.at[idx, "value"] + checked = val == "Yes" + self.labelRoiReplaceExistingObjectsCheckbox.setChecked(checked) + + if self.labelRoiIsCircularRadioButton.isChecked(): + self.labelRoiCircularRadiusSpinbox.setDisabled(False) + + def showLabelRoiContextMenu(self, event): + menu = QMenu(self.labelRoiButton) + action = QAction("Re-initialize magic labeller model...") + action.triggered.connect(self.initLabelRoiModel) + menu.addAction(action) + menu.exec_(QCursor.pos()) + + def storeLabelRoiParams(self, value=None, checked=True): + checkedRoiType = self.labelRoiTypesGroup.checkedButton().text() + circRoiRadius = self.labelRoiCircularRadiusSpinbox.value() + roiZdepth = self.labelRoiZdepthSpinbox.value() + autoClearBorder = self.labelRoiAutoClearBorderCheckbox.isChecked() + clearBorder = "Yes" if autoClearBorder else "No" + self.df_settings.at["labelRoi_checkedRoiType", "value"] = checkedRoiType + self.df_settings.at["labelRoi_circRoiRadius", "value"] = circRoiRadius + self.df_settings.at["labelRoi_roiZdepth", "value"] = roiZdepth + self.df_settings.at["labelRoi_autoClearBorder", "value"] = clearBorder + self.df_settings.at["labelRoi_replaceExistingObjects", "value"] = ( + "Yes" if self.labelRoiReplaceExistingObjectsCheckbox.isChecked() else "No" + ) + self.df_settings.to_csv(self.settings_csv_path) + + def updateLabelRoiCircularCursor(self, x, y, checked): + if not self.labelRoiButton.isChecked(): + return + if not self.labelRoiIsCircularRadioButton.isChecked(): + return + if self.labelRoiRunning: + return + + size = self.labelRoiCircularRadiusSpinbox.value() + if not checked: + xx, yy = [], [] + else: + xx, yy = [x], [y] + + if not xx and len(self.labelRoiCircItemLeft.getData()[0]) == 0: + return + + self.labelRoiCircItemLeft.setData(xx, yy, size=size) + self.labelRoiCircItemRight.setData(xx, yy, size=size) + + def updateLabelRoiCircularSize(self, value): + self.labelRoiCircItemLeft.setSize(value) + self.labelRoiCircItemRight.setSize(value) diff --git a/cellacdc/mixins/label_transform_tools.py b/cellacdc/mixins/label_transform_tools.py new file mode 100644 index 000000000..cd8f165e8 --- /dev/null +++ b/cellacdc/mixins/label_transform_tools.py @@ -0,0 +1,225 @@ +"""View adapter for label transform tools.""" + +from __future__ import annotations + +import skimage.measure + +from .brush_tools import BrushTools +from .label_editing import LabelEditing + + +class LabelTransformTools(BrushTools, LabelEditing): + """Extracted from guiWin.""" + + def expandLabel(self, dilation=True): + posData = self.data[self.pos_i] + if self.hoverLabelID == 0: + self.isExpandingLabel = False + return + + # Re-initialize label to expand when we hover on a different ID + # or we change direction + reinitExpandingLab = ( + self.expandingID != self.hoverLabelID or dilation != self.isDilation + ) + + ID = self.hoverLabelID + + obj = posData.rp[posData.IDs.index(ID)] + + if reinitExpandingLab: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + # hoverLabelID different from previously expanded ID --> reinit + self.isExpandingLabel = True + self.expandingID = ID + self.expandingLab = np.zeros_like(self.currentLab2D) + self.expandingLab[obj.coords[:, -2], obj.coords[:, -1]] = ID + self.expandFootprintSize = 1 + + prevCoords = (obj.coords[:, -2], obj.coords[:, -1]) + self.currentLab2D[obj.coords[:, -2], obj.coords[:, -1]] = 0 + lab_2D = self.get_2Dlab(posData.lab) + lab_2D[obj.coords[:, -2], obj.coords[:, -1]] = 0 + + footprint = skimage.morphology.disk(self.expandFootprintSize) + if dilation: + expandedLab = skimage.morphology.dilation(self.expandingLab, footprint) + self.isDilation = True + else: + expandedLab = skimage.morphology.erosion(self.expandingLab, footprint) + self.isDilation = False + + # Prevent expanding into neighbouring labels + expandedLab[self.currentLab2D > 0] = 0 + + # Get coords of the dilated/eroded object + expandedObj = skimage.measure.regionprops(expandedLab)[0] + expandedObjCoords = (expandedObj.coords[:, -2], expandedObj.coords[:, -1]) + + # Add the dilated/erored object + self.currentLab2D[expandedObjCoords] = self.expandingID + lab_2D[expandedObjCoords] = self.expandingID + + self.set_2Dlab(lab_2D) + self.currentLab2D = lab_2D + + self.update_rp() + + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.img2.setImage(img=self.currentLab2D, autoLevels=False) + + self.setTempImgExpandLabel(prevCoords, expandedObjCoords) + + def expandLabelCallback(self, checked): + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.sender()) + self.connectLeftClickButtons() + self.expandFootprintSize = 1 + else: + self.clearHighlightedID() + alpha = self.imgGrad.labelsAlphaSlider.value() + self.labelsLayerImg1.setOpacity(alpha) + self.labelsLayerRightImg.setOpacity(alpha) + self.hoverLabelID = 0 + self.expandingID = 0 + self.updateAllImages() + + def _setTempImgExpandLabelContours(self, prevCoords, ax=0): + self.contoursImage[prevCoords] = [0, 0, 0, 0] + currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) + for obj in currentLab2Drp: + if obj.label == self.expandingID: + # self.clearObjContour(obj=obj, ax=ax) + self.addObjContourToContoursImage(obj=obj, ax=ax, force=True) + break + + def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0): + # Remove previous overlaid mask + labelsImage = self.getLabelsLayerImage(ax=ax) + labelsImage[prevCoords] = 0 + + # Overlay new moved mask + labelsImage[prevCoords] = self.expandingID + + if ax == 0: + self.labelsLayerImg1.setImage(self.labelsLayerImg1.image, autoLevels=False) + else: + self.labelsLayerRightImg.setImage( + self.labelsLayerRightImg.image, autoLevels=False + ) + + def resetExpandLabel(self): + self.expandingID = -1 + + def startMovingLabel(self, xPos, yPos): + posData = self.data[self.pos_i] + xdata, ydata = int(xPos), int(yPos) + lab_2D = self.get_2Dlab(posData.lab) + ID = lab_2D[ydata, xdata] + if ID == 0: + self.isMovingLabel = False + return + + posData = self.data[self.pos_i] + self.isMovingLabel = True + + self.searchedIDitemRight.setData([], []) + self.searchedIDitemLeft.setData([], []) + self.movingID = ID + self.prevMovePos = (xdata, ydata) + movingObj = posData.rp[posData.IDs.index(ID)] + self.movingObjCoords = movingObj.coords.copy() + yy, xx = movingObj.coords[:, -2], movingObj.coords[:, -1] + self.currentLab2D[yy, xx] = 0 + + def moveLabel(self, xPos, yPos): + posData = self.data[self.pos_i] + lab_2D = self.get_2Dlab(posData.lab) + Y, X = lab_2D.shape + xdata, ydata = int(xPos), int(yPos) + if xdata < 0 or ydata < 0 or xdata >= X or ydata >= Y: + return + + self.clearObjContour(ID=self.movingID, ax=0) + + xStart, yStart = self.prevMovePos + deltaX = xdata - xStart + deltaY = ydata - yStart + + yy, xx = self.movingObjCoords[:, -2], self.movingObjCoords[:, -1] + + if self.isSegm3D: + zz = self.movingObjCoords[:, 0] + posData.lab[zz, yy, xx] = 0 + else: + posData.lab[yy, xx] = 0 + + self.movingObjCoords[:, -2] = self.movingObjCoords[:, -2] + deltaY + self.movingObjCoords[:, -1] = self.movingObjCoords[:, -1] + deltaX + + yy, xx = self.movingObjCoords[:, -2], self.movingObjCoords[:, -1] + + yy[yy < 0] = 0 + xx[xx < 0] = 0 + yy[yy >= Y] = Y - 1 + xx[xx >= X] = X - 1 + + if self.isSegm3D: + zz = self.movingObjCoords[:, 0] + posData.lab[zz, yy, xx] = self.movingID + else: + posData.lab[yy, xx] = self.movingID + + self.currentLab2D = self.get_2Dlab(posData.lab) + if self.labelsGrad.showLabelsImgAction.isChecked(): + self.img2.setImage(self.currentLab2D, autoLevels=False) + + self.setTempImg1MoveLabel() + + self.prevMovePos = (xdata, ydata) + + def setTempImgExpandLabel(self, prevCoords, expandedObjCoords, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + self._setTempImgExpandLabelContours(prevCoords, ax=ax) + + def setTempImg1MoveLabel(self, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + if how.find("contours") != -1: + currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) + for obj in currentLab2Drp: + if obj.label == self.movingID: + self.addObjContourToContoursImage(obj=obj, ax=ax) + break + elif how.find("overlay segm. masks") != -1: + if ax == 0: + self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) + self.highLightIDLayerImg1.image[:] = 0 + mask = self.currentLab2D == self.movingID + self.highLightIDLayerImg1.image[mask] = self.movingID + highlightedImage = self.highLightIDLayerImg1.image + self.highLightIDLayerImg1.setImage(highlightedImage) + else: + self.labelsLayerRightImg.setImage(self.currentLab2D, autoLevels=False) + self.highLightIDLayerRightImage.image[:] = 0 + mask = self.currentLab2D == self.movingID + self.highLightIDLayerRightImage.image[mask] = self.movingID + highlightedImage = self.highLightIDLayerRightImage.image + self.highLightIDLayerRightImage.setImage(highlightedImage) + + def moveLabelButtonToggled(self, checked): + if not checked: + self.hoverLabelID = 0 + self.highlightedID = 0 + self.highLightIDLayerImg1.clear() + self.highLightIDLayerRightImage.clear() + self.setHighlightID(False) diff --git a/cellacdc/mixins/layout_controls.py b/cellacdc/mixins/layout_controls.py new file mode 100644 index 000000000..7155b7def --- /dev/null +++ b/cellacdc/mixins/layout_controls.py @@ -0,0 +1,773 @@ +"""Qt view adapter for layout-control workflows.""" + +from __future__ import annotations + +from functools import partial + +from natsort import natsorted +import re +from qtpy.QtCore import QTimer, Qt +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import ( + QAction, + QActionGroup, + QButtonGroup, + QCheckBox, + QDockWidget, + QGridLayout, + QLabel, + QRadioButton, + QSizePolicy, + QWidget, +) + +from cellacdc import utils, widgets +from cellacdc.gui_decorators import resetViewRange + +from .image_controls import ImageControls +from .window_events import WindowEvents +from .label_roi import LabelRoi + + +class LayoutControls(ImageControls, WindowEvents, LabelRoi): + """Extracted from guiWin.""" + + def gui_createControlsToolbar(self): + self.controlToolBars = [] + self.addToolBarBreak() + + # Edit toolbar + modeToolBar = widgets.ToolBar("Mode", self) + self.addToolBar(modeToolBar) + + self.modeComboBox = widgets.ComboBox() + self.modeComboBox.addItems(self.modeItems) + self.modeComboBoxLabel = QLabel(" Mode: ") + self.modeComboBoxLabel.setBuddy(self.modeComboBox) + modeToolBar.addWidget(self.modeComboBoxLabel) + modeToolBar.addWidget(self.modeComboBox) + modeToolBar.setVisible(False) + + self.modeToolBar = modeToolBar + + self.overlayToolbar = widgets.OverlayToolbar(parent=self) + self.addToolBar(Qt.TopToolBarArea, self.overlayToolbar) + self.overlayToolbar.setVisible(False) + self.overlayToolbar.sigSetTranspacency.connect(self.setOverlayTransparency) + self.overlayToolbar.sigSetSingleChannel.connect(self.setOverlaySingleChannel) + + self.autoPilotZoomToObjToolbar = widgets.ToolBar("Auto-zoom to objects", self) + self.autoPilotZoomToObjToolbar.setContextMenuPolicy(Qt.PreventContextMenu) + self.autoPilotZoomToObjToolbar.setMovable(False) + self.addToolBar(Qt.TopToolBarArea, self.autoPilotZoomToObjToolbar) + # self.autoPilotZoomToObjToolbar.setIconSize(QSize(16, 16)) + self.autoPilotZoomToObjToolbar.setVisible(False) + self.autoPilotZoomToObjToolbar.keepVisibleWhenActive = True + self.controlToolBars.append(self.autoPilotZoomToObjToolbar) + + # Highlighted ID or searched ID toolbar + self.highlightIDToolbar = widgets.HighlightedIDToolbar(parent=self) + self.addToolBar(Qt.TopToolBarArea, self.highlightIDToolbar) + self.highlightIDToolbar.setVisible(False) + self.highlightIDToolbar.keepVisibleWhenActive = True + self.controlToolBars.append(self.highlightIDToolbar) + + self.highlightIDToolbar.sigIDChanged.connect(self.setHighlighedIDfromToolbar) + + # Widgets toolbar + brushEraserToolBar = widgets.ToolBar("Widgets", self) + self.addToolBar(Qt.TopToolBarArea, brushEraserToolBar) + self.controlToolBars.append(brushEraserToolBar) + + self.editIDspinbox = widgets.SpinBox() + # self.editIDspinbox.setMaximum(2**32-1) + editIDLabel = QLabel(" ID: ") + self.editIDLabelAction = brushEraserToolBar.addWidget(editIDLabel) + self.editIDspinboxAction = brushEraserToolBar.addWidget(self.editIDspinbox) + self.editIDLabelAction.setVisible(False) + self.editIDspinboxAction.setVisible(False) + self.editIDspinboxAction.setDisabled(True) + self.editIDLabelAction.setDisabled(True) + + brushEraserToolBar.addWidget(QLabel(" ")) + self.autoIDcheckbox = QCheckBox("Auto-ID") + self.autoIDcheckbox.setChecked(True) + self.autoIDcheckboxAction = brushEraserToolBar.addWidget(self.autoIDcheckbox) + self.autoIDcheckboxAction.setVisible(False) + + self.brushSizeSpinbox = widgets.SpinBox( + disableKeyPress=True, allowNegative=False + ) + self.brushSizeSpinbox.setValue(4) + brushSizeLabel = QLabel(" Size: ") + brushSizeLabel.setBuddy(self.brushSizeSpinbox) + self.brushSizeLabelAction = brushEraserToolBar.addWidget(brushSizeLabel) + self.brushSizeAction = brushEraserToolBar.addWidget(self.brushSizeSpinbox) + self.brushSizeLabelAction.setVisible(False) + self.brushSizeAction.setVisible(False) + + brushEraserToolBar.addWidget(QLabel(" ")) + self.brushAutoFillCheckbox = QCheckBox("Auto-fill holes") + self.brushAutoFillAction = brushEraserToolBar.addWidget( + self.brushAutoFillCheckbox + ) + self.brushAutoFillAction.setVisible(False) + if "brushAutoFill" in self.df_settings.index: + checked = self.df_settings.at["brushAutoFill", "value"] == "Yes" + self.brushAutoFillCheckbox.setChecked(checked) + + brushEraserToolBar.addWidget(QLabel(" ")) + self.brushAutoHideCheckbox = QCheckBox("Hide objects when hovering") + self.brushAutoHideAction = brushEraserToolBar.addWidget( + self.brushAutoHideCheckbox + ) + self.brushAutoHideCheckbox.setChecked(True) + self.brushAutoHideAction.setVisible(False) + if "brushAutoHide" in self.df_settings.index: + checked = self.df_settings.at["brushAutoHide", "value"] == "Yes" + self.brushAutoHideCheckbox.setChecked(checked) + + brushEraserToolBar.setVisible(False) + self.brushEraserToolBar = brushEraserToolBar + + self.wandControlsToolbar = widgets.WandControlsToolbar(parent=self) + + self.addToolBar(Qt.TopToolBarArea, self.wandControlsToolbar) + self.wandControlsToolbar.setVisible(False) + self.controlToolBars.append(self.wandControlsToolbar) + + separatorW = 5 + self.labelRoiToolbar = widgets.ToolBar("Magic labeller controls", self) + self.labelRoiToolbar.addWidget(QLabel("ROI n. of z-slices: ")) + self.labelRoiZdepthSpinbox = widgets.SpinBox(disableKeyPress=True) + self.labelRoiToolbar.addWidget(self.labelRoiZdepthSpinbox) + + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + self.labelRoiToolbar.addWidget(widgets.QVLine()) + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + + self.labelRoiReplaceExistingObjectsCheckbox = QCheckBox( + "Remove objs. touched by new ones" + ) + self.labelRoiToolbar.addWidget(self.labelRoiReplaceExistingObjectsCheckbox) + self.labelRoiAutoClearBorderCheckbox = QCheckBox( + "Clear ROI borders before adding new objs." + ) + self.labelRoiAutoClearBorderCheckbox.setChecked(True) + self.labelRoiToolbar.addWidget(self.labelRoiAutoClearBorderCheckbox) + + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + self.labelRoiToolbar.addWidget(widgets.QVLine()) + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + + group = QButtonGroup() + group.setExclusive(True) + self.labelRoiIsRectRadioButton = QRadioButton("Rect. ROI") + self.labelRoiIsRectRadioButton.setChecked(True) + self.labelRoiIsFreeHandRadioButton = QRadioButton("Freehand ROI") + self.labelRoiIsCircularRadioButton = QRadioButton("Circular ROI") + group.addButton(self.labelRoiIsRectRadioButton) + group.addButton(self.labelRoiIsFreeHandRadioButton) + group.addButton(self.labelRoiIsCircularRadioButton) + self.labelRoiToolbar.addWidget(self.labelRoiIsRectRadioButton) + self.labelRoiToolbar.addWidget(self.labelRoiIsFreeHandRadioButton) + self.labelRoiToolbar.addWidget(self.labelRoiIsCircularRadioButton) + self.labelRoiToolbar.addWidget(QLabel(" | Radius (pixel): ")) + self.labelRoiCircularRadiusSpinbox = widgets.SpinBox(disableKeyPress=True) + self.labelRoiCircularRadiusSpinbox.setMinimum(1) + self.labelRoiCircularRadiusSpinbox.setValue(11) + self.labelRoiCircularRadiusSpinbox.setDisabled(True) + self.labelRoiToolbar.addWidget(self.labelRoiCircularRadiusSpinbox) + + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + self.labelRoiToolbar.addWidget(widgets.QVLine()) + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=separatorW)) + + startFrameLabel = QLabel("Start frame n. ") + startFrameLabel.setDisabled(True) + self.labelRoiToolbar.addWidget(startFrameLabel) + self.labelRoiStartFrameNoSpinbox = widgets.SpinBox(disableKeyPress=True) + self.labelRoiStartFrameNoSpinbox.label = startFrameLabel + self.labelRoiStartFrameNoSpinbox.setValue(1) + self.labelRoiStartFrameNoSpinbox.setMinimum(1) + self.labelRoiToolbar.addWidget(self.labelRoiStartFrameNoSpinbox) + self.labelRoiStartFrameNoSpinbox.setDisabled(True) + + self.labelRoiFromCurrentFrameAction = QAction(self) + self.labelRoiFromCurrentFrameAction.setText("Segment from current frame") + self.labelRoiFromCurrentFrameAction.setIcon(QIcon(":frames_current.svg")) + self.labelRoiToolbar.addAction(self.labelRoiFromCurrentFrameAction) + self.labelRoiFromCurrentFrameAction.setDisabled(True) + + self.labelRoiToolbar.addWidget(widgets.QHWidgetSpacer(width=3)) + stopFrameLabel = QLabel(" Stop frame n. ") + stopFrameLabel.setDisabled(True) + self.labelRoiToolbar.addWidget(stopFrameLabel) + self.labelRoiStopFrameNoSpinbox = widgets.SpinBox(disableKeyPress=True) + self.labelRoiStopFrameNoSpinbox.label = stopFrameLabel + self.labelRoiStopFrameNoSpinbox.setValue(1) + self.labelRoiStopFrameNoSpinbox.setMinimum(1) + self.labelRoiToolbar.addWidget(self.labelRoiStopFrameNoSpinbox) + self.labelRoiStopFrameNoSpinbox.setDisabled(True) + + self.labelRoiToEndFramesAction = QAction(self) + self.labelRoiToEndFramesAction.setText("Segment all remaining frames") + self.labelRoiToEndFramesAction.setIcon(QIcon(":frames_end.svg")) + self.labelRoiToolbar.addAction(self.labelRoiToEndFramesAction) + self.labelRoiToEndFramesAction.setDisabled(True) + + self.labelRoiTrangeCheckbox = QCheckBox("Segment range of frames") + self.labelRoiToolbar.addWidget(self.labelRoiTrangeCheckbox) + + self.labelRoiViewCurrentModelAction = QAction(self) + self.labelRoiViewCurrentModelAction.setText("View current model's parameters") + self.labelRoiViewCurrentModelAction.setIcon(QIcon(":view.svg")) + self.labelRoiToolbar.addAction(self.labelRoiViewCurrentModelAction) + self.labelRoiViewCurrentModelAction.setDisabled(True) + + self.addToolBar(Qt.TopToolBarArea, self.labelRoiToolbar) + self.controlToolBars.append(self.labelRoiToolbar) + self.labelRoiToolbar.setVisible(False) + self.labelRoiTypesGroup = group + + self.loadLabelRoiLastParams() + + self.labelRoiTrangeCheckbox.toggled.connect(self.labelRoiTrangeCheckboxToggled) + self.labelRoiReplaceExistingObjectsCheckbox.toggled.connect( + self.storeLabelRoiParams + ) + self.labelRoiIsCircularRadioButton.toggled.connect( + self.labelRoiIsCircularRadioButtonToggled + ) + self.labelRoiCircularRadiusSpinbox.valueChanged.connect( + self.updateLabelRoiCircularSize + ) + self.labelRoiCircularRadiusSpinbox.valueChanged.connect( + self.storeLabelRoiParams + ) + self.labelRoiZdepthSpinbox.valueChanged.connect(self.storeLabelRoiParams) + self.labelRoiAutoClearBorderCheckbox.toggled.connect(self.storeLabelRoiParams) + group.buttonToggled.connect(self.storeLabelRoiParams) + + self.labelRoiToEndFramesAction.triggered.connect( + self.labelRoiToEndFramesTriggered + ) + self.labelRoiFromCurrentFrameAction.triggered.connect( + self.labelRoiFromCurrentFrameTriggered + ) + self.labelRoiViewCurrentModelAction.triggered.connect( + self.labelRoiViewCurrentModel + ) + + self.keepIDsToolbar = widgets.ToolBar("Keep IDs controls", self) + self.keepIDsConfirmAction = QAction() + self.keepIDsConfirmAction.setIcon(QIcon(":greenTick.svg")) + self.keepIDsConfirmAction.setToolTip('Apply "keep IDs" selection') + self.keepIDsConfirmAction.setDisabled(True) + self.keepIDsToolbar.addAction(self.keepIDsConfirmAction) + self.keepIDsToolbar.addWidget(QLabel(" IDs to keep: ")) + instructionsText = ( + " (Separate IDs by comma. Use a dash to denote a range of IDs)" + ) + instructionsLabel = QLabel(instructionsText) + self.keptIDsLineEdit = widgets.KeepIDsLineEdit(instructionsLabel, parent=self) + self.keepIDsToolbar.addWidget(self.keptIDsLineEdit) + self.keepIDsToolbar.addWidget(instructionsLabel) + spacer = QWidget() + spacer.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + self.keepIDsToolbar.addWidget(spacer) + self.addToolBar(Qt.TopToolBarArea, self.keepIDsToolbar) + self.keepIDsToolbar.setVisible(False) + self.controlToolBars.append(self.keepIDsToolbar) + + self.keptIDsLineEdit.sigEnterPressed.connect(self.applyKeepObjects) + self.keptIDsLineEdit.sigIDsChanged.connect(self.updateKeepIDs) + self.keepIDsConfirmAction.triggered.connect(self.applyKeepObjects) + + # closeToolbarAction = QAction( + # QIcon(":cancelButton.svg"), "Close toolbar...", self + # ) + # closeToolbarAction.triggered.connect(self.closeToolbars) + # self.autoPilotZoomToObjToolbar.addAction(closeToolbarAction) + + self.autoPilotZoomToObjToolbar.addWidget(widgets.QVLine()) + self.autoPilotZoomToObjToolbar.addWidget( + widgets.QHWidgetSpacer(width=separatorW) + ) + + spinBox = widgets.SpinBox() + spinBox.setMinimum(1) + spinBox.label = QLabel(" Zoom to ID: ") + spinBox.labelAction = self.autoPilotZoomToObjToolbar.addWidget(spinBox.label) + spinBox.action = self.autoPilotZoomToObjToolbar.addWidget(spinBox) + spinBox.editingFinished.connect(self.zoomToObj) + spinBox.sigUpClicked.connect(self.autoZoomNextObj) + spinBox.sigDownClicked.connect(self.autoZoomPrevObj) + self.autoPilotZoomToObjSpinBox = spinBox + toggle = widgets.Toggle() + self.autoPilotZoomToObjToggle = toggle + toggle.toggled.connect(self.autoPilotZoomToObjToggled) + toggle.label = QLabel(" Auto-pilot: ") + tooltip = ( + "When auto-pilot is active, you can use Up/Down arrows to " + "automatically zoom to the next/previous object.\n\n" + "Alternatively, you can type the ID of the object you want to " + "zoom to." + ) + toggle.label.setToolTip(tooltip) + toggle.setToolTip(tooltip) + self.autoPilotZoomToObjToolbar.addWidget(toggle.label) + self.autoPilotZoomToObjToolbar.addWidget(toggle) + + self.pointsLayersToolbars = [] + + self.pointsLayersToolbar = widgets.PointsLayersToolbar(parent=self) + self.pointsLayersToolbar.setContextMenuPolicy(Qt.PreventContextMenu) + + self.pointsLayersToolbar.sigAddPointsLayer.connect(self.addPointsLayerTriggered) + + self.addToolBar(Qt.TopToolBarArea, self.pointsLayersToolbar) + + self.pointsLayersToolbar.setVisible(False) + self.pointsLayersToolbar.keepVisibleWhenActive = True + self.controlToolBars.append(self.pointsLayersToolbar) + + self.pointsLayersToolbars.append(self.pointsLayersToolbar) + + self.manualTrackingToolbar = widgets.ManualTrackingToolBar( + "Manual tracking controls", self + ) + self.manualTrackingToolbar.sigIDchanged.connect(self.initGhostObject) + self.manualTrackingToolbar.sigDisableGhost.connect(self.clearGhost) + self.manualTrackingToolbar.sigClearGhostContour.connect(self.clearGhostContour) + self.manualTrackingToolbar.sigClearGhostMask.connect(self.clearGhostMask) + self.manualTrackingToolbar.sigGhostOpacityChanged.connect( + self.updateGhostMaskOpacity + ) + + self.addToolBar(Qt.TopToolBarArea, self.manualTrackingToolbar) + self.manualTrackingToolbar.setVisible(False) + self.controlToolBars.append(self.manualTrackingToolbar) + + self.manualBackgroundToolbar = widgets.ManualBackgroundToolBar( + "Manual background controls", self + ) + self.manualBackgroundToolbar.sigIDchanged.connect( + self.initManualBackgroundObject + ) + self.addToolBar(Qt.TopToolBarArea, self.manualBackgroundToolbar) + self.manualBackgroundToolbar.setVisible(False) + self.controlToolBars.append(self.manualBackgroundToolbar) + + # Copy lost object contour toolbar + self.copyLostObjToolbar = widgets.CopyLostObjectToolbar( + "Copy lost object controls", self + ) + for name, action in self.copyLostObjToolbar.widgetsWithShortcut.items(): + self.widgetsWithShortcut[name] = action + + self.copyLostObjToolbar.sigCopyAllObjects.connect(self.copyAllLostObjects) + + self.addToolBar(Qt.TopToolBarArea, self.copyLostObjToolbar) + self.copyLostObjToolbar.setVisible(False) + # self.controlToolBars.append(self.copyLostObjToolbar) + + # Copy lost object contour toolbar + self.drawClearRegionToolbar = widgets.DrawClearRegionToolbar( + "Draw freehand region and clear objects controls", self + ) + + self.addToolBar(Qt.TopToolBarArea, self.drawClearRegionToolbar) + self.drawClearRegionToolbar.setVisible(False) + self.controlToolBars.append(self.drawClearRegionToolbar) + + try: + addNewIDToggleState = ( + self.df_settings.at["addNewIDsWhitelistToggle", "value"] == "Yes" + ) + except KeyError: + addNewIDToggleState = True + + self.whitelistIDsToolbar = widgets.WhitelistIDsToolbar( + addNewIDToggleState, self + ) + for name, action in self.whitelistIDsToolbar.widgetsWithShortcut.items(): + self.widgetsWithShortcut[name] = action + + self.addToolBar(Qt.TopToolBarArea, self.whitelistIDsToolbar) + self.whitelistIDsToolbar.setVisible(False) + self.controlToolBars.append(self.whitelistIDsToolbar) + + self.magicPromptsToolbar = widgets.MagicPromptsToolbar(self) + for name, action in self.magicPromptsToolbar.widgetsWithShortcut.items(): + self.widgetsWithShortcut[name] = action + + self.magicPromptsToolbar.sigComputeOnZoom.connect( + self.magicPromptsComputeOnZoomTriggered + ) + self.magicPromptsToolbar.sigComputeOnImage.connect( + self.magicPromptsComputeOnImageTriggered + ) + self.magicPromptsToolbar.sigInitSelectedModel.connect( + self.magicPromptsInitModel + ) + self.magicPromptsToolbar.sigViewModelParams.connect( + self.viewSetMagicPromptModelParams + ) + self.magicPromptsToolbar.sigClearPoints.connect( + partial(self.magicPromptsClearPoints, only_zoom=False) + ) + self.magicPromptsToolbar.sigClearPointsOnZmom.connect( + partial(self.magicPromptsClearPoints, only_zoom=True) + ) + self.magicPromptsToolbar.sigInterpolateZslice.connect( + self.magicPromptsInterpolateZsliceToggled + ) + + self.addToolBar(Qt.TopToolBarArea, self.magicPromptsToolbar) + self.magicPromptsToolbar.setVisible(False) + self.magicPromptsToolbar.keepVisibleWhenActive = True + self.controlToolBars.append(self.magicPromptsToolbar) + + self.promptSegmentPointsLayerToolbar = ( + widgets.PromptableModelPointsLayerToolbar(parent=self) + ) + self.promptSegmentPointsLayerToolbar.setContextMenuPolicy(Qt.PreventContextMenu) + + self.addToolBar(Qt.TopToolBarArea, self.promptSegmentPointsLayerToolbar) + self.promptSegmentPointsLayerToolbar.setVisible(False) + + self.pointsLayersToolbars.append(self.promptSegmentPointsLayerToolbar) + + # Second level toolbar + secondLevelToolbar = widgets.ToolBar("Second level toolbar", self) + self.addToolBar(Qt.TopToolBarArea, secondLevelToolbar) + self.delObjToolAction = QAction(self) + self.delObjToolAction.setIcon(QIcon(":del_obj_click.svg")) + self.delObjToolAction.setCheckable(True) + self.delObjToolAction.setToolTip( + "Customisable delete object action\n\n" + "Go to the `Settings --> Customise keyboard shortcuts...` menu " + "on the top menubar\n" + "to customise the action required to delete " + "an object with a click.\n\n" + 'When working with 3D segmentations, to delete only the z-slice mask, hold "Shift" while clicking.' + ) + secondLevelToolbar.addAction(self.delObjToolAction) + secondLevelToolbar.setMovable(False) + self.secondLevelToolbar = secondLevelToolbar + self.secondLevelToolbar.setVisible(False) + + def gui_createMainLayout(self): + mainLayout = QGridLayout() + row, col = 0, 1 # Leave column 1 for the overlay labels gradient editor + mainLayout.addLayout(self.leftSideDocksLayout, row, col, 2, 1) + + row = 0 + col = 2 + mainLayout.addWidget(self.graphLayout, row, col, 1, 2) + mainLayout.setRowStretch(row, 2) + + col = 4 # graphLayout spans two columns + mainLayout.addWidget(self.labelsGrad, row, col) + + col = 5 + mainLayout.addLayout(self.rightSideDocksLayout, row, col, 2, 1) + + col = 2 + row += 1 + self.resizeBottomLayoutLine = widgets.VerticalResizeHline() + mainLayout.addWidget(self.resizeBottomLayoutLine, row, col, 1, 2) + self.resizeBottomLayoutLine.dragged.connect(self.resizeBottomLayoutLineDragged) + self.resizeBottomLayoutLine.clicked.connect(self.resizeBottomLayoutLineClicked) + self.resizeBottomLayoutLine.released.connect( + self.resizeBottomLayoutLineReleased + ) + + # row += 1 + # mainLayout.addItem(QSpacerItem(5,5), row+1, col, 1, 2) + + # row, col = 1, 2 + # mainLayout.addLayout( + # self.bottomLayout, row, col, 1, 2, alignment=Qt.AlignLeft + # ) + + row += 1 + mainLayout.addWidget(self.bottomScrollArea, row, col, 1, 2) + mainLayout.setRowStretch(row, 0) + + # row, col = 2, 1 + # mainLayout.addWidget(self.terminal, row, col, 1, 4) + # self.terminal.hide() + + return mainLayout + + def gui_createRegionPropsDockWidget(self, side=Qt.LeftDockWidgetArea): + self.propsDockWidget = QDockWidget("Cell-ACDC objects", self) + self.guiTabControl = widgets.guiTabControl(self.propsDockWidget) + + # self.guiTabControl.setFont(_font) + + self.propsDockWidget.setWidget(self.guiTabControl) + self.propsDockWidget.setFeatures( + QDockWidget.DockWidgetFeature.DockWidgetFloatable + | QDockWidget.DockWidgetFeature.DockWidgetMovable + ) + self.propsDockWidget.setAllowedAreas( + Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea + ) + + self.addDockWidget(side, self.propsDockWidget) + self.propsDockWidget.hide() + + def gui_createStatusBar(self): + self.statusbar = self.statusBar() + # Permanent widget + self.wcLabel = QLabel("") + self.statusbar.addPermanentWidget(self.wcLabel) + + # self.toggleTerminalButton = widgets.ToggleTerminalButton() + # self.statusbar.addWidget(self.toggleTerminalButton) + # self.toggleTerminalButton.sigClicked.connect( + # self.gui_terminalButtonClicked + # ) + + self.statusBarLabel = QLabel("") + self.statusbar.addWidget(self.statusBarLabel) + + def gui_createTerminalWidget(self): + self.terminal = widgets.QLog(logger=self.logger) + self.terminal.connect() + self.terminalDock = QDockWidget("Log", self) + + self.terminalDock.setWidget(self.terminal) + self.terminalDock.setFeatures( + QDockWidget.DockWidgetFeature.DockWidgetFloatable + | QDockWidget.DockWidgetFeature.DockWidgetMovable + ) + self.terminalDock.setAllowedAreas(Qt.BottomDockWidgetArea) + self.addDockWidget(Qt.BottomDockWidgetArea, self.terminalDock) + # self.terminalDock.widget().layout().setContentsMargins(10,0,10,0) + self.terminalDock.setVisible(False) + + def gui_populateToolSettingsMenu(self): + brushHoverModeActionGroup = QActionGroup(self) + brushHoverModeActionGroup.setExclusive(True) + self.brushHoverCenterModeAction = QAction() + self.brushHoverCenterModeAction.setCheckable(True) + self.brushHoverCenterModeAction.setText( + "Use center of the brush/eraser cursor to determine hover ID" + ) + self.brushHoverCircleModeAction = QAction() + self.brushHoverCircleModeAction.setCheckable(True) + self.brushHoverCircleModeAction.setText( + "Use the entire circle of the brush/eraser cursor to determine hover ID" + ) + brushHoverModeActionGroup.addAction(self.brushHoverCenterModeAction) + brushHoverModeActionGroup.addAction(self.brushHoverCircleModeAction) + brushHoverModeMenu = self.settingsMenu.addMenu( + "Brush/eraser cursor hovering mode" + ) + brushHoverModeMenu.addAction(self.brushHoverCenterModeAction) + brushHoverModeMenu.addAction(self.brushHoverCircleModeAction) + + if "useCenterBrushCursorHoverID" not in self.df_settings.index: + self.df_settings.at["useCenterBrushCursorHoverID", "value"] = "Yes" + + useCenterBrushCursorHoverID = ( + self.df_settings.at["useCenterBrushCursorHoverID", "value"] == "Yes" + ) + self.brushHoverCenterModeAction.setChecked(useCenterBrushCursorHoverID) + self.brushHoverCircleModeAction.setChecked(not useCenterBrushCursorHoverID) + + self.brushHoverCenterModeAction.toggled.connect( + self.useCenterBrushCursorHoverIDtoggled + ) + + self.settingsMenu.addSeparator() + + keepToolActiveNames = {"Segment range of frames": self.labelRoiTrangeCheckbox} + for button in self.checkableQButtonsGroup.buttons(): + if button.toolTip() == "": + toolName = "MISSING" + continue + else: + toolName = re.findall(r"Name: (.*)", button.toolTip())[0] + keepToolActiveNames[toolName] = button + + keepToolActiveNames = dict(natsorted(keepToolActiveNames.items())) + + applyToNewFrameNames = { + "Segmenting for lost IDs": self.segForLostIDsButton, + "Delete bordering objects": self.delBorderObjAction.button, + "Delete newly segmented objects": self.delNewObjAction.button, + } + + allToolsList = list(keepToolActiveNames.keys()) + list( + applyToNewFrameNames.keys() + ) + allToolsList = natsorted(allToolsList) + + menus = {} + + for toolName in allToolsList: + menuItemText = f"{toolName} tool".replace(" ", " ") + menus[toolName] = self.settingsMenu.addMenu(menuItemText) + + self.keepToolActiveActions = dict() + self.applyToolNewFrameActions = dict() + self.applyToolNewFrameButtons = dict() + all_checked = True + + for toolName, button in keepToolActiveNames.items(): + menu = menus[toolName] + action = QAction(button) + action.setText("Keep tool active after using it") + action.setCheckable(True) + if toolName in self.df_settings.index: + action.setChecked(True) + else: + all_checked = False + action.toggled.connect(self.keepToolActiveActionToggled) + menu.addAction(action) + self.keepToolActiveActions[toolName] = action + + for toolName, button in applyToNewFrameNames.items(): + menu = menus[toolName] + action = QAction(button) + action.setText("Apply when visitng new frame") + action.setCheckable(True) + action.toggled.connect(self.applyToolNewFrameActionToggled) + menu.addAction(action) + self.applyToolNewFrameActions[toolName] = action + self.applyToolNewFrameButtons[toolName] = button + + for toolName in self.applyToolNewFrameActions.keys(): + settingString = toolName.strip() + settingString = toolName.replace(" ", "_") + settingString = f"{settingString}_applyNewFrame" + if settingString in self.df_settings.index: + val = self.df_settings.at[settingString, "value"] + if val == "applyNewFrame": + self.applyToolNewFrameActions[toolName].setChecked(True) + + self.settingsMenu.addSeparator() + + self.keepAllToolsActiveToggle = QAction() + self.keepAllToolsActiveToggle.setText("Keep all tools active after using them") + self.keepAllToolsActiveToggle.setCheckable(True) + self.keepAllToolsActiveToggle.setChecked(all_checked) + self.keepAllToolsActiveToggle.toggled.connect( + self.keepAllToolsActiveActionToggled + ) + self.settingsMenu.addAction(self.keepAllToolsActiveToggle) + self.settingsMenu.addSeparator() + + askHowFutureFramesMenu = self.settingsMenu.addMenu( + "Ask how to propagate changes to future frames" + ) + self.askHowFutureFramesActions = {} + askHowFutureFramesActionsKeys = ( + "Delete ID", + "Exclude cell from analysis", + "Annotate cell as dead", + "Edit ID", + "Keep ID", + ) + for key in askHowFutureFramesActionsKeys: + askHowFutureFramesAction = QAction() + askHowFutureFramesAction.setText(f'Ask for "{key}" action') + askHowFutureFramesAction.setCheckable(True) + askHowFutureFramesAction.setChecked(True) + askHowFutureFramesAction.setDisabled(True) + askHowFutureFramesMenu.addAction(askHowFutureFramesAction) + self.askHowFutureFramesActions[key] = askHowFutureFramesAction + + warningsMenu = self.settingsMenu.addMenu("Warnings and pop-ups") + self.warnLostCellsAction = QAction() + self.warnLostCellsAction.setText("Show pop-up warning for lost cells") + self.warnLostCellsAction.setCheckable(True) + self.warnLostCellsAction.setChecked(True) + warningsMenu.addAction(self.warnLostCellsAction) + + warnEditingWithAnnotTexts = { + "Delete ID": "Show warning when deleting ID that has annotations", + "Separate IDs": "Show warning when separating IDs that have annotations", + "Edit ID": "Show warning when editing ID that has annotations", + "Annotate ID as dead": "Show warning when annotating dead ID that has annotations", + "Delete ID with eraser": "Show warning when erasing ID that has annotations", + "Add new ID with brush tool": "Show warning when adding new ID (brush) that has annotations", + "Merge IDs": "Show warning when merging IDs that have annotations", + "Add new ID with curvature tool": "Show warning when adding new ID (curv. tool) that has annotations", + "Add new ID with magic-wand": "Show warning when adding new ID (magic-wand) that has annotations", + "Delete IDs using ROI": "Show warning when using ROIs to delete IDs that have annotations", + } + self.warnEditingWithAnnotActions = {} + for key, desc in warnEditingWithAnnotTexts.items(): + action = QAction() + action.setText(desc) + action.setCheckable(True) + action.setChecked(True) + action.removeAnnot = False + self.warnEditingWithAnnotActions[key] = action + warningsMenu.addAction(action) + + def gui_terminalButtonClicked(self, terminalVisible): + self.terminalDock.setVisible(terminalVisible) + + def retainSpaceSlidersToggled(self, checked): + if checked: + self.df_settings.at["retain_space_hidden_sliders", "value"] = "Yes" + else: + self.df_settings.at["retain_space_hidden_sliders", "value"] = "No" + self.df_settings.to_csv(self.settings_csv_path) + if not self.zSliceScrollBar.isEnabled(): + retainSpaceZ = False + else: + retainSpaceZ = checked + utils.setRetainSizePolicy(self.zSliceScrollBar, retain=retainSpaceZ) + utils.setRetainSizePolicy(self.zProjComboBox, retain=retainSpaceZ) + utils.setRetainSizePolicy(self.zSliceOverlay_SB, retain=retainSpaceZ) + utils.setRetainSizePolicy(self.zProjOverlay_CB, retain=retainSpaceZ) + utils.setRetainSizePolicy(self.overlay_z_label, retain=retainSpaceZ) + + QTimer.singleShot(200, self.resizeGui) + + def useCenterBrushCursorHoverIDtoggled(self, checked): + if checked: + self.df_settings.at["useCenterBrushCursorHoverID", "value"] = "Yes" + else: + self.df_settings.at["useCenterBrushCursorHoverID", "value"] = "No" + self.df_settings.to_csv(self.settings_csv_path) + + def zoomBottomLayoutActionTriggered(self, checked): + if not checked: + return + perc = int(re.findall(r"(\d+)%", self.sender().text())[0]) + if perc != 100: + fontSizeFactor = perc / 100 + heightFactor = perc / 100 + self.resizeSlidersArea( + fontSizeFactor=fontSizeFactor, heightFactor=heightFactor + ) + else: + self.gui_resetBottomLayoutHeight() + self.df_settings.at["bottom_sliders_zoom_perc", "value"] = perc + self.df_settings.to_csv(self.settings_csv_path) + QTimer.singleShot(150, self.resizeGui) + + def gui_createShowPropsButton(self, side="left"): + self.leftSideDocksLayout = QVBoxLayout() + self.leftSideDocksLayout.setSpacing(0) + self.leftSideDocksLayout.setContentsMargins(0, 0, 0, 0) + self.rightSideDocksLayout = QVBoxLayout() + self.rightSideDocksLayout.setSpacing(0) + self.rightSideDocksLayout.setContentsMargins(0, 0, 0, 0) + self.showPropsDockButton = widgets.expandCollapseButton() + self.showPropsDockButton.setDisabled(True) + self.showPropsDockButton.setFocusPolicy(Qt.NoFocus) + self.showPropsDockButton.setToolTip("Show object properties") + if side == "left": + self.leftSideDocksLayout.addWidget(self.showPropsDockButton) + else: + self.rightSideDocksLayout.addWidget(self.showPropsDockButton) diff --git a/cellacdc/mixins/lineage_interactions.py b/cellacdc/mixins/lineage_interactions.py new file mode 100644 index 000000000..d5c9c19be --- /dev/null +++ b/cellacdc/mixins/lineage_interactions.py @@ -0,0 +1,721 @@ +"""Qt view adapter for lineage-tree interaction workflows.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Sequence +import numpy as np +import pandas as pd +from qtpy.QtCore import Qt + +from cellacdc import ( + disableWindow, + exception_handler, + html_utils, + lineage_tree_cols, + printl, + widgets, +) +from cellacdc.trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import ( + normal_division_lineage_tree, +) + +from .annotation_display import AnnotationDisplay +from .tracking import Tracking + + +class LineageInteractions(AnnotationDisplay, Tracking): + """Extracted from guiWin.""" + + def annotate_unknown_lineage_action(self, posData, event, ydata, xdata): + """ + This function is part of the lin_tree edit functionality. + Associated with the right-click action of the 'unknownLineageButton' button. + Annotates an unknown lineage by setting its parent ID to -1 in the lineage tree (self.lineage_tree.lineage_list) + + Parameters + ---------- + posData : cellacdc.load.loadData + The position data. + event : QtGui.QMouseEvent + The event that triggered the annotation. + ydata : int + The y-coordinate data. + xdata : int + The x-coordinate data. + """ + point, ID = self.repeat_click_and_backup(posData, event, ydata, xdata) + + if point is None: + return + posData = self.data[self.pos_i] + acdc_df_frame = posData.allData_li[posData.frame_i]["acdc_df"] + acdc_df_frame.at[ID, "parent_ID_tree"] = -1 + self.drawAllLineageTreeLines() + + def askLineageTreeChanges(self): + """ + Asks the user for changes in the lineage tree. + + This method is called when the user selects the 'Normal division: Lineage tree' mode. + It compared the backed up df (self.original_df from repeat_click_and_backup) with the current df (self.lineage_tree.export_df(posData.frame_i)) and propts the user to keep, propagate or discard the changes. + + """ + mode = str(self.modeComboBox.currentText()) + if mode != "Normal division: Lineage tree": + return + + if not self.lineage_tree: + return + + posData = self.data[self.pos_i] + + if ( + self.original_df_lin_tree_i is not None + and self.original_df_lin_tree_i != posData.frame_i + ): + printl("!This should not happen!") + self.store_data(autosave=False) + og_frame = posData.frame_i + posData.frame_i = self.original_df_lin_tree_i + self.get_data() + self.logger.info( + "Lineage tree changes were not propagated, going back to original frame." + ) + self.askLineageTreeChanges() + self.store_data(autosave=False) + posData.frame_i = og_frame + self.get_data() + return + + result = self.get_difference_table( + return_css_separated=True, return_differece=True + ) + if result is None: + self.original_df_lin_tree = None + self.original_df_lin_tree_i = None + return + + css, txt, differences = result + changed_IDs = differences["Cell_ID"].unique() + + if posData.frame_i == max(self.lineage_tree.frames_for_dfs): + # here we can just propagate the cahnged. This is super fast, since there is no recursion, no children and fast finding of parents + self.lineage_tree.propagate(posData.frame_i, relevant_cells=changed_IDs) + self.original_df_lin_tree = None + self.original_df_lin_tree_i = None + return + + txt = txt + "Do you want to keep, propgagte or discard the changes?" + txt = css + html_utils.paragraph("Changes made in this frame
    " + txt) + + msg = widgets.myMessageBox() + + propagate_btn, discard_btn, _ = msg.question( + self, + "Changes in lineage tree", + txt, + buttonsTexts=("Propagate", "Discard", "Cancel"), + ) + + if msg.clickedButton == propagate_btn: + self.lineage_tree.propagate(posData.frame_i, relevant_cells=changed_IDs) + self.original_df_lin_tree = None + self.original_df_lin_tree_i = None + self.logger.info("Lineage tree propagated.") + + elif msg.clickedButton == discard_btn: + posData.allData_li[posData.frame_i]["acdc_df"] = ( + self.original_df_lin_tree.copy() + ) + self.original_df_lin_tree = None + self.original_df_lin_tree_i = None + self.logger.info("Lineage tree changes discarded.") + + elif msg.cancel: + # Go back to current frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + Changes were kept but not propagated! + Please make sure to come back and propagate them, + otherwise your table might be inconsistent! + There is a button for this next to the edit buttons. + Please also do not visit new frames! + + """) + msg.warning(self, "Changes kept but not propagated!", txt) + self.original_df_lin_tree = None + self.original_df_lin_tree_i = None + self.logger.info("Lineage tree changes discarded.") + + def autoLinTree_df(self, enforceAll=False): + """Automatically generates a lineage tree dataframe. + + This method generates a lineage tree dataframe based on the current mode and data. + It checks if the mode is set to 'Normal division: Lineage tree' and if the current frame + is not already processed. If the conditions are met, it retrieves the necessary data + from the current position data and previous position data, and passes it to the + `real_time` method of the `lineage_tree` object. Finally, it converts the lineage tree + to an ACDC dataframe and adds the current frame to the set of frames that have been + processed. + + Parameters + ---------- + enforceAll : bool, optional + If True, enforces processing of all frames, even if they have been processed before. + If False, only processes frames that have not been processed before. Default is False. + + Returns + ------- + bool + True if there are not enough G1 cells for lineage tree generation, False otherwise. + bool + True if the lineage tree generation should proceed, False otherwise. + """ + proceed = True + notEnoughG1Cells = False + mode = str(self.modeComboBox.currentText()) + + # Skip if not the right mode + if mode != "Normal division: Lineage tree": + return notEnoughG1Cells, proceed + + posData = self.data[self.pos_i] + frame_i = posData.frame_i + + if frame_i in self.lineage_tree.frames_for_dfs: + return notEnoughG1Cells, proceed + + # Make sure that this is a visited frame in segmentation tracking mode + if posData.allData_li[frame_i]["labels"] is None: # may need to change this + proceed = self.warnFrameNeverVisitedSegmMode() + return notEnoughG1Cells, proceed + + self.store_data(autosave=False) + self.get_data() + lab = posData.lab + prev_lab = posData.allData_li[frame_i - 1]["labels"] + rp = posData.rp + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + + self.lineage_tree.real_time(frame_i, lab, prev_lab, rp=rp, prev_rp=prev_rp) + self.store_data() + + def find_mother_action(self, posData, event, ydata, xdata): + """ + This function is part of the lin_tree edit functionality. + Associated with the right-click action of the 'findNextMotherButton' button. + Handles the right click action, which cycles through possible mothers of the clicked cell. + Changes the parent ID of the clicked cell to the next possible mother in self.lineage_tree.lineage_list. + + Parameters + ---------- + posData : cellacdc.load.loadData + The position data object. + event : QtGui.QMouseEvent + The event object. + ydata : int + The y-coordinate data. + xdata : int + The x-coordinate data. + """ + point, ID = self.repeat_click_and_backup(posData, event, ydata, xdata) + + if point is None: + return + posData = self.data[self.pos_i] + acdc_df_frame = posData.allData_li[posData.frame_i]["acdc_df"] + filtered_IDs = self.getDistanceListMissingIDs(point, ID) + if len(filtered_IDs) == 0: + self.logger.info("No mother candidates found.") + return + + i = self.right_click_i % len(filtered_IDs) + i = abs(i) # Ensure i is non-negative + new_mother = filtered_IDs[i] + + if ( + acdc_df_frame.loc[ID]["parent_ID_tree"] == new_mother + and self.original_mother_skipped == False + ): # if a mother is already present, skip it + self.right_click_i += 1 + self.original_mother_skipped = True + + i = self.right_click_i % len(filtered_IDs) + i = abs(i) # Ensure i is non-negative + new_mother = filtered_IDs[i] + + acdc_df_frame.at[ID, "parent_ID_tree"] = ( + new_mother # update mother in the df, no need to propagate or stuff lile this + ) + # dont need to update alldata_li as acdc_df_frame is just a view + self.drawAllLineageTreeLines() + + def getDistanceListMissingIDs(self, point, ID): + posData = self.data[self.pos_i] + frame_i = posData.frame_i + if self.getDistanceListMissingIDsCachedFrame != frame_i: + self.distanceListMissingIDs = dict() + self.getDistanceListMissingIDsCachedFrame = frame_i + # self.store_data(autosave=False) + # self.get_data() + + if ID not in self.distanceListMissingIDs.keys(): + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + relevant_rp = [obj for obj in prev_rp if obj.label not in posData.IDs] + len_relevant_rp = len(relevant_rp) + if len_relevant_rp == 0: + self.logger.info("No missing IDs found in previous frame.") + return [] + elif len_relevant_rp == 1: + self.distanceListMissingIDs[ID] = [relevant_rp[0].label] + return [relevant_rp[0].label] + else: + sorted_missing_IDs = utils.sort_IDs_dist(relevant_rp, point=point) + self.distanceListMissingIDs[ID] = sorted_missing_IDs + return sorted_missing_IDs + else: + return self.distanceListMissingIDs[ID] + + def get_difference_table(self, return_css_separated=False, return_differece=False): + + if self.original_df_lin_tree is None: + return + + posData = self.data[self.pos_i] + + new_df = posData.allData_li[posData.frame_i]["acdc_df"] + original_df = self.original_df_lin_tree.copy() + + if original_df.equals(new_df): + return + + compare_columns = ["parent_ID_tree"] + + new_df = new_df[original_df.columns] + new_df = utils.checked_reset_index_Cell_ID(new_df) + new_df = new_df[compare_columns] + new_df = new_df.sort_index() + original_df = utils.checked_reset_index_Cell_ID(original_df) + original_df = original_df[compare_columns] + original_df = original_df.sort_index() + + differences = original_df.compare(new_df) + if differences.empty: + return + + differences = utils.checked_reset_index_Cell_ID(differences) + + differences = differences["parent_ID_tree"] + differences = differences.reset_index() + + txt = """
    + + + + + """ + + for diff in differences.itertuples(): + ID = str(int(diff.Cell_ID)) + old_parent = str(int(diff.self)) + new_parent = str(int(diff.other)) + + txt += f""" + + + + """ + txt += "
    IDold parent -->new parent
    {ID}{old_parent}{new_parent}
    " + + css = r""" + + """ + if return_css_separated and not return_differece: + return css, txt + elif return_css_separated and return_differece: + return css, txt, differences + elif not return_css_separated and return_differece: + return txt, differences + else: + txt = css + html_utils.paragraph(txt) + return txt + + def initLinTree(self, force=False): + """ + Initializes the lineage tree analysis. + + This method checks if the tracking has been previously checked and saved. If not, it displays a message to the user. + It also prompts the user to go to the last annotated frame and restart the lineage tree analysis if necessary. + Finally, it initializes the necessary data structures and updates the GUI. + + Returns + ------- + proceed : bool + True if the initialization is successful, nothing otherwise. + """ + + if not force and self.lineage_tree is not None: + return + + mode = str(self.modeComboBox.currentText()) + if mode != "Normal division: Lineage tree" and not force: + return + + posData = self.data[self.pos_i] + last_tracked_i = self.get_last_tracked_i() + defaultMode = "Viewer" + if last_tracked_i == 0: + # Display message to the user + txt = html_utils.paragraph( + "On this dataset either you never checked that the segmentation " + "and tracking are correct or you did not save yet.

    " + 'If you already visited some frames with "Segmentation and Tracking" ' + 'mode save data before switching to "Normal division: Lineage Tree".

    ' + "Otherwise you first have to check (and eventually correct) some frames " + 'in "Segmentation and Tracking" mode before proceeding ' + "with lineage tree analysis." + ) + msg = widgets.myMessageBox() + msg.critical(self, "Tracking was never checked", txt) + self.modeComboBox.setCurrentText(defaultMode) + return + + proceed = True + last_lin_tree_frame_i = 0 + # Determine last annotated frame index + for i, dict_frame_i in enumerate(posData.allData_li): + df = dict_frame_i["acdc_df"] + if ( + df is None + or "generation_num_tree" not in df.columns + or df["generation_num_tree"].isin([np.nan, 0]).all() + ): + break + else: + last_lin_tree_frame_i = i + + if last_lin_tree_frame_i == 0: + # Remove undoable actions from segmentation mode + posData.UndoRedoStates[0] = [] + self.undoAction.setEnabled(False) + self.redoAction.setEnabled(False) + + if posData.frame_i > last_lin_tree_frame_i: + # Prompt user to go to last annotated frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + The last annotated frame is frame {last_lin_tree_frame_i + 1}.

    + Do you want to restart lineage tree analysis from frame + {last_lin_tree_frame_i + 1}?
    + """) + _, yesButton, stayButton = msg.warning( + self, + "Go to last annotated frame?", + txt, + buttonsTexts=( + "Cancel", + f"Yes, go to frame {last_lin_tree_frame_i + 1}", + "No, stay on current frame", + ), + ) + if yesButton == msg.clickedButton: + msg = "Looking good!" + self.last_lin_tree_frame_i = last_lin_tree_frame_i + posData.frame_i = last_lin_tree_frame_i + self.titleLabel.setText(msg, color=self.titleColor) + self.get_data(lin_tree_init=False) + self.updateAllImages() # i dont think I need to change this + self.updateScrollbars() # i dont think I need to change this + elif stayButton == msg.clickedButton: + self.initMissingFramesLinTree(posData.frame_i) #!!! + last_lin_tree_frame_i = posData.frame_i + msg = "Lineage tree analysis initialised!" + self.titleLabel.setText(msg, color="g") + elif msg.cancel: + msg = "Lineage tree analysis aborted." + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + self.modeComboBox.setCurrentText(defaultMode) + proceed = False + return + + elif posData.frame_i < last_lin_tree_frame_i: + # Prompt user to go to last annotated frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + The last annotated frame is frame {last_lin_tree_frame_i + 1}.

    + Do you want to restart lineage tree analysis from frame + {last_lin_tree_frame_i + 1}?
    + """) + goTo_last_annotated_frame_i = msg.question( + self, + "Go to last annotated frame?", + txt, + buttonsTexts=("Yes", "No", "Cancel"), + )[0] + if goTo_last_annotated_frame_i == msg.clickedButton: + msg = "Looking good!" + self.titleLabel.setText(msg, color=self.titleColor) + self.last_lin_tree_frame_i = last_lin_tree_frame_i + posData.frame_i = last_lin_tree_frame_i + self.get_data(lin_tree_init=False) + self.updateAllImages() # i dont think I need to change this + self.updateScrollbars() # i dont think I need to change this + elif msg.cancel: + msg = "Lineage tree analysis aborted." + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + self.modeComboBox.setCurrentText(defaultMode) + proceed = False + return + else: + self.get_data(lin_tree_init=False) + + self.last_lin_tree_frame_i = last_lin_tree_frame_i + + self.navigateScrollBar.setMaximum(last_lin_tree_frame_i + 1) + self.navSpinBox.setMaximum(last_lin_tree_frame_i + 1) + + if self.lineage_tree is None or force: + self.store_data(autosave=False) + self.get_data(lin_tree_init=False) + self.lineage_tree = normal_division_lineage_tree(gui=self) + + msg = "Lineage tree analysis initialized!" + self.logger.info(msg) + self.titleLabel.setText(msg, color=self.titleColor) + + return proceed + + def initMissingFramesLinTree( + self, current_frame_i + ): # done Need to add partially missing previous frames and loading + """ + When not starting from the first frame, automatically creates lineage tree dfs for all "skipped" frames and initializes the tree if not done so before. + + Parameters + ---------- + current_frame_i : int + The index of the current frame. + + Returns + ------- + None + + Notes + ----- + This method initializes the lineage tree annotations of missing past frames. If the lineage tree has not been initialized before, it creates a new lineage tree based on the labels of the first frame. It then iterates over the missing frames and updates the lineage tree with the labels and region properties of each frame. + """ + + self.logger.info( + "Initialising lineage tree annotations of missing past frames..." + ) + + self.store_data(autosave=False) + self.get_data() + + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + + if not self.lineage_tree: # init lin tree if not done already + self.lineage_tree = normal_division_lineage_tree( + gui=self + ) # here frame_i!=0 + + missing_frames = list(range(current_frame_i + 1)) + present_frames = ( + list(self.lineage_tree.frames_for_dfs) if self.lineage_tree else [] + ) + present_frames = [] if not present_frames else present_frames # deal with None + missing_frames = [ + frame_i for frame_i in missing_frames if frame_i not in present_frames + ] + missing_frames.sort() + + for frame_i in missing_frames: + lab = posData.allData_li[frame_i]["labels"] + prev_lab = posData.allData_li[frame_i - 1]["labels"] + rp = posData.allData_li[frame_i]["regionprops"] + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + # i might need to change this if I need support for only partially missing frames... Although I probably never have to care about that though + self.lineage_tree.real_time(frame_i, lab, prev_lab, rp=rp, prev_rp=prev_rp) + + posData.frame_i = current_frame_i + self.store_data() + + def propagateLinTreeAction(self, dummy_for_button=None): + """ + Propagates the lineage tree based on the current frame_i. Used in self.propagateLinTreeButton. + """ + posData = self.data[self.pos_i] + self.lineage_tree.propagate(posData.frame_i) + if posData.frame_i == self.original_df_lin_tree_i: + self.original_df_lin_tree = posData.allData_li[posData.frame_i][ + "acdc_df" + ].copy() + + self.logger.info("Lineage tree propagated.") + + def repeat_click_and_backup(self, posData, event, ydata, xdata): + """ + This function is part of the lin_tree edit functionality. + It handles the back up of the original self.lineage_tree.lineage_list + df and the repeated clicking on the same ID to cycle through pssible mothers. + + Parameters + ---------- + posData : cellacdc.load.loadData + The position data. + event : QtGui.QMouseEvent + The event object. + ydata : int + The y-coordinate data. + xdata : int + The x-coordinate data. + + Returns + ------- + tuple + A tuple containing the point(tuple: (x, y) coords) and ID of clicked cell. + """ + if self.original_df_lin_tree is None: + self.original_df_lin_tree = posData.allData_li[posData.frame_i][ + "acdc_df" + ].copy() + self.original_df_lin_tree_i = posData.frame_i + elif self.original_df_lin_tree_i != posData.frame_i: + self.logger.info( + "[WARNING]: !!! Original lineage tree df changed, resetting original_df_lin_tree !!!" + ) + self.original_df_lin_tree = posData.allData_li[posData.frame_i][ + "acdc_df" + ].copy() + self.original_df_lin_tree_i = posData.frame_i + + if not self.right_click_ID: + self.right_click_i = 0 + self.right_click_ID = 0 + + x, y = event.pos().x(), event.pos().y() + point = int(x), int(y) + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + + if ID == 0: + return None, None + + if self.right_click_ID != ID: + self.right_click_i = 0 + self.right_click_ID = ID + self.original_mother_skipped = False + elif event.modifiers() & Qt.ShiftModifier: + self.right_click_i -= 1 + else: + self.right_click_i += 1 + + return point, ID + + def resetLin_tree_future(self): + posData = self.data[self.pos_i] + frame_i = posData.frame_i + + for i in range(frame_i, posData.SizeT): + if self.lineage_tree is not None: + self.lineage_tree.frames_for_dfs.discard(frame_i) + df = posData.allData_li[i]["acdc_df"] + # reste lineage tree columns + if df is None: + continue + df = df.drop(columns=lineage_tree_cols, errors="ignore") + posData.allData_li[i]["acdc_df"] = df + + def viewLinTreeInfoAction(self): + mode = str(self.modeComboBox.currentText()) + if mode != "Normal division: Lineage tree": + self.logger.info( + 'This action is only available in the "Normal division: Lineage tree" mode.' + ) + return + + if not self.lineage_tree: + self.logger.info("No lineage tree found.") + return + + posData = self.data[self.pos_i] + + if self.original_df_lin_tree_i != posData.frame_i: + # could be that this is not entirley true and self.curr_original_df_i just didnt get set right though! + txt_changes = "
    No changes were made in this frame.

    " + + else: + result = self.get_difference_table(return_css_separated=True) + + if result is None: + txt_changes = "No changes were made in this frame." + else: + css, txt_changes = result + + txt_changes = "Changes made in this frame:" + txt_changes + "

    " + + cells_with_parent, orphan_cells, lost_cells = ( + self.lineage_tree.export_lin_tree_info(posData.frame_i) + ) + + if orphan_cells == []: + txt_orphan_cells = "No orphan Cells!" + else: + txt_orphan_cells = ", ".join([str(cell) for cell in orphan_cells]) + txt_orphan = f"Orphan cells:
    {txt_orphan_cells}

    " + + lost_cells = list(lost_cells) + if lost_cells == []: + txt_lost_cells = "No lost Cells!" + else: + txt_lost_cells = ", ".join([str(cell) for cell in lost_cells]) + txt_lost = f"Lost cells:
    {txt_lost_cells}

    " + + if cells_with_parent == []: + table_cells_with_parent = "
    No cells with parents!" + else: + table_cells_with_parent = """ + + + + """ + + for cell, parent in cells_with_parent: + table_cells_with_parent += f""" + + + """ + table_cells_with_parent += "
    Parent IDID
    {parent}{cell}
    " + + txt_cells_with_parents = ( + f"Cells with parents:{table_cells_with_parent}

    " + ) + + css = r""" + + """ + + txt = css + html_utils.paragraph( + txt_changes + txt_orphan + txt_lost + txt_cells_with_parents + ) + + msg = widgets.myMessageBox() + msg.information(self, "lineage tree information", txt) diff --git a/cellacdc/mixins/magic_prompts.py b/cellacdc/mixins/magic_prompts.py new file mode 100644 index 000000000..5171184fc --- /dev/null +++ b/cellacdc/mixins/magic_prompts.py @@ -0,0 +1,420 @@ +"""Qt view adapter for promptable segmentation workflows.""" + +from __future__ import annotations + +from functools import partial + +from typing import Mapping +from qtpy.QtCore import QEventLoop, QThread + +from cellacdc import ( + _warnings, + apps, + exception_handler, + html_utils, + prompts, + qutils, + widgets, + workers, +) +from cellacdc import disableWindow + +from .graphics import Graphics + + +class MagicPrompts(Graphics): + """Extracted from guiWin.""" + + def _importInitMagicPromptModel( + self, model_name, posData, win, acdcPromptSegment, toolbar + ): + self.logger.info(f"Initializing promptable model {model_name}...") + init_kwargs = win.init_kwargs + model = utils.init_prompt_segm_model( + acdcPromptSegment, posData, win.init_kwargs + ) + toolbar.model = model + toolbar.model_segment_kwargs = win.model_kwargs + toolbar.model_name = model_name + toolbar.viewModelParamsAction.setDisabled(False) + + self.magicPromptsToolbar.setInitializedModel( + init_kwargs, toolbar.model_segment_kwargs + ) + + self.logger.info(f"Promptable model {model_name} successfully initialised!") + + def getMagicPromptsInputs(self, toolbar): + if not self.promptSegmentPointsLayerToolbar.isPointsLayerInit: + _warnings.warnPromptSegmentPointsLayerNotInit(qparent=self) + return + + if not self.magicPromptsToolbar.viewModelParamsAction.isEnabled(): + _warnings.warnPromptSegmentModelNotInit(qparent=self) + return + + posData = self.data[self.pos_i] + image = self.getDisplayedZstack() + df_points = self.promptSegmentPointsLayerToolbar.pointsLayerDf( + posData, isSegm3D=self.isSegm3D + ) + + self.logger.info( + f"Starting {toolbar.model_name} promptable segmentation with the " + f"following prompts:\n\n{df_points}" + ) + + return image, df_points + + def magicPromptsClearPoints(self, toolbar, only_zoom=False): + posData = self.data[self.pos_i] + scatterItem = self.promptSegmentPointsLayerToolbar.scatterItem() + action = scatterItem.action + + pointsDataPos = action.pointsData.get(self.pos_i) + if pointsDataPos is None: + return + + framePointsData = action.pointsData[self.pos_i].pop(posData.frame_i, None) + if framePointsData is None: + return + + if not only_zoom: + scatterItem.clear() + return + + ((xmin, xmax), (ymin, ymax)) = self.ax1.viewRange() + Y, X = posData.img_data.shape[-2:] + + xmin = int(max(0, xmin)) + xmax = int(min(X, xmax)) + ymin = int(max(0, ymin)) + ymax = int(min(Y, ymax)) + + if "x" in framePointsData: + newFramePointsData = {"x": [], "y": [], "id": []} + xx = framePointsData["x"] + yy = framePointsData["y"] + ids = framePointsData["id"] + for x, y, point_id in zip(xx, yy, ids): + if x < xmin or x >= xmax or y < ymin or y >= ymax: + newFramePointsData["x"].append(x) + newFramePointsData["y"].append(y) + newFramePointsData["id"].append(point_id) + else: + newFramePointsData = {} + for z, zSliceFramePointsData in framePointsData.items(): + newFramePointsData[z] = {"x": [], "y": [], "id": []} + xx = zSliceFramePointsData["x"] + yy = zSliceFramePointsData["y"] + ids = zSliceFramePointsData["id"] + for x, y, point_id in zip(xx, yy, ids): + if x < xmin or x >= xmax or y < ymin or y >= ymax: + newFramePointsData[z]["x"].append(x) + newFramePointsData[z]["y"].append(y) + newFramePointsData[z]["id"].append(point_id) + + action.pointsData[self.pos_i][posData.frame_i] = newFramePointsData + self.drawPointsLayers() + + def magicPromptsComputeOnImageTriggered(self, toolbar): + inputs = self.getMagicPromptsInputs(toolbar) + if inputs is None: + self.logger.info( + '"Computing promptable segmentation on entire image" process cancelled.' + ) + return + + image, df_points = inputs + + self.startMagicPromptsWorkerAndWait( + image, df_points, toolbar.model, toolbar.model_segment_kwargs + ) + + def magicPromptsComputeOnZoomTriggered(self, toolbar): + inputs = self.getMagicPromptsInputs(toolbar) + if inputs is None: + self.logger.info( + '"Computing promptable segmentation on zoom" process cancelled.' + ) + return + + posData = self.data[self.pos_i] + image, df_points = inputs + + ((xmin, xmax), (ymin, ymax)) = self.ax1.viewRange() + Y, X = image.shape[-2:] + + xmin = int(max(0, xmin)) + xmax = int(min(X, xmax)) + ymin = int(max(0, ymin)) + ymax = int(min(Y, ymax)) + + self.logger.info( + f"Zoom range: xmin={xmin}, xmax={xmax}, ymin={ymin}, ymax={ymax}" + ) + + zoom_slice = (slice(ymin, ymax), slice(xmin, xmax)) + + image = image[..., ymin:ymax, xmin:xmax] + image_origin = (0, ymin, xmin) + + df_points = df_points[df_points["y"] >= ymin] + df_points = df_points[df_points["x"] >= xmin] + df_points = df_points[df_points["y"] < ymax] + df_points = df_points[df_points["x"] < xmax] + + df_points["y"] -= ymin + df_points["x"] -= xmin + + df_points = df_points[df_points["frame_i"] == posData.frame_i] + + self.logger.info(f"Image origin = {image_origin}\nImage shape = {image.shape}") + + self.startMagicPromptsWorkerAndWait( + image, + df_points, + toolbar.model, + toolbar.model_segment_kwargs, + image_origin=image_origin, + zoom_slice=zoom_slice, + ) + + def magicPromptsInitModel( + self, + model_name, + acdcPromptSegment, + init_argspecs, + segment_argspecs, + help_url, + toolbar, + ): + posData = self.data[self.pos_i] + + out = prompts.init_prompt_model_params( + posData, + model_name, + init_argspecs, + segment_argspecs, + help_url=help_url, + qparent=self, + init_last_params=True, + ) + win = out.get("win") + if win.cancel: + self.logger.info( + f"Initialization of {model_name} promptable model cancelled." + ) + return + + self._importInitMagicPromptModel( + model_name, posData, win, acdcPromptSegment, toolbar + ) + + def magicPromptsInterpolateZsliceToggled(self, checked): + # See 'self.promptSegmentPointsLayerToolbar.addPointsZslicesInterpolation' + self.promptSegmentPointsLayerToolbar.doAddPointsZslicesInterpolation = checked + + def magicPromptsWorkerCritical(self, error): + self.magicPromptsWorkerLoop.exit() + self.workerCritical(error) + + def magicPromptsWorkerFinished(self, output, zoom_slice=None): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.magicPromptsWorkerLoop.exit() + + lab_new, lab_union, lab_interesection = output + + posData = self.data[self.pos_i] + + is_zoom = True + if zoom_slice is None: + zoom_slice = (slice(None), slice(None)) + is_zoom = False + + img = posData.img_data[posData.frame_i][..., zoom_slice[0], zoom_slice[1]] + images = [img, img, img, img] + labels_overlays = [ + posData.lab[..., zoom_slice[0], zoom_slice[1]], + lab_new[..., zoom_slice[0], zoom_slice[1]], + lab_union[..., zoom_slice[0], zoom_slice[1]], + lab_interesection[..., zoom_slice[0], zoom_slice[1]], + ] + labels_overlays_lut = self.getLabelsImageLut() + labels_overlays_luts = [ + labels_overlays_lut, + labels_overlays_lut, + labels_overlays_lut, + labels_overlays_lut, + ] + axis_titles = [ + "Original masks", + "New masks", + "Union of original and new masks", + "Intersection of original and new masks", + ] + + from cellacdc.plot import imshow + + promptSegmResultsWindow = imshow( + *images, + labels_overlays=labels_overlays, + labels_overlays_luts=labels_overlays_luts, + axis_titles=axis_titles, + window_title="Promptable segmentation results", + figure_title="Ctrl+Click to select the result to use", + annotate_labels_idxs=[0, 1, 2, 3], + selectable_images=True, + max_ncols=2, + lut="gray", + infer_rgb=False, + ) + if promptSegmResultsWindow.selected_idx is None: + self.logger.info( + "Selection of the promptable model segmentation result cancelled." + ) + return + + if promptSegmResultsWindow.selected_idx == 0: + self.logger.info( + "No selection of a promptable model segmentation result was made" + ) + return + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + results = (None, lab_new, lab_union, lab_interesection) + selected_idx = promptSegmResultsWindow.selected_idx + zoom_out_lab = results[selected_idx][..., zoom_slice[0], zoom_slice[1]] + zoom_out_lab_mask = zoom_out_lab > 0 + + lab = posData.allData_li[posData.frame_i]["labels"] + lab[..., zoom_slice[0], zoom_slice[1]][zoom_out_lab_mask] = zoom_out_lab[ + zoom_out_lab_mask + ] + + posData.allData_li[posData.frame_i]["labels"] = lab + self.get_data() + self.store_data(autosave=False) + self.updateAllImages() + + def segmWithPromptableModelActionTriggered(self): + self.blinker = qutils.QControlBlink(self.magicPromptsToolButton, qparent=self) + self.blinker.start() + + def showInstructionsCustomPromptModel(self): + modelFilePath = apps.addCustomPromptModelMessages(QParent=self) + if modelFilePath is None: + self.logger.info("Adding custom promptable model process stopped.") + return + + utils.store_custom_promptable_model_path(modelFilePath) + + msg = widgets.myMessageBox(wrapText=False) + info_txt = html_utils.paragraph(f""" + Done!

    + The custom promptable model has been added to the list of models.

    + Use the Magic prompts button (top toolbar) to use it.

    + Have fun! + """) + msg.information(self, "Custom promptable model added", info_txt) + + def startMagicPromptsWorkerAndWait( + self, + image, + df_points, + model, + model_segment_kwargs, + image_origin=(0, 0, 0), + zoom_slice=None, + ): + desc = "Running promptable segmentation model..." + self.logger.info(desc) + posData = self.data[self.pos_i] + + self.progressWin = apps.QDialogWorkerProgress( + title=desc, parent=self, pbarDesc=desc + ) + self.progressWin.mainPbar.setMaximum(0) + self.progressWin.show(self.app) + + self.magicPromptsThread = QThread() + self.magicPromptsWorker = workers.MagicPromptsWorker( + posData, + image, + df_points, + model, + model_segment_kwargs, + image_origin=image_origin, + global_image=posData.img_data[posData.frame_i], + ) + + self.magicPromptsWorker.moveToThread(self.magicPromptsThread) + + self.magicPromptsWorker.signals.finished.connect(self.magicPromptsThread.quit) + self.magicPromptsWorker.signals.finished.connect( + self.magicPromptsWorker.deleteLater + ) + self.magicPromptsThread.finished.connect(self.magicPromptsThread.deleteLater) + + self.magicPromptsWorker.signals.critical.connect( + self.magicPromptsWorkerCritical + ) + self.magicPromptsWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.magicPromptsWorker.signals.progressBar.connect( + self.workerUpdateProgressbar + ) + self.magicPromptsWorker.signals.progress.connect(self.workerProgress) + self.magicPromptsWorker.signals.finished.connect( + partial(self.magicPromptsWorkerFinished, zoom_slice=zoom_slice) + ) + + self.magicPromptsThread.started.connect(self.magicPromptsWorker.run) + self.magicPromptsThread.start() + + self.magicPromptsWorkerLoop = QEventLoop() + self.magicPromptsWorkerLoop.exec_() + + def viewSetMagicPromptModelParams( + self, + model_name, + acdcPromptSegment, + init_argspecs, + segment_argspecs, + help_url, + init_kwargs, + segment_kwargs, + toolbar, + ): + posData = self.data[self.pos_i] + + init_argspecs = utils.setDefaultValueArgSpecsFromKwargs( + init_argspecs, init_kwargs + ) + segment_argspecs = utils.setDefaultValueArgSpecsFromKwargs( + segment_argspecs, segment_kwargs + ) + + out = prompts.init_prompt_model_params( + posData, + model_name, + init_argspecs, + segment_argspecs, + help_url=help_url, + qparent=self, + init_last_params=False, + ) + win = out.get("win") + if win.cancel: + return + + if win.model_kwargs != segment_kwargs or win.init_kwargs != init_kwargs: + self._importInitMagicPromptModel( + model_name, posData, win, acdcPromptSegment, toolbar + ) diff --git a/cellacdc/mixins/main_menu.py b/cellacdc/mixins/main_menu.py new file mode 100644 index 000000000..4ce6c14aa --- /dev/null +++ b/cellacdc/mixins/main_menu.py @@ -0,0 +1,195 @@ +"""View adapter for the main menu.""" + +from __future__ import annotations + +from qtpy.QtWidgets import QAction, QActionGroup, QMenu + + +class MainMenu: + """Extracted from guiWin.""" + + def gui_createMenuBar(self): + menuBar = self.menuBar() + menuBar.setNativeMenuBar(False) + + # File menu + fileMenu = QMenu("&File", self) + self.fileMenu = fileMenu + menuBar.addMenu(fileMenu) + if self.debug: + fileMenu.addAction(self.createEmptyDataAction) + fileMenu.addAction(self.newAction) + fileMenu.addAction(self.newWindowAction) + fileMenu.addSeparator() + fileMenu.addAction(self.openFolderAction) + fileMenu.addAction(self.openFileAction) + # Open Recent submenu + self.openRecentMenu = fileMenu.addMenu("Open Recent") + fileMenu.addSeparator() + fileMenu.addAction(self.manageVersionsAction) + fileMenu.addAction(self.saveAction) + fileMenu.addAction(self.saveAsAction) + fileMenu.addAction(self.quickSaveAction) + fileMenu.addSeparator() + + self.exportMenu = fileMenu.addMenu("Export") + self.exportMenu.addAction(self.exportToVideoAction) + self.exportMenu.addAction(self.exportToImageAction) + fileMenu.addSeparator() + fileMenu.addAction(self.loadFluoAction) + fileMenu.addAction(self.loadPosAction) + # Separator + self.fileMenu.lastSeparator = fileMenu.addSeparator() + fileMenu.addAction(self.exitAction) + + # Edit menu + editMenu = menuBar.addMenu("&Edit") + editMenu.addSeparator() + + editMenu.addAction(self.editShortcutsAction) + editMenu.addAction(self.editTextIDsColorAction) + editMenu.addAction(self.editOverlayColorAction) + editMenu.addAction(self.manuallyEditCcaAction) + editMenu.addAction(self.enableSmartTrackAction) + editMenu.addAction(self.enableAutoZoomToCellsAction) + + # View menu + self.viewMenu = menuBar.addMenu("&View") + self.viewMenu.addSeparator() + self.viewMenu.addAction(self.viewCcaTableAction) + + # Image menu + ImageMenu = menuBar.addMenu("&Image") + ImageMenu.addSeparator() + ImageMenu.addAction(self.imgPropertiesAction) + self.defaultRescaleIntensLutMenu = ImageMenu.addMenu( + "Default method to rescale intensities (LUT)" + ) + + self.defaultRescaleIntensActionGroup = QActionGroup( + self.defaultRescaleIntensLutMenu + ) + howTexts = ( + "Rescale each 2D image", + "Rescale across z-stack", + "Rescale across time frames", + "Do no rescale, display raw image", + ) + try: + self.defaultRescaleIntensHow = self.df_settings.at[ + "default_rescale_intens_how", "value" + ] + except Exception as err: + self.defaultRescaleIntensHow = howTexts[0] + + for howText in howTexts: + action = QAction(howText, self.defaultRescaleIntensLutMenu) + action.setCheckable(True) + if howText == self.defaultRescaleIntensHow: + action.setChecked(True) + + self.defaultRescaleIntensActionGroup.addAction(action) + self.defaultRescaleIntensLutMenu.addAction(action) + + ImageMenu.addAction(self.addScaleBarAction) + ImageMenu.addAction(self.addTimestampAction) + + self.rescaleIntensMenu = ImageMenu.addMenu("Rescale intensities (LUT)") + + ImageMenu.addAction(self.preprocessAction) + ImageMenu.addAction(self.combineChannelsAction) + ImageMenu.addAction(self.saveLabColormapAction) + ImageMenu.addAction(self.shuffleCmapAction) + ImageMenu.addAction(self.greedyShuffleCmapAction) + ImageMenu.addAction(self.zoomToObjsAction) + ImageMenu.addAction(self.zoomOutAction) + + # Segment menu + SegmMenu = menuBar.addMenu("&Segment") + self.segmentMenu = SegmMenu + SegmMenu.addSeparator() + self.segmSingleFrameMenu = SegmMenu.addMenu("Segment displayed frame") + for action in self.segmActions: + self.segmSingleFrameMenu.addAction(action) + + self.segmSingleFrameMenu.addSeparator() + self.segmSingleFrameMenu.addAction(self.addCustomModelFrameAction) + + self.segmVideoMenu = SegmMenu.addMenu("Segment multiple frames") + for action in self.segmActionsVideo: + self.segmVideoMenu.addAction(action) + + self.segmVideoMenu.addSeparator() + self.segmVideoMenu.addAction(self.addCustomModelVideoAction) + + self.segmWithPromptableModelMenu = SegmMenu.addMenu( + "Segment with promptable model" + ) + + self.segmWithPromptableModelMenu.addAction(self.segmWithPromptableModelAction) + + self.segmWithPromptableModelMenu.addSeparator() + self.segmWithPromptableModelMenu.addAction(self.addCustomPromptModelAction) + + SegmMenu.addAction(self.EditSegForLostIDsSetSettings) + SegmMenu.addAction(self.postProcessSegmAction) + SegmMenu.addAction(self.autoSegmAction) + SegmMenu.addAction(self.relabelSequentialAction) + SegmMenu.aboutToShow.connect(self.nonViewerEditMenuOpened) + + # Tracking menu + trackingMenu = menuBar.addMenu("&Tracking") + self.trackingMenu = trackingMenu + trackingMenu.addSeparator() + selectTrackAlgoMenu = trackingMenu.addMenu( + "Select real-time tracking algorithm" + ) + for rtTrackerAction in self.trackingAlgosGroup.actions(): + selectTrackAlgoMenu.addAction(rtTrackerAction) + + trackingMenu.addAction(self.editRtTrackerParamsAction) + trackingMenu.addAction(self.repeatTrackingVideoAction) + + trackingMenu.addAction(self.repeatTrackingMenuAction) + trackingMenu.aboutToShow.connect(self.nonViewerEditMenuOpened) + + if self.mainWin is not None: + trackingMenu.addAction(self.mainWin.applyTrackingFromTableAction) + trackingMenu.addAction(self.mainWin.applyTrackingFromTrackMateXMLAction) + + # Measurements menu + measurementsMenu = menuBar.addMenu("&Measurements") + self.measurementsMenu = measurementsMenu + measurementsMenu.addSeparator() + measurementsMenu.addAction(self.setMeasurementsAction) + measurementsMenu.addAction(self.addCustomMetricAction) + measurementsMenu.addAction(self.addCombineMetricAction) + measurementsMenu.setDisabled(True) + + # Settings menu + self.settingsMenu = QMenu("Settings", self) + menuBar.addMenu(self.settingsMenu) + self.settingsMenu.addAction(self.invertBwAction) + self.settingsMenu.addAction(self.toggleColorSchemeAction) + self.settingsMenu.addSeparator() + self.settingsMenu.addAction(self.pxModeAction) + self.settingsMenu.addAction(self.highLowResAction) + self.settingsMenu.addAction(self.editShortcutsAction) + self.settingsMenu.addAction(self.showMirroredCursorAction) + self.settingsMenu.addSeparator() + self.settingsMenu.addAction(self.editAutoSaveIntervalAction) + self.settingsMenu.addSeparator() + + # Mode menu (actions added when self.modeComboBox is created) + self.modeMenu = menuBar.addMenu("Mode") + self.modeMenu.menuAction().setVisible(False) + + # Help menu + helpMenu = menuBar.addMenu("&Help") + helpMenu.addAction(self.openLogFileAction) + helpMenu.addAction(self.showLogFilesAction) + helpMenu.addAction(self.tipsAction) + helpMenu.addAction(self.UserManualAction) + helpMenu.addSeparator() + helpMenu.addAction(self.aboutAction) + self.helpMenu = helpMenu diff --git a/cellacdc/mixins/main_toolbar.py b/cellacdc/mixins/main_toolbar.py new file mode 100644 index 000000000..12dba9747 --- /dev/null +++ b/cellacdc/mixins/main_toolbar.py @@ -0,0 +1,561 @@ +"""Qt view adapter for the main GUI toolbars.""" + +from __future__ import annotations + +from qtpy.QtCore import Qt +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import QAction, QActionGroup, QButtonGroup, QToolButton + +import pyqtgraph as pg + +from cellacdc import widgets + +from .actions import Actions + + +class MainToolbar(Actions): + """Extracted from guiWin.""" + + def closeToolbars(self): + for toolbar in self.sender().toolbars: + toolbar.setVisible(False) + for action in toolbar.actions(): + try: + action.button.setChecked(False) + except Exception as e: + pass + + def gui_createAnnotateToolbar(self): + # Edit toolbar + self.annotateToolbar = widgets.ToolBar("Custom annotations", self) + self.annotateToolbar.setContextMenuPolicy(Qt.PreventContextMenu) + self.addToolBar(Qt.LeftToolBarArea, self.annotateToolbar) + self.annotateToolbar.addAction(self.loadCustomAnnotationsAction) + self.annotateToolbar.addAction(self.addCustomAnnotationAction) + self.annotateToolbar.addAction(self.viewAllCustomAnnotAction) + self.annotateToolbar.setVisible(False) + + def gui_createToolBars(self): + # File toolbar + fileToolBar = self.addToolBar("File") + # fileToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) + fileToolBar.setMovable(False) + + self.segmNdimIndicatorAction = fileToolBar.addWidget(self.segmNdimIndicator) + self.segmNdimIndicatorAction.setVisible(False) + fileToolBar.addAction(self.newAction) + fileToolBar.addAction(self.openFolderAction) + fileToolBar.addAction(self.openFileAction) + fileToolBar.addAction(self.manageVersionsAction) + fileToolBar.addAction(self.saveAction) + fileToolBar.addAction(self.showInExplorerAction) + # fileToolBar.addAction(self.reloadAction) + fileToolBar.addAction(self.undoAction) + fileToolBar.addAction(self.redoAction) + self.fileToolBar = fileToolBar + self.setEnabledFileToolbar(False) + + self.undoAction.setEnabled(False) + self.redoAction.setEnabled(False) + + # Navigation toolbar + navigateToolBar = widgets.ToolBar("Navigation", self) + navigateToolBar.setContextMenuPolicy(Qt.PreventContextMenu) + # navigateToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) + self.addToolBar(navigateToolBar) + navigateToolBar.addAction(self.findIdAction) + + navigateToolBar.addWidget(self.zoomRectButton) + + self.slideshowButton = QToolButton(self) + self.slideshowButton.setIcon(QIcon(":eye-plus.svg")) + self.slideshowButton.setCheckable(True) + self.slideshowButton.setShortcut("Ctrl+W") + navigateToolBar.addWidget(self.slideshowButton) + + navigateToolBar.addAction(self.autoPilotButton) + + # navigateToolBar.setIconSize(QSize(toolbarSize, toolbarSize)) + navigateToolBar.addAction(self.skipToNewIdAction) + + self.preprocessImageAction = QAction("Preprocess image", self) + self.preprocessImageAction.setIcon(QIcon(":filter_image.svg")) + navigateToolBar.addAction(self.preprocessImageAction) + + self.overlayButton = widgets.rightClickToolButton(parent=self) + self.overlayButton.setIcon(QIcon(":overlay.svg")) + self.overlayButton.setCheckable(True) + + self.overlayButtonAction = navigateToolBar.addWidget(self.overlayButton) + # self.checkableButtons.append(self.overlayButton) + # self.checkableQButtonsGroup.addButton(self.overlayButton) + + self.countObjsButton = QToolButton(self) + self.countObjsButton.setIcon(QIcon(":count_objects.svg")) + self.countObjsButton.setCheckable(True) + self.countObjsButton.setShortcut("Ctrl+Shift+C") + self.countObjsButtonAction = navigateToolBar.addWidget(self.countObjsButton) + + self.togglePointsLayerAction = QAction("Activate points layer", self) + self.togglePointsLayerAction.setCheckable(True) + self.togglePointsLayerAction.setIcon(QIcon(":pointsLayer.svg")) + navigateToolBar.addAction(self.togglePointsLayerAction) + + self.overlayLabelsButton = widgets.rightClickToolButton(parent=self) + self.overlayLabelsButton.setIcon(QIcon(":overlay_labels.svg")) + self.overlayLabelsButton.setCheckable(True) + # self.overlayLabelsButton.setVisible(False) + self.overlayLabelsButtonAction = navigateToolBar.addWidget( + self.overlayLabelsButton + ) + self.overlayLabelsButtonAction.setVisible(False) + + self.rulerButton = QToolButton(self) + self.rulerButton.setIcon(QIcon(":ruler.svg")) + self.rulerButton.setCheckable(True) + navigateToolBar.addWidget(self.rulerButton) + self.checkableButtons.append(self.rulerButton) + self.LeftClickButtons.append(self.rulerButton) + + # fluorescence image color widget + colorsToolBar = widgets.ToolBar("Colors", self) + + self.overlayColorButton = pg.ColorButton(self, color=(230, 230, 230)) + self.overlayColorButton.setDisabled(True) + colorsToolBar.addWidget(self.overlayColorButton) + + self.textIDsColorButton = pg.ColorButton(self) + colorsToolBar.addWidget(self.textIDsColorButton) + + self.addToolBar(colorsToolBar) + colorsToolBar.setVisible(False) + + self.navigateToolBar = navigateToolBar + + # cca toolbar + ccaToolBar = widgets.ToolBar("Cell cycle annotations", self) + self.addToolBar(ccaToolBar) + + # Assign mother to bud button + self.assignBudMothButton = QToolButton(self) + self.assignBudMothButton.setIcon(QIcon(":assign-motherbud.svg")) + self.assignBudMothButton.setCheckable(True) + self.assignBudMothButton.setShortcut("A") + self.assignBudMothButton.setVisible(False) + self.assignBudMothButton.action = ccaToolBar.addWidget(self.assignBudMothButton) + self.checkableButtons.append(self.assignBudMothButton) + self.checkableQButtonsGroup.addButton(self.assignBudMothButton) + self.functionsNotTested3D.append(self.assignBudMothButton) + + # Set is_history_known button + self.setIsHistoryKnownButton = QToolButton(self) + self.setIsHistoryKnownButton.setIcon(QIcon(":history.svg")) + self.setIsHistoryKnownButton.setCheckable(True) + self.setIsHistoryKnownButton.setShortcut("U") + self.setIsHistoryKnownButton.setVisible(False) + self.setIsHistoryKnownButton.action = ccaToolBar.addWidget( + self.setIsHistoryKnownButton + ) + self.checkableButtons.append(self.setIsHistoryKnownButton) + self.checkableQButtonsGroup.addButton(self.setIsHistoryKnownButton) + self.functionsNotTested3D.append(self.setIsHistoryKnownButton) + + ccaToolBar.addAction(self.assignBudMothAutoAction) + ccaToolBar.addAction(self.editCcaToolAction) + ccaToolBar.addAction(self.reInitCcaAction) + ccaToolBar.setVisible(False) + self.ccaToolBar = ccaToolBar + self.functionsNotTested3D.append(self.assignBudMothAutoAction) + self.functionsNotTested3D.append(self.reInitCcaAction) + self.functionsNotTested3D.append(self.editCcaToolAction) + + # Edit toolbar + editToolBar = widgets.ToolBar("Edit", self) + editToolBar.setContextMenuPolicy(Qt.PreventContextMenu) + + self.addToolBar(editToolBar) + + self.manulAnnotToolButtons = set() + + self.brushButton = QToolButton(self) + self.brushButton.setIcon(QIcon(":brush.svg")) + self.brushButton.setCheckable(True) + editToolBar.addWidget(self.brushButton) + self.checkableButtons.append(self.brushButton) + self.LeftClickButtons.append(self.brushButton) + self.brushButton.keyPressShortcut = Qt.Key_B + self.widgetsWithShortcut["Brush"] = self.brushButton + self.manulAnnotToolButtons.add(self.brushButton) + + self.eraserButton = QToolButton(self) + self.eraserButton.setIcon(QIcon(":eraser.svg")) + self.eraserButton.setCheckable(True) + editToolBar.addWidget(self.eraserButton) + self.eraserButton.keyPressShortcut = Qt.Key_X + self.widgetsWithShortcut["Eraser"] = self.eraserButton + self.checkableButtons.append(self.eraserButton) + self.LeftClickButtons.append(self.eraserButton) + self.manulAnnotToolButtons.add(self.eraserButton) + + self.curvToolButton = QToolButton(self) + self.curvToolButton.setIcon(QIcon(":curvature-tool.svg")) + self.curvToolButton.setCheckable(True) + self.curvToolButton.setShortcut("C") + self.curvToolButton.action = editToolBar.addWidget(self.curvToolButton) + self.LeftClickButtons.append(self.curvToolButton) + # self.functionsNotTested3D.append(self.curvToolButton) + self.widgetsWithShortcut["Curvature tool"] = self.curvToolButton + # self.checkableButtons.append(self.curvToolButton) + self.manulAnnotToolButtons.add(self.curvToolButton) + + self.wandToolButton = QToolButton(self) + self.wandToolButton.setIcon(QIcon(":magic_wand.svg")) + self.wandToolButton.setCheckable(True) + self.wandToolButton.setShortcut("Ctrl+D") + self.wandToolButton.action = editToolBar.addWidget(self.wandToolButton) + self.LeftClickButtons.append(self.wandToolButton) + self.checkableButtons.append(self.eraserButton) + self.widgetsWithShortcut["Magic wand"] = self.wandToolButton + + self.magicPromptsToolButton = QToolButton(self) + self.magicPromptsToolButton.setIcon(QIcon(":magic-prompts.svg")) + self.magicPromptsToolButton.setCheckable(True) + self.magicPromptsToolButton.setShortcut("W") + self.magicPromptsToolButton.action = editToolBar.addWidget( + self.magicPromptsToolButton + ) + self.widgetsWithShortcut["Magic prompts"] = self.magicPromptsToolButton + + self.drawClearRegionButton = QToolButton(self) + self.drawClearRegionButton.setCheckable(True) + self.drawClearRegionButton.setIcon(QIcon(":clear_freehand_region.svg")) + self.widgetsWithShortcut["Clear freehand region"] = self.drawClearRegionButton + self.toolsActiveInProj3Dsegm.add(self.drawClearRegionButton) + + self.checkableButtons.append(self.drawClearRegionButton) + self.LeftClickButtons.append(self.drawClearRegionButton) + + self.drawClearRegionAction = editToolBar.addWidget(self.drawClearRegionButton) + + self.widgetsWithShortcut["Annotate mother/daughter pairing"] = ( + self.assignBudMothButton + ) + self.widgetsWithShortcut["Annotate unknown history"] = ( + self.setIsHistoryKnownButton + ) + + self.copyLostObjButton = QToolButton(self) + self.copyLostObjButton.setIcon(QIcon(":copyContour.svg")) + self.copyLostObjButton.setCheckable(True) + self.copyLostObjButton.setShortcut("V") + self.copyLostObjButton.action = editToolBar.addWidget(self.copyLostObjButton) + self.checkableButtons.append(self.copyLostObjButton) + self.checkableQButtonsGroup.addButton(self.copyLostObjButton) + self.widgetsWithShortcut["Copy lost object contour"] = self.copyLostObjButton + self.functionsNotTested3D.append(self.copyLostObjButton) + + self.labelRoiButton = widgets.rightClickToolButton(parent=self) + self.labelRoiButton.setIcon(QIcon(":label_roi.svg")) + self.labelRoiButton.setCheckable(True) + self.labelRoiButton.setShortcut("L") + self.labelRoiButton.action = editToolBar.addWidget(self.labelRoiButton) + self.LeftClickButtons.append(self.labelRoiButton) + self.checkableButtons.append(self.labelRoiButton) + self.checkableQButtonsGroup.addButton(self.labelRoiButton) + self.widgetsWithShortcut["Label ROI"] = self.labelRoiButton + # self.functionsNotTested3D.append(self.labelRoiButton) + + self.manualAnnotPastButton = QToolButton(self) + self.manualAnnotPastButton.setIcon(QIcon(":lock_id_annotate_future.svg")) + self.manualAnnotPastButton.setCheckable(True) + self.manualAnnotPastButton.setShortcut("Y") + self.manualAnnotPastButton.action = editToolBar.addWidget( + self.manualAnnotPastButton + ) + self.checkableButtons.append(self.manualAnnotPastButton) + self.widgetsWithShortcut["Lock ID and annotate single object"] = ( + self.manualAnnotPastButton + ) + self.functionsNotTested3D.append(self.manualAnnotPastButton) + self.manulAnnotToolButtons.add(self.manualAnnotPastButton) + + self.segmentToolAction = QAction("Segment with last used model", self) + self.segmentToolAction.setIcon(QIcon(":segment.svg")) + self.segmentToolAction.setShortcut("R") + self.widgetsWithShortcut["Repeat segmentation"] = self.segmentToolAction + editToolBar.addAction(self.segmentToolAction) + + self.segForLostIDsButton = QToolButton(self) + self.segForLostIDsButton.setIcon(QIcon(":segForLostIDs.svg")) + self.segForLostIDsAction = editToolBar.addWidget(self.segForLostIDsButton) + self.segForLostIDsButton.clicked.connect(self.segForLostIDsButtonClicked) + + # self.SegForLostIDsButton.setShortcut('U') + # self.widgetsWithShortcut['Unknown lineage (lineage tree)'] = self.SegForLostIDsButton + + self.manualBackgroundButton = QToolButton(self) + self.manualBackgroundButton.setIcon(QIcon(":manual_background.svg")) + self.manualBackgroundButton.setCheckable(True) + self.manualBackgroundButton.setShortcut("G") + self.LeftClickButtons.append(self.manualBackgroundButton) + self.checkableButtons.append(self.manualBackgroundButton) + self.checkableQButtonsGroup.addButton(self.manualBackgroundButton) + self.widgetsWithShortcut["Manual background"] = self.manualBackgroundButton + + self.manualBackgroundAction = editToolBar.addWidget(self.manualBackgroundButton) + + self.delObjsOutSegmMaskAction = QAction( + QIcon(":del_objs_out_segm.svg"), + "Select a segmentation file and delete all objects on the background", + self, + ) + self.delObjsOutSegmMaskAction.setShortcut("I") + self.widgetsWithShortcut["Delete all objects outside segm"] = ( + self.delObjsOutSegmMaskAction + ) + editToolBar.addAction(self.delObjsOutSegmMaskAction) + + self.hullContToolButton = QToolButton(self) + self.hullContToolButton.setIcon(QIcon(":hull.svg")) + self.hullContToolButton.setCheckable(True) + self.hullContToolButton.setShortcut("O") + self.hullContToolButton.action = editToolBar.addWidget(self.hullContToolButton) + self.checkableButtons.append(self.hullContToolButton) + self.checkableQButtonsGroup.addButton(self.hullContToolButton) + self.functionsNotTested3D.append(self.hullContToolButton) + self.widgetsWithShortcut["Hull contour"] = self.hullContToolButton + + self.fillHolesToolButton = QToolButton(self) + self.fillHolesToolButton.setIcon(QIcon(":fill_holes.svg")) + self.fillHolesToolButton.setCheckable(True) + self.fillHolesToolButton.setShortcut("F") + self.fillHolesToolButton.action = editToolBar.addWidget( + self.fillHolesToolButton + ) + self.checkableButtons.append(self.fillHolesToolButton) + self.checkableQButtonsGroup.addButton(self.fillHolesToolButton) + self.functionsNotTested3D.append(self.fillHolesToolButton) + self.widgetsWithShortcut["Fill holes"] = self.fillHolesToolButton + + self.moveLabelToolButton = QToolButton(self) + self.moveLabelToolButton.setIcon(QIcon(":moveLabel.svg")) + self.moveLabelToolButton.setCheckable(True) + self.moveLabelToolButton.setShortcut("P") + self.moveLabelToolButton.action = editToolBar.addWidget( + self.moveLabelToolButton + ) + self.checkableButtons.append(self.moveLabelToolButton) + self.checkableQButtonsGroup.addButton(self.moveLabelToolButton) + self.widgetsWithShortcut["Move label"] = self.moveLabelToolButton + + self.expandLabelToolButton = QToolButton(self) + self.expandLabelToolButton.setIcon(QIcon(":expandLabel.svg")) + self.expandLabelToolButton.setCheckable(True) + self.expandLabelToolButton.setShortcut("E") + self.expandLabelToolButton.action = editToolBar.addWidget( + self.expandLabelToolButton + ) + self.expandLabelToolButton.hide() + self.checkableButtons.append(self.expandLabelToolButton) + self.LeftClickButtons.append(self.expandLabelToolButton) + self.checkableQButtonsGroup.addButton(self.expandLabelToolButton) + self.widgetsWithShortcut["Expand/shrink label"] = self.expandLabelToolButton + + self.editIDbutton = QToolButton(self) + self.editIDbutton.setIcon(QIcon(":edit-id.svg")) + self.editIDbutton.setCheckable(True) + self.editIDbutton.setShortcut("N") + editToolBar.addWidget(self.editIDbutton) + self.checkableButtons.append(self.editIDbutton) + self.checkableQButtonsGroup.addButton(self.editIDbutton) + self.widgetsWithShortcut["Edit ID"] = self.editIDbutton + + self.separateBudButton = QToolButton(self) + self.separateBudButton.setIcon(QIcon(":separate-bud.svg")) + self.separateBudButton.setCheckable(True) + self.separateBudButton.setShortcut("S") + self.separateBudButton.action = editToolBar.addWidget(self.separateBudButton) + self.checkableButtons.append(self.separateBudButton) + self.checkableQButtonsGroup.addButton(self.separateBudButton) + # self.functionsNotTested3D.append(self.separateBudButton) + self.widgetsWithShortcut["Separate objects"] = self.separateBudButton + + self.mergeIDsButton = QToolButton(self) + self.mergeIDsButton.setIcon(QIcon(":merge-IDs.svg")) + self.mergeIDsButton.setCheckable(True) + self.mergeIDsButton.setShortcut("M") + self.mergeIDsButton.action = editToolBar.addWidget(self.mergeIDsButton) + self.checkableButtons.append(self.mergeIDsButton) + self.checkableQButtonsGroup.addButton(self.mergeIDsButton) + # self.functionsNotTested3D.append(self.mergeIDsButton) + self.widgetsWithShortcut["Merge objects"] = self.mergeIDsButton + + self.keepIDsButton = QToolButton(self) + self.keepIDsButton.setIcon(QIcon(":keep_objects.svg")) + self.keepIDsButton.setCheckable(True) + self.keepIDsButton.action = editToolBar.addWidget(self.keepIDsButton) + self.keepIDsButton.setShortcut("K") + self.checkableButtons.append(self.keepIDsButton) + self.checkableQButtonsGroup.addButton(self.keepIDsButton) + # self.functionsNotTested3D.append(self.keepIDsButton) + self.widgetsWithShortcut["Select objects to keep"] = self.keepIDsButton + + self.whitelistIDsButton = QToolButton(self) + self.whitelistIDsButton.setIcon(QIcon(":whitelist.svg")) + self.whitelistIDsButton.setCheckable(True) + self.whitelistIDsButton.action = editToolBar.addWidget(self.whitelistIDsButton) + self.whitelistIDsButton.setShortcut("Ctrl+K") + self.checkableButtons.append(self.whitelistIDsButton) + self.checkableQButtonsGroup.addButton(self.whitelistIDsButton) + self.LeftClickButtons.append(self.whitelistIDsButton) + # self.functionsNotTested3D.append(self.whitelistIDsButton) + self.widgetsWithShortcut["Select objects to add to a tracking whitelist"] = ( + self.whitelistIDsButton + ) + + self.binCellButton = QToolButton(self) + self.binCellButton.setIcon(QIcon(":bin.svg")) + self.binCellButton.setCheckable(True) + # self.binCellButton.setShortcut('R') + self.binCellButton.action = editToolBar.addWidget(self.binCellButton) + self.checkableButtons.append(self.binCellButton) + self.checkableQButtonsGroup.addButton(self.binCellButton) + # self.functionsNotTested3D.append(self.binCellButton) + + self.manualTrackingButton = QToolButton(self) + self.manualTrackingButton.setIcon(QIcon(":manual_tracking.svg")) + self.manualTrackingButton.setCheckable(True) + self.manualTrackingButton.setShortcut("T") + self.checkableQButtonsGroup.addButton(self.manualTrackingButton) + self.checkableButtons.append(self.manualTrackingButton) + self.widgetsWithShortcut["Manual tracking"] = self.manualTrackingButton + + self.ripCellButton = QToolButton(self) + self.ripCellButton.setIcon(QIcon(":rip.svg")) + self.ripCellButton.setCheckable(True) + self.ripCellButton.setShortcut("D") + self.ripCellButton.action = editToolBar.addWidget(self.ripCellButton) + self.checkableButtons.append(self.ripCellButton) + self.checkableQButtonsGroup.addButton(self.ripCellButton) + self.functionsNotTested3D.append(self.ripCellButton) + self.widgetsWithShortcut["Annotate cell as dead"] = self.ripCellButton + + editToolBar.addAction(self.addDelRoiAction) + # editToolBar.addAction(self.addDelPolyLineRoiAction) + + self.addDelPolyLineRoiAction = editToolBar.addWidget( + self.addDelPolyLineRoiButton + ) + self.addDelPolyLineRoiAction.roiType = "polyline" + + editToolBar.addAction(self.delBorderObjAction) + self.delBorderObjAction.button = editToolBar.widgetForAction( + self.delBorderObjAction + ) + editToolBar.addAction(self.delNewObjAction) + self.delNewObjAction.button = editToolBar.widgetForAction(self.delNewObjAction) + + self.addDelRoiAction.toolbar = editToolBar + self.functionsNotTested3D.append(self.addDelRoiAction) + + self.addDelPolyLineRoiAction.toolbar = editToolBar + self.functionsNotTested3D.append(self.addDelPolyLineRoiAction) + + self.delBorderObjAction.toolbar = editToolBar + self.functionsNotTested3D.append(self.delBorderObjAction) + + self.delNewObjAction.toolbar = editToolBar + # self.functionsNotTested3D.append(self.delNewObjAction) so id this doesnt work in 3d i dont know anymore + + editToolBar.addAction(self.repeatTrackingAction) + + self.manualTrackingAction = editToolBar.addWidget(self.manualTrackingButton) + + self.functionsNotTested3D.append(self.repeatTrackingAction) + self.functionsNotTested3D.append(self.manualTrackingAction) + + self.reinitLastSegmFrameAction = QAction(self) + self.reinitLastSegmFrameAction.setIcon(QIcon(":reinitLastSegm.svg")) + self.reinitLastSegmFrameAction.setVisible(False) + editToolBar.addAction(self.reinitLastSegmFrameAction) + editToolBar.setVisible(False) + self.reinitLastSegmFrameAction.toolbar = editToolBar + self.functionsNotTested3D.append(self.reinitLastSegmFrameAction) + + self.editLin_TreeBar = widgets.ToolBar("Lin Tree Edit", self) + self.editLin_TreeBar.setContextMenuPolicy(Qt.PreventContextMenu) + + self.addToolBar(self.editLin_TreeBar) + self.editLin_TreeGroup = QButtonGroup() + self.editLin_TreeGroup.setExclusive(True) + + self.findNextMotherButton = QToolButton(self) + self.findNextMotherButton.setIcon(QIcon(":magnGlass.svg")) + self.findNextMotherButton.setCheckable(True) + self.editLin_TreeBar.addWidget(self.findNextMotherButton) + self.editLin_TreeGroup.addButton(self.findNextMotherButton) + self.findNextMotherButton.setShortcut("F") + self.widgetsWithShortcut["Find next potential mother (lineage tree)"] = ( + self.findNextMotherButton + ) + + self.unknownLineageButton = QToolButton(self) + self.unknownLineageButton.setIcon(QIcon(":history.svg")) + self.unknownLineageButton.setCheckable(True) + self.editLin_TreeBar.addWidget(self.unknownLineageButton) + self.editLin_TreeGroup.addButton(self.unknownLineageButton) + self.unknownLineageButton.setShortcut("U") + self.widgetsWithShortcut["Unknown lineage (lineage tree)"] = ( + self.unknownLineageButton + ) + + self.noToolLinTreeButton = QToolButton(self) + self.noToolLinTreeButton.setIcon(QIcon(":arrow_cursor.svg")) + self.noToolLinTreeButton.setCheckable(True) + self.editLin_TreeBar.addWidget(self.noToolLinTreeButton) + self.editLin_TreeGroup.addButton(self.noToolLinTreeButton) + self.noToolLinTreeButton.setShortcut("N") + self.widgetsWithShortcut["No tool (lineage tree)"] = self.noToolLinTreeButton + + self.propagateLinTreeButton = QToolButton(self) + self.propagateLinTreeButton.setIcon(QIcon(":compute.svg")) + self.editLin_TreeBar.addWidget(self.propagateLinTreeButton) + self.propagateLinTreeButton.setShortcut("P") + self.widgetsWithShortcut["Propagate (lineage tree)"] = ( + self.propagateLinTreeButton + ) + self.propagateLinTreeButton.clicked.connect(self.propagateLinTreeAction) + + self.viewLinTreeInfoButton = QToolButton(self) + self.viewLinTreeInfoButton.setIcon(QIcon(":addCustomAnnotation.svg")) + self.editLin_TreeBar.addWidget(self.viewLinTreeInfoButton) + self.viewLinTreeInfoButton.setShortcut("S") + self.widgetsWithShortcut["View Changes (lineage tree)"] = ( + self.viewLinTreeInfoButton + ) + self.viewLinTreeInfoButton.clicked.connect(self.viewLinTreeInfoAction) + + modes_available = [ + "Segmentation and Tracking", + "Cell cycle analysis", + "Viewer", + "Custom annotations", + "Normal division: Lineage tree", + ] + self.modeItems = modes_available + + self.modeActionGroup = QActionGroup(self.modeMenu) + for mode in self.modeItems: + action = QAction(mode) + action.setCheckable(True) + self.modeActionGroup.addAction(action) + self.modeMenu.addAction(action) + if mode == "Viewer": + action.setChecked(True) + + self.editToolBar = editToolBar + self.editToolBar.setVisible(False) + self.navigateToolBar.setVisible(False) + self.editLin_TreeBar.setVisible(False) + + self.gui_createAnnotateToolbar() diff --git a/cellacdc/mixins/measurements.py b/cellacdc/mixins/measurements.py new file mode 100644 index 000000000..2066c6bf3 --- /dev/null +++ b/cellacdc/mixins/measurements.py @@ -0,0 +1,151 @@ +"""View adapter for measurement setup and dialogs.""" + +from __future__ import annotations + +import pandas as pd + +from cellacdc import apps, cli, favourite_func_metrics_csv_path, widgets + + +class Measurements: + """Extracted from guiWin.""" + + def _setMetrics(self, measurementsWin): + self._measurements_kernel.set_metrics_from_set_measurements_dialog( + measurementsWin + ) + for ch in self._measurements_kernel.chNamesToProcess: + if ch not in self.notLoadedChNames: + continue + + success = self.loadFluo_cb(fluo_channels=[ch]) + if not success: + continue + + def addCombineMetric(self): + posData = self.data[self.pos_i] + isZstack = posData.SizeZ > 1 + win = apps.combineMetricsEquationDialog( + self.ch_names, isZstack, self.isSegm3D, parent=self + ) + win.sigOk.connect(self.saveCombineMetricsToPosData) + win.exec_() + win.sigOk.disconnect() + + def addCustomMetric(self, checked=False): + txt = measurements.add_metrics_instructions() + metrics_path = measurements.metrics_path + msg = widgets.myMessageBox() + msg.addShowInFileManagerButton(metrics_path, "Show example...") + title = "Add custom metrics instructions" + msg.information(self, title, txt, buttonsTexts=("Ok",)) + + def initMetricsToSave(self, posData): + self._measurements_kernel._init_metrics_to_save(posData) + + def initMetrics(self): + self.logger.info("Initializing measurements...") + posData = self.data[self.pos_i] + self._measurements_kernel = cli.ComputeMeasurementsKernel( + self.logger, self.log_path, False + ) + self._measurements_kernel.init_args(posData.chNames, posData.getSegmEndname()) + self._measurements_kernel._init_metrics(posData, self.isSegm3D) + + def showSetMeasurements(self, checked=False, qparent=None): + qparent = qparent if qparent is not None else self + if self.measurementsWin is not None: + self.measurementsWin.show() + self.measurementsWin.raise_() + self.measurementsWin.activateWindow() + return + + try: + df_favourite_funcs = pd.read_csv(favourite_func_metrics_csv_path) + favourite_funcs = df_favourite_funcs["favourite_func_name"].to_list() + except Exception as e: + favourite_funcs = None + + posData = self.data[self.pos_i] + allPos_acdc_df_cols = set() + for _posData in self.data: + for frame_i, data_dict in enumerate(_posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + + allPos_acdc_df_cols.update(acdc_df.columns) + loadedChNames = posData.setLoadedChannelNames(returnList=True) + posData.fluo_data_dict.pop(self.user_ch_name, None) + if self.user_ch_name not in loadedChNames: + loadedChNames.insert(0, self.user_ch_name) + notLoadedChNames = [c for c in self.ch_names if c not in loadedChNames] + self.notLoadedChNames = notLoadedChNames + self.measurementsWin = apps.SetMeasurementsDialog( + loadedChNames, + notLoadedChNames, + posData.SizeZ > 1, + self.isSegm3D, + favourite_funcs=favourite_funcs, + allPos_acdc_df_cols=list(allPos_acdc_df_cols), + acdc_df_path=posData.images_path, + posData=posData, + addCombineMetricCallback=self.addCombineMetric, + allPosData=self.data, + parent=qparent, + state=self.setMeasWinState, + ) + self.measurementsWin.sigCancel.connect(self.setMeasurementsCancelled) + self.measurementsWin.sigClosed.connect(self.setMeasurements) + self.measurementsWin.show() + + def setMeasurementsCancelled(self): + self.measurementsWin = None + + def setMeasurements(self): + posData = self.data[self.pos_i] + if self.measurementsWin.delExistingCols: + self.logger.info("Removing existing unchecked measurements...") + delCols = self.measurementsWin.existingUncheckedColnames + delRps = self.measurementsWin.existingUncheckedRps + delCols_format = [f" * {colname}" for colname in delCols] + delRps_format = [f" * {colname}" for colname in delRps] + delCols_format.extend(delRps_format) + delCols_format = "\n".join(delCols_format) + self.logger.info(delCols_format) + for _posData in self.data: + for frame_i, data_dict in enumerate(_posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + + acdc_df = acdc_df.drop(columns=delCols, errors="ignore") + for col_rp in delRps: + drop_df_rp = acdc_df.filter(regex=rf"{col_rp}.*", axis=1) + drop_cols_rp = drop_df_rp.columns + acdc_df = acdc_df.drop(columns=drop_cols_rp, errors="ignore") + _posData.allData_li[frame_i]["acdc_df"] = acdc_df + self.setMeasWinState = self.measurementsWin.state() + self.logger.info("Setting measurements...") + self._setMetrics(self.measurementsWin) + self.logger.info("Metrics successfully set.") + self.measurementsWin = None + + def saveCombineMetricsToPosData(self, window): + for posData in self.data: + equationsDict, isMixedChannels = window.getEquationsDict() + for newColName, equation in equationsDict.items(): + posData.addEquationCombineMetrics(equation, newColName, isMixedChannels) + posData.saveCombineMetrics() + + if self.measurementsWin is None: + return + + self.measurementsWinState = self.measurementsWin.state() + self.measurementsWin.close() + self.showSetMeasurements() + self.measurementsWin.restoreState(self.measurementsWinState) + + def setMetricsFunc(self): + posData = self.data[self.pos_i] + self._measurements_kernel._set_metrics_func_from_posData(posData) diff --git a/cellacdc/mixins/mode_controls.py b/cellacdc/mixins/mode_controls.py new file mode 100644 index 000000000..113973d49 --- /dev/null +++ b/cellacdc/mixins/mode_controls.py @@ -0,0 +1,440 @@ +"""Qt view adapter for mode and toolbar state controls.""" + +from __future__ import annotations + +from qtpy.QtCore import QTimer + +from cellacdc import disableWindow + +from .tool_activation import ToolActivation + + +class ModeControls(ToolActivation): + """Extracted from guiWin.""" + + def blinkModeComboBox(self): + if self.flag: + self.modeComboBox.setStyleSheet("background-color: orange") + else: + self.modeComboBox.setStyleSheet("background-color: none") + self.flag = not self.flag + + def changeMode(self, text): + self.reconnectUndoRedo() + self.updateModeMenuAction() + self.clearCustomAnnot() + posData = self.data[self.pos_i] + mode = text + prevMode = self.modeComboBox.previousText() + self.annotateToolbar.setVisible(False) + if prevMode != "Viewer": + self.store_data(autosave=True) + + self.copyLostObjButton.setChecked(False) + self.stopCcaIntegrityCheckerWorker() + self.setAutoSaveSegmentationEnabled(False) + self.setAutoSaveAnnotationsEnabled(False) + if prevMode == "Normal division: Lineage tree": + self.askLineageTreeChanges() + self.lineage_tree = None + self.editLin_TreeBar.setVisible(False) + self.uncheckAllButtonsFromButtonGroup(self.editLin_TreeGroup) + + elif prevMode == "Cell cycle analysis": + self.setEnabledCcaToolbar(enabled=False) + + if mode == "Segmentation and Tracking": + self.setAutoSaveSegmentationEnabled(True) + self.setSwitchViewedPlaneDisabled(True) + self.trackingMenu.setDisabled(False) + self.modeToolBar.setVisible(True) + self.lastTrackedFrameLabel.setText("") + self.initSegmTrackMode() + self.setEnabledEditToolbarButton(enabled=True) + self.addExistingDelROIs() + self.isFirstTimeOnNextFrame() + self.setEnabledCcaToolbar(enabled=False) + self.clearComputedContours() + self.realTimeTrackingToggle.setDisabled(False) + self.realTimeTrackingToggle.label.setDisabled(False) + if posData.cca_df is not None: + self.store_cca_df() + self.restorePrevAnnotOptions() + self.whitelistViewOGIDs(False) + elif mode == "Cell cycle analysis": + self.setAutoSaveAnnotationsEnabled(True) + self.setSwitchViewedPlaneDisabled(True) + self.startCcaIntegrityCheckerWorker() + proceed = self.initCca() + if proceed: + self.applyDelROIs() + self.modeToolBar.setVisible(True) + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + self.computeAllContours() + # RAWR!!!!! + # self.computeAllObjToObjCostPairs() + if proceed: + self.setEnabledEditToolbarButton(enabled=False) + if self.isSnapshot: + self.editToolBar.setVisible(True) + self.setEnabledCcaToolbar(enabled=True) + self.removeAlldelROIsCurrentFrame() + self.setAnnotOptionsCcaMode() + self.clearGhost() + elif mode == "Viewer": + self.autoSaveTimer.stop() + self.setSwitchViewedPlaneDisabled(False) + self.modeToolBar.setVisible(True) + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + self.setEnabledEditToolbarButton(enabled=False) + self.setEnabledCcaToolbar(enabled=False) + self.removeAlldelROIsCurrentFrame() + self.setStatusBarLabel() + self.navigateScrollBar.setMaximum(posData.SizeT) + self.navSpinBox.setMaximum(posData.SizeT) + self.clearGhost() + self.computeAllContours() + elif mode == "Custom annotations": + self.setAutoSaveAnnotationsEnabled(True) + self.setSwitchViewedPlaneDisabled(True) + self.modeToolBar.setVisible(True) + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + self.setEnabledEditToolbarButton(enabled=False) + self.setEnabledCcaToolbar(enabled=False) + self.removeAlldelROIsCurrentFrame() + self.annotateToolbar.setVisible(True) + self.clearGhost() + self.doCustomAnnotation(0) + self.computeAllContours() + elif mode == "Snapshot": + self.setAutoSaveAnnotationsEnabled(True) + self.setSwitchViewedPlaneDisabled(False) + self.reconnectUndoRedo() + self.setEnabledSnapshotMode() + self.doCustomAnnotation(0) + self.clearComputedContours() + elif ( + mode == "Normal division: Lineage tree" + ): # Mode activation for lineage tree + # self.startLinTreeIntegrityCheckerWorker() # need to replace (postponed) + proceed = self.initLinTree() + self.setEnabledCcaToolbar(enabled=False) + self.setNavigateScrollBarMaximum() + if proceed: + self.applyDelROIs() + self.modeToolBar.setVisible(True) + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + if proceed: + self.setAutoSaveAnnotationsEnabled(True) + self.setEnabledEditToolbarButton(enabled=False) + if self.isSnapshot: + self.editToolBar.setVisible(True) + self.removeAlldelROIsCurrentFrame() + self.setAnnotOptionsLin_treeMode() + self.clearGhost() + self.editLin_TreeBar.setVisible(True) + + self.disableNonFunctionalButtons() + + def changeModeFromMenu(self, action): + self.modeComboBox.setCurrentText(action.text()) + + def clearComboBoxFocus(self, mode): + # Remove focus from modeComboBox to avoid the key_up changes its value + self.sender().clearFocus() + try: + self.timer.stop() + self.modeComboBox.setStyleSheet("background-color: none") + except Exception as e: + pass + + def disableEditingViewPlaneNotXY(self): + posData = self.data[self.pos_i] + self.manuallyEditCcaAction.setDisabled(True) + for action in self.segmActions: + action.setDisabled(True) + if posData.SizeT == 1: + self.segmVideoMenu.setDisabled(True) + self.postProcessSegmAction.setDisabled(True) + self.autoSegmAction.setDisabled(True) + self.ccaToolBar.setVisible(False) + self.editToolBar.setVisible(False) + for action in self.ccaToolBar.actions(): + button = self.editToolBar.widgetForAction(action) + if button is not None: + button.setDisabled(True) + action.setVisible(False) + for action in self.editToolBar.actions(): + button = self.editToolBar.widgetForAction(action) + action.setVisible(False) + if button is not None: + button.setDisabled(True) + + def enableSizeSpinbox(self, enabled): + self.brushSizeLabelAction.setVisible(enabled) + self.brushSizeAction.setVisible(enabled) + self.brushAutoFillAction.setVisible(enabled) + self.brushAutoHideAction.setVisible(enabled) + self.brushEraserToolBar.setVisible(enabled) + self.disableNonFunctionalButtons() + + def nonViewerEditMenuOpened(self): + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + self.startBlinkingModeCB() + + def reconnectUndoRedo(self): + try: + self.undoAction.triggered.disconnect() + self.redoAction.triggered.disconnect() + except Exception as e: + pass + mode = self.modeComboBox.currentText() + if mode == "Segmentation and Tracking" or mode == "Snapshot": + self.undoAction.triggered.connect(self.undo) + self.redoAction.triggered.connect(self.redo) + elif mode == "Cell cycle analysis": + self.undoAction.triggered.connect(self.UndoCca) + elif mode == "Custom annotations": + self.undoAction.triggered.connect(self.undoCustomAnnotation) + else: + self.undoAction.setDisabled(True) + self.redoAction.setDisabled(True) + + def restorePrevAnnotOptions(self): + if self.prevAnnotOptions is None: + return + self.restoreAnnotOptions_ax1(options=self.prevAnnotOptions) + self.setDrawAnnotComboboxText() + self.prevAnnotOptions = None + + def setEnabledCcaToolbar(self, enabled=False): + self.manuallyEditCcaAction.setDisabled(False) + self.viewCcaTableAction.setDisabled(False) + self.ccaToolBar.setVisible(enabled) + for action in self.ccaToolBar.actions(): + button = self.ccaToolBar.widgetForAction(action) + action.setVisible(enabled) + button.setEnabled(enabled) + + def setEnabledEditToolbarButton(self, enabled=False): + for action in self.segmActions: + action.setEnabled(enabled) + + for action in self.segmActionsVideo: + action.setEnabled(enabled) + + self.relabelSequentialAction.setEnabled(enabled) + self.repeatTrackingMenuAction.setEnabled(enabled) + self.repeatTrackingVideoAction.setEnabled(enabled) + self.postProcessSegmAction.setEnabled(enabled) + self.autoSegmAction.setEnabled(enabled) + self.editToolBar.setVisible(enabled) + mode = self.modeComboBox.currentText() + ccaON = mode == "Cell cycle analysis" + for action in self.editToolBar.actions(): + button = self.editToolBar.widgetForAction(action) + # Keep binCellButton active in cca mode + if button == self.binCellButton and not enabled and ccaON: + action.setVisible(True) + button.setEnabled(True) + else: + action.setVisible(enabled) + button.setEnabled(enabled) + if not enabled: + self.setUncheckedAllButtons() + + def setEnabledFileToolbar(self, enabled): + for action in self.fileToolBar.actions(): + button = self.fileToolBar.widgetForAction(action) + if action == self.openFolderAction or action == self.newAction: + continue + if action == self.manageVersionsAction: + continue + if action == self.openFileAction: + continue + action.setEnabled(enabled) + button.setEnabled(enabled) + + def setEnabledSnapshotMode(self): + posData = self.data[self.pos_i] + self.manuallyEditCcaAction.setDisabled(False) + self.viewCcaTableAction.setDisabled(False) + for action in self.segmActions: + action.setDisabled(False) + + self.segmVideoMenu.setDisabled(True) + self.trackingMenu.setDisabled(True) + self.modeToolBar.setVisible(False) + + self.relabelSequentialAction.setDisabled(False) + self.postProcessSegmAction.setDisabled(False) + self.autoSegmAction.setDisabled(False) + self.ccaToolBar.setVisible(True) + self.editToolBar.setVisible(True) + self.reinitLastSegmFrameAction.setVisible(False) + for action in self.ccaToolBar.actions(): + button = self.ccaToolBar.widgetForAction(action) + if button == self.assignBudMothButton: + button.setDisabled(False) + action.setVisible(True) + elif action == self.reInitCcaAction: + action.setVisible(True) + elif action == self.assignBudMothAutoAction and posData.SizeT == 1: + action.setVisible(True) + for action in self.editToolBar.actions(): + button = self.editToolBar.widgetForAction(action) + action.setVisible(True) + button.setEnabled(True) + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + self.repeatTrackingAction.setVisible(False) + self.manualTrackingAction.setVisible(False) + button = self.editToolBar.widgetForAction(self.repeatTrackingAction) + button.setDisabled(True) + button = self.editToolBar.widgetForAction(self.manualTrackingAction) + button.setDisabled(True) + self.disableNonFunctionalButtons() + self.reinitLastSegmFrameAction.setVisible(False) + + def setFramesSnapshotMode(self): + self.measurementsMenu.setDisabled(False) + self.setPermanentGreedyCmapPreferences() + if self.isSnapshot: + self.realTimeTrackingToggle.setDisabled(True) + self.realTimeTrackingToggle.label.setDisabled(True) + try: + self.drawIDsContComboBox.currentIndexChanged.disconnect() + except Exception as e: + pass + + self.imgGrad.rescaleAcrossTimeAction.setDisabled(True) + self.repeatTrackingAction.setDisabled(True) + self.manualTrackingAction.setDisabled(True) + self.logger.info('Setting GUI mode to "Snapshots"...') + self.modeComboBox.clear() + self.modeComboBox.addItems(["Snapshot"]) + self.modeComboBox.setDisabled(True) + self.modeMenu.menuAction().setVisible(False) + self.drawIDsContComboBox.clear() + self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) + self.drawIDsContComboBox.setCurrentIndex(1) + self.modeToolBar.setVisible(False) + self.skipToNewIdAction.setVisible(False) + self.skipToNewIdAction.setDisabled(True) + self.modeComboBox.setCurrentText("Snapshot") + self.annotateToolbar.setVisible(True) + self.labelsGrad.showNextFrameAction.setDisabled(True) + self.drawIDsContComboBox.currentIndexChanged.connect( + self.drawIDsContComboBox_cb + ) + self.showTreeInfoCheckbox.hide() + self.rightImageFramesScrollbar.setVisible(False) + self.rightImageFramesScrollbar.setDisabled(True) + if not self.isSegm3D: + self.manualBackgroundAction.setVisible(True) + self.manualBackgroundAction.setDisabled(False) + else: + self.manualBackgroundAction.setVisible(False) + self.manualBackgroundAction.setDisabled(True) + self.manualAnnotPastButton.setDisabled(True) + self.manualAnnotPastButton.action.setDisabled(True) + self.manualAnnotPastButton.setVisible(False) + self.manualAnnotPastButton.action.setVisible(False) + self.copyLostObjButton.setDisabled(True) + self.copyLostObjButton.action.setDisabled(True) + self.copyLostObjButton.setVisible(False) + self.copyLostObjButton.action.setVisible(False) + self.segForLostIDsAction.setVisible(False) + self.segForLostIDsAction.setDisabled(True) + self.delNewObjAction.setVisible(False) + self.delNewObjAction.setDisabled(True) + else: + self.imgGrad.rescaleAcrossTimeAction.setDisabled(False) + self.annotateToolbar.setVisible(False) + self.realTimeTrackingToggle.setDisabled(False) + self.repeatTrackingAction.setDisabled(False) + self.manualTrackingAction.setDisabled(False) + self.modeComboBox.setDisabled(False) + self.modeMenu.menuAction().setVisible(True) + self.skipToNewIdAction.setVisible(True) + self.skipToNewIdAction.setDisabled(False) + try: + self.modeComboBox.activated.disconnect() + self.modeComboBox.sigTextChanged.disconnect() + self.drawIDsContComboBox.currentIndexChanged.disconnect() + except Exception as e: + pass + # traceback.print_exc() + self.modeComboBox.clear() + self.modeComboBox.addItems(self.modeItems) + self.drawIDsContComboBox.clear() + self.drawIDsContComboBox.addItems(self.drawIDsContComboBoxSegmItems) + self.modeComboBox.sigTextChanged.connect(self.changeMode) + self.modeComboBox.activated.connect(self.clearComboBoxFocus) + self.drawIDsContComboBox.currentIndexChanged.connect( + self.drawIDsContComboBox_cb + ) + self.modeComboBox.setCurrentText("Viewer") + self.showTreeInfoCheckbox.show() + self.manualBackgroundAction.setVisible(False) + self.manualBackgroundAction.setDisabled(True) + self.labelsGrad.showNextFrameAction.setDisabled(False) + self.manualAnnotPastButton.setDisabled(False) + self.manualAnnotPastButton.action.setDisabled(False) + self.manualAnnotPastButton.setVisible(True) + self.manualAnnotPastButton.action.setVisible(True) + self.copyLostObjButton.setDisabled(False) + self.copyLostObjButton.action.setDisabled(False) + self.copyLostObjButton.setVisible(True) + self.copyLostObjButton.action.setVisible(True) + self.segForLostIDsAction.setVisible(True) + self.segForLostIDsAction.setDisabled(False) + self.delNewObjAction.setVisible(True) + self.delNewObjAction.setDisabled(False) + + for ch, overlayItems in self.overlayLayersItems.items(): + lutItem = overlayItems[1] + lutItem.rescaleAcrossTimeAction.setDisabled(self.isSnapshot) + + def startBlinkingModeCB(self): + try: + self.timer.stop() + self.stopBlinkTimer.stop() + except Exception as e: + pass + if self.rulerButton.isChecked(): + return + self.timer = QTimer(self) + self.timer.timeout.connect(self.blinkModeComboBox) + self.timer.start(200) + self.stopBlinkTimer = QTimer(self) + self.stopBlinkTimer.timeout.connect(self.stopBlinkingCB) + self.stopBlinkTimer.start(2000) + + def stopBlinkingCB(self): + self.timer.stop() + self.modeComboBox.setStyleSheet("background-color: none") + + def uncheckAllButtonsFromButtonGroup(self, buttonGroup): + for button in buttonGroup.buttons(): + if not button.isCheckable(): + continue + + if not button.isChecked(): + continue + + button.setChecked(False) + + def updateModeMenuAction(self): + self.modeActionGroup.triggered.disconnect() + for action in self.modeActionGroup.actions(): + if action.text() != self.modeComboBox.currentText(): + continue + action.setChecked(True) + break + self.modeActionGroup.triggered.connect(self.changeModeFromMenu) diff --git a/cellacdc/mixins/object_cleanup.py b/cellacdc/mixins/object_cleanup.py new file mode 100644 index 000000000..a906527cb --- /dev/null +++ b/cellacdc/mixins/object_cleanup.py @@ -0,0 +1,94 @@ +"""View adapter for object cleanup workflows.""" + +from __future__ import annotations + +import numpy as np +from qtpy.QtCore import QThread + +from cellacdc import apps, widgets, workers + +from .cell_cycle import CellCycle + + +class ObjectCleanup(CellCycle): + """Extracted from guiWin.""" + + def delObjsOutSegmMaskActionTriggered(self): + posData = self.data[self.pos_i] + segm_files = load.get_segm_files(posData.images_path) + existingSegmEndnames = load.get_endnames(posData.basename, segm_files) + selectSegmWin = widgets.QDialogListbox( + "Select segmentation file", + "Select segmentation file to use as ROI:\n", + existingSegmEndnames, + multiSelection=False, + parent=self, + ) + selectSegmWin.exec_() + if selectSegmWin.cancel: + self.logger.info("Delete objects process cancelled.") + return + + selectedSegmEndname = selectSegmWin.selectedItemsText[0] + + self.startDelObjsOutSegmMaskWorker(selectedSegmEndname) + + def delObjsOutSegmMaskWorkerFinished(self, result): + posData = self.data[self.pos_i] + worker, cleared_segm_data, delIDs = result + if posData.SizeT == 1: + cleared_segm_data = cleared_segm_data[np.newaxis] + + self.update_cca_df_deletedIDs(posData, delIDs) + + current_frame_i = posData.frame_i + for frame_i, cleared_lab in enumerate(cleared_segm_data): + # Store change + posData.allData_li[frame_i]["labels"] = cleared_lab + # Get the rest of the stored metadata based on the new lab + posData.frame_i = frame_i + self.get_data() + self.store_data(autosave=False) + + # Back to current frame + posData.frame_i = current_frame_i + self.get_data() + + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.logger.info("Deleting objects outside of ROIs finished.") + self.titleLabel.setText("Deleting objects outside of ROIs finished.", color="w") + self.updateAllImages() + + def startDelObjsOutSegmMaskWorker(self, selectedSegmEndname): + self.store_data(autosave=False) + posData = self.data[self.pos_i] + segm_data = np.squeeze(self.getStoredSegmData()) + + self.progressWin = apps.QDialogWorkerProgress( + title="Deleting objects outside of ROIs", + parent=self, + pbarDesc="Deleting objects outside of ROIs...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + self.thread = QThread() + self.worker = workers.DelObjectsOutsideSegmROIWorker( + selectedSegmEndname, segm_data, posData.images_path + ) + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + self.worker.progress.connect(self.workerProgress) + self.worker.critical.connect(self.workerCritical) + self.worker.finished.connect(self.delObjsOutSegmMaskWorkerFinished) + + self.worker.debug.connect(self.workerDebug) + + self.thread.started.connect(self.worker.run) + self.thread.start() diff --git a/cellacdc/mixins/object_properties.py b/cellacdc/mixins/object_properties.py new file mode 100644 index 000000000..5606f1dee --- /dev/null +++ b/cellacdc/mixins/object_properties.py @@ -0,0 +1,888 @@ +"""Qt view adapter for object-property workflows.""" + +from __future__ import annotations + +import numpy as np +import skimage.measure +from tqdm import tqdm + +from cellacdc import apps, exception_handler, html_utils, widgets + +from .cell_cycle import CellCycle +from .tracking import Tracking + + +class ObjectProperties(CellCycle, Tracking): + """Extracted from guiWin.""" + + def _keepObjects(self, keepIDs=None, lab=None, rp=None): + posData = self.data[self.pos_i] + if lab is None: + lab = posData.lab + + if rp is None: + rp = posData.rp + + if keepIDs is None: + keepIDs = self.keptObjectsIDs + + for obj in rp: + if obj.label in keepIDs: + continue + + lab[obj.slice][obj.image] = 0 + + return lab + + def applyKeepObjects(self): + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + self._keepObjects() + self.highlightHoverIDsKeptObj(0, 0, hoverID=0) + + posData = self.data[self.pos_i] + + self.update_rp() + # Repeat tracking + self.tracking(enforce=True, assign_unique_new_IDs=False) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Deleted non-selected objects") + self.updateAllImages() + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + return + else: + removeAnnot = self.warnEditingWithCca_df( + "Deleted non-selected objects", get_answer=True + ) + if not removeAnnot: + # We can propagate changes only if the user agrees on + # removing annotations + return + + self.current_frame_i = posData.frame_i + if posData.frame_i > 0: + txt = html_utils.paragraph(""" + Do you want to remove un-kept objects in the past frames too? + """) + msg = widgets.myMessageBox(wrapText=False, showCentered=False) + _, _, applyToPastButton = msg.question( + self, + "Propagate to past frames?", + txt, + buttonsTexts=("Cancel", "No", "Yes, apply to past frames"), + ) + if msg.cancel: + return + if msg.clickedButton == applyToPastButton: + self.store_data() + self.logger.info("Applying keep objects to past frames...") + if not removeAnnot and posData.cca_df is not None: + delIDs = [ + ID for ID in posData.cca_df.index if ID not in posData.IDs + ] + self.update_cca_df_deletedIDs(posData, delIDs) + + for i in tqdm(range(posData.frame_i), ncols=100): + lab = posData.allData_li[i]["labels"] + rp = posData.allData_li[i]["regionprops"] + keepLab = self._keepObjects(lab=lab, rp=rp) + # Store change + posData.allData_li[i]["labels"] = keepLab.copy() + # Get the rest of the stored metadata based on the new lab + posData.frame_i = i + self.get_data() + self.store_data(autosave=False) + + posData.frame_i = self.current_frame_i + self.get_data() + + # Ask to propagate change to all future visited frames + key = "Keep ID" + askAction = self.askHowFutureFramesActions[key] + doNotShow = not askAction.isChecked() + (UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain) = ( + self.propagateChange( + self.keptObjectsIDs, + key, + doNotShow, + posData.UndoFutFrames_keepID, + posData.applyFutFrames_keepID, + force=True, + applyTrackingB=True, + ) + ) + + if UndoFutFrames is None: + # Empty keep object list + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + return + + posData.doNotShowAgain_keepID = doNotShowAgain + posData.UndoFutFrames_keepID = UndoFutFrames + posData.applyFutFrames_keepID = applyFutFrames + includeUnvisited = posData.includeUnvisitedInfo["Keep ID"] + + if applyFutFrames: + self.store_data() + + self.logger.info("Applying to future frames...") + pbar = tqdm(total=posData.SizeT - posData.frame_i - 1, ncols=100) + segmSizeT = len(posData.segm_data) + if not removeAnnot and posData.cca_df is not None: + delIDs = [ID for ID in posData.cca_df.index if ID not in posData.IDs] + self.update_cca_df_deletedIDs(posData, delIDs) + + for i in range(posData.frame_i + 1, segmSizeT): + lab = posData.allData_li[i]["labels"] + if lab is None and not includeUnvisited: + self.enqAutosave() + pbar.update(posData.SizeT - i) + break + + rp = posData.allData_li[i]["regionprops"] + + if lab is not None: + keepLab = self._keepObjects(lab=lab, rp=rp) + # Store change + posData.allData_li[i]["labels"] = keepLab.copy() + # Get the rest of the stored metadata based on the new lab + posData.frame_i = i + self.get_data() + self.store_data(autosave=False) + elif includeUnvisited: + # Unvisited frame (includeUnvisited = True) + lab = posData.segm_data[i] + rp = skimage.measure.regionprops(lab) + keepLab = self._keepObjects(lab=lab, rp=rp) + posData.segm_data[i] = keepLab + + pbar.update() + pbar.close() + + # Back to current frame + if applyFutFrames: + posData.frame_i = self.current_frame_i + self.get_data() + + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + + def clearHighlightedID(self): + self.highlightIDToolbar.setVisible(False) + + try: + self.updateLostContoursImage(ax=0, delROIsIDs=None) + except Exception as err: + pass + + if self.highlightedID == 0: + return + + self.highlightedID = 0 + self.guiTabControl.highlightCheckbox.setChecked(False) + self.guiTabControl.highlightSearchedCheckbox.setChecked(False) + self.setHighlightID(False) + + def clearHighlightedKeepIDs(self): + self.setAllTextAnnotations() + self.highlightedID = 0 + self.searchedIDitemRight.setData([], []) + self.searchedIDitemLeft.setData([], []) + self.highLightIDLayerImg1.clear() + self.highLightIDLayerRightImage.clear() + + def clearHighlightedText(self): + pass + + def countObjects(self): + self.logger.info("Counting objects...") + + posData = self.data[self.pos_i] + if posData.SizeT > 1: + return self.countObjectsTimelapse() + + return self.countObjectsSnapshots() + + def countObjectsCb(self, checked): + if self.countObjsWindow is None: + categoryCountMapper = self.countObjects() + self.countObjsWindow = apps.ObjectCountDialog( + categoryCountMapper=categoryCountMapper, parent=self, data=self.data + ) + self.countObjsWindow.sigShowEvent.connect(self.updateObjectCounts) + self.countObjsWindow.sigUpdateCounts.connect(self.updateObjectCounts) + + if checked: + self.countObjsWindow.show() + else: + self.countObjsWindow.hide() + + def countObjectsSnapshots(self): + posData = self.data[self.pos_i] + if self.countObjsWindow is None: + activeCategories = { + "In current position", + "In all visited positions (current session)", + "In all visited positions (previous sessions)", + "In all loaded positions", + } + if self.isSegm3D: + activeCategories.add("In current z-slice") + else: + activeCategories = self.countObjsWindow.activeCategories() + + numObjectsCurrentPos = len(posData.IDs) + numObjectsAllPos = 0 + numObjectsVisitedPosPrevious = 0 + numObjectsVisitedPosCurrent = 0 + numObjectsCurrentZslice = None + if "In current z-slice" in activeCategories: + numObjectsCurrentZslice = len( + skimage.measure.regionprops(self.currentLab2D) + ) + + for pos_i, _posData in enumerate(self.data): + IDs = _posData.allData_li[0]["IDs"] + if os.path.exists(_posData.acdc_output_csv_path): + numObjectsVisitedPosPrevious += len(IDs) + if IDs: + numObjs = len(IDs) + numObjectsAllPos += len(IDs) + else: + lab = _posData.segm_data[0] + rp = skimage.measure.regionprops(lab) + numObjs = len(rp) + numObjectsAllPos += numObjs + + if _posData.visited: + numObjectsVisitedPosCurrent += numObjs + + allCategoryCountMapper = { + "In current position": numObjectsCurrentPos, + "In all visited positions (current session)": numObjectsVisitedPosCurrent, + "In all visited positions (previous sessions)": numObjectsVisitedPosPrevious, + "In all loaded positions": numObjectsAllPos, + } + if numObjectsCurrentZslice is not None: + allCategoryCountMapper["In current z-slice"] = numObjectsCurrentZslice + + if self.countObjsWindow is None: + return allCategoryCountMapper + + categoryCountMapper = {} + for category in activeCategories: + categoryCountMapper[category] = allCategoryCountMapper[category] + + return categoryCountMapper + + def countObjectsTimelapse(self): + if self.countObjsWindow is None: + activeCategories = { + "In current frame", + "In all visited frames", + "In entire video", + "Unique objects in all visited frames", + "Unique objects in entire video", + } + else: + activeCategories = self.countObjsWindow.activeCategories() + + posData = self.data[self.pos_i] + allCategoryCountMapper = posData.countObjectsInSegmTimelapse(activeCategories) + if self.countObjsWindow is None: + return allCategoryCountMapper + + categoryCountMapper = {} + for category in activeCategories: + categoryCountMapper[category] = allCategoryCountMapper[category] + + return categoryCountMapper + + def getHighlightedID(self): + if self.highlightedID > 0: + return self.highlightedID + + doHighlight = self.propsDockWidget.isVisible() and ( + self.guiTabControl.highlightCheckbox.isChecked() + or self.guiTabControl.highlightSearchedCheckbox.isChecked() + ) + if not doHighlight: + return 0 + + return self.guiTabControl.propsQGBox.idSB.value() + + def get_curr_lab( + self, curr_lab: np.ndarray | None = None, frame_i: int | None = None + ): + """Get the current labels for the position data. Hirarchically checks: + 1. If `curr_lab` is provided, use it. + 2. If `posData.lab` is not None, use it. + 3. If `posData.allData_li[frame_i]['labels']` exists, use it. + 4. If `posData.segm_data[frame_i]` exists, use it. + + If frame_i is None, uses the current frame index from `posData`. + + Parameters + ---------- + curr_lab : np.ndarray, optional + Current labels for the position data if it should be checked + if its not None first, by default None + frame_i : int, optional + Frame index to use for retrieving labels, by default None + + Returns + ------- + np.ndarray + Current labels for the position data + """ + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + if curr_lab is None and frame_i == posData.frame_i: + curr_lab = posData.lab + + if curr_lab is None: + try: + curr_lab = posData.allData_li[frame_i]["labels"].copy() + except: + pass + + if curr_lab is None: + try: + curr_lab = posData.segm_data[frame_i].copy() + except: + pass + + return curr_lab + + def grayOutHighlightedLabels(self, nonGrayedIDs=None, alpha=None): + if nonGrayedIDs is None: + nonGrayedIDs = set() + + posData = self.data[self.pos_i] + if alpha is None: + alpha = self.imgGrad.labelsAlphaSlider.value() + + if not hasattr(self, "highlightedLab"): + self.highlightedLab = np.zeros_like(self.currentLab2D) + else: + self.highlightedLab[:] = 0 + + lut = np.zeros((2, 4), dtype=np.uint8) + for _obj in posData.rp: + if not self.isObjVisible(_obj.bbox): + continue + if _obj.label not in nonGrayedIDs: + continue + _slice = self.getObjSlice(_obj.slice) + _objMask = self.getObjImage(_obj.image, _obj.bbox) + self.highlightedLab[_slice][_objMask] = _obj.label + rgb = self.lut[_obj.label].copy() + lut[1, :-1] = rgb + # Set alpha to 0.7 + lut[1, -1] = 178 + + return lut + + def grayOutOverlaySegm(self, ax=0): + if ax == 0: + how = self.drawIDsContComboBox.currentText() + else: + how = self.getAnnotateHowRightImage() + + isOverlaySegmActive = how.find("segm. masks") != -1 + if not isOverlaySegmActive: + return + + grayedLut = self.grayOutHighlightedLabels() + + def highlightHoverID(self, x, y, hoverID=None): + if hoverID is None: + try: + hoverID = self.currentLab2D[int(y), int(x)] + except IndexError: + return + + if hoverID == 0: + return + + posData = self.data[self.pos_i] + objIdx = posData.IDs_idxs[hoverID] + obj = posData.rp[objIdx] + self.goToZsliceSearchedID(obj) + self.highlightSearchedID(hoverID) + + def highlightHoverIDsKeptObj(self, x, y, hoverID=None): + if hoverID is None: + try: + hoverID = self.currentLab2D[int(y), int(x)] + except IndexError: + return + + self.highlightSearchedID(hoverID, greyOthers=False) + + if hoverID == 0 and self.highlightedID == 0: + return + + if hoverID == 0 and self.highlightedID != 0: + self.clearHighlightedKeepIDs() + for ID in self.keptObjectsIDs: + self.highlightLabelID(ID) + return + + posData = self.data[self.pos_i] + try: + objIdx = posData.IDs_idxs[hoverID] + except KeyError as err: + return + + obj = posData.rp[objIdx] + self.goToZsliceSearchedID(obj) + + for ID in self.keptObjectsIDs: + self.highlightLabelID(ID) + + def highlightIDonHoverCheckBoxToggled(self, checked): + doHighlight = ( + self.guiTabControl.highlightCheckbox.isChecked() + or self.guiTabControl.highlightSearchedCheckbox.isChecked() + ) + if not doHighlight: + self.highlightedID = 0 + self.initLookupTableLab() + else: + self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() + self.highlightSearchedID(self.highlightedID, force=True) + self.updatePropsWidget(self.highlightedID) + self.updateAllImages() + + def highlightLabelID(self, ID, ax=0): + posData = self.data[self.pos_i] + try: + obj = posData.rp[posData.IDs_idxs[ID]] + except KeyError: + return + + self.textAnnot[ax].highlightObject(obj) + + def highlightSearchedID(self, ID, force=False, greyOthers=True): + self.highlightIDToolbar.setIDNoSignals(ID) + + if ID == 0: + self.highlightIDToolbar.setVisible(False) + return + + if ID == self.highlightedID and not force: + return + + doHighlight = self.propsDockWidget.isVisible() and ( + self.guiTabControl.highlightCheckbox.isChecked() + or self.guiTabControl.highlightSearchedCheckbox.isChecked() + ) + if doHighlight: + self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() + ID = self.highlightedID + + if self.highlightedID > 0: + self.clearHighlightedText() + + self.searchedIDitemRight.setData([], []) + self.searchedIDitemLeft.setData([], []) + + posData = self.data[self.pos_i] + + self.highlightedID = ID + self.highlightIDToolbar.setVisible(True) + + objIdx = posData.IDs_idxs.get(ID) + if objIdx is None: + return + + obj = posData.rp[objIdx] + isObjVisible = self.isObjVisible(obj.bbox) + if not isObjVisible: + return + + if greyOthers: + self.textAnnot[0].grayOutAnnotations() + self.textAnnot[1].grayOutAnnotations() + + how_ax1 = self.drawIDsContComboBox.currentText() + how_ax2 = self.getAnnotateHowRightImage() + isOverlaySegm_ax1 = how_ax1.find("segm. masks") != -1 + isOverlaySegm_ax2 = how_ax2.find("segm. masks") != -1 + alpha = self.imgGrad.labelsAlphaSlider.value() + + if isOverlaySegm_ax1 or isOverlaySegm_ax2: + grayedLut = self.grayOutHighlightedLabels( + nonGrayedIDs={obj.label}, alpha=alpha + ) + + cont = None + contours = None + if isOverlaySegm_ax1: + self.highLightIDLayerImg1.setLookupTable(grayedLut) + self.highLightIDLayerImg1.setImage(self.highlightedLab) + self.labelsLayerImg1.setOpacity(alpha / 3) + else: + contours = self.getObjContours(obj, all_external=True) + for cont in contours: + self.searchedIDitemLeft.addPoints(cont[:, 0] + 0.5, cont[:, 1] + 0.5) + + if isOverlaySegm_ax2: + self.highLightIDLayerRightImage.setLookupTable(grayedLut) + self.highLightIDLayerRightImage.setImage(self.highlightedLab) + self.labelsLayerRightImg.setOpacity(alpha / 3) + else: + if contours is None: + contours = self.getObjContours(obj, all_external=True) + for cont in contours: + self.searchedIDitemRight.addPoints(cont[:, 0] + 0.5, cont[:, 1] + 0.5) + + # Gray out all IDs excpet searched one + lut = self.lut.copy() # [:max(posData.IDs)+1] + lut[:ID] = lut[:ID] * 0.2 + lut[ID + 1 :] = lut[ID + 1 :] * 0.2 + self.img2.setLookupTable(lut) + + # Highlight text + self.highlightLabelID(ID, ax=0) + self.highlightLabelID(ID, ax=1) + + def highlightSearchedIDcheckBoxToggled(self, checked): + self.highlightIDonHoverCheckBoxToggled(checked) + if checked: + posData = self.data[self.pos_i] + self.highlightedID = self.getHighlightedID() + if self.highlightedID == 0: + return + objIdx = posData.IDs_idxs[self.highlightedID] + obj_idx = posData.IDs_idxs.get(self.highlightedID) + if obj_idx is None: + return + obj = posData.rp[objIdx] + self.goToZsliceSearchedID(obj) + + def initKeepObjLabelsLayers(self): + lut = np.zeros((len(self.lut), 4), dtype=np.uint8) + lut[:, :-1] = self.lut + lut[:, -1:] = 255 + lut[0] = [0, 0, 0, 0] + self.keepIDsTempLayerLeft.setLevels([0, len(lut)]) + self.keepIDsTempLayerLeft.setLookupTable(lut) + + def initPixelSizePropsDockWidget(self): + posData = self.data[self.pos_i] + PhysicalSizeX = posData.PhysicalSizeX + PhysicalSizeY = posData.PhysicalSizeY + PhysicalSizeZ = posData.PhysicalSizeZ + self.guiTabControl.initPixelSize(PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ) + + def keepIDs_cb(self, checked): + if checked: + self.highlightedLab = np.zeros_like(self.currentLab2D) + if self.annotCcaInfoCheckbox.isChecked(): + self.annotCcaInfoCheckbox.setChecked(False) + self.annotIDsCheckbox.setChecked(True) + self.setDrawAnnotComboboxText() + self.uncheckLeftClickButtons(None) + self.initKeepObjLabelsLayers() + self.setAllIDs() + else: + # restore items to non-grayed out + self.clearTempBrushImage() + alpha = self.imgGrad.labelsAlphaSlider.value() + self.labelsLayerImg1.setOpacity(alpha) + self.labelsLayerRightImg.setOpacity(alpha) + self.ax1_contoursImageItem.setOpacity(1.0) + self.ax2_contoursImageItem.setOpacity(1.0) + self.ax1_lostObjImageItem.setOpacity(1.0) + self.ax2_lostObjImageItem.setOpacity(1.0) + self.ax1_lostTrackedObjImageItem.setOpacity(1.0) + self.ax2_lostTrackedObjImageItem.setOpacity(1.0) + + self.keepIDsToolbar.setVisible(checked) + self.highlightedIDopts = None + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + self.updateAllImages() + + def propsWidgetIDvalueChanged(self, ID): + posData = self.data[self.pos_i] + if ID == 0: + self.updatePropsWidget(int(ID)) + return + + propsQGBox = self.guiTabControl.propsQGBox + obj_idx = posData.IDs_idxs.get(ID) + if obj_idx is None: + s = f"Object ID {int(ID):d} does not exist" + propsQGBox.notExistingIDLabel.setText(s) + return + + obj = posData.rp[obj_idx] + self.goToZsliceSearchedID(obj) + self.updatePropsWidget(int(ID)) + + def removeHighlightLabelID(self, IDs=None, ax=0): + posData = self.data[self.pos_i] + if IDs is None: + IDs = posData.IDs + + for ID in IDs: + obj = posData.rp[posData.IDs_idxs[ID]] + self.textAnnot[ax].removeHighlightObject(obj) + + def setAllIDs(self, onlyVisited=False): + for posData in self.data: + posData.allIDs = set() + for frame_i in range(len(posData.segm_data)): + if frame_i >= len(posData.allData_li): + break + lab = posData.allData_li[frame_i]["labels"] + if lab is None and onlyVisited: + break + + if lab is None: + rp = skimage.measure.regionprops(posData.segm_data[frame_i]) + else: + rp = posData.allData_li[frame_i]["regionprops"] + posData.allIDs.update([obj.label for obj in rp]) + + def setHighlighedIDfromToolbar(self, ID: int): + self.findID(ID=ID) + + def setHighlightID(self, doHighlight): + if not doHighlight: + self.highlightedID = 0 + self.initLookupTableLab() + else: + self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() + self.highlightSearchedID(self.highlightedID, force=True) + self.updatePropsWidget(self.highlightedID) + self.updateAllImages() + + def showPropsDockWidget(self, checked=False): + if self.showPropsDockButton.isExpand: + self.propsDockWidget.setVisible(False) + self.setHighlightID(False) + else: + self.highlightedID = self.guiTabControl.propsQGBox.idSB.value() + if self.isSegm3D: + self.guiTabControl.propsQGBox.cellVolVox3D_SB.show() + self.guiTabControl.propsQGBox.cellVolVox3D_SB.label.show() + self.guiTabControl.propsQGBox.cellVolFl3D_DSB.show() + self.guiTabControl.propsQGBox.cellVolFl3D_DSB.label.show() + else: + self.guiTabControl.propsQGBox.cellVolVox3D_SB.hide() + self.guiTabControl.propsQGBox.cellVolVox3D_SB.label.hide() + self.guiTabControl.propsQGBox.cellVolFl3D_DSB.hide() + self.guiTabControl.propsQGBox.cellVolFl3D_DSB.label.hide() + + self.propsDockWidget.setVisible(True) + self.propsDockWidget.setEnabled(True) + self.updateAllImages() + + def updateKeepIDs(self, IDs): + posData = self.data[self.pos_i] + + self.clearHighlightedText() + + isAnyIDnotExisting = False + # Check if IDs from line edit are present in current keptObjectIDs list + for ID in IDs: + if ID not in posData.allIDs: + isAnyIDnotExisting = True + continue + if ID not in self.keptObjectsIDs: + self.keptObjectsIDs.append(ID, editText=False) + self.highlightLabelID(ID) + + # Check if IDs in current keptObjectsIDs are present in IDs from line edit + for ID in self.keptObjectsIDs: + if ID not in posData.allIDs: + isAnyIDnotExisting = True + continue + if ID not in IDs: + self.keptObjectsIDs.remove(ID, editText=False) + + self.updateTempLayerKeepIDs() + if isAnyIDnotExisting: + self.keptIDsLineEdit.warnNotExistingID() + else: + self.keptIDsLineEdit.setInstructionsText() + + def updateObjectCounts(self): + if self.countObjsWindow is None: + return + + if not self.countObjsWindow.isVisible(): + return + + if not self.countObjsWindow.livePreviewCheckbox.isChecked(): + return + + categoryCountMapper = self.countObjects() + self.countObjsWindow.updateCounts(categoryCountMapper) + + def updatePropsWidget(self, ID, fromHover=False): + if isinstance(ID, str): + # Function called by currentTextChanged of channelCombobox or + # additionalMeasCombobox. We set self.currentPropsID = 0 to force update + ID = self.guiTabControl.propsQGBox.idSB.value() + self.currentPropsID = -1 + + ID = int(ID) + + update = ( + self.propsDockWidget.isVisible() and ID != 0 and ID != self.currentPropsID + ) + if not update: + return + + posData = self.data[self.pos_i] + if not hasattr(posData, "rp"): + return + + if posData.rp is None: + self.update_rp() + + if not posData.IDs: + # empty segmentation mask + return + + if fromHover and not self.guiTabControl.highlightCheckbox.isChecked(): + # Do not highlight on hover + return + + propsQGBox = self.guiTabControl.propsQGBox + + obj_idx = posData.IDs_idxs.get(ID) + if obj_idx is None: + s = f"Object ID {int(ID):d} does not exist" + propsQGBox.notExistingIDLabel.setText(s) + return + + propsQGBox.notExistingIDLabel.setText("") + self.currentPropsID = ID + propsQGBox.idSB.setValue(ID) + + doHighlight = ( + self.guiTabControl.highlightCheckbox.isChecked() + or self.guiTabControl.highlightSearchedCheckbox.isChecked() + ) + if doHighlight: + self.highlightSearchedID(ID) + + obj = posData.rp[obj_idx] + + if self.isSegm3D: + if self.zProjComboBox.currentText() == "single z-slice": + local_z = self.z_lab() - obj.bbox[0] + area_pxl = np.count_nonzero(obj.image[local_z]) + else: + area_pxl = np.count_nonzero(obj.image.max(axis=0)) + else: + area_pxl = obj.area + + propsQGBox.cellAreaPxlSB.setValue(area_pxl) + + pixelSizeQGBox = self.guiTabControl.pixelSizeQGBox + PhysicalSizeX = pixelSizeQGBox.pixelWidthWidget.value() + PhysicalSizeY = pixelSizeQGBox.pixelHeightWidget.value() + PhysicalSizeZ = pixelSizeQGBox.voxelDepthWidget.value() + + yx_pxl_to_um2 = PhysicalSizeY * PhysicalSizeX + + area_um2 = area_pxl * yx_pxl_to_um2 + + propsQGBox.cellAreaUm2DSB.setValue(area_um2) + + if self.isSegm3D: + PhysicalSizeZ = posData.PhysicalSizeZ + vol_vox_3D = obj.area + vol_fl_3D = vol_vox_3D * PhysicalSizeZ * PhysicalSizeY * PhysicalSizeX + propsQGBox.cellVolVox3D_SB.setValue(vol_vox_3D) + propsQGBox.cellVolFl3D_DSB.setValue(vol_fl_3D) + + vol_vox, vol_fl = _calc_rot_vol(obj, PhysicalSizeY, PhysicalSizeX) + propsQGBox.cellVolVoxSB.setValue(int(vol_vox)) + propsQGBox.cellVolFlDSB.setValue(vol_fl) + + minor_axis_length = max(1, obj.minor_axis_length) + elongation = obj.major_axis_length / minor_axis_length + propsQGBox.elongationDSB.setValue(elongation) + + solidity = obj.solidity + propsQGBox.solidityDSB.setValue(solidity) + + additionalPropName = propsQGBox.additionalPropsCombobox.currentText() + additionalPropValue = getattr(obj, additionalPropName) + propsQGBox.additionalPropsCombobox.indicator.setValue(additionalPropValue) + + intensMeasurQGBox = self.guiTabControl.intensMeasurQGBox + selectedChannel = intensMeasurQGBox.channelCombobox.currentText() + + try: + _, filename = self.getPathFromChName(selectedChannel, posData) + image = posData.ol_data_dict[filename][posData.frame_i] + except Exception as e: + image = posData.img_data[posData.frame_i] + + if posData.SizeZ > 1 and not self.isSegm3D: + z = self.zSliceScrollBar.sliderPosition() + objData = image[z][obj.slice][obj.image] + img = self.img1.image + else: + objData = image[obj.slice][obj.image] + img = image + + intensMeasurQGBox.minimumDSB.setValue(np.min(objData)) + intensMeasurQGBox.maximumDSB.setValue(np.max(objData)) + intensMeasurQGBox.meanDSB.setValue(np.mean(objData)) + intensMeasurQGBox.medianDSB.setValue(np.median(objData)) + + funcDesc = intensMeasurQGBox.additionalMeasCombobox.currentText() + func = intensMeasurQGBox.additionalMeasCombobox.functions[funcDesc] + if funcDesc == "Concentration": + bkgrVal = np.median(img[posData.lab == 0]) + amount = func(objData, bkgrVal, obj.area) + value = amount / vol_vox + elif funcDesc == "Amount": + bkgrVal = np.median(img[posData.lab == 0]) + amount = func(objData, bkgrVal, obj.area) + value = amount + else: + value = func(objData) + + intensMeasurQGBox.additionalMeasCombobox.indicator.setValue(value) + + def updateTempLayerKeepIDs(self): + if not self.keepIDsButton.isChecked(): + return + + keptLab = np.zeros_like(self.currentLab2D) + + posData = self.data[self.pos_i] + for obj in posData.rp: + if obj.label not in self.keptObjectsIDs: + continue + + if not self.isObjVisible(obj.bbox): + continue + + _slice = self.getObjSlice(obj.slice) + _objMask = self.getObjImage(obj.image, obj.bbox) + + keptLab[_slice][_objMask] = obj.label + + self.keepIDsTempLayerLeft.setImage(keptLab, autoLevels=False) diff --git a/cellacdc/mixins/object_search.py b/cellacdc/mixins/object_search.py new file mode 100644 index 000000000..925a1b6d2 --- /dev/null +++ b/cellacdc/mixins/object_search.py @@ -0,0 +1,253 @@ +"""Qt view adapter for object search and navigation.""" + +from __future__ import annotations + +from collections.abc import Callable +from qtpy.QtCore import QEventLoop, QThread + +from cellacdc import apps, html_utils, widgets, workers + +from .frame_navigation import FrameNavigation + + +class ObjectSearch(FrameNavigation): + """Extracted from guiWin.""" + + def askGoToFrameFoundID(self, searchedID, frame_i_found): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + Object ID {searchedID} was found at frame n. {frame_i_found + 1}.

    + Do you want to go to frame n. {frame_i_found + 1}. + """) + noButton, yesButton = msg.information( + self, + f"ID {searchedID} found at frame n. {frame_i_found + 1}", + txt, + buttonsTexts=( + "No, stay on current frame", + f"Yes, go to frame n. {frame_i_found + 1}", + ), + ) + return msg.clickedButton == yesButton + + def findID(self, checked=False, ID=None): + posData = self.data[self.pos_i] + if ID is None: + searchIDdialog = apps.FindIDDialog( + title="Search object by ID", + msg="Enter object ID to find and highlight", + parent=self, + isInteger=True, + ) + searchIDdialog.exec_() + if searchIDdialog.cancel: + return + + searchedID = searchIDdialog.EntryID + else: + searchedID = ID + + if searchedID in posData.IDs: + self.goToObjectID(searchedID) + return + + if posData.SizeT == 1: + self.warnIDnotFound(searchedID) + return + + if searchedID in posData.lost_IDs: + self.goToLostObjectID(searchedID) + return + + tracked_lost_IDs = self.getTrackedLostIDs() + if searchedID in tracked_lost_IDs: + self.goToAcceptedLostObjectID(searchedID) + return + + self.logger.info(f"Searching ID {searchedID} in other frames...") + + frame_i_found = self.startSearchIDworker(searchedID) + if frame_i_found is None: + self.warnIDnotFound(searchedID) + return + + self.logger.info( + f"Object ID {searchedID} found at frame n. {frame_i_found + 1}." + ) + proceed = self.askGoToFrameFoundID(searchedID, frame_i_found) + if not proceed: + return + + posData.frame_i = frame_i_found + self.get_data() + self.updateAllImages() + self.updateScrollbars() + + self.goToObjectID(searchedID) + + def findNextNewIdWorkerFinished(self, next_frame_i): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.navSpinBox.setValue(next_frame_i + 1) + self.framesScrollBarReleased() + + def goToAcceptedLostObjectID(self, acceptedLostID): + posData = self.data[self.pos_i] + frame_i = posData.frame_i + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[frame_i - 1]["IDs_idxs"] + obj = prev_rp[prev_IDs_idxs[acceptedLostID]] + self.goToZsliceSearchedID(obj) + + self.updateLostTrackedContoursImage(tracked_lost_IDs=[acceptedLostID]) + + def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)): + posData = self.data[self.pos_i] + frame_i = posData.frame_i + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[frame_i - 1]["IDs_idxs"] + obj = prev_rp[prev_IDs_idxs[lostID]] + self.goToZsliceSearchedID(obj) + + imageItem = self.getLostObjImageItem(0) + thickness = 1 + if not hasattr(self, "lostObjContoursImage"): + self.initLostObjContoursImage() + else: + self.lostObjContoursImage[:] = 0 + + contours = [] + obj_contours = self.getObjContours(obj, all_external=True) + contours.extend(obj_contours) + + self.addLostObjsToLostObjImage(obj, lostID) + self.drawLostObjContoursImage(imageItem, contours, thickness=2, color=color) + + def goToObjectID(self, ID): + posData = self.data[self.pos_i] + objIdx = posData.IDs_idxs[ID] + obj = posData.rp[objIdx] + self.goToZsliceSearchedID(obj) + + self.highlightSearchedID(ID) + propsQGBox = self.guiTabControl.propsQGBox + propsQGBox.idSB.setValue(ID) + + def searchIDworkerCallback(self, posData, searchedID): + self.searchIDworker.signals.initProgressBar.emit(0) + self.setAllIDs() + self.searchIDworker.signals.initProgressBar.emit(posData.SizeT) + frame_i_found = None + for frame_i in range(len(posData.segm_data)): + if frame_i >= len(posData.allData_li): + break + lab = posData.allData_li[frame_i]["labels"] + if lab is None: + rp = skimage.measure.regionprops(posData.segm_data[frame_i]) + IDs = set([obj.label for obj in rp]) + else: + IDs = posData.allData_li[frame_i]["IDs"] + + if searchedID in IDs: + frame_i_found = frame_i + break + + self.searchIDworker.signals.progressBar.emit(1) + + self.searchIDworker.frame_i_found = frame_i_found + + def searchIDworkerCritical(self, error): + self.searchIDworkerLoop.exit() + self.workerCritical(error) + + def searchIDworkerFinished(self): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.searchIDworkerLoop.exit() + + def skipForwardToNewID(self): + self.progressWin = apps.QDialogWorkerProgress( + title="Searching the next frame with a new object", + parent=self, + pbarDesc=f"Searching the next frame with a new object...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + self.startFindNextNewIdWorker() + + def startFindNextNewIdWorker(self): + posData = self.data[self.pos_i] + self._thread = QThread() + self.findNextNewIdWorker = workers.FindNextNewIdWorker(posData, self) + self.findNextNewIdWorker.moveToThread(self._thread) + + self.findNextNewIdWorker.signals.finished.connect(self._thread.quit) + self.findNextNewIdWorker.signals.finished.connect( + self.findNextNewIdWorker.deleteLater + ) + self._thread.finished.connect(self._thread.deleteLater) + + self.findNextNewIdWorker.signals.finished.connect( + self.findNextNewIdWorkerFinished + ) + self.findNextNewIdWorker.signals.progress.connect(self.workerProgress) + self.findNextNewIdWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.findNextNewIdWorker.signals.progressBar.connect( + self.workerUpdateProgressbar + ) + self.findNextNewIdWorker.signals.critical.connect(self.workerCritical) + + self._thread.started.connect(self.findNextNewIdWorker.run) + self._thread.start() + + def startSearchIDworker(self, searchedID): + posData = self.data[self.pos_i] + + desc = "Searching ID in all frames..." + + self.progressWin = apps.QDialogWorkerProgress( + title=desc, parent=self.mainWin, pbarDesc=desc + ) + self.progressWin.mainPbar.setMaximum(posData.SizeT) + self.progressWin.show(self.app) + + self.searchIDthread = QThread() + self.searchIDworker = workers.SimpleWorker( + posData, self.searchIDworkerCallback, func_args=(searchedID,) + ) + self.searchIDworker.frame_i_found = None + self.searchIDworker.moveToThread(self.searchIDthread) + + self.searchIDworker.signals.finished.connect(self.searchIDthread.quit) + self.searchIDworker.signals.finished.connect(self.searchIDworker.deleteLater) + self.searchIDthread.finished.connect(self.searchIDthread.deleteLater) + + self.searchIDworker.signals.critical.connect(self.searchIDworkerCritical) + self.searchIDworker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.searchIDworker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.searchIDworker.signals.progress.connect(self.workerProgress) + self.searchIDworker.signals.finished.connect(self.searchIDworkerFinished) + + self.searchIDthread.started.connect(self.searchIDworker.run) + self.searchIDthread.start() + + self.searchIDworkerLoop = QEventLoop() + self.searchIDworkerLoop.exec_() + + return self.searchIDworker.frame_i_found + + def warnIDnotFound(self, searchedID): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + Object ID {searchedID} was not found.

    + """) + msg.warning(self, f"ID {searchedID} not found", txt) diff --git a/cellacdc/mixins/points_layers.py b/cellacdc/mixins/points_layers.py new file mode 100644 index 000000000..39b0143e6 --- /dev/null +++ b/cellacdc/mixins/points_layers.py @@ -0,0 +1,1251 @@ +"""Qt view adapter for points-layer workflows.""" + +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Mapping +from copy import deepcopy +from datetime import datetime +from functools import partial + +import matplotlib +import numpy as np +import pyqtgraph as pg +import skimage.draw +import skimage.measure +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import QLabel + +from cellacdc import _warnings, apps, colors, exception_handler, html_utils, widgets + +from .brush_tools import BrushTools + + +class PointsLayers(BrushTools): + """Extracted from guiWin.""" + + def addClickedPoint(self, action, x, y, id): + x, y = round(x, 2), round(y, 2) + posData = self.data[self.pos_i] + pointsDataPos = action.pointsData.get(self.pos_i) + if pointsDataPos is None: + action.pointsData[self.pos_i] = {} + + framePointsData = action.pointsData[self.pos_i].get(posData.frame_i) + if action.snapToMax: + radius = round(action.pointSize / 2) + rr, cc = skimage.draw.disk((round(y), round(x)), radius) + idx_max = (self.img1.image[rr, cc]).argmax() + y, x = rr[idx_max], cc[idx_max] + + if framePointsData is None: + if posData.SizeZ > 1: + zSlice = self.zSliceScrollBar.sliderPosition() + action.pointsData[self.pos_i][posData.frame_i] = { + zSlice: {"x": [x], "y": [y], "id": [id]} + } + else: + action.pointsData[self.pos_i][posData.frame_i] = { + "x": [x], + "y": [y], + "id": [id], + } + else: + if posData.SizeZ > 1: + zSlice = self.zSliceScrollBar.sliderPosition() + z_data = framePointsData.get(zSlice) + if z_data is None: + framePointsData[zSlice] = {"x": [x], "y": [y], "id": [id]} + else: + framePointsData[zSlice]["x"].append(x) + framePointsData[zSlice]["y"].append(y) + framePointsData[zSlice]["id"].append(id) + action.pointsData[self.pos_i][posData.frame_i] = framePointsData + else: + pointsDataPos = action.pointsData[self.pos_i] + framePointsData = pointsDataPos[posData.frame_i] + framePointsData["x"].append(x) + framePointsData["y"].append(y) + framePointsData["id"].append(id) + + self.markPointsLayerDirty(action=action) + + def addPointsByClickingButtonToggled(self, checked=True, sender=None): + if sender is None: + sender = self.sender() + if not sender.isChecked(): + action = sender.action + action.scatterItem.setVisible(False) + return + + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(sender) + self.connectLeftClickButtons() + action = sender.action + action.scatterItem.setVisible(True) + self.ax1_BrushCircle.setBrush(action.brushColor) + self.ax1_BrushCircle.setPen(action.penColor) + + def addPointsByClickingScatterItemHoverEntered(self, item, points, event): + point = points[0] + point_id = point.data() + toolButton = item.action.button + toolButton.rightClickIDSpinbox.prevId = toolButton.rightClickIDSpinbox.value() + toolButton.rightClickIDSpinbox.setValue(point_id) + + def addPointsLayer(self, toolbar=None): + proceed = self.checkLoadedTableIds(toolbar) + + if self.addPointsWin.cancel or not proceed: + self.addPointsWin = None + self.logger.info("Adding points layer cancelled.") + return + + if toolbar is None: + toolbar = self.pointsLayersToolbar + + symbol = self.addPointsWin.symbol + color = self.addPointsWin.color + pointSize = self.addPointsWin.pointSize + zRadius = int((self.addPointsWin.zHeight - 1) / 2) + r, g, b, a = color.getRgb() + + scatterItem = widgets.PointsScatterPlotItem( + [], + [], + ax=self.ax1, + symbol=symbol, + pxMode=False, + size=pointSize, + brush=pg.mkBrush(color=(r, g, b, 100)), + pen=pg.mkPen(width=2, color=(r, g, b)), + hoverable=True, + hoverBrush=pg.mkBrush((r, g, b, 200)), + tip=None, + show_data_as_tip=True, + ) + self.ax1.addItem(scatterItem) + + toolButton = widgets.PointsLayerToolButton(symbol, color, parent=self) + toolButton.actions = [] + toolButton.setCheckable(True) + toolButton.setChecked(True) + if self.addPointsWin.keySequence is not None: + toolButton.setShortcut(self.addPointsWin.keySequence) + toolButton.toggled.connect(self.pointLayerToolbuttonToggled) + toolButton.sigEditAppearance.connect(self.editPointsLayerAppearance) + toolButton.sigShowIdsToggled.connect(self.showPointsLayerIdsToggled) + toolButton.sigRemove.connect(partial(self.removePointsLayer, toolbar=toolbar)) + + action = toolbar.addWidget(toolButton) + action.state = self.addPointsWin.state() + + toolButton.action = action + action.brushColor = (r, g, b, 100) + action.brushColorId0 = ( + *colors.hex_to_rgb( + colors.lighten_color(np.array(action.brushColor) / 255, 0.3) + ), + 100, + ) + action.penColor = (r, g, b) + action.penColorId0 = colors.lighten_color(np.array(action.penColor) / 255, 0.3) + action.pointSize = pointSize + action.zRadius = zRadius + action.button = toolButton + action.scatterItem = scatterItem + scatterItem.action = action + action.layerType = self.addPointsWin.layerType + action.layerTypeIdx = self.addPointsWin.layerTypeIdx + action.loadedDf = self.addPointsWin.loadedDf + posData = self.data[self.pos_i] + action.pointsData = {} + action.pointsData[self.pos_i] = self.addPointsWin.pointsData + action.snapToMax = False + action.loadedDfInfo = self.addPointsWin.loadedDfInfo + self.setPointsLayerLoadedDfEndanme(action) + + if self.addPointsWin.layerType.startswith("Click to annotate point"): + action.snapToMax = self.addPointsWin.snapToMaxToggle.isChecked() + isLoadedDf = self.addPointsWin.clickEntryIsLoadedDf + self.setupAddPointsByClicking(toolButton, isLoadedDf, toolbar=toolbar) + if self.addPointsWin.autoPilotToggle.isChecked(): + self.autoPilotZoomToObjToggle.setChecked(True) + + weighingChannel = self.addPointsWin.weighingChannel + self.loadPointsLayerWeighingData(action, weighingChannel) + + self.drawPointsLayers() + + if toolbar == self.promptSegmentPointsLayerToolbar: + self.promptSegmentPointsLayerToolbar.isPointsLayerInit = True + self.magicPromptsToolbar.clearPointsAction.setDisabled(False) + self.magicPromptsToolbar.clearPointsActionOnZoom.setDisabled(False) + QTimer.singleShot(200, self.magicPromptsToolbar.selectModelAction.trigger) + + self.addPointsWin = None + + def addPointsLayerTriggered(self, checked=False, toolbar=None): + if toolbar is None: + toolbar = self.pointsLayersToolbar + + if self.addPointsWin is not None: + self.logger.info("Add points layer window is already open. Cannot add now.") + return + + onlyMouseClicks = toolbar == self.promptSegmentPointsLayerToolbar + posData = self.data[self.pos_i] + self.addPointsWin = apps.AddPointsLayerDialog( + channelNames=posData.chNames, + imagesPath=posData.images_path, + hideCentroidsSection=onlyMouseClicks, + hideWeightedCentroidsSection=onlyMouseClicks, + hideFromTableSection=onlyMouseClicks, + hideManualEntrySection=onlyMouseClicks, + hideWithMouseClicksSection=False, + parent=self, + ) + cmap = matplotlib.colormaps["gist_rainbow"] + i = np.random.default_rng(seed=123).uniform() + for action in toolbar.actions()[1:]: + if not hasattr(action, "layerTypeIdx"): + continue + rgb = [round(c * 255) for c in cmap(i)][:3] + self.addPointsWin.appearanceGroupbox.colorButton.setColor(rgb) + break + + self.addPointsWin.sigCriticalReadTable.connect(self.logger.info) + self.addPointsWin.sigLoadedTable.connect(self.logLoadedTablePointsLayer) + self.addPointsWin.sigClosed.connect( + partial(self.addPointsLayer, toolbar=toolbar) + ) + self.addPointsWin.sigCheckClickEntryTableEndnameExists.connect( + self.checkClickEntryTableEndnameExists + ) + self.addPointsWin.show() + if self.addPointsWin.clickEntryRadiobutton.isChecked(): + QTimer.singleShot( + 200, + partial( + self.addPointsWin.sigCheckClickEntryTableEndnameExists.emit, + self.addPointsWin.clickEntryTableEndname.text(), + False, + ), + ) + + def askLoadNewerRecoveryClickEntryDfs(self, tableEndName, newer_recovery_filepaths): + if not newer_recovery_filepaths: + return False + + num_tables = len(newer_recovery_filepaths) + filepath, recovery_filepath = newer_recovery_filepaths[0] + main_timestamp = datetime.fromtimestamp(os.path.getmtime(filepath)).strftime( + "%a %d. %b. %y - %H:%M:%S" + ) + recovery_timestamp = datetime.fromtimestamp( + os.path.getmtime(recovery_filepath) + ).strftime("%a %d. %b. %y - %H:%M:%S") + + if num_tables == 1: + text = html_utils.paragraph( + f"A newer recovery version of {tableEndName}.csv " + "was found.

    " + f"Main table save date: {main_timestamp}
    " + f"Recovery save date: {recovery_timestamp}

    " + "Do you want to load the newer recovery version?" + ) + else: + text = html_utils.paragraph( + f"Newer recovery versions of {tableEndName}.csv " + f"were found for {num_tables} positions.

    " + f"Example main table save date: {main_timestamp}
    " + f"Example recovery save date: {recovery_timestamp}

    " + "Do you want to load the newer recovery version where available?" + ) + + msg = widgets.myMessageBox(wrapText=False) + _, yesButton, _ = msg.warning( + self.addPointsWin, + "Newer recovery table found", + text, + buttonsTexts=("Cancel", "Yes, load newer recovery", "No, load main table"), + ) + return msg.clickedButton == yesButton + + def askSaveAddedPoints(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph("Do you want to save the annotated points?") + _, noButton, yesButton = msg.question( + self, "Save?", txt, buttonsTexts=("Cancel", "No", "Yes") + ) + if msg.clickedButton != yesButton: + return + + for toolbar in self.pointsLayersToolbars: + for action in self.pointsLayersToolbar.actions(): + try: + if "Save annotated" in action.text(): + action.trigger() + except Exception as err: + pass + + def askSavePointsLayer(self, action): + toolButton = action.button + tableEndName = toolButton.clickEntryTableEndName + saveAction = toolButton.saveAction + + txt = html_utils.paragraph(f""" + Do you want to save the points you added + (table called {tableEndName}.csv)? + """) + msg = widgets.myMessageBox(wrapText=False) + _, _, saveButton = msg.question( + self, + "Save points layer?", + txt, + buttonsTexts=("Cancel", "No, do not save", "Yes, save points"), + ) + if msg.clickedButton == saveButton: + self.savePointsAddedByClicking(saveAction.saveToolbutton, None) + + return msg.cancel + + def autoPilotZoomToObjToggled(self, checked): + if not checked: + self.zoomOut() + return + + posData = self.data[self.pos_i] + if not posData.IDs: + self.logger.info("There are no objects in current segmentation mask") + return + self.autoPilotZoomToObjSpinBox.setValue(posData.IDs[0]) + self.zoomToObj(posData.rp[0]) + + def autoZoomNextObj(self): + self.sender().setValue(self.sender().value() - 1) + self.pointsLayerAutoPilot("next") + self.setFocusMain() + self.setFocusGraphics() + + def autoZoomPrevObj(self): + self.sender().setValue(self.sender().value() + 1) + self.pointsLayerAutoPilot("prev") + self.setFocusMain() + self.setFocusGraphics() + + def buttonAddPointsByClickingActive(self): + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "layerTypeIdx"): + continue + if action.layerTypeIdx == 4 and action.button.isChecked(): + return action.button + + def checkAskSavePointsLayers(self): + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "layerTypeIdx"): + continue + if action.layerTypeIdx != 4: + continue + + scatterItem = action.scatterItem + xx, yy = scatterItem.getData() + + if xx is None or len(xx) == 0: + toolButton = action.button + tableEndName = toolButton.clickEntryTableEndName + # Check in other loaded pos + are_there_points_to_save = False + for pos_i, _posData in enumerate(self.data): + if pos_i == self.pos_i: + continue + + df = _posData.clickEntryPointsDfs.get(tableEndName) + if df is None: + continue + + are_there_points_to_save = True + break + + if not are_there_points_to_save: + continue + + cancel = self.askSavePointsLayer(action) + if cancel: + return cancel + + return False + + def checkClickEntryTableEndnameExists(self, tableEndName, forceLoading): + doesTableExists = False + for posData in self.data: + filepath, _ = self.getClickEntryTableFilepaths(posData, tableEndName) + if os.path.exists(filepath): + doesTableExists = True + break + + if not doesTableExists: + return + + if not forceLoading: + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + f"The table {tableEndName}.csv already exists!

    " + "Do you want to load it?" + ) + _, yesButton, _ = msg.warning( + self.addPointsWin, + "Table exists!", + txt, + buttonsTexts=("Cancel", "Yes, load it", "No, let me enter a new name"), + ) + if msg.clickedButton != yesButton: + return + + newer_recovery_filepaths = self.getClickEntryNewerRecoveryFilepaths( + tableEndName + ) + load_recovery_if_newer = self.askLoadNewerRecoveryClickEntryDfs( + tableEndName, newer_recovery_filepaths + ) + + self.loadClickEntryDfs(tableEndName, loadRecoveryIfNewer=load_recovery_if_newer) + + def checkLoadedTableIds(self, toolbar): + if toolbar != self.promptSegmentPointsLayerToolbar: + return True + + for posData in self.data: + for tableEndName, df in posData.clickEntryPointsDfs.items(): + for point_id in df["id"].values: + if point_id in posData.IDs_idxs: + proceed = self.warnAddingPointWithExistingId( + point_id, table_endname=tableEndName + ) + return proceed + + return True + + def clearPointsLayers(self): + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + try: + action.scatterItem.clear() + except Exception as e: + continue + + def drawPointsLayers(self, computePointsLayers=True): + posData = self.data[self.pos_i] + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "layerTypeIdx"): + continue + + if action.layerTypeIdx < 2 and computePointsLayers: + self.getCentroidsPointsData(action) + + if not action.button.isChecked(): + continue + + frames = action.pointsData.get(self.pos_i, set()) + if posData.frame_i not in frames: + if action.layerTypeIdx != 4: + self.logger.info( + f"Frame number {posData.frame_i + 1} does not have any " + f'"{action.layerType}" point to display.' + ) + continue + + framePointsData = action.pointsData[self.pos_i][posData.frame_i] + + if "x" not in framePointsData: + # 3D points + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == "single z-slice" and posData.SizeZ > 1 + if isZslice: + xx, yy, ids, data = [], [], [], [] + zSlice = self.zSliceScrollBar.sliderPosition() + zRadius = action.zRadius + zRange = range(zSlice - zRadius, zSlice + zRadius + 1) + for z in zRange: + z_data = framePointsData.get(z) + if z_data is None: + continue + xx.extend(z_data["x"]) + yy.extend(z_data["y"]) + ids.extend(z_data["id"]) + try: + data.extend(z_data["data"]) + except KeyError as err: + # data is needed only for loaded tables + pass + else: + xx, yy, ids, data = [], [], [], [] + # z-projection --> draw all points + for z, z_data in framePointsData.items(): + xx.extend(z_data["x"]) + yy.extend(z_data["y"]) + ids.extend(z_data["id"]) + try: + data.extend(z_data["data"]) + except KeyError as err: + # data is needed only for loaded tables + pass + else: + # 2D segmentation + xx = framePointsData["x"] + yy = framePointsData["y"] + ids = framePointsData["id"] + try: + data = framePointsData["data"] + except KeyError as err: + # data is needed only for loaded tables + pass + + brushColors = [ + action.brushColor if id != 0 else action.brushColorId0 for id in ids + ] + brushes = [pg.mkBrush(color) for color in brushColors] + + pensColor = [ + action.penColor if id != 0 else action.penColorId0 for id in ids + ] + pens = [pg.mkPen(color) for color in pensColor] + + if action.layerTypeIdx == 2: + # For loaded table show the rest of the table as a tooltip + data = data + show_data_as_tip = True + else: + data = ids + show_data_as_tip = False + + xx = np.array(xx) # + 0.5 + yy = np.array(yy) # + 0.5 + + action.scatterItem.show_data_as_tip = show_data_as_tip + action.scatterItem.setData(xx, yy, data=data, brush=brushes, pen=pens) + + def editPointsLayerAppearance(self, button): + win = apps.EditPointsLayerAppearanceDialog(parent=self) + win.restoreState(button.action.state) + win.exec_() + if win.cancel: + return + + symbol = win.symbol + color = win.color + pointSize = win.pointSize + zRadius = int((win.zHeight - 1) / 2) + r, g, b, a = color.getRgb() + + scatterItem = button.action.scatterItem + scatterItem.opts["hoverBrush"] = pg.mkBrush((r, g, b, 200)) + scatterItem.setSymbol(symbol, update=False) + scatterItem.setBrush(pg.mkBrush(color=(r, g, b, 100)), update=False) + scatterItem.setPen(pg.mkPen(width=2, color=(r, g, b)), update=False) + scatterItem.setSize(pointSize, update=True) + + button.action.brushColor = (r, g, b, 100) + button.action.penColor = (r, g, b) + button.action.pointSize = pointSize + button.action.zRadius = zRadius + + button.action.state = win.state() + + def flushDirtyPointsLayersAutosave(self): + if not self.dirtyPointsLayerTableEndNames: + return + + for tableEndName in tuple( + self.dirtyPointsLayerTableEndNames + ): # avoid runtime error + self.savePointsAddedByClickingFromEndname(tableEndName, recovery=True) + + self.dirtyPointsLayerTableEndNames.clear() + + def getAddedPointId( + self, + isMagicPrompts, + addPointsByClickingButton, + right_click, + left_click, + middle_click, + ): + action = addPointsByClickingButton.action + if right_click: + id = addPointsByClickingButton.rightClickIDSpinbox.value() + elif left_click: + id = addPointsByClickingButton.pointIdSpinbox.value() + id = self.getClickedPointNewId( + action, + id, + addPointsByClickingButton.pointIdSpinbox, + isMagicPrompts=isMagicPrompts, + ) + if isMagicPrompts: + proceed = self.warnAddingPointWithExistingId(id) + if not proceed: + return + + addPointsByClickingButton.pointIdSpinbox.setValue(id) + elif middle_click: + id = 0 + + return id + + def getCentroidsPointsData(self, action): + # Centroids (either weighted or not) + # NOTE: if user requested to draw from table we load that in + # apps.AddPointsLayerDialog.ok_cb() + posData = self.data[self.pos_i] + action.pointsData[self.pos_i] = {posData.frame_i: {}} + if hasattr(action, "weighingData"): + lab = posData.lab + img = action.weighingData[self.pos_i][posData.frame_i] + rp = skimage.measure.regionprops(lab, intensity_image=img) + attr = "weighted_centroid" + else: + rp = posData.rp + attr = "centroid" + for i, obj in enumerate(rp): + centroid = getattr(obj, attr) + if len(centroid) == 3: + zc, yc, xc = centroid + z_int = round(zc) + if z_int not in action.pointsData[self.pos_i][posData.frame_i]: + action.pointsData[self.pos_i][posData.frame_i][z_int] = { + "x": [xc], + "y": [yc], + "id": [obj.label], + } + else: + z_data = action.pointsData[self.pos_i][posData.frame_i][z_int] + z_data["x"].append(xc) + z_data["y"].append(yc) + z_data["id"].append(obj.label) + else: + yc, xc = centroid + if "y" not in action.pointsData[self.pos_i][posData.frame_i]: + action.pointsData[self.pos_i][posData.frame_i]["y"] = [yc] + action.pointsData[self.pos_i][posData.frame_i]["x"] = [xc] + action.pointsData[self.pos_i][posData.frame_i]["id"] = [obj.label] + else: + action.pointsData[self.pos_i][posData.frame_i]["y"].append(yc) + action.pointsData[self.pos_i][posData.frame_i]["x"].append(xc) + action.pointsData[self.pos_i][posData.frame_i]["id"].append( + obj.label + ) + + def getClickEntryNewerRecoveryFilepaths(self, tableEndName): + newer_recovery_filepaths = [] + for posData in self.data: + filepath, recovery_filepath = self.getClickEntryTableFilepaths( + posData, tableEndName + ) + if not os.path.exists(filepath) or not os.path.exists(recovery_filepath): + continue + + if ( + os.path.getmtime(recovery_filepath) <= os.path.getmtime(filepath) + 15 + ): # add a 15 second tolerance + continue + + newer_recovery_filepaths.append((filepath, recovery_filepath)) + + return newer_recovery_filepaths + + def getClickEntryTableFilepaths(self, posData, tableEndName): + if posData.basename.endswith("_"): + basename = posData.basename + else: + basename = f"{posData.basename}_" + + csv_filename = f"{basename}{tableEndName}" + if not csv_filename.endswith(".csv"): + csv_filename = f"{csv_filename}.csv" + + filepath = os.path.join(posData.images_path, csv_filename) + recovery_filepath = os.path.join(posData.images_path, "recovery", csv_filename) + return filepath, recovery_filepath + + def getClickedPointNewId( + self, action, current_id, pointIdSpinbox, isMagicPrompts=False + ): + removed_id = getattr(pointIdSpinbox, "removedId", None) + if removed_id is not None: + pointIdSpinbox.removedId = None + return removed_id + + posData = self.data[self.pos_i] + if isMagicPrompts: + is_already_new = self.isPointIdAlreadyNew(current_id, action) + if is_already_new: + return current_id + + new_ID = self.setBrushID(return_val=True) + new_id = max(current_id, new_ID) + 1 + return new_id + else: + pointsDataPos = action.pointsData.get(self.pos_i) + if pointsDataPos is None: + return 1 + + framePointsData = pointsDataPos.get(posData.frame_i) + if framePointsData is None: + return 1 + if posData.SizeZ > 1: + new_id = 1 + for z_data in framePointsData.values(): + max_id = max(z_data.get("id", 0), default=0) + 1 + if max_id > new_id: + new_id = max_id + else: + new_id = max(framePointsData.get("id", 0), default=0) + 1 + if current_id >= new_id: + return current_id + return new_id + + def isPointIdAlreadyNew(self, point_id, action): + posData = self.data[self.pos_i] + if point_id in posData.IDs_idxs: + return False + + is_ID = point_id in posData.IDs_idxs + pointsDataPos = action.pointsData.get(self.pos_i) + if pointsDataPos is None: + return not is_ID + + framePointsData = pointsDataPos.get(posData.frame_i) + if framePointsData is None: + return not is_ID + + if "x" not in framePointsData: + is_id_already_added = False + for z, z_data in framePointsData.items(): + if point_id in z_data["id"]: + is_id_already_added = True + break + else: + is_id_already_added = point_id in framePointsData["id"] + + is_already_new = not is_ID and not is_id_already_added + return is_already_new + + def loadClickEntryDfs(self, tableEndName, loadRecoveryIfNewer=False): + for posData in self.data: + filepath, recovery_filepath = self.getClickEntryTableFilepaths( + posData, tableEndName + ) + + if loadRecoveryIfNewer: + recovery_exists = os.path.exists(recovery_filepath) + main_exists = os.path.exists(filepath) + if recovery_exists and ( + not main_exists + or os.path.getmtime(recovery_filepath) + > os.path.getmtime(filepath) + 15 + ): + filepath = recovery_filepath + elif not main_exists: + continue + + if not os.path.exists(filepath): + continue + + self.logger.info(f'Loading points from "{filepath}"...') + df = pd.read_csv(filepath) + if "id" not in df.columns: + df["id"] = range(1, len(df) + 1) + posData.clickEntryPointsDfs[tableEndName] = df + + try: + self.addPointsWin.loadButton.confirmAction() + except Exception as err: + pass + + def loadPointsLayerWeighingData(self, action, weighingChannel): + if not weighingChannel: + return + + self.logger.info(f'Loading "{weighingChannel}" weighing data...') + action.weighingData = [] + for p, posData in enumerate(self.data): + if weighingChannel == posData.user_ch_name: + wData = posData.img_data + action.weighingData.append(wData) + continue + + path, filename = self.getPathFromChName(weighingChannel, posData) + if path is None: + self.criticalFluoChannelNotFound(weighingChannel, posData) + action.weighingData = [] + return + + if filename in posData.fluo_data_dict: + # Weighing data already loaded as additional fluo channel + wData = posData.fluo_data_dict[filename] + else: + # Weighing data never loaded --> load now + wData, _ = self.load_fluo_data(path) + if posData.SizeT == 1: + wData = wData[np.newaxis] + action.weighingData.append(wData) + + def logLoadedTablePointsLayer(self, df, filename: str): + separator = f"-" * 100 + header = f'First 10 rows of loaded table - "{filename}":' + footer = f"Number of points: {len(df)}" + text = f"{separator}\n{header}\n\n{df.head(10)}\n\n{footer}\n{separator}" + if filename: + text = f"{text}\nFilename: {filename}" + self.logger.info(text) + + def markPointsLayerDirty(self, tableEndName=None, action=None): + if tableEndName is None and action is not None: + tableEndName = getattr(action, "clickEntryTableEndName", None) + + if tableEndName is None: + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + if addPointsByClickingButton is None: + return + tableEndName = addPointsByClickingButton.clickEntryTableEndName + + self.dirtyPointsLayerTableEndNames.add(tableEndName) + + def pointLayerToolbuttonToggled(self, checked): + action = self.sender().action + action.scatterItem.setVisible(checked) + + def pointsLayerAutoPilot(self, direction): + if not self.autoPilotZoomToObjToggle.isChecked(): + return + ID = self.autoPilotZoomToObjSpinBox.value() + posData = self.data[self.pos_i] + if not posData.IDs: + return + + try: + ID_idx = posData.IDs_idxs[ID] + if direction == "next": + nextID_idx = ID_idx + 1 + else: + nextID_idx = ID_idx - 1 + obj = posData.rp[nextID_idx] + except Exception as e: + self.logger.info(f"Auto-pilot restarted from first ID") + obj = posData.rp[0] + + self.autoPilotZoomToObjSpinBox.setValue(obj.label) + self.zoomToObj(obj) + + def pointsLayerClicksDfsToData(self, posData, toolbar=None): + if toolbar is None: + toolbar = self.pointsLayersToolbar + + for action in toolbar.actions()[1:]: + if not hasattr(action, "button"): + continue + + if not hasattr(action.button, "clickEntryTableEndName"): + continue + tableEndName = action.button.clickEntryTableEndName + action.pointsData[self.pos_i] = {} + if posData.clickEntryPointsDfs.get(tableEndName) is None: + continue + + df = posData.clickEntryPointsDfs[tableEndName] + + if posData.SizeZ > 1 and df["z"].isna().any(): + self.warnLoadedPointsTableIsNot3D(tableEndName) + return + + for frame_i, df_frame in df.groupby("frame_i"): + action.pointsData[self.pos_i][frame_i] = {} + if posData.SizeZ > 1: + for z, df_zlice in df_frame.groupby("z"): + xx = df_zlice["x"].to_list() + yy = df_zlice["y"].to_list() + ids = df_zlice["id"].to_list() + action.pointsData[self.pos_i][frame_i][z] = { + "x": xx, + "y": yy, + "id": ids, + } + else: + xx = df_frame["x"].to_list() + yy = df_frame["y"].to_list() + ids = df_frame["id"].to_list() + action.pointsData[self.pos_i][frame_i] = { + "x": xx, + "y": yy, + "id": ids, + } + + def pointsLayerDataToDf(self, posData, getOnlyActive=False, toolbar=None): + df = None + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "button"): + continue + if not hasattr(action.button, "clickEntryTableEndName"): + continue + + tableEndName = action.button.clickEntryTableEndName + if getOnlyActive and not action.button.isChecked(): + continue + + df = toolbar.fromActionToDataFrame( + action, posData, isSegm3D=self.isSegm3D + ) + posData.clickEntryPointsDfs[tableEndName] = df + return df + + def pointsLayerDfsToData(self, posData): + self.pointsLayerClicksDfsToData(posData) + + def pointsLayerLoadedDfsToData(self): + posData = self.data[self.pos_i] + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "loadedDfInfo"): + continue + + if action.loadedDfInfo is None: + continue + + endname = action.loadedDfInfo.get("endname") + if endname is None: + continue + + filename = f"{posData.basename}{endname}" + filepath = os.path.join(posData.images_path, filename) + if not os.path.exists(filepath): + action.pointsData[self.pos_i] = {} + + df = load.load_df_points_layer(filepath) + action.pointsData[self.pos_i] = load.loaded_df_to_points_data( + df, + action.loadedDfInfo["t"], + action.loadedDfInfo["z"], + action.loadedDfInfo["y"], + action.loadedDfInfo["x"], + ) + self.logLoadedTablePointsLayer(df, filename=filename) + + def pointsLayerToggled(self, checked): + if not checked: + for action in self.pointsLayersToolbar.actions(): + try: + if "Save annotated" in action.text(): + self.askSaveAddedPoints() + break + except Exception as err: + pass + self.pointsLayersToolbar.setVisible(checked) + self.autoPilotZoomToObjToolbar.setVisible(checked) + if self.pointsLayersNeverToggled: + self.pointsLayersToolbar.sigAddPointsLayer.emit() + self.pointsLayersNeverToggled = False + QTimer.singleShot(200, self.autoRange) + + def reinitPointsLayers(self): + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + toolbar.removeAction(action) + toolbar.setVisible(False) + self.autoPilotZoomToObjToolbar.setVisible(False) + + def removeClickedPoints(self, action, points): + posData = self.data[self.pos_i] + framePointsData = action.pointsData[self.pos_i][posData.frame_i] + if posData.SizeZ > 1: + zProjHow = self.zProjComboBox.currentText() + if zProjHow != "single z-slice": + _warnings.warnCannotAddRemovePointsProjection() + return + zSlice = self.zSliceScrollBar.sliderPosition() + else: + zSlice = None + + removed_ids = [] + for point in points: + pos = point.pos() + x, y = pos.x(), pos.y() + if zSlice is not None: + zSliceRad = action.zRadius + sliceFramePointsData = [ + framePointsData[z] + for z in range(zSlice - zSliceRad, zSlice + zSliceRad + 1) + if z in framePointsData.keys() + ] + else: + sliceFramePointsData = [framePointsData] + + for sliceFramePointsData in sliceFramePointsData: + if point.data() in sliceFramePointsData["id"]: + sliceFramePointsData["x"].remove(x) + sliceFramePointsData["y"].remove(y) + sliceFramePointsData["id"].remove(point.data()) + removed_ids.append(point.data()) + + if removed_ids: + self.markPointsLayerDirty(action=action) + + return removed_ids + + def removePointsLayer(self, button, toolbar=None): + button.setChecked(False) + button.action.scatterItem.setData([], []) + button.action.loadedDfInfo = None + self.ax1.removeItem(button.action.scatterItem) + toolbar.removeAction(button.action) + for action in button.actions: + toolbar.removeAction(action) + + if toolbar == self.promptSegmentPointsLayerToolbar: + self.promptSegmentPointsLayerToolbar.isPointsLayerInit = False + + def resizeRangeWelcomeText(self): + xRange, yRange = self.ax1.viewRange() + deltaX = xRange[1] - xRange[0] + deltaY = yRange[1] - yRange[0] + self.ax1.setXRange(0, deltaX) + self.ax1.setYRange(0, deltaY) + self.ax1.setLimits(xMin=0, xMax=deltaX, yMin=0, yMax=deltaY) + + def restartZoomAutoPilot(self): + if not self.autoPilotZoomToObjToggle.isChecked(): + return + + posData = self.data[self.pos_i] + if not posData.IDs: + return + + self.autoPilotZoomToObjSpinBox.setValue(posData.IDs[0]) + self.zoomToObj(posData.rp[0]) + + def restorePrevPointIdRightClick(self, addPointsByClickingButton): + # Try to restore the id that was there before hovering + # because the hovering was required only to delete the + # point + try: + prevId = addPointsByClickingButton.rightClickIDSpinbox.prevId + addPointsByClickingButton.rightClickIDSpinbox.setValue(prevId) + except Exception as err: + addPointsByClickingButton.rightClickIDSpinbox.prevId = None + + def savePointsAddedByClicking(self, button, event): + sender = button.action + toolButton = sender.toolButton + tableEndName = toolButton.clickEntryTableEndName + + self.logger.info(f"Saving _{tableEndName}.csv table...") + + self.savePointsAddedByClickingFromEndname(tableEndName) + + self.logger.info(f"{tableEndName}.csv saved!") + self.titleLabel.setText(f"{tableEndName}.csv saved!", color="g") + + def savePointsAddedByClickingFromEndname(self, tableEndName, recovery=False): + self.pointsLayerDataToDf(self.data[self.pos_i]) + for posData in self.data: + if not posData.basename.endswith("_"): + basename = f"{posData.basename}_" + else: + basename = posData.basename + tableFilename = f"{basename}{tableEndName}.csv" + if recovery: + tableFilepath = os.path.join( + posData.recoveryFolderpath(), tableFilename + ) + else: + tableFilepath = os.path.join(posData.images_path, tableFilename) + df = posData.clickEntryPointsDfs.get(tableEndName) + if df is None: + continue + df = df.sort_values(["frame_i", "Cell_ID"]) + df.to_csv(tableFilepath, index=False) + + def setHoverCircleAddPoint(self, x, y): + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + if addPointsByClickingButton is None: + return + action = addPointsByClickingButton.action + self.setHoverToolSymbolData( + [x], [y], (self.ax1_BrushCircle,), size=action.pointSize + ) + + def setPointsLayerLoadedDfEndanme(self, action): + if action.loadedDfInfo is None: + return + + posData = self.data[self.pos_i] + images_path = posData.images_path.replace("\\", "/") + + df_folderpath = os.path.dirname( + action.loadedDfInfo["filepath"].replace("\\", "/") + ) + + if images_path != df_folderpath: + return + + df_filename = os.path.basename(action.loadedDfInfo["filepath"]) + + if not df_filename.startswith(posData.basename): + return + + endname = df_filename[len(posData.basename) :] + action.loadedDfInfo["endname"] = endname + + action.button.setToolTip(endname) + + def setupAddPointsByClicking(self, toolButton, isLoadedDf, toolbar): + self.LeftClickButtons.append(toolButton) + posData = self.data[self.pos_i] + tableEndName = self.addPointsWin.clickEntryTableEndnameText + if isLoadedDf is not None: + posData = self.data[self.pos_i] + tableEndName = tableEndName[len(posData.basename) :] + self.loadClickEntryDfs(tableEndName) + + toolButton.toolbar = toolbar + toolButton.clickEntryTableEndName = tableEndName + self.checkableQButtonsGroup.addButton(toolButton) + toolButton.toggled.connect(self.addPointsByClickingButtonToggled) + + self.addPointsByClickingButtonToggled(sender=toolButton) + + toolButton.setToolTip(tableEndName) + + pointIdSpinbox = widgets.SpinBox() + pointIdSpinbox.setMinimum(0) + pointIdSpinbox.setValue(1) + pointIdSpinbox.label = QLabel(" Left-click ID: ") + pointIdSpinbox.labelAction = toolbar.addWidget(pointIdSpinbox.label) + if toolbar == self.promptSegmentPointsLayerToolbar: + newID = self.setBrushID(return_val=True) + pointIdSpinbox.setValue(newID) + pointIdSpinbox.setReadOnly(True) + pointIdSpinbox.setToolTip( + "The ids added with left-click cannot be manually edited. " + "They are always a new, non-existing id." + ) + + toolButton.actions.append(pointIdSpinbox.labelAction) + pointIdSpinbox.action = toolbar.addWidget(pointIdSpinbox) + toolButton.actions.append(pointIdSpinbox.action) + pointIdSpinbox.toolButton = toolButton + toolButton.pointIdSpinbox = pointIdSpinbox + + rightClickIDSpinbox = widgets.SpinBox() + pointIdSpinbox.setLinkedValueWidget(rightClickIDSpinbox) + rightClickIDSpinbox.setMaximumWidth(pointIdSpinbox.sizeHint().width()) + rightClickIDSpinbox.setValue(pointIdSpinbox.value()) + rightClickIDSpinbox.setMinimum(0) + rightClickIDSpinbox.label = QLabel(" | Right-click ID: ") + rightClickIDSpinbox.labelAction = toolbar.addWidget(rightClickIDSpinbox.label) + toolButton.actions.append(rightClickIDSpinbox.labelAction) + rightClickIDSpinbox.action = toolbar.addWidget(rightClickIDSpinbox) + toolButton.actions.append(rightClickIDSpinbox.action) + rightClickIDSpinbox.toolButton = toolButton + toolButton.rightClickIDSpinbox = rightClickIDSpinbox + + saveToolbutton = widgets.SavePointsLayerButton(tableEndName, parent=self) + saveToolbutton.sigRenameTableAction.connect( + self.updatePointsLayerClickEntryTableEndname + ) + saveToolbutton.sigLeftClick.connect(self.savePointsAddedByClicking) + saveAction = toolbar.addWidget(saveToolbutton) + saveToolbutton.action = saveAction + saveAction.saveToolbutton = saveToolbutton + saveAction.toolButton = toolButton + toolButton.saveAction = saveAction + toolButton.saveToolbutton = saveToolbutton + + toolButton.actions.append(saveAction) + + vlineAction = toolbar.addWidget(widgets.QVLine()) + spacerAction = toolbar.addWidget(widgets.QHWidgetSpacer(width=5)) + + toolButton.actions.append(vlineAction) + toolButton.actions.append(spacerAction) + + action = toolButton.action + scatterItem = action.scatterItem + scatterItem.sigHoverEntered.connect( + self.addPointsByClickingScatterItemHoverEntered + ) + + self.pointsLayerClicksDfsToData(posData, toolbar=toolbar) + + def showPointsLayerIdsToggled(self, button, checked): + button.action.scatterItem.drawIds = checked + self.drawPointsLayers() + + def storeUndoAddPoint(self, action): + if not hasattr(self, "undoAddPointQueueMapper"): + self.undoAddPointQueueMapper = defaultdict(list) + + posData = self.data[self.pos_i] + pointsDataPos = action.pointsData.get(self.pos_i) + if pointsDataPos is None: + return + + state = deepcopy(pointsDataPos) + self.undoAddPointQueueMapper[action].append(state) + self.undoAction.setEnabled(True) + + def undoAddPoint(self, action): + undoAddPointQueue = self.undoAddPointQueueMapper.get(action) + if undoAddPointQueue is None: + return False + + if len(undoAddPointQueue) == 0: + return False + + posData = self.data[self.pos_i] + state = undoAddPointQueue.pop(-1) + action.pointsData[self.pos_i] = state + self.markPointsLayerDirty(action=action) + + self.drawPointsLayers(computePointsLayers=False) + + if len(self.undoAddPointQueueMapper[action]) == 0: + self.undoAction.setEnabled(True) + + return True + + def updatePointsLayerClickEntryTableEndname(self, saveToolbutton, table_endname): + saveAction = saveToolbutton.action + toolButton = saveAction.toolButton + toolButton.clickEntryTableEndName = table_endname + + self.logger.info( + f'Done. Click entry table endname updated to "{table_endname}"' + ) + + def zoomToObj(self, obj=None): + if not hasattr(self, "data"): + return + posData = self.data[self.pos_i] + if obj is None: + ID = self.sender().value() + try: + ID_idx = posData.IDs_idxs[ID] + obj = obj = posData.rp[ID_idx] + except Exception as e: + self.logger.warning(f"ID {ID} does not exist (add points by clicking)") + + if obj is None: + return + + self.goToZsliceSearchedID(obj) + min_row, min_col, max_row, max_col = self.getObjBbox(obj.bbox) + xRange = min_col - 5, max_col + 5 + yRange = max_row + 5, min_row - 5 + + self.ax1.setRange(xRange=xRange, yRange=yRange) diff --git a/cellacdc/mixins/preprocessing.py b/cellacdc/mixins/preprocessing.py new file mode 100644 index 000000000..b4c0f5644 --- /dev/null +++ b/cellacdc/mixins/preprocessing.py @@ -0,0 +1,407 @@ +"""Qt view adapter for image preprocessing workflows.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from qtpy.QtCore import QMutex, QThread, QWaitCondition + +from cellacdc import apps, html_utils, widgets, workers +from cellacdc.plot import imshow + +from .session import Session + + +class Preprocessing(Session): + """Extracted from guiWin.""" + + def askGet2Dor3Dimage(self): + txt = html_utils.paragraph(""" + Do you want to test the denoising on the visualized 2D image or + on the entire 3D z-stack? + """) + msg = widgets.myMessageBox(wrapText=False) + _, use3Dbutton, use2Dbutton = msg.question( + self, + "3D denoising?", + txt, + buttonsTexts=("Cancel", "Denoise 3D z-stack", "Denoise 2D image"), + ) + if msg.cancel: + return + + if msg.clickedButton == use3Dbutton: + posData = self.data[self.pos_i] + zslice = self.zSliceScrollBar.sliderPosition() + return posData.img_data[posData.frame_i, zslice] + else: + return self.getDisplayedImg1() + + def debugShowImg(self, img): + imshow(img) + + def getChData(self, requ_ch=None, pos_i=None): + if not pos_i: + pos_i = self.pos_i + + posData = self.data[pos_i] + + if not requ_ch: + requ_ch = set(self.ch_names) + else: + requ_ch = set(requ_ch) + + posData.setLoadedChannelNames() + + loaded_channels = set(posData.loadedChNames) + missing_channels = requ_ch - loaded_channels + + self.loadFluo_cb(fluo_channels=missing_channels) + + def preprocWorkerClosed(self, worker): + self.logger.info("Pre-processing worker stopped.") + + def preprocWorkerCritical(self, error): + self.preprocessDialog.appliedFinished() + self.workerCritical(error) + + def preprocWorkerDone( + self, + processed_data: np.ndarray, + how: str, + ): + self.setStatusBarLabel(log=False) + self.preprocessDialog.appliedFinished() + + posData = self.data[self.pos_i] + if not hasattr(posData, "preproc_img_data"): + posData.preproc_img_data = preprocess.PreprocessedData() + + if how == "current_image": + if posData.SizeZ > 1: + z_slice = self.z_slice_index() + posData.preproc_img_data[posData.frame_i][z_slice] = processed_data + else: + posData.preproc_img_data[posData.frame_i] = processed_data + z_slice = 0 + self.img1.updateMinMaxValuesPreprocessedData( + self.data, self.pos_i, posData.frame_i, z_slice + ) + elif how == "z_stack": + for z_slice, processed_img in enumerate(processed_data): + posData.preproc_img_data[posData.frame_i][z_slice] = processed_img + self.img1.updateMinMaxValuesPreprocessedData( + self.data, self.pos_i, posData.frame_i, z_slice + ) + self.img1.updateMinMaxValuesPreprocessedProjections( + self.data, self.pos_i, posData.frame_i + ) + elif how == "all_frames": + for frame_i, processed_frame in enumerate(processed_data): + if processed_frame.ndim == 2: + processed_frame = (processed_frame,) + + for z_slice, processed_img in enumerate(processed_frame): + posData.preproc_img_data[frame_i][z_slice] = processed_img + self.img1.updateMinMaxValuesPreprocessedData( + self.data, self.pos_i, frame_i, z_slice + ) + self.img1.updateMinMaxValuesPreprocessedProjections( + self.data, self.pos_i, frame_i + ) + elif how == "all_pos": + for pos_i, processed_pos_data in enumerate(processed_data): + if processed_pos_data.ndim == 2: + processed_pos_data = (processed_pos_data,) + + posData = self.data[pos_i] + if not hasattr(posData, "preproc_img_data"): + posData.preproc_img_data = preprocess.PreprocessedData() + for z_slice, processed_img in enumerate(processed_pos_data): + posData.preproc_img_data[0][z_slice] = processed_img + self.img1.updateMinMaxValuesPreprocessedData( + self.data, pos_i, 0, z_slice + ) + + if posData.SizeZ > 1: + self.img1.updateMinMaxValuesPreprocessedProjections( + self.data, pos_i, frame_i + ) + + if not self.viewPreprocDataToggle.isChecked(): + self.viewPreprocDataToggle.setChecked(True) + else: + self.setImageImg1() + + def preprocWorkerIsQueueEmpty(self, isEmpty: bool): + if isEmpty: + self.preprocessDialog.appliedFinished() + else: + self.preprocessDialog.setDisabled(True) + self.preprocessDialog.infoLabel.setText( + "Computing preview...
    " + "(Feel free to use Cell-ACDC while waiting)" + ) + + def preprocWorkerPreviewDone( + self, processed_data: np.ndarray, key: Tuple[int, int, Union[int, str]] + ): + pos_i, frame_i, z_slice = key + posData = self.data[pos_i] + if not hasattr(posData, "preproc_img_data"): + posData.preproc_img_data = preprocess.PreprocessedData( + image_data=np.zeros(posData.img_data.shape) + ) + + posData.preproc_img_data[frame_i][z_slice] = processed_data + self.img1.updateMinMaxValuesPreprocessedData(self.data, pos_i, frame_i, z_slice) + + self.setImageImg1() + + def preprocessActionTriggered(self): + self.preprocessDialog.show() + self.preprocessDialog.raise_() + self.preprocessDialog.activateWindow() + self.preprocessDialog.emitSigPreviewToggled() + + def preprocessAllFrames(self, recipe: List[Dict[str, Any]]): + txt = "Pre-processing all frames..." + self.logger.info(txt) + self.statusBarLabel.setText(txt) + + posData = self.data[self.pos_i] + func = core.preprocess_video_from_recipe + image_data = posData.img_data + self.preprocWorker.setupJob(func, image_data, recipe, "all_frames") + self.preprocWorker.wakeUp() + + def preprocessAllPos(self, recipe: List[Dict[str, Any]]): + txt = "Pre-processing all Positions..." + self.logger.info(txt) + self.statusBarLabel.setText(txt) + + func = core.preprocess_multi_pos_from_recipe + recipe = core.validate_multidimensional_recipe( + recipe, apply_to_all_frames=False + ) + image_data = [posData.img_data[0] for posData in self.data] + self.preprocWorker.setupJob(func, image_data, recipe, "all_pos") + + self.preprocWorker.wakeUp() + + def preprocessCurrentImage(self, recipe: List[Dict[str, Any]], *args): + txt = "Pre-processing current image..." + self.logger.info(txt) + self.statusBarLabel.setText(txt) + + func = core.preprocess_image_from_recipe + recipe = core.validate_multidimensional_recipe(recipe) + + image_data = self.getImage(raw=True) + self.preprocWorker.setupJob(func, image_data, recipe, "current_image") + + self.preprocWorker.wakeUp() + + def preprocessDialogRecipeChanged( + self, recipe + ): # why does this need the recepie as an arg + recipe = self.preprocessDialog.recipe() + if recipe is None: + self.logger.warning("Pre-processing recipe not initialized yet.") + return + + self.updatePreprocessPreview(recipe=recipe) + + def preprocessDialogSavePreprocessedData(self, dialog): + posData = self.data[self.pos_i] + + try: + posData.preprocessedDataArray() + except TypeError as e: + if "Not all frames have been processed." in str(e): + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Not all frames have been processed.
    " + "Please process all frames before saving." + ) + msg.warning(self, "Process all data before saving", txt) + return + + helpText = """ + The preprocessed image file will be saved with a different + file name.

    + Insert a name to append to the end of the new file name. The rest of + the name will be the same as the original file. + """ + + win = apps.filenameDialog( + basename=f"{posData.basename}{self.user_ch_name}", + ext=".tif", + hintText="Insert a name for the preprocessed image file:", + defaultEntry="preprocessed", + helpText=helpText, + allowEmpty=False, + parent=dialog, + ) + win.exec_() + if win.cancel: + return + + appendedText = win.entryText + + self.progressWin = apps.QDialogWorkerProgress( + title="Saving pre-processed image(s)", + parent=self, + pbarDesc="Saving pre-processed image(s)", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + self.statusBarLabel.setText("Saving pre-processed data...") + + self.savePreprocWorker = workers.SaveProcessedDataWorker( + self.data, appendedText, ext=".tif" + ) + + self.savePreprocThread = QThread() + self.savePreprocWorker.moveToThread(self.savePreprocThread) + self.savePreprocWorker.signals.finished.connect(self.savePreprocThread.quit) + self.savePreprocWorker.signals.finished.connect( + self.savePreprocWorker.deleteLater + ) + self.savePreprocThread.finished.connect(self.savePreprocThread.deleteLater) + + self.savePreprocWorker.signals.critical.connect(self.workerCritical) + self.savePreprocWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.savePreprocWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.savePreprocWorker.signals.progress.connect(self.workerProgress) + self.savePreprocWorker.signals.finished.connect(self.savePreprocWorkerFinished) + + self.savePreprocThread.started.connect(self.savePreprocWorker.run) + self.savePreprocThread.start() + + def preprocessEnqueueCurrentImage(self, recipe): + posData = self.data[self.pos_i] + func = core.preprocess_image_from_recipe + image_data = self.getImage(raw=True) + if posData.SizeZ > 1: + z_slice = self.z_slice_index() + else: + z_slice = 0 + + recipe = core.validate_multidimensional_recipe(recipe) + + key = (self.pos_i, posData.frame_i, z_slice) + self.preprocWorker.enqueue(func, image_data, recipe, key) + + def preprocessPreviewToggled(self, checked): + self.viewPreprocDataToggle.setChecked(checked) + self.updatePreprocessPreview() + + def preprocessZStack(self, recipe: List[Dict[str, Any]], *args): + txt = "Pre-processing z-stack..." + self.statusBarLabel.setText(txt) + self.logger.info(txt) + + posData = self.data[self.pos_i] + func = core.preprocess_zstack_from_recipe + recipe = core.validate_multidimensional_recipe( + recipe, apply_to_all_frames=False + ) + image_data = posData.img_data[posData.frame_i] + self.preprocWorker.setupJob(func, image_data, recipe, "z_stack") + + self.preprocWorker.wakeUp() + + def setupPreprocessing(self): + posData = self.data[self.pos_i] + if self.preprocessDialog is not None: + self.preprocessDialog.close() + + self.preprocessDialog = apps.PreProcessRecipeDialog( + isTimelapse=posData.SizeT > 1, + isZstack=posData.SizeZ > 1, + isMultiPos=len(self.data) > 1, + df_metadata=posData.metadata_df, + hideOnClosing=True, + addApplyButton=True, + parent=self, + ) + self.doPreviewPreprocImage = False + self.preprocessDialog.sigApplyImage.connect(self.preprocessCurrentImage) + self.preprocessDialog.sigApplyZstack.connect(self.preprocessZStack) + self.preprocessDialog.sigApplyAllFrames.connect(self.preprocessAllFrames) + self.preprocessDialog.sigApplyAllPos.connect(self.preprocessAllPos) + self.preprocessDialog.sigPreviewToggled.connect(self.preprocessPreviewToggled) + self.preprocessDialog.sigValuesChanged.connect( + self.preprocessDialogRecipeChanged + ) + self.preprocessDialog.sigSavePreprocData.connect( + self.preprocessDialogSavePreprocessedData + ) + + if self.preprocWorker is not None: + return + + self.preprocThread = QThread() + self.preprocMutex = QMutex() + self.preprocWaitCond = QWaitCondition() + + self.preprocWorker = workers.CustomPreprocessWorkerGUI( + self.preprocMutex, self.preprocWaitCond + ) + + self.preprocWorker.moveToThread(self.preprocThread) + self.preprocWorker.signals.finished.connect(self.preprocThread.quit) + self.preprocWorker.signals.finished.connect(self.preprocWorker.deleteLater) + self.preprocThread.finished.connect(self.preprocThread.deleteLater) + + self.preprocWorker.sigDone.connect(self.preprocWorkerDone) + self.preprocWorker.sigIsQueueEmpty.connect(self.preprocWorkerIsQueueEmpty) + self.preprocWorker.sigPreviewDone.connect(self.preprocWorkerPreviewDone) + self.preprocWorker.signals.progress.connect(self.workerProgress) + self.preprocWorker.signals.critical.connect(self.workerCritical) + self.preprocWorker.signals.finished.connect(self.preprocWorkerClosed) + + self.preprocThread.started.connect(self.preprocWorker.run) + self.preprocThread.start() + + self.logger.info("Pre-processing worker started.") + + def updatePreprocessPreview(self, *args, **kwargs): + force = kwargs.get("force", False) + + if not self.preprocessDialog.isVisible() and not force: + return + + if not self.preprocessDialog.previewCheckbox.isChecked() and not force: + return + + if kwargs.get("recipe") is None: + recipe = self.preprocessDialog.recipe() + else: + recipe = kwargs.get("recipe") + + if recipe is None: + self.logger.warning("Pre-processing recipe not initialized yet.") + return + + txt = "Pre-processing current image..." + self.logger.info(txt) + self.statusBarLabel.setText(txt) + + self.preprocessEnqueueCurrentImage(recipe) + + def viewPreprocDataToggled(self, checked): + self.img1.setUsePreprocessed(checked) + self.setImageImg1() + + if self.viewCombineChannelDataToggle.isChecked(): + self.viewCombineChannelDataToggle.toggled.disconnect() + self.viewCombineChannelDataToggle.setChecked(False) + self.viewCombineChannelDataToggle.toggled.connect( + self.viewCombineChannelDataToggled + ) diff --git a/cellacdc/mixins/quick_settings.py b/cellacdc/mixins/quick_settings.py new file mode 100644 index 000000000..db95f18c1 --- /dev/null +++ b/cellacdc/mixins/quick_settings.py @@ -0,0 +1,148 @@ +"""View adapter for quick settings and side-panel widgets.""" + +from __future__ import annotations + +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QFormLayout, QLabel, QVBoxLayout + +from cellacdc import apps, settings_csv_path, widgets + +from .actions import Actions + + +class QuickSettings(Actions): + """Extracted from guiWin.""" + + def gui_createQuickSettingsWidgets(self): + self.quickSettingsLayout = QVBoxLayout() + self.quickSettingsGroupbox = widgets.GroupBox() + self.quickSettingsGroupbox.setTitle("Quick settings") + + layout = QFormLayout() + layout.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.FieldsStayAtSizeHint) + layout.setFormAlignment(Qt.AlignRight | Qt.AlignVCenter) + + self.viewPreprocDataToggle = widgets.Toggle() + viewPreprocDataToggleTooltip = ( + "View pre-processed data. See menu `Image --> Pre-processing...`\n" + "on the top menubar." + ) + self.viewPreprocDataToggle.setChecked(False) + self.viewPreprocDataToggle.setToolTip(viewPreprocDataToggleTooltip) + viewPreprocDataToggleLabel = QLabel("View pre-processed image") + viewPreprocDataToggleLabel.setToolTip(viewPreprocDataToggleTooltip) + layout.addRow(viewPreprocDataToggleLabel, self.viewPreprocDataToggle) + + self.viewCombineChannelDataToggle = widgets.Toggle() + viewCombineChannelDataToggleTooltip = ( + "View combined channel. See menu `Image --> combing channels...`\n" + "on the top menubar." + ) + self.viewCombineChannelDataToggle.setChecked(False) + self.viewCombineChannelDataToggle.setToolTip( + viewCombineChannelDataToggleTooltip + ) + viewCombineChannelDataToggleLabel = QLabel("View combined channels") + viewCombineChannelDataToggleLabel.setToolTip( + viewCombineChannelDataToggleTooltip + ) + layout.addRow( + viewCombineChannelDataToggleLabel, self.viewCombineChannelDataToggle + ) + + self.autoSaveToggle = widgets.Toggle() + autoSaveTooltip = ( + "Automatically store a copy of the segmentation data " + "in the `.recovery` folder after every edit." + ) + self.autoSaveToggle.setChecked(True) + self.autoSaveToggle.setToolTip(autoSaveTooltip) + autoSaveLabel = QLabel("Autosave segmentation") + autoSaveLabel.setToolTip(autoSaveTooltip) + layout.addRow(autoSaveLabel, self.autoSaveToggle) + + self.autoSaveAnnotToggle = widgets.Toggle() + autoSaveAnnotTooltip = ( + "Automatically store a copy of the annotations (acdc_output CSV file) " + "in the `.recovery` folder after every edit." + ) + self.autoSaveAnnotToggle.setChecked(True) + self.autoSaveAnnotToggle.setToolTip(autoSaveAnnotTooltip) + autoSaveAnnotLabel = QLabel("Autosave annotations") + autoSaveAnnotLabel.setToolTip(autoSaveAnnotTooltip) + layout.addRow(autoSaveAnnotLabel, self.autoSaveAnnotToggle) + + self.autoSaveIntervalEditButton = widgets.editPushButton( + flat=True, hoverable=True + ) + self.autoSaveIntervalLabel = QLabel("Autosave interval") + self.autoSaveIntervalSetTooltip() + layout.addRow(self.autoSaveIntervalLabel, self.autoSaveIntervalEditButton) + + self.autoSaveIntervalDialog = apps.AutoSaveIntervalDialog(parent=self) + self.autoSaveIntervalDialog.setValues(*self.autoSaveIntevalValueUnit) + + self.ccaIntegrCheckerToggle = widgets.Toggle() + ccaIntegrCheckerToggleTooltip = ( + "Toggle background cell cycle annotations integrity checker ON/OFF" + ) + self.ccaIntegrCheckerToggle.setChecked(False) + self.ccaIntegrCheckerToggle.setToolTip(ccaIntegrCheckerToggleTooltip) + label = QLabel("Cc annot. checker") + label.setToolTip(ccaIntegrCheckerToggleTooltip) + layout.addRow(label, self.ccaIntegrCheckerToggle) + if "is_cca_integrity_checker_activated" in self.df_settings.index: + idx = "is_cca_integrity_checker_activated" + val = int(self.df_settings.at[idx, "value"]) + self.ccaIntegrCheckerToggle.setChecked(not val) + + self.annotLostObjsToggle = widgets.Toggle() + annotLostObjsToggleTooltip = "Toggle annotation of lost objects mode ON/OFF" + self.annotLostObjsToggle.setChecked(True) + self.annotLostObjsToggle.setToolTip(annotLostObjsToggleTooltip) + label = QLabel("Annot. lost objects") + label.setToolTip(annotLostObjsToggleTooltip) + layout.addRow(label, self.annotLostObjsToggle) + + self.realTimeTrackingToggle = widgets.Toggle() + self.realTimeTrackingToggle.setChecked(True) + self.realTimeTrackingToggle.setDisabled(True) + label = QLabel("Real-time tracking") + label.setDisabled(True) + self.realTimeTrackingToggle.label = label + layout.addRow(label, self.realTimeTrackingToggle) + + self.showAllContoursToggle = widgets.Toggle() + showAllContoursTooltip = ( + "If active, all contours will be displayed, including inner contours" + "(e.g. holes and sub-objects)" + ) + self.showAllContoursToggle.setToolTip(showAllContoursTooltip) + showAllContourLabel = QLabel("Show all contours") + showAllContourLabel.setToolTip(showAllContoursTooltip) + layout.addRow(showAllContourLabel, self.showAllContoursToggle) + self.showAllContoursToggle.toggled.connect(self.showAllContoursToggled) + + # Font size + self.fontSizeSpinBox = widgets.SpinBox() + self.fontSizeSpinBox.setMinimum(1) + self.fontSizeSpinBox.setMaximum(99) + layout.addRow("Font size", self.fontSizeSpinBox) + savedFontSize = str(self.df_settings.at["fontSize", "value"]) + if savedFontSize.find("pt") != -1: + savedFontSize = savedFontSize[:-2] + self.fontSize = int(savedFontSize) + if "pxMode" not in self.df_settings.index: + # Users before introduction of pxMode had pxMode=False, but now + # the new default is True. This requires larger font size. + self.fontSize = 2 * self.fontSize + self.df_settings.at["pxMode", "value"] = 1 + self.df_settings.to_csv(settings_csv_path) + self.fontSizeSpinBox.setValue(self.fontSize) + self.fontSizeSpinBox.editingFinished.connect(self.changeFontSize) + self.fontSizeSpinBox.sigUpClicked.connect(self.changeFontSize) + self.fontSizeSpinBox.sigDownClicked.connect(self.changeFontSize) + + self.quickSettingsGroupbox.setLayout(layout) + self.quickSettingsLayout.addWidget(self.quickSettingsGroupbox) + self.quickSettingsLayout.addStretch(1) diff --git a/cellacdc/mixins/saving.py b/cellacdc/mixins/saving.py new file mode 100644 index 000000000..176164960 --- /dev/null +++ b/cellacdc/mixins/saving.py @@ -0,0 +1,1003 @@ +"""Qt view adapter for save and autosave workflows.""" + +from __future__ import annotations + +import os +import uuid +from datetime import datetime +from functools import partial +from typing import Literal + +import pandas as pd +from qtpy.QtCore import QEventLoop, QMutex, QThread, QTimer, QWaitCondition +from qtpy.QtGui import QFont +from qtpy.QtWidgets import QCheckBox, QMessageBox +from tqdm import tqdm + +from cellacdc import _warnings, apps, disableWindow, exception_handler +from cellacdc import cca_df_colnames, html_utils, settings_csv_path, widgets +from cellacdc import load +from cellacdc import workers + + +_font = QFont() +_font.setPixelSize(11) + +from .app_shell import AppShell + + +class Saving(AppShell): + """Extracted from guiWin.""" + + def _enqueueAutoSave(self): + if not self.statusBarLabel.text().endswith("Autosaving..."): + self.statusBarLabel.setText(f"{self.statusBarLabel.text()} | Autosaving...") + + timestamp = datetime.now().strftime(r"%H:%M:%S.%f")[:-3] + self.logger.info(f"Autosaving... - {timestamp}") + + posData = self.data[self.pos_i] + worker, thread = self.autoSaveActiveWorkers[-1] + worker.enqueue(posData) + + def _waitCloseAutoSaveWorker(self): + didWorkersFinished = [True] + for worker, thread in self.autoSaveActiveWorkers: + if worker.isFinished: + didWorkersFinished.append(True) + else: + didWorkersFinished.append(False) + if all(didWorkersFinished): + self.waitCloseAutoSaveWorkerLoop.stop() + + def askConcatenate(self): + if self.mainWin is None: + return + + if self._isQuickSave: + return + + if "showAskConcatenate" not in self.df_settings.index: + self.df_settings.at["showAskConcatenate", "value"] = "Yes" + + showAskConcatenate = self.df_settings.at["showAskConcatenate", "value"] == "Yes" + if not showAskConcatenate: + return + + txt = html_utils.paragraph(f""" + Do you want to concatenate the `acdc_output.csv` tables from + multiple Positions into one single CSV file?
    + """) + doNotShowAgainCheckbox = QCheckBox("Do not show again") + msg = widgets.myMessageBox(wrapText=False) + noButton, yesButton = msg.question( + self, + "Concatenate tables?", + txt, + buttonsTexts=("No", "Yes"), + widgets=doNotShowAgainCheckbox, + ) + showAskConcatenate = "No" if doNotShowAgainCheckbox.isChecked() else "Yes" + self.df_settings.at["showAskConcatenate", "value"] = showAskConcatenate + self.df_settings.to_csv(settings_csv_path) + + if not msg.clickedButton == yesButton: + return + + txt = html_utils.paragraph(f""" + To concatenate the `acdc_output.csv` tables from + multiple Positions and multiple experiments
    + launch the concatenation utility from the top menubar of the Cell-ACDC main launcher:

    + Utilities --> Concatenate --> Concatenate acdc output tables from multiple Positions and experiments.... + """) + msg = widgets.myMessageBox(wrapText=False) + msg.information(self, "How to concatenate tables", txt) + + def askPosToSave(self): + return self.askSelectPos() + + def askSaveLastVisitedCcaMode(self, isQuickSave=False): + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + frame_i = 0 + last_tracked_i = 0 + self.save_until_frame_i = 0 + if self.isSnapshot: + return True + + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + frame_i -= 1 + break + + self.save_until_frame_i = frame_i + self.save_cca_until_frame_i = frame_i + self.last_tracked_i = frame_i + + if isQuickSave: + return True + + last_cca_frame_i = self.navigateScrollBar.maximum() - 1 + # Ask to save last visited frame or not + txt = html_utils.paragraph(f""" + You annotated the cell cycle stages up + until frame number {last_cca_frame_i + 1}.

    + Enter up to which frame number you want to save the + cell cycle annotations: + """) + lastFrameDialog = apps.QLineEditDialog( + title="Last annoated frame number to save", + defaultTxt=str(last_cca_frame_i + 1), + msg=txt, + parent=self, + allowedValues=(1, last_cca_frame_i + 1), + warnLastFrame=True, + isInteger=True, + stretchEntry=False, + lastVisitedFrame=last_cca_frame_i + 1, + ) + lastFrameDialog.exec_() + if lastFrameDialog.cancel: + return False + + last_save_cca_frame_i = lastFrameDialog.enteredValue - 1 + + if last_save_cca_frame_i < last_cca_frame_i: + self.resetCcaFuture(last_cca_frame_i) + + self.save_cca_until_frame_i = last_save_cca_frame_i + + return True + + def askSaveLastVisitedSegmMode(self, isQuickSave=False): + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + frame_i = 0 + last_tracked_i = 0 + self.save_until_frame_i = 0 + self.save_cca_until_frame_i = 0 + if self.isSnapshot: + return True + + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + frame_i -= 1 + break + + if isQuickSave: + self.save_until_frame_i = frame_i + self.save_cca_until_frame_i = frame_i + self.last_tracked_i = frame_i + return True + + # Ask to save last visited frame or not + txt = html_utils.paragraph(f""" + You visualised and corrected segmentation and tracking data up + until frame number {frame_i + 1}.

    + Enter up to which frame number you want to save data: + """) + lastFrameDialog = apps.QLineEditDialog( + title="Last frame number to save", + defaultTxt=str(frame_i + 1), + msg=txt, + parent=self, + allowedValues=(1, posData.SizeT), + warnLastFrame=True, + isInteger=True, + stretchEntry=False, + lastVisitedFrame=frame_i + 1, + ) + lastFrameDialog.exec_() + if lastFrameDialog.cancel: + return False + + self.save_until_frame_i = lastFrameDialog.enteredValue - 1 + self.save_cca_until_frame_i = self.save_until_frame_i + if self.save_until_frame_i > frame_i: + self.logger.info( + f"Storing frames {frame_i + 1}-{self.save_until_frame_i + 1}..." + ) + current_frame_i = posData.frame_i + # User is requesting to save past the last visited frame --> + # store data as if they were visited + for i in range(frame_i + 1, self.save_until_frame_i + 1): + posData.frame_i = i + self.get_data() + self.store_data(autosave=False) + + # Go back to current frame + posData.frame_i = current_frame_i + self.get_data() + last_tracked_i = self.save_until_frame_i + + self.last_tracked_i = last_tracked_i + return True + + def askSaveMetrics(self): + txt = html_utils.paragraph( + """ + Do you also want to save the measurements + (e.g., cell volume, mean, amount etc.)?

    + + You can find more information by clicking on the + "Set measurements" button below
    + where you will be able to select which measurements + you want to save.

    + If you already set the measurements and you want to save them click "Yes".

    + + NOTE: Saving metrics might be slow, + we recommend doing it only when you need it.
    + """ + ) + msg = widgets.myMessageBox(parent=self, resizeButtons=False, wrapText=False) + setMeasurementsButton = widgets.setPushButton("Set measurements...") + _, yesButton, noButton, _ = msg.question( + self, + "Save measurements?", + txt, + buttonsTexts=("Cancel", "Yes", "No", setMeasurementsButton), + showDialog=False, + ) + setMeasurementsButton.disconnect() + setMeasurementsButton.clicked.connect( + partial( + self.showSetMeasurements, + qparent=msg, + ) + ) + msg.exec_() + save_metrics = msg.clickedButton == yesButton + return save_metrics, msg.cancel + + def askSaveOnClosing(self, event): + if not self.saveAction.isEnabled(): + return True + if self.titleLabel.text == "Saved!": + return True + if not self.isDataLoaded: + return True + + msg = widgets.myMessageBox() + txt = html_utils.paragraph("Do you want to save before closing?") + _, noButton, yesButton = msg.question( + self, "Save?", txt, buttonsTexts=("Cancel", "No", "Yes") + ) + if msg.cancel: + event.ignore() + return False + + if msg.clickedButton == yesButton: + self.closeGUI = True + QTimer.singleShot(100, self.saveAction.trigger) + event.ignore() + return False + return True + + def askSaveOriginalSegm(self, isQuickSave=False): + if isQuickSave: + return "", True, True + + posData = self.data[self.pos_i] + if not posData.whitelist: + return "", True, True + + help_txt = html_utils.paragraph(f""" + You have whitelisted IDs in the current position.
    + Do you want to save the not whitelisted segmentation data
    + This will allow you to revisit the original segmentation.
    + """) + + txt = html_utils.paragraph(f""" + You have whitelisted IDs in the current position.
    + Do you want to save the not whitelisted segmentation data?
    + """) + + found_files = load.get_segm_files(posData.images_path) + existingEndnames = load.get_endnames(posData.basename, found_files) + + segmFilename = os.path.basename(posData.segm_npz_path) + segmFilename = f"{segmFilename[:-4]}_not_whitelisted" + win = apps.filenameDialog( + basename=posData.basename, + hintText=txt, + defaultEntry=segmFilename, + existingNames=existingEndnames, + helpText=help_txt, + allowEmpty=False, + parent=self, + title="Save not whitelisted segmentation data", + addDoNotSaveButton=True, + ) + win.exec_() + if win.cancel: + return "", False, True + if win.doNotSave: + return "", True, True + return win.entryText, True, False + + def askSelectPos(self, action="to save"): + last_pos = 1 + for p, posData in enumerate(self.data): + acdc_df = posData.allData_li[0]["acdc_df"] + if acdc_df is None: + last_pos = p + break + else: + last_pos = len(self.data) + + items = [posData.pos_foldername for posData in self.data] + selectPosWin = widgets.QDialogListbox( + f"Select Positions {action}", + f"Select Positions {action}:\n", + items, + multiSelection=True, + parent=self, + preSelectedItems=items[:last_pos], + ) + selectPosWin.exec_() + if selectPosWin.cancel: + return + + return selectPosWin.selectedItemsText + + def autoSaveAnnotToggled(self, checked): + if not self.autoSaveActiveWorkers: + self.gui_createAutoSaveWorker() + + if not self.autoSaveActiveWorkers: + return + + worker, thread = self.autoSaveActiveWorkers[-1] + + mode = self.modeComboBox.currentText() + if mode != "Viewer": + # No reason to save in viewer mode + checked = False + + worker.isAutoSaveAnnotON = checked + + def autoSaveClose(self): + for worker, thread in self.autoSaveActiveWorkers: + worker._stop() + + def autoSaveIntervalEdit(self): + self.autoSaveIntervalDialog.show() + self.autoSaveIntervalDialog.raise_() + self.autoSaveIntervalDialog.activateWindow() + + def autoSaveIntervalSetTooltip(self): + value, unit = self.autoSaveIntevalValueUnit + autoSaveIntervalEditTooltip = ( + "Change autosave interval to every N frames or minutes\n\n" + f"Current autosave interval: {value} {unit}" + ) + self.autoSaveIntervalLabel.setToolTip(autoSaveIntervalEditTooltip) + self.autoSaveIntervalEditButton.setToolTip(autoSaveIntervalEditTooltip) + + def autoSaveIntervalValueChanged( + self, value: float, unit: Literal["minutes", "frames"] + ): + self.autoSaveIntevalValueUnit = (value, unit) + self.autoSaveTimer.stop() + + self.df_settings.at["autoSaveIntevalValue", "value"] = str(value) + self.df_settings.at["autoSaveIntervalUnit", "value"] = unit + self.df_settings.to_csv(settings_csv_path) + + self.logger.info(f"Autosave interval changed to: {value} {unit}") + self.autoSaveIntervalSetTooltip() + + if unit == "frames": + self.startAutoSaveEveryNframesTimer() + + def autoSaveTimerCountFrames(self): + if not hasattr(self, "data"): + # This happes when the self.autoSaveTimer times out after + # the GUI has been closed --> we simply ignore it + return + + posData = self.data[self.pos_i] + autoSaveIntevalValue, autoSaveIntervalUnit = self.autoSaveIntevalValueUnit + isTimeToAutoSave = ( + abs(posData.frame_i - self.autoSaveTimeStartFrameIdx) + >= autoSaveIntevalValue + ) + if not isTimeToAutoSave: + return + + self.autoSaveTimeStartFrameIdx = posData.frame_i + self.flushDirtyPointsLayersAutosave() + self._enqueueAutoSave() + + def autoSaveTimerTimedOut(self): + if not hasattr(self, "data"): + # This happes when the self.autoSaveTimer times out after + # the GUI has been closed --> we simply ignore it + self.autoSaveTimer.stop() + return + + self.autoSaveTimer.stop() + self.flushDirtyPointsLayersAutosave() + self._enqueueAutoSave() + + def autoSaveToggled(self, checked): + if not self.autoSaveActiveWorkers: + self.gui_createAutoSaveWorker() + + if not self.autoSaveActiveWorkers: + return + + worker, thread = self.autoSaveActiveWorkers[-1] + + mode = self.modeComboBox.currentText() + if mode != "Segmentation and Tracking": + # Autosaving segmentation makes sense only in + # "Segmentation and Tracking" mode + checked = False + + worker.isAutoSaveON = checked + + def cancelSavingInitialisation(self): + self.titleLabel.setText("Saving data process cancelled.", color=self.titleColor) + self.closeGUI = False + + def checkMissingCca(self): + proceed = True + ignore = False + doNotShowAgain = False + if not self.doNotShowAgainMissingCca: + return proceed, ignore, doNotShowAgain + + missing_cca_items = [] + for posData in self.data: + for frame_i, data_dict in enumerate(posData.allData_li): + acdc_df = data_dict["acdc_df"] + if acdc_df is None: + continue + + if "cell_cycle_stage" not in acdc_df.columns: + continue + + cca_df = acdc_df[cca_df_colnames] + if cca_df.isnull().values.any(): + i = frame_i if not self.isSnapshot else None + missing_cca_items.append((cca_df, posData, i)) + + if not missing_cca_items: + return proceed, ignore, doNotShowAgain + + proceed = False + ignore, doNotShowAgain = _warnings.warnMissingCca( + missing_cca_items, qparent=self + ) + + if doNotShowAgain: + self.df_settings.at["doNotShowAgainMissingCca", "value"] = "Yes" + self.df_settings.to_csv(self.settings_csv_path) + + return proceed, ignore, doNotShowAgain + + def computeVolumeRegionprop(self): + if "cell_vol_vox" not in self._measurements_kernel.sizeMetricsToSave: + return + + # We compute the cell volume in the main thread because calling + # skimage.transform.rotate in a separate thread causes crashes + # with segmentation fault on macOS. I don't know why yet. + self.logger.info("Computing cell volume...") + end_i = self.save_until_frame_i + pos_iter = tqdm(self.data, ncols=100) + for p, posData in enumerate(pos_iter): + if self.posToSave is not None: + if posData.pos_foldername not in self.posToSave: + continue + + PhysicalSizeY = posData.PhysicalSizeY + PhysicalSizeX = posData.PhysicalSizeX + frame_iter = tqdm( + posData.allData_li[: end_i + 1], ncols=100, position=1, leave=False + ) + for frame_i, data_dict in enumerate(frame_iter): + lab = data_dict["labels"] + if lab is None: + break + rp = data_dict["regionprops"] + obj_iter = tqdm(rp, ncols=100, position=2, leave=False) + for i, obj in enumerate(obj_iter): + vol_vox, vol_fl = _calc_rot_vol(obj, PhysicalSizeY, PhysicalSizeX) + obj.vol_vox = vol_vox + obj.vol_fl = vol_fl + posData.allData_li[frame_i]["regionprops"] = rp + + def enqAutosave(self): + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + if self.statusBarLabel.text().endswith("Autosaving..."): + self.statusBarLabel.setText( + self.statusBarLabel.text().replace(" | Autosaving...", "") + ) + return + + if not self.autoSaveActiveWorkers: + self.gui_createAutoSaveWorker() + + if not self.autoSaveActiveWorkers: + return + + if self.autoSaveTimer.isActive(): + return + + self._enqueueAutoSave() + autoSaveIntevalValue, autoSaveIntervalUnit = self.autoSaveIntevalValueUnit + if autoSaveIntevalValue == 0: + return + + try: + self.autoSaveTimer.timeout.disconnect() + except Exception as err: + pass + + if autoSaveIntervalUnit == "minutes": + autosave_interval_ms = round(autoSaveIntevalValue * 60 * 1000) + self.autoSaveTimer.timeout.connect(self.autoSaveTimerTimedOut) + self.autoSaveTimer.start(autosave_interval_ms) + else: + self.startAutoSaveEveryNframesTimer() + + def manageVersions(self): + posData = self.data[self.pos_i] + selectVersion = apps.SelectAcdcDfVersionToRestore(posData, parent=self) + selectVersion.exec_() + + if selectVersion.cancel: + return + + undoId = uuid.uuid4() + if posData.cca_df is not None: + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + + selectedTime = selectVersion.selectedTimestamp + + self.modeComboBox.setCurrentText("Viewer") + self.logger.info(f"Loading file from {selectedTime}...") + + acdc_df = load.read_acdc_df_from_archive( + selectVersion.archiveFilePath, selectVersion.selectedKey + ) + posData.acdc_df = acdc_df + frames = acdc_df.index.get_level_values(0) + last_visited_frame_i = frames.max() + current_frame_i = posData.frame_i + pbar = tqdm(total=last_visited_frame_i + 1, ncols=100) + for frame_i in range(last_visited_frame_i + 1): + posData.frame_i = frame_i + self.get_data() + if posData.cca_df is not None: + self.storeUndoRedoCca(posData.frame_i, posData.cca_df, undoId) + if posData.allData_li[frame_i]["labels"] is None: + pbar.update() + continue + + if frame_i not in frames: + acdc_df_i = pd.DataFrame(columns=acdc_df.columns) + acdc_df_i.drop(self.cca_df_colnames, axis=1, errors="ignore") + acdc_df_i.index.name = "Cell_ID" + else: + acdc_df_i = acdc_df.loc[frame_i].dropna(axis=1, how="all") + + posData.allData_li[frame_i]["acdc_df"] = acdc_df_i + pbar.update() + pbar.close() + + # Back to current frame + posData.frame_i = current_frame_i + self.get_data(debug=False) + self.updateAllImages() + self.logger.info("Annotations correctly recovered.") + + def quickSave(self): + self.saveData(isQuickSave=True) + + def saveAsData(self, checked=True): + try: + posData = self.data[self.pos_i] + except AttributeError: + return + + existingFilenames = set() + for _posData in self.data: + segm_files = load.get_segm_files(_posData.images_path) + _existingEndnames = load.get_endnames(_posData.basename, segm_files) + existingFilenames.update( + [f"{_posData.basename}{endname}.npz" for endname in _existingEndnames] + ) + posData = self.data[self.pos_i] + if posData.basename.endswith("_"): + basename = f"{posData.basename}segm" + else: + basename = f"{posData.basename}_segm" + win = apps.filenameDialog( + basename=basename, + hintText="Insert a filename for the segmentation file:
    ", + existingNames=existingFilenames, + ) + win.exec_() + if win.cancel: + return + + for posData in self.data: + posData.setFilePaths(new_endname=win.entryText) + + self.setStatusBarLabel() + self.saveData() + + def saveData(self, checked=False, finishedCallback=None, isQuickSave=False): + self.setDisabled(True, keepDisabled=True) + + self.askLineageTreeChanges() + + self.store_data(autosave=False) + self.applyDelROIs() + self.store_data() + self._isQuickSave = isQuickSave + + # Wait autosave worker to finish + for worker, thread in self.autoSaveActiveWorkers: + self.logger.info("Stopping autosaving process...") + self.statusBarLabel.setText("Stopping autosaving process...") + worker.stop() + self.waitAutoSaveWorkerTimer = QTimer() + self.waitAutoSaveWorkerTimer.timeout.connect( + partial(self.waitAutoSaveWorker, worker) + ) + self.waitAutoSaveWorkerTimer.start(100) + self.waitAutoSaveWorkerLoop = QEventLoop() + self.waitAutoSaveWorkerLoop.exec_() + + self.titleLabel.setText( + "Saving data... (check progress in the terminal)", color=self.titleColor + ) + + # Check channel name correspondence to warn + posData = self.data[self.pos_i] + lastSegmChannel, segmEndName = posData.getSegmentedChannelHyperparams() + if lastSegmChannel != self.user_ch_name and lastSegmChannel: + cancel = self.warnDifferentSegmChannel( + self.user_ch_name, lastSegmChannel, segmEndName + ) + if cancel: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + posData.updateSegmentedChannelHyperparams(self.user_ch_name) + + # Check missing cca annotations in snaphots + proceed, ignore, self.doNotShowAgainMissingCca = self.checkMissingCca() + if not proceed and not ignore: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return + + self.save_metrics = False + if not isQuickSave: + self.save_metrics, cancel = self.askSaveMetrics() + if cancel: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + + self.posToSave = None + if self.isSnapshot and not isQuickSave and len(self.data) > 1: + self.posToSave = self.askPosToSave() + if self.posToSave is None: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + + if isQuickSave: + # Quick save only current pos + self.posToSave = {self.data[self.pos_i].pos_foldername} + + if self.isSnapshot: + self.store_data(mainThread=False) + + mode = self.modeComboBox.currentText() + if mode == "Cell cycle analysis": + proceed = self.askSaveLastVisitedCcaMode(isQuickSave=isQuickSave) + if not proceed: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + else: + proceed = self.askSaveLastVisitedSegmMode(isQuickSave=isQuickSave) + if not proceed: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + + append_name_og_whitelist, proceed, do_not_save_og_whitelist = ( + self.askSaveOriginalSegm(isQuickSave=isQuickSave) + ) + if not proceed: + self.cancelSavingInitialisation() + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + return True + + if self.save_metrics or mode == "Cell cycle analysis": + self.computeVolumeRegionprop() + + infoTxt = html_utils.paragraph( + f"Saving {self.exp_path}...
    ", font_size="14px" + ) + + self.saveWin = apps.QDialogPbar( + parent=self, title="Saving data", infoTxt=infoTxt + ) + self.saveWin.setFont(_font) + # if not self.save_metrics: + self.saveWin.metricsQPbar.hide() + self.saveWin.progressLabel.setText("Preparing data...") + self.saveWin.show() + + # Set up separate thread for saving and show progress bar widget + self.mutex = QMutex() + self.waitCond = QWaitCondition() + self.thread = QThread() + self.worker = workers.saveDataWorker(self) + self.worker.mode = mode + self.worker.isQuickSave = isQuickSave + self.worker.append_name_og_whitelist = append_name_og_whitelist + self.worker.do_not_save_og_whitelist = do_not_save_og_whitelist + + self.worker.moveToThread(self.thread) + + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + # Custom signals + self.worker.finished.connect(self.saveDataFinished) + if finishedCallback is not None: + self.worker.finished.connect(finishedCallback) + self.worker.progress.connect(self.saveDataProgress) + self.worker.sigLog.connect(self.workerLog) + self.worker.progressBar.connect(self.saveDataUpdatePbar) + # self.worker.metricsPbarProgress.connect(self.saveDataUpdateMetricsPbar) + self.worker.critical.connect(self.saveDataWorkerCritical) + self.worker.customMetricsCritical.connect(self.saveDataCustomMetricsCritical) + self.worker.sigCombinedMetricsMissingColumn.connect( + self.saveDataCombinedMetricsMissingColumn + ) + self.worker.addMetricsCritical.connect(self.saveDataAddMetricsCritical) + self.worker.regionPropsCritical.connect(self.saveDataRegionPropsCritical) + self.worker.criticalPermissionError.connect(self.saveDataPermissionError) + self.worker.askZsliceAbsent.connect(self.zSliceAbsent) + self.worker.sigDebug.connect(self._workerDebug) + + self.thread.started.connect(self.worker.run) + + self.thread.start() + + return False + + def saveDataAddMetricsCritical(self, traceback_format, error_message): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + self.logger.info("") + _hl = "====================================" + self.logger.info(f"{_hl}\n{traceback_format}\n{_hl}") + self.worker.addMetricsErrors[error_message] = traceback_format + + def saveDataCombinedMetricsMissingColumn(self, error_msg, func_name): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + self.logger.info("") + warning = f"[WARNING]: {error_msg}. Metric {func_name} was skipped." + _hl = "====================================" + self.logger.info(f"{_hl}\n{warning}\n{_hl}") + self.worker.customMetricsErrors[func_name] = warning + + def saveDataCustomMetricsCritical(self, traceback_format, func_name): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + self.logger.info("") + _hl = "====================================" + self.logger.info(f"{_hl}\n{traceback_format}\n{_hl}") + self.worker.customMetricsErrors[func_name] = traceback_format + + def saveDataFinished(self): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + if self.saveWin.aborted or self.worker.abort: + self.titleLabel.setText("Saving process cancelled.", color="r") + elif self._isQuickSave: + self.titleLabel.setText("Saved segmentation file and annotations") + else: + self.titleLabel.setText("Saved!") + self.saveWin.workerFinished = True + self.saveWin.close() + + if not self.closeGUI: + # Update savedSegmData in autosave worker + self.updateSegmDataAutoSaveWorker() + + if self.worker.addMetricsErrors: + self.warnErrorsAddMetrics() + if self.worker.regionPropsErrors: + self.warnErrorsRegionProps() + if self.worker.customMetricsErrors: + self.warnErrorsCustomMetrics() + + self.checkManageVersions() + + self.askConcatenate() + + if self.closeGUI: + salute_string = utils.get_salute_string() + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + f"Data saved!. The GUI will now close.

    {salute_string}" + ) + msg.information(self, "Data saved", txt) + self.close() + + def saveDataPermissionError(self, err_msg): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + msg = QMessageBox() + msg.critical(self, "Permission denied", err_msg, msg.Ok) + self.waitCond.wakeAll() + + def saveDataProgress(self, text): + self.logger.info(text) + self.saveWin.progressLabel.setText(text) + + def saveDataRegionPropsCritical(self, traceback_format, error_message): + self.setDisabled(False, keepDisabled=False) + self.activateWindow() + self.logger.info("") + _hl = "====================================" + self.logger.info(f"{_hl}\n{traceback_format}\n{_hl}") + self.worker.regionPropsErrors[error_message] = traceback_format + + def saveDataUpdateMetricsPbar(self, max, step): + if max > 0: + self.saveWin.metricsQPbar.setMaximum(max) + self.saveWin.metricsQPbar.setValue(0) + self.saveWin.metricsQPbar.setValue(self.saveWin.metricsQPbar.value() + step) + + def saveDataUpdatePbar(self, step, max=-1, exec_time=0.0): + if max >= 0: + self.saveWin.QPbar.setMaximum(max) + else: + self.saveWin.QPbar.setValue(self.saveWin.QPbar.value() + step) + steps_left = self.saveWin.QPbar.maximum() - self.saveWin.QPbar.value() + seconds = round(exec_time * steps_left) + ETA = utils.seconds_to_ETA(seconds) + self.saveWin.ETA_label.setText(f"ETA: {ETA}") + + def saveMetricsCritical(self, traceback_format): + print("\n====================================") + self.logger.exception(traceback_format) + print("====================================\n") + self.logger.info("Warning: calculating metrics failed see above...") + print("------------------------------") + + msg = widgets.myMessageBox(wrapText=False) + err_msg = html_utils.paragraph(f""" + Error while saving metrics.

    + More details below or in the terminal/console.

    + Note that the error details from this session are also saved + in the file
    + {self.log_path}

    + Please send the log file when reporting a bug, thanks! + Please restart Cell-ACDC, we apologise for any inconvenience.

    + + """) + msg.addShowInFileManagerButton(self.logs_path, txt="Show log file...") + msg.setDetailedText(traceback_format, visible=True) + msg.critical(self, "Critical error while saving metrics", err_msg) + + self.is_error_state = True + self.waitCond.wakeAll() + + def setAutoSaveAnnotationsEnabled(self, enabled): + if not self.autoSaveActiveWorkers: + return + + worker, thread = self.autoSaveActiveWorkers[-1] + + if enabled: + worker.isAutoSaveAnnotON = self.autoSaveToggle.isChecked() + else: + worker.isAutoSaveAnnotON = False + + def setAutoSaveSegmentationEnabled(self, enabled): + if not self.autoSaveActiveWorkers: + return + + worker, thread = self.autoSaveActiveWorkers[-1] + + if enabled: + worker.isAutoSaveON = self.autoSaveToggle.isChecked() + else: + worker.isAutoSaveON = False + + def startAutoSaveEveryNframesTimer(self): + posData = self.data[self.pos_i] + self.autoSaveTimeStartFrameIdx = posData.frame_i + self.autoSaveTimer.timeout.connect(self.autoSaveTimerCountFrames) + self.autoSaveTimer.start(500) + + def turnOffAutoSaveWorker(self): + self.autoSaveToggle.setChecked(False) + + def updateSegmDataAutoSaveWorker(self): + # Update savedSegmData in autosave worker + posData = self.data[self.pos_i] + for worker, thread in self.autoSaveActiveWorkers: + worker.savedSegmData = posData.segm_data.copy() + + def waitAutoSaveWorker(self, worker): + if worker.isFinished or worker.isPaused or len(worker.dataQ) == 0: + self.waitAutoSaveWorkerLoop.exit() + self.waitAutoSaveWorkerTimer.stop() + self.setStatusBarLabel(log=False) + + def warnDifferentSegmChannel( + self, loaded_channel, segm_channel_hyperparams, segmEndName + ): + txt = html_utils.paragraph(f""" + You loaded the segmentation file ending with _{segmEndName}.npz + which corresponds to the channel + {segm_channel_hyperparams}.

    + However, in this session you loaded the channel + {loaded_channel}.

    + If you proceed with saving, the segmentation file ending with + _{segmEndName}.npz will be OVERWRITTEN.

    + Are you sure you want to proceed? + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + msg.warning( + self, + "WARNING: Potential for data loss", + txt, + buttonsTexts=("Cancel", "Yes"), + ) + return msg.cancel + + def warnErrorsAddMetrics(self): + win = apps.ComputeMetricsErrorsDialog( + self.worker.addMetricsErrors, + self.logs_path, + log_type="standard_metrics", + parent=self, + ) + win.exec_() + + def warnErrorsCustomMetrics(self): + win = apps.ComputeMetricsErrorsDialog( + self.worker.customMetricsErrors, + self.logs_path, + log_type="custom_metrics", + parent=self, + ) + win.exec_() + + def warnErrorsRegionProps(self): + win = apps.ComputeMetricsErrorsDialog( + self.worker.regionPropsErrors, + self.logs_path, + log_type="region_props", + parent=self, + ) + win.exec_() diff --git a/cellacdc/mixins/seg_for_lost_ids.py b/cellacdc/mixins/seg_for_lost_ids.py new file mode 100644 index 000000000..20105b13c --- /dev/null +++ b/cellacdc/mixins/seg_for_lost_ids.py @@ -0,0 +1,301 @@ +"""Qt view adapter for segmenting lost IDs.""" + +from __future__ import annotations + +from typing import Any + +from qtpy.QtCore import QMutex, QThread, QWaitCondition + +from cellacdc import apps, workers +from cellacdc.plot import imshow + +from .segmentation import Segmentation +from .frame_navigation import FrameNavigation + + +class SegForLostIds(Segmentation, FrameNavigation): + """Extracted from guiWin.""" + + def SegForLostIDsSetSettings(self): + + try: + prev_model = str(self.df_settings.at["SegForLostIDsModel", "value"]) + except KeyError: + prev_model = None + win = apps.QDialogSelectModel(parent=self, customFirst=prev_model) + win.exec_() + if win.cancel: + self.logger.info("Seg for lost IDs cancelled.") + return + base_model_name = win.selectedModel + + if base_model_name: + self.df_settings.at["SegForLostIDsModel", "value"] = base_model_name + self.df_settings.to_csv(self.settings_csv_path) + + model_name = "local_seg" + + idx = self.modelNames.index(model_name) + acdcSegment = self.acdcSegment_li[idx] + + try: + if acdcSegment is None or base_model_name != self.local_seg_base_model_name: + self.logger.info(f"Importing {base_model_name}...") + acdcSegment = utils.import_segment_module(base_model_name) + self.acdcSegment_li[idx] = acdcSegment + self.local_seg_base_model_name = base_model_name + except (IndexError, ImportError, KeyError) as e: + self.logger.error(f"Error importing {base_model_name}: {e}") + return + + extra_params = [ + "overlap_threshold", + "padding", + "size_perc_diff", + "distance_filler_growth", + "max_iterations", + "allow_only_tracked_cells", + ] + + extra_types = [float, float, float, float, int, bool] + + extra_defaults = [0.5, 0.8, 0.3, 1.0, 2, False] + + extra_desc = [ + "Overlap threshold with other already segemented cells over which newly segmented cells are discarded", + "Padding of the box used for new segmentation around the segmentation from the previous frame", + "Relative size difference acceptable compared to previous frames", + """Cells which are already segmented are filled with random noise sampled from background + to ensure that they don't get segmented again. + This parameter controls the additional padding around the already segmented cells.""", + """The algorithm will try and segment the maximum amount + of cells in the image by running the model several + times and filling new found cells with background noise. + How many of these iterations should be run?""", + "If no new cell IDs should be permitted (based on real time tracking)", + ] + + extra_ArgSpec = [] + for i, param in enumerate(extra_params): + param = ArgSpec( + name=param, + default=extra_defaults[i], + type=extra_types[i], + desc=extra_desc[i], + docstring="", + ) + + extra_ArgSpec.append(param) + + init_params, segment_params = utils.getModelArgSpec(acdcSegment) + segment_params = [arg for arg in segment_params if arg[0] != "diameter"] + + extraParamsTitle = "Settings for local segmentation" + win = self.initSegmModelParams( + base_model_name, + acdcSegment, + init_params, + segment_params, + extraParams=extra_ArgSpec, + extraParamsTitle=extraParamsTitle, + initLastParams=True, + ini_filename="segmentation_for_lostIDs.ini", + ) + + if win is None: + self.logger.info("Segmentation for lost IDs cancelled.") + return + + init_kwargs_new = {} + args_new = {} + for key, val in win.init_kwargs.items(): + if key in extra_params: + args_new[key] = val + else: + init_kwargs_new[key] = val + + for key, val in win.extra_kwargs.items(): + if key in extra_params: + args_new[key] = val + + self.SegForLostIDsSettings = { + "win": win, + "init_kwargs_new": init_kwargs_new, + "args_new": args_new, + "base_model_name": base_model_name, + } + + def SegForLostIDsWorkerAskInstallGPU(self, model_name, use_gpu): + result = utils.check_gpu_available(model_name, use_gpu, qparent=self) + self.SegForLostIDsWorker.gpu_go = result + dont_force_cpu = utils.check_gpu_available( + model_name, use_gpu, do_not_warn=True + ) + self.SegForLostIDsWorker.dont_force_cpu = dont_force_cpu + self.SegForLostIDsWaitCond.wakeAll() + + def SegForLostIDsWorkerAskInstallModel(self, model_name): + utils.check_install_package(model_name) + self.SegForLostIDsWaitCond.wakeAll() + + def SegForLostIDsWorkerFinished(self): + self.updateAllImages() + self.update_rp() + self.store_data(autosave=True) + self.setFrameNavigationDisabled(disable=False, why="Segmentation for lost IDs") + + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + def onSegForLostInit(self): + self.logger.info("Settings for segmentation for lost IDs not set.") + self.SegForLostIDsSetSettings() + self.SegForLostIDsWaitCond.wakeAll() + + def onSigGetData(self, waitcond, debug=False): + self.get_data(debug=debug) + waitcond.wakeAll() + + def onSigStoreData( + self, + waitcond, + pos_i=None, + enforce=True, + debug=False, + mainThread=True, + autosave=True, + store_cca_df_copy=False, + ): + self.store_data( + pos_i=pos_i, + enforce=enforce, + debug=debug, + mainThread=mainThread, + autosave=autosave, + store_cca_df_copy=store_cca_df_copy, + ) + waitcond.wakeAll() + + def onSigStoreDataSegForLostIDsWorker(self, autosave): + self.onSigStoreData(self.SegForLostIDsWaitCond, autosave=autosave) + + def onSigTrackManuallyAddedObjectSegForLostIDsWorker( + self, added_IDs, isNewID, wl_update, wl_track_og_curr + ): + self.trackManuallyAddedObject( + added_IDs, isNewID, wl_update=wl_update, wl_track_og_curr=wl_track_og_curr + ) + self.SegForLostIDsWaitCond.wakeAll() + + def onSigUpdateRP( + self, + waitcond, + draw=True, + debug=False, + update_IDs=True, + wl_update=True, + wl_track_og_curr=False, + ): + self.update_rp( + draw=draw, + debug=debug, + update_IDs=update_IDs, + wl_update=wl_update, + wl_track_og_curr=wl_track_og_curr, + ) + waitcond.wakeAll() + + def onSigUpdateRPSegForLostIDsWorker(self, wl_update, wl_track_og_curr): + self.onSigUpdateRP( + self.SegForLostIDsWaitCond, + wl_update=wl_update, + wl_track_og_curr=wl_track_og_curr, + ) + + def segForLostIDsButtonClicked(self): + + self.setFrameNavigationDisabled(disable=True, why="Segmentation for lost IDs") + posData = self.data[self.pos_i] + if posData.frame_i == 0: + self.logger.info("Segmentation for lost IDs not available on first frame.") + self.setFrameNavigationDisabled( + disable=False, why="Segmentation for lost IDs" + ) + return + self.storeUndoRedoStates(False) + self.progressWin = apps.QDialogWorkerProgress( + title="Segmenting for lost IDs", + parent=self, + pbarDesc=f"Segmenting for lost IDs...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + self.startSegForLostIDsWorker() + + def showImageDebug(self, img): + imshow(img) + + def startSegForLostIDsWorker(self): + self.SegForLostIDsMutex = QMutex() + self.SegForLostIDsWaitCond = QWaitCondition() + self._thread = QThread() + + # Initialize the worker with mutex and wait condition + self.SegForLostIDsWorker = workers.SegForLostIDsWorker( + self, self.SegForLostIDsMutex, self.SegForLostIDsWaitCond + ) + + # Connect the worker's signal to the main thread's slot + self.SegForLostIDsWorker.sigAskInit.connect(self.onSegForLostInit) + self.SegForLostIDsWorker.sigAskInstallModel.connect( + self.SegForLostIDsWorkerAskInstallModel + ) + self.SegForLostIDsWorker.sigshowImageDebug.connect(self.showImageDebug) + + self.SegForLostIDsWorker.sigSegForLostIDsWorkerAskInstallGPU.connect( + self.SegForLostIDsWorkerAskInstallGPU + ) + + self.SegForLostIDsWorker.sigStoreData.connect( + self.onSigStoreDataSegForLostIDsWorker + ) + self.SegForLostIDsWorker.sigUpdateRP.connect( + self.onSigUpdateRPSegForLostIDsWorker + ) + # self.SegForLostIDsWorker.sigGetData.connect(self.onSigGetDataSegForLostIDsWorker) + # self.SegForLostIDsWorker.sigGet2Dlab.connect(self.onSigGet2DlabSegForLostIDsWorker) + # self.SegForLostIDsWorker.sigGetTrackedLostIDs.connect(self.onSigGetTrackedSegForLostIDsWorker) + # self.SegForLostIDsWorker.sigGetBrushID.connect(self.onSigGetBrushIDSegForLostIDsWorker) + self.SegForLostIDsWorker.sigTrackManuallyAddedObject.connect( + self.onSigTrackManuallyAddedObjectSegForLostIDsWorker + ) + + # Move the worker to the thread + self.SegForLostIDsWorker.moveToThread(self._thread) + + # Manage thread lifecycle + self.SegForLostIDsWorker.signals.finished.connect(self._thread.quit) + self.SegForLostIDsWorker.signals.finished.connect( + self.SegForLostIDsWorker.deleteLater + ) + self._thread.finished.connect(self._thread.deleteLater) + + # Connect other worker signals to the appropriate slots + self.SegForLostIDsWorker.signals.finished.connect( + self.SegForLostIDsWorkerFinished + ) + self.SegForLostIDsWorker.signals.progress.connect(self.workerProgress) + self.SegForLostIDsWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.SegForLostIDsWorker.signals.progressBar.connect( + self.workerUpdateProgressbar + ) + self.SegForLostIDsWorker.signals.critical.connect(self.workerCritical) + + # Start the thread and worker + self._thread.started.connect(self.SegForLostIDsWorker.run) + self._thread.start() diff --git a/cellacdc/mixins/segmentation.py b/cellacdc/mixins/segmentation.py new file mode 100644 index 000000000..7a4236cfc --- /dev/null +++ b/cellacdc/mixins/segmentation.py @@ -0,0 +1,752 @@ +"""Qt view adapter for segmentation workflows.""" + +from __future__ import annotations + +import os + +import numpy as np +from qtpy.QtCore import QMutex, QThread, QTimer, QWaitCondition, Qt +from qtpy.QtWidgets import QAction + +from cellacdc import ( + apps, + exception_handler, + html_utils, + prompts, + printl, + widgets, + workers, +) +from cellacdc.plot import imshow + +from .tool_activation import ToolActivation + + +class Segmentation(ToolActivation): + """Extracted from guiWin.""" + + def autoSegm_cb(self, checked): + if checked: + self.askSegmParam = True + # Ask which model + models = utils.get_list_of_models() + win = widgets.QDialogListbox( + "Select model", + "Select model to use for segmentation: ", + models, + multiSelection=False, + parent=self, + ) + win.exec_() + if win.cancel: + return + model_name = win.selectedItemsText[0] + self.segmModelName = model_name + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + self.updateAllImages() + self.computeSegm() + self.askSegmParam = False + else: + self.segmModelName = None + + def checkIfAutoSegm(self): + """ + If there are any frame or position with empty segmentation mask + ask whether automatic segmentation should be turned ON + """ + if self.autoSegmAction.isChecked(): + return + if self.autoSegmDoNotAskAgain: + return + + ask = False + for posData in self.data: + if posData.SizeT > 1: + for lab in posData.segm_data: + if not np.any(lab): + ask = True + txt = "frames" + break + else: + if not np.any(posData.segm_data): + ask = True + txt = "positions" + break + + if not ask: + return + + questionTxt = html_utils.paragraph( + f"Some or all loaded {txt} contain empty segmentation masks.

    " + "Do you want to activate automatic segmentation* " + f"when visiting these {txt}?

    " + "* Automatic segmentation can always be turned ON/OFF from the menu
    " + " Edit --> Segmentation --> Enable automatic segmentation

    " + f"NOTE: you can automatically segment all {txt} using the
    " + " segmentation module." + ) + msg = widgets.myMessageBox(wrapText=False) + noButton, yesButton = msg.question( + self, "Automatic segmentation?", questionTxt, buttonsTexts=("No", "Yes") + ) + if msg.clickedButton == yesButton: + self.autoSegmAction.setChecked(True) + else: + self.autoSegmDoNotAskAgain = True + self.autoSegmAction.setChecked(False) + + def computeSegm(self, force=False): + posData = self.data[self.pos_i] + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer" or mode == "Cell cycle analysis": + return + + if np.any(posData.lab) and not force: + # Do not compute segm if there is already a mask + return + + if not self.autoSegmAction.isChecked(): + return + + self.repeatSegm(model_name=self.segmModelName) + + def debugSegmWorker(self, to_debug): + img, _lab, lab = to_debug + printl(img.shape, _lab.shape, lab.shape) + imshow(img, _lab, lab) + self.segmWorkerWaitCond.wakeAll() + + def initSegmModelParams( + self, + model_name, + acdcSegment, + init_params, + segment_params, + is_label_roi=False, + initLastParams=False, + extraParams=None, + extraParamsTitle=None, + ini_filename=None, + ): + posData = self.data[self.pos_i] + try: + url = acdcSegment.url_help() + except AttributeError: + url = None + + text_if_cancelled = "Segmentation process cancelled." + out = prompts.init_segm_model_params( + posData, + model_name, + init_params, + segment_params, + help_url=url, + qparent=self, + init_last_params=initLastParams, + check_sam_embeddings=not is_label_roi, + is_gui_caller=True, + extraParams=extraParams, + extraParamsTitle=extraParamsTitle, + ini_filename=ini_filename, + ) + if out.get("load_sam_embeddings", False): + self.logger.info("Loading Segment Anything image embeddings...") + for _posData in self.data: + _posData.loadSamEmbeddings(logger_func=None) + text_if_cancelled = "SAM embeddings loaded." + + win = out.get("win") + if win is None: + self.logger.info(text_if_cancelled) + self.titleLabel.setText(text_if_cancelled) + return + + if win.cancel: + self.logger.info(text_if_cancelled) + self.titleLabel.setText(text_if_cancelled) + return + + if model_name != "thresholding": + self.model_kwargs = win.model_kwargs + + return win + + def init_segmInfo_df(self): + for posData in self.data: + if posData is None: + # posData is None when computing measurements with the utility + # and with timelapse data + continue + posData.init_segmInfo_df() + + def postProcessSegm(self, checked): + if self.isSegm3D: + SizeZ = max([posData.SizeZ for posData in self.data]) + else: + SizeZ = None + if checked: + posData = self.data[self.pos_i] + self.postProcessSegmWin = apps.PostProcessSegmDialog(posData, mainWin=self) + self.postProcessSegmWin.sigClosed.connect(self.postProcessSegmWinClosed) + self.postProcessSegmWin.sigValueChanged.connect( + self.postProcessSegmValueChanged + ) + self.postProcessSegmWin.sigEditingFinished.connect( + self.postProcessSegmEditingFinished + ) + self.postProcessSegmWin.sigApplyToAllFutureFrames.connect( + self.postProcessSegmApplyToAllFutureFrames + ) + self.postProcessSegmWin.show() + self.postProcessSegmWin.valueChanged(None) + else: + self.postProcessSegmWin.close() + self.postProcessSegmWin = None + + def postProcessSegmApplyToAllFutureFrames( + self, + postProcessKwargs, + customPostProcessGroupedFeatures, + customPostProcessFeatures, + ): + proceed = self.warnEditingWithCca_df( + "post-processing segmentation", update_images=False + ) + if not proceed: + self.logger.info("Post-processing segmentation cancelled.") + return + + self.progressWin = apps.QDialogWorkerProgress( + title="Post-processing segmentation", + parent=self, + pbarDesc=f"Post-processing segmentation masks...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(0) + + self.startPostProcessSegmWorker( + postProcessKwargs, + customPostProcessGroupedFeatures, + customPostProcessFeatures, + ) + + def postProcessSegmEditingFinished(self): + self.update_rp() + self.store_data() + self.updateAllImages() + + def postProcessSegmValueChanged(self, lab, delObjs: dict): + for delObj in delObjs.values(): + self.clearObjContour(obj=delObj, ax=0) + self.clearObjContour(obj=delObj, ax=1) + + posData = self.data[self.pos_i] + + labelsToSkip = {} + for ID in posData.IDs: + if ID in delObjs: + labelsToSkip[ID] = True + continue + + restoreObj = self.postProcessSegmWin.origObjs[ID] + self.addObjContourToContoursImage(obj=restoreObj, ax=0) + self.addObjContourToContoursImage(obj=restoreObj, ax=1) + + # self.setAllTextAnnotations(labelsToSkip=labelsToSkip) + + posData.lab = lab + self.setImageImg2() + if self.annotSegmMasksCheckbox.isChecked(): + self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) + if self.annotSegmMasksCheckboxRight.isChecked(): + self.labelsLayerRightImg.setImage(self.currentLab2D, autoLevels=False) + + def postProcessSegmWinClosed(self): + self.postProcessSegmWin = None + self.postProcessSegmAction.toggled.disconnect() + self.postProcessSegmAction.setChecked(False) + self.postProcessSegmAction.toggled.connect(self.postProcessSegm) + + def postProcessSegmWorkerFinished(self): + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.get_data() + self.updateAllImages() + self.titleLabel.setText("Post-processing segmentation done!", color="w") + self.logger.info("Post-processing segmentation done!") + + def postProcessing(self): + if self.postProcessSegmWin is None: + return + + self.postProcessSegmWin.setPosData() + posData = self.data[self.pos_i] + lab, delIDs = self.postProcessSegmWin.apply() + if posData.allData_li[posData.frame_i]["labels"] is None: + posData.lab = lab.copy() + self.update_rp() + else: + posData.allData_li[posData.frame_i]["labels"] = lab + self.get_data() + + def reinitStoredSegmModels(self): + self.models = [None] * len(self.models) + + def repeatSegm(self, model_name="", askSegmParams=False, is_label_roi=False): + if model_name == "thresholding": + # thresholding model is stored as 'Automatic thresholding' + # at line of code `models.append('Automatic thresholding')` + model_name = "Automatic thresholding" + + idx = self.modelNames.index(model_name) + # Ask segm parameters if not already set + # and not called by segmSingleFrameMenu (askSegmParams=False) + if not askSegmParams: + askSegmParams = self.model_kwargs is None + + self.downloadWin = apps.downloadModel(model_name, parent=self) + self.downloadWin.download() + + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + if model_name == "Automatic thresholding": + # Automatic thresholding is the name of the models as stored + # in self.modelNames, but the actual model is called thresholding + # (see cellacdc/models/thresholding) + model_name = "thresholding" + + posData = self.data[self.pos_i] + # Check if model needs to be imported + acdcSegment = self.acdcSegment_li[idx] + if acdcSegment is None: + self.logger.info(f"Importing {model_name}...") + acdcSegment = utils.import_segment_module(model_name) + self.acdcSegment_li[idx] = acdcSegment + + # Ask parameters if the user clicked on the action + # Otherwise this function is called by "computeSegm" function and + # we use loaded parameters + if askSegmParams: + if self.app.overrideCursor() == Qt.WaitCursor: + self.app.restoreOverrideCursor() + self.segmModelName = model_name + # Read all models parameters + init_params, segment_params = utils.getModelArgSpec(acdcSegment) + # Prompt user to enter the model parameters + try: + url = acdcSegment.url_help() + except AttributeError: + url = None + + self.preproc_recipe = None + initLastParams = True + if model_name == "thresholding": + win = apps.QDialogAutomaticThresholding( + parent=self, isSegm3D=self.isSegm3D + ) + win.exec_() + if win.cancel: + return + self.model_kwargs = win.segment_kwargs + thresh_method = self.model_kwargs["threshold_method"] + gauss_sigma = self.model_kwargs["gauss_sigma"] + segment_params = utils.insertModelArgSpec( + segment_params, "threshold_method", thresh_method + ) + segment_params = utils.insertModelArgSpec( + segment_params, "gauss_sigma", gauss_sigma + ) + initLastParams = False + + win = self.initSegmModelParams( + model_name, + acdcSegment, + init_params, + segment_params, + is_label_roi=is_label_roi, + initLastParams=initLastParams, + ) + if win is None: + return + + self.standardPostProcessKwargs = win.standardPostProcessKwargs + self.customPostProcessFeatures = win.customPostProcessFeatures + self.customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + self.applyPostProcessing = win.applyPostProcessing + self.secondChannelName = win.secondChannelName + self.preproc_recipe = win.preproc_recipe + + utils.log_segm_params( + model_name, + win.init_kwargs, + win.model_kwargs, + logger_func=self.logger.info, + preproc_recipe=win.preproc_recipe, + apply_post_process=self.applyPostProcessing, + standard_postprocess_kwargs=self.standardPostProcessKwargs, + custom_postprocess_features=self.customPostProcessFeatures, + ) + + use_gpu = win.init_kwargs.get("gpu", False) + proceed = utils.check_gpu_available(model_name, use_gpu, qparent=self) + if not proceed: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + + model = utils.init_segm_model(acdcSegment, posData, win.init_kwargs) + if model is None: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + try: + model.setupLogger(self.logger) + except Exception as e: + pass + self.models[idx] = model + model.model_name = model_name + else: + model = self.models[idx] + + if is_label_roi: + return model + + self.titleLabel.setText( + f"Segmenting with {model_name}... (check progress in terminal/console)", + color=self.titleColor, + ) + + post_process_params = {"applied_postprocessing": self.applyPostProcessing} + post_process_params = { + **post_process_params, + **self.standardPostProcessKwargs, + **self.customPostProcessFeatures, + } + if askSegmParams: + posData.saveSegmHyperparams( + model_name, + win.init_kwargs, + win.model_kwargs, + post_process_params=post_process_params, + preproc_recipe=self.preproc_recipe, + ) + + if self.askRepeatSegment3D: + self.segment3D = False + if self.isSegm3D and self.askRepeatSegment3D: + msg = widgets.myMessageBox(showCentered=False) + msg.addDoNotShowAgainCheckbox(text="Do not ask again") + txt = html_utils.paragraph( + "Do you want to segment the entire z-stack or only the " + "current z-slice?" + ) + _, segment3DButton, _ = msg.question( + self, + "3D segmentation?", + txt, + buttonsTexts=("Cancel", "Segment 3D z-stack", "Segment 2D z-slice"), + ) + if msg.cancel: + self.titleLabel.setText("Segmentation process aborted.") + self.logger.info("Segmentation process aborted.") + return + self.segment3D = msg.clickedButton == segment3DButton + if msg.doNotShowAgainCheckbox.isChecked(): + self.askRepeatSegment3D = False + + if self.askZrangeSegm3D: + self.z_range = None + if self.isSegm3D and self.segment3D and self.askZrangeSegm3D: + idx = (posData.filename, posData.frame_i) + try: + orignal_z = posData.segmInfo_df.at[idx, "z_slice_used_gui"] + except ValueError as e: + orignal_z = posData.segmInfo_df.loc[idx, "z_slice_used_gui"].iloc[0] + selectZtool = apps.QCropZtool( + posData.SizeZ, + parent=self, + cropButtonText="Ok", + addDoNotShowAgain=True, + title="Select z-slice range to segment", + ) + selectZtool.sigZvalueChanged.connect(self.selectZtoolZvalueChanged) + selectZtool.sigCrop.connect(selectZtool.close) + selectZtool.exec_() + self.update_z_slice(orignal_z) + if selectZtool.cancel: + self.titleLabel.setText("Segmentation process aborted.") + self.logger.info("Segmentation process aborted.") + return + startZ = selectZtool.lowerZscrollbar.value() + stopZ = selectZtool.upperZscrollbar.value() + self.z_range = (startZ, stopZ) + if selectZtool.doNotShowAgainCheckbox.isChecked(): + self.askZrangeSegm3D = False + + secondChannelData = None + if self.secondChannelName is not None: + secondChannelData = self.getSecondChannelData() + + self.titleLabel.setText( + f"{model_name} is thinking... (check progress in terminal/console)", + color=self.titleColor, + ) + + self.model = model + + self.segmWorkerMutex = QMutex() + self.segmWorkerWaitCond = QWaitCondition() + self.thread = QThread() + self.worker = workers.segmWorker( + self, + secondChannelData=secondChannelData, + mutex=self.segmWorkerMutex, + waitCond=self.segmWorkerWaitCond, + ) + self.worker.z_range = self.z_range + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + if self.debug: + self.worker.debug.connect(self.debugSegmWorker) + self.thread.finished.connect(self.thread.deleteLater) + + # Custom signals + self.worker.critical.connect(self.workerCritical) + self.worker.finished.connect(self.segmWorkerFinished) + + self.thread.started.connect(self.worker.run) + self.thread.start() + + def repeatSegmVideo(self, model_name, startFrameNum, stopFrameNum): + if model_name == "thresholding": + # thresholding model is stored as 'Automatic thresholding' + # at line of code `models.append('Automatic thresholding')` + model_name = "Automatic thresholding" + + idx = self.modelNames.index(model_name) + + self.downloadWin = apps.downloadModel(model_name, parent=self) + self.downloadWin.download() + + if model_name == "Automatic thresholding": + # Automatic thresholding is the name of the models as stored + # in self.modelNames, but the actual model is called thresholding + # (see cellacdc/models/thresholding) + model_name = "thresholding" + + posData = self.data[self.pos_i] + # Check if model needs to be imported + acdcSegment = self.acdcSegment_li[idx] + if acdcSegment is None: + self.logger.info(f"Importing {model_name}...") + acdcSegment = utils.import_segment_module(model_name) + self.acdcSegment_li[idx] = acdcSegment + + # Read all models parameters + init_params, segment_params = utils.getModelArgSpec(acdcSegment) + # Prompt user to enter the model parameters + try: + url = acdcSegment.url_help() + except AttributeError: + url = None + + if model_name == "thresholding": + autoThreshWin = apps.QDialogAutomaticThresholding( + parent=self, isSegm3D=self.isSegm3D + ) + autoThreshWin.exec_() + if autoThreshWin.cancel: + return + + win = self.initSegmModelParams( + model_name, acdcSegment, init_params, segment_params + ) + if win is None: + return + + self.standardPostProcessKwargs = win.standardPostProcessKwargs + self.customPostProcessFeatures = win.customPostProcessFeatures + self.customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + self.applyPostProcessing = win.applyPostProcessing + self.preproc_recipe = win.preproc_recipe + + utils.log_segm_params( + model_name, + win.init_kwargs, + win.model_kwargs, + logger_func=self.logger.info, + preproc_recipe=win.preproc_recipe, + apply_post_process=self.applyPostProcessing, + standard_postprocess_kwargs=self.standardPostProcessKwargs, + custom_postprocess_features=self.customPostProcessFeatures, + ) + + secondChannelData = None + if win.secondChannelName is not None: + secondChannelData = self.getSecondChannelData() + + use_gpu = win.init_kwargs.get("gpu", False) + proceed = utils.check_gpu_available(model_name, use_gpu, qparent=self) + if not proceed: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + + model = utils.init_segm_model(acdcSegment, posData, win.init_kwargs) + if model is None: + self.logger.info("Segmentation process cancelled.") + self.titleLabel.setText("Segmentation process cancelled.") + return + try: + model.setupLogger(self.logger) + except Exception as e: + pass + + self.extendSegmDataIfNeeded(stopFrameNum) + self.reInitLastSegmFrame(from_frame_i=startFrameNum - 1, updateImages=False) + + self.titleLabel.setText( + f"{model_name} is thinking... (check progress in terminal/console)", + color=self.titleColor, + ) + + self.progressWin = apps.QDialogWorkerProgress( + title="Segmenting video", + parent=self, + pbarDesc=f"Segmenting from frame n. {startFrameNum} to {stopFrameNum}...", + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(stopFrameNum - startFrameNum) + + self.thread = QThread() + self.worker = workers.segmVideoWorker( + posData, win, model, startFrameNum, stopFrameNum + ) + self.worker.secondChannelData = secondChannelData + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + # Custom signals + self.worker.critical.connect(self.workerCritical) + self.worker.finished.connect(self.segmVideoWorkerFinished) + self.worker.progressBar.connect(self.workerUpdateProgressbar) + self.worker.progress.connect(self.workerProgress) + + self.thread.started.connect(self.worker.run) + self.thread.start() + + def resetCursor(self): + if self.app.overrideCursor() is not None: + while self.app.overrideCursor() is not None: + self.app.restoreOverrideCursor() + + def segmFrameCallback(self, action): + if action == self.addCustomModelFrameAction: + return + + idx = self.segmActions.index(action) + model_name = self.modelNames[idx] + self.repeatSegm(model_name=model_name, askSegmParams=True) + + def segmVideoCallback(self, action): + if action == self.addCustomModelVideoAction: + return + + posData = self.data[self.pos_i] + win = apps.startStopFramesDialog( + posData.SizeT, currentFrameNum=posData.frame_i + 1 + ) + win.exec_() + if win.cancel: + self.logger.info("Segmentation on multiple frames aborted.") + return + + idx = self.segmActionsVideo.index(action) + model_name = self.modelNames[idx] + self.repeatSegmVideo(model_name, win.startFrame, win.stopFrame) + + def segmVideoWorkerFinished(self, exec_time): + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.activateAnnotations() + + self.get_data() + self.tracking(enforce=True) + self.updateAllImages() + + txt = f"Done. Segmentation computed in {exec_time:.3f} s" + self.logger.info("-----------------") + self.logger.info(txt) + self.logger.info("=================") + self.titleLabel.setText(txt, color="g") + + def segmWorkerFinished(self, lab, exec_time): + posData = self.data[self.pos_i] + + if posData.segmInfo_df is not None and posData.SizeZ > 1: + idx = (posData.filename, posData.frame_i) + posData.segmInfo_df.at[idx, "resegmented_in_gui"] = True + + if lab.ndim == 2 and self.isSegm3D: + self.set_2Dlab(lab) + else: + posData.lab = lab.copy() + + self.activateAnnotations() + + self.update_rp(wl_update=False) + self.tracking(enforce=True, against_next=posData.frame_i == 0) + + if self.isSnapshot: + self.fixCcaDfAfterEdit("Repeat segmentation") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Repeat segmentation") + + txt = f"Done. Segmentation computed in {exec_time:.3f} s" + self.logger.info("-----------------") + self.logger.info(txt) + self.logger.info("=================") + self.titleLabel.setText(txt, color="g") + self.checkIfAutoSegm() + + QTimer.singleShot(200, self.resizeGui) + + def segmentToolActionTriggered(self): + if self.segmModelName is None: + win = apps.QDialogSelectModel(parent=self) + win.exec_() + if win.cancel: + self.logger.info("Repeat segmentation cancelled.") + return + model_name = win.selectedModel + self.repeatSegm(model_name=model_name, askSegmParams=True) + else: + self.repeatSegm(model_name=self.segmModelName) + + def selectZtoolZvalueChanged(self, whichZ, z): + self.update_z_slice(z) + + def showInstructionsCustomModel(self): + modelFilePath = apps.addCustomModelMessages(self) + if modelFilePath is None: + self.logger.info("Adding custom model process stopped.") + return + + utils.store_custom_model_path(modelFilePath) + modelName = os.path.basename(os.path.dirname(modelFilePath)) + customModelAction = QAction(modelName) + self.segmSingleFrameMenu.addAction(customModelAction) + self.segmActions.append(customModelAction) + self.segmActionsVideo.append(customModelAction) + self.modelNames.append(modelName) + self.models.append(None) + self.sender().callback(customModelAction) diff --git a/cellacdc/mixins/session.py b/cellacdc/mixins/session.py new file mode 100644 index 000000000..ce0a395f9 --- /dev/null +++ b/cellacdc/mixins/session.py @@ -0,0 +1,695 @@ +"""Qt view adapter for session workflows.""" + +from __future__ import annotations + +import os +from functools import partial + +import numpy as np +import skimage.measure +from qtpy.QtWidgets import QAction + +from cellacdc import ( + exception_handler, + html_utils, + recentPaths_path, + settings_csv_path, + widgets, +) +from cellacdc.gui_decorators import get_data_exception_handler + +from .worker import Worker + + +class Session(Worker): + """Extracted from guiWin.""" + + def _get_data_unvisited( + self, + posData, + debug=False, + lin_tree_init=True, + ): + posData.editID_info = [] + proceed_cca = True + never_visited = True + if str(self.modeComboBox.currentText()) == "Cell cycle analysis": + # Warn that we are visiting a frame that was never segm-checked + # on cell cycle analysis mode + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Segmentation and Tracking was never checked from " + f"frame {posData.frame_i + 1} onwards.

    " + "To ensure correct cell cell cycle analysis you have to " + "first visit the frames after " + f'{posData.frame_i + 1} with "Segmentation and Tracking" mode.' + ) + warn_cca = msg.critical( + self, "Never checked segmentation on requested frame", txt + ) + proceed_cca = False + return proceed_cca, never_visited + + elif str(self.modeComboBox.currentText()) == "Normal division: Lineage tree": + # Warn that we are visiting a frame that was never segm-checked + # on cell cycle analysis mode + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Segmentation and Tracking was never checked from " + f"frame {posData.frame_i + 1} onwards.

    " + "To ensure correct lineage tree analysis you have to " + "first visit the frames after " + f'{posData.frame_i + 1} with "Segmentation and Tracking" mode.' + ) + warn_cca = msg.critical( # ??? + self, "Never checked segmentation on requested frame", txt + ) + proceed_cca = False + return proceed_cca, never_visited + + # Requested frame was never visited before. Load from HDD + labels = self.get_labels() + posData.lab = self.apply_manual_edits_to_lab_if_needed(labels) + posData.rp = skimage.measure.regionprops(posData.lab) + self.setManualBackgroundLab() + + if posData.acdc_df is not None: + frames = posData.acdc_df.index.get_level_values(0) + if posData.frame_i in frames: + # Since there was already segmentation metadata from + # previous closed session add it to current metadata + df = posData.acdc_df.loc[posData.frame_i].copy() + binnedIDs_df = df[df["is_cell_excluded"] > 0] + binnedIDs = set(binnedIDs_df.index).union(posData.binnedIDs) + posData.binnedIDs = binnedIDs + ripIDs_df = df[df["is_cell_dead"] > 0] + ripIDs = set(ripIDs_df.index).union(posData.ripIDs) + posData.ripIDs = ripIDs + posData.editID_info.extend(self._get_editID_info(df)) + # Load cca df into current metadata + if "cell_cycle_stage" in df.columns: + cca_cols = df.columns.intersection(self.cca_df_colnames) + cca_df = df[cca_cols].dropna() + if cca_df.empty: + df = df.drop(columns=self.cca_df_colnames, errors="ignore") + else: + df = df.loc[cca_df.index] + cols = self.cca_df_int_cols + df[cols] = df[cols].astype("Int64") + + i = posData.frame_i + posData.allData_li[i]["acdc_df"] = df.copy() + + if self.lineage_tree is None and lin_tree_init: + self.initLinTree() + + self.get_cca_df() + + return proceed_cca, never_visited + + def _get_data_visited( + self, + posData, + debug=False, + lin_tree_init=True, + ): + # Requested frame was already visited. Load from RAM. + never_visited = False + posData.lab = self.get_labels(from_store=True) + posData.rp = skimage.measure.regionprops(posData.lab) + df = posData.allData_li[posData.frame_i]["acdc_df"] + if df is None: + posData.binnedIDs = set() + posData.ripIDs = set() + posData.editID_info = [] + else: + try: + binnedIDs_df = df[df["is_cell_excluded"] > 0] + except Exception as err: + df = utils.fix_acdc_df_dtypes(df) + binnedIDs_df = df[df["is_cell_excluded"] > 0] + posData.binnedIDs = set(binnedIDs_df.index) + ripIDs_df = df[df["is_cell_dead"] > 0] + posData.ripIDs = set(ripIDs_df.index) + posData.editID_info = self._get_editID_info(df) + self.setManualBackgroundLab(load_from_store=True, debug=debug) + if self.lineage_tree is None and lin_tree_init: + self.initLinTree() + + self.get_cca_df(debug=debug) + + return True, never_visited + + def addPathToOpenRecentMenu(self, path): + for action in self.openRecentMenu.actions(): + if path == action.text(): + break + else: + action = QAction(path, self) + action.triggered.connect(partial(self.openRecentFile, path)) + + try: + firstAction = self.openRecentMenu.actions()[0] + self.openRecentMenu.insertAction(firstAction, action) + except Exception as e: + pass + + def getStoredSegmData(self): + posData = self.data[self.pos_i] + segm_data = [] + for data_frame_i in posData.allData_li: + lab = data_frame_i["labels"] + if lab is None: + break + segm_data.append(lab) + return np.array(segm_data) + + def get_data(self, debug=False, lin_tree_init=True): + posData = self.data[self.pos_i] + proceed_cca = True + never_visited = False + if posData.frame_i > 2: + # Remove undo states from 4 frames back to avoid memory issues + posData.UndoRedoStates[posData.frame_i - 4] = [] + # Check if current frame contains undo states (not empty list) + if posData.UndoRedoStates[posData.frame_i]: + self.undoAction.setDisabled(False) + elif posData.UndoRedoCcaStates[posData.frame_i]: + self.undoAction.setDisabled(False) + else: + self.undoAction.setDisabled(True) + self.UndoCount = 0 + # If stored labels is None then it is the first time we visit this frame + if posData.allData_li[posData.frame_i]["labels"] is None: + proceed_cca, never_visited = self._get_data_unvisited( + posData, + lin_tree_init=lin_tree_init, + ) + if not proceed_cca: + return proceed_cca, never_visited + else: + proceed_cca, never_visited = self._get_data_visited( + posData, lin_tree_init=lin_tree_init, debug=debug + ) + + self.update_rp_metadata(draw=False) + posData.IDs = [obj.label for obj in posData.rp] + posData.IDs_idxs = { + ID: i for ID, i in zip(posData.IDs, range(len(posData.IDs))) + } + self.get_zslices_rp() + self.pointsLayerDfsToData(posData) + return proceed_cca, never_visited + + def get_labels( + self, from_store=False, frame_i=None, return_existing=False, return_copy=True + ): + """Get the labels array. + + Parameters + ---------- + from_store : bool, optional + If True load the labels array from the stored posData.allData_li, + i.e., from RAM. Default is False + frame_i : int, optional + If None, use the current frame index. Default is None + return_existing : bool, optional + If True, the second return element will be a boolean that + is True if the labels array was found stored in `posData.allData_li`. + Default is False + return_copy : bool, optional + If True returns a copy of the labels array + + Returns + ------- + numpy.ndarray or tuple of (numpy.ndarray, bool) + The first element is the labels array requested. If `return_existing` + is True then this method also returns a second boolean element that + is True if the labels array was found in in `posData.allData_li`. + + Note + ---- + + If `from_store` is True then this method will try to get the stored + labels array. If any error occurs then the returned labels are the + saved ones in the segmentation file (i.e., from hard drive). + + """ + posData = self.data[self.pos_i] + if frame_i is None: + frame_i = posData.frame_i + + existing = True + if from_store: + try: + labels = posData.allData_li[frame_i]["labels"] + if labels is None: + from_store = False + except Exception as err: + from_store = False + + if not from_store: + try: + labels = posData.segm_data[frame_i] + except IndexError: + existing = False + # Visting a frame that was not segmented --> empty masks + if self.isSegm3D: + shape = (posData.SizeZ, posData.SizeY, posData.SizeX) + else: + shape = (posData.SizeY, posData.SizeX) + labels = np.zeros(shape, dtype=np.uint32) + return_copy = False + + if return_copy: + labels = labels.copy() + + if return_existing: + return labels, existing + else: + return labels + + def initPosAttr(self): + exp_path = self.data[self.pos_i].exp_path + pos_foldernames = utils.get_pos_foldernames(exp_path) + if len(pos_foldernames) == 1: + self.loadPosAction.setDisabled(True) + else: + self.loadPosAction.setDisabled(False) + + for p, posData in enumerate(self.data): + self.pos_i = p + posData.curvPlotItems = [] + posData.curvAnchorsItems = [] + posData.curvHoverItems = [] + posData.trackedLostIDs = set() + + posData.HDDmaxID = np.max(posData.segm_data) + + # Decision on what to do with changes to future frames attr + posData.doNotShowAgain_EditID = False + posData.UndoFutFrames_EditID = False + posData.applyFutFrames_EditID = False + + posData.doNotShowAgain_RipID = False + posData.UndoFutFrames_RipID = False + posData.applyFutFrames_RipID = False + + posData.doNotShowAgain_DelID = False + posData.UndoFutFrames_DelID = False + posData.applyFutFrames_DelID = False + + posData.doNotShowAgain_keepID = False + posData.UndoFutFrames_keepID = False + posData.applyFutFrames_keepID = False + + posData.doNotShowAgainAssignNewID = False + posData.UndoFutFramesAssignNewID = False + posData.applyFutFramesAssignNewID = False + + posData.includeUnvisitedInfo = { + "Delete ID": False, + "Edit ID": False, + "Keep ID": False, + } + + posData.loadTrackedLostCentroids() + posData.acdcTracker2stepsAnnotInfo = {} + + posData.doNotShowAgain_BinID = False + posData.UndoFutFrames_BinID = False + posData.applyFutFrames_BinID = False + + posData.disableAutoActivateViewerWindow = False + posData.new_IDs = [] + posData.lost_IDs = [] + posData.multiBud_mothIDs = [2] + posData.UndoRedoStates = [[] for _ in range(posData.SizeT)] + posData.UndoRedoCcaStates = [[] for _ in range(posData.SizeT)] + + posData.ol_data_dict = {} + posData.ol_data = None + + posData.ol_labels_data = None + + missing_frames = posData.SizeT - len(posData.allData_li) + if missing_frames > 0: + posData.allData_li.extend([None] * missing_frames) + for i in range(posData.SizeT): + if posData.allData_li[i] is None: + posData.allData_li[i] = utils.get_empty_stored_data_dict() + + posData.lutLevels = {channel: {} for channel in self.ch_names} + + posData.ccaStatus_whenEmerged = {} + + posData.frame_i = 0 + posData.brushID = 0 + posData.binnedIDs = set() + posData.ripIDs = set() + posData.cca_df = None + if posData.last_tracked_i is not None: + last_tracked_num = posData.last_tracked_i + 1 + # Load previous session data + # Keep track of which ROIs have already been added + # in previous frame + delROIshapes = [[] for _ in range(posData.SizeT)] + for i in range(last_tracked_num): + posData.frame_i = i + self.get_data(debug=True) + self.store_data( + enforce=True, autosave=False, store_cca_df_copy=True + ) + + # Ask whether to resume from last frame + if last_tracked_num > 1: + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Cell-ACDC detected a previous session ended " + f"at frame {last_tracked_num}.

    " + f"Do you want to resume from frame " + f"{last_tracked_num}?" + ) + noButton, yesButton = msg.question( + self, + "Start from last session?", + txt, + buttonsTexts=(" No ", "Yes"), + ) + self.AutoPilotProfile.storeClickMessageBox( + "Start from last session?", msg.clickedButton.text() + ) + if msg.clickedButton == yesButton: + posData.frame_i = posData.last_tracked_i + self.lastFrameRanOnFirstVisitTools = posData.frame_i + else: + posData.frame_i = 0 + + posData.img_data_min_max = (posData.img_data.min(), posData.img_data.max()) + + # Back to first position + self.pos_i = 0 + self.get_data(debug=False) + self.store_data(autosave=False) + # self.updateAllImages() + + # Link Y and X axis of both plots to scroll zoom and pan together + self.ax2.vb.setYLink(self.ax1.vb) + self.ax2.vb.setXLink(self.ax1.vb) + + self.setAllIDs() + + def loadLastSessionSettings(self): + self.settings_csv_path = settings_csv_path + if os.path.exists(settings_csv_path): + self.df_settings = pd.read_csv(settings_csv_path, index_col="setting") + if "is_bw_inverted" not in self.df_settings.index: + self.df_settings.at["is_bw_inverted", "value"] = "No" + else: + self.df_settings.loc["is_bw_inverted"] = self.df_settings.loc[ + "is_bw_inverted" + ].astype(str) + if "fontSize" not in self.df_settings.index: + self.df_settings.at["fontSize", "value"] = 12 + if "overlayColor" not in self.df_settings.index: + self.df_settings.at["overlayColor", "value"] = "255-255-0" + if "how_normIntensities" not in self.df_settings.index: + raw = "Do not normalize. Display raw image" + self.df_settings.at["how_normIntensities", "value"] = raw + else: + idx = ["is_bw_inverted", "fontSize", "overlayColor", "how_normIntensities"] + values = ["No", 12, "255-255-0", "raw"] + self.df_settings = pd.DataFrame( + {"setting": idx, "value": values} + ).set_index("setting") + + if "isLabelsVisible" not in self.df_settings.index: + self.df_settings.at["isLabelsVisible", "value"] = "No" + + if "isNextFrameVisible" not in self.df_settings.index: + self.df_settings.at["isNextFrameVisible", "value"] = "No" + + if "isRightImageVisible" not in self.df_settings.index: + self.df_settings.at["isRightImageVisible", "value"] = "Yes" + + if "manual_separate_draw_mode" not in self.df_settings.index: + col = "manual_separate_draw_mode" + self.df_settings.at[col, "value"] = "threepoints_arc" + + if "colorScheme" in self.df_settings.index: + col = "colorScheme" + self._colorScheme = self.df_settings.at[col, "value"] + else: + self._colorScheme = "light" + + self.doNotShowAgainMissingCca = False + if "doNotShowAgainMissingCca" not in self.df_settings.index: + self.df_settings.at["doNotShowAgainMissingCca", "value"] = "No" + else: + val = self.df_settings.at["doNotShowAgainMissingCca", "value"] + self.doNotShowAgainMissingCca = val == "Yes" + + def reInitGui(self): + cancel = self.checkAskSavePointsLayers() + if cancel: + return False + + if self.overlayToolbar.isTransparent(): + self.overlayToolbar.setTransparent(False) + + self.secondLevelToolbar.setVisible(False) + + self.gui_createLazyLoader() + + try: + self.navSpinBox.valueChanged.disconnect() + except Exception as e: + pass + + try: + self.scaleBar.removeFromAxis(self.ax1) + except Exception as e: + pass + + self.lineage_tree = None + self.getDistanceListMissingIDsCachedFrame = None + self.isZmodifier = False + self.zKeptDown = False + self.askRepeatSegment3D = True + self.askZrangeSegm3D = True + self.isDataLoaded = False + self.retainSizeLutItems = False + self.setMeasWinState = None + self.addPointsWin = None + self.delRoiLab = None + self.showPropsDockButton.setDisabled(True) + self.removeOverlayItems() + self.lutItemsLayout.addItem(self.imgGrad, row=0, col=0) + + self.reinitWidgetsPos() + self.removeAllItems() + self.reinitCustomAnnot() + self.reinitPointsLayers() + self.gui_createPlotItems() + self.setUncheckedAllButtons() + self.setUncheckedPointsLayers() + self.restoreDefaultColors() + self.reinitStoredSegmModels() + self.removeAxLimits() + self.curvToolButton.setChecked(False) + + self.wandControlsToolbar.setVisible(False) + self.wandToolButton.setChecked(False) + self.segmNdimIndicatorAction.setVisible(False) + + self.navigateToolBar.hide() + self.ccaToolBar.hide() + self.editToolBar.hide() + self.brushEraserToolBar.hide() + self.modeToolBar.hide() + + self.modeComboBox.setCurrentText("Viewer") + + alpha = self.imgGrad.labelsAlphaSlider.value() + self.labelsLayerImg1.setOpacity(alpha) + self.labelsLayerRightImg.setOpacity(alpha) + self.lastTrackedFrameLabel.setText("") + + self.promptSegmentPointsLayerToolbar.isPointsLayerInit = False + + for action in self.askHowFutureFramesActions.values(): + action.setChecked(True) + action.setDisabled(True) + + return True + + def readRecentPaths(self, recent_paths_path=None): + # Step 0. Remove the old options from the menu + self.openRecentMenu.clear() + + # Step 1. Read recent Paths + if recent_paths_path is None: + recent_paths_path = recentPaths_path + + if os.path.exists(recent_paths_path): + df = pd.read_csv(recent_paths_path, index_col="index") + df["path"] = df["path"].str.replace("\\", "/") + df = df.drop_duplicates(subset=["path"]) + df.to_csv(recent_paths_path) + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + recentPaths = df["path"].to_list() + else: + recentPaths = [] + + # Step 2. Dynamically create the actions + actions = [] + for path in recentPaths: + if not os.path.exists(path): + continue + action = QAction(path, self) + action.triggered.connect(partial(self.openRecentFile, path)) + actions.append(action) + + # Step 3. Add the actions to the menu + self.openRecentMenu.addActions(actions) + + def reinitWidgetsPos(self): + pass + + def store_data( + self, + pos_i=None, + enforce=True, + debug=False, + mainThread=True, + autosave=True, + store_cca_df_copy=False, + ): + pos_i = self.pos_i if pos_i is None else pos_i + posData = self.data[pos_i] + if posData.frame_i < 0: + # In some cases we set frame_i = -1 and then call next_frame + # to visualize frame 0. In that case we don't store data + # for frame_i = -1 + return + + mode = str(self.modeComboBox.currentText()) + + if mode == "Viewer" and not enforce: + return + + # if not mainThread: + # self.lin_tree_ask_changes() + + allData_li = posData.allData_li[posData.frame_i] + allData_li["regionprops"] = posData.rp.copy() + allData_li["labels"] = posData.lab.copy() + allData_li["IDs"] = posData.IDs.copy() + allData_li["manualBackgroundLab"] = posData.manualBackgroundLab + allData_li["IDs_idxs"] = posData.IDs_idxs.copy() + if self.manualAnnotPastButton.isChecked(): + self.store_manual_annot_data(posData=posData, data_frame_i=allData_li) + + self.store_zslices_rp() + + # Store dynamic metadata + is_cell_dead_li = [False] * len(posData.rp) + is_cell_excluded_li = [False] * len(posData.rp) + IDs = [0] * len(posData.rp) + xx_centroid = [0] * len(posData.rp) + yy_centroid = [0] * len(posData.rp) + if self.isSegm3D: + zz_centroid = [0] * len(posData.rp) + areManuallyEdited = [0] * len(posData.rp) + editedNewIDs = [vals[2] for vals in posData.editID_info] + for i, obj in enumerate(posData.rp): + is_cell_dead_li[i] = obj.dead + is_cell_excluded_li[i] = obj.excluded + IDs[i] = obj.label + try: + xx_centroid[i] = int(self.getObjCentroid(obj.centroid)[1]) + yy_centroid[i] = int(self.getObjCentroid(obj.centroid)[0]) + except Exception as err: + printl(obj, obj.centroid, obj.label, posData.frame_i) + if self.isSegm3D: + zz_centroid[i] = int(obj.centroid[0]) + if obj.label in editedNewIDs: + areManuallyEdited[i] = 1 + + posData.STOREDmaxID = max(IDs, default=0) + + acdc_df = allData_li["acdc_df"] + if acdc_df is None: + allData_li["acdc_df"] = pd.DataFrame( + { + "Cell_ID": IDs, + "is_cell_dead": is_cell_dead_li, + "is_cell_excluded": is_cell_excluded_li, + "x_centroid": xx_centroid, + "y_centroid": yy_centroid, + "was_manually_edited": areManuallyEdited, + } + ).set_index("Cell_ID") + + if self.isSegm3D: + allData_li["acdc_df"]["z_centroid"] = zz_centroid + else: + # Filter or add IDs that were not stored yet + acdc_df = acdc_df.drop(columns=["time_seconds"], errors="ignore") + acdc_df = acdc_df.reindex(IDs, fill_value=0) + acdc_df["is_cell_dead"] = is_cell_dead_li + acdc_df["is_cell_excluded"] = is_cell_excluded_li + acdc_df["x_centroid"] = xx_centroid + acdc_df["y_centroid"] = yy_centroid + if self.isSegm3D: + acdc_df["z_centroid"] = zz_centroid + acdc_df["was_manually_edited"] = areManuallyEdited + allData_li["acdc_df"] = acdc_df + + if mainThread: + self.pointsLayerDataToDf(posData) + + self.store_cca_df( + pos_i=pos_i, + mainThread=mainThread, + autosave=autosave, + store_cca_df_copy=store_cca_df_copy, + ) + + def store_manual_annot_data(self, posData=None, data_frame_i=None): + if posData is None: + posData = self.data[self.pos_i] + + if data_frame_i is None: + data_frame_i = posData.allData_li[posData.frame_i] + + if not self.isSegm3D: + lab = [posData.lab] + else: + lab = posData.lab + + for z, lab_2D in enumerate(lab): + data_frame_i["manually_edited_lab"]["lab"][z] = lab_2D + + def unstore_data(self): + posData = self.data[self.pos_i] + posData.allData_li[posData.frame_i] = utils.get_empty_stored_data_dict() + + def updateLastVisitedFrame(self, last_visited_frame_i=None): + if last_visited_frame_i is None: + posData = self.data[self.pos_i] + last_visited_frame_i = posData.frame_i + + mode = str(self.modeComboBox.currentText()) + if mode == "Viewer": + return + elif mode == "Segmentation and Tracking": + posData = self.data[self.pos_i] + if posData.last_tracked_i >= last_visited_frame_i: + return + posData.last_tracked_i = last_visited_frame_i + elif mode == "Cell cycle analysis": + if self.last_cca_frame_i >= last_visited_frame_i: + return + self.last_cca_frame_i = last_visited_frame_i diff --git a/cellacdc/mixins/status_hover.py b/cellacdc/mixins/status_hover.py new file mode 100644 index 000000000..fefeedbc2 --- /dev/null +++ b/cellacdc/mixins/status_hover.py @@ -0,0 +1,150 @@ +"""View adapter for hover and status-bar formatting.""" + +from __future__ import annotations + + +import math +import os +import re + +from .image_display import ImageDisplay + + +class StatusHover(ImageDisplay): + """Extracted from guiWin.""" + + def _addOverlayHoverValuesFormatted(self, txt, xdata, ydata): + posData = self.data[self.pos_i] + if posData.ol_data is None: + return txt + + for filename in posData.ol_data: + chName = utils.get_chname_from_basename( + filename, posData.basename, remove_ext=False + ) + if chName not in self.checkedOverlayChannels: + continue + + raw_overlay_img = self.getRawImage(filename=filename) + raw_overlay_value = raw_overlay_img[ydata, xdata] + # raw_overlay_max_value = raw_overlay_img.max() + + raw_txt = self._channelHoverValues("Raw", chName, raw_overlay_value) + + txt = f"{txt} | {raw_txt}" + return txt + + def _addRulerMeasurementText(self, txt): + posData = self.data[self.pos_i] + xx, yy = self.ax1_rulerPlotItem.getData() + if xx is None: + return txt + + lenPxl = math.sqrt((xx[0] - xx[1]) ** 2 + (yy[0] - yy[1]) ** 2) + depthAxes = self.switchPlaneCombobox.depthAxes() + if depthAxes != "z": + pxlToUm = posData.PhysicalSizeZ + else: + pxlToUm = posData.PhysicalSizeX + + length_txt = f"length = {int(lenPxl)} pxl ({lenPxl * pxlToUm:.2f} μm)" + txt = f"{txt} | Measurement: {length_txt}" + return txt + + def _channelHoverValues(self, descr, channel, value, ff=None): + if ff is None: + n_digits = len(str(int(value))) + ff = utils.get_number_fstring_formatter( + type(value), precision=abs(n_digits - 5) + ) + txt = f"{descr} {channel}: value={value:{ff}}" + return txt + + def getActiveToolButton(self): + for button in self.LeftClickButtons: + if button.isChecked(): + return button + + def updateValuesStatusBar(self): + (xl, xr), (yt, yb) = self.ax1ViewRange(integers=True) + W = round(xr - xl) + H = round(yb - yt) + txt = self.wcLabel.text() + pattern = ( + r"W=.*?, H=.*? \| " + r"x_left=.*?, y_top=.*? \| " + r"x_right=.*?, y_bottom=.*? \| " + ) + replacing = ( + f"W={W:d}, H={H:d} | " + f"x_left={xl:d}, y_top={yt:d} | " + f"x_right={xr:d}, y_bottom={yb:d} | " + ) + txt = re.sub(pattern, replacing, txt) + self.wcLabel.setText(txt) + + def hoverValuesFormatted(self, xdata, ydata, activeToolButton, is_ax0): + (xl, xr), (yt, yb) = self.ax1ViewRange(integers=True) + W = round(xr - xl) + H = round(yb - yt) + ax_idx = 0 if is_ax0 else 1 + txt = ( + f"x={xdata:d}, y={ydata:d} | " + f"W={W:d}, H={H:d} | " + f"x_left={xl:d}, y_top={yt:d} | " + f"x_right={xr:d}, y_bottom={yb:d} | " + f"(ax{ax_idx})" + ) + if activeToolButton == self.rulerButton: + txt = self._addRulerMeasurementText(txt) + return txt + elif activeToolButton is not None: + return txt + + posData = self.data[self.pos_i] + + raw_img = self.getRawImage() + raw_value = raw_img[ydata, xdata] + # raw_max_value = raw_img.max() + + ch = self.user_ch_name + raw_txt = self._channelHoverValues("Raw", ch, raw_value) + + txt = f"{txt} | {raw_txt}" + + txt = self._addOverlayHoverValuesFormatted(txt, xdata, ydata) + + ID = self.currentLab2D[ydata, xdata] + maxID = max(posData.IDs, default=0) + + num_obj = len(posData.IDs) + lab_txt = ( + f"Objects: ID={ID}, max ID={maxID}, num. of objects={num_obj}" + ) + txt = f"{txt} | {lab_txt}" + + txt = self._addRulerMeasurementText(txt) + return txt + + def setStatusBarLabel(self, log=True): + self.statusbar.clearMessage() + posData = self.data[self.pos_i] + segmentedChannelname = posData.filename[len(posData.basename) :] + segmFilename = os.path.basename(posData.segm_npz_path) + segmEndName = segmFilename[len(posData.basename) :] + txt = ( + f"{posData.pos_foldername} || " + f"Basename: {posData.basename} || " + f"Segmented channel: {segmentedChannelname} || " + f"Segmentation file name: {segmEndName}" + ) + mode = str(self.modeComboBox.currentText()) + if log: + self.logger.info(txt) + self.statusBarLabel.setText(txt) + + def getRulerLengthText(self): + text = self.wcLabel.text() + lengthText = re.findall(r"length = (.*)\)", text)[0] + lengthText = lengthText.replace("pxl", "pixels") + return f"{lengthText})" diff --git a/cellacdc/mixins/tool_activation.py b/cellacdc/mixins/tool_activation.py new file mode 100644 index 000000000..012d595b7 --- /dev/null +++ b/cellacdc/mixins/tool_activation.py @@ -0,0 +1,808 @@ +"""Qt view adapter for active-tool workflows.""" + +from __future__ import annotations + +import numpy as np +from qtpy.QtCore import QEventLoop, QThread, QTimer, Qt + +from cellacdc import apps, qutils, widgets, workers +from cellacdc import disableWindow + +from .session import Session + + +class ToolActivation(Session): + """Extracted from guiWin.""" + + def _copyAllLostObjects_navigateToFrame(self, frame_i): + posData = self.data[self.pos_i] + self.store_data(mainThread=False, autosave=False) + + posData.frame_i = frame_i + self.get_data() + self.tracking(wl_update=False) + self.currentLab2D = self.get_2Dlab(posData.lab) + self.update_rp() + self.updateLostNewCurrentIDs() + self.store_data(mainThread=False, autosave=False) + + self.lostObjContoursImage[:] = 0 + self.lostObjImage[:] = 0 + prev_rp = posData.allData_li[frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[frame_i - 1][ + "IDs_idxs" + ] # need to change this when merging with opt. + for lostID in posData.lost_IDs: + obj = prev_rp[prev_IDs_idxs[lostID]] + self.addLostObjsToLostObjImage(obj, lostID, force=True) + + def _copyAllLostObjects_refreshRp(self): + self.update_rp( + draw=False, wl_update=False + ) # need to change this when merging with opt. + + def _copyAllLostObjects_returnToFrame(self, frame_i): + posData = self.data[self.pos_i] + self.store_data(autosave=False, mainThread=False) + posData.frame_i = frame_i + self.get_data() + + def addLostObjsToLostObjImage(self, lostObj, lostID, force=False): + if not force: + if not self.copyLostObjButton.isChecked(): + return + + obj_slice = self.getObjSlice(lostObj.slice) + obj_image = self.getObjImage(lostObj.image, lostObj.bbox) + self.lostObjImage[obj_slice][obj_image] = lostID + + def annotLostObjsToggled(self, checked): + if not self.isDataLoaded: + return + self.updateAllImages() + + def clearTempBrushImage(self, forceClearLinked=True): + if not hasattr(self, "tempLayerImg1"): + return + + self.tempLayerImg1.setImage(self.emptyLab, force_set_linked=forceClearLinked) + + try: + self.brushContourImage[:] = 0 + except Exception as err: + pass + + try: + self.brushImage[:] = 0 + except Exception as err: + pass + + def connectLeftClickButtons(self): + self.brushButton.toggled.connect(self.Brush_cb) + self.curvToolButton.toggled.connect(self.curvTool_cb) + self.rulerButton.toggled.connect(self.ruler_cb) + self.eraserButton.toggled.connect(self.Eraser_cb) + self.wandToolButton.toggled.connect(self.wand_cb) + self.labelRoiButton.toggled.connect(self.labelRoi_cb) + self.magicPromptsToolButton.toggled.connect(self.magicPrompts_cb) + self.drawClearRegionButton.toggled.connect(self.drawClearRegion_cb) + self.expandLabelToolButton.toggled.connect(self.expandLabelCallback) + self.addDelPolyLineRoiButton.toggled.connect(self.addDelPolyLineRoi_cb) + self.manualBackgroundButton.toggled.connect(self.manualBackground_cb) + self.whitelistIDsButton.toggled.connect(self.whitelistIDs_cb) + self.zoomRectButton.toggled.connect(self.zoomRectActionToggled) + self.connectLeftClickButtonsPointsLayersToolbar() + + def connectLeftClickButtonsPointsLayersToolbar(self): + for toolbar in self.pointsLayersToolbars: + for action in toolbar.actions()[1:]: + if not hasattr(action, "layerTypeIdx"): + continue + if action.layerTypeIdx != 4: + continue + action.button.toggled.connect(self.addPointsByClickingButtonToggled) + + def copyAllLostObjects(self, for_future_frame_n, max_overlap_perc): + if not self.copyLostObjButton.isChecked(): + return + + posData = self.data[self.pos_i] + + desc = "Copying all lost objects..." + + self.progressWin = apps.QDialogWorkerProgress( + title=desc, parent=self.mainWin, pbarDesc=desc + ) + self.progressWin.mainPbar.setMaximum(for_future_frame_n + 1) + self.progressWin.show(self.app) + + self.copyAllLostObjectsThread = QThread() + + self.copyAllLostObjectsWorker = workers.CopyAllLostObjectsWorker( + self, posData, for_future_frame_n, max_overlap_perc + ) + self.copyAllLostObjectsWorker.moveToThread(self.copyAllLostObjectsThread) + + self.copyAllLostObjectsWorker.navigateToFrame.connect( + self._copyAllLostObjects_navigateToFrame, Qt.BlockingQueuedConnection + ) + self.copyAllLostObjectsWorker.returnToFrame.connect( + self._copyAllLostObjects_returnToFrame, Qt.BlockingQueuedConnection + ) + self.copyAllLostObjectsWorker.copyLostObjectMask.connect( + self.copyLostObjectMask, Qt.BlockingQueuedConnection + ) + self.copyAllLostObjectsWorker.refreshRp.connect( + self._copyAllLostObjects_refreshRp, Qt.BlockingQueuedConnection + ) + self.copyAllLostObjectsWorker.progressBar.connect(self.workerUpdateProgressbar) + self.copyAllLostObjectsWorker.critical.connect( + self.copyAllLostObjectsWorkerCritical + ) + self.copyAllLostObjectsWorker.finished.connect( + self.copyAllLostObjectsThread.quit + ) + self.copyAllLostObjectsWorker.finished.connect( + self.copyAllLostObjectsWorker.deleteLater + ) + self.copyAllLostObjectsThread.finished.connect( + self.copyAllLostObjectsThread.deleteLater + ) + self.copyAllLostObjectsWorker.finished.connect( + self.copyAllLostObjectsWorkerFinished + ) + + self.copyAllLostObjectsThread.started.connect(self.copyAllLostObjectsWorker.run) + self.copyAllLostObjectsThread.start() + + self.copyAllLostObjectsWorkerLoop = QEventLoop() + self.copyAllLostObjectsWorkerLoop.exec_() + + def copyAllLostObjectsWorkerCritical(self, error): + self.copyAllLostObjectsWorkerLoop.exit() + self.workerCritical(error) + + def copyAllLostObjectsWorkerFinished(self, output): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + if output.get("doReinitLastSegmFrame", False): + self.reInitLastSegmFrame( + from_frame_i=output.get("last_visited_frame_i"), + updateImages=False, + force=True, + ) + + if output.get("overlap_warning", False): + self.blinker = qutils.QControlBlink( + self.copyLostObjToolbar.maxOverlapNumberControl, qparent=self.mainWin + ) + self.blinker.start() + + self.copyAllLostObjectsWorkerLoop.exit() + self.update_rp() + self.updateAllImages() + self.store_data() + + def copyLostObjContour_cb(self, checked): + self.copyLostObjToolbar.setVisible(checked) + + self.ax1_lostObjScatterItem.hoverLostID = 0 + if not checked: + return + + self.lostObjImage = np.zeros_like(self.currentLab2D) + self.updateLostContoursImage(0) + + def copyLostObjectMask(self, ID: int): + posData = self.data[self.pos_i] + mask = self.lostObjImage == ID + lab2D = self.get_2Dlab(posData.lab) + lab2D[mask] = ID + self.lostObjImage[mask] = 0 + self.set_2Dlab(lab2D) + + def disableNonFunctionalButtons(self): + if not self.isSegm3D: + return + + for item in self.functionsNotTested3D: + if hasattr(item, "action"): + toolButton = item + action = toolButton.action + toolButton.setDisabled(True) + elif hasattr(item, "toolbar"): + toolbar = item.toolbar + action = item + toolButton = toolbar.widgetForAction(action) + toolButton.setDisabled(True) + else: + action = item + action.setDisabled(True) + + def disconnectLeftClickButtons(self): + for button in self.LeftClickButtons: + try: + button.toggled.disconnect() + except Exception as e: + # Not all the LeftClickButtons have toggled connected + pass + + def getPrevFrameIDs(self, current_frame_i=None): + posData = self.data[self.pos_i] + if current_frame_i is None: + current_frame_i = posData.frame_i + + if current_frame_i is None: + return [] + + prev_frame_i = current_frame_i - 1 + prevIDs = posData.allData_li[prev_frame_i]["IDs"] + + if prevIDs: + return prevIDs + + # IDs in previous frame were not stored --> load prev lab from HDD + prev_lab = self.get_labels( + from_store=False, frame_i=prev_frame_i, return_copy=False + ) + rp = skimage.measure.regionprops(prev_lab) + prevIDs = [obj.label for obj in rp] + return prevIDs + + def hideItemsHoverBrush(self, xy=None, ID=None, force=False): + if xy is not None: + x, y = xy + if x is None: + return + + xdata, ydata = int(x), int(y) + Y, X = self.currentLab2D.shape + + if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): + return + + if not self.brushAutoHideCheckbox.isChecked() and not force: + return + + posData = self.data[self.pos_i] + size = self.brushSizeSpinbox.value() * 2 + + if xy is not None: + ID = self.get_2Dlab(posData.lab)[ydata, xdata] + + if self.ax1_lostObjScatterItem.isVisible(): + self.ax1_lostObjScatterItem.setVisible(False) + + if self.ax1_lostTrackedScatterItem.isVisible(): + self.ax1_lostTrackedScatterItem.setVisible(False) + + if self.ax2_lostObjScatterItem.isVisible(): + self.ax2_lostObjScatterItem.setVisible(False) + + if self.ax2_lostTrackedScatterItem.isVisible(): + self.ax2_lostTrackedScatterItem.setVisible(False) + + # Restore ID previously hovered + if ID != self.ax1BrushHoverID and not self.isMouseDragImg1: + try: + self.restoreHoverObjBrush() + except Exception as e: + self.ax1BrushHoverID = 0 + return + + # Hide items hover ID + if ID != 0: + self.clearObjContour(ID=ID, ax=0) + self.clearObjContour(ID=ID, ax=1) + self.ax1BrushHoverID = ID + else: + self.ax1BrushHoverID = 0 + + def highlightHoverLostObj(self, modifiers, event): + noModifier = modifiers == Qt.NoModifier + if not noModifier: + return + + if not self.copyLostObjButton.isChecked(): + return + + if event.isExit(): + return + + posData = self.data[self.pos_i] + x, y = event.pos() + xdata, ydata = int(x), int(y) + try: + hoverLostID = self.lostObjImage[ydata, xdata] + except IndexError: + return + + self.ax1_lostObjScatterItem.hoverLostID = hoverLostID + if hoverLostID == 0: + self.ax1_lostObjScatterItem.setSize(self.contLineWeight + 1) + self.ax1_lostObjScatterItem.setData([], []) + else: + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + prev_IDs_idxs = posData.allData_li[posData.frame_i - 1]["IDs_idxs"] + lostObj = prev_rp[prev_IDs_idxs[hoverLostID]] + obj_contours = self.getObjContours(lostObj, all_external=True) + for cont in obj_contours: + xx = cont[:, 0] + yy = cont[:, 1] + self.ax1_lostObjScatterItem.addPoints(xx, yy) + self.ax1_lostObjScatterItem.setSize(self.contLineWeight + 2) + + def highlightLostNew(self): + if self.modeComboBox.currentText() == "Viewer": + return + + posData = self.data[self.pos_i] + delROIsIDs = self.getDelRoisIDs() + + # self.setAllContoursImages(delROIsIDs=delROIsIDs) + if posData.frame_i == 0: + return + + if not self.annotLostObjsToggle.isChecked(): + return + + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + + if prev_rp is None: + return + + self.setAllLostObjContoursImage(delROIsIDs=delROIsIDs) + self.setAllLostTrackedObjContoursImage(delROIsIDs=delROIsIDs) + + def highlightManualAnnotMode(self, viewBox, viewRange): + self.ax1.setHighlighted(True) + + def magicPrompts_cb(self, checked): + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.magicPromptsToolButton) + self.connectLeftClickButtons() + self.magicPromptsToolbar.setVisible(True) + self.promptSegmentPointsLayerToolbar.setVisible(True) + if not self.promptSegmentPointsLayerToolbar.isPointsLayerInit: + self.addPointsLayerTriggered( + toolbar=self.promptSegmentPointsLayerToolbar + ) + else: + self.resetCursors() + self.promptSegmentPointsLayerToolbar.setVisible(False) + self.magicPromptsToolbar.setVisible(False) + + def manualAnnotPast_cb(self, checked): + posData = self.data[self.pos_i] + if checked: + for _ in range(3): + self.onEscape( + buttonsToNotUncheck=[self.manualAnnotPastButton], doAutoRange=False + ) + + self.brushButton.setChecked(True) + self.store_data() + self.manualAnnotState = { + "editID": self.editIDspinbox.value(), + "isAutoID": self.autoIDcheckbox.isChecked(), + "doWarnLostObj": self.warnLostCellsAction.isChecked(), + } + self.autoIDcheckbox.setChecked(False) + self.warnLostCellsAction.setChecked(False) + hoverID = self.getLastHoveredID() + if hoverID == 0: + win = apps.QLineEditDialog( + title="Not hovering any ID", + msg="You are not hovering on any ID.\n" + "Enter the ID that you want to lock.", + parent=self, + isInteger=True, + defaultTxt=self.setBrushID(return_val=True), + ) + win.exec_() + if win.cancel: + self.manualAnnotPastButton.setChecked(False) + return + hoverID = win.EntryID + self.logger.info( + "Setting manual annotation for ID = " + f"{hoverID}, at frame n. {posData.frame_i + 1}" + ) + self.editIDspinbox.setValue(hoverID) + try: + obj_idx = posData.IDs_idxs[hoverID] + obj = posData.rp[obj_idx] + radius = ( + 0.9 * obj.minor_axis_length / 2 + ) # math.sqrt(obj.area/math.pi)*0.9 + self.brushSizeSpinbox.setValue(round(radius)) + except Exception as err: + pass + + self.manualAnnotState["frame_i_to_restore"] = posData.frame_i + self.manualAnnotState["last_tracked_i"] = ( + self.navigateScrollBar.maximum() - 1 + ) + self.ax1.sigRangeChanged.connect(self.highlightManualAnnotMode) + self.ax1.setHighlighted(True, color="green") + else: + self.setStatusBarLabel() + self.autoIDcheckbox.setChecked(self.manualAnnotState["isAutoID"]) + self.editIDspinbox.setValue(self.manualAnnotState["editID"]) + self.warnLostCellsAction.setChecked(self.manualAnnotState["doWarnLostObj"]) + frame_to_restore = self.manualAnnotState.get("frame_i_to_restore") + if frame_to_restore is None: + return + + self.store_data() + self.store_manual_annot_data() + + last_tracked_i_to_restore = self.manualAnnotState["last_tracked_i"] + self.manualAnnotRestoreLastTrackedFrame(last_tracked_i_to_restore) + + self.logger.info(f"Restoring view to frame n. {posData.frame_i + 1}...") + posData.frame_i = frame_to_restore + self.get_data() + self.updateAllImages() + self.updateScrollbars() + self.ax1.sigRangeChanged.disconnect() + self.ax1.setHighlighted(False) + QTimer.singleShot(150, self.autoRange) + + self.setManualAnnotModeEnabledTools(checked) + + def onEscape( + self, + isTypingIDFunctionChecked=False, + buttonsToNotUncheck=None, + doAutoRange=True, + ): + if buttonsToNotUncheck is None: + buttonsToNotUncheck = set() + + if self.keepIDsButton.isChecked() and self.keptObjectsIDs: + self.keptObjectsIDs = widgets.KeptObjectIDsList( + self.keptIDsLineEdit, self.keepIDsConfirmAction + ) + self.highlightHoverIDsKeptObj(0, 0, hoverID=0) + QTimer.singleShot(300, self.autoRange) + return + + if self.brushButton.isChecked() and self.typingEditID: + self.autoIDcheckbox.setChecked(True) + self.typingEditID = False + QTimer.singleShot(300, self.autoRange) + return + + if isTypingIDFunctionChecked and self.typingEditID: + self.typingEditID = False + QTimer.singleShot(300, self.autoRange) + return + + if self.labelRoiButton.isChecked() and self.isMouseDragImg1: + self.isMouseDragImg1 = False + self.labelRoiItem.setPos((0, 0)) + self.labelRoiItem.setSize((0, 0)) + self.freeRoiItem.clear() + QTimer.singleShot(300, self.autoRange) + return + + if self.zoomRectButton.isChecked(): + self.zoomRectCancelled() + QTimer.singleShot(300, self.autoRange) + return + + self.setUncheckedAllButtons(buttonsToNotUncheck=buttonsToNotUncheck) + self.setUncheckedAllCustomAnnotButtons() + self.setUncheckedPointsLayers() + self.clearTempBrushImage() + self.isMouseDragImg1 = False + self.typingEditID = False + self.clearHighlightedID() + try: + self.polyLineRoi.clearPoints() + except Exception as e: + pass + + if doAutoRange: + QTimer.singleShot(11, self.autoRange) + + def restoreHoverObjBrush(self): + posData = self.data[self.pos_i] + if self.ax1BrushHoverID in posData.IDs: + obj_idx = posData.IDs_idxs[self.ax1BrushHoverID] + obj = posData.rp[obj_idx] + if not self.isObjVisible(obj.bbox): + return + + self.addObjContourToContoursImage(obj=obj, ax=0) + self.addObjContourToContoursImage(obj=obj, ax=1) + + def setLostNewOldPrevIDs(self): + posData = self.data[self.pos_i] + if posData.frame_i == 0: + posData.lost_IDs = [] + posData.new_IDs = [] + posData.old_IDs = [] + # posData.multiContIDs = set() + self.titleLabel.setText("Looking good!", color=self.titleColor) + return [] + + # elif self.modeComboBox.currentText() == 'Viewer': + # pass + + out = self.updateLostNewCurrentIDs() + lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs, curr_delRoiIDs = out + self.setTitleText(lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs) + return curr_delRoiIDs + + def setManualAnnotModeEnabledTools(self, enabled): + for action in self.editToolBar.actions(): + toolButton = self.editToolBar.widgetForAction(action) + if toolButton in self.manulAnnotToolButtons: + continue + + toolButton.setDisabled(enabled) + action.setDisabled(enabled) + + def setTitleFormatter(self, htmlTxt_li, htmlTxtFull_li, pretxt, color, IDs): + if not IDs: + return htmlTxt_li, htmlTxtFull_li + + if isinstance(IDs, set): + IDs = list(IDs) + + trim_IDs = utils.get_trimmed_list(IDs) + txt = f"{pretxt}: {trim_IDs}" + txt_full = f"{pretxt}:
    {IDs}" + + txt = f'{txt}' + txt_full = f'{txt_full}' + + htmlTxt_li.append(txt) + htmlTxtFull_li.append(txt_full) + + return htmlTxt_li, htmlTxtFull_li + + def setTitleText( + self, lost_IDs=None, new_IDs=None, IDs_with_holes=None, tracked_lost_IDs=None + ): + if self.manualAnnotPastButton.isChecked(): + lockedID = self.editIDspinbox.value() + frame_to_restore = self.manualAnnotState.get("frame_i_to_restore") + txt = f"Locked ID {lockedID} since frame n. {frame_to_restore + 1}" + htmlTxt = f'{txt}' + self.titleLabel.setText(htmlTxt) + return + + mode = self.modeComboBox.currentText() + try: + posData = self.data[self.pos_i] + posData.segm_data[posData.frame_i] + prev_segmented = True + except IndexError: + prev_segmented = False + + if prev_segmented: + htmlTxt_li = [] + htmlTxtFull_li = [] + else: + htmlTxt = f'Never segmented frame. ' + self.titleLabel.setText(htmlTxt) + self.titleLabel.setToolTip(htmlTxt) + return + + if mode != "Normal division: Lineage tree": + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "IDs lost", "orange", lost_IDs + ) + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "New IDs", "red", new_IDs + ) + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "Acc. IDs lost", "green", tracked_lost_IDs + ) + + for i, htmlTxtFull in enumerate(htmlTxtFull_li): + htmlTxtFull_li[i] = htmlTxtFull.replace("Acc.", "Accepted") + + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "IDs with holes", "red", IDs_with_holes + ) + else: + try: + cells_with_parent, orphan_cells, lost_cells = ( + self.lineage_tree.export_lin_tree_info(posData.frame_i) + ) + except IndexError or KeyError: + title = "Processing lineage tree..." + htmlTxt = f'{title}' + self.titleLabel.setText(htmlTxt) + self.titleLabel.setToolTip(htmlTxt) + return + except AttributeError: + title = "Lineage tree still initializing..." + htmlTxt = f'{title}' + self.titleLabel.setText(htmlTxt) + self.titleLabel.setToolTip(htmlTxt) + return + + parent_cell_txt_raw = [] + if cells_with_parent: + # aggregate same parents + parent_cell_groups = dict() + for cell, parent in cells_with_parent: + if parent not in parent_cell_groups: + parent_cell_groups[parent] = [] + parent_cell_groups[parent].append(cell) + for parent, daughters in parent_cell_groups.items(): + cells_str = ",".join([str(daughter) for daughter in daughters]) + parent_cell_txt_raw.append(f"({parent}>{cells_str})") + + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "New w/out mother", "red", orphan_cells + ) + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, htmlTxtFull_li, "Lost", "yellow", lost_cells + ) + htmlTxt_li, htmlTxtFull_li = self.setTitleFormatter( + htmlTxt_li, + htmlTxtFull_li, + "Parent > Cell", + "green", + parent_cell_txt_raw, + ) + + if not htmlTxt_li: + title = "Looking good" + htmlTxt = f'{title}' + self.titleLabel.setText(htmlTxt) + self.titleLabel.setToolTip(htmlTxt) + return + + htmlTxt = ", ".join(htmlTxt_li) + htmlTxtFull = "
    ".join(htmlTxtFull_li) + + self.titleLabel.setText(htmlTxt) + self.titleLabel.setToolTip(htmlTxtFull) + + def setUncheckedAllButtons(self, buttonsToNotUncheck=None): + self.clickedOnBud = False + if buttonsToNotUncheck is None: + buttonsToNotUncheck = set() + + try: + self.BudMothTempLine.setData([], []) + except Exception as e: + pass + for button in self.checkableButtons: + if button in buttonsToNotUncheck: + continue + button.setChecked(False) + + if self.countObjsButton not in buttonsToNotUncheck: + self.countObjsButton.setChecked(False) + self.splineHoverON = False + self.tempSegmentON = False + self.isRightClickDragImg1 = False + self.clearCurvItems(removeItems=False) + + def setUncheckedAllCustomAnnotButtons(self): + for button in self.customAnnotDict.keys(): + button.setChecked(False) + + def setUncheckedPointsLayers(self): + self.togglePointsLayerAction.setChecked(False) + self.magicPromptsToolButton.setChecked(False) + + def uncheckLeftClickButtons(self, sender): + for button in self.LeftClickButtons: + if button != sender: + button.setChecked(False) + + if button != self.labelRoiButton: + # self.labelRoiButton is disconnected so we manually call uncheck + self.labelRoi_cb(False) + self.secondLevelToolbar.setVisible(True) + for toolbar in self.controlToolBars: + try: + toolbar.keepVisibleWhenActive + if toolbar.isVisible(): + self.secondLevelToolbar.setVisible(False) + continue + except: + pass + toolbar.setVisible(False) + + self.enableSizeSpinbox(False) + if sender is not None: + self.keepIDsButton.setChecked(False) + + def uncheckQButton(self, button): + # Manual exclusive where we allow to uncheck all buttons + for b in self.checkableQButtonsGroup.buttons(): + if b != button: + b.setChecked(False) + + def updateBrushCursor(self, x, y, isHoverImg1=True): + if x is None: + return + + xdata, ydata = int(x), int(y) + _img = self.currentLab2D + Y, X = _img.shape + + if not (xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y): + return + + size = self.brushSizeSpinbox.value() * 2 + self.setHoverToolSymbolData( + [x], [y], self.activeBrushCircleCursors(isHoverImg1), size=size + ) + self.setHoverToolSymbolColor( + xdata, + ydata, + self.ax2_BrushCirclePen, + self.activeBrushCircleCursors(isHoverImg1), + self.brushButton, + brush=self.ax2_BrushCircleBrush, + ) + + def updateHighlightedAxis(self): + if not self.manualAnnotPastButton.isChecked(): + return + + frame_to_restore = self.manualAnnotState.get("frame_i_to_restore") + posData = self.data[self.pos_i] + if posData.frame_i == frame_to_restore: + color = "green" + elif posData.frame_i < frame_to_restore: + color = "gold" + else: + color = "red" + + self.ax1.setHighlightingRectItemsColor(color) + + def updateLostNewCurrentIDs(self): + posData = self.data[self.pos_i] + + prev_IDs = self.getPrevFrameIDs() + tracked_lost_IDs = self.getTrackedLostIDs() + curr_IDs = posData.IDs + curr_delRoiIDs = self.getStoredDelRoiIDs() + prev_delRoiIDs = self.getStoredDelRoiIDs(frame_i=posData.frame_i - 1) + lost_IDs = [ + ID + for ID in prev_IDs + if ID not in curr_IDs + and ID not in prev_delRoiIDs + and ID not in tracked_lost_IDs + ] + new_IDs = [ + ID for ID in curr_IDs if ID not in prev_IDs and ID not in curr_delRoiIDs + ] + IDs_with_holes = [] + posData.lost_IDs = lost_IDs + posData.new_IDs = new_IDs + posData.old_IDs = prev_IDs + posData.IDs = curr_IDs + + out = (lost_IDs, new_IDs, IDs_with_holes, tracked_lost_IDs, curr_delRoiIDs) + return out + + def wand_cb(self, checked): + posData = self.data[self.pos_i] + if checked: + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.wandToolButton) + self.connectLeftClickButtons() + self.wandControlsToolbar.setVisible(True) + # self.secondLevelToolbar.setVisible(False) + else: + self.resetCursors() + # self.secondLevelToolbar.setVisible(True) + self.wandControlsToolbar.setVisible(False) diff --git a/cellacdc/mixins/tracking.py b/cellacdc/mixins/tracking.py new file mode 100644 index 000000000..a132263c4 --- /dev/null +++ b/cellacdc/mixins/tracking.py @@ -0,0 +1,1357 @@ +"""Qt view adapter for tracking and manual tracking workflows.""" + +from __future__ import annotations + +import cv2 +from functools import partial +from typing import Iterable, List, Set + +import numpy as np +import pyqtgraph as pg +import skimage.measure +from tqdm import tqdm +from qtpy.QtCore import QTimer +from qtpy.QtGui import QFont + +from cellacdc import apps, exception_handler, html_utils, widgets +from cellacdc.trackers.CellACDC import CellACDC_tracker + + +font_13px = QFont() +font_13px.setPixelSize(13) + +from .undo_redo import UndoRedo + + +class Tracking(UndoRedo): + """Extracted from guiWin.""" + + def _drawGhostContour(self, x, y): + if self.ghostObject is None: + return + + ID = self.ghostObject.label + yc, xc = self.ghostObject.local_centroid + Dx = x - xc + Dy = y - yc + xx = self.ghostObject.xx_contour + Dx + yy = self.ghostObject.yy_contour + Dy + self.ghostContourItemLeft.setData( + xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x + ) + self.ghostContourItemRight.setData( + xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x + ) + + def _drawGhostMask(self, x, y): + if self.ghostObject is None: + return + + self.clearGhostMask() + ID = self.ghostObject.label + h, w = self.ghostObject.image.shape[-2:] + yc, xc = self.ghostObject.local_centroid + Dx = int(x - xc) + Dy = int(y - yc) + bbox = ((Dy, Dy + h), (Dx, Dx + w)) + + Y, X = self.currentLab2D.shape + slices = utils.get_slices_local_into_global_arr(bbox, (Y, X)) + slice_global_to_local, slice_crop_local = slices + + obj_image = self.ghostObject.image[slice_crop_local] + + self.ghostMaskItemLeft.image[slice_global_to_local][obj_image] = ID + self.ghostMaskItemLeft.updateGhostImage( + fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x + ) + + self.ghostMaskItemRight.image[slice_global_to_local][obj_image] = ID + self.ghostMaskItemRight.updateGhostImage( + fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x + ) + + def _drawManualBackgroundObjContour(self, x, y): + if self.manualBackgroundObj is None: + return + + ID = self.manualBackgroundObj.label + yc, xc = self.manualBackgroundObj.local_centroid + Dx = x - xc + Dy = y - yc + xx = self.manualBackgroundObj.xx_contour + Dx + yy = self.manualBackgroundObj.yy_contour + Dy + self.manualBackgroundObjItem.setData( + xx, yy, fontSize=self.fontSize, ID=ID, y_cursor=y, x_cursor=x + ) + + def addManualBackgroundItems(self): + self.manualBackgroundObjItem.addToPlotItem() + self.ax1.addItem(self.manualBackgroundImageItem) + + def addManualBackgroundObject(self, x, y): + posData = self.data[self.pos_i] + + if not hasattr(self, "manualBackgroundObj"): + self.initManualBackgroundObject() + + Y, X = self.currentLab2D.shape + ymin, xmin, ymax, xmax = self.manualBackgroundObj.bbox + width, height = xmax - xmin, ymax - ymin + yc, xc = self.manualBackgroundObj.local_centroid + xstart, ystart = round(x - xc), round(y - yc) + xstart = xstart if xstart >= 0 else 0 + ystart = ystart if ystart >= 0 else 0 + + xend = xstart + width + yend = ystart + height + xend = xend if xend <= X else X + yend = yend if yend <= Y else Y + + width = xend - xstart + height = yend - ystart + + obj_image = self.manualBackgroundObj.image[:height, :width] + obj_slice = (slice(ystart, yend), slice(xstart, xend)) + ID = self.manualBackgroundObj.label + self.clearManualBackgroundObject(ID) + posData.manualBackgroundLab[obj_slice][obj_image] = ID + + if ID in self.manualBackgroundTextItems: + self.manualBackgroundTextItems[ID].setPos(x, y) + return + + textItem = pg.TextItem(text=str(ID), color="r", anchor=(0.5, 0.5)) + textItem.setFont(font_13px) + textItem.setPos(x, y) + self.manualBackgroundTextItems[ID] = textItem + + self.ax1.addItem(textItem) + + def addManualTrackingItems(self): + self.ghostContourItemLeft.addToPlotItem() + self.ghostContourItemRight.addToPlotItem() + + self.ghostMaskItemLeft.addToPlotItem() + self.ghostMaskItemRight.addToPlotItem() + + Y, X = self.img1.image.shape[:2] + self.ghostMaskItemLeft.initImage((Y, X)) + self.ghostMaskItemRight.initImage((Y, X)) + + self.updateGhostMaskOpacity() + + def annotateAssignedObjsAcdcTrackerSecondStep(self): + posData = self.data[self.pos_i] + annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) + if annotInfo is None: + return + + new_objs_1st_step, lost_objs_1st_step = annotInfo + for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): + allContours = self.getObjContours(lostObj, all_external=True) + for objContours in allContours: + isObjVisible = self.isObjVisible(newObj.bbox) + if not isObjVisible: + continue + xx = objContours[:, 0] + 0.5 + yy = objContours[:, 1] + 0.5 + self.yellowContourScatterItem.addPoints(xx, yy) + + y1, x1 = self.getObjCentroid(lostObj.centroid) + y2, x2 = self.getObjCentroid(newObj.centroid) + xx, yy = core.get_line(y1, x1, y2, x2, dashed=False) + self.ax1_oldMothBudLinesItem.addPoints(xx, yy) + + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None + + def clearAssignedObjsSecondStep(self): + posData = self.data[self.pos_i] + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None + + def clearGhost(self): + self.clearGhostContour() + self.clearGhostMask() + + def clearGhostContour(self): + self.ghostContourItemLeft.clear() + self.ghostContourItemRight.clear() + self.manualBackgroundObjItem.clear() + + def clearGhostMask(self): + self.ghostMaskItemLeft.clear() + self.ghostMaskItemRight.clear() + + def clearManualBackgroundAnnotations(self): + try: + for textItem in self.manualBackgroundTextItems.values(): + textItem.setText("") + except Exception as error: + pass + + def clearManualBackgroundObject(self, ID): + posData = self.data[self.pos_i] + mask = posData.manualBackgroundLab == ID + posData.manualBackgroundImage[mask, :] = 0 + posData.manualBackgroundLab[mask] = 0 + + def doSkipTracking(self, against_next: bool, enforce: bool): + if self.isSnapshot: + return True + + mode = str(self.modeComboBox.currentText()) + if mode != "Segmentation and Tracking": + return True + + if self.UserEnforced_DisabledTracking: + return True + + if not self.realTimeTrackingToggle.isChecked(): + return True + + posData = self.data[self.pos_i] + if against_next: + reference_lab = posData.allData_li[posData.frame_i + 1]["labels"] + if reference_lab is None: + # Next frame never visited --> cannot track against next + return True + + if posData.frame_i == posData.SizeT - 1: + # Last frame --> cannot track against next + return True + + else: + # check that we are not on the last frame + if posData.frame_i == 0: + return True + + if enforce or self.UserEnforced_Tracking: + # Enforce even if not last visited frame + return False + + is_first_time_on_next_frame = self.isFirstTimeOnNextFrame() + skip_tracking = not is_first_time_on_next_frame + + return skip_tracking + + def drawManualBackgroundObj(self, x, y): + if x is None or y is None: + self.clearGhost() + return + + self._drawManualBackgroundObjContour(x, y) + + def drawManualTrackingGhost(self, x, y): + if not self.manualTrackingToolbar.showGhostCheckbox.isChecked(): + return + + if x is None or y is None: + self.clearGhost() + return + + if self.manualTrackingToolbar.ghostContourRadiobutton.isChecked(): + self._drawGhostContour(x, y) + else: + self._drawGhostMask(x, y) + + def enableSmartTrack(self, checked): + posData = self.data[self.pos_i] + # Disable tracking for already visited frames + + if posData.allData_li[posData.frame_i]["labels"] is not None: + trackingEnabled = True + else: + trackingEnabled = False + + if checked: + self.UserEnforced_DisabledTracking = False + self.UserEnforced_Tracking = False + else: + if trackingEnabled: + self.UserEnforced_DisabledTracking = True + self.UserEnforced_Tracking = False + else: + self.UserEnforced_DisabledTracking = False + self.UserEnforced_Tracking = True + + def getLastTrackedFrame(self, posData): + last_tracked_i = 0 + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + frame_i -= 1 + break + if frame_i > 0: + return frame_i + else: + return last_tracked_i + + def getTrackedLostIDs(self, prev_lab=None, IDs_in_frames=None, frame_i=None): + trackedLostIDs = set() + posData = self.data[self.pos_i] + if self.isExportingVideo: + posData.trackedLostIDs = trackedLostIDs + return trackedLostIDs + + retrackedLostcent = set() + if frame_i is None: + frame_i = posData.frame_i + + if prev_lab is None: + prev_lab = self.get_labels( + from_store=True, + frame_i=posData.frame_i - 1, + return_existing=False, + return_copy=False, + ) + + if IDs_in_frames is None: + IDs_in_frames = posData.IDs + + try: + tracked_lost_centroids = posData.tracked_lost_centroids[frame_i] + except KeyError: + tracked_lost_centroids = set() + + for centroid in tracked_lost_centroids: + if len(centroid) < 3 and prev_lab.ndim == 3: + # Ignore wrongly stored centroids + continue + + ID = prev_lab[centroid] + if ID == 0: + continue + + if ID in IDs_in_frames: + retrackedLostcent.add(centroid) + continue + + trackedLostIDs.add(ID) + + posData.tracked_lost_centroids[frame_i] = ( + tracked_lost_centroids - retrackedLostcent + ) + posData.trackedLostIDs = trackedLostIDs + + return trackedLostIDs + + def get_last_tracked_i(self): + posData = self.data[self.pos_i] + last_tracked_i = 0 + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None and frame_i == 0: + last_tracked_i = 0 + break + elif lab is None: + last_tracked_i = frame_i - 1 + break + else: + last_tracked_i = posData.segmSizeT - 1 + return last_tracked_i + + def handleAdditionalInfoRealTimeTracker(self, prev_rp, *args): + if self._rtTrackerName == "CellACDC_normal_division": + tracked_lost_IDs = args[0] + self.setTrackedLostCentroids(prev_rp, tracked_lost_IDs) + elif self._rtTrackerName == "CellACDC_2steps": + if args[0] is None: + return + posData = self.data[self.pos_i] + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = args[0] + + def initGhostObject(self, ID=None): + mode = self.modeComboBox.currentText() + if mode != "Segmentation and Tracking": + self.ghostObject = None + return + + if not self.manualTrackingButton.isChecked(): + self.ghostObject = None + return + + if not self.manualTrackingToolbar.showGhostCheckbox.isChecked(): + self.ghostObject = None + return + + if ID is None: + ID = self.manualTrackingToolbar.spinboxID.value() + + posData = self.data[self.pos_i] + if posData.frame_i == 0: + self.ghostObject = None + return + + prevFrameRp = posData.allData_li[posData.frame_i - 1]["regionprops"] + if prevFrameRp is None: + self.ghostObject = None + return + + for obj in prevFrameRp: + if obj.label != ID: + continue + self.ghostObject = obj + break + else: + self.ghostObject = None + self.manualTrackingToolbar.showWarning( + f"The ID {ID} does not exist in previous frame " + "--> starting a new track." + ) + return + + self.manualTrackingToolbar.clearInfoText() + + self.ghostObject.contour = self.getObjContours(self.ghostObject, local=True) + self.ghostObject.xx_contour = self.ghostObject.contour[:, 0] + self.ghostObject.yy_contour = self.ghostObject.contour[:, 1] + + self.ghostMaskItemLeft.initLookupTable(self.lut[ID]) + self.ghostMaskItemRight.initLookupTable(self.lut[ID]) + + def initManualBackgroundObject(self, ID=None): + if not self.manualBackgroundButton.isChecked(): + self.manualBackgroundObj = None + return + + if ID is None: + ID = self.manualBackgroundToolbar.spinboxID.value() + + posData = self.data[self.pos_i] + if ID not in posData.IDs: + self.manualBackgroundObj = None + self.manualBackgroundToolbar.showWarning(f"The ID {ID} does not exist") + self.manualBackgroundObjItem.clear() + return + + ID_idx = posData.IDs_idxs[ID] + self.manualBackgroundObj = posData.rp[ID_idx] + + self.manualBackgroundToolbar.clearInfoText() + self.manualBackgroundObj.contour = self.getObjContours( + self.manualBackgroundObj, local=True + ) + xx_contour = self.manualBackgroundObj.contour[:, 0] + yy_contour = self.manualBackgroundObj.contour[:, 1] + self.manualBackgroundObj.xx_contour = xx_contour + self.manualBackgroundObj.yy_contour = yy_contour + + def initRealTimeTracker(self, force=False): + for rtTrackerAction in self.trackingAlgosGroup.actions(): + if rtTrackerAction.isChecked(): + break + + aliases = utils.aliases_real_time_trackers(reverse=True) + + rtTracker = rtTrackerAction.text() + rtTracker_txt = rtTracker + + if rtTracker in aliases: + rtTracker = aliases[rtTracker] + + if rtTracker == "Cell-ACDC": + return + if rtTracker == "YeaZ": + return + + if self.isRealTimeTrackerInitialized and not force: + return + + self.logger.info(f"Initializing {rtTracker_txt} tracker...") + self._rtTrackerName = rtTracker + posData = self.data[self.pos_i] + realTimeTracker, track_frame_params = utils.init_tracker( + posData, rtTracker, qparent=self, realTime=True + ) + if realTimeTracker is None: + self.logger.info(f"{rtTracker} tracker initialization cancelled.") + return + + self.realTimeTracker = realTimeTracker + self.track_frame_params = track_frame_params + self.logger.info(f"{rtTracker} tracker successfully initialized.") + if "image_channel_name" in self.track_frame_params: + # Remove the channel name since it was already loaded in init_tracker + del self.track_frame_params["image_channel_name"] + + def initSegmTrackMode(self): + posData = self.data[self.pos_i] + last_tracked_i = self.get_last_tracked_i() + + if posData.frame_i > last_tracked_i: + # Prompt user to go to last tracked frame + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + f'The last visited frame in "Segmentation and Tracking mode" ' + f"is frame {last_tracked_i + 1}.\n\n" + f"We recommend to resume from that frame.

    " + "How do you want to proceed?" + ) + goToButton, stayButton = msg.warning( + self, + "Go to last visited frame?", + txt, + buttonsTexts=( + f"Resume from frame {last_tracked_i + 1} (RECOMMENDED)", + f"Stay on current frame {posData.frame_i + 1}", + ), + ) + if msg.clickedButton == goToButton: + posData.frame_i = last_tracked_i + self.lastFrameRanOnFirstVisitTools = posData.frame_i + self.get_data() + self.updateAllImages() + self.updateScrollbars() + else: + last_tracked_i = posData.frame_i + current_frame_i = posData.frame_i + self.lastFrameRanOnFirstVisitTools = posData.frame_i + self.logger.info( + f"Storing data up until frame n. {current_frame_i + 1}..." + ) + pbar = tqdm(total=current_frame_i + 1, ncols=100) + for i in range(current_frame_i): + posData.frame_i = i + self.get_data() + self.store_data(autosave=i == current_frame_i - 1) + pbar.update() + pbar.close() + + posData.frame_i = current_frame_i + self.get_data() + + self.highlightLostNew() + self.updateLastCheckedFrameWidgets(last_tracked_i) + + self.isFirstTimeOnNextFrame() + self.initRealTimeTracker() + + def isFirstTimeOnNextFrame(self): + posData = self.data[self.pos_i] + posData.last_tracked_i = self.navigateScrollBar.maximum() - 1 + return posData.frame_i > posData.last_tracked_i + + def keepOnlyNewIDAssignedObjsSecondStep(self, trackedID): + posData = self.data[self.pos_i] + annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) + + if annotInfo is None: + return + + new_objs_1st_step, lost_objs_1st_step = annotInfo + correct_new_objs, correct_lost_objs = [], [] + for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): + newObj_ID = posData.lab[newObj.slice][newObj.image][0] + if newObj_ID != trackedID: + continue + + correct_new_objs.append(newObj) + correct_lost_objs.append(lostObj) + + if not correct_new_objs: + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None + else: + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = ( + correct_new_objs, + correct_lost_objs, + ) + + def manualBackground_cb(self, checked): + if checked: + posData = self.data[self.pos_i] + minID = min(posData.IDs, default=0) + if minID == self.manualBackgroundToolbar.spinboxID.value(): + self.initManualBackgroundObject() + else: + self.manualBackgroundToolbar.spinboxID.setValue(minID) + # self.initManualBackgroundObject() + # self.initManualBackgroundImage() + self.addManualBackgroundItems() + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.manualBackgroundButton) + self.connectLeftClickButtons() + self.updateAllImages() + else: + self.removeManualTrackingItems() + self.clearGhost() + self.clearManualBackgroundAnnotations() + self.manualBackgroundToolbar.setVisible(checked) + + def manualTracking_cb(self, checked): + self.manualTrackingToolbar.setVisible(checked) + if checked: + self.realTimeTrackingToggle.previousStatus = ( + self.realTimeTrackingToggle.isChecked() + ) + self.realTimeTrackingToggle.setChecked(False) + self.UserEnforced_DisabledTracking_previousStatus = ( + self.UserEnforced_DisabledTracking + ) + self.UserEnforced_Tracking_previousStatus = self.UserEnforced_Tracking + + self.UserEnforced_DisabledTracking = True + self.UserEnforced_Tracking = False + self.initGhostObject() + self.addManualTrackingItems() + else: + self.realTimeTrackingToggle.setChecked( + self.realTimeTrackingToggle.previousStatus + ) + self.UserEnforced_DisabledTracking = ( + self.UserEnforced_DisabledTracking_previousStatus + ) + self.UserEnforced_Tracking = self.UserEnforced_Tracking_previousStatus + self.removeManualTrackingItems() + self.clearGhost() + + def manuallyEditTracking(self, tracked_lab, allIDs): + posData = self.data[self.pos_i] + infoToRemove = [] + # Correct tracking with manually changed IDs + maxID = max(allIDs, default=1) + for y, x, new_ID in posData.editID_info: + old_ID = tracked_lab[y, x] + if old_ID == 0 or old_ID == new_ID: + infoToRemove.append((y, x, new_ID)) + continue + if new_ID in allIDs: + tempID = maxID + 1 + tracked_lab[tracked_lab == old_ID] = tempID + tracked_lab[tracked_lab == new_ID] = old_ID + tracked_lab[tracked_lab == tempID] = new_ID + else: + tracked_lab[tracked_lab == old_ID] = new_ID + if new_ID > maxID: + maxID = new_ID + for info in infoToRemove: + posData.editID_info.remove(info) + + def realTimeTrackingClicked(self, checked): + # Event called ONLY if the user click on Disable tracking + # NOT called if setChecked is called. This allows to keep track + # of the user choice. This way user con enforce tracking + # NOTE: I know two booleans doing the same thing is overkill + # but the code is more readable when we actually need them + + posData = self.data[self.pos_i] + isRealTimeTrackingDisabled = not checked + + # Turn off smart tracking + self.enableSmartTrackAction.toggled.disconnect() + self.enableSmartTrackAction.setChecked(False) + if isRealTimeTrackingDisabled: + self.UserEnforced_DisabledTracking = True + self.UserEnforced_Tracking = False + else: + txt = html_utils.paragraph(""" + + Do you want to keep tracking always active including on already + visited frames?

    + Note: To re-activate automatic handling of tracking go to
    + Edit --> Smart handling of enabling/disabling tracking. + + """) + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + yesButton, noButton = msg.question( + self, "Keep tracking always active?", txt, buttonsTexts=("Yes", "No") + ) + if msg.clickedButton == yesButton: + self.repeatTracking() + self.UserEnforced_DisabledTracking = False + self.UserEnforced_Tracking = True + else: + self.enableSmartTrackAction.setChecked(True) + + def removeManualBackgroundItems(self): + self.manualBackgroundObjItem.removeFromPlotItem() + self.ax1.removeItem(self.manualBackgroundImageItem) + + def removeManualTrackingItems(self): + self.ghostContourItemLeft.removeFromPlotItem() + self.ghostContourItemRight.removeFromPlotItem() + + self.ghostMaskItemLeft.removeFromPlotItem() + self.ghostMaskItemRight.removeFromPlotItem() + + def repeatTracking(self): + posData = self.data[self.pos_i] + prev_lab = self.get_2Dlab(posData.lab).copy() + self.tracking(enforce=True, DoManualEdit=False) + if posData.editID_info: + editedIDsInfo = { + posData.lab[y, x]: newID + for y, x, newID in posData.editID_info + if posData.lab[y, x] != newID + } + editedIDsInfoItems = [ + f"ID {oldID} --> {newID}" for oldID, newID in editedIDsInfo.items() + ] + editIDul = html_utils.to_list(editedIDsInfoItems) + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + You requested to repeat tracking but there are manually + edited IDs (see edited IDs in the details section below) +

    + Do you want to keep these edits or ignore them? + """) + keepManualEditButton = widgets.okPushButton("Keep manually edited IDs") + ignoreButton = widgets.noPushButton("Ignore manually edited IDs") + msg.question( + self, + "Repeat tracking mode", + txt, + buttonsTexts=(keepManualEditButton, ignoreButton), + detailsText=editIDul, + ) + if msg.cancel: + return + if msg.clickedButton == keepManualEditButton: + allIDs = [obj.label for obj in posData.rp] + lab2D = self.get_2Dlab(posData.lab) + self.manuallyEditTracking(lab2D, allIDs) + self.update_rp() + self.setAllTextAnnotations() + self.highlightLostNew() + # self.checkIDsMultiContour() + else: + posData.editID_info = [] + if np.any(posData.lab != prev_lab): + if self.isSnapshot: + self.fixCcaDfAfterEdit("Repeat tracking") + self.updateAllImages() + else: + self.warnEditingWithCca_df("Repeat tracking") + else: + self.updateAllImages() + + def repeatTrackingVideo(self, checked=False): + posData = self.data[self.pos_i] + win = widgets.selectTrackerGUI( + posData.SizeT, currentFrameNo=posData.frame_i + 1 + ) + win.exec_() + if win.cancel: + self.logger.info("Tracking aborted.") + return + + trackerName = win.selectedItemsText[0] + start_n = win.startFrame + stop_n = win.stopFrame + video_to_track = posData.segm_data + for frame_i in range(start_n - 1, stop_n): + data_dict = posData.allData_li[frame_i] + lab = data_dict["labels"] + if lab is None: + break + + video_to_track[frame_i] = lab + video_to_track = video_to_track[start_n - 1 : stop_n] + + self.logger.info(f"Importing {trackerName} tracker...") + self.tracker, self.track_params, init_params = utils.init_tracker( + posData, trackerName, qparent=self, return_init_params=True + ) + if self.track_params is None: + self.logger.info("Tracking aborted.") + return + + warningText = utils.validate_tracker_input(self.tracker, video_to_track) + if warningText is not None: + self.logger.info(warningText) + self.warnTrackerInputNotValid(trackerName, warningText) + return + + if "image_channel_name" in self.track_params: + # Remove the channel name since it was already loaded in init_tracker + del self.track_params["image_channel_name"] + + track_params_log = { + key: value for key, value in self.track_params.items() if key != "image" + } + self.logger.info( + "Tracking parameters:\n\n" + f"Initialization parameters: {init_params}\n" + f"Track parameters: {track_params_log}" + ) + + last_cca_i = self.get_last_cca_frame_i() + if start_n - 2 <= last_cca_i and start_n > 1: + proceed = self.warnRepeatTrackingVideoWithAnnotations(last_cca_i, start_n) + if not proceed: + self.logger.info("Tracking aborted.") + return + + self.logger.info(f"Removing annotations from frame n. {start_n}.") + self.resetCcaFuture(start_n - 1) + + self.start_n = start_n + self.stop_n = stop_n + + info_txt = f"Tracking from frame n. {start_n} to {stop_n}..." + self.logger.info(info_txt) + + self.progressWin = apps.QDialogWorkerProgress( + title="Tracking", parent=self, pbarDesc=info_txt + ) + self.progressWin.show(self.app) + self.progressWin.mainPbar.setMaximum(stop_n - start_n) + self.startTrackingWorker(posData, video_to_track) + + def resetManualBackgroundItems(self): + self.initManualBackgroundImage() + self.resetManualBackgroundSpinboxID() + self.drawManualTrackingGhost(self.xHoverImg, self.yHoverImg) + self.drawManualBackgroundObj(self.xHoverImg, self.yHoverImg) + + def resetManualBackgroundSpinboxID(self): + if not self.manualBackgroundButton.isChecked(): + self.manualBackgroundObj = None + return + + posData = self.data[self.pos_i] + minID = min(posData.IDs, default=0) + self.manualBackgroundToolbar.spinboxID.setValue(minID) + + def separateByLabelling(self, lab, rp, maxID=None): + """ + Label each single object in posData.lab and if the result is more than + one object then we insert the separated object into posData.lab + """ + setRp = False + posData = self.data[self.pos_i] + if maxID is None: + maxID = max(posData.IDs, default=1) + for obj in rp: + lab_obj = skimage.measure.label(obj.image) + rp_lab_obj = skimage.measure.regionprops(lab_obj) + if len(rp_lab_obj) <= 1: + continue + lab_obj += maxID + _slice = obj.slice # self.getObjSlice(obj.slice) + _objMask = obj.image # self.getObjImage(obj.image) + lab[_slice][_objMask] = lab_obj[_objMask] + setRp = True + maxID += 1 + return setRp + + def setManualBackgrounNextID(self): + posData = self.data[self.pos_i] + currentID = self.manualBackgroundObj.label + idx = posData.IDs_idxs[currentID] + next_idx = idx + 1 + if next_idx >= len(posData.IDs): + return + next_ID = posData.IDs[next_idx] + self.manualBackgroundToolbar.spinboxID.setValue(next_ID) + + def setManualBackgroundImage(self): + if not self.manualBackgroundButton.isChecked(): + return + + posData = self.data[self.pos_i] + if not hasattr(posData, "manualBackgroundImage"): + self.initManualBackgroundImage() + + contours = [] + for obj in skimage.measure.regionprops(posData.manualBackgroundLab): + obj_contours = self.getObjContours(obj, all_external=True) + contours.extend(obj_contours) + textItem = self.manualBackgroundTextItems[obj.label] + textItem.setText(f"{obj.label}") + self.ax1.addItem(textItem) + yc, xc = obj.centroid + textItem.setPos(xc, yc) + + cv2.drawContours( + posData.manualBackgroundImage, contours, -1, (255, 0, 0, 200), 1 + ) + self.manualBackgroundImageItem.setImage(posData.manualBackgroundImage) + + def setManualBackgroundLab(self, load_from_store=False, debug=True): + posData = self.data[self.pos_i] + if posData.manualBackgroundLab is None: + self.initManualBackgroundImage() + + for obj in skimage.measure.regionprops(posData.manualBackgroundLab): + textItem = pg.TextItem(text="", color="r", anchor=(0.5, 0.5)) + if obj.label in self.manualBackgroundTextItems: + continue + self.manualBackgroundTextItems[obj.label] = textItem + + def setTrackedLostCentroids(self, prev_rp, tracked_lost_IDs): + """Store centroids of those IDs the tracker decided is fine to lose + (e.g., upon standard cell division the ID of the mother is fine) + + Parameters + ---------- + prev_rp : skimage.measure.RegionProperties + List of region properties of the object in previous frame + tracked_lost_IDs : iterable + List-like container of the IDs that is fine to lose from previous + frame to current frame + + Note + ---- + This function stores the centroids because the user could change IDs + in multiple ways. Storing centroids is more robust. + """ + posData = self.data[self.pos_i] + frame_i = posData.frame_i + + for obj in prev_rp: + if obj.label not in tracked_lost_IDs: + continue + + int_centroid = tuple([int(val) for val in obj.centroid]) + try: + posData.tracked_lost_centroids[frame_i].add(int_centroid) + except KeyError: + posData.tracked_lost_centroids[frame_i] = {int_centroid} + + def trackFrame( + self, + prev_lab, + prev_rp, + curr_lab, + curr_rp, + curr_IDs, + assign_unique_new_IDs=True, + IDs=None, + unique_ID=None, + ): + if self.trackWithAcdcAction.isChecked(): + tracked_result = CellACDC_tracker.track_frame( + prev_lab, + prev_rp, + curr_lab, + curr_rp, + IDs_curr_untracked=curr_IDs, + setBrushID_func=self.setBrushID, + posData=self.data[self.pos_i], + assign_unique_new_IDs=assign_unique_new_IDs, + IDs=IDs, + unique_ID=unique_ID, + ) + elif self.trackWithYeazAction.isChecked(): + tracked_result = self.tracking_yeaz.correspondence( + prev_lab, curr_lab, use_modified_yeaz=True, use_scipy=True + ) + else: + tracked_result = self.trackFrameCustomTracker( + prev_lab, curr_lab, IDs=IDs, unique_ID=unique_ID + ) + + # Check if tracker also returns additional info + if isinstance(tracked_result, tuple): + tracked_lab, tracked_lost_IDs = tracked_result + self.handleAdditionalInfoRealTimeTracker(prev_rp, tracked_lost_IDs) + else: + tracked_lab = tracked_result + + return tracked_lab + + def trackFrameCustomTracker(self, prev_lab, currentLab, IDs=None, unique_ID=None): + if unique_ID is None: + unique_ID = self.setBrushID() + try: + tracked_result = self.realTimeTracker.track_frame( + prev_lab, + currentLab, + unique_ID=unique_ID, + IDs=IDs, + **self.track_frame_params, + ) + except TypeError as err: + if str(err).find("an unexpected keyword argument 'unique_ID'") != -1: + try: + tracked_result = self.realTimeTracker.track_frame( + prev_lab, currentLab, IDs=IDs, **self.track_frame_params + ) + except TypeError as err: + if str(err).find("an unexpected keyword argument 'IDs'") != -1: + tracked_result = self.realTimeTracker.track_frame( + prev_lab, currentLab, **self.track_frame_params + ) + else: + raise err + elif str(err).find("an unexpected keyword argument 'IDs'") != -1: + try: + tracked_result = self.realTimeTracker.track_frame( + prev_lab, + currentLab, + unique_ID=unique_ID, + **self.track_frame_params, + ) + except TypeError as err: + if ( + str(err).find("an unexpected keyword argument 'unique_ID'") + != -1 + ): + tracked_result = self.realTimeTracker.track_frame( + prev_lab, currentLab, **self.track_frame_params + ) + else: + raise err + else: + raise err + return tracked_result + + def trackManuallyAddedObject( + self, + added_IDs: List[int] | int | Set[int], + isNewID: bool, + wl_update: bool = True, + wl_track_og_curr: bool = False, + ): + """Track object added manually on frame that was already visited. + + Parameters + ---------- + added_IDs : int | list of int | set + ID or IDs of the object added manually + isNewID : bool + If True, the added object is new + + Notes + ----- + This method tracks the new added object against the previous frame + labels. If the ID determined by tracking is different from `added_ID` + (meaning that tracking thinks the new ID should be changed to the + tracked ID) and the tracked ID is not already existing (which would + otherwise causing merging) we assign the tracked ID to the object with + `added_ID`. + + If instead the tracked ID is the same as `added_ID` we are dealing + with a truly new object. In this case we want to try tracking it against + the next frame (since the next frame was already validated). + As before, we assign the tracked ID (against the next frame) only if + not already existing in current frame (to avoid merging). + """ + if self.isSnapshot: + return + + if not isNewID: + return + + if isinstance(added_IDs, int): + added_IDs = [added_IDs] + + posData = self.data[self.pos_i] + tracked_lab = self.tracking( + enforce=True, assign_unique_new_IDs=False, return_lab=True, IDs=added_IDs + ) + self.clearAssignedObjsSecondStep() + if tracked_lab is None: + return + + # Track only new object + prevIDs = posData.allData_li[posData.frame_i - 1]["IDs"] + + # mask = np.zeros(posData.lab.shape, dtype=bool) + update_rp = False + + for added_ID in added_IDs: + # try: + # obj = posData.rp[added_ID] # ID not present + # mask[obj.slice][obj.image] = True + + # except IndexError as err: + mask = posData.lab == added_ID + try: + trackedID = tracked_lab[mask][0] + except IndexError as err: + # added_ID is not present + continue + + isTrackedIDalreadyPresentAndNotNew = ( + posData.IDs_idxs.get(trackedID) is not None and added_ID != trackedID + ) + if isTrackedIDalreadyPresentAndNotNew: + continue + + isTrackedIDinPrevIDs = trackedID in prevIDs + if isTrackedIDinPrevIDs: + posData.lab[mask] = trackedID + else: + # New object where we can try to track against next frame + trackedID = self.trackNewIDtoNewIDsFutureFrame(added_ID, mask) + if trackedID is None: + self.clearAssignedObjsSecondStep() + continue + posData.lab[mask] = trackedID + + self.keepOnlyNewIDAssignedObjsSecondStep(trackedID) + update_rp = True + + if update_rp: + self.update_rp(wl_update=wl_update) + + def trackNewIDtoNewIDsFutureFrame(self, newID, newIDmask): + posData = self.data[self.pos_i] + try: + nextLab = posData.allData_li[posData.frame_i + 1]["labels"] + except IndexError: + # This is last frame --> there are no future frames + return + + if nextLab is None: + return + + newID_lab = np.zeros_like(posData.lab) + newID_lab[newIDmask] = newID + newLab_rp = [posData.rp[posData.IDs_idxs[newID]]] + newLab_IDs = [newID] + nextRp = posData.allData_li[posData.frame_i + 1]["regionprops"] + + tracked_lab = self.trackFrame( + nextLab, + nextRp, + newID_lab, + newLab_rp, + newLab_IDs, + assign_unique_new_IDs=False, + ) + trackedID = tracked_lab[newID_lab > 0][0] + if trackedID == newID: + # Object does not exist in future frame --> do not track + return + + if posData.IDs_idxs.get(trackedID) is not None: + # Tracked ID already exists --> do not track to avoid merging + return + + return trackedID + + def trackSubsetIDs(self, subsetIDs: Iterable[int]): + posData = self.data[self.pos_i] + if posData.frame_i == 0: + return + + subsetLab = np.zeros_like(posData.lab) + for subsetID in subsetIDs: + subsetLab[posData.lab == subsetID] = subsetID + + prev_lab = posData.allData_li[posData.frame_i - 1]["labels"] + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + tracked_lab = self.trackFrame( + prev_lab, + prev_rp, + posData.lab, + posData.rp, + posData.IDs, + assign_unique_new_IDs=True, + ) + doUpdateRp = False + for subsetID in subsetIDs: + subsetIDmask = posData.lab == subsetID + trackedID = tracked_lab[subsetIDmask][0] + if trackedID == subsetID: + continue + + is_manually_edited = False + for y, x, new_ID in posData.editID_info: + if new_ID == subsetID: + # Do not track because it was manually edited + break + + posData.lab[subsetIDmask] = tracked_lab[subsetIDmask] + doUpdateRp = True + + if not doUpdateRp: + return + + self.update_rp() + + def tracking( + self, + enforce=False, + DoManualEdit=True, + storeUndo=False, + prev_lab=None, + prev_rp=None, + return_lab=False, + assign_unique_new_IDs=True, + separateByLabel=True, + wl_update=True, + IDs=None, + against_next=False, + ): + posData = self.data[self.pos_i] + + if self.doSkipTracking(against_next, enforce): + self.setLostNewOldPrevIDs() + return + + """Tracking starts here""" + staturBarLabelText = self.statusBarLabel.text() + self.statusBarLabel.setText("Tracking...") + + if storeUndo: + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + # First separate by labelling + if separateByLabel: + maxID = max(posData.IDs, default=1) + setRp = core.split_connected_components( + posData.lab, rp=posData.rp, max_ID=maxID + ) + if setRp: + self.update_rp( + wl_update=wl_update, + ) + + if prev_lab is None: + if not against_next: + prev_lab = posData.allData_li[posData.frame_i - 1]["labels"] + else: + prev_lab = posData.allData_li[posData.frame_i + 1]["labels"] + if prev_rp is None: + if not against_next: + prev_rp = posData.allData_li[posData.frame_i - 1]["regionprops"] + else: + prev_rp = posData.allData_li[posData.frame_i + 1]["regionprops"] + + unique_ID = None + if posData.frame_i < self.get_last_tracked_i(): + unique_ID = self.setBrushID(return_val=True) + + tracked_lab = self.trackFrame( + prev_lab, + prev_rp, + posData.lab, + posData.rp, + posData.IDs, + assign_unique_new_IDs=assign_unique_new_IDs, + IDs=IDs, + unique_ID=unique_ID, + ) + + if DoManualEdit: + # Correct tracking with manually changed IDs + rp = skimage.measure.regionprops(tracked_lab) + IDs = [obj.label for obj in rp] + self.manuallyEditTracking(tracked_lab, IDs) + + if return_lab: + QTimer.singleShot( + 50, partial(self.statusBarLabel.setText, staturBarLabelText) + ) + return tracked_lab + + # Update labels, regionprops and determine new and lost IDs + posData.lab = tracked_lab + self.update_rp( + wl_update=wl_update, + ) + self.setAllTextAnnotations() + QTimer.singleShot(50, partial(self.statusBarLabel.setText, staturBarLabelText)) + + def updateAssignedObjsAcdcTrackerSecondStep(self, newID): + posData = self.data[self.pos_i] + annotInfo = posData.acdcTracker2stepsAnnotInfo.get(posData.frame_i) + if annotInfo is None: + return + + new_objs_1st_step, lost_objs_1st_step = annotInfo + correct_new_objs, correct_lost_objs = [], [] + for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): + newObj_ID = posData.lab[newObj.slice][newObj.image][0] + if newObj_ID == newID: + # The ID of the new object tracked with 2nd step was + # manually edit --> do not annotate its linking to lost obj anymore + continue + correct_new_objs.append(newObj) + correct_lost_objs.append(lostObj) + + if not correct_new_objs: + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = None + else: + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = ( + correct_new_objs, + correct_lost_objs, + ) + self.annotateAssignedObjsAcdcTrackerSecondStep() + + def updateGhostMaskOpacity(self, alpha_percentage=None): + if alpha_percentage is None: + alpha_percentage = ( + self.manualTrackingToolbar.ghostMaskOpacitySpinbox.value() + ) + alpha = alpha_percentage / 100 + self.ghostMaskItemLeft.setOpacity(alpha) + self.ghostMaskItemRight.setOpacity(alpha) + + def updateLastCheckedFrameWidgets(self, last_tracked_i): + self.navigateScrollBar.setMaximum(last_tracked_i + 1) + self.navSpinBox.setMaximum(last_tracked_i + 1) + self.lastTrackedFrameLabel.setText( + f"Last checked frame n. = {last_tracked_i + 1}" + ) + + def warnRepeatTrackingVideoOnVisitedFrames(self, last_tracked_i, start_n): + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "You are repeating tracking on frames that have already " + "been visited/tracked before.

    " + "This will very likely make the annotations wrong.

    " + "If you really want to repeat tracking on the frames before " + f"{last_tracked_i + 1} the annotations from frame " + f"{start_n} to frame {last_tracked_i + 1} " + "will be removed.

    " + "Do you want to continue?" + ) + noButton, yesButton = msg.warning( + self, + "Repating tracking with annotations!", + txt, + buttonsTexts=( + " No, stop tracking and keep annotations.", + " Yes, repeat tracking and DELETE annotations.", + ), + ) + if msg.cancel: + return False + + if msg.clickedButton == noButton: + return False + else: + return True + + def warnRepeatTrackingVideoWithAnnotations(self, last_tracked_i, start_n): + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "You are repeating tracking on frames that have cell cycle " + "annotations.

    " + "This will very likely make the annotations wrong.

    " + "If you really want to repeat tracking on the frames before " + f"{last_tracked_i + 1} the annotations from frame " + f"{start_n} to frame {last_tracked_i + 1} " + "will be removed.

    " + "Do you want to continue?" + ) + noButton, yesButton = msg.warning( + self, + "Repating tracking with annotations!", + txt, + buttonsTexts=( + " No, stop tracking and keep annotations.", + " Yes, repeat tracking and DELETE annotations.", + ), + ) + if msg.cancel: + return False + + if msg.clickedButton == noButton: + return False + else: + return True + + def warnTrackerInputNotValid(self, trackerName, warningText): + msg = widgets.myMessageBox(wrapText=False) + txt = warningText.replace("\n", "
    ") + txt = html_utils.paragraph( + f"{txt}

    " + "Tracking process will be cancelled. Thank you for your patience!" + ) + msg.warning(self, "Invalid input for tracker", txt) diff --git a/cellacdc/mixins/undo_redo.py b/cellacdc/mixins/undo_redo.py new file mode 100644 index 000000000..16e2abcde --- /dev/null +++ b/cellacdc/mixins/undo_redo.py @@ -0,0 +1,401 @@ +"""Qt view adapter for undo, redo, and future-frame propagation.""" + +from __future__ import annotations + +import uuid + +from cellacdc import apps, html_utils, widgets + + +from collections import defaultdict + +from .label_editing import LabelEditing + + +class UndoRedo(LabelEditing): + """Extracted from guiWin.""" + + def UndoCca(self): + posData = self.data[self.pos_i] + # Undo current ccaState + storeState = False + if self.UndoCount == 0: + undoId = uuid.uuid4() + self.addCcaState(posData.frame_i, posData.cca_df, undoId) + storeState = True + + # Get previously stored state + self.UndoCount += 1 + currentCcaStates = posData.UndoRedoCcaStates[posData.frame_i] + prevCcaState = currentCcaStates[self.UndoCount] + posData.cca_df = prevCcaState["cca_df"] + self.store_cca_df() + self.updateAllImages() + + # Check if we have undone all states + if len(currentCcaStates) > self.UndoCount: + # There are no states left to undo for current frame_i + self.undoAction.setEnabled(False) + + # Undo all past and future frames that has a last status inserted + # when modyfing current frame + prevStateId = prevCcaState["id"] + for frame_i in range(0, posData.SizeT): + if storeState: + cca_df_i = self.get_cca_df(frame_i=frame_i, return_df=True) + if cca_df_i is None: + break + # Store current state to enable redoing it + self.addCcaState(frame_i, cca_df_i, undoId) + + CcaStates_i = posData.UndoRedoCcaStates[frame_i] + if len(CcaStates_i) <= self.UndoCount: + # There are no states to undo for frame_i + continue + + CcaState_i = CcaStates_i[self.UndoCount] + id_i = CcaState_i["id"] + if id_i != prevStateId: + # The id of the state in frame_i is different from current frame + continue + + cca_df_i = CcaState_i["cca_df"] + self.store_cca_df(frame_i=frame_i, cca_df=cca_df_i, autosave=False) + + self.resetWillDivideInfo() + self.enqAutosave() + + def addCcaState(self, frame_i, cca_df, undoId): + posData = self.data[self.pos_i] + posData.UndoRedoCcaStates[frame_i].insert( + 0, {"id": undoId, "cca_df": cca_df.copy()} + ) + + def addCurrentState(self, storeImage=False, storeOnlyZoom=False): + posData = self.data[self.pos_i] + if posData.cca_df is not None: + cca_df = posData.cca_df.copy() + else: + cca_df = None + + if storeImage: + image = self.img1.image.copy() + else: + image = None + + if storeOnlyZoom: + labels, crop_slice = transformation.crop_2D( + self.currentLab2D, self.ax1.viewRange(), tolerance=10, return_copy=False + ) + if self.isSegm3D: + z = self.z_lab(checkIfProj=True) + if z is None: + z_slice = slice(0, len(posData.lab)) + crop_slice = (z_slice, *crop_slice) + labels = posData.lab[crop_slice].copy() + else: + z_slice = z + crop_slice = (z_slice, *crop_slice) + labels = labels.copy() + else: + labels = labels.copy() + else: + labels = posData.lab.copy() + crop_slice = None + + state = { + "image": image, + "labels": labels, + "editID_info": posData.editID_info.copy(), + "binnedIDs": posData.binnedIDs.copy(), + "keptObejctsIDs": self.keptObjectsIDs.copy(), + "ripIDs": posData.ripIDs.copy(), + "cca_df": cca_df, + "crop_slice": crop_slice, + } + posData.UndoRedoStates[posData.frame_i].insert(0, state) + + def askPropagateChangePast(self, change_txt): + txt = html_utils.paragraph(f""" + Do you want to propagate the change "{change_txt}" to the past frames? + """) + msg = widgets.myMessageBox(wrapText=False) + yesButton, _ = msg.question( + self, "Propagate change to past frames", txt, buttonsTexts=("Yes", "No") + ) + return msg.clickedButton == yesButton + + def clearUndoQueue(self): + posData = self.data[self.pos_i] + self.UndoCount = 0 + self.redoAction.setEnabled(False) + self.undoAction.setEnabled(False) + posData.UndoRedoStates = [[] for _ in range(posData.SizeT)] + posData.UndoRedoCcaStates = [[] for _ in range(posData.SizeT)] + if hasattr(self, "undoAddPointQueueMapper"): + self.undoAddPointQueueMapper = defaultdict(list) + + def getCurrentState(self): + posData = self.data[self.pos_i] + i = posData.frame_i + c = self.UndoCount + state = posData.UndoRedoStates[i][c] + if state["image"] is None: + image_left = None + else: + image_left = state["image"].copy() + + crop_slice = state["crop_slice"] + if crop_slice is None: + posData.lab = state["labels"].copy() + elif self.isSegm3D: + z_slice, slice_y, slice_x = crop_slice + posData.lab[..., z_slice, slice_y, slice_x] = state["labels"].copy() + else: + slice_y, slice_x = crop_slice + posData.lab[..., slice_y, slice_x] = state["labels"].copy() + + posData.editID_info = state["editID_info"].copy() + posData.binnedIDs = state["binnedIDs"].copy() + posData.ripIDs = state["ripIDs"].copy() + self.keptObjectsIDs = state["keptObejctsIDs"].copy() + cca_df = state["cca_df"] + if cca_df is not None: + posData.cca_df = state["cca_df"].copy() + else: + posData.cca_df = None + return image_left + + def propagateChange( + self, + modID, + modTxt, + doNotShow, + UndoFutFrames, + applyFutFrames, + applyTrackingB=False, + force=False, + ): + """ + This function determines whether there are already visited future frames + that contains "modID". If so, it triggers a pop-up asking the user + what to do (propagate change to future frames o not) + """ + posData = self.data[self.pos_i] + # Do not check the future for the last frame + if posData.frame_i + 1 == posData.SizeT: + # No future frames to propagate the change to + return False, False, None, doNotShow + + includeUnvisited = posData.includeUnvisitedInfo.get(modTxt, False) + areFutureIDs_affected = [] + # Get number of future frames already visited and check if future + # frames has an ID affected by the change + last_tracked_i_found = False + segmSizeT = len(posData.segm_data) + for i in range(posData.frame_i + 1, segmSizeT): + if posData.allData_li[i]["labels"] is None: + if not last_tracked_i_found: + # We set last tracked frame at -1 first None found + last_tracked_i = i - 1 + last_tracked_i_found = True + if not includeUnvisited: + # Stop at last visited frame since includeUnvisited = False + break + else: + lab = posData.segm_data[i] + else: + lab = posData.allData_li[i]["labels"] + + if modID in lab: + areFutureIDs_affected.append(True) + + if not last_tracked_i_found: + # All frames have been visited in segm&track mode + last_tracked_i = posData.SizeT - 1 + + if last_tracked_i == posData.frame_i and not includeUnvisited: + # No future frames to propagate the change to + return False, False, None, doNotShow + + if not areFutureIDs_affected and not force: + # There are future frames but they are not affected by the change + return UndoFutFrames, False, None, doNotShow + + # Ask what to do unless the user has previously checked doNotShowAgain + if doNotShow: + endFrame_i = last_tracked_i + if applyFutFrames and not UndoFutFrames and modTxt == "Edit ID": + self.whitelistSyncIDsOG(frame_is=range(posData.frame_i, endFrame_i + 1)) + return UndoFutFrames, applyFutFrames, endFrame_i, doNotShow + else: + addApplyAllButton = ( + modTxt == "Delete ID" + or modTxt == "Edit ID" + or modTxt == "Assign new ID" + ) + ffa = apps.FutureFramesAction_QDialog( + posData.frame_i + 1, + last_tracked_i, + modTxt, + applyTrackingB=applyTrackingB, + parent=self, + addApplyAllButton=addApplyAllButton, + ) + ffa.exec_() + decision = ffa.decision + + if decision is None: + return None, None, None, doNotShow + + endFrame_i = ffa.endFrame_i + doNotShowAgain = ffa.doNotShowCheckbox.isChecked() + askAction = self.askHowFutureFramesActions[modTxt] + askAction.setChecked(not doNotShowAgain) + askAction.setDisabled(False) + + self.onlyTracking = False + if decision == "apply_and_reinit": + UndoFutFrames = True + applyFutFrames = False + elif decision == "apply_and_NOTreinit": + UndoFutFrames = False + applyFutFrames = False + elif decision == "apply_to_all_visited": + UndoFutFrames = False + applyFutFrames = True + elif decision == "only_tracking": + UndoFutFrames = False + applyFutFrames = True + self.onlyTracking = True + elif decision == "apply_to_all": + UndoFutFrames = False + applyFutFrames = True + posData.includeUnvisitedInfo[modTxt] = True + + if applyFutFrames and not UndoFutFrames and modTxt == "Edit ID": + self.whitelistSyncIDsOG(frame_is=range(posData.frame_i, endFrame_i + 1)) + return UndoFutFrames, applyFutFrames, endFrame_i, doNotShowAgain + + def propagateMergeObjsPast(self, IDs_to_merge): + self.store_data(autosave=False) + posData = self.data[self.pos_i] + current_frame_i = posData.frame_i + for past_frame_i in range(posData.frame_i - 1, -1, -1): + posData.frame_i = past_frame_i + self.get_data() + + IDs = posData.allData_li[past_frame_i]["IDs"] + stop_loop = False + for ID in IDs_to_merge: + if ID not in IDs: + stop_loop = True + break + + if ID == 0: + continue + posData.lab[posData.lab == ID] = self.firstID + self.update_rp() + + self.store_data(autosave=False) + + if stop_loop: + break + + posData.frame_i = current_frame_i + self.get_data() + + def redo(self): + posData = self.data[self.pos_i] + # Get previously stored state + if self.UndoCount > 0: + self.UndoCount -= 1 + # Since we have redone then it is possible to undo + self.undoAction.setEnabled(True) + + # Restore state + image_left = self.getCurrentState() + self.update_rp() + self.updateAllImages(image=image_left) + self.store_data() + + if not self.UndoCount > 0: + # We have redone all available states + self.redoAction.setEnabled(False) + + if self.whitelistIDsButton.isChecked(): + self.whitelistHighlightIDs() + + def storeUndoRedoCca(self, frame_i, cca_df, undoId): + if self.isSnapshot: + # For snapshot mode we don't store anything because we have only + # segmentation undo action active + return + """ + Store current cca_df along with a unique id to know which cca_df needs + to be restored + """ + + posData = self.data[self.pos_i] + + # Restart count from the most recent state (index 0) + # NOTE: index 0 is most recent state before doing last change + self.UndoCcaCount = 0 + self.undoAction.setEnabled(True) + + self.addCcaState(frame_i, cca_df, undoId) + + # Keep only 10 Undo/Redo states + if len(posData.UndoRedoCcaStates[frame_i]) > 10: + posData.UndoRedoCcaStates[frame_i].pop(-1) + + def storeUndoRedoStates(self, UndoFutFrames, storeImage=False, storeOnlyZoom=False): + posData = self.data[self.pos_i] + if UndoFutFrames: + # Since we modified current frame all future frames that were already + # visited are not valid anymore. Undo changes there + self.reInitLastSegmFrame(updateImages=False) + + # Keep only 5 Undo/Redo states + if len(posData.UndoRedoStates[posData.frame_i]) > 5: + posData.UndoRedoStates[posData.frame_i].pop(-1) + + # Restart count from the most recent state (index 0) + # NOTE: index 0 is most recent state before doing last change + self.UndoCount = 0 + self.undoAction.setEnabled(True) + self.addCurrentState(storeImage=storeImage, storeOnlyZoom=storeOnlyZoom) + + def undo(self): + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + if addPointsByClickingButton is not None: + done = self.undoAddPoint(addPointsByClickingButton.action) + if done: + return + + if self.UndoCount == 0: + # Store current state to enable redoing it + self.addCurrentState() + + posData = self.data[self.pos_i] + # Get previously stored state + if self.UndoCount < len(posData.UndoRedoStates[posData.frame_i]) - 1: + self.UndoCount += 1 + # Since we have undone then it is possible to redo + self.redoAction.setEnabled(True) + + # Restore state + image_left = self.getCurrentState() + self.update_rp() + self.updateAllImages(image=image_left) + self.store_data() + + if not self.UndoCount < len(posData.UndoRedoStates[posData.frame_i]) - 1: + # We have undone all available states + self.undoAction.setEnabled(False) + + if self.whitelistIDsButton.isChecked(): + self.whitelistHighlightIDs() + + def undoCustomAnnotation(self): + pass diff --git a/cellacdc/mixins/whitelist.py b/cellacdc/mixins/whitelist.py new file mode 100644 index 000000000..99ea2b236 --- /dev/null +++ b/cellacdc/mixins/whitelist.py @@ -0,0 +1,1177 @@ +"""Whitelist GUI mixin extracted from whitelist.py.""" + +from __future__ import annotations + +import os +import time + +import numpy as np +import skimage.measure +from typing import Set, List, Tuple + +from cellacdc import ( + apps, + disableWindow, + exception_handler, + exec_time, + html_utils, + printl, + widgets, +) +from cellacdc.trackers.CellACDC import CellACDC_tracker +from cellacdc.whitelist import Whitelist + + +class WhitelistGui: + """A class to manage the whitelist GUI elements.""" + + def whitelistCheckOriginalLabels(self, warning: bool = True, frame_i: int = None): + """Warns the user that there are no original labels labels are present + for the frame""" + posData = self.data[self.pos_i] + if posData.whitelist is None: + return False + + if frame_i is None: + frame_i = posData.frame_i + + if posData.whitelist.originalLabsIDs is None: + return False + + if ( + frame_i >= len(posData.whitelist.originalLabsIDs) + or posData.whitelist.originalLabsIDs[frame_i] is None + ): + txt = """ + No original labels are present for the current frame, + this action cannot be performed.""" + self.logger.warning(txt) + if not warning: + return False + msg = widgets.myMessageBox.warning( + self, + "No original labels", + txt, + ) + + return False + else: + return True + + @disableWindow + def whitelistTrackOGagainstPreviousFrame_cb(self, signal_slot=None): + """Tracks the original labels against the previous frame. + This is used as a callback for sigTrackOGagainstPreviousFrame signal + """ + posData = self.data[self.pos_i] + frame_i = posData.frame_i + if not self.whitelistCheckOriginalLabels(): + return + old_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] + prev_cell_IDs = posData.allData_li[frame_i - 1]["IDs"] + self.whitelistTrackOGCurr(against_prev=True) + new_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] + + new_IDs = new_cell_IDs - old_cell_IDs + new_IDs = new_IDs & set(prev_cell_IDs) + + self.whitelistUpdateLab( + track_og_curr=False, + IDs_to_add=new_IDs, + ) + + def whitelistLoadOGLabs_cb(self): + """Generates a dialog to load the original (not whitelisted) labels""" + posData = self.data[self.pos_i] + curr_seg_path = posData.segm_npz_path + + segmFilename = os.path.basename(curr_seg_path) + custom_first = f"{segmFilename[:-4]}_not_whitelisted.npz" + images_path = posData.images_path + existingEndnames = [ + files for files in os.listdir(images_path) if files.endswith(".npz") + ] + if custom_first not in existingEndnames: + custom_first = None + + infoText = html_utils.paragraph( + "Select the segmentation file containing the original labels " + 'of the objects. Pleae note that the current saved "original" ' + "labels will be replaced with the new ones, but the filtered " + "labels will be kept." + ) + + win = apps.SelectSegmFileDialog( + existingEndnames, + images_path, + parent=self, + basename=posData.basename, + infoText=infoText, + custom_first=custom_first, + ) + win.exec_() + if win.cancel: + self.logger.info("Loading original labels canceled.") + return + selected = win.selectedItemText + self.logger.info(f"Loading original labels from {selected}...") + self.whitelistLoadOGLabs(selected) + + @disableWindow + def whitelistLoadOGLabs(self, selected: str): + """Loads the original labels from the selected files + + Parameters + ---------- + selected : str + Selected file name from the dialog. + """ + posData = self.data[self.pos_i] + images_path = posData.images_path + + selected_path = os.path.join(images_path, selected) + posData.whitelist.loadOGLabs(selected_path) + + self.whitelistIDsToolbar.viewOGToggle.setCheckable(True) + + @exception_handler + @disableWindow + def whitelistViewOGIDs(self, checked: bool): + """Switch between selected and original labels. + Uses self.viewOriginalLabels to see what has to be done. + + Parameters + ---------- + checked : bool + True if the original labels have to be shown, False otherwise. + """ + switch_to_og = checked and not self.viewOriginalLabels + switch_to_seg = not checked and self.viewOriginalLabels + + if not switch_to_og and not switch_to_seg: + return + + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + if posData.whitelist._debug: + printl("whitelistViewOGIDs", checked) + + frame_i = posData.frame_i + if frame_i > 0: + frames_range = [frame_i - 1, frame_i] + else: + frames_range = [frame_i] + + self.store_data(autosave=False) + + if not self.whitelistCheckOriginalLabels(): + return + if switch_to_og: + self.setFrameNavigationDisabled(True, why="Viewing original labels") + self.viewOriginalLabels = True + + for i in frames_range: + posData.frame_i = i + self.get_data() + self.whitelistTrackOGCurr(frame_i=i) + + IDs = posData.IDs + + og_frame = posData.whitelist.originalLabs[i].copy() + IDs_to_uppdate = ( + posData.whitelist.whitelistIDs[i] + & posData.whitelist.originalLabsIDs[i] + ) + if IDs_to_uppdate: + mask = np.isin(og_frame, list(IDs_to_uppdate)) + og_frame[mask] = 0 + + mask = np.isin(posData.lab, list(IDs_to_uppdate)) + og_frame[mask] = posData.lab[mask] + + IDs_to_add = ( + posData.whitelist.whitelistIDs[i] + - posData.whitelist.originalLabsIDs[i] + ) + if IDs_to_add: + mask = np.isin(posData.lab, list(IDs_to_add)) + og_frame[mask] = posData.lab[mask] + + posData.lab = og_frame + self.update_rp(wl_update=False) + self.store_data(autosave=False) + + if frame_i > 0: + missing_IDs = set(posData.IDs) - set( + posData.allData_li[frame_i - 1]["IDs"] + ) + self.trackManuallyAddedObject( + missing_IDs, isNewID=True, wl_update=False + ) + + self.setAllTextAnnotations() + self.updateAllImages() + + elif switch_to_seg: + self.viewOriginalLabels = False + self.setFrameNavigationDisabled(False, why="Viewing original labels") + + for i in frames_range: + posData.frame_i = i + self.get_data() + try: + posData.whitelist.originalLabs[i] = posData.lab.copy() + posData.whitelist.originalLabsIDs[i] = set(posData.IDs) + except AttributeError: + lab = posData.segm_data[i].copy() + IDs = [obj.label for obj in skimage.measure.regionprops(lab)] + posData.whitelist.originalLabs[i] = lab + posData.whitelist.originalLabsIDs[i] = set(IDs) + + # self.whitelistTrackCurrOG() + self.update_rp(wl_update=False) + self.store_data(autosave=False) + self.whitelistUpdateLab(frame_i=i) # has update_rp and store data + self.setAllTextAnnotations() + self.updateAllImages() + + def whitelistSetViewOGIDsToggle(self, checked: bool): + """Set the view original labels toggle button to checked or unchecked. + This also updates the self.viewOriginalLabels variable. + !!! Doesn't change the actually displayed labels, use self.whitelistViewOGIDs + to do that.!!! + + Parameters + ---------- + checked : bool + True if the original labels are shown, False otherwise. + """ + self.viewOriginalLabels = checked + self.whitelistIDsToolbar.viewOGToggle.blockSignals(True) + self.whitelistIDsToolbar.viewOGToggle.setChecked(checked) + self.whitelistIDsToolbar.viewOGToggle.blockSignals(False) + + def whitelistAddNewIDsToggled(self, checked: bool): + """Will set self.addNewIDsWhitelistToggle to checked and call + whitelistAddNewIDs if checked is True. + + Parameters + ---------- + checked : bool + True if the add new IDs toggle is checked, False otherwise. + """ + self.addNewIDsWhitelistToggle = checked + if checked: + self.df_settings.at["addNewIDsWhitelistToggle", "value"] = "Yes" + else: + self.df_settings.at["addNewIDsWhitelistToggle", "value"] = "No" + self.df_settings.to_csv(self.settings_csv_path) + if checked: + self.whitelistAddNewIDs(ignore_not_first_time=True) + self.whitelistPropagateIDs() + self.updateAllImages() + self.whitelistIDsUpdateText() + + def whitelistAddNewIDs(self, ignore_not_first_time: bool = False): + """Function which adds new IDs to the whitelist, based on the original labels. + It will check if the frame is visited the first time, unless + ignore_not_first_time is True. + It does nothing if self.addNewIDsWhitelistToggle is False. + !!!Careful, does not change the lab, just the whitelist!!! + + Parameters + ---------- + ignore_not_first_time : bool, optional + Weather it should be checked if the frame is visited + the first time, by default False + """ + mode = self.modeComboBox.currentText() + if mode != "Segmentation and Tracking": + return + + if not self.addNewIDsWhitelistToggle: + return + + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + debug = posData.whitelist._debug + + if debug: + printl("whitelistAddNewIDs") + + posData = self.data[self.pos_i] + frame_i = posData.frame_i + + if self.get_last_tracked_i() > frame_i and not ignore_not_first_time: + return + + if frame_i == 0: + return + + if ( + self.whitelistAddNewIDsFrame is not None + and frame_i == self.whitelistAddNewIDsFrame + ): + return + + self.whitelistAddNewIDsFrame = frame_i + + curr_lab = self.get_curr_lab() + + posData.whitelist.addNewIDs( + frame_i=frame_i, + allData_li=posData.allData_li, + IDs_curr=posData.IDs, + curr_lab=curr_lab, + ) + + def whitelistIDsAccepted(self, whitelistIDs: Set[int] | List[int]): + """Function which is called when the user accepts a whitelist. + Also initializes the whitelist if it is not already initialized. (Aka not loaded) + + Parameters + ---------- + whitelistIDs : set | list + The accepted IDs from the whitelist dialog. + """ + # Store undo state before modifying stuff + self.storeUndoRedoStates(False) + + self.whitelistIDsToolbar.viewOGToggle.setCheckable(True) + self.whitelistSetViewOGIDsToggle(False) + self.setFrameNavigationDisabled(False, why="Viewing original labels") + + self.store_data(autosave=False) + + posData = self.data[self.pos_i] + + if not posData.whitelist: + posData.whitelist = Whitelist( + total_frames=posData.SizeT, + ) + + if posData.whitelist._debug: + printl("whitelistIDsAccepted", whitelistIDs) + + whitelistIDs = set(whitelistIDs) + + IDs_curr = set(posData.IDs) + + posData.whitelist.IDsAccepted( + whitelistIDs, + segm_data=posData.segm_data, + frame_i=posData.frame_i, + allData_li=posData.allData_li, + IDs_curr=IDs_curr, + curr_lab=posData.lab, + ) + + # self.whitelistPropagateIDs(new_whitelist=whitelistIDs, + # try_create_new_whitelists=True, + # only_future_frames=True, + # force_not_dynamic_update=True, + # update_lab=True + # ) + self.whitelistUpdateLab(track_og_curr=True) + + self.whitelistIDsUpdateText() + self.keepIDsTempLayerLeft.clear() + + def whitelistUpdateLab( + self, + frame_i: int = None, + track_og_curr=False, + new_frame: bool = False, + IDs_to_add: List[int] | Set[int] = None, + IDs_to_remove: List[int] | Set[int] = None, + ): + # this should also work for 3D i think... + """Updates the displayed lab based on the whitelist. + + Parameters + ---------- + frame_i : int, optional + frame which should be updated. If not provided, + uses posData.frame_i, by default None + track_og_curr : bool, optional + if True, will track the original current IDs, by default False + new_frame : bool, optional + if True, will set the frame to the new frame, by default False + IDs_to_add : list, optional + IDs to add to the whitelist, by default None + IDs_to_remove : list, optional + IDs to remove from the whitelist, by default None + """ + got_data = False + benchmark = False + if benchmark: + ts = [time.perf_counter()] + titles = [ + "", + "store_data", + "whitelistSetViewOGIDsToggle", + "get_data", + "get what to add/remove", + "track_og_curr", + "get current lab", + "add/remove IDs", + "store data", + "update images", + ] + + mode = self.modeComboBox.currentText() + if mode != "Segmentation and Tracking": + return + + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + if frame_i is None: + frame_i = posData.frame_i + og_frame_i = frame_i + else: + og_frame_i = posData.frame_i + posData.frame_i = frame_i + # getting data is handles later in the code + + debug = posData.whitelist._debug + if debug: + printl("whitelistUpdateLab", frame_i, og_frame_i) + from . import debugutils + + debugutils.print_call_stack() + + if benchmark: + ts.append(time.perf_counter()) + + self.whitelistSetViewOGIDsToggle(False) ### + + if benchmark: + ts.append(time.perf_counter()) + + if self.whitelistCheckOriginalLabels(warning=False, frame_i=frame_i): + og_lab = posData.whitelist.originalLabs[frame_i] ### + else: + og_lab = None + if benchmark: + ts.append(time.perf_counter()) + + #### + whitelist = posData.whitelist.get(frame_i=frame_i) + IDs_to_add_remove_provided = IDs_to_add is not None or IDs_to_remove is not None + if not IDs_to_add_remove_provided: + self.get_data() + got_data = True + current_IDs = set(posData.IDs) + missing_IDs = list(whitelist - current_IDs) + to_be_removed_IDs = list(current_IDs - whitelist) + else: + missing_IDs = list(IDs_to_add) if IDs_to_add is not None else [] + to_be_removed_IDs = list(IDs_to_remove) if IDs_to_remove is not None else [] + + ### + + if benchmark: + ts.append(time.perf_counter()) + + ### + if not missing_IDs and not to_be_removed_IDs: # nothing to do + if og_frame_i != frame_i: + posData.frame_i = og_frame_i + if got_data and og_frame_i != frame_i: + self.get_data() + if benchmark: + print("No IDs to add/remove") + ts.append(time.perf_counter()) + indx = titles.index("track_og_curr") + titles[indx + 1] = "store_data" + time_taken = time.perf_counter() - ts[0] + print(f"\nTotal time for whitelistUpdateLab: {time_taken:.2f}s") + for i in range(1, len(ts)): + time_taken = ts[i] - ts[i - 1] + print(f"Time taken for {titles[i]}: {time_taken:.2f}s") + print("") + return + + if not got_data and og_frame_i != frame_i: + self.get_data() + got_data = True + + if benchmark: + ts.append(time.perf_counter()) + + ### + if missing_IDs and track_og_curr and not new_frame: + self.whitelistTrackOGCurr(frame_i=frame_i, lab=posData.lab, rp=posData.rp) + + missing_IDs = np.array(missing_IDs, dtype=np.int32) + to_be_removed_IDs = np.array(to_be_removed_IDs, dtype=np.int32) + + if debug: + printl(missing_IDs, to_be_removed_IDs) + + curr_lab = posData.lab # or curr_lab = posData.lab??? + # convert values to int if they are not already + if curr_lab is None: + try: + curr_lab = posData.allData_li[frame_i]["labels"].copy() + except: + pass + if curr_lab is None: + try: + curr_lab = posData.segm_data[frame_i].copy() + except: + pass + if curr_lab is None: + printl("No current lab?") + curr_lab = np.zeros_like(posData.segm_data[0]) + curr_lab = curr_lab.astype(np.int32) + if benchmark: + ts.append(time.perf_counter()) + + if missing_IDs.size > 0 and og_lab is not None: + mask = np.isin(og_lab, missing_IDs) # add missing_IDs + curr_lab[mask] = og_lab[mask] + + if to_be_removed_IDs.size > 0: + curr_lab[np.isin(curr_lab, to_be_removed_IDs)] = ( + 0 # remove to_be_removed_IDs + ) + + if benchmark: + ts.append(time.perf_counter()) + + posData.lab = curr_lab + + self.update_rp(wl_update=False) + self.store_data() + + if benchmark: + ts.append(time.perf_counter()) + if og_frame_i != frame_i: + posData.frame_i = og_frame_i + self.get_data() + + self.updateAllImages() + self.setAllTextAnnotations() + + if benchmark: + ts.append(time.perf_counter()) + time_taken = time.perf_counter() - ts[0] + print(f"\nTotal time for whitelistUpdateLab: {time_taken:.2f}s") + for i in range(1, len(ts)): + time_taken = ts[i] - ts[i - 1] + print(f"Time taken for {titles[i]}: {time_taken:.2f}s") + print("") + + def whitelistIDsUpdateText(self): + """Updates the text. Carefull, triggers whitelistLineEdit.textChanged!""" + mode = self.modeComboBox.currentText() + if mode != "Segmentation and Tracking": + return + + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + if posData.whitelist._debug: + printl("whitelistIDsUpdateText") + + frame_i = posData.frame_i + whitelist = posData.whitelist.get(frame_i=frame_i) + + self.whitelistIDsToolbar.whitelistLineEdit.setText(whitelist) + + def whitelistTrackOGCurr( + self, + frame_i: int = None, + against_prev: bool = False, + lab: np.ndarray = None, + rp: list = None, + IDs: Set[int] | List[int] = None, + ): + """Track the original labels in relation to the current (whitelisted) + labels. + Parameters + + Parameters + ---------- + frame_i : int, optional + frame_i to be tracked, posData.frame_i if not provided, + by default None + against_prev : bool, optional + if the original frame should be tracked against frame_i-1. + Cannot be used with rp or lab, by default False + lab : np.ndarray, optional + lab to be tracked against, by default None + rp : list, optional + regionprops for this lab, by default None + IDs : Set[int] | List[int], optional + IDs that should be tracked based on og + + Raises + ------ + ValueError + Cannot provide both rp and lab when tracking against previous frame. + Instead only provide rp and lab, and dont set against_prev. + """ + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + debug = posData.whitelist._debug + + if debug: + from . import debugutils + + debugutils.print_call_stack(depth=2) + printl("whitelistTrackOGCurr", against_prev) + + if against_prev and (rp is not None or lab is not None): + raise ValueError( + "Cannot provide both rp and lab when tracking" + " against previous frame." + "Instead only provide rp and lab, and dont set against_prev." + ) + + if frame_i is None: + frame_i = posData.frame_i + + if against_prev and frame_i == 0: + return + + if not self.whitelistCheckOriginalLabels(warning=False, frame_i=frame_i): + if debug: + printl("No original labels, cannot track.") + return + + og_frame_i = posData.frame_i + ### against what should I track? + + if lab is not None and not rp: + rp = skimage.measure.regionprops(lab) + + changed_frame = False + if lab is None: + if debug: + printl("No lab and no rp provided.") + if against_prev: + rp = posData.allData_li[frame_i - 1]["regionprops"] + lab = posData.allData_li[frame_i - 1]["labels"] + else: + if frame_i != og_frame_i: + self.store_data(autosave=False) + posData.frame_i = frame_i + self.get_data() + changed_frame = True + rp = posData.rp + lab = posData.lab + og_lab = posData.whitelist.originalLabs[frame_i] + og_rp = skimage.measure.regionprops(og_lab) + # lab = lab.copy() + + denom_overlap_matrix = "union" if not against_prev else "area_prev" + + og_lab = CellACDC_tracker.track_frame( + lab, + rp, + og_lab, + og_rp, + denom_overlap_matrix=denom_overlap_matrix, + posData=posData, + setBrushID_func=self.setBrushID, + IDs=IDs, + # assign_unique_new_IDs=False, + ) + + posData.whitelist.originalLabs[frame_i] = og_lab + posData.whitelist.originalLabsIDs[frame_i] = { + obj.label for obj in skimage.measure.regionprops(og_lab) + } + + if changed_frame: + posData.frame_i = og_frame_i + self.get_data() + + def whitelistTrackCurrOG(self, frame_i: int = None, against_prev: bool = False): + """Track the current (whitelisted) labels in relation to the original labels. + Parameters + ---------- + frame_i : int, optional + frame_i to be tracked, posData.frame_i if not provided, by default None + against_prev : bool, optional + if the original frame should be tracked against frame_i-1. + """ + posData = self.data[self.pos_i] + if posData.whitelist is None: + return + + if posData.whitelist._debug: + printl("whitelistTrackCurrOG", frame_i, against_prev) + + if frame_i is None: + frame_i = posData.frame_i + + if against_prev and frame_i == 0: + return + + og_frame = posData.frame_i + if frame_i != og_frame: + self.store_data(autosave=False) + posData.frame_i = frame_i + self.get_data() + + lab = posData.lab + rp = posData.rp + + if not self.whitelistCheckOriginalLabels( + warning=False, frame_i=frame_i if not against_prev else frame_i - 1 + ): + if posData.whitelist._debug: + printl("No original labels, cannot track.") + return + + if against_prev: + og_lab = posData.whitelist.originalLabs[frame_i - 1] + else: + og_lab = posData.whitelist.originalLabs[frame_i] + + og_rp = skimage.measure.regionprops(og_lab) + + denom_overlap_matrix = "union" if not against_prev else "area_prev" + + lab = CellACDC_tracker.track_frame( + og_lab, + og_rp, + lab, + rp, + denom_overlap_matrix=denom_overlap_matrix, + posData=posData, + setBrushID_func=self.setBrushID, + ) + + posData.lab = lab + + self.update_rp(wl_update=False) + self.store_data(autosave=False) + + if frame_i != og_frame: + posData.frame_i = og_frame + self.get_data() + + def whitelistSyncIDsOG( + self, + frame_is: List[int] = None, + against_prev: bool = False, + ): + """Interates over the frames and calls whitelistTrackOGCurr for each frame. + + Parameters + ---------- + frame_is : List[int], optional + list of frame_i, if None goes through all, by default None + against_prev : bool, optional + if the original frame should be tracked against frame_i-1. + """ + posData = self.data[self.pos_i] + if frame_is is None: + frame_is = range(posData.SizeT) + + for frame_i in frame_is: + self.whitelistTrackOGCurr(frame_i=frame_i, against_prev=against_prev) + + def whitelistInitNewFrames(self, frame_i: int = None, force: bool = False): + """Initialize the whitelist for a new frame. The class whitelist keeps track + of the init frames and doesnt try to init them again, unless forced. + Does not init the class! + + Parameters + ---------- + frame_i : int, optional + frame_i to be init, posData.frame_i if not provided, by default None + force : bool, optional + if the init should be forced, by default False + + Returns + ------- + bool + if the frame was new or not + list + list of frames that were updated, and info about added/removed IDs + """ + + posData = self.data[self.pos_i] + if posData.whitelist is None: + return False, [] + + if frame_i is None: + frame_i = posData.frame_i + + if posData.whitelist._debug: + printl("whitelistInitNewFrames", frame_i, force) + + if frame_i not in posData.whitelist.initialized_i: + self.whitelistTrackOGCurr(frame_i=frame_i, against_prev=True) + + new_frame, update_frames = posData.whitelist.initNewFrames( + frame_i=frame_i, force=force + ) + + self.whitelistAddNewIDs() + return new_frame, update_frames + + # @exec_time + def whitelistPropagateIDs( + self, + new_whitelist: Set[int] | List[int] = None, + IDs_to_add: Set[int] = None, + IDs_to_remove: Set[int] = None, + frame_i: int = None, + try_create_new_whitelists: bool = False, + curr_frame_only: bool = False, + force_not_dynamic_update: bool = False, + only_future_frames: bool = True, + allow_only_current_IDs: bool = False, + track_og_curr: bool = True, + IDs_curr: Set[int] | List[int] = None, + index_lab_combo: Tuple[int, np.ndarray] = None, + curr_rp: list = None, + curr_lab: np.ndarray = None, + store_data: bool = True, + update_lab: bool = False, + ): + """ + Propagates whitelist IDs across frames in the dataset. (Doesnt update labs) + Should also be called when viewing a new frame! + + This function updates whitelist. If curr_frame_only is True, it only updates the + whitelist of the current frame. If the frame changes, this function should be called + again to update the whitelist for the new frame (without this argument). + It should also handle cases were this is not done, but this is less safe. + Then, all the additions and removals are propagated to the other frames. + If force_not_dynamic_update is True, the function will propagate the entire whitelist to + frames, and not only the IDs which were added or removed. + + Hierarchy of arguments for current_IDs: + 1. IDs_curr (if provided) + (2. index_lab_combo (if provided) (is also passed to not current frame only + propagation if that propagation is necessary, and used when the frame_i matches)) + 3. curr_rp (if provided) + 4. curr_lab (if provided) + 5. allData_li + + Parameters + ---------- + new_whitelist : Set[int] | List[int], optional + A new set of whitelist IDs to replace the current whitelist. Cannot be + used together with `IDs_to_add` or `IDs_to_remove`, by default None. + IDs_to_add : Set[int], optional + A set of IDs to add to the current whitelist, by default None. + IDs_to_remove : Set[int], optional + A set of IDs to remove from the current whitelist, by default None. + frame_i : int, optional + The frame index for the propagation. + If None, uses posData.frame_i, by default None. + try_create_new_whitelists : bool, optional + If True, creates new whitelist entries for frames that do not already + have them. Should only be necessary when its initialized, by default False. + curr_frame_only : bool, optional + If True, only updates the whitelist for the current frame. + (See description of function), by default False. + force_not_dynamic_update : bool, optional + If True, disables dynamic updates to the whitelist. + (See description of function), by default False. + only_future_frames : bool, optional + If True, propagates changes only to future frames, by default True. + allow_only_current_IDs : bool, optional + If True, only allows IDs that are present in the current frame + to be added to the whitelist, by default True. + track_og_curr : bool, optional + If True, tracks the original labels in relation to the current + (whitelisted) labels. This is done by calling whitelistTrackOGCurr. + If its a new frame, this is done in whitelistInitNewFrames against the + previous frame, + by default True. + IDs_curr : Set[int] | List[int], optional + A set of IDs for the current frame, if None, + will be calculated from other stuff (see description), by default None. + index_lab_combo : Tuple[int, np.ndarray], optional + Combination of frame_i and current frame, + Used to get IDs_curr (see description), when the frame_i matches + (is also passed to not current frame only + propagation if that propagation is necessary, + and used when the frame_i matches), by default None. + curr_rp : list, optional + Region properties for the current frame. For IDs_curr. (see description), + by default None. + curr_lab : np.ndarray, optional + Labels for the current frame for IDs_curr. (see description), + by default None. + store_data : bool, optional + If True, stores the data before propagating the IDs. + update_lab : bool, optional + If True, updates the labels after propagating the IDs. + Will always update labels for newly init frames, by default False. + + Raises + ------ + ValueError + If both `new_whitelistIDs` and `IDs_to_add`/`IDs_to_remove` are provided. + + Example + ------- + To add IDs 5 and 6 to the whitelist for the current frame: + ```python + self.whitelistPropagateIDs(IDs_to_add={5, 6}, curr_frame_only=True) + ``` + Then when the frame changes: + ```python + self.whitelistPropagateIDs() + ``` + + To replace the whitelist for frame 10 with a new set of IDs: + ```python + self.whitelistPropagateIDs(new_whitelistIDs={1, 2, 3}, frame_i=10) + ``` + This would also propagate the changes to all other frames. + + """ + # doesnt update the frame displayed, only wl + try: # safety XD + IDs_curr = IDs_curr.copy() + except AttributeError: + pass + + IDs_curr = set(IDs_curr) if IDs_curr is not None else None + + posData = self.data[self.pos_i] + + debug = posData.whitelist._debug if posData.whitelist is not None else False + + if debug: + printl("Propagating IDs...") + from . import debugutils + + debugutils.print_call_stack() + printl(new_whitelist, IDs_to_add, IDs_to_remove) + + if posData.whitelist is None: + return + + # og_frame_i = posData.frame_i + if frame_i is None: + frame_i = posData.frame_i + + new_frame, update_frames_init = self.whitelistInitNewFrames(frame_i=frame_i) + + if new_frame: + self.update_rp(wl_update=False) + # if track_og_curr and not new_frame: + # self.whitelistTrackOGCurr(frame_i=frame_i, rp=curr_rp, lab=curr_lab) + + update_frames = posData.whitelist.propagateIDs( + frame_i, + posData.allData_li, + new_whitelist=new_whitelist, + IDs_to_add=IDs_to_add, + IDs_to_remove=IDs_to_remove, + try_create_new_whitelists=try_create_new_whitelists, + curr_frame_only=curr_frame_only, + force_not_dynamic_update=force_not_dynamic_update, + only_future_frames=only_future_frames, + allow_only_current_IDs=allow_only_current_IDs, + IDs_curr=IDs_curr, + index_lab_combo=index_lab_combo, + curr_rp=curr_rp, + curr_lab=curr_lab, + ) + if update_lab: + update_frames = update_frames_init + update_frames + else: + update_frames = update_frames_init + # printl(posData.whitelistIDs[frame_i]) + # posData.frame_i = og_frame_i + self.whitelistIDsUpdateText() + if store_data: + self.store_data(autosave=False) + + for frame_i, IDs_to_add, IDs_to_remove, new_frame in update_frames: + self.whitelistUpdateLab( + frame_i=frame_i, + track_og_curr=track_og_curr, + new_frame=new_frame, + IDs_to_add=IDs_to_add, + IDs_to_remove=IDs_to_remove, + ) + + def whitelistIDs_cb(self, checked: bool): + """Callback for when the whitelist IDs button is checked or unchecked. + Initialises the pointlayer and the whitelist IDs toolbar if checked. + + Parameters + ---------- + checked : bool + True if the whitelist IDs button is checked, False otherwise. + """ + if checked: + self.initKeepObjLabelsLayers() + self.disconnectLeftClickButtons() + self.uncheckLeftClickButtons(self.whitelistIDsButton) + self.connectLeftClickButtons() + + self.whitelistIDsToolbar.setVisible(checked) + self.whitelistHighlightIDs(checked) + self.whitelistIDsUpdateText() + self.whitelistUpdateTempLayer() + + if not checked: + self.setLostNewOldPrevIDs() + self.updateAllImages() + + def whitelistHighlightIDs(self, checked: bool = True): + """Highlights the IDs in the current frame based on the whitelist. + + Parameters + ---------- + checked : bool, optional + If False, will delete all highlights, by default True + """ + if not checked: + self.removeHighlightLabelID() + return + + posData = self.data[self.pos_i] + + if posData.whitelist is None: + if not hasattr(self, "tempWhitelistIDs"): + self.tempWhitelistIDs = set() # not updated, only use in this context + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = posData.whitelist.get(frame_i=posData.frame_i) + + for ID in current_whitelist: + self.highlightLabelID(ID) + + def whitelistIDsChanged( + self, whitelistIDs: Set[int] | List[int], debug: bool = False + ): + """Callback for when the whitelist IDs are changed. + This is called when the user changed the IDs in the whitelist IDs toolbar + (or when its programmatically changed, but if its not + visible it should return instantly) + Will update the temp layer and also complain when IDs + are not valid/present in the current lab + + Parameters + ---------- + whitelistIDs : set | list + The IDs that are currently in the whitelist. + debug : bool, optional + debug, by default False + """ + if not self.whitelistIDsButton.isChecked(): + return + + posData = self.data[self.pos_i] + + if posData.whitelist: + debug = posData.whitelist._debug + if debug: + printl("whitelistIDsChanged", whitelistIDs) + + if posData.whitelist is None: + wl_init = False + if not hasattr(self, "tempWhitelistIDs"): + self.tempWhitelistIDs = set() # not updated, only use in this context + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = self.tempWhitelistIDs + else: + wl_init = True + current_whitelist = posData.whitelist.get(frame_i=posData.frame_i) + + current_whitelist_copy = current_whitelist.copy() + if ( + not hasattr(posData, "originalLabsIDs") + or posData.whitelist.originalLabsIDs is None + ): + possible_IDs = posData.IDs.copy() + else: + if not self.whitelistCheckOriginalLabels(warning=False): + possible_IDs = set(posData.IDs) + else: + possible_IDs = posData.whitelist.originalLabsIDs[posData.frame_i] + possible_IDs.update(posData.IDs) + + isAnyIDnotExisting = False + for ID in whitelistIDs: + if ID not in possible_IDs: + isAnyIDnotExisting = True + continue + if ID not in current_whitelist_copy: + current_whitelist.add(ID) + self.highlightLabelID(ID) + + for ID in current_whitelist_copy: + if ID not in possible_IDs: + isAnyIDnotExisting = True + continue + if ID not in whitelistIDs: + current_whitelist.remove(ID) + self.removeHighlightLabelID(IDs=[ID]) + + if wl_init: + posData.whitelist.whitelistIDs[posData.frame_i] = current_whitelist + else: + self.tempWhitelistIDs = current_whitelist + + self.whitelistUpdateTempLayer() + if isAnyIDnotExisting: + self.whitelistIDsToolbar.whitelistLineEdit.warnNotExistingID() + else: + self.whitelistIDsToolbar.whitelistLineEdit.setInstructionsText() + + # @exec_time + def whitelistUpdateTempLayer(self): + """Updates the temp layer with the current whitelist IDs.""" + if not self.whitelistIDsButton.isChecked(): + self.keepIDsTempLayerLeft.clear() + return + + if not hasattr(self, "keptLab"): + self.keptLab = np.zeros_like(self.currentLab2D) + keptLab = self.keptLab + else: + keptLab = self.keptLab + keptLab[:] = 0 + + posData = self.data[self.pos_i] + if posData.whitelist is None: + if not hasattr(self, "tempWhitelistIDs"): + self.tempWhitelistIDs = set() # not updated, only use in this context + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = self.tempWhitelistIDs + else: + current_whitelist = posData.whitelist.get(posData.frame_i) + + for obj in posData.rp: + if obj.label not in current_whitelist: + continue + + if not self.isObjVisible(obj.bbox): + continue + + _slice = self.getObjSlice(obj.slice) + _objMask = self.getObjImage(obj.image, obj.bbox) + + keptLab[_slice][_objMask] = obj.label + + self.keepIDsTempLayerLeft.setImage(keptLab, autoLevels=False) diff --git a/cellacdc/mixins/window_events.py b/cellacdc/mixins/window_events.py new file mode 100644 index 000000000..1f1a0ae69 --- /dev/null +++ b/cellacdc/mixins/window_events.py @@ -0,0 +1,975 @@ +"""Qt view adapter for main-window and pointer events.""" + +from __future__ import annotations + +import gc +import os +import traceback +import time + +from qtpy.QtCore import Qt, QSettings, QTimer +from qtpy.QtGui import QCursor, QFont, QKeyEvent, QKeySequence, QPixmap +from qtpy.QtWidgets import QAbstractSlider, QCheckBox, QMainWindow + +from cellacdc import ( + apps, + exception_handler, + html_utils, + is_mac, + printl, + qutils, + widgets, +) +from cellacdc.plot import imshow + + +_font = QFont() +_font.setPixelSize(11) + +from .app_shell import AppShell +from .frame_navigation import FrameNavigation + + +class WindowEvents(AppShell, FrameNavigation): + """Extracted from guiWin.""" + + def _resizeLeaveSpaceTerminalBelow(self): + geometry = self.geometry() + left = geometry.left() + top = geometry.top() + width = geometry.width() + height = geometry.height() + self.setGeometry(left, top + 10, width, height - 200) + + def _resizeSlidersArea(self): + self.navigateScrollBar.setFixedHeight(self.newHeight) + self.zSliceScrollBar.setFixedHeight(self.newHeight) + self.zSliceOverlay_SB.setFixedHeight(self.newHeight) + self.zProjComboBox.setFixedHeight(self.newHeight) + self.zProjOverlay_CB.setFixedHeight(self.newHeight) + self.navSpinBox.setFixedHeight(self.newHeight) + self.zSliceSpinbox.setFixedHeight(self.newHeight) + try: + self.img1.alphaScrollbar.setFixedHeight(self.newHeight) + except Exception as e: + pass + try: + for channel, items in self.overlayLayersItems.items(): + alphaScrollbar = items[2] + alphaScrollbar.setFixedHeight(self.newHeight) + except: + pass + checkBoxStyleSheet = ( + "QCheckBox::indicator {" + f"width: {self.newCheckBoxesHeight}px;" + f"height: {self.newCheckBoxesHeight}px" + "}" + ) + for i in range(self.annotOptionsLayout.count()): + widget = self.annotOptionsLayout.itemAt(i).widget() + if isinstance(widget, QCheckBox): + widget.setStyleSheet(checkBoxStyleSheet) + for i in range(self.annotOptionsLayoutRight.count()): + widget = self.annotOptionsLayoutRight.itemAt(i).widget() + if isinstance(widget, QCheckBox): + widget.setStyleSheet(checkBoxStyleSheet) + self.zSliceCheckbox.setStyleSheet(checkBoxStyleSheet) + + def _temp_debug(self, id=None): + posData = self.data[self.pos_i] + imshow(posData.lab, annotate_labels_idxs=[0]) + + def askCloseAllWindows(self): + txt = html_utils.paragraph(""" + There are other open windows that were created from this window. +

    + If you proceed, the other windows will be closed too.
    + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning(self, "Open windows", txt, buttonsTexts=("Cancel", "Ok, close now")) + return msg.cancel + + def changeEvent(self, event): + try: + self.delObjToolAction.setChecked(False) + except Exception as err: + return + + def changeRightClickToLeftOnMac(self, mouseEvent): + button = mouseEvent.button() + if not is_mac: + return button + + delObjKeySequence, delObjQtButton = self.delObjAction + if delObjKeySequence is None: + return button + + if not delObjKeySequence.toString() == "Control": + return button + + if button != Qt.MouseButton.RightButton: + return button + + if delObjQtButton == Qt.MouseButton.LeftButton: + # On mac, pressing "Control" and clicking with left button changes + # it to a right click button --> here, left click is required for + # delete object --> force return of left click + return Qt.MouseButton.LeftButton + + return button + + def checkOverlayToolbuttonClicked(self, event): + success = False + try: + n = int(event.text()) + toolbutton = self.allOverlayToolbuttonsByIdx.get(n, None) + toolbutton.click() + success = True + except Exception as e: + # printl(traceback.format_exc()) + success = False + return success + + def checkSetDelObjActionActive(self, event): + if self.delObjAction is None and self.is_win: + return + + if self.delObjAction is None: + # On mac we check for Key_Control + if event.key() == Qt.Key_Control: + self.delObjToolAction.setChecked(True) + return + + delObjKeySequence, delObjQtButton = self.delObjAction + keySequenceText = widgets.QKeyEventToString(event).rstrip("+") + + if delObjKeySequence is None: + # self.delObjToolAction.setChecked(True) + return + + delObjKeySequenceText = widgets.macShortcutToWindows( + delObjKeySequence.toString() + ) + keySequenceText = widgets.macShortcutToWindows(keySequenceText) + + # printl( + # delObjKeySequence.toString(), + # keySequenceText, + # delObjKeySequenceText + # ) + + if keySequenceText == delObjKeySequenceText: + self.delObjToolAction.setChecked(True) + + def checkTriggerKeyPressShortcuts(self, event: QKeyEvent): + isBrushKey = event.key() == self.brushButton.keyPressShortcut + isEraserKey = event.key() == self.eraserButton.keyPressShortcut + if isBrushKey or isEraserKey: + return isBrushKey, isEraserKey + + modifierText = widgets.modifierKeyToText(event.modifiers()) + for widget in self.widgetsWithShortcut.values(): + if not hasattr(widget, "keyPressShortcut"): + continue + + if event.key() == widget.keyPressShortcut: + if widget.isCheckable(): + widget.setChecked(True) + else: + widget.trigger() + continue + + shortcutText = widget.keyPressShortcut.toString() + try: + mod, key = shortcutText.split("+") + if modifierText == mod and event.key() == QKeySequence(key): + widget.trigger() + + except Exception as e: + pass + + return isBrushKey, isEraserKey + + def clearMemory(self): + if not hasattr(self, "data"): + return + self.logger.info("Clearing memory...") + for posData in self.data: + try: + del posData.img_data + except Exception as e: + pass + try: + del posData.segm_data + except Exception as e: + pass + try: + del posData.ol_data_dict + except Exception as e: + pass + try: + del posData.fluo_data_dict + except Exception as e: + pass + try: + del posData.ol_data + except Exception as e: + pass + del self.data + + def closeEvent(self, event): + self.setDisabled(False) + cancel = self.checkAskSavePointsLayers() + if cancel: + event.ignore() + return + + self.onEscape() + self.saveWindowGeometry() + + if self.newWindows: + cancel = self.askCloseAllWindows() + if cancel: + event.ignore() + return + + for window in self.newWindows: + window.close() + + if self.slideshowWin is not None: + self.slideshowWin.close() + if self.ccaTableWin is not None: + self.ccaTableWin.close() + + proceed = self.askSaveOnClosing(event) + if not proceed: + event.ignore() + return + + self.autoSaveClose() + + if self.autoSaveActiveWorkers: + progressWin = apps.QDialogWorkerProgress( + title="Closing autosaving worker", + parent=self, + pbarDesc="Closing autosaving worker...", + ) + progressWin.show(self.app) + progressWin.mainPbar.setMaximum(0) + self.waitCloseAutoSaveWorkerLoop = qutils.QWhileLoop( + self._waitCloseAutoSaveWorker, period=250 + ) + self.waitCloseAutoSaveWorkerLoop.exec_() + progressWin.workerFinished = True + progressWin.close() + + self.stopPreprocWorker() + self.stopCombineWorker() + self.stopCcaIntegrityCheckerWorker() + + # Close the inifinte loop of the thread + if self.lazyLoader is not None: + self.lazyLoader.exit = True + self.lazyLoaderWaitCond.wakeAll() + self.waitReadH5cond.wakeAll() + + if self.storeStateWorker is not None: + # Close storeStateWorker + self.storeStateWorker._stop() + while self.storeStateWorker.isFinished: + time.sleep(0.05) + + # Block main thread while separate threads closes + time.sleep(0.1) + + self.clearMemory() + + self.logger.info("Closing GUI logger...") + self.logger.close() + + if self.lazyLoader is None: + self.sigClosed.emit(self) + + gc.collect() + + def doubleKeySpacebarTimerCallback(self): + if self.isKeyDoublePress: + self.doubleKeyTimeElapsed = False + return + self.doubleKeyTimeElapsed = True + self.countKeyPress = 0 + + def doubleKeyTimerCallBack(self): + if self.isKeyDoublePress: + self.doubleKeyTimeElapsed = False + return + self.doubleKeyTimeElapsed = True + self.countKeyPress = 0 + if self.Button is None: + return + + isBrushChecked = self.Button.isChecked() + if isBrushChecked and self.uncheck: + self.Button.setChecked(False) + c = self.defaultToolBarButtonColor + self.Button.setStyleSheet(f"background-color: {c}") + + def doubleRightClickTimerCallBack(self): + if self.isDoubleRightClick: + self.doubleRightClickTimeElapsed = False + return + self.doubleRightClickTimeElapsed = True + self.countRightClicks = 0 + + # Time to double right click on img1 expired --> single right-click + self.gui_imgGradShowContextMenu(*self._img1_click_xy) + + def dragEnterEvent(self, event): + file_path = event.mimeData().urls()[0].toLocalFile() + if os.path.isdir(file_path): + exp_path = file_path + basename = os.path.basename(file_path) + if basename.find("Position_") != -1 or basename == "Images": + event.acceptProposedAction() + else: + event.ignore() + else: + event.acceptProposedAction() + + def dropEvent(self, event): + event.setDropAction(Qt.CopyAction) + file_path = event.mimeData().urls()[0].toLocalFile() + self.logger.info(f'Dragged and dropped path "{file_path}"') + basename = os.path.basename(file_path) + if os.path.isdir(file_path): + exp_path = file_path + self.openFolder(exp_path=exp_path) + else: + self.openFile(file_path=file_path) + + def editingSpinboxValueTimerCallback(self): + self.typingEditID = False + + def enterEvent(self, event): + event.accept() + if self.slideshowWin is not None: + posData = self.data[self.pos_i] + mainWinGeometry = self.geometry() + mainWinLeft = mainWinGeometry.left() + mainWinTop = mainWinGeometry.top() + mainWinWidth = mainWinGeometry.width() + mainWinHeight = mainWinGeometry.height() + mainWinRight = mainWinLeft + mainWinWidth + mainWinBottom = mainWinTop + mainWinHeight + + slideshowWinGeometry = self.slideshowWin.geometry() + slideshowWinLeft = slideshowWinGeometry.left() + slideshowWinTop = slideshowWinGeometry.top() + slideshowWinWidth = slideshowWinGeometry.width() + slideshowWinHeight = slideshowWinGeometry.height() + + # Determine if overlap + overlap = (slideshowWinTop < mainWinBottom) and ( + slideshowWinLeft < mainWinRight + ) + + autoActivate = ( + self.isDataLoaded + and not overlap + and not posData.disableAutoActivateViewerWindow + ) + + if autoActivate: + # self.setFocus() + self.activateWindow() + + def gui_createCursors(self): + pixmap = QPixmap(":wand_cursor.svg") + self.wandCursor = QCursor(pixmap, 16, 16) + + pixmap = QPixmap(":curv_cursor.svg") + self.curvCursor = QCursor(pixmap, 16, 16) + + pixmap = QPixmap(":addDelPolyLineRoi_cursor.svg") + self.polyLineRoiCursor = QCursor(pixmap, 16, 16) + + pixmap = QPixmap(":cross_cursor.svg") + self.addPointsCursor = QCursor(pixmap, 16, 16) + + def keyDownCallback( + self, isBrushActive, isWandActive, isExpandLabelActive, isLabelRoiCircActive + ): + isAutoPilotActive = ( + self.autoPilotZoomToObjToggle.isChecked() + and self.autoPilotZoomToObjToolbar.isVisible() + ) + if isBrushActive: + brushSize = self.brushSizeSpinbox.value() + self.brushSizeSpinbox.setValue(brushSize - 1) + elif isWandActive: + wandTolerance = self.wandControlsToolbar.toleranceSpinbox.value() + self.wandControlsToolbar.toleranceSpinbox.setValue(wandTolerance - 1) + elif isExpandLabelActive: + self.expandLabel(dilation=False) + self.expandFootprintSize += 1 + elif isLabelRoiCircActive: + val = self.labelRoiCircularRadiusSpinbox.value() + self.labelRoiCircularRadiusSpinbox.setValue(val - 1) + elif isAutoPilotActive: + self.pointsLayerAutoPilot("prev") + elif self.isNavigateActionOnNextFrame(): + posData = self.data[self.pos_i] + self.rightImageFramesScrollbar.setValue(posData.frame_i + 2) + else: + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepSub + ) + + def keyPressCheckSetSpinboxValue(self, event, spinbox): + """Check if the key pressed is a digit and set the spinbox value + accordingly.""" + try: + n = int(event.text()) + if self.typingEditID: + value = int(f"{spinbox.value()}{n}") + else: + value = n + self.typingEditID = True + spinbox.setValue(value) + + try: + spinbox.timer.stop() + except Exception as err: + pass + + spinbox.timer = QTimer(spinbox) + spinbox.timer.timeout.connect(self.editingSpinboxValueTimerCallback) + spinbox.timer.start(2000) + spinbox.timer.setSingleShot(True) + success = True + except Exception as e: + # printl(traceback.format_exc()) + success = False + return success + + def keyPressEvent(self, ev): + ctrl = ev.modifiers() == Qt.ControlModifier + if ctrl and ev.key() == Qt.Key_D: + self.resizeLeaveSpaceTerminalBelow() + return + + if ev.key() == Qt.Key_Q and self.debug: + try: + from . import _q_debug + + _q_debug.q_debug(self) + except Exception as err: + printl(traceback.format_exc()) + printl('[ERROR]: Error with "_qdebug" module. See Traceback above.') + pass + + if not self.isDataLoaded: + self.logger.warning( + "Data not loaded yet. Key pressing events are not connected." + ) + return + + if ev.key() == Qt.Key_Control: + if not ctrl: + self.wasCtrlPressedFirstTime = True + self.onCtrlPressedFirstTime() + + if ev.key() == Qt.Key_PageDown: + self.onKeyPageDown() + + if ev.key() == Qt.Key_PageUp: + self.onKeyPageUp() + + if ev.key() == Qt.Key_Home: + self.onKeyHome() + + if ev.key() == Qt.Key_End: + self.onKeyEnd() + + modifiers = ev.modifiers() + isAltModifier = modifiers == Qt.AltModifier + isCtrlModifier = modifiers == Qt.ControlModifier + isShiftModifier = modifiers == Qt.ShiftModifier + + self.checkSetDelObjActionActive(ev) + + self.isZmodifier = ( + ev.key() == Qt.Key_Z + and not isAltModifier + and not isCtrlModifier + and not isShiftModifier + ) + if isShiftModifier: + if self.brushButton.isChecked(): + # Force default brush symbol with shift down + self.setHoverToolSymbolColor( + 1, + 1, + self.ax2_BrushCirclePen, + (self.ax2_BrushCircle, self.ax1_BrushCircle), + self.brushButton, + brush=self.ax2_BrushCircleBrush, + ID=0, + ) + if self.isSegm3D: + self.changeBrushID() + + isAnyModifier = isAltModifier or isCtrlModifier or isShiftModifier + if not isAnyModifier and self.overlayButton.isChecked(): + isButtonClicked = self.checkOverlayToolbuttonClicked(ev) + if isButtonClicked: + return + + isBrushActive = self.brushButton.isChecked() or self.eraserButton.isChecked() + isManualTrackingActive = self.manualTrackingButton.isChecked() + isManualBackgroundActive = self.manualBackgroundButton.isChecked() + isTypingIDFunctionChecked = False + if self.brushButton.isChecked() and not self.autoIDcheckbox.isChecked(): + success = self.keyPressCheckSetSpinboxValue(ev, self.editIDspinbox) + isTypingIDFunctionChecked = True + + if isManualTrackingActive: + isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( + ev, self.manualTrackingToolbar.spinboxID + ) + + elif isManualBackgroundActive: + isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( + ev, self.manualBackgroundToolbar.spinboxID + ) + + addPointsByClickingButton = self.buttonAddPointsByClickingActive() + if ( + addPointsByClickingButton is not None + and addPointsByClickingButton.toolbar.isVisible() + ): + isTypingIDFunctionChecked = self.keyPressCheckSetSpinboxValue( + ev, addPointsByClickingButton.rightClickIDSpinbox + ) + + isBrushKey, isEraserKey = self.checkTriggerKeyPressShortcuts(ev) + isExpandLabelActive = self.expandLabelToolButton.isChecked() + isWandActive = self.wandToolButton.isChecked() + isLabelRoiCircActive = ( + self.labelRoiButton.isChecked() + and self.labelRoiIsCircularRadioButton.isChecked() + ) + how = self.drawIDsContComboBox.currentText() + isOverlaySegm = how.find("overlay segm. masks") != -1 + if ev.key() == Qt.Key_Up and not isCtrlModifier: + self.keyUpCallback( + isBrushActive, isWandActive, isExpandLabelActive, isLabelRoiCircActive + ) + elif ev.key() == Qt.Key_Down and not isCtrlModifier: + self.keyDownCallback( + isBrushActive, isWandActive, isExpandLabelActive, isLabelRoiCircActive + ) + elif ev.key() == Qt.Key_Enter or ev.key() == Qt.Key_Return: + if isTypingIDFunctionChecked: + self.typingEditID = False + elif self.keepIDsButton.isChecked(): + self.keepIDsConfirmAction.trigger() + elif ev.key() == Qt.Key_Escape: + self.onEscape(isTypingIDFunctionChecked=isTypingIDFunctionChecked) + elif isAltModifier: + isCursorSizeAll = self.app.overrideCursor() == Qt.SizeAllCursor + # Alt is pressed while cursor is on images --> set SizeAllCursor + if self.xHoverImg is not None and not isCursorSizeAll: + self.app.setOverrideCursor(Qt.SizeAllCursor) + elif isCtrlModifier and isOverlaySegm: + if ev.key() == Qt.Key_Up: + val = self.imgGrad.labelsAlphaSlider.value() + delta = 5 / self.imgGrad.labelsAlphaSlider.maximum() + val = val + delta + self.imgGrad.labelsAlphaSlider.setValue(val, emitSignal=True) + elif ev.key() == Qt.Key_Down: + val = self.imgGrad.labelsAlphaSlider.value() + delta = 5 / self.imgGrad.labelsAlphaSlider.maximum() + val = val - delta + self.imgGrad.labelsAlphaSlider.setValue(val, emitSignal=True) + elif ev.key() == self.zoomOutKeyValue: + self.zoomToCells(enforce=True) + if self.countKeyPress == 0: + self.isKeyDoublePress = False + self.countKeyPress = 1 + self.doubleKeyTimeElapsed = False + self.Button = None + QTimer.singleShot(400, self.doubleKeyTimerCallBack) + elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: + self.ax1.autoRange() + self.isKeyDoublePress = True + self.countKeyPress = 0 + elif ev.key() == Qt.Key_Space: + if self.countKeyPress == 0: + # Single press --> wait that it's not double press + self.isKeyDoublePress = False + self.countKeyPress = 1 + self.doubleKeyTimeElapsed = False + QTimer.singleShot(300, self.doubleKeySpacebarTimerCallback) + elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: + self.isKeyDoublePress = True + # Double press --> toggle draw nothing + self.onDoubleSpaceBar() + self.countKeyPress = 0 + elif isBrushKey or isEraserKey: + if isBrushKey: + self.Button = self.brushButton + else: + self.Button = self.eraserButton + + if not self.Button.isVisible(): + return + + if self.countKeyPress == 0: + # If first time clicking B activate brush and start timer + # to catch double press of B + if not self.Button.isChecked(): + self.uncheck = False + self.Button.setChecked(True) + else: + self.uncheck = True + self.countKeyPress = 1 + self.isKeyDoublePress = False + self.doubleKeyTimeElapsed = False + + QTimer.singleShot(400, self.doubleKeyTimerCallBack) + elif self.countKeyPress == 1 and not self.doubleKeyTimeElapsed: + self.isKeyDoublePress = True + color = self.Button.palette().button().color().name() + if color == self.doublePressKeyButtonColor: + c = self.defaultToolBarButtonColor + else: + c = self.doublePressKeyButtonColor + self.Button.setStyleSheet(f"background-color: {c}") + self.countKeyPress = 0 + if self.xHoverImg is not None: + xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) + if isBrushKey: + self.setHoverToolSymbolColor( + xdata, + ydata, + self.ax2_BrushCirclePen, + (self.ax2_BrushCircle, self.ax1_BrushCircle), + self.brushButton, + brush=self.ax2_BrushCircleBrush, + ) + elif isEraserKey: + self.setHoverToolSymbolColor( + xdata, + ydata, + self.eraserCirclePen, + (self.ax2_EraserCircle, self.ax1_EraserCircle), + self.eraserButton, + ) + + def keyReleaseEvent(self, ev): + if self.app.overrideCursor() == Qt.SizeAllCursor: + self.app.restoreOverrideCursor() + if ev.key() == Qt.Key_Control: + self.onCtrlReleased() + elif ev.key() == Qt.Key_Shift: + self.onShiftReleased() + + canRepeat = ( + ev.key() == Qt.Key_Left + or ev.key() == Qt.Key_Right + or ev.key() == Qt.Key_Up + or ev.key() == Qt.Key_Down + or ev.key() == Qt.Key_Control + or ev.key() == Qt.Key_Backspace + or self.delObjToolAction.isChecked() + ) + + if canRepeat and ev.isAutoRepeat(): + return + + self.delObjToolAction.setChecked(False) + + if ev.isAutoRepeat() and not ev.key() == Qt.Key_Z: + if self.warnKeyPressedMsg is not None: + return + self.warnKeyPressedMsg = widgets.myMessageBox( + showCentered=False, wrapText=False + ) + txt = html_utils.paragraph(f""" + Please, do not keep the key "{ev.text().upper()}" + pressed.

    + It confuses me :)

    + Thanks! + """) + self.warnKeyPressedMsg.warning(self, "Release the key, please", txt) + self.warnKeyPressedMsg = None + elif ev.isAutoRepeat() and ev.key() == Qt.Key_Z and self.isZmodifier: + self.zKeptDown = True + elif ev.key() == Qt.Key_Z and self.isZmodifier: + posData = self.data[self.pos_i] + self.isZmodifier = False + if not self.zKeptDown and posData.SizeZ > 1: + self.zSliceCheckbox.setChecked(not self.zSliceCheckbox.isChecked()) + self.zKeptDown = False + + def keyUpCallback( + self, isBrushActive, isWandActive, isExpandLabelActive, isLabelRoiCircActive + ): + isAutoPilotActive = ( + self.autoPilotZoomToObjToggle.isChecked() + and self.autoPilotZoomToObjToolbar.isVisible() + ) + if isBrushActive: + brushSize = self.brushSizeSpinbox.value() + self.brushSizeSpinbox.setValue(brushSize + 1) + elif isWandActive: + wandTolerance = self.wandControlsToolbar.toleranceSpinbox.value() + self.wandControlsToolbar.toleranceSpinbox.setValue(wandTolerance + 1) + elif isExpandLabelActive: + self.expandLabel(dilation=True) + self.expandFootprintSize += 1 + elif isLabelRoiCircActive: + val = self.labelRoiCircularRadiusSpinbox.value() + self.labelRoiCircularRadiusSpinbox.setValue(val + 1) + elif isAutoPilotActive: + self.pointsLayerAutoPilot("next") + else: + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepAdd + ) + + def leaveEvent(self, event): + if self.slideshowWin is not None: + posData = self.data[self.pos_i] + mainWinGeometry = self.geometry() + mainWinLeft = mainWinGeometry.left() + mainWinTop = mainWinGeometry.top() + mainWinWidth = mainWinGeometry.width() + mainWinHeight = mainWinGeometry.height() + mainWinRight = mainWinLeft + mainWinWidth + mainWinBottom = mainWinTop + mainWinHeight + + slideshowWinGeometry = self.slideshowWin.geometry() + slideshowWinLeft = slideshowWinGeometry.left() + slideshowWinTop = slideshowWinGeometry.top() + slideshowWinWidth = slideshowWinGeometry.width() + slideshowWinHeight = slideshowWinGeometry.height() + + # Determine if overlap + overlap = (slideshowWinTop < mainWinBottom) and ( + slideshowWinLeft < mainWinRight + ) + + autoActivate = ( + self.isDataLoaded + and not overlap + and not posData.disableAutoActivateViewerWindow + ) + + if autoActivate: + self.slideshowWin.setFocus() + self.slideshowWin.activateWindow() + + def mousePressEvent(self, event) -> None: + if event.button() == Qt.MouseButton.RightButton: + pos = self.resizeBottomLayoutLine.mapFromGlobal(event.globalPos()) + if pos.y() >= 0: + self.gui_raiseBottomLayoutContextMenu(event) + return super().mousePressEvent(event) + + def onKeyEnd(self): + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepSub + ) + + def onKeyHome(self): + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepAdd + ) + + def onKeyPageDown(self): + isAutoPilotActive = ( + self.autoPilotZoomToObjToggle.isChecked() + and self.autoPilotZoomToObjToolbar.isVisible() + ) + if isAutoPilotActive: + self.pointsLayerAutoPilot("prev") + elif self.zSliceScrollBar.isVisible(): + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepAdd + ) + + def onKeyPageUp(self): + isAutoPilotActive = ( + self.autoPilotZoomToObjToggle.isChecked() + and self.autoPilotZoomToObjToolbar.isVisible() + ) + if isAutoPilotActive: + self.pointsLayerAutoPilot("next") + elif self.zSliceScrollBar.isVisible(): + self.zSliceScrollBar.triggerAction( + QAbstractSlider.SliderAction.SliderSingleStepAdd + ) + + def onShiftReleased(self): + if self.brushButton.isChecked() and self.xHoverImg is not None: + self.updateBrushCursorOnShiftRelease() + + def readSettings(self): + settings = QSettings("schmollerlab", "acdc_gui") + if settings.value("geometry") is not None: + self.restoreGeometry(settings.value("geometry")) + + def resizeBottomLayoutLineClicked(self, event): + pass + + def resizeBottomLayoutLineDragged(self, event): + if not self.img1BottomGroupbox.isVisible(): + return + newBottomLayoutHeight = self.bottomScrollArea.minimumHeight() - event.y() + self.bottomScrollArea.setFixedHeight(newBottomLayoutHeight) + + def resizeBottomLayoutLineReleased(self): + QTimer.singleShot(100, self.autoRange) + + def resizeEvent(self, event): + if hasattr(self, "ax1"): + self.ax1.autoRange() + + def resizeLeaveSpaceTerminalBelow(self): + self.setWindowState(Qt.WindowMaximized) + QTimer.singleShot(200, self._resizeLeaveSpaceTerminalBelow) + + def resizeSlidersArea(self, fontSizeFactor=None, heightFactor=None): + global _font + if heightFactor is None: + self.newCheckBoxesHeight = self.checkBoxesHeight + self.newHeight = self.h + else: + self.newHeight = round(self.h * heightFactor) + self.newCheckBoxesHeight = round(self.checkBoxesHeight * heightFactor) + + if fontSizeFactor is None: + newFontSize = self.fontPixelSize + else: + newFontSize = round(self.fontPixelSize * fontSizeFactor) + newFont = QFont() + newFont.setPixelSize(newFontSize) + _font = newFont + self.zProjComboBox.setFont(newFont) + self.t_label.setFont(newFont) + self.zProjOverlay_CB.setFont(newFont) + self.annotateRightHowCombobox.setFont(newFont) + self.drawIDsContComboBox.setFont(newFont) + self.showTreeInfoCheckbox.setFont(newFont) + self.highlightZneighObjCheckbox.setFont(newFont) + self.navSpinBox.setFont(newFont) + self.zSliceSpinbox.setFont(newFont) + self.SizeZlabel.setFont(newFont) + self.navSizeLabel.setFont(newFont) + self.overlay_z_label.setFont(newFont) + self.img1BottomGroupbox.setFont(newFont) + self.rightBottomGroupbox.setFont(newFont) + try: + self.img1.alphaScrollbar.label.setFont(newFont) + except Exception as e: + pass + for i in range(self.annotOptionsLayout.count()): + widget = self.annotOptionsLayout.itemAt(i).widget() + widget.setFont(newFont) + for i in range(self.annotOptionsLayoutRight.count()): + widget = self.annotOptionsLayoutRight.itemAt(i).widget() + widget.setFont(newFont) + try: + for channel, items in self.overlayLayersItems.items(): + alphaScrollbar = items[2] + alphaScrollbar.label.setFont(newFont) + except: + pass + QTimer.singleShot(100, self._resizeSlidersArea) + + def saveWindowGeometry(self): + settings = QSettings("schmollerlab", "acdc_gui") + settings.setValue("geometry", self.saveGeometry()) + + def show(self): + self.setFont(_font) + QMainWindow.show(self) + + self.setWindowState(Qt.WindowNoState) + self.setWindowState(Qt.WindowActive) + self.raise_() + + self.readSettings() + self.storeDefaultAndCustomColors() + + self.h = self.navSpinBox.size().height() + fontSizeFactor = None + heightFactor = None + if "bottom_sliders_zoom_perc" in self.df_settings.index: + val = int(self.df_settings.at["bottom_sliders_zoom_perc", "value"]) + if val != 100: + fontSizeFactor = val / 100 + heightFactor = val / 100 + + self.defaultWidgetHeightBottomLayout = self.h + self.checkBoxesHeight = 14 + self.fontPixelSize = 11 + self.defaultBottomLayoutHeight = self.img1BottomGroupbox.height() + + self.bottomLayout.setStretch(0, 0) + self.bottomLayout.addSpacing(self.quickSettingsGroupbox.width()) + self.resizeSlidersArea(fontSizeFactor=fontSizeFactor, heightFactor=heightFactor) + self.bottomScrollArea.hide() + + self.gui_initImg1BottomWidgets() + self.img1BottomGroupbox.hide() + + w = self.showPropsDockButton.width() + h = self.showPropsDockButton.height() + + self.showPropsDockButton.setMaximumWidth(15) + self.showPropsDockButton.setMaximumHeight(120) + + for toolbar in self.controlToolBars: + toolbar.setMinimumHeight(self.secondLevelToolbar.sizeHint().height()) + + self.graphLayout.setFocus() + + def showEvent(self, event): + if self.mainWin is not None: + if not self.mainWin.isMinimized(): + return + self.mainWin.showAllWindows() + # self.setFocus() + self.activateWindow() + + def stopPreprocWorker(self): + self.logger.info("Closing pre-processing worker...") + try: + self.preprocWorker.stop() + except Exception as err: + pass + + def storeDefaultAndCustomColors(self): + c = self.overlayButton.palette().button().color().name() + self.defaultToolBarButtonColor = c + self.doublePressKeyButtonColor = "#fa693b" + + def super_show(self): + super().show() + + def updateBrushCursorOnShiftRelease(self): + xdata, ydata = int(self.xHoverImg), int(self.yHoverImg) + self.setHoverToolSymbolColor( + xdata, + ydata, + self.ax2_BrushCirclePen, + (self.ax2_BrushCircle, self.ax1_BrushCircle), + self.brushButton, + brush=self.ax2_BrushCircleBrush, + byPassShiftCheck=True, + ) + if self.isSegm3D: + self.changeBrushID() diff --git a/cellacdc/mixins/worker.py b/cellacdc/mixins/worker.py new file mode 100644 index 000000000..010a8b226 --- /dev/null +++ b/cellacdc/mixins/worker.py @@ -0,0 +1,415 @@ +"""Qt view adapter for GUI worker lifecycle handling.""" + +from __future__ import annotations + +import logging +import traceback +from functools import partial +from typing import Tuple + +from qtpy.QtCore import QObject, QMutex, QThread, QTimer, QWaitCondition + +from cellacdc import apps, exception_handler, html_utils, issues_url, widgets, workers + +from .status_hover import StatusHover + + +class Worker(StatusHover): + """Extracted from guiWin.""" + + def autoSaveWorkerClosed(self, worker): + if self.autoSaveActiveWorkers: + self.logger.info("Autosaving worker closed.") + try: + self.autoSaveActiveWorkers.remove(worker) + except Exception as e: + pass + + def autoSaveWorkerDone(self): + self.setStatusBarLabel(log=False) + + def autoSaveWorkerStartTimer(self, worker, posData): + self.autoSaveWorkerTimer = QTimer() + self.autoSaveWorkerTimer.timeout.connect( + partial(self.autoSaveWorkerTimerCallback, worker, posData) + ) + self.autoSaveWorkerTimer.start(150) + + def autoSaveWorkerTimerCallback(self, worker, posData): + if not self.isSaving: + self.autoSaveWorkerTimer.stop() + worker._enqueue(posData) + + def ccaIntegrityWorkerCritical(self, error): + try: + raise error + except Exception as err: + self.logger.exception(traceback.format_exc()) + + href = f'GitHub page' + txt = html_utils.paragraph(f""" + Unfortunately the experimental feature + check cell cycle annotations integrity raised a + critical error.

    + Cell-ACDC will now disable this feature to allow you to keep + using the software.

    + However, we kindly ask you to report the issue on our + {href}, thank you very much!

    + Please, include the log file when reporting the issue.

    + Log file location: + """) + msg = widgets.myMessageBox(wrapText=False) + msg.warning( + self, + "Experimental feature error", + txt, + commands=(self.log_path,), + path_to_browse=self.logs_path, + ) + self.disableCcaIntegrityChecker() + + def gui_createAutoSaveWorker(self): + if not hasattr(self, "data"): + return + + if not self.isDataLoaded: + return + + if self.autoSaveActiveWorkers: + garbage = self.autoSaveActiveWorkers[-1] + self.autoSaveGarbageWorkers.append(garbage) + worker = garbage[0] + worker._stop() + + posData = self.data[self.pos_i] + autoSaveThread = QThread() + self.autoSaveMutex = QMutex() + self.autoSaveWaitCond = QWaitCondition() + + savedSegmData = posData.segm_data.copy() + autoSaveWorker = workers.AutoSaveWorker( + self.autoSaveMutex, self.autoSaveWaitCond, savedSegmData + ) + autoSaveWorker.isAutoSaveON = self.autoSaveToggle.isChecked() + + autoSaveWorker.moveToThread(autoSaveThread) + autoSaveWorker.finished.connect(autoSaveThread.quit) + autoSaveWorker.finished.connect(autoSaveWorker.deleteLater) + autoSaveThread.finished.connect(autoSaveThread.deleteLater) + + autoSaveWorker.sigDone.connect(self.autoSaveWorkerDone) + autoSaveWorker.progress.connect(self.workerProgress) + autoSaveWorker.finished.connect(self.autoSaveWorkerClosed) + autoSaveWorker.sigAutoSaveCannotProceed.connect(self.turnOffAutoSaveWorker) + + autoSaveThread.started.connect(autoSaveWorker.run) + autoSaveThread.start() + + self.autoSaveActiveWorkers.append((autoSaveWorker, autoSaveThread)) + + self.logger.info("Autosaving worker started.") + + def gui_createLazyLoader(self): + if not self.lazyLoader is None: + return + + self.lazyLoaderThread = QThread() + self.lazyLoaderMutex = QMutex() + self.lazyLoaderWaitCond = QWaitCondition() + self.waitReadH5cond = QWaitCondition() + self.readH5mutex = QMutex() + self.lazyLoader = workers.LazyLoader( + self.lazyLoaderMutex, + self.lazyLoaderWaitCond, + self.waitReadH5cond, + self.readH5mutex, + ) + self.lazyLoader.moveToThread(self.lazyLoaderThread) + self.lazyLoader.wait = True + + self.lazyLoader.signals.finished.connect(self.lazyLoaderThread.quit) + self.lazyLoader.signals.finished.connect(self.lazyLoader.deleteLater) + self.lazyLoaderThread.finished.connect(self.lazyLoaderThread.deleteLater) + + self.lazyLoader.signals.progress.connect(self.workerProgress) + self.lazyLoader.signals.sigLoadingNewChunk.connect(self.loadingNewChunk) + self.lazyLoader.sigLoadingFinished.connect(self.lazyLoaderFinished) + self.lazyLoader.signals.critical.connect(self.lazyLoaderCritical) + self.lazyLoader.signals.finished.connect(self.lazyLoaderWorkerClosed) + + self.lazyLoaderThread.started.connect(self.lazyLoader.run) + self.lazyLoaderThread.start() + + def gui_createStoreStateWorker(self): + self.storeStateWorker = None + return + self.storeStateThread = QThread() + self.autoSaveMutex = QMutex() + self.autoSaveWaitCond = QWaitCondition() + + self.storeStateWorker = workers.StoreGuiStateWorker( + self.autoSaveMutex, self.autoSaveWaitCond + ) + + self.storeStateWorker.moveToThread(self.storeStateThread) + self.storeStateWorker.finished.connect(self.storeStateThread.quit) + self.storeStateWorker.finished.connect(self.storeStateWorker.deleteLater) + self.storeStateThread.finished.connect(self.storeStateThread.deleteLater) + + self.storeStateWorker.sigDone.connect(self.storeStateWorkerDone) + self.storeStateWorker.progress.connect(self.workerProgress) + self.storeStateWorker.finished.connect(self.storeStateWorkerClosed) + + self.storeStateThread.started.connect(self.storeStateWorker.run) + self.storeStateThread.start() + + self.logger.info("Store state worker started.") + + def lazyLoaderCritical(self, error): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.lazyLoader.pause() + raise error + + def lazyLoaderFinished(self): + self.logger.info("Load chunk data worker done.") + if self.lazyLoader.updateImgOnFinished: + self.updateAllImages() + + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + def lazyLoaderWorkerClosed(self): + if self.lazyLoader.salute: + self.logger.info("Cell-ACDC GUI closed.") + self.sigClosed.emit(self) + + self.lazyLoader = None + + def loadingNewChunk(self, chunk_range): + coord0_chunk, coord1_chunk = chunk_range + desc = f"Loading new window, range = ({coord0_chunk}, {coord1_chunk})..." + self.progressWin = apps.QDialogWorkerProgress( + title="Loading data...", parent=self, pbarDesc=desc + ) + self.progressWin.mainPbar.setMaximum(0) + self.progressWin.show(self.app) + + def relabelWorkerFinished(self): + self.updateAllImages() + + def saveDataWorkerCritical(self, error): + self.logger.warning("Saving process stopped because of critical error.") + self.saveWin.aborted = True + self.worker.finished.emit() + self.workerCritical(error) + + def savePreprocWorkerFinished(self): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + + self.setStatusBarLabel() + self.logger.info("Pre-processed data saved!") + self.titleLabel.setText("Pre-processed data saved!", color="w") + + def startPostProcessSegmWorker( + self, + postProcessKwargs, + customPostProcessGroupedFeatures, + customPostProcessFeatures, + ): + self.thread = QThread() + self.postProcessWorker = workers.PostProcessSegmWorker( + postProcessKwargs, + customPostProcessGroupedFeatures, + customPostProcessFeatures, + self, + ) + + self.postProcessWorker.moveToThread(self.thread) + self.postProcessWorker.signals.finished.connect(self.thread.quit) + self.postProcessWorker.signals.finished.connect( + self.postProcessWorker.deleteLater + ) + self.thread.finished.connect(self.thread.deleteLater) + + self.postProcessWorker.signals.finished.connect( + self.postProcessSegmWorkerFinished + ) + self.postProcessWorker.signals.progress.connect(self.workerProgress) + self.postProcessWorker.signals.initProgressBar.connect( + self.workerInitProgressbar + ) + self.postProcessWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.postProcessWorker.signals.critical.connect(self.workerCritical) + + self.thread.started.connect(self.postProcessWorker.run) + self.thread.start() + + def startRelabellingWorker(self, posFoldernames): + self.thread = QThread() + self.worker = workers.relabelSequentialWorker(self, posFoldernames) + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + self.worker.progress.connect(self.workerProgress) + self.worker.critical.connect(self.workerCritical) + self.worker.finished.connect(self.workerFinished) + self.worker.finished.connect(self.relabelWorkerFinished) + + self.worker.debug.connect(self.workerDebug) + + self.thread.started.connect(self.worker.run) + self.thread.start() + + def startTrackingWorker(self, posData, video_to_track): + self.thread = QThread() + self.trackingWorker = workers.trackingWorker(posData, self, video_to_track) + self.trackingWorker.moveToThread(self.thread) + self.trackingWorker.finished.connect(self.thread.quit) + self.trackingWorker.finished.connect(self.trackingWorker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + + # Custom signals + self.trackingWorker.signals.progress = self.trackingWorker.progress + self.trackingWorker.signals.progressBar.connect(self.workerUpdateProgressbar) + self.trackingWorker.signals.initProgressBar.connect(self.workerInitProgressbar) + self.trackingWorker.signals.sigInitInnerPbar.connect(self.workerInitInnerPbar) + self.trackingWorker.progress.connect(self.workerProgress) + self.trackingWorker.critical.connect(self.workerCritical) + self.trackingWorker.finished.connect(self.trackingWorkerFinished) + + self.trackingWorker.debug.connect(self.workerDebug) + + self.thread.started.connect(self.trackingWorker.run) + self.thread.start() + + def storeStateWorkerClosed(self): + self.logger.info("Store state worker started.") + + def storeStateWorkerDone(self): + if self.storeStateWorker.callbackOnDone is not None: + self.storeStateWorker.callbackOnDone() + self.storeStateWorker.callbackOnDone = None + + def trackingWorkerFinished(self): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.logger.info("Worker process ended.") + askDisableRealTimeTracking = ( + self.trackingWorker.trackingOnNeverVisitedFrames + and self.realTimeTrackingToggle.isChecked() + ) + if askDisableRealTimeTracking: + msg = widgets.myMessageBox() + title = "Disable real-time tracking?" + txt = ( + "You perfomed tracking on frames that you have " + "never visited.

    " + "Cell-ACDC default behaviour is to track them again when you " + "will visit them.

    " + "However, you can overwrite this behaviour and explicitly " + "disable tracking for all of the frames you already tracked.

    " + "NOTE: you can reactivate real-time tracking by clicking on the " + '"Reset last segmented frame" button on the top toolbar.

    ' + "What do you want me to do?" + ) + _, disableTrackingButton = msg.information( + self, + title, + html_utils.paragraph(txt), + buttonsTexts=( + "Keep real-time tracking active (recommended)", + "Disable real-time tracking", + ), + ) + if msg.clickedButton == disableTrackingButton: + self.logger.info("Disabling real time tracking...") + self.realTimeTrackingToggle.setChecked(False) + # posData = self.data[self.pos_i] + # current_frame_i = posData.frame_i + # for frame_i in range(self.start_n-1, self.stop_n): + # posData.frame_i = frame_i + # self.get_data() + # self.store_data(autosave=frame_i==self.stop_n-1) + # posData.last_tracked_i = frame_i + # self.setNavigateScrollBarMaximum() + + # # Back to current frame + # posData.frame_i = current_frame_i + # self.get_data() + posData = self.data[self.pos_i] + self.updateAllImages() + self.titleLabel.setText("Done", color="w") + + def workerCritical(self, out: Tuple[QObject, Exception]): + self.setDisabled(False) + try: + worker, error = out + except TypeError as err: + error = out + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.logger.info(error) + try: + worker.thread().quit() + worker.deleteLater() + worker.thread().deleteLater() + except Exception as err: + # Worker already closed + pass + raise error + + def workerDebug(self, item): + tracked_video, worker = item + from cellacdc.plot import imshow + + imshow(tracked_video) + worker.waitCond.wakeAll() + + def workerFinished(self): + if self.progressWin is not None: + self.progressWin.workerFinished = True + self.progressWin.close() + self.progressWin = None + self.logger.info("Worker process ended.") + self.updateAllImages() + self.titleLabel.setText("Done", color="w") + + def workerInitInnerPbar(self, totalIter): + self.progressWin.innerPbar.setValue(0) + if totalIter == 1: + totalIter = 0 + self.progressWin.innerPbar.setMaximum(totalIter) + + def workerInitProgressbar(self, totalIter): + self.progressWin.mainPbar.setValue(0) + if totalIter == 1: + totalIter = 0 + self.progressWin.mainPbar.setMaximum(totalIter) + + def workerLog(self, text): + self.logger.info(text) + + def workerProgress(self, text, loggerLevel="INFO"): # used in cca and lin tree + if self.progressWin is not None: + self.progressWin.logConsole.append(text) + self.logger.log(getattr(logging, loggerLevel), text) + + def workerUpdateInnerPbar(self, step): + self.progressWin.innerPbar.update(step) + + def workerUpdateProgressbar(self, step): + self.progressWin.mainPbar.update(step) diff --git a/cellacdc/models/Cellpose_germlineNuclei/__init__.py b/cellacdc/models/Cellpose_germlineNuclei/__init__.py deleted file mode 100644 index e1e15a14e..000000000 --- a/cellacdc/models/Cellpose_germlineNuclei/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_cellpose() \ No newline at end of file diff --git a/cellacdc/models/Cellpose_germlineNuclei/acdcSegment.py b/cellacdc/models/Cellpose_germlineNuclei/acdcSegment.py deleted file mode 100644 index d5c5fb6cc..000000000 --- a/cellacdc/models/Cellpose_germlineNuclei/acdcSegment.py +++ /dev/null @@ -1,200 +0,0 @@ -import os -import numpy as np - -from skimage.measure import label as skiLabel -import math -import scipy -import scipy.ndimage - -import skimage.exposure -import skimage.filters -import skimage.measure - -from cellpose import models -from cellacdc import user_profile_path - -default_model_path = os.path.join( - user_profile_path, - 'acdc-Cellpose_germlineNuclei', - 'cellpose_germlineNuclei_2023' -) - -class Model: - def __init__( - self, - model_path: os.PathLike=default_model_path, - gpu=False - ): - self.model = models.CellposeModel( - gpu=gpu, diam_mean=30, pretrained_model=model_path - ) - - def setupLogger(self, logger): - models.models_logger = logger - - def setLoggerPropagation(self, propagate:bool): - models.models_logger.propagate = propagate - - def setLoggerLevel(self, level:str): - import logging - if level == 'error': - models.models_logger.setLevel(logging.ERROR) - - - def closeLogger(self): - handlers = models.models_logger.handlers[:] - for handler in handlers: - handler.close() - models.models_logger.removeHandler(handler) - - def _eval(self, image, **kwargs): - return self.model.eval(image.astype(np.float32), **kwargs)[0] - - def _initialize_image(self, image): - # See cellpose.gui.io._initialize_images - if image.ndim > 3: - # make tiff Z x channels x W x H - if image.shape[0]<4: - # tiff is channels x Z x W x H - image = np.transpose(image, (1,0,2,3)) - elif image.shape[-1]<4: - # tiff is Z x W x H x channels - image = np.transpose(image, (0,3,1,2)) - # fill in with blank channels to make 3 channels - if image.shape[1] < 3: - shape = image.shape - shape_to_concat = (shape[0], 3-shape[1], shape[2], shape[3]) - to_concat = np.zeros(shape_to_concat, dtype=np.uint8) - image = np.concatenate((image, to_concat), axis=1) - image = np.transpose(image, (0,2,3,1)) - elif image.ndim==3: - if image.shape[0] < 5: - image = np.transpose(image, (1,2,0)) - if image.shape[-1] < 3: - shape = image.shape - #if parent.autochannelbtn.isChecked(): - # image = normalize99(image) * 255 - shape_to_concat = (shape[0], shape[1], 3-shape[2]) - to_concat = np.zeros(shape_to_concat,dtype=type(image[0,0,0])) - image = np.concatenate((image, to_concat), axis=-1) - image = image[np.newaxis,...] - elif image.shape[-1]<5 and image.shape[-1]>2: - image = image[:,:,:3] - #if parent.autochannelbtn.isChecked(): - # image = normalize99(image) * 255 - image = image[np.newaxis,...] - else: - image = image[np.newaxis,...] - - if image.ndim < 4: - image = image[:,:,:,np.newaxis] - return image - - - def segment( - self, image, - diameter_um=3.5, - blurfactor=2.50, - PhysicalSizeZ = 1.0001, - PhysicalSizeY = 1.0001, - PhysicalSizeX = 1.0001, - cellprob_threshold=0.0, - clean_borders=False - ): - """ Cellpose model for C. elegans germline nuclei. This model works on a single channel only. - - Parameters - ---------- - diameter_um : float - Expected diameter of a nucleus in micrometer - blurfactor : float - Sigma value of the gaussian filter used for blurring of the data. - PhysicalSizeZ : float - Spacing of slices in z (unit: micrometer/slice). Prepopulated from image metadata - PhysicalSizeY : float - Pixelsize in y (unit: micrometer/pixel). Prepopulated from image metadata - PhysicalSizeX : float - Pixelsize in x (unit: micrometer/pixel). Prepopulated from image metadata - cellprob_threshold : float - cellprob_threshold for cellpose. - clean_borders : bool - Remove masks that touch the top or bottom slice in z, or that are closer than 2 pixels to the edges in x or y. - - Returns - ----- - np.ndarray - Instance segmentation array with the same shape as the input image. - """ - - - # Preprocess image - # image = image/image.max() - # image = skimage.filters.gaussian(image, sigma=1) - # image = skimage.exposure.equalize_adapthist(image) - zspacing = PhysicalSizeZ - xysize = np.mean([PhysicalSizeX, PhysicalSizeY]) - - isRGB = image.shape[-1] == 3 or image.shape[-1] == 4 - if isRGB: - raise TypeError( - "This model was trained for 1 channel only. Please specify a single channel (DNA or synaptonemal complex/axis staining). " - ) - - isZstack = (image.ndim==3 and not isRGB) or (image.ndim==4) - - anisotropy = math.ceil(abs(zspacing/xysize)) - pxScale=xysize*30/diameter_um - - - do_3D = True - - #if stitch_threshold > 0: - # do_3D = False - - - channels = [0,0] - - - - # Run cellpose eval - if not isZstack: - raise TypeError( - "This script is for 3D data (at least 5 slices) only. If needed, please modify the script to segment 2D data." - ) - else: - img_scaled=np.zeros((image.shape[0],round(image.shape[1]*pxScale),round(image.shape[2]*pxScale))) - img_blur=np.zeros((img_scaled.shape)) - image[image==0] = np.quantile(image[image>0],0.01) - - if pxScale > 1: - for i in range(image.shape[0]): - img_scaled[i,:,:] = scipy.ndimage.zoom(image[i,:,:],pxScale, order=3) - img_blur[i,:,:]=scipy.ndimage.gaussian_filter(img_scaled[i,:,:],blurfactor) - - else: - for i in range(image.shape[0]): - img_scaled[i,:,:] = scipy.ndimage.zoom(image[i,:,:],pxScale, order=3) - img_blur[i,:,:]=scipy.ndimage.gaussian_filter(img_scaled[i,:,:],blurfactor) - img_blur = self._initialize_image(img_blur) - labels_scaled, flows_blur, styles_blur = self.model.eval(img_blur.astype(np.uint16), - diameter=30, - channels=channels, do_3D=True, - anisotropy=anisotropy, - batch_size=3, - cellprob_threshold=cellprob_threshold) - - labels=np.zeros(image.shape,dtype=labels_scaled.dtype) - for i in range(image.shape[0]): - labels[i,:,:]=scipy.ndimage.zoom(labels_scaled[i,:,:],(image.shape[1]/labels_scaled.shape[1],image.shape[2]/labels_scaled.shape[2]),order=0) - - if clean_borders: - idx = np.unique(np.concatenate([np.unique(labels[-1,:,:][labels[-1,:,:]>0]),np.unique(labels[0,:,:][labels[0,:,:]>0]), - np.unique(labels[:,0:2,:][labels[:,0:2,:]>0]),np.unique(labels[:,-3:-1,:][labels[:,-3:-1,:]>0]), - np.unique(labels[:,:,0:2][labels[:,:,0:2]>0]),np.unique(labels[:,:,-3:-1][labels[:,:,-3:-1]>0]),])) - - labels[np.isin(labels,idx)] = 0 - - return labels - -def url_help(): - return 'https://cellpose.readthedocs.io/en/latest/api.html' diff --git a/cellacdc/models/InstanSeg/__init__.py b/cellacdc/models/InstanSeg/__init__.py deleted file mode 100644 index 39ed82738..000000000 --- a/cellacdc/models/InstanSeg/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_instanseg() - -INSTANSEG_MODELS = ( - 'fluorescence_nuclei_and_cells', - 'brightfield_nuclei' -) - diff --git a/cellacdc/models/YeaZ/__init__.py b/cellacdc/models/YeaZ/__init__.py deleted file mode 100755 index f212dc81a..000000000 --- a/cellacdc/models/YeaZ/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_package('tensorflow', max_version='2.17') diff --git a/cellacdc/models/YeaZ/unet/model.py b/cellacdc/models/YeaZ/unet/model.py deleted file mode 100755 index 587de0035..000000000 --- a/cellacdc/models/YeaZ/unet/model.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Source of the code: https://github.com/zhixuhao/unet -""" -# Turn off GPU access so can train and use the YeaZ-GUI -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - -# Import tensorflow differently depending on version -from tensorflow import __version__ as tf_version -tf_version_old = int(tf_version[0]) <= 1 - -from tensorflow.keras.models import Model -from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, - concatenate, UpSampling2D) -from tensorflow.keras.optimizers import Adam -#from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler - -if tf_version_old: - from tensorflow import ConfigProto - from tensorflow import InteractiveSession - -else: - from tensorflow.compat.v1 import ConfigProto - from tensorflow.compat.v1 import InteractiveSession - - -config = ConfigProto() -config.gpu_options.allow_growth = True -session = InteractiveSession(config=config) - -def unet(pretrained_weights = None,input_size = (256,256,1)): - inputs = Input(input_size) - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) - pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) - pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) - pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) - drop4 = Dropout(0.5)(conv4) - pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) - - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) - drop5 = Dropout(0.5)(conv5) - - up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) - merge6 = concatenate([drop4,up6], axis = 3) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) - - up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) - merge7 = concatenate([conv3,up7], axis = 3) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) - - up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) - merge8 = concatenate([conv2,up8], axis = 3) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) - - up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) - merge9 = concatenate([conv1,up9], axis = 3) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) - conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) - conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) - - model = Model(inputs = inputs, outputs = conv10) - - model.compile(optimizer = Adam(learning_rate = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) - - if(pretrained_weights): - model.load_weights(pretrained_weights) - - return model diff --git a/cellacdc/models/YeaZ_v2/__init__.py b/cellacdc/models/YeaZ_v2/__init__.py deleted file mode 100644 index bb3071951..000000000 --- a/cellacdc/models/YeaZ_v2/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -import os - -from cellacdc import myutils, load - -myutils.check_install_yeaz() - -custom_weights_json_filename = 'custom_weights_name_filepath.json' - -def add_model_filepath(name: str, filepath: os.PathLike): - _, model_folderpath = myutils.get_model_path( - 'YeaZ_v2', create_temp_dir=False - ) - custom_weights_json_file = os.path.join( - model_folderpath, custom_weights_json_filename - ) - custom_weights_mapper = {} - if os.path.exists(custom_weights_json_file): - custom_weights_mapper = load.read_json( - custom_weights_json_file, - desc='YeaZ_v2 custom weights filepath info' - ) - - custom_weights_mapper[name] = filepath - load.write_json(custom_weights_mapper, custom_weights_json_file) - -def load_models_filepath(): - values = [ - 'Phase contrast', - 'Bright-field', - 'Fission yeast' - ] - mapper = { - 'Phase contrast': 'weights_budding_PhC_multilab_0_1', - 'Bright-field': 'weights_budding_BF_multilab_0_1', - 'Fission yeast': 'weights_fission_multilab_0_2' - } - _, model_folderpath = myutils.get_model_path( - 'YeaZ_v2', create_temp_dir=False - ) - mapper = { - name: os.path.join(model_folderpath, filename) - for name, filename in mapper.items() - } - - custom_weights_json_file = os.path.join( - model_folderpath, custom_weights_json_filename - ) - if not os.path.exists(custom_weights_json_file): - return values, mapper - - custom_weights_mapper = load.read_json( - custom_weights_json_file, - desc='YeaZ_v2 custom weights filepath info' - ) - values.extend(custom_weights_mapper.keys()) - mapper = {**mapper, **custom_weights_mapper} - - return values, mapper \ No newline at end of file diff --git a/cellacdc/models/__init__.py b/cellacdc/models/__init__.py old mode 100755 new mode 100644 index b9c59c7b0..8b1378917 --- a/cellacdc/models/__init__.py +++ b/cellacdc/models/__init__.py @@ -1,5 +1 @@ -STARDIST_MODELS = [ - '2D_versatile_fluo', - '2D_versatile_he', - '2D_paper_dsb2018' -] \ No newline at end of file + diff --git a/cellacdc/models/_cellpose_base/__init__.py b/cellacdc/models/_cellpose_base/__init__.py deleted file mode 100644 index 59bc26e30..000000000 --- a/cellacdc/models/_cellpose_base/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -min_target_versions_cp = { - '2': '2.3.2', - '3': '3.1.1.2', - '4': '4.0.6', -} \ No newline at end of file diff --git a/cellacdc/models/cellpose_v2/__init__.py b/cellacdc/models/cellpose_v2/__init__.py deleted file mode 100644 index e99ca1b01..000000000 --- a/cellacdc/models/cellpose_v2/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import cellacdc.myutils as myutils -myutils.check_install_cellpose(2) - -class AvailableModelsv2: - from cellpose.models import MODEL_NAMES - values = MODEL_NAMES - - is_exclusive_with = ['model_path'] - default_exclusive = 'Using custom model' \ No newline at end of file diff --git a/cellacdc/models/cellpose_v3/__init__.py b/cellacdc/models/cellpose_v3/__init__.py deleted file mode 100644 index cca6ea45e..000000000 --- a/cellacdc/models/cellpose_v3/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -import cellacdc.myutils as myutils -myutils.check_install_cellpose(3) - -class AvailableModelsv3: - from cellpose.models import MODEL_NAMES - values = MODEL_NAMES - - is_exclusive_with = ['model_path'] - default_exclusive = 'Using custom model' - -class AvailableModelsv3Denoise: - from cellpose.denoise import MODEL_NAMES - values = MODEL_NAMES - - is_exclusive_with = ['denoise_model_path'] - default_exclusive = 'Using custom denoise model' \ No newline at end of file diff --git a/cellacdc/models/cellpose_v4/__init__.py b/cellacdc/models/cellpose_v4/__init__.py deleted file mode 100644 index 4a53003ca..000000000 --- a/cellacdc/models/cellpose_v4/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import cellacdc.myutils as myutils -myutils.check_install_cellpose(4) - -class AvailableModelsv4: - from cellpose.models import MODEL_NAMES - values = MODEL_NAMES - - is_exclusive_with = ['model_path'] - default_exclusive = 'Using custom model' \ No newline at end of file diff --git a/cellacdc/models/delta/__init__.py b/cellacdc/models/delta/__init__.py deleted file mode 100644 index cd7539ea5..000000000 --- a/cellacdc/models/delta/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Installs delta2 into acdc. - -@author: jroberts / jamesr787 -""" - -from cellacdc import myutils - -myutils.check_install_package('delta', pypi_name='delta2') diff --git a/cellacdc/models/omnipose/__init__.py b/cellacdc/models/omnipose/__init__.py deleted file mode 100644 index b21279481..000000000 --- a/cellacdc/models/omnipose/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -import os -import sys -import subprocess - -from cellacdc import myutils - -myutils.check_install_omnipose() \ No newline at end of file diff --git a/cellacdc/models/omnipose_custom/__init__.py b/cellacdc/models/omnipose_custom/__init__.py deleted file mode 100644 index 7ff29a02c..000000000 --- a/cellacdc/models/omnipose_custom/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -import os -import sys -import subprocess - -from cellacdc import myutils - -myutils.check_install_package('omnipose_acdc') diff --git a/cellacdc/models/pomBseen/__init__.py b/cellacdc/models/pomBseen/__init__.py deleted file mode 100644 index 03dbe438a..000000000 --- a/cellacdc/models/pomBseen/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_package('pombseen', pypi_name='pomBseen') diff --git a/cellacdc/models/pomBseen_nuclear/__init__.py b/cellacdc/models/pomBseen_nuclear/__init__.py deleted file mode 100644 index 03dbe438a..000000000 --- a/cellacdc/models/pomBseen_nuclear/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_package('pombseen', pypi_name='pomBseen') diff --git a/cellacdc/models/sam2/__init__.py b/cellacdc/models/sam2/__init__.py deleted file mode 100644 index 65b4ab001..000000000 --- a/cellacdc/models/sam2/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_sam2() -import sam2 - -import os -from pathlib import Path - -# Get SAM2 models path -# Using the same pattern as segment_anything -_, sam_models_path = myutils.get_model_path('sam2', create_temp_dir=False) - -# SAM2 model configurations -# Format: 'Display Name': ('config_file', 'checkpoint_filename') -model_types = { - 'Large': ('configs/sam2.1/sam2.1_hiera_l.yaml', 'sam2.1_hiera_large.pt'), - 'Base Plus': ('configs/sam2.1/sam2.1_hiera_b+.yaml', 'sam2.1_hiera_base_plus.pt'), - 'Small': ('configs/sam2.1/sam2.1_hiera_s.yaml', 'sam2.1_hiera_small.pt'), - 'Tiny': ('configs/sam2.1/sam2.1_hiera_t.yaml', 'sam2.1_hiera_tiny.pt'), -} diff --git a/cellacdc/models/segment_anything/__init__.py b/cellacdc/models/segment_anything/__init__.py deleted file mode 100644 index 807d1a8ec..000000000 --- a/cellacdc/models/segment_anything/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from cellacdc import myutils - -myutils.check_install_segment_anything() - -import os -from cellacdc import segment_anything_weights_filenames - -_, sam_models_path = myutils.get_model_path('segment_anything', create_temp_dir=False) - -model_types = { - 'Large': ('default', segment_anything_weights_filenames[0]), - 'Medium': ('vit_l', segment_anything_weights_filenames[1]), - 'Small': ('vit_b', segment_anything_weights_filenames[2]) -} \ No newline at end of file diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py deleted file mode 100644 index d04c49739..000000000 --- a/cellacdc/myutils.py +++ /dev/null @@ -1,5707 +0,0 @@ -import os -import re -import ast - -import typing -from typing import Literal, List, Callable, Tuple, Dict - -import pathlib -import difflib -import sys -import platform -import tempfile -import shutil -import traceback -import logging -import datetime -import time -import subprocess -import importlib -from uuid import uuid4 -from importlib import import_module -from math import pow, ceil, floor -from functools import wraps, partial -from collections import namedtuple, Counter -from tqdm import tqdm -import requests -import zipfile -import json -import numpy as np -import pandas as pd -import skimage -import inspect - -import traceback -import itertools -from packaging import version as packaging_version - -from natsort import natsorted - -import tifffile -import skimage.io -import skimage.measure - -from . import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env - -from . import core, load -from . import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 -from . import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path -from . import user_profile_path, recentPaths_path -from . import models_list_file_path, models_path -from . import promptable_models_list_file_path, promptable_models_path -from . import github_home_url -from . import try_input_install_package -from . import _warnings -from . import urls -from . import qrc_resources_path -from . import settings_folderpath -from .models._cellpose_base import min_target_versions_cp - -if GUI_INSTALLED: - from qtpy.QtWidgets import QMessageBox - from qtpy.QtCore import Signal, QObject, QCoreApplication - - from . import widgets, apps - from . import config - -ArgSpec = namedtuple('ArgSpec', ['name', 'default', 'type', 'desc', 'docstring']) - -def get_module_name(script_file_path): - parts = pathlib.Path(script_file_path).parts - parts = list(parts[parts.index('cellacdc')+1:]) - parts[-1] = os.path.splitext(parts[-1])[0] - module = '.'.join(parts) - return module - -def get_pos_status_acdc(pos_path): - images_path = os.path.join(pos_path, 'Images') - ls = listdir(images_path) - for file in ls: - if file.endswith('acdc_output.csv'): - acdc_df_path = os.path.join(images_path, file) - break - else: - return '' - - acdc_df = pd.read_csv(acdc_df_path) - last_tracked_i = acdc_df['frame_i'].max() - last_cca_i = 0 - if 'cell_cycle_stage' in acdc_df.columns: - cca_df = acdc_df[['frame_i', 'cell_cycle_stage']].dropna() - last_cca_i = cca_df['frame_i'].max() - if last_cca_i > 0: - return ( - f' (last tracked frame = {last_tracked_i+1}, ' - f'last annotated frame = {last_cca_i+1})' - ) - else: - return f' (last tracked frame = {last_tracked_i+1})' - -def get_pos_status_spotmax(pos_path): - spotmax_out_path = os.path.join(pos_path, 'spotMAX_output') - is_smax_out_present = 'Yes' if os.path.exists(spotmax_out_path) else 'No' - if os.path.exists(spotmax_out_path): - return ' (SpotMAX output exists)' - else: - return '' - -def get_pos_status( - pos_path, - caller: Literal['Cell-ACDC', 'SpotMAX']='Cell-ACDC' - ): - if caller == 'Cell-ACDC': - return get_pos_status_acdc(pos_path) - - if caller == 'SpotMAX': - return get_pos_status_spotmax(pos_path) - -def get_gdrive_path(): - if is_win: - return os.path.join(f'G:{os.sep}', 'My Drive') - elif is_mac: - return os.path.join( - '/Users/francesco.padovani/Library/CloudStorage/' - 'GoogleDrive-padovaf@tcd.ie/My Drive' - ) - -def get_acdc_data_path(): - Cell_ACDC_path = os.path.dirname(cellacdc_path) - return os.path.join(Cell_ACDC_path, 'data') - -def get_open_filemaneger_os_string(): - if is_win: - return 'Show in Explorer...' - elif is_mac: - return 'Reveal in Finder...' - elif is_linux: - return 'Show in File Manager...' - -def filterCommonStart(images_path): - startNameLen = 6 - ls = listdir(images_path) - if not ls: - return [] - allFilesStartNames = [f[:startNameLen] for f in ls] - mostCommonStart = Counter(allFilesStartNames).most_common(1)[0][0] - commonStartFilenames = [f for f in ls if f.startswith(mostCommonStart)] - return commonStartFilenames - -def get_salute_string(): - time_now = datetime.datetime.now().time() - time_end_morning = datetime.time(12,00,00) - time_end_lunch = datetime.time(13,00,00) - time_end_afternoon = datetime.time(15,00,00) - time_end_evening = datetime.time(20,00,00) - time_end_night = datetime.time(4,00,00) - if time_now >= time_end_night and time_now < time_end_morning: - return 'Have a good day!' - elif time_now >= time_end_morning and time_now < time_end_lunch: - return 'Enjoy your lunch!' - elif time_now >= time_end_lunch and time_now < time_end_afternoon: - return 'Have a good afternoon!' - elif time_now >= time_end_afternoon and time_now < time_end_evening: - return 'Have a good evening!' - else: - return 'Have a good night!' - -def remove_known_extension(name): - for ext in KNOWN_EXTENSIONS: - if name.endswith(ext): - return name[:-len(ext)], ext - - return name, '' - -def getCustomAnnotTooltip(annotState): - toolTip = ( - f'Name: {annotState["name"]}\n\n' - f'Type: {annotState["type"]}\n\n' - f'Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n' - f'Description: {annotState["description"]}\n\n' - f'SHORTCUT: "{annotState["shortcut"]}"' - ) - return toolTip - -def trim_path(path, depth=3, start_with_dots=True): - path_li = os.path.abspath(path).split(os.sep) - rel_path = f'{f"{os.sep}".join(path_li[-depth:])}' - if start_with_dots: - return f'...{os.sep}{rel_path}' - else: - return rel_path - -def get_add_custom_prompt_model_instructions(): - init_sh = html_utils.init_sh - segment_sh = html_utils.segment_sh - add_prompt_sh = html_utils.add_prompt_sh - href = f'here' - text = html_utils.paragraph(f""" - To use a custom prompt model, you need to create a Python file with the name - acdcPromptModel.py.
    - Note that the folder name where you place this file will be used as the - model name.

    - In this file, you will implement a class called Model with - at least the {init_sh} to initialise the model,
    - the {add_prompt_sh} method to add prompts (points, boxes, etc.) - to the model, and the {segment_sh} method to run the - segmentation.

    - Have a look at the existing models in the promptable_models - folder for examples.

    - If it doesn't work, please report the issue {href} with the - code you wrote. Thanks! - """) - return text - -def get_add_custom_model_instructions(): - user_manual_url = 'https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf' - href_user_manual = f'user manual' - href = f'here' - class_sh = html_utils.class_sh - def_sh = html_utils.def_sh - kwargs_sh = html_utils.kwargs_sh - Model_sh = html_utils.Model_sh - segment_sh = html_utils.segment_sh - predict_sh = html_utils.predict_sh - init_sh = html_utils.init_sh - myModel_sh = html_utils.myModel_sh - return_sh = html_utils.return_sh - equal_sh = html_utils.equal_sh - open_par_sh = html_utils.open_par_sh - close_par_sh = html_utils.close_par_sh - image_sh = html_utils.image_sh - from_sh = html_utils.from_sh - import_sh = html_utils.import_sh - s = html_utils.paragraph(f""" - To use a custom model first create a folder with the name of your model.

    - Inside this new folder create a file named acdcSegment.py.

    - In the acdcSegment.py file you will implement the model class.

    - Have a look at the other existing models, but essentially you have to create - a class called Model with at least
    - the {init_sh} and the {segment_sh} method.

    - The {segment_sh} method takes the image (2D or 3D) as an input and return the segmentation mask.

    - You can find more details in the {href_user_manual} at the section - called Adding segmentation models to the pipeline.

    - Pseudo-code for the acdcSegment.py file: -

    
    -    {from_sh} myModel {import_sh} {myModel_sh}
    -
    -    {class_sh} {Model_sh}:
    -        {def_sh} {init_sh}(self, {kwargs_sh}):
    -            self.model {equal_sh} {myModel_sh}{open_par_sh}{close_par_sh}
    -
    -        {def_sh} {segment_sh}(self, {image_sh}, {kwargs_sh}):
    -            labels {equal_sh} self.model.{predict_sh}{open_par_sh}{image_sh}{close_par_sh}
    -            {return_sh} labels
    -    
    - - If it doesn't work, please report the issue {href} with the - code you wrote. Thanks. - """) - return s - -def is_iterable(item): - try: - iter(item) - return True - except TypeError as e: - return False - -class utilClass: - pass - -def get_trimmed_list(li: list, max_num_digits=10): - if len(li) == 0: - return '[]' - - tom_num_digits = sum([len(str(val)) for val in li]) - - if tom_num_digits == 0: - return f"[{', '.join(map(str, li))}]" - - avg_num_digits = tom_num_digits/len(li) - max_num_vals = int(round(max_num_digits/avg_num_digits)) - - if tom_num_digits > max_num_digits: - front_vals = ceil(max_num_vals / 2) - back_vals = max_num_vals // 2 - - if front_vals + back_vals >= len(li): - return f"[{', '.join(map(str, li))}]" - - li = li[:front_vals] + ['...'] + li[len(li) - back_vals:] - - return f"[{', '.join(map(str, li))}]" - -def get_trimmed_dict(di: dict, max_num_digits=10): - di_str = di.copy() - total_num_digits = sum([len(str(key)) + len(str(val)) for key, val in di.items()]) - avg_num_digits = total_num_digits / len(di) - max_num_vals = int(round(max_num_digits / avg_num_digits)) - if total_num_digits > max_num_digits: - keys = list(di_str.keys()) - for key in keys[max_num_vals:-max_num_vals]: - del di_str[key] - di_str[keys[max_num_vals]] = "..." - return f"[{', '.join([f'{key} -> {val}' for key, val in di_str.items()])}]" - -def checked_reset_index(df): - if df.index.names is None or df.index.names == [None]: - return df.reset_index(drop=True) - else: - return df.reset_index() - -def checked_reset_index_Cell_ID(df): - if df.index.names == ['Cell_ID']: - return df - df = checked_reset_index(df) - return df.set_index('Cell_ID') - - -def _bytes_to_MB(size_bytes): - factor = pow(2, -20) - size_MB = round(size_bytes*factor) - return size_MB - -def _bytes_to_GB(size_bytes): - factor = pow(2, -30) - size_GB = round(size_bytes*factor, 2) - return size_GB - -def getMemoryFootprint(files_list): - required_memory = sum([ - 48 if file.endswith('.h5') else os.path.getsize(file) - for file in files_list - ]) - return required_memory - -def get_logs_path(): - return logs_path - -class Logger(logging.Logger): - def __init__( - self, - module='base', - name='cellacdc-logger', - level=logging.DEBUG - ): - super().__init__(f'{name}-{module}', level=level) - self._stdout = sys.stdout - self._stderr = StdErr(logger=self) - sys.stderr = self._stderr - self._levelToName = { - 50: "CRITICAL", - 40: "ERROR", - 30: "WARNING", - 20: "INFO", - 10: "DEBUG", - 0: "NOTSET" - } - - def write(self, text, log_to_file=True, write_to_stdout=True): - """Capture print statements, print to terminal and log text to - the open log file - - Parameters - ---------- - text : str - Text to log - log_to_file : bool, optional - If True, call `info` method with `text`. Default is True - """ - if write_to_stdout: - self._stdout.write(text) - - if not log_to_file: - return - - if text == '\n': - return - - if not text: - return - - self.debug(text) - - def close(self): - for handler in self.handlers: - handler.close() - self.removeHandler(handler) - sys.stdout = self._stdout - self._stderr.close() - - def __del__(self): - sys.stdout = self._stdout - self._stderr.close() - - def info(self, text, *args, **kwargs): - super().info(text, *args, **kwargs) - try: - self.write(f'{text}\n', log_to_file=False) - except TypeError: - # Sometimes the logger is patched (e.g., by spotiflow), which - # triggers the TypeError because the patching function does not have - # log_to_file argument - self.write(f'{text}\n') - - def warning(self, text, *args, **kwargs): - super().warning(text, *args, **kwargs) - try: - self.write(f'[WARNING]: {text}\n', log_to_file=False) - except TypeError: - # Sometimes the logger is patched (e.g., by spotiflow), which - # triggers the TypeError because the patching function does not have - # log_to_file argument - self.write(f'[WARNING]: {text}\n') - - def error(self, text, *args, write_traceback=True, **kwargs): - super().error(text, *args, **kwargs) - self.write(traceback.format_exc()) - try: - self.write(f'[ERROR]: {text}\n', log_to_file=False) - except TypeError: - # Sometimes the logger is patched (e.g., by spotiflow), which - # triggers the TypeError because the patching function does not have - # log_to_file argument - self.write(f'[ERROR]: {text}\n') - - def plain(self, text, write_to_stdout=False): - orig_formatters = [handler.formatter for handler in self.handlers] - for handler in self.handlers: - handler.setFormatter(logging.Formatter('%(message)s')) - self.write(text, write_to_stdout=write_to_stdout) - for handler in self.handlers: - handler.setFormatter(orig_formatters.pop(0)) - - def critical(self, text, *args, **kwargs): - super().critical(text, *args, **kwargs) - try: - self.write(f'[CRITICAL]: {text}\n', log_to_file=False) - except TypeError: - # Sometimes the logger is patched (e.g., by spotiflow), which - # triggers the TypeError because the patching function does not have - # log_to_file argument - self.write(f'[CRITICAL]: {text}\n') - - def exception(self, text, *args, write_traceback=True, **kwargs): - super().exception(text, *args, **kwargs) - self.write(traceback.format_exc()) - try: - self.write(f'[ERROR]: {text}\n', log_to_file=False) - except TypeError: - # Sometimes the logger is patched (e.g., by spotiflow), which - # triggers the TypeError because the patching function does not have - # log_to_file argument - self.write(f'[ERROR]: {text}\n') - - def log(self, level, text): - if not isinstance(level, int): - printl(level, text, type(level), type(text), sep='\n') - super().log(level, text) - levelName = self._levelToName.get(level, 'INFO') - getattr(self, levelName.lower())(text) - - def flush(self): - self._stdout.flush() - -class StdErr: - def __init__(self, logger: Logger=None): - self._sys_stderr = sys.stderr - self._err_msg_line_buffer = [] - self._logger = logger - - def write(self, text: str): - if text.startswith('Traceback'): - print('-'*100) - - self._sys_stderr.write(text) - - if not text: - return - - self._err_msg_line_buffer.append(text) - if not text.endswith('\n'): - return - - # If the line ends with a newline, flush the buffer - err_line = ''.join(self._err_msg_line_buffer) - if self._logger is not None: - self._logger.plain(err_line, write_to_stdout=False) - else: - print(err_line) - - self._err_msg_line_buffer = [] - - def flush(self): - self._sys_stderr.flush() - - def close(self): - """Close the StdErr stream""" - sys.stderr = self._sys_stderr - -def delete_older_log_files(logs_path): - if not os.path.exists(logs_path): - return - - log_files = os.listdir(logs_path) - for log_file in log_files: - if not log_file.endswith('.log'): - continue - - log_filepath = os.path.join(logs_path, log_file) - try: - mtime = os.path.getmtime(log_filepath) - except Exception as err: - continue - - mdatetime = datetime.datetime.fromtimestamp(mtime) - days = (datetime.datetime.now() - mdatetime).days - if days < 7: - continue - - try: - os.remove(log_filepath) - except Exception as err: - continue - -def get_info_version_text(is_cli=False, cli_formatted_text=True): - version = read_version() - release_date = get_date_from_version(version, package='cellacdc') - py_ver = sys.version_info - env_folderpath = sys.prefix - python_version = f'{py_ver.major}.{py_ver.minor}.{py_ver.micro}' - info_txts = [ - f'Version {version}', - f'Released on: {release_date}', - f'Installed in "{cellacdc_path}"', - f'Environment folder: "{env_folderpath}"', - f'User profile folder: "{user_profile_path}"', - f'Settings folder: "{settings_folderpath}"', - f'Python {python_version}', - f'Platform: {platform.platform()}', - f'System: {platform.system()}', - ] - if is_linux: - try: - distro_name = get_linux_distribution_name() - except Exception as err: - distro_name = 'Undetermined' - - info_txts.append(f'Linux distribution: {distro_name}') - - if GUI_INSTALLED and not is_cli: - info_txts.append(f'Icons from: "{qrc_resources_path}"') - try: - from qtpy import QtCore - info_txts.append(f'Qt {QtCore.__version__}') - except Exception as err: - info_txts.append('Qt: Not installed') - - try: - branch_name = get_git_branch_name() - info_txts.append(f'Git branch: "{branch_name}"') - except Exception as err: - pass - - info_txts.append(f'Working directory: {os.getcwd()}') - - if not cli_formatted_text: - return info_txts - - info_txts = [f' - {txt}' for txt in info_txts] - - max_len = max([len(txt) for txt in info_txts]) + 2 - - formatted_info_txts = [] - for txt in info_txts: - horiz_spacing = ' '*(max_len - len(txt)) - txt = f'{txt}{horiz_spacing}|' - formatted_info_txts.append(txt) - - formatted_info_txts.insert(0, 'Cell-ACDC info:\n') - formatted_info_txts.insert(0, '='*max_len) - formatted_info_txts.append('='*max_len) - info_txt = '\n'.join(formatted_info_txts) - - try: - from spotmax.utils import get_info_version_text as smax_info - smax_info_txt = smax_info(include_platform=False, is_cli=is_cli) - info_txt += '\n\n' + smax_info_txt - except ImportError: - pass - - return info_txt - -def _log_system_info(logger, log_path, is_cli=False, also_spotmax=False): - logger.info(f'Initialized log file "{log_path}"') - - info_txt = get_info_version_text(is_cli=is_cli) - - logger.info(info_txt) - - if not also_spotmax: - return - - from spotmax.utils import get_info_version_text as smax_info - smax_info_txt = smax_info(include_platform=False) - logger.info(smax_info_txt) - -def setupLogger(module='base', logs_path=None, caller='Cell-ACDC'): - if logs_path is None: - logs_path = get_logs_path() - - logger = Logger(module=module) - sys.stdout = logger - - delete_older_log_files(logs_path) - if not os.path.exists(logs_path): - os.mkdir(logs_path) - - date_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - id = uuid4() - log_filename = f'{date_time}_{module}_{id}_stdout.log' - log_path = os.path.join(logs_path, log_filename) - - output_file_handler = logging.FileHandler(log_path, mode='w') - - # Format your logs (optional) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s:\n' - '------------------------\n' - '%(message)s\n' - '------------------------\n', - datefmt='%d-%m-%Y, %H:%M:%S' - ) - output_file_handler.setFormatter(formatter) - - logger.addHandler(output_file_handler) - - _log_system_info(logger, log_path, also_spotmax=caller!='Cell-ACDC') - - # if module == 'gui' and GUI_INSTALLED: - # qt_handler = widgets.QtHandler() - # qt_handler.setFormatter(logging.Formatter("%(message)s")) - # logger.addHandler(qt_handler) - - return logger, logs_path, log_path, log_filename - -def get_pos_foldernames(exp_path, check_if_is_sub_folder=False): - if not check_if_is_sub_folder: - ls = listdir(exp_path) - pos_foldernames = [ - pos for pos in ls if is_pos_folderpath(os.path.join(exp_path, pos)) - ] - else: - folder_type = determine_folder_type(exp_path) - is_pos_folder, is_images_folder, _ = folder_type - if is_pos_folder: - return [os.path.basename(exp_path)] - elif is_images_folder: - pos_path = os.path.dirname(exp_path) - if is_pos_folderpath(pos_path): - return [os.path.basename(pos_path)] - else: - return [] - else: - return get_pos_foldernames(exp_path) - return pos_foldernames - -def get_images_folderpath(folderpath): - if os.path.isfile(folderpath): - folderpath = os.path.dirname(folderpath) - - if folderpath.endswith('Images'): - return folderpath - - images_folderpath = os.path.join(folderpath, 'Images') - if os.path.exists(images_folderpath): - return images_folderpath - - return '' - -def getMostRecentPath(): - if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - MostRecentPath = '' - for path in df['path']: - if os.path.exists(path): - MostRecentPath = path - break - else: - MostRecentPath = '' - return MostRecentPath - -def addToRecentPaths(exp_path, logger=None): - if not os.path.exists(exp_path): - return - exp_path = exp_path.replace('\\', '/') - if os.path.exists(recentPaths_path): - try: - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() - else: - openedOn = [np.nan]*len(recentPaths) - if exp_path in recentPaths: - pop_idx = recentPaths.index(exp_path) - recentPaths.pop(pop_idx) - openedOn.pop(pop_idx) - recentPaths.insert(0, exp_path) - openedOn.insert(0, datetime.datetime.now()) - # Keep max 40 recent paths - if len(recentPaths) > 40: - recentPaths.pop(-1) - openedOn.pop(-1) - except Exception as e: - recentPaths = [exp_path] - openedOn = [datetime.datetime.now()] - else: - recentPaths = [exp_path] - openedOn = [datetime.datetime.now()] - df = pd.DataFrame({ - 'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, dtype='datetime64[ns]')} - ) - df.index.name = 'index' - df.to_csv(recentPaths_path) - -def checkDataIntegrity(filenames, parent_path, parentQWidget=None): - if not filenames: - msg = widgets.myMessageBox(wrapText=False) - txt = html_utils.paragraph( - 'Cell-ACDC could not find any files in the folder ' - f'{parent_path}.

    ' - 'Please make sure that the folder contains at least one image file.

    ' - 'Thank you for your patience!' - ) - msg.warning(parentQWidget, 'Selected folder is emppty', txt) - raise FileNotFoundError( - f'No files found in the folder {parent_path}. ' - ) - - char = filenames[0][:2] - startWithSameChar = all([f.startswith(char) for f in filenames]) - if not startWithSameChar: - msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Cell-ACDC detected files inside the folder ' - 'that do not start with the same, common basename.

    ' - 'To ensure correct loading of the data, the folder where ' - 'the file(s) is/are should either contain a single image file or' - 'only files that start with the same, common basename.

    ' - 'For example the following filenames:

    ' - 'F014_s01_phase_contr.tif
    ' - 'F014_s01_mCitrine.tif

    ' - 'are named correctly since they all start with the ' - 'the common basename "F014_s01_". After the common basename you ' - 'can write whatever text you want. In the example above, "phase_contr" ' - 'and "mCitrine" are the channel names.

    ' - 'Data loading may still be successfull, so Cell-ACDC will ' - 'still try to load data now.
    ' - ) - filesFormat = [f' - {file}' for file in filenames] - filesFormat = "\n".join(filesFormat) - detailsText = ( - f'Files present in the folder {parent_path}:\n\n' - f'{filesFormat}' - ) - msg.addShowInFileManagerButton(parent_path, txt='Open folder...') - msg.warning( - parentQWidget, 'Data structure compromised', txt, - detailsText=detailsText, buttonsTexts=('Cancel', 'Ok') - ) - if msg.cancel: - raise TypeError( - 'Process aborted by the user.' - ) - return False - return True - -def get_cca_colname_desc(): - desc = { - 'Cell ID': ( - 'ID of the segmented cell. All of the other columns ' - 'are properties of this ID.' - ), - 'Cell cycle stage': ( - 'G1 if the cell does NOT have a bud. S/G2/M if it does.' - ), - 'Relative ID': ( - 'ID of the bud related to the Cell ID (row). For cells in G1 write the ' - 'bud ID it had in the previous cycle.' - ), - 'Generation number': ( - 'Number of times the cell divided from a bud. For cells in the first ' - 'frame write any number greater than 1.' - ), - 'Relationship': ( - 'Relationship of the current Cell ID (row). ' - 'Either mother or bud. An object is a bud if ' - 'it didn\'t divide from the mother yet. All other instances ' - '(e.g., cell in G1) are still labelled as mother.' - ), - 'Emerging frame num.': ( - 'Frame number at which the object emerged/appeared in the scene.' - ), - 'Division frame num.': ( - 'Frame number at which the bud separated from the mother.' - ), - 'Is history known?': ( - 'Cells that are already present in the first frame or appears ' - 'from outside of the field of view, have some information missing. ' - 'For example, for cells in the first frame we do not know how many ' - 'times it budded and divided in the past. ' - 'In these cases Is history known? is True.' - ) - } - return desc - -def testQcoreApp(): - print(QCoreApplication.instance()) - -def store_custom_model_path(model_file_path): - model_file_path = model_file_path.replace('\\', '/') - model_name = os.path.basename(os.path.dirname(model_file_path)) - cp = config.ConfigParser() - if os.path.exists(models_list_file_path): - cp.read(models_list_file_path) - if model_name not in cp: - cp[model_name] = {} - cp[model_name]['path'] = model_file_path - with open(models_list_file_path, 'w') as configFile: - cp.write(configFile) - -def store_custom_promptable_model_path(promptable_model_file_path): - model_file_path = promptable_model_file_path.replace('\\', '/') - model_name = os.path.basename(os.path.dirname(model_file_path)) - cp = config.ConfigParser() - if os.path.exists(promptable_models_list_file_path): - cp.read(promptable_models_list_file_path) - if model_name not in cp: - cp[model_name] = {} - cp[model_name]['path'] = model_file_path - with open(promptable_models_list_file_path, 'w') as configFile: - cp.write(configFile) - -def check_git_installed(parent=None): - try: - subprocess.check_call(['git', '--version'], shell=True) - return True - except Exception as e: - print('='*20) - traceback.print_exc() - print('='*20) - git_url = 'https://git-scm.com/book/en/v2/Getting-Started-Installing-Git' - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - In order to install javabridge you first need to install - Git (it was not found).

    - Close Cell-ACDC and follow the instructions - {html_utils.tag('here', f'a href="{git_url}"')}.

    - NOTE: After installing Git you might need to restart the - terminal. - """) - msg.warning( - parent, 'Git not installed', txt - ) - return False - -def browse_url(url): - import webbrowser - webbrowser.open(url) - -def browse_docs(): - browse_url(urls.docs_homepage) - -def install_java(): - try: - subprocess.check_call(['javac', '-version'], shell=True) - return False - except Exception as e: - from . import widgets - win = widgets.installJavaDialog() - win.exec_() - return win.clickedButton == win.cancelButton - -def install_javabridge(force_compile=False, attempt_uninstall_first=False): - if attempt_uninstall_first: - try: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'uninstall', '-y', 'javabridge'] - ) - except Exception as e: - pass - if sys.platform.startswith('win'): - if force_compile: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', - 'git+https://github.com/SchmollerLab/python-javabridge-acdc'] - ) - else: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', - 'git+https://github.com/SchmollerLab/python-javabridge-windows'] - ) - elif is_mac: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', - 'git+https://github.com/SchmollerLab/python-javabridge-acdc'] - ) - elif is_linux: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', - 'git+https://github.com/LeeKamentsky/python-javabridge.git@master'] - ) - -def is_in_bounds(x,y,X,Y): - in_bounds = x >= 0 and x < X and y >= 0 and y < Y - return in_bounds - -def read_version(logger=None, return_success=False): - cellacdc_parent_path = os.path.dirname(cellacdc_path) - cellacdc_parent_folder = os.path.basename(cellacdc_parent_path) - if cellacdc_parent_folder == 'site-packages': - from . import __version__ - version = __version__ - success = True - else: - try: - from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__) - success = True - except Exception as e: - if logger is None: - logger = print - logger('*'*40) - logger(traceback.format_exc()) - logger('-'*40) - logger( - '[WARNING]: Cell-ACDC could not determine the current version. ' - 'Returning the version determined at installation time. ' - 'See details above.' - ) - logger('='*40) - try: - from . import _version - version = _version.version - success = False - except Exception as e: - version = 'ND' - success = False - - if return_success: - return version, success - else: - return version - -def get_date_from_version(version: str, package='cellacdc', debug=False): - try: - response = requests.get( - f'https://pypi.org/pypi/{package}/json', - timeout=2 - ) - res_json = response.json() - pypi_releases_json = res_json['releases'] - version_json = pypi_releases_json[version][0] - upload_time = version_json['upload_time_iso_8601'] - date = datetime.datetime.strptime( - upload_time, r'%Y-%m-%dT%H:%M:%S.%fZ' - ) - date_str = date.strftime(r'%A %d %B %Y at %H:%M') - return date_str - except Exception as err: - if debug: - traceback.print_exc() - - try: - # Locate the direct_url.json file for the package - # installed with pip git+ - dist = importlib.metadata.distribution(package) - dist_info_dir = dist._path # internal path to .dist-info - direct_url_path = os.path.join(dist_info_dir, "direct_url.json") - - with open(direct_url_path) as f: - data = json.load(f) - - vcs_info = data["vcs_info"] - commit_id = vcs_info.get("commit_id") - url = data.get("url") - - parts = url.split("github.com/")[1].split(".git")[0] - owner, repo = parts.split("/", 1) - - # Query GitHub API for commit date - api_url = ( - f"https://api.github.com/repos/{owner}/{repo}/commits/{commit_id}" - ) - response = requests.get(api_url) - response.raise_for_status() - - commit_data = response.json() - date_utc = commit_data["commit"]["committer"]["date"] - - date_str = format_commit_date_utc(date_utc) - - return date_str - except Exception as err: - if debug: - traceback.print_exc() - - try: - if package == 'cellacdc': - pkg_path = cellacdc_path - elif package == 'spotmax': - from spotmax import spotmax_path - pkg_path = spotmax_path - commit_hash = re.findall(r'\+g([A-Za-z0-9]+)(\.d)?', version)[0][0] - git_path = os.path.dirname(pkg_path) - command = f'git -C {git_path} show {commit_hash}' - commit_log = _subprocess_run_command( - command, shell=False, callback='check_output' - ) - commit_log = commit_log.decode() - date_log = re.findall(r'Date:(.*) \+', commit_log)[0].strip() - date = datetime.datetime.strptime(date_log, r'%a %b %d %H:%M:%S %Y') - date_str = date.strftime(r'%A %d %B %Y at %H:%M') - return date_str - except Exception as err: - if debug: - traceback.print_exc() - - return 'ND' - -def get_git_branch_name(): - command = 'git rev-parse --abbrev-ref HEAD' - output = _subprocess_run_command( - command, shell=False, callback='check_output' - ) - branch_name = output.decode().strip() - return branch_name - -def showInExplorer(path): - if is_mac: - os.system(f'open "{path}"') - elif is_linux: - os.system(f'xdg-open "{path}"') - else: - os.startfile(path) - -def exec_time(func): - @wraps(func) - def inner_function(self, *args, **kwargs): - t0 = time.perf_counter() - if func.__code__.co_argcount==1 and func.__defaults__ is None: - result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: - result = func(self, *args) - else: - result = func(self, *args, **kwargs) - t1 = time.perf_counter() - s = f'{func.__name__} execution time = {(t1-t0)*1000:.3f} ms' - printl(s, is_decorator=True) - return result - return inner_function - -def setRetainSizePolicy(widget, retain=True): - sp = widget.sizePolicy() - sp.setRetainSizeWhenHidden(retain) - widget.setSizePolicy(sp) - -def getAcdcDfSegmPaths(images_path): - ls = listdir(images_path) - basename = getBasename(ls) - paths = {} - for file in ls: - filePath = os.path.join(images_path, file) - fileName, ext = os.path.splitext(file) - endName = fileName[len(basename):] - if endName.find('acdc_output') != -1 and ext=='.csv': - info_name = endName.replace('acdc_output', '') - paths.setdefault(info_name, {}) - paths[info_name]['acdc_df_path'] = filePath - paths[info_name]['acdc_df_filename'] = fileName - elif endName.find('segm') != -1 and ext=='.npz': - info_name = endName.replace('segm', '') - paths.setdefault(info_name, {}) - paths[info_name]['segm_path'] = filePath - paths[info_name]['segm_filename'] = fileName - return paths - -def getChannelFilePath(images_path, chName): - file = '' - alignedFilePath = '' - tifFilePath = '' - h5FilePath = '' - for file in listdir(images_path): - filePath = os.path.join(images_path, file) - if file.endswith(f'{chName}_aligned.npz'): - alignedFilePath = filePath - elif file.endswith(f'{chName}.tif'): - tifFilePath = filePath - elif file.endswith(f'{chName}.h5'): - h5FilePath = filePath - if alignedFilePath: - return alignedFilePath - elif h5FilePath: - return h5FilePath - elif tifFilePath: - return tifFilePath - else: - return '' - -def get_number_fstring_formatter(dtype, precision=4): - if np.issubdtype(dtype, np.integer): - return 'd' - else: - return f'.{precision}f' - -def get_chname_from_basename(filename, basename, remove_ext=True): - if remove_ext: - filename, ext = os.path.splitext(filename) - chName = filename[len(basename):] - aligned_idx = chName.find('_aligned') - if aligned_idx != -1: - chName = chName[:aligned_idx] - return chName - -def getBaseAcdcDf(rp): - zeros_list = [0]*len(rp) - nones_list = [None]*len(rp) - minus1_list = [-1]*len(rp) - IDs = [] - xx_centroid = [] - yy_centroid = [] - zz_centroid = [] - for obj in rp: - xc, yc = obj.centroid[-2:] - IDs.append(obj.label) - xx_centroid.append(xc) - yy_centroid.append(yc) - if len(obj.centroid) == 3: - zc = obj.centroid[0] - zz_centroid.append(zc) - - df = pd.DataFrame( - { - 'Cell_ID': IDs, - 'is_cell_dead': zeros_list, - 'is_cell_excluded': zeros_list, - 'x_centroid': xx_centroid, - 'y_centroid': yy_centroid, - 'was_manually_edited': minus1_list - } - ).set_index('Cell_ID') - if zz_centroid: - df['z_centroid'] = zz_centroid - - return df - -def getBasenameAndChNames(images_path, useExt=None): - _tempPosData = utilClass() - _tempPosData.images_path = images_path - load.loadData.getBasenameAndChNames(_tempPosData, useExt=useExt) - return _tempPosData.basename, _tempPosData.chNames - -def getBasename(files): - basename = files[0] - for file in files: - # Determine the basename based on intersection of all files - _, ext = os.path.splitext(file) - sm = difflib.SequenceMatcher(None, file, basename) - i, j, k = sm.find_longest_match( - 0, len(file), 0, len(basename) - ) - basename = file[i:i+k] - return basename - -def findalliter(patter, string): - """Function used to return all re.findall objects in string""" - m_test = re.findall(r'(\d+)_(.+)', string) - m_iter = [m_test] - while m_test: - m_test = re.findall(r'(\d+)_(.+)', m_test[0][1]) - m_iter.append(m_test) - return m_iter - -def clipSelemMask(mask, shape, Yc, Xc, copy=True): - if copy: - mask = mask.copy() - - Y, X = shape - h, w = mask.shape - - # Bottom, Left, Top, Right global coordinates of mask - Y0, X0, Y1, X1 = Yc-(h/2), Xc-(w/2), Yc+(h/2), Xc+(w/2) - mask_limits = [floor(Y0)+1, floor(X0)+1, floor(Y1)+1, floor(X1)+1] - - if Y0>=0 and X0>=0 and Y1<=Y and X1<=X: - # Mask is withing shape boundaries, no need to clip - ystart, xstart, yend, xend = mask_limits - mask_slice = slice(ystart, yend), slice(xstart, xend) - return mask, mask_slice - - if Y0<0: - # Mask is exceeding at the bottom - ystart = floor(abs(Y0)) - mask_limits[0] = 0 - mask = mask[ystart:] - if X0<0: - # Mask is exceeding at the left - xstart = floor(abs(X0)) - mask_limits[1] = 0 - mask = mask[:, xstart:] - if Y1>Y: - # Mask is exceeding at the top - yend = ceil(abs(Y1)) - Y - mask_limits[2] = Y - mask = mask[:-yend] - if X1>X: - # Mask is exceeding at the right - xend = ceil(abs(X1)) - X - mask_limits[3] = X - mask = mask[:, :-xend] - - ystart, xstart, yend, xend = mask_limits - mask_slice = slice(ystart, yend), slice(xstart, xend) - return mask, mask_slice - - -def listdir(path) -> List[str]: - return natsorted([ - f for f in os.listdir(path) - if not f.startswith('.') - and not f == 'desktop.ini' - and not f == 'recovery' - and not f.endswith('.new.npz') - ]) - -def setDefaultValueArgSpecsFromKwargs( - params: List[ArgSpec], - kwargs: Dict[str, object] - ): - new_params = [] - for param in params: - new_value = kwargs.get(param.name) - if new_value is None: - new_params.append(param) - continue - - new_param = ArgSpec( - name=param.name, - default=new_value, - type=param.type, - desc=param.desc, - docstring=param.docstring - ) - new_params.append(new_param) - return new_params - -def insertModelArgSpec( - params, param_name, param_value, param_type=None, desc='', - docstring='' - ): - updated_params = [] - for param in params: - if param.name == param_name: - if param_type is None: - param_type = param.type - new_param = ArgSpec( - name=param_name, default=param_value, type=param_type, - desc=desc, docstring=docstring - ) - updated_params.append(new_param) - else: - updated_params.append(param) - return updated_params - -def get_function_argspec(function, args_to_skip={'logger_func',}): - argspecs = inspect.getfullargspec(function) - kwargs_type_hints = typing.get_type_hints(function) - docstring = function.__doc__ - params = params_to_ArgSpec( - argspecs, kwargs_type_hints, docstring, - args_to_skip=args_to_skip - ) - return params - -def getModelArgSpec(acdcSegment): - init_ArgSpec = inspect.getfullargspec(acdcSegment.Model.__init__) - init_kwargs_type_hints = typing.get_type_hints(acdcSegment.Model.__init__) - init_doc = acdcSegment.Model.__init__.__doc__ - init_params = params_to_ArgSpec( - init_ArgSpec, init_kwargs_type_hints, init_doc - ) - init_params = add_segm_data_param(init_params, init_ArgSpec) - - segment_ArgSpec = inspect.getfullargspec(acdcSegment.Model.segment) - segment_kwargs_type_hints = typing.get_type_hints(acdcSegment.Model.segment) - try: - segment_ArgSpec.args.remove('frame_i') - except Exception as e: - pass - - segment_doc = acdcSegment.Model.segment.__doc__ - segment_params = params_to_ArgSpec( - segment_ArgSpec, segment_kwargs_type_hints, segment_doc, - ) - - return init_params, segment_params - -def _get_doc_stop_idx(docstring, start_idx, next_param_name=None, debug=False): - if debug: - import pdb; pdb.set_trace() - - if next_param_name is not None: - doc_stop_idx = docstring.find(f'{next_param_name} : ') - if doc_stop_idx > 1: - return doc_stop_idx - - docstring_from_start = docstring[start_idx:] - next_param_searched = re.search(r'\w+ : ', docstring_from_start) - if next_param_searched is not None: - return next_param_searched.start(0) + start_idx - - doc_stop_idx = docstring.find('Returns') - if doc_stop_idx > 1: - return doc_stop_idx - - doc_stop_idx = docstring.find('Notes') - if doc_stop_idx > 1: - return doc_stop_idx - - return -1 - -def parse_model_param_doc(name, next_param_name=None, docstring=None): - if not docstring: - return '' - - try: - # Extract parameter description from 'param : ...' - start_text = f'{name} : ' - if docstring.find(start_text) == -1: - # Parameter not present in docstring - return '' - - doc_start_idx = docstring.find(start_text) + len(start_text) - - doc_stop_idx = _get_doc_stop_idx( - docstring, doc_start_idx, next_param_name=next_param_name - ) - if doc_stop_idx == -1: - doc_stop_idx = len(docstring) - - param_doc = docstring[doc_start_idx:doc_stop_idx] - - # Start at first end of line - param_doc = param_doc[param_doc.find('\n')+1:] - - # Replace multiples spaces with single space - param_doc = re.sub(' +', ' ', param_doc) - - # Remove trailing spaces - param_doc = param_doc.strip() - except Exception as err: - param_doc = '' - - param_doc = param_doc.replace(', optional', '') - - return param_doc - -def add_segm_data_param(init_params, init_argspecs): - if init_argspecs.defaults is None: - num_kwargs = 0 - else: - num_kwargs = len(init_argspecs.defaults) - - # Segm model requires segm data --> add it to params - num_args = len(init_argspecs.args) - num_kwargs - if num_args == 1: - # Args is only self --> segm data not needed - return init_params - - desc = ( -'This model requires an additional segmentation file as input.\n\n' -'Please, select which segmentation file to provide to the model.' - ) - - segm_data_argspec = ArgSpec( - name='Auxiliary segmentation file', - default='', - type=str, - desc=desc, - docstring=None - ) - - init_params.insert(0, segm_data_argspec) - return init_params - -def params_to_ArgSpec( - fullargspecs, type_hints, docstring, args_to_skip=None - ): - params = [] - - if fullargspecs.defaults is None: - return params - - if args_to_skip is None: - args_to_skip = set() - - num_params = len(fullargspecs.args) - ip = num_params - len(fullargspecs.defaults) - if ip < 0: - return params - - for arg, default in zip(fullargspecs.args[ip:], fullargspecs.defaults): - if arg in args_to_skip: - continue - - if arg in type_hints: - _type = type_hints[arg] - else: - _type = type(default) - - next_param_name = None - if ip+1 < num_params: - next_param_name = fullargspecs.args[ip+1] - - param_doc = parse_model_param_doc( - arg, - next_param_name=next_param_name, - docstring=docstring - ) - param = ArgSpec( - name=arg, - default=default, - type=_type, - desc=param_doc, - docstring=docstring - ) - params.append(param) - ip += 1 - return params - -def getClassArgSpecs(classModule, runMethodName='run'): - init_ArgSpec = inspect.getfullargspec(classModule.__init__) - init_kwargs_type_hints = typing.get_type_hints( - classModule.__init__ - ) - init_doc = classModule.__init__.__doc__ - init_params = params_to_ArgSpec( - init_ArgSpec, init_kwargs_type_hints, init_doc - ) - - run_ArgSpec = inspect.getfullargspec(getattr(classModule, runMethodName)) - run_kwargs_type_hints = typing.get_type_hints( - getattr(classModule, runMethodName) - ) - run_doc = getattr(classModule, runMethodName).__doc__ - run_params = params_to_ArgSpec( - run_ArgSpec, run_kwargs_type_hints, run_doc, - args_to_skip={'signals', 'export_to'} - ) - return init_params, run_params - -def getTrackerArgSpec(trackerModule, realTime=False): - init_ArgSpec = inspect.getfullargspec(trackerModule.tracker.__init__) - init_kwargs_type_hints = typing.get_type_hints( - trackerModule.tracker.__init__ - ) - init_doc = trackerModule.tracker.__init__.__doc__ - init_params = params_to_ArgSpec( - init_ArgSpec, init_kwargs_type_hints, init_doc - ) - if realTime: - track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track_frame) - track_kwargs_type_hints = typing.get_type_hints( - trackerModule.tracker.track_frame - ) - track_doc = trackerModule.tracker.track_frame.__doc__ - else: - track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track) - track_kwargs_type_hints = typing.get_type_hints( - trackerModule.tracker.track - ) - track_doc = trackerModule.tracker.track.__doc__ - - track_params = params_to_ArgSpec( - track_ArgSpec, track_kwargs_type_hints, track_doc, - args_to_skip={'signals', 'export_to'} - ) - return init_params, track_params - -def isIntensityImgRequiredForTracker(trackerModule): - track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track) - num_args = len(track_ArgSpec.args) - len(track_ArgSpec.defaults) - # If the number of args is 3 then we have `self, labels, image` as args - # which means the tracker requires the image - return num_args == 3 - -def getDefault_SegmInfo_df(posData, filename): - mid_slice = int(posData.SizeZ/2) - df = pd.DataFrame({ - 'filename': [filename]*posData.SizeT, - 'frame_i': range(posData.SizeT), - 'z_slice_used_dataPrep': [mid_slice]*posData.SizeT, - 'which_z_proj': ['single z-slice']*posData.SizeT, - 'z_slice_used_gui': [mid_slice]*posData.SizeT, - 'which_z_proj_gui': ['single z-slice']*posData.SizeT, - 'resegmented_in_gui': [False]*posData.SizeT, - 'is_from_dataPrep': [False]*posData.SizeT - }).set_index(['filename', 'frame_i']) - return df - -def get_examples_path(which): - if which == 'time_lapse_2D': - foldername = 'TimeLapse_2D' - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/KgJQtsQKZJnWZjL/download/TimeLapse_2D.zip' - file_size = 45143552 - elif which == 'snapshots_3D': - foldername = 'Multi_3D_zStack_Analysed' - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/3RNjGiPwKcdnGtj/download/Yeast_Analysed_multi3D_zStacks.zip' - file_size = 124822528 - else: - return '' - - examples_path = os.path.join(user_profile_path, 'acdc-examples') - example_path = os.path.join(examples_path, foldername) - return examples_path, example_path, url, file_size - -def download_examples(which='time_lapse_2D', progress=None): - examples_path, example_path, url, file_size = get_examples_path(which) - if os.path.exists(example_path): - if progress is not None: - # display 100% progressbar - progress.emit(0, 0) - return example_path - - zip_dst = os.path.join(examples_path, 'example_temp.zip') - - if not os.path.exists(examples_path): - os.makedirs(examples_path, exist_ok=True) - - print(f'Downloading example to {example_path}') - - download_url( - url, zip_dst, verbose=False, file_size=file_size, - progress=progress - ) - exctract_to = examples_path - extract_zip(zip_dst, exctract_to) - - if progress is not None: - # display 100% progressbar - progress.emit(0, 0) - - # Remove downloaded zip archive - os.remove(zip_dst) - print('Example downloaded successfully') - return example_path - -def get_acdc_java_path(): - acdc_java_path = os.path.join(user_profile_path, 'acdc-java') - dot_acdc_java_path = os.path.join(user_profile_path, '.acdc-java') - return acdc_java_path, dot_acdc_java_path - -def get_java_url(): - is_linux = sys.platform.startswith('linux') - is_mac = sys.platform == 'darwin' - is_win = sys.platform.startswith("win") - is_win64 = (is_win and (os.environ["PROCESSOR_ARCHITECTURE"] == "AMD64")) - - # https://drive.google.com/drive/u/0/folders/1MxhySsxB1aBrqb31QmLfVpq8z1vDyLbo - if is_win64: - os_foldername = 'win64' - unzipped_foldername = 'java_portable_windows-0.1' - file_size = 214798150 - # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/eMyirTw8qG2wJMt/download/java_portable_windows-0.1.zip' - url = 'https://github.com/SchmollerLab/java_portable_windows/archive/refs/tags/v0.1.zip' - elif is_mac: - os_foldername = 'macOS' - unzipped_foldername = 'java_portable_macos-0.1' - url = 'https://github.com/SchmollerLab/java_portable_macos/archive/refs/tags/v0.1.zip' - # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/SjZb8aommXgrECq/download/java_portable_macos-0.1.zip' - file_size = 108478751 - elif is_linux: - os_foldername = 'linux' - unzipped_foldername = 'java_portable_linux-0.1' - url = 'https://github.com/SchmollerLab/java_portable_linux/archive/refs/tags/v0.1.zip' - # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/HjeQagixE2cjbZL/download/java_portable_linux-0.1.zip' - file_size = 92520706 - return url, file_size, os_foldername, unzipped_foldername - -def _jdk_exists(jre_path): - # If jre_path exists and it's windows search for ~/acdc-java/win64/jdk - # or ~/.acdc-java/win64/jdk. If not Windows return jre_path - if not jre_path: - return '' - os_acdc_java_path = os.path.dirname(jre_path) - os_foldername = os.path.basename(os_acdc_java_path) - if not os_foldername.startswith('win'): - return jre_path - if os.path.exists(os_acdc_java_path): - for folder in os.listdir(os_acdc_java_path): - if not folder.startswith('jdk'): - continue - dir_path = os.path.join(os_acdc_java_path, folder) - for file in os.listdir(dir_path): - if file == 'bin': - return dir_path - return '' - -def get_package_version(import_pkg_name): - import importlib.metadata - version = importlib.metadata.version(import_pkg_name) - return version - -def check_upgrade_javabridge(): - try: - version = get_package_version('javabridge') - except Exception as e: - return - patch = int(version.split('.')[2]) - if patch > 18: - return - install_javabridge() - -def _java_exists(os_foldername): - acdc_java_path, dot_acdc_java_path = get_acdc_java_path() - os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) - if os.path.exists(os_acdc_java_path): - for folder in os.listdir(os_acdc_java_path): - if not folder.startswith('jre'): - continue - dir_path = os.path.join(os_acdc_java_path, folder) - for file in os.listdir(dir_path): - if file == 'bin': - return dir_path - - # Some users still has the old .acdc folder --> check - os_dot_acdc_java_path = os.path.join(dot_acdc_java_path, os_foldername) - if os.path.exists(os_dot_acdc_java_path): - for folder in os.listdir(os_dot_acdc_java_path): - if not folder.startswith('jre'): - continue - dir_path = os.path.join(os_dot_acdc_java_path, folder) - for file in os.listdir(dir_path): - if file == 'bin': - return dir_path - return '' - - # Check if the user unzipped the javabridge_portable folder and not its content - os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) - if os.path.exists(os_acdc_java_path): - for folder in os.listdir(os_acdc_java_path): - dir_path = os.path.join(os_acdc_java_path, folder) - if folder.startswith('java_portable') and os.path.isdir(dir_path): - # Move files one level up - unzipped_path = os.path.join(os_acdc_java_path, folder) - for name in os.listdir(unzipped_path): - # move files up one level - src = os.path.join(unzipped_path, name) - shutil.move(src, os_acdc_java_path) - try: - shutil.rmtree(unzipped_path) - except PermissionError as e: - pass - # Check if what we moved one level up was actually java - for folder in os.listdir(os_acdc_java_path): - if not folder.startswith('jre'): - continue - dir_path = os.path.join(os_acdc_java_path, folder) - for file in os.listdir(dir_path): - if file == 'bin': - return dir_path - return '' - -def download_java(): - url, file_size, os_foldername, unzipped_foldername = get_java_url() - jre_path = _java_exists(os_foldername) - jdk_path = _jdk_exists(jre_path) - if os_foldername.startswith('win') and jre_path and jdk_path: - return jre_path, jdk_path, url - - if jre_path: - # on macOS jdk is the same as jre - return jre_path, jre_path, url - - acdc_java_path, _ = get_acdc_java_path() - os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) - temp_zip = os.path.join(os_acdc_java_path, 'acdc_java_temp.zip') - - if not os.path.exists(os_acdc_java_path): - os.makedirs(os_acdc_java_path, exist_ok=True) - - try: - download_url(url, temp_zip, file_size=file_size, desc='Java') - extract_zip(temp_zip, os_acdc_java_path) - except Exception as e: - print('=======================') - traceback.print_exc() - print('=======================') - finally: - os.remove(temp_zip) - - # Move files one level up - unzipped_path = os.path.join(os_acdc_java_path, unzipped_foldername) - for name in os.listdir(unzipped_path): - # move files up one level - src = os.path.join(unzipped_path, name) - shutil.move(src, os_acdc_java_path) - try: - shutil.rmtree(unzipped_path) - except PermissionError as e: - pass - - jre_path = _java_exists(os_foldername) - jdk_path = _jdk_exists(jre_path) - return jre_path, jdk_path, url - -def get_model_path(model_name, create_temp_dir=True): - if model_name == 'Automatic thresholding': - model_name == 'thresholding' - - model_info_path = os.path.join(cellacdc_path, 'models', model_name, 'model') - - if os.path.exists(model_info_path): - for file in listdir(model_info_path): - if file != 'weights_location_path.txt': - continue - with open(os.path.join(model_info_path, file), 'r') as txt: - model_path = txt.read() - model_path = os.path.expanduser(model_path) - if not os.path.exists(model_path): - model_path = _write_model_location_to_txt(model_name) - else: - break - else: - model_path = _write_model_location_to_txt(model_name) - else: - os.makedirs(model_info_path, exist_ok=True) - model_path = _write_model_location_to_txt(model_name) - - model_path = migrate_to_new_user_profile_path(model_path) - - if not os.path.exists(model_path): - os.makedirs(model_path, exist_ok=True) - - if not create_temp_dir: - return '', model_path - - exists = check_model_exists(model_path, model_name) - if exists: - return '', model_path - - temp_zip_path = _create_temp_dir() - return temp_zip_path, model_path - -def check_model_exists(model_path, model_name): - try: - import cellacdc - m = model_name.lower() - weights_filenames = getattr(cellacdc, f'{m}_weights_filenames') - files_present = listdir(model_path) - return all([f in files_present for f in weights_filenames]) - except Exception as e: - return True - -def _create_temp_dir(): - temp_model_path = tempfile.mkdtemp() - temp_zip_path = os.path.join(temp_model_path, 'model_temp.zip') - return temp_zip_path - -def _model_url(model_name, return_alternative=False): - if model_name == 'YeaZ': - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/8PMePcwJXmaMMS6/download/YeaZ_weights.zip' - alternative_url = 'https://zenodo.org/record/6125825/files/YeaZ_weights.zip?download=1' - file_size = 693685011 - elif model_name == 'YeastMate': - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/pMT8pAmMkNtN8BP/download/yeastmate_weights.zip' - alternative_url = 'https://zenodo.org/record/6140067/files/yeastmate_weights.zip?download=1' - file_size = 164911104 - elif model_name == 'segment_anything': - url = [ - 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', - 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', - 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth' - ] - file_size = [2564550879, 1249524736, 375042383] - alternative_url = '' - elif model_name == 'YeaZ_v2': - url = [ - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/5PARckkcJcN9D3S/download/weights_budding_BF_multilab_0_1', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/CTHq4HN3adyFbnE/download/weights_budding_PhC_multilab_0_1', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/QTtBJycYnLQZsHQ/download/weights_fission_multilab_0_2' - ] - file_size = [124142981, 124143031, 124144759] - alternative_url = 'https://github.com/rahi-lab/YeaZ-GUI#installation' - elif model_name == 'DeepSea': - url = [ - 'https://github.com/abzargar/DeepSea/raw/master/deepsea/trained_models/segmentation.pth', - 'https://github.com/abzargar/DeepSea/raw/master/deepsea/trained_models/tracker.pth' - ] - file_size = [7988969, 8637439] - alternative_url = '' - elif model_name == 'TAPIR': - url = [ - 'https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy' - ] - file_size = [124408122] - alternative_url = '' - elif model_name == 'Cellpose_germlineNuclei': - url = [ - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/AXG6fFfD8o5GZ83/download/cellpose_germlineNuclei_2023' - ] - file_size = [26570752] - alternative_url = '' - elif model_name == 'omnipose': - url = [ - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/DynLkocWRbQfyRp/download/bact_fluor_cptorch_0' - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/2248Eoyozp3Ezj2/download/bact_fluor_omnitorch_0', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/GiacDfXGerxE7PT/download/bact_phase_omnitorch_0', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/DDq8s3CgnG2Yw6H/download/cyto2_omnitorch_0', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/MM5meM2J5HbWqXR/download/plant_cptorch_0', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/aap7znrWq5sE6JQ/download/plant_omnitorch_0', - 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/w5M46x9qr8zLHZH/download/size_cyto2_omnitorch_0.npy' - ] - file_size = [ - 26558464, - 26558464, - 26558464, - 26558464, - 26558464, - 75071488, - 4096 - ] - alternative_url = '' - elif model_name == 'sam2': - url = [ - 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt', - 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt', - 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt', - 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt' - ] - file_size = [155233385, 184211977, 319128965, 910600801] - alternative_url = '' - else: - return - if return_alternative: - return url, alternative_url - else: - return url, file_size - -def _download_segment_anything_models(): - urls, file_sizes = _model_url('segment_anything') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('segment_anything', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='segment_anything', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def _download_sam2_models(): - urls, file_sizes = _model_url('sam2') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('sam2', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='sam2', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def _download_deepsea_models(): - urls, file_sizes = _model_url('DeepSea') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('deepsea', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='deepsea', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def download_manual(): - manual_folder_path = os.path.join(user_profile_path, 'acdc-manual') - if not os.path.exists(manual_folder_path): - os.makedirs(manual_folder_path, exist_ok=True) - - manual_file_path = os.path.join(user_profile_path, 'Cell-ACDC_User_Manual.pdf') - if not os.path.exists(manual_file_path): - url = 'https://github.com/SchmollerLab/Cell_ACDC/raw/main/UserManual/Cell-ACDC_User_Manual.pdf' - download_url(url, manual_file_path, file_size=1727470) - return manual_file_path - -def download_bioformats_jar( - qparent=None, logger_info=print, logger_exception=print - ): - dst_filepath = os.path.join( - cellacdc_path, 'bioformats', 'jars', 'bioformats_package.jar' - ) - if os.path.exists(dst_filepath): - return True, dst_filepath - urls_to_try = (urls.bioformats_jar_home_url, urls.bioformats_jar_hmgu_url) - success = False - for url in urls_to_try: - try: - logger_info( - f'Downloading `bioformats_package.jar`...' - ) - download_url(url, dst_filepath, file_size=43233280) - success = True - break - except Exception as err: - success = False - traceback_str = traceback.format_exc() - logger_exception(traceback_str) - continue - - if success: - return True, dst_filepath - - _warnings.warn_download_bioformats_jar_failed(dst_filepath, qparent=qparent) - raise ModuleNotFoundError( - 'Bioformats package jar could not be downloaded. Please, ' - f'download it from here {urls.bioformats_download_page} and ' - f'place it in the following path "{dst_filepath}". ' - 'Thank you for your patience!' - ) - return False, dst_filepath - - -def showUserManual(): - manual_file_path = download_manual() - showInExplorer(manual_file_path) - -def get_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - return None - -def download_url( - url, dst, desc='', file_size=None, verbose=True, progress=None - ): - import urllib3 - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - CHUNK_SIZE = 32768 - if verbose: - print(f'Downloading {desc} to: {os.path.dirname(dst)}') - response = requests.get(url, stream=True, timeout=20, verify=False) - if file_size is not None and progress is not None: - progress.emit(file_size, -1) - pbar = tqdm( - total=file_size, unit='B', unit_scale=True, - unit_divisor=1024, ncols=100 - ) - with open(dst, 'wb') as f: - for chunk in response.iter_content(CHUNK_SIZE): - # if chunk: - f.write(chunk) - pbar.update(len(chunk)) - if progress is not None: - progress.emit(-1, len(chunk)) - pbar.close() - -def save_response_content( - response, destination, file_size=None, - model_name='cellpose', progress=None - ): - print(f'Downloading {model_name} to: {os.path.dirname(destination)}') - CHUNK_SIZE = 32768 - - # Download to a temp folder in user path - temp_folder = pathlib.Path.home().joinpath('.acdc_temp') - if not os.path.exists(temp_folder): - os.mkdir(temp_folder) - temp_dst = os.path.join(temp_folder, os.path.basename(destination)) - if file_size is not None and progress is not None: - progress.emit(file_size, -1) - pbar = tqdm( - total=file_size, unit='B', unit_scale=True, - unit_divisor=1024, ncols=100 - ) - with open(temp_dst, "wb") as f: - for chunk in response.iter_content(CHUNK_SIZE): - if chunk: - f.write(chunk) - pbar.update(len(chunk)) - if progress is not None: - progress.emit(-1, len(chunk)) - pbar.close() - - # Move to destination and delete temp folder - destination_dir = os.path.dirname(destination) - if not os.path.exists(destination_dir): - os.makedirs(destination_dir, exist_ok=True) - shutil.move(temp_dst, destination) - shutil.rmtree(temp_folder) - -def extract_zip(zip_path, extract_to_path, verbose=True): - if verbose: - print(f'Extracting to {extract_to_path}...') - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(extract_to_path) - -def check_v123_model_path(model_name): - # Cell-ACDC v1.2.3 saved the weights inside the package, - # while from v1.2.4 we save them on user folder. If we find the - # weights in the package we move them to user folder without downloading - # new ones. - v123_model_path = os.path.join(cellacdc_path, 'models', model_name, 'model') - exists = check_model_exists(v123_model_path, model_name) - if exists: - return v123_model_path - else: - return '' - -def is_old_user_profile_path(path_to_check: os.PathLike): - from . import user_data_dir - user_data_folderpath = user_data_dir() - user_profile_path_txt = os.path.join( - user_data_folderpath, 'acdc_user_profile_location.txt' - ) - if os.path.exists(user_profile_path_txt): - return False - - from . import user_home_path - user_home_path = user_home_path.replace('\\', '/') - path_to_check = path_to_check.replace('\\', '/') - return user_home_path == path_to_check - -def migrate_to_new_user_profile_path(path_to_migrate: os.PathLike): - parent_dir = os.path.dirname(path_to_migrate) - if not is_old_user_profile_path(parent_dir): - return path_to_migrate - folder = os.path.basename(path_to_migrate) - return os.path.join(user_profile_path, folder) - -def _write_model_location_to_txt(model_name): - model_info_path = os.path.join(cellacdc_path, 'models', model_name, 'model') - model_path = os.path.join(user_profile_path, f'acdc-{model_name}') - file = 'weights_location_path.txt' - with open(os.path.join(model_info_path, file), 'w') as txt: - txt.write(model_path) - return os.path.expanduser(model_path) - -def determine_folder_type(folder_path): - is_pos_folder = is_pos_folderpath(folder_path) - is_images_folder = folder_path.endswith('Images') and listdir(folder_path) - contains_images_folder = os.path.exists( - os.path.join(folder_path, 'Images') - ) - contains_pos_folders = len(get_pos_foldernames(folder_path)) > 0 - if contains_pos_folders: - is_pos_folder = False - is_images_folder = False - elif contains_images_folder and not is_pos_folder: - # Folder created by loading an image - is_images_folder = True - folder_path = os.path.join(folder_path, 'Images') - - return is_pos_folder, is_images_folder, folder_path - -def download_model(model_name): - if model_name == 'segment_anything': - try: - _download_segment_anything_models() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'sam2': - try: - _download_sam2_models() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'DeepSea': - try: - _download_deepsea_models() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'TAPIR': - try: - _download_tapir_model() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'YeaZ_v2': - try: - _download_yeaz_models() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'Cellpose_germlineNuclei': - try: - _download_cellpose_germlineNuclei_model() - return True - except Exception as e: - traceback.print_exc() - return False - elif model_name == 'omnipose': - try: - _download_omnipose_models() - return True - except Exception as err: - return False - elif model_name != 'YeastMate' and model_name != 'YeaZ': - # We manage only YeastMate and YeaZ - return True - - try: - # Check if model exists - temp_zip_path, model_path = get_model_path(model_name) - if not temp_zip_path: - # Model exists return - return True - - # Check if user has model in the old v1.2.3 location - v123_model_path = check_v123_model_path(model_name) - if v123_model_path: - print(f'Weights files found in {v123_model_path}') - print(f'--> moving to new location: {model_path}...') - for file in listdir(v123_model_path): - src = os.path.join(v123_model_path, file) - dst = os.path.join(model_path, file) - shutil.copy(src, dst) - return True - - # Download model from url to tempDir/model_temp.zip - temp_dir = os.path.dirname(temp_zip_path) - url, file_size = _model_url(model_name) - print(f'Downloading {model_name} to {model_path}') - download_url( - url, temp_zip_path, file_size=file_size, desc=model_name, - verbose=False - ) - - # Extract zip file inside temp dir - print(f'Extracting model...') - extract_zip(temp_zip_path, temp_dir, verbose=False) - - # Move unzipped files to ~/acdc-{model_name} folder - print(f'Moving files from temporary folder to {model_path}...') - for file in listdir(temp_dir): - if file.endswith('.zip'): - continue - src = os.path.join(temp_dir, file) - dst = os.path.join(model_path, file) - shutil.move(src, dst) - - # Remove temp directory - print(f'Removing temporary folder...') - shutil.rmtree(temp_dir) - return True - - except Exception as e: - traceback.print_exc() - return False - -# def get_tiff_metadata( -# image_arr, -# SizeT=None, -# SizeZ=None, -# PhysicalSizeZ=None, -# PhysicalSizeX=None, -# PhysicalSizeY=None, -# TimeIncrement=None -# ): -# SizeY, SizeX = image_arr.shape[-2:] -# Type = str(image_arr.dtype) - -# metadata = { -# 'SizeX': SizeX, -# 'SizeY': SizeY, -# 'Type': Type -# } - -# axes = 'YX' -# if SizeZ is not None and SizeZ > 1: -# axes = f'Z{axes}' -# metadata['SizeZ'] = SizeZ - -# if SizeT is not None and SizeT > 1: -# axes = f'T{axes}' -# metadata['SizeT'] = SizeT - -# metadata['axes'] = axes - -# if PhysicalSizeX is not None: -# metadata['PhysicalSizeX'] = PhysicalSizeX - -# if PhysicalSizeY is not None: -# metadata['PhysicalSizeY'] = PhysicalSizeY - -# if PhysicalSizeZ is not None: -# metadata['PhysicalSizeZ'] = PhysicalSizeZ - -# if TimeIncrement is not None: -# metadata['TimeIncrement'] = TimeIncrement - -# return metadata - -def get_tiff_metadata( - image_arr, - SizeT=None, - SizeZ=None, - PhysicalSizeZ=None, - PhysicalSizeX=None, - PhysicalSizeY=None, - TimeIncrement=None - ): - SizeY, SizeX = image_arr.shape[-2:] - Type = str(image_arr.dtype) - - metadata = { - 'Pixels': { - 'SizeX': SizeX, - 'SizeY': SizeY, - 'Type': Type - } - } - - axes = 'YX' - if SizeZ is not None and SizeZ > 1: - axes = f'Z{axes}' - metadata['Pixels']['SizeZ'] = SizeZ - - if SizeT is not None and SizeT > 1: - axes = f'T{axes}' - metadata['Pixels']['SizeT'] = SizeT - - metadata['axes'] = axes - - if PhysicalSizeX is not None: - metadata['Pixels']['PhysicalSizeX'] = PhysicalSizeX - - if PhysicalSizeY is not None: - metadata['Pixels']['PhysicalSizeY'] = PhysicalSizeY - - if PhysicalSizeZ is not None: - metadata['Pixels']['PhysicalSizeZ'] = PhysicalSizeZ - - if TimeIncrement is not None: - metadata['Pixels']['TimeIncrement'] = TimeIncrement - - return metadata - -def to_tiff( - new_path, data, - SizeT=None, - SizeZ=None, - PhysicalSizeZ=None, - PhysicalSizeX=None, - PhysicalSizeY=None, - TimeIncrement=None - ): - valid_dtypes = ( - np.uint8, np.uint16, np.float32 - ) - is_valid_dtype = False - for valid_dtype in valid_dtypes: - if np.issubdtype(data.dtype, valid_dtype): - is_valid_dtype = True - break - - if not is_valid_dtype: - data = data.astype(np.float32) - - metadata = get_tiff_metadata( - data, - SizeT=SizeT, - SizeZ=SizeZ, - PhysicalSizeZ=PhysicalSizeZ, - PhysicalSizeX=PhysicalSizeX, - PhysicalSizeY=PhysicalSizeY, - TimeIncrement=TimeIncrement - ) - - # # Potential alternative - # hyperstack = tifffile.memmap( - # new_path, - # shape=img.shape, - # dtype=img.dtype, - # imagej=True, - # metadata={'axes': 'TZYX'}, - # ) - # hyperstack[:] = img - # hyperstack.flush() - - try: - tifffile.imwrite( - new_path, data, metadata=metadata, imagej=True - ) - except Exception as err: - tifffile.imwrite(new_path, data) - -def from_lab_to_obj_coords(lab): - rp = skimage.measure.regionprops(lab) - dfs = [] - keys = [] - for obj in rp: - keys.append(obj.label) - obj_coords = obj.coords - ndim = obj_coords.shape[1] - if ndim == 3: - columns = ['z', 'y', 'x'] - else: - columns = ['y', 'x'] - df_obj = pd.DataFrame(data=obj_coords, columns=columns) - dfs.append(df_obj) - df = pd.concat(dfs, keys = keys, names=['Cell_ID', 'idx']).droplevel('idx') - return df - -def lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=None, z=None): - rp = skimage.measure.regionprops(lab2D) - rois = [] - for obj in rp: - cont = core.get_obj_contours(obj) - yc, xc = obj.centroid - x_str = str((int(xc))).zfill(ndigits) - y_str = str((int(yc))).zfill(ndigits) - name = f'{x_str}-{y_str}' - if z is not None: - z_str = str(z).zfill(ndigits) - name = f'{z_str}-{name}' - - if t is not None: - t_str = str(t).zfill(ndigits) - name = f'{t_str}-{name}' - - name = f'id={obj.label}-{name}' - - roi = ImagejRoi.frompoints( - cont, name=name, t=t, z=z, index=obj.label - ) - rois.append(roi) - return rois - -def from_lab_to_imagej_rois(lab, ImagejRoi, t=0, SizeT=1, max_ID=None): - if max_ID is None: - max_ID = lab.max() - - if SizeT == 1: - t = None - - SizeY, SizeX = lab.shape[-2:] - ndigitsT = len(str(SizeT)) - ndigitsY = len(str(SizeY)) - ndigitsX = len(str(SizeX)) - - if lab.ndim == 3: - rois = [] - SizeZ = len(lab) - ndigitsZ = len(str(SizeZ)) - ndigits = max(ndigitsT, ndigitsZ, ndigitsY, ndigitsX) - for z, lab2D in enumerate(lab): - z_rois = lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=t, z=z) - rois.extend(z_rois) - else: - ndigits = max(ndigitsT, ndigitsY, ndigitsX) - rois = lab2d_to_rois(ImagejRoi, lab, ndigits, t=t) - return rois - -def from_imagej_rois_to_segm_data( - TZYX_shape, ID_to_roi_mapper, rescale_rois_sizes, - repeat_2d_rois_zslices_range - ): - SizeT, SizeZ, SizeY, SizeX = TZYX_shape - segm_data = np.zeros(TZYX_shape, dtype=np.uint32) - for ID, roi in ID_to_roi_mapper.items(): - name = roi.name - name_parts = name.split('-') - zz = [0] - if len(name_parts) == 2 and SizeZ > 1: - # 2D roi in 3D segm data --> place 2D roi on each z-slice - zz = range(*repeat_2d_rois_zslices_range) - - elif len(name_parts) > 2 and SizeZ > 1: - # 2D roi from a 3D roi --> place at requested z-slice - zz = [int(name_parts[-3])] - - tt = [0]*len(zz) - if SizeT > 1: - tt = [roi.t_position]*len(zz) - - y0, x0 = roi.top, roi.left - contours = roi.integer_coordinates + (x0, y0) - xx = contours[:, 0] - yy = contours[:, 1] - if rescale_rois_sizes is not None: - rescale_z = rescale_rois_sizes['Z'] - rescale_y = rescale_rois_sizes['Y'] - rescale_x = rescale_rois_sizes['X'] - - factor_z = rescale_z[1]/rescale_z[0] - factor_y = rescale_y[1]/rescale_y[0] - factor_x = rescale_x[1]/rescale_x[0] - - xx = np.clip(np.round(xx * factor_x).astype(int), 0, SizeX-1) - yy = np.clip(np.round(yy * factor_y).astype(int), 0, SizeY-1) - - for t, z in zip(tt, zz): - if rescale_rois_sizes is not None: - z = round(z*factor_z) - z = z if z=0 else 0 - - rr, cc = skimage.draw.polygon(yy, xx) - segm_data[t, z, rr, cc] = ID - - return np.squeeze(segm_data) - -def aliases_real_time_trackers(reverse=False): - """ - Returns a dictionary with aliases for real-time trackers. - """ - - aliases = { - 'CellACDC_normal_division': 'Cell-ACDC symmetric division', - 'CellACDC_2steps' : 'Cell-ACDC 2 steps', - } - - if reverse: - aliases = {v: k for k, v in aliases.items()} - - return aliases - -def get_list_of_real_time_trackers(): - trackers = get_list_of_trackers() - rt_trackers = [] - aliases = aliases_real_time_trackers() - for tracker in trackers: - if tracker == 'CellACDC': - continue - if tracker == 'YeaZ': - continue - tracker_filename = f'{tracker}_tracker.py' - tracker_path = os.path.join( - cellacdc_path, 'trackers', tracker, tracker_filename - ) - try: - with open(tracker_path) as file: - txt = file.read() - if txt.find('def track_frame') != -1: - rt_trackers.append(tracker) - except Exception as e: - continue - - for i, tracker in enumerate(rt_trackers): - if tracker in aliases: - rt_trackers[i] = aliases[tracker] - - return natsorted(rt_trackers, key=str.casefold) - -def get_list_of_trackers(): - trackers_path = os.path.join(cellacdc_path, 'trackers') - trackers = [] - for name in listdir(trackers_path): - _path = os.path.join(trackers_path, name) - tracker_script_path = os.path.join(_path, f'{name}_tracker.py') - is_valid_tracker = ( - os.path.isdir(_path) and os.path.exists(tracker_script_path) - and not name.endswith('__') - ) - - if name.startswith('_'): - continue - - if is_valid_tracker: - trackers.append(name) - return natsorted(trackers, key=str.casefold) - -def get_list_of_models(): - models = set() - for name in listdir(models_path): - _path = os.path.join(models_path, name) - if not os.path.exists(_path): - continue - - if not os.path.isdir(_path): - continue - - if name.endswith('__'): - continue - - if name.startswith('_'): - continue - - if name == 'skip_segmentation': - continue - - if not os.path.exists(os.path.join(_path, 'acdcSegment.py')): - continue - - if name == 'thresholding': - name = 'Automatic thresholding' - - models.add(name) - - if not os.path.exists(models_list_file_path): - return natsorted(list(models), key=str.casefold) - - cp = config.ConfigParser() - cp.read(models_list_file_path) - models.update(cp.sections()) - return natsorted(list(models), key=str.casefold) - -def get_list_of_promptable_models(): - models = set() - for name in listdir(promptable_models_path): - _path = os.path.join(promptable_models_path, name) - if not os.path.exists(_path): - continue - - if not os.path.isdir(_path): - continue - - if name.endswith('__'): - continue - - if not os.path.exists(os.path.join(_path, 'acdcPromptSegment.py')): - continue - - models.add(name) - - if not os.path.exists(promptable_models_list_file_path): - return natsorted(list(models), key=str.casefold) - - cp = config.ConfigParser() - cp.read(promptable_models_list_file_path) - models.update(cp.sections()) - return natsorted(list(models), key=str.casefold) - -def seconds_to_ETA(seconds): - seconds = round(seconds) - ETA = datetime.timedelta(seconds=seconds) - ETA_split = str(ETA).split(':') - if seconds < 0: - ETA = '00h:00m:00s' - elif seconds >= 86400: - days, hhmmss = str(ETA).split(',') - h, m, s = hhmmss.split(':') - ETA = f'{days}, {int(h):02}h:{int(m):02}m:{int(s):02}s' - else: - h, m, s = str(ETA).split(':') - ETA = f'{int(h):02}h:{int(m):02}m:{int(s):02}s' - return ETA - -def to_uint8(img): - if img.dtype == np.uint8: - return img - img = np.round(img_to_float(img)*255).astype(np.uint8) - return img - -def to_uint16(img): - if img.dtype == np.uint16: - return img - img = np.round(img_to_float(img)*65535).astype(np.uint16) - return img - -def elided_text(text, max_len=50, elid_idx=None): - if len(text) <= max_len: - return text - - if elid_idx is None: - elid_idx = int(max_len/2) - if elid_idx >= max_len: - elid_idx = max_len - 1 - idx1 = elid_idx - idx2 = elid_idx - max_len - text = f'{text[:idx1]}...{text[idx2:]}' - return text - -def to_relative_path(path, levels=3, prefix='...'): - path = path.replace('\\', '/') - parts = path.split('/') - if levels >= len(parts): - return path - parts = parts[-levels:] - rel_path = '/'.join(parts) - rel_path.replace('/', os.sep) - if prefix: - rel_path = f'{prefix}{os.sep}{rel_path}' - return rel_path - -def img_to_float(img, force_dtype=None, force_missing_dtype=None, warn=True): - input_img_dtype = img.dtype - value = img[(0,) * img.ndim] - img_max = np.max(img) - # Check if float outside of -1, 1 - if img_max <= 1.0 and isinstance(value, (np.floating, float)): - return img - - uint8_max = np.iinfo(np.uint8).max - uint16_max = np.iinfo(np.uint16).max - uint32_max = np.iinfo(np.uint32).max - - img = img.astype(float) - if force_dtype is not None: - dtype_max = np.iinfo(force_dtype).max - img = img/dtype_max - elif input_img_dtype == np.uint8: - # Input image is 8-bit - img = img/uint8_max - elif input_img_dtype == np.uint16: - # Input image is 16-bit - img = img/uint16_max - elif input_img_dtype == np.uint32: - # Input image is 32-bit - img = img/uint32_max - elif force_missing_dtype is not None: - img = img.astype(force_dtype) - elif img_max <= uint8_max: - # Input image is probably 8-bit - if warn: - _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, '8-bit') - img = img/uint8_max - elif img_max <= uint16_max: - # Input image is probably 16-bit - if warn: - _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, '16-bit') - img = img/uint16_max - elif img_max <= uint32_max: - # Input image is probably 32-bit - if warn: - _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, '32-bit') - img = img/uint32_max - else: - # Input image is a non-supported data type - raise TypeError( - f'The maximum value in the image is {img_max} which is greater than the ' - f'maximum value supported of {uint32_max} (32-bit). ' - 'Please consider converting your images to 32-bit or 16-bit first.' - ) - return img - -def float_img_to_dtype(img, dtype): - if img.dtype == dtype: - return img - - img_max = img.max() - if img_max > 1.0: - raise TypeError( - 'Images of float data type with values greater than 1.0 cannot ' - f'be safely casted to {dtype}. ' - f'The max value of the input image is {img_max:.3f}' - ) - - img_min = img.min() - if img_min < -1.0: - raise TypeError( - 'Images of float data type with values smaller than -1.0 cannot ' - f'be safely casted to {dtype}.' - f'The minumum value of the input image is {img_min:.3f}' - ) - - if dtype == np.uint8: - return skimage.img_as_ubyte(img) - - if dtype == np.uint16: - return skimage.img_as_uint(img) - - if dtype == np.float32: - return img.astype(np.float32) - - if dtype == np.float64: - return img.astype(np.float64) - - raise TypeError( - f'Invalid output data type `{dtype}`. ' - 'Valid output data types are `np.uint8` and `np.uint16`' - ) - -def convert_to_dtype(data: np.ndarray, dtype): - if data.dtype == dtype: - return data - val = data[tuple([0]*data.ndim)] - if isinstance(val, (np.floating, float)): - data = float_img_to_dtype(data, dtype) - elif dtype == np.uint8: - data = np.round(img_to_float(data)*255).astype(np.uint8) - elif dtype == np.uint16: - data = np.round(img_to_float(data)*65535).astype(np.uint16) - else: - raise TypeError( - f'Invalid output data type `{dtype}`. ' - 'Valid data types are floating-point format, `np.uint8` ' - 'and `np.uint16`' - ) - return data - -def _install_homebrew_command(): - return '/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"' - -def _brew_install_java_command(): - return 'brew install --cask homebrew/cask-versions/adoptopenjdk8' - -def _brew_install_hdf5(): - return 'brew install hdf5' - -def _apt_update_command(): - return 'sudo apt-get update' - -def _apt_gcc_command(): - return 'sudo apt install python-dev gcc' - -def _apt_install_java_command(): - return 'sudo apt-get install openjdk-8-jdk' - -def _java_instructions_linux(): - s1 = html_utils.paragraph(""" - Run the following commands
    - in the Teminal one by one: - """) - - s2 = html_utils.paragraph(f""" - {_apt_gcc_command().replace(' ', ' ')} - """) - - s3 = html_utils.paragraph(f""" - {_apt_update_command().replace(' ', ' ')} - """) - - s4 = html_utils.paragraph(f""" - {_apt_install_java_command().replace(' ', ' ')} - """) - - s5 = html_utils.paragraph(""" - The first command is used to install GCC, which is needed later.

    - The second and third commands are used is used to install - Java Development Kit 8.

    - Follow the instructions on the terminal to complete - installation.

    - """) - return s1, s2, s3, s4 - -def _java_instructions_macOS(): - s1 = html_utils.paragraph(""" - Run the following commands
    - in the Teminal one by one: - """) - - s2 = html_utils.paragraph(f""" - {_install_homebrew_command()} - """) - - s3 = html_utils.paragraph(f""" - {_brew_install_java_command().replace(' ', ' ')} - """) - - s4 = html_utils.paragraph(""" - The first command is used to install Homebrew
    - a package manager for macOS/Linux.

    - The second command is used to install Java 8.
    - Follow the instructions on the terminal to complete - installation.

    - Alternatively, you can install Java as a regular app
    - by downloading the app from - - here - . - """) - return s1, s2, s3, s4 - -def jdk_windows_url(): - return 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/R62Ktcda6jWea2s' - -def cpp_windows_url(): - return 'https://visualstudio.microsoft.com/visual-cpp-build-tools/' - -def _java_instructions_windows(): - jdk_url = f'"{jdk_windows_url()}"' - cpp_url = f'"{cpp_windows_url()}"' - s1 = html_utils.paragraph(""" - Download and install Java Development Kit and
    - Microsoft C++ Build Tools for Windows (links below).

    - IMPORTANT: when installing "Microsoft C++ Build Tools"
    - make sure to select "Desktop development with C++".
    - Click "See the screenshot" for more details.
    - """) - - s2 = html_utils.paragraph(f""" - Java Development Kit: - - here - - """) - - s3 = html_utils.paragraph(f""" - Microsoft C++ Build Tools: - - here - - """) - return s1, s2, s3 - -def install_javabridge_instructions_text(): - if is_win: - return _java_instructions_windows() - elif is_mac: - return _java_instructions_macOS() - elif is_linux: - return _java_instructions_linux() - -def install_javabridge_help(parent=None): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(f""" - Cell-ACDC is going to download and install - javabridge.

    - Make sure you have an active internet connection, - before continuing. - Progress will be displayed on the terminal

    - IMPORTANT: If the installation fails, please open an issue - on our - - GitHub page - .

    - Alternatively, you can cancel the process and try later. - """) - msg.setIcon() - msg.setWindowTitle('Installing javabridge') - msg.addText(txt) - msg.addButton(' Ok ') - cancel = msg.addButton(' Cancel ') - msg.exec_() - return msg.clickedButton == cancel - -def check_napari_plugin(plugin_name, module_name, parent=None): - try: - import_module(module_name) - except ModuleNotFoundError as e: - url = 'https://napari.org/stable/plugins/find_and_install_plugin.html#find-and-install-plugins' - href = html_utils.href_tag('this guide', url) - txt = html_utils.paragraph(f""" - To correctly use this napari utility you need to install the - plugin called {plugin_name}.

    - Please, read {href} on how to install plugins in napari.

    - You will need to restart both napari and Cell-ACDC after installing - the plugin.

    - NOTE: in the text box in napari you will need to write the full name - {plugin_name} becasue it is NOT A SEARCH BOX. - """) - msg = widgets.myMessageBox() - msg.critical(parent, f'Napari plugin required', txt) - raise e - -def _install_pip_package( - pkg_name: str, - logger: Callable = print, - install_dependencies: bool = True, - force_binary: bool = True, - pref_binary: bool = True, - ) -> None: - command = [sys.executable, '-m', 'pip', 'install', pkg_name,] - if force_binary: - command.append('--only-binary=:all:') - elif pref_binary: - command.append('--prefer-binary') - if not install_dependencies: - command.append('--no-deps') - try: - subprocess.check_call( - command - ) - except subprocess.CalledProcessError as e: - if "--only-binary=:all:" in str(e): - logger(f"Error: {pkg_name} does not have a binary distribution available, trying preferred binary.") - _install_pip_package( - pkg_name=pkg_name, - logger=logger, - install_dependencies=install_dependencies, - force_binary=False, - pref_binary=True, - ) - elif "--prefer-binary" in str(e): - logger(f"Error: {pkg_name} does not have a preferred binary distribution available, trying source.") - command.remove('--prefer-binary') - command.append('--no-binary=:all:') - _install_pip_package( - pkg_name=pkg_name, - logger=logger, - install_dependencies=install_dependencies, - force_binary=False, - pref_binary=False, - ) - else: - logger(f"""Error: {pkg_name} installation failed. Please check the error message. This is probably due to the package - not being available for your platform or python version.""") - raise e - -def uninstall_pip_package(pkg_name): - subprocess.check_call( - [sys.executable, '-m', 'pip', 'uninstall', '-y', pkg_name] - ) - -def uninstall_omnipose_acdc(): - """Uninstall omnipose-acdc if present. Since v1.5.0 it is not needed. - """ - import json - pip_list_output = subprocess.check_output( - [sys.executable, '-m', 'pip', 'list', '--format', 'json'] - ) - installed_packages = json.loads(pip_list_output) - pkgs_to_uninstall = [] - for package_info in installed_packages: - if package_info['name'] == 'omnipose-acdc': - pkgs_to_uninstall.append('omnipose-acdc') - elif package_info['name'] == 'cellpose-omni-acdc': - pkgs_to_uninstall.append('cellpose-omni-acdc') - - for pkg_to_uninstall in pkgs_to_uninstall: - uninstall_pip_package(pkg_to_uninstall) - -def get_cellpose_major_version(errors='raise'): - major_installed = None - try: - installed_version = get_package_version('cellpose') - major_installed = int(installed_version.split('.')[0]) - except Exception as err: - if errors == 'raise': - raise err - - return major_installed - -def check_cellpose_version(version: str): - if isinstance(version, int): - version = f'{version}.0' - - major_requested = int(version.split('.')[0]) - cancel = False - try: - installed_version = get_package_version('cellpose') - major_installed = int(installed_version.split('.')[0]) - is_version_correct = major_installed == major_requested - if not is_version_correct: - cancel = _warnings.warn_installing_different_cellpose_version( - version, installed_version - ) - if not is_second_version_greater( - min_target_versions_cp[str(major_requested)], - installed_version - ): - is_version_correct = False - except Exception as err: - is_version_correct = False - - if cancel: - raise ModuleNotFoundError('Cellpose installation cancelled by the user.') - return is_version_correct - -def purge_module(module_name): - to_delete = [mod for mod in sys.modules if mod == module_name or mod.startswith(module_name + '.')] - for mod in to_delete: - del sys.modules[mod] - - importlib.invalidate_caches() - importlib.import_module(module_name) - if module_name in sys.modules: - importlib.reload(sys.modules[module_name]) - else: - raise ModuleNotFoundError(f"Module '{module_name}' not found in sys.modules.") - -def is_second_version_greater( - target_version: str, - current_version: str, -): - """ - Compares two model versions and returns True if the current version is - greater than or equal to the target version. - """ - target_version = packaging_version.parse(target_version) - current_version = packaging_version.parse(current_version) - - return current_version >= target_version - -def is_pkg_version_within_range( - package_version: str, min_version='', max_version='' - ): - package_version_number = packaging_version.parse(package_version) - is_greater_than_min = True - if min_version: - min_version_number = packaging_version.parse(min_version) - is_greater_than_min = package_version_number >= min_version_number - - is_less_than_max = True - if max_version: - max_version_number = packaging_version.parse(max_version) - is_less_than_max = package_version_number <= max_version_number - - return is_greater_than_min and is_less_than_max - - -def check_install_cellpose( - version: Literal['2.0', '3.0', '4.0', 'any'] = '2.0', - version_to_install_if_missing: Literal['2.0', '3.0', '4.0'] = '4.0' - ): - if isinstance(version, int): - version = f'{version}.0' - - check_install_torch() - - if version == 'any': - try: - from cellpose import models - return - except Exception as err: - version = version_to_install_if_missing # after this the version will for sure be a valid format and not 'any' - - is_version_correct = check_cellpose_version(version) - if is_version_correct: - return - - major_version = int(version.split('.')[0]) - - next_version = major_version+1 - - min_version = min_target_versions_cp[str(major_version)] - - check_install_package( - 'cellpose', - max_version=f'{next_version}.0', - min_version=min_version, - include_lower_version=True, - ) - - purge_module('cellpose') - -def check_install_baby(): - check_install_package( - 'TensorFlow', - pypi_name='tensorflow', - import_pkg_name='tensorflow', - max_version='2.14' - ) - check_install_package('baby', pypi_name='baby-seg', import_pkg_name='baby') - -def check_install_nnInteractive(): - check_install_package('huggingface-hub') - check_install_torch() - check_install_package('nnInteractive') - - purge_module('nnInteractive') - - importlib.invalidate_caches() - import nnInteractive - importlib.reload(nnInteractive) - -def check_install_microsam(): - check_install_package( - 'micro-sam', - pypi_name='micro_sam', - installer='conda' - ) - -def check_install_yeaz(): - check_install_torch() - check_install_package('yeaz') - -def check_install_segment_anything(): - check_install_torch() - check_install_package('segment_anything') - -def check_install_sam2(): - check_install_torch() - check_install_package('sam2') - - -def check_install_cellsam(): - check_install_torch() - check_install_package( - 'cellSAM', - pypi_name='git+https://github.com/vanvalenlab/cellSAM.git', - import_pkg_name='cellSAM', - note=( - 'CellSAM requires a DeepCell access token to download models.\n' - 'Set the DEEPCELL_ACCESS_TOKEN environment variable before use.\n' - 'Get your token at: https://deepcell.org' - ) - ) - -def is_gui_running(): - if not GUI_INSTALLED: - return False - - return QCoreApplication.instance() is not None - -def check_pkg_version(import_pkg_name, min_version, include_lower_version, raise_err=True): - is_version_correct = False - try: - installed_version = get_package_version(import_pkg_name) - if include_lower_version: - is_version_correct = ( - packaging_version.parse(installed_version) - >= packaging_version.parse(min_version) - ) - else: - is_version_correct = ( - packaging_version.parse(installed_version) - > packaging_version.parse(min_version) - ) - except Exception as err: - is_version_correct = False - - if raise_err and not is_version_correct: - raise ModuleNotFoundError( - f'{import_pkg_name}>{min_version} not installed.' - ) - else: - return is_version_correct - -def check_pkg_exact_version(import_pkg_name, version: str, raise_err=True): - is_version_correct = False - try: - installed_version = get_package_version(import_pkg_name) - is_version_correct = ( - packaging_version.parse(installed_version) - == packaging_version.parse(version) - ) - except Exception as err: - is_version_correct = False - - if raise_err and not is_version_correct: - raise ModuleNotFoundError( - f'{import_pkg_name}=={version} not installed.' - ) - else: - return is_version_correct - -def check_pkg_max_version( - import_pkg_name, max_version, include_higher_version, raise_err=True - ): - is_version_correct = False - try: - from packaging import version - installed_version = get_package_version(import_pkg_name) - if include_higher_version: - is_version_correct = ( - packaging_version.parse(installed_version) - <= packaging_version.parse(max_version) - ) - else: - is_version_correct = ( - packaging_version.parse(installed_version) - < packaging_version.parse(max_version) - ) - except Exception as err: - is_version_correct = False - - if raise_err and not is_version_correct: - raise ModuleNotFoundError( - f'{import_pkg_name}<={max_version} not installed.' - ) - else: - return is_version_correct - -def install_package_conda(conda_pkg_name, channel='conda-forge'): - if not is_conda_env(): - raise EnvironmentError( - 'Cell-ACDC is not running in a `conda` environment.' - ) - conda_prefix, pip_prefix = get_pip_conda_prefix() - conda_prefix = re.sub( - r'(-c\sconda-forge\s?|--channel=conda-forge\s?)', f'-c {channel} ', - conda_prefix - ) - - command = f'{conda_prefix} -y {conda_pkg_name}' - _subprocess_run_command(command) - -def _subprocess_run_command(command, shell=True, callback='check_call'): - func = getattr(subprocess, callback) - try: - out = func(command, shell=shell) - except Exception as err: - print( - f'[WARNING]: Command `{command}` failed. ' - f'Trying with `{command.split()}`...' - ) - out = func(command.split(), shell=shell) - - return out - -def check_install_omnipose(): - try: - import_module('omnipose') - return - except ModuleNotFoundError: - pass - - try: - check_install_package('omnipose', pypi_name='omnipose_acdc') - except Exception as err: - install_package_conda('mahotas') - _install_pip_package('omnipose-acdc') - -def _run_command(command: str | list[str], shell=False): - if not isinstance(command, (str, list)): - raise TypeError( - f'Command must be a string or a list of strings, not {type(command)}' - ) - - command_str = None - if isinstance(command, str): - args_list = [command] - command_str = command - else: - args_list = command - if len(command) == 1: - command_str = command[0] - - try: - subprocess.check_call(args_list, shell=shell) - return - except Exception as err: - pass - - if command_str is None: - return - - try: - subprocess.check_call(command_str, shell=shell) - return - except Exception as err: - pass - - try: - from . import acdc_regex - args = acdc_regex.RE_SPLIT_SPACES_IGNORE_QUOTES.split(command_str)[1::2] - subprocess.check_call(args, shell=shell) - return - except Exception as err: - pass - -def _warn_dll_torch(qparent=None): - msg = widgets.myMessageBox() - txt = html_utils.paragraph(""" - An error message will occur after you close this message.
    - Please save your data and restart Cell-ACDC.
    - Sorry for the inconvenience!
    - This error is not critical for the main functionality of Cell-ACDC, - and only concerns the segmentation model. Your can save your data without - a problem.
    - The specific reason is that PyTorch and QtPy have weird issues with - DLL conflicts. - """) - msg.information( - qparent, 'Please restart Cell-ACDC', txt, - buttonsTexts=('Ok, I will save my data and restart Cell-ACDC'), - ) - -def check_install_torch(is_cli=False, caller_name='Cell-ACDC', qparent=None): - try: - import torch - import torchvision - return - - except OSError as err: - if 'dll' in str(err): - _warn_dll_torch(qparent=qparent) - raise err - else: - traceback.print_exc() - except Exception as err: - traceback.print_exc() - - if is_cli: - _install_pytorch_cli(caller_name=caller_name) - return - - win = apps.InstallPyTorchDialog(parent=qparent, caller_name=caller_name) - win.exec_() - if win.cancel: - _warnings.log_pytorch_not_installed() - return - - command = win.command - print(f'Running command: "{command}"') - _run_command(command) - - try: - import torch - except OSError as e: - if 'dll' in str(e): - _warn_dll_torch(qparent=qparent) - raise e - - purge_module('torch') - -def check_install_package( - pkg_name: str, - import_pkg_name: str='', - pypi_name='', - note='', - parent=None, - raise_on_cancel=True, - logger_func=print, - is_cli=False, - caller_name='Cell-ACDC', - force_upgrade=False, - upgrade=False, - min_version='', - max_version='', - exact_version='', - install_dependencies=True, - return_outcome=False, - installer: Literal['pip', 'conda']='pip', - include_higher_version: bool = False, - include_lower_version: bool = False - ): - """Try to import a package. If import fails, ask user to install it - automatically. - - Parameters - ---------- - pkg_name : str - The name of the package that is displayed to the user. - import_pkg_name : str, optional - The name of the package as it should be imported (case sensitive). - If empty string, `pkg_name` will be imported instead. Default is '' - pypi_name : str, optional - The name of the package to be installed with pip. - If empty string, `pkg_name` will be installed instead. Default is '' - note : str, optional - Additional text to display to the user. Default is '' - parent : QObject, optional - Calling QtWidget. Default is None - raise_on_cancel : bool, optional - Raise exception if processed cancelled. Default is True - logger_func : callable, optional - Function used to log text. Default is print - is_cli : bool, optional - If True, message will be displayed in the terminal. - If False, message will be displayed in a Qt message box. - Default is False - caller_name : str, optional - Program calling this function. Default is 'Cell-ACDC' - force_upgrade : bool, optional - If True, we force the upgrade even if package is installed. - upgrade : bool, optional - If True, pip will upgrade the package. This value is True if - `force_upgrade` is True. Without min_version and max_version - it will never upgrade or downgrade the package. - min_version : str, optional - If not empty it must be a valid version `major[.minor][.patch]` where - minor and patch are optional. If the installed package is older the - upgrade will be forced. - max_version : str, optional - If not empty it must be a valid version `major[.minor][.patch]` where - minor and patch are optional. If the installed package is newer the - upgrade will be forced. - exact_version : str, optional - If not empty, install this exact version. It must be a valid - `major[.minor][.patch]`. - install_dependencies : bool, optional - If False, the `--no-deps` flag will be added to the pip command. - return_outcome : bool, optional - If True, returns 1 on successfull action - installer : str, optional - Package manager to use to install the package. Either 'pip' or 'conda'. - Default is 'pip' - include_higher_version : bool, optional - If True, if the higher version is installed, it will not be downgraded. - Default is False - include_lower_version : bool, optional - If True, if the lower version is installed, it will not be upgraded. - Default is False - - Raises - ------ - ModuleNotFoundError - Error raised if process is cancelled and `raise_on_cancel=True`. - """ - if not import_pkg_name: - import_pkg_name = pkg_name - - if not is_gui_running(): - is_cli=True - - try: # check_pkg_version and check_pkg_max_version - import_pkg_name = import_pkg_name.replace('-', '_') - import_module(import_pkg_name) - if force_upgrade: - upgrade = True - raise ModuleNotFoundError( - f'User requested to forcefully upgrade the package "{pkg_name}"') - if exact_version: - check_pkg_exact_version(import_pkg_name, exact_version) - if min_version: - check_pkg_version(import_pkg_name, min_version, include_lower_version) - if max_version: - check_pkg_max_version(import_pkg_name, max_version, include_higher_version) - except ModuleNotFoundError: - proceed = _install_package_msg( - pkg_name, - note=note, - parent=parent, - upgrade=upgrade, - is_cli=is_cli, - caller_name=caller_name, - logger_func=logger_func, - pkg_command=pypi_name, - max_version=max_version, - min_version=min_version, - exact_version=exact_version, - installer=installer, - include_higher_version=include_higher_version, - include_lower_version=include_lower_version - ) - if pypi_name: - pkg_name = pypi_name - if not proceed: - if raise_on_cancel: - raise ModuleNotFoundError( - f'User aborted {pkg_name} installation' - ) - else: - return traceback.format_exc() - try: - if pkg_name == 'tensorflow': - _install_tensorflow( - max_version=max_version, min_version=min_version - ) - elif pkg_name == 'deepsea': - _install_deepsea() - elif pkg_name == 'segment_anything': - _install_segment_anything() - elif pkg_name == 'sam2': - _install_sam2() - else: - pkg_command = _get_pkg_command_pip_install( - pkg_name, - exact_version=exact_version, - max_version=max_version, - min_version=min_version, - including_higher_version=include_higher_version, - including_lower_version=include_lower_version, - ) - if installer == 'pip': - _install_pip_package(pkg_command, install_dependencies=install_dependencies) - else: - install_package_conda(pkg_command) - except Exception as e: - printl(traceback.format_exc()) - _inform_install_package_failed( - pkg_name, parent=parent, do_exit=raise_on_cancel - ) - if return_outcome: - return True - -def check_install_custom_dependencies(custom_install_requires, *args, **kwargs): - """Used to install a package with custom dependencies, usefull if they have - random pinned versions for their dependencies. - - For *args and **kwargs see `myutils.check_install_package`. - - Parameters - ---------- - custom_install_requires : list - list of dependencies. Check either requirements.txt, setup.py, - setup.cfg, pyproject.toml, or any other file that lists the dependencies. - For formatting of the dependencies with min max version, - use _get_pkg_command_pip_install. - """ - kwargs['install_dependencies'] = False - kwargs['return_outcome'] = True - success = check_install_package(*args, **kwargs) - if not success: - return - for pkg_name in custom_install_requires: - _install_pip_package(pkg_name) - -def get_chained_attr(_object, _name): - for attr in _name.split('.'): - _object = getattr(_object, attr) - return _object - -def check_matplotlib_version(qparent=None): - mpl_version = get_package_version('matplotlib') - mpl_version_digits = mpl_version.split('.') - - mpl_major = int(mpl_version_digits[0]) - mpl_minor = int(mpl_version_digits[1]) - is_less_than_3_5 = ( - mpl_major < 3 or (mpl_major >= 3 and mpl_minor < 5) - ) - if not is_less_than_3_5: - return - - proceed = _install_package_msg('matplotlib', parent=qparent, upgrade=True) - if not proceed: - raise ModuleNotFoundError( - f'User aborted "matplotlib" installation' - ) - import subprocess - try: - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '-U', 'matplotlib'] - ) - except Exception as e: - printl(traceback.format_exc()) - _inform_install_package_failed( - 'matplotlib', parent=qparent, do_exit=False - ) - -def _inform_install_package_failed(pkg_name, parent=None, do_exit=True): - conda_prefix, pip_prefix = get_pip_conda_prefix() - - install_command = f'{pip_prefix} --upgrade {pkg_name}' - txt = html_utils.paragraph(f""" - Unfortunately, installation of {pkg_name} returned an error.

    - Try restarting Cell-ACDC. If it doesn't work, - please close Cell-ACDC and, with the acdc environment ACTIVE, - install {pkg_name} manually using the follwing command:

    - {install_command}

    - Thank you for your patience. - """) - msg = widgets.myMessageBox() - msg.critical(parent, f'{pkg_name} installation failed', txt) - print('*'*50) - print( - f'[ERROR]: Installation of "{pkg_name}" failed. ' - f'Please, close Cell-ACDC and run the command ' - f'{pip_prefix} --upgrade {pkg_name}`' - ) - print('^'*50) - -def download_fiji(logger_func=print): - url = None - if is_mac: - url = 'https://downloads.micron.ox.ac.uk/fiji_update/mirrors/fiji-latest/fiji-macosx.zip' - file_size = 474_525_405 - - if url is None: - return - - if os.path.exists(get_fiji_exec_folderpath()): - return - - os.makedirs(acdc_fiji_path) - - temp_dir = tempfile.mkdtemp() - zip_dst = os.path.join(temp_dir, 'fiji-macosx.zip') - logger_func(f'Downloading Fiji to "{acdc_fiji_path}"...') - download_url( - url, zip_dst, verbose=False, file_size=file_size - ) - extract_zip(zip_dst, acdc_fiji_path) - - return acdc_fiji_path - -def _install_package_msg( - pkg_name, note='', parent=None, upgrade=False, caller_name='Cell-ACDC', - is_cli=False, pkg_command='', logger_func=print, - exact_version='', max_version='', min_version='', - installer: Literal['pip', 'conda']='pip', - include_higher_version: bool = False, - include_lower_version: bool = False - ): - if is_cli: - proceed = _install_package_cli_msg( - pkg_name, note=note, upgrade=upgrade, caller_name=caller_name, - pkg_command=pkg_command, - exact_version=exact_version, - max_version=max_version, - min_version=min_version, logger_func=logger_func, - installer=installer, - include_higher_version=include_higher_version, - include_lower_version=include_lower_version - ) - else: - proceed = _install_package_gui_msg( - pkg_name, note=note, parent=parent, upgrade=upgrade, - caller_name=caller_name, pkg_command=pkg_command, - exact_version=exact_version, - max_version=max_version, min_version=min_version, - logger_func=logger_func, installer=installer, - including_higher_version=include_higher_version, - including_lower_version=include_lower_version - ) - return proceed - -def get_cli_multi_choice_question(question, choices): - choices_format = [f'{i+1}) {choice}.' for i, choice in enumerate(choices)] - choices_format = ' '.join(choices_format) - choices_opts = '/'.join([str(i) for i in range(1, len(choices)+1)]) - text = f'{question} {choices_format} q) Quit. ({choices_opts})?: ' - return text - -def _install_pytorch_cli( - caller_name='Cell-ACDC', action='install', logger_func=print - ): - separator = '-'*60 - txt = ( - f'{separator}\n{caller_name} needs to {action} PyTorch\n\n' - 'You can choose to install it now or stop the process and install it ' - 'later. To install it correctly, we need to know your preferences.\n' - ) - logger_func(txt) - questions = { - 'Choose your OS:': ('Windows', 'Mac', 'Linux'), - 'Package manager:': ('Pip'), - 'Compute platform:': ( - 'CPU', 'CUDA 11.8 (NVIDIA GPU)', 'CUDA 12.1 (NVIDIA GPU)' - ) - } - selected_command = get_pytorch_command() - selected_preferences = [] - for question, choices in questions.items(): - input_txt = get_cli_multi_choice_question(question, choices) - while True: - answer = input(input_txt) - if answer.lower() == 'q': - exit('Execution stopped by the user.') - - try: - idx = int(answer) - 1 - if idx >= len(choices): - raise TypeError('Not a valid answer') - except Exception as err: - print('-'*100) - logger_func( - f'"{answer}" is not a valid answer.' - 'Choose one of the options or "q" to quit.' - ) - print('^'*100) - continue - - preference = choices[idx] - selected_command = selected_command[preference] - selected_preferences.append(preference) - print('') - break - - print('-'*100) - selected_preferences = ', '.join(selected_preferences) - logger_func(f'Selected preferences: {selected_preferences}') - print('-'*100) - logger_func(f'Command:\n\n{selected_command}\n') - while True: - answer = input('Do you want to run the command now ([y]/n)?: ') - if answer.lower() == 'n': - exit('Execution stopped by the user.') - - if answer.lower() == 'y' or not answer: - break - - print('-'*100) - print( - f'"{answer}" is not a valid answer. ' - 'Choose "y" for yes or "n" for no.' - ) - print('^'*100) - - if selected_command.startswith('conda'): - try: - subprocess.check_call([selected_command], shell=True) - except Exception as err: - cmd_list = selected_command.split() - cmd_list = [cmd.strip('"') for cmd in cmd_list] - cmd_list = [cmd.strip("'") for cmd in cmd_list] - cmd_list = [cmd.lstrip(".") for cmd in cmd_list] - subprocess.check_call(cmd_list, shell=True) - else: - cmd_list = selected_command.split()[1:] - cmd_list = [cmd.strip('"') for cmd in cmd_list] - cmd_list = [cmd.strip("'") for cmd in cmd_list] - cmd_list = [cmd.lstrip(".") for cmd in cmd_list] - subprocess.check_call([sys.executable, *cmd_list], shell=True) - -def _get_pkg_command_pip_install( - pkg_command, - exact_version='', - max_version='', - min_version='', - including_lower_version=False, - including_higher_version=False - ): - if exact_version: - pkg_command = f'{pkg_command}=={exact_version}' - return pkg_command - - if including_higher_version: - sign_max = "<=" - else: - sign_max = "<" - if including_lower_version: - sign_min = ">=" - else: - sign_min = ">" - if min_version: - pkg_command = f'{pkg_command}{sign_min}{min_version}' - if max_version: - pkg_command = f'{pkg_command},' - - if max_version: - pkg_command = f'{pkg_command}{sign_max}{max_version}' - - return pkg_command - -def _install_package_cli_msg( - pkg_name, note='', upgrade=False, caller_name='Cell-ACDC', - logger_func=print, pkg_command='', exact_version='', max_version='', - min_version='', installer: Literal['pip', 'conda']='pip', - include_lower_version=False, - include_higher_version=False - ): - if not pkg_command: - pkg_command = pkg_name - - pkg_command = _get_pkg_command_pip_install( - pkg_command, exact_version=exact_version, - max_version=max_version, min_version=min_version, - including_lower_version=include_lower_version, - including_higher_version=include_higher_version - ) - - if upgrade: - action = 'upgrade' - else: - action = 'install' - - conda_prefix, pip_prefix = get_pip_conda_prefix() - - if installer == 'pip': - install_command = f'{pip_prefix} --upgrade {pkg_command}' - elif installer == 'conda': - install_command = f'{conda_prefix} {pkg_command}' - - separator = '-'*60 - txt = ( - f'{separator}\n{caller_name} needs to {action} {pkg_name}\n\n' - 'You can choose to install it now or stop the process and install it ' - 'later with the following command:\n\n' - f'{install_command}\n' - ) - logger_func(txt) - - - - while True: - answer = try_input_install_package(pkg_name, install_command) - if not answer or answer.lower() == 'y': - return True - - if answer.lower() == 'n': - return False - - logger_func( - f'{answer} is not a valid answer. Valid answers are "y" for Yes and ' - '"n" for No.' - ) - -def _install_package_gui_msg( - pkg_name, note='', parent=None, upgrade=False, caller_name='Cell-ACDC', - pkg_command='', logger_func=None, exact_version='', - max_version='', min_version='', - including_lower_version=False, including_higher_version=False, - installer: Literal['pip', 'conda']='pip' - ): - msg = widgets.myMessageBox(parent=parent) - if upgrade: - install_text = 'upgrade' - else: - install_text = 'install' - if pkg_name == 'BayesianTracker': - pkg_name = 'btrack' - - if not pkg_command: - pkg_command = pkg_name - - pkg_command = _get_pkg_command_pip_install( - pkg_command, exact_version=exact_version, - max_version=max_version, min_version=min_version, - including_lower_version=including_lower_version, - including_higher_version=including_higher_version - ) - - conda_prefix, pip_prefix = get_pip_conda_prefix() - - if installer == 'pip': - command = f'{pip_prefix} --upgrade {pkg_command}' - elif installer == 'conda': - command = f'{conda_prefix} {pkg_command}' - - command_html = command.lower().replace('<', '<').replace('>', '>') - - txt = html_utils.paragraph(f""" - {caller_name} is going to download and {install_text} - {pkg_name}.

    - Make sure you have an active internet connection, - before continuing.
    - Progress will be displayed on the terminal

    - You might have to restart {caller_name}.

    - Alternatively, you can cancel the process and try later.

    - To install later, or if the installation fails, run the following - command: - """) - if note: - txt = f'{txt}{note}' - _, okButton = msg.information( - parent, f'Install {pkg_name}', txt, - buttonsTexts=('Cancel', 'Ok'), - commands=(command_html,) - ) - return msg.clickedButton == okButton - -def _install_tensorflow(max_version='', min_version=''): - cpu = platform.processor() - pkg_command = _get_pkg_command_pip_install( - 'tensorflow', - max_version=max_version, - min_version=min_version - ) - conda_prefix, pip_prefix = get_pip_conda_prefix() - - if is_mac and cpu == 'arm': - args = [f'{conda_prefix} "{pkg_command}"'] - shell = True - else: - args = [sys.executable, '-m', 'pip', 'install', '-U', pkg_command] - shell = False - subprocess.check_call(args, shell=shell) - - # purge numpy - purge_module('numpy') - -def _install_segment_anything(): - args = [ - sys.executable, '-m', 'pip', 'install', - '-U', '--use-pep517', - 'git+https://github.com/facebookresearch/segment-anything.git' - ] - subprocess.check_call(args) - -def _install_sam2(): - args = [ - sys.executable, '-m', 'pip', 'install', - '-U', '--use-pep517', - 'git+https://github.com/facebookresearch/sam2.git' - ] - subprocess.check_call(args) - -def _install_deepsea(): - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', 'deepsea'] - ) - -def import_tracker_module(tracker_name): - module_name = f'cellacdc.trackers.{tracker_name}.{tracker_name}_tracker' - tracker_module = import_module(module_name) - return tracker_module - -def download_ffmpeg(): - ffmpeg_folderpath = acdc_ffmpeg_path - if is_win: - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/rXioWZpwjwn9JTT/download/windows_ffmpeg-7.0-full_build.zip' - file_size = 173477888 - ffmep_exec_path = os.path.join(ffmpeg_folderpath, 'bin', 'ffmpeg.exe') - elif is_mac: - url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/We7rcTLzqAP4zf7/download/mac_ffmpeg.zip' - file_size = 25288704 - ffmep_exec_path = os.path.join(ffmpeg_folderpath, 'ffmpeg') - elif is_linux: - ffmep_exec_path = '' - return ffmep_exec_path - - if os.path.exists(ffmep_exec_path): - return ffmep_exec_path.replace('\\', os.sep).replace('/', os.sep) - - print('Downloading FFMPEG...') - temp_dir = tempfile.mkdtemp() - temp_zip_path = os.path.join(temp_dir, 'acdc-ffmpeg.zip') - - download_url( - url, temp_zip_path, verbose=True, file_size=file_size, - ) - extract_zip(temp_zip_path, ffmpeg_folderpath) - - return ffmep_exec_path.replace('\\', os.sep).replace('/', os.sep) - -def get_fiji_binary_filepath_mac(fiji_app_filepath): - if not is_mac: - return '' - - fiji_binary_path = os.path.join( - fiji_app_filepath, 'Contents', 'MacOS', 'ImageJ-macosx' - ) - if os.path.exists(fiji_binary_path): - return fiji_binary_path - - fiji_binary_path = os.path.join( - fiji_app_filepath, 'Contents', 'MacOS', 'fiji-macos' - ) - if os.path.exists(fiji_binary_path): - return fiji_binary_path - - return '' - -def get_fiji_exec_folderpath() -> str: - if not is_mac: - return '' - - from cellacdc import fiji_location_filepath - - if os.path.exists(fiji_location_filepath): - with open(fiji_location_filepath, 'r') as txt: - fiji_app_filepath = txt.read() - - return get_fiji_binary_filepath_mac(fiji_app_filepath) - - if os.path.exists('/Applications/Fiji.app'): - return get_fiji_binary_filepath_mac('/Applications/Fiji.app') - - acdc_fiji_app_path = os.path.join(acdc_fiji_path, 'Fiji.app') - acdc_fiji_binary_path = get_fiji_binary_filepath_mac(acdc_fiji_app_path) - - return acdc_fiji_binary_path - -def get_fiji_base_command(): - command = None - if is_mac: - command = get_fiji_exec_folderpath() - - return command - -def _init_fiji_cli(): - if is_win: - return True - - fiji_app_folderpath = get_fiji_exec_folderpath() - args_add_to_path = [f'chmod 755 {fiji_app_folderpath}'] - try: - subprocess.check_call(args_add_to_path, shell=True) - return True - except Exception as e: - printl(f'Error occurred while setting permissions: {e}') - return False - -def test_fiji_base_command(logger_func=print): - base_command = get_fiji_base_command() - - if base_command is None: - logger_func('[WARNING]: Fiji is not present.') - return False - - command = f'{base_command} --headless' - return run_fiji_command(command=command, logger_func=logger_func) - -def run_fiji_command(command=None, logger_func=print): - if command is None: - command = f'{get_fiji_base_command()} --headless' - - init_success = _init_fiji_cli() - if not init_success: - return False - - separator = '-'*100 - commands = (command, command.split()) - for args in commands: - logger_func( - f'{separator}\n' - f'Trying Fiji command: "{args}"...\n' - f'{separator}\n' - ) - try: - subprocess.check_call(args, shell=True) - return True - except Exception as err: - continue - return False - -def import_promptable_segment_module(model_name): - try: - acdcPromptSegment = import_module( - f'cellacdc.promptable_models.{model_name}.acdcPromptSegment' - ) - except ModuleNotFoundError as e: - # Check if custom model - cp = config.ConfigParser() - cp.read(promptable_models_list_file_path) - model_path = cp[model_name]['path'] - spec = importlib.util.spec_from_file_location( - 'acdcPromptSegment', model_path - ) - acdcPromptSegment = importlib.util.module_from_spec(spec) - sys.modules['acdcPromptSegment'] = acdcPromptSegment - spec.loader.exec_module(acdcPromptSegment) - return acdcPromptSegment - -def init_tracker( - posData, trackerName, realTime=False, qparent=None, - return_init_params=False - ): - from . import apps - downloadWin = apps.downloadModel(trackerName, parent=qparent) - downloadWin.download() - - trackerModule = import_tracker_module(trackerName) - init_params = {} - track_params = {} - paramsWin = None - if trackerName == 'BayesianTracker': - Y, X = posData.img_data_shape[-2:] - if posData.isSegm3D: - labShape = (posData.SizeZ, Y, X) - else: - labShape = (1, Y, X) - paramsWin = apps.BayesianTrackerParamsWin( - labShape, parent=qparent, channels=posData.chNames, - currentChannelName=posData.user_ch_name - ) - paramsWin.exec_() - if not paramsWin.cancel: - init_params = paramsWin.params - track_params['export_to'] = posData.get_btrack_export_path() - if paramsWin.intensityImageChannel is not None: - chName = paramsWin.intensityImageChannel - track_params['image'] = posData.loadChannelData(chName) - track_params['image_channel_name'] = chName - elif trackerName == 'CellACDC': - paramsWin = apps.CellACDCTrackerParamsWin(parent=qparent) - paramsWin.exec_() - if not paramsWin.cancel: - init_params = paramsWin.params - elif trackerName == 'delta': - paramsWin = apps.DeltaTrackerParamsWin(posData=posData, parent=qparent) - paramsWin.exec_() - if not paramsWin.cancel: - init_params = paramsWin.params - else: - init_argspecs, track_argspecs = getTrackerArgSpec( - trackerModule, realTime=realTime - ) - intensityImgRequiredForTracker = isIntensityImgRequiredForTracker( - trackerModule - ) - if init_argspecs or track_argspecs: - try: - url = trackerModule.url_help() - except AttributeError: - url = None - try: - channels = posData.chNames - except Exception as e: - channels = None - try: - currentChannelName = posData.user_ch_name - except Exception as e: - currentChannelName = None - try: - df_metadata = posData.metadata_df - except Exception as e: - df_metadata = None - - if not intensityImgRequiredForTracker: - currentChannelName = None - - paramsWin = apps.QDialogModelParams( - init_argspecs, track_argspecs, trackerName, url=url, - channels=channels, is_tracker=True, - currentChannelName=currentChannelName, - df_metadata=df_metadata, posData=posData - ) - if not intensityImgRequiredForTracker and channels is not None: - paramsWin.channelCombobox.setDisabled(True) - - paramsWin.exec_() - if not paramsWin.cancel: - init_params = paramsWin.init_kwargs - track_params = paramsWin.model_kwargs - if paramsWin.inputChannelName != 'None': - chName = paramsWin.inputChannelName - track_params['image'] = posData.loadChannelData(chName) - track_params['image_channel_name'] = chName - if 'export_to_extension' in track_params: - ext = track_params['export_to_extension'] - track_params['export_to'] = posData.get_tracker_export_path( - trackerName, ext - ) - - if paramsWin is not None and paramsWin.cancel: - tracker = None, - track_params = None - init_params = None - else: - tracker = trackerModule.tracker(**init_params) - - if return_init_params: - return tracker, track_params, init_params - else: - return tracker, track_params - -def import_segment_module(model_name): - try: - acdcSegment = import_module(f'cellacdc.models.{model_name}.acdcSegment') - except ModuleNotFoundError as e: - # Check if custom model - cp = config.ConfigParser() - cp.read(models_list_file_path) - model_path = cp[model_name]['path'] - spec = importlib.util.spec_from_file_location('acdcSegment', model_path) - acdcSegment = importlib.util.module_from_spec(spec) - sys.modules['acdcSegment'] = acdcSegment - spec.loader.exec_module(acdcSegment) - return acdcSegment - -def get_pip_conda_prefix(list_return=False): - from .config import parser_args - try: - cp = parser_args - if cp["install_details"] is not None: - no_cli_install = True - install_details = cp["install_details"] - venv_path = install_details["venv_path"] - conda_path = install_details["conda_path"] - if ' ' not in conda_path: - conda_path = conda_path.strip('"').strip("'") - else: - no_cli_install = False - except: - no_cli_install = False - pass - - if no_cli_install: - conda_prefix = f'{conda_path} install -y -p {venv_path} -c conda-forge' - exec_path = sys.executable - if ' ' in exec_path: - exec_path = f'"{exec_path}"' - pip_prefix = f"{exec_path} -m pip install" - else: - conda_prefix = 'conda install -y -c conda-forge' - pip_prefix = 'pip install' - - pip_list = [sys.executable, '-m', 'pip', 'install'] - if no_cli_install: - conda_list = [conda_path.strip('"').strip("'"), 'install', '-y', '-p', venv_path.strip('"').strip("'"), '-c', 'conda-forge'] - else: - conda_list = ['conda', 'install', '-y', '-c', 'conda-forge'] - if list_return: - return conda_list, pip_list - else: - return conda_prefix, pip_prefix - - -def _warn_install_gpu(model_name, ask_installs, qparent=None): - - cellpose_cuda_url = ( - r'https://github.com/mouseland/cellpose#gpu-version-cuda-on-windows-or-linux' - ) - torch_cuda_url = ( - r'https://pytorch.org/get-started/locally/' - ) - direct_ml_url = ( - r'https://microsoft.github.io/DirectML/' - ) - torch_directml_url = ( - r'https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows' - ) - - - cellpose_href = f'{html_utils.href_tag("here", cellpose_cuda_url)}' - torch_href = f'{html_utils.href_tag("here", torch_cuda_url)}' - direct_ml_href = f'{html_utils.href_tag("direct_ml_DirectMLref", direct_ml_url)}' - torch_directml_href = f'{html_utils.href_tag("directml pytorch", torch_directml_url)}' - - conda_prefix, pip_prefix = get_pip_conda_prefix() - - msg = widgets.myMessageBox(showCentered=False, wrapText=False) - txt = html_utils.paragraph(f""" - In order to use {model_name} with the GPU you need - to install a PyTorch version which can use it.
    - We recomment using CUDA over DirectML, but if you are using a Windows - machine with an AMD GPU, you can use DirectML.
    - """) - txt_cuda_title = html_utils.paragraph(f"CUDA", font_size='18px') - - pip_prefix = pip_prefix.replace('install -y', 'uninstall') - txt_cuda = html_utils.paragraph(f""" - Check out these instructions {cellpose_href}, and {torch_href}.
    - First, uninstall the CPU version of PyTorch with the following command: - {pip_prefix} uninstall torch -
    Then, install the CUDA version required by your GPU with the following - command (in this case 12.8): - {pip_prefix} torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 -
    - """) - - add_info = html_utils.to_admonition( - f""" - Pleae use the following table to find the correct link for the command. - You can check the highest CUDA
    version supported on your system with the - command nvidia-smi in the terminal.
    - - {html_utils.table_style_header} - - CUDA Version - PyTorch Installation Link - - - CUDA 11.8 - https://download.pytorch.org/whl/cu118 - - - CUDA 12.6 - https://download.pytorch.org/whl/cu126 - - - CUDA 12.8 - https://download.pytorch.org/whl/cu128 - - - """, - "info" - ) - - txt_cuda = f'{txt_cuda}{add_info}' - - txt_directML_title = html_utils.paragraph(f"DirectML", font_size='18px') - txt_directML = html_utils.paragraph(f""" - Check out {direct_ml_href}, and {torch_directml_href} for more info.
    - Only supported on Windows 10/11 with Python 3.8-3.12.
    - Click the Install DirectML button to install DirectML. -

    - """) - - txt_end = html_utils.paragraph(f""" - How do you want to proceed? - """) - - stopButton = widgets.cancelPushButton('Stop the process') - directMLButton = widgets.okPushButton('Install DirectML') - proceedButton = widgets.okPushButton('Proceed without GPU') - - buttons = [stopButton] - - if 'cuda' in ask_installs: - txt = f'{txt}{txt_cuda_title}{txt_cuda}' - if 'directML' in ask_installs: - txt = f'{txt}{txt_directML_title}{txt_directML}' - buttons.append(directMLButton) - txt = f'{txt}{txt_end}' - buttons.append(proceedButton) - - msg.warning( - qparent, 'PyTorch GPU version not installed', txt, - buttonsTexts=buttons, - ) - - if msg.cancel: - return False, False - - if msg.clickedButton == directMLButton: - py_ver = sys.version_info - if is_win and py_ver.major == 3 and py_ver.minor < 13: - success = check_install_package( - pkg_name = 'torch-directml', - import_pkg_name = 'torch_directml', - pypi_name = 'torch-directml', - return_outcome=True, - ) - purge_module('torch') - return success, True - else: - msg = widgets.myMessageBox() - msg.warning( - qparent, 'DirectML not supported', - 'DirectML is only supported on Python 3.8-3.12 and Windows 10/11', - ) - return False, False - - if msg.clickedButton == stopButton: - return False, False - - if msg.clickedButton == proceedButton: - return True, False - -def check_gpu_requested_segm_model(init_kwargs): - gpu = init_kwargs.get('gpu', False) - if gpu: - return True - - device_type = init_kwargs.get('device_type', 'cpu') - return device_type == 'gpu' or device_type == '' - -def check_gpu_available( - model_name, use_gpu, - do_not_warn=False, - qparent=None, - cuda=False, - directML=False, - return_available_gpu_type=False - ): - if not use_gpu: - if return_available_gpu_type: - return True, [] - else: - return True - - ask_for_cuda = False - if cuda: - try: - import torch - if not torch.cuda.is_available(): - ask_for_cuda = True - if not torch.cuda.device_count() > 0: - ask_for_cuda = True - except ModuleNotFoundError: - ask_for_cuda = True - - ask_for_directML = False - if directML: - if is_win: - try: - import torch_directml - if not torch_directml.is_available(): - ask_for_directML = True - except ModuleNotFoundError: - ask_for_directML = True - - frameworks = _available_frameworks(model_name) - ask_installs = set() if not ask_for_cuda else {'cuda'} - ask_installs.update( - {'directML'} if ask_for_directML else set() - ) - framework_available = False - available_frameworks_list = [] - for framework, model_compatible in frameworks.items(): - if not model_compatible: - continue - if framework == 'cuda': - import torch - if not torch.cuda.is_available(): - ask_installs.add('cuda') - elif not torch.cuda.device_count() > 0: - ask_installs.add('cuda') - else: - framework_available = True - available_frameworks_list.append('cuda') - elif framework == 'directML': - if is_win: - try: - import torch_directml - if not torch_directml.is_available(): - ask_installs.add('directML') - else: - framework_available = True - available_frameworks_list.append('directML') - except ModuleNotFoundError: - ask_installs.add('directML') - elif is_mac_arm64: - framework_available = True - break - - if framework_available and not ask_for_cuda and not ask_for_directML: - if return_available_gpu_type: - return True, available_frameworks_list - else: - return True - - elif do_not_warn: - if return_available_gpu_type: - return False, available_frameworks_list - else: - return False - - proceed, directML_installed = _warn_install_gpu(model_name, ask_installs, qparent=qparent) - if return_available_gpu_type: - if directML_installed: - available_frameworks_list.append('directML') - return proceed, available_frameworks_list - else: - return proceed - - -def _available_frameworks(model_name): - frameworks = { - - "cuda":( - model_name.lower().find('cellpose') != -1 - or model_name.lower().find('omnipose') != -1 - or model_name.lower().find('deepsea') != -1 - or model_name.lower().find('segment_anything') != -1 - or model_name.lower().find('sam2') != -1 - or model_name.lower().find('yeaz') != -1 - or model_name.lower().find('yeaz_v2') != -1 - ), - "directML":( - model_name.lower().find('cellpose_v4') != -1 - or model_name.lower().find('cellpose_v3') != -1# has its own way to check - - ) - } - return frameworks - -def find_missing_integers(lst, max_range=None): - if max_range is not None: - max_range = lst[-1]+1 - return [x for x in range(lst[0], max_range) if x not in lst] - -def synthetic_image_geneator(size=(512,512), f_x=1, f_y=1): - Y, X = size - x = np.linspace(0, 10, Y) - y = np.linspace(0, 10, X) - xx, yy = np.meshgrid(x, y) - img = np.sin(f_x*xx)*np.cos(f_y*yy) - return img - -def get_show_in_file_manager_text(): - if is_mac: - return 'Reveal in Finder' - elif is_linux: - return 'Show in File Manager' - elif is_win: - return 'Show in File Explorer' - -def get_slices_local_into_global_arr(bbox_coords, global_shape): - slice_global_to_local = [] - slice_crop_local = [] - for (_min, _max), _D in zip(bbox_coords, global_shape): - _min_crop, _max_crop = None, None - if _min < 0: - _min_crop = abs(_min) - _min = 0 - if _max > _D: - _max_crop = _D - _max - _max = _D - - slice_global_to_local.append(slice(_min, _max)) - slice_crop_local.append(slice(_min_crop, _max_crop)) - - return tuple(slice_global_to_local), tuple(slice_crop_local) - -def get_pip_install_cellacdc_version_command(version=None): - conda_prefix, pip_prefix = get_pip_conda_prefix() - - if version is None: - version = read_version() - commit_hash_idx = version.find('+g') - is_dev_version = commit_hash_idx > 0 - if is_dev_version: - commit_hash = version[commit_hash_idx+2:].split('.')[0] - command = f'{pip_prefix} --upgrade "git+{github_home_url}.git@{commit_hash}"' - command_github = None - else: - command = f'{pip_prefix} --upgrade cellacdc=={version}' - command_github = f'{pip_prefix} --upgrade "git+{urls.github_url}@{version}"' - return command, command_github - -def get_git_pull_checkout_cellacdc_version_commands(version=None): - if version is None: - version = read_version() - commit_hash_idx = version.find('+g') - is_dev_version = commit_hash_idx > 0 - if not is_dev_version: - return [] - commit_hash = version[commit_hash_idx+2:].split('.')[0] - commands = ( - f'cd "{os.path.dirname(cellacdc_path)}"', - 'git pull', - f'git checkout {commit_hash}' - ) - return commands - -def check_install_tapir(): - check_install_package( - 'tapnet', pypi_name='git+https://github.com/ElpadoCan/TAPIR.git' - ) - -def _download_tapir_model(): - urls, file_sizes = _model_url('TAPIR') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('TAPIR', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='TAPIR', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def _download_yeaz_models(): - urls, file_sizes = _model_url('YeaZ_v2') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('YeaZ_v2', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='YeaZ_v2', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def _download_cellpose_germlineNuclei_model(): - urls, file_sizes = _model_url('Cellpose_germlineNuclei') - temp_model_path = tempfile.mkdtemp() - _, final_model_path = ( - get_model_path('Cellpose_germlineNuclei', create_temp_dir=False) - ) - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='Cellpose_germlineNuclei', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def _download_omnipose_models(): - urls, file_sizes = _model_url('omnipose') - temp_model_path = tempfile.mkdtemp() - final_model_path = os.path.expanduser(r'~\.cellpose\models') - for url, file_size in zip(urls, file_sizes): - filename = url.split('/')[-1] - final_dst = os.path.join(final_model_path, filename) - if os.path.exists(final_dst): - continue - - temp_dst = os.path.join(temp_model_path, filename) - download_url( - url, temp_dst, file_size=file_size, desc='omnipose', - verbose=False - ) - - shutil.move(temp_dst, final_dst) - -def format_cca_manual_changes(changes: dict): - txt = '' - for ID, changes_ID in changes.items(): - txt = f'{txt}* ID {ID}:\n' - for col, (old_val, new_val) in changes_ID.items(): - txt = f'{txt} - {col}: {old_val} --> {new_val}\n' - txt = f'{txt}--------------------------------\n\n' - return txt - -def init_prompt_segm_model(acdcPromptSegment, posData, init_kwargs): - model = acdcPromptSegment.Model(**init_kwargs) - return model - -def init_segm_model(acdcSegment, posData, init_kwargs): - segm_endname = init_kwargs.pop('segm_endname', 'None') - if segm_endname != 'None': - load_segm = True - if not hasattr(posData, 'segm_data'): - load_segm = True - elif posData.segm_npz_path.endswith(f'{segm_endname}.npz'): - load_segm = False - if not load_segm: - segm_data = np.squeeze(posData.segm_data) - else: - segm_filepath, _ = load.get_path_from_endname( - segm_endname, posData.images_path - ) - printl(f'Loading segmentation data from "{segm_filepath}"...') - segm_data = np.load(segm_filepath)['arr_0'] - else: - segm_data = None - - # Initialize input_points_df for models promptable with points - input_points_filepath = init_kwargs.pop('input_points_path', '') - if input_points_filepath: - input_points_df = init_input_points_df( - posData, input_points_filepath - ) - init_kwargs['input_points_df'] = input_points_df - - try: - # Models introduced before 1.3.2 do not have the segm_data as input - kwargs = inspect.getfullargspec(acdcSegment.Model.__init__).args - if 'is_rgb' not in kwargs and 'is_rgb' in init_kwargs: - del init_kwargs['is_rgb'] - model = acdcSegment.Model(**init_kwargs) - - - except Exception as e: - model = acdcSegment.Model(segm_data, **init_kwargs) - - if hasattr(model, 'init_successful'): - if not model.init_successful: - return None - return model - -def _parse_bool_str(value): - if isinstance(value, bool): - return value - - if value == 'True': - return True - elif value == 'False': - return False - -def check_install_trackastra(): - check_install_package( - 'Trackastra', - import_pkg_name='trackastra', - pypi_name='trackastra' - ) - -def get_torch_device(gpu=False): - import torch - if torch.cuda.is_available() and gpu: - device = torch.device('cuda') - elif torch.backends.mps.is_available(): - device = torch.device('mps') - else: - device = torch.device('cpu') - return device - -def parse_model_params(model_argspecs, model_params): - parsed_model_params = {} - for row, argspec in enumerate(model_argspecs): - value = model_params.get(argspec.name) - if value is None: - continue - if argspec.type == bool: - value = _parse_bool_str(value) - elif argspec.type == int: - value = int(value) - elif argspec.type == float: - value = float(value) - parsed_model_params[argspec.name] = value - return parsed_model_params - -# def init_cellpose_denoise_model(): -# from . import apps - -# from cellacdc.models.cellpose_v3._denoise import ( -# CellposeDenoiseModel, url_help -# ) - -# init_argspecs, run_argspecs = getClassArgSpecs(CellposeDenoiseModel) -# url = url_help() - -# paramsWin = apps.QDialogModelParams( -# init_argspecs, run_argspecs, 'Cellpose 3.0', -# url=url, is_tracker=True, action_type='denoising' -# ) -# paramsWin.exec_() -# if paramsWin.cancel: -# return - -# init_params = paramsWin.init_kwargs -# run_params = paramsWin.model_kwargs -# denoise_model = CellposeDenoiseModel(**init_params) -# return denoise_model, init_params, run_params - -def init_input_points_df(posData, input_points_filepath): - input_points_df = None - if os.path.exists(input_points_filepath): - input_points_df = pd.read_csv(input_points_filepath) - else: - # input_points_filepath is actually and endname - for file in listdir(posData.images_path): - if file.endswith(input_points_filepath): - filepath = os.path.join(posData.images_path, file) - input_points_df = pd.read_csv(filepath) - break - - if input_points_df is None: - raise FileNotFoundError( - f'Could not find input points table from file "input_points_filepath" ' - 'Perhaps, you forgot to save the table?' - ) - - for col in ('x', 'y', 'id'): - if col not in input_points_df.columns: - raise KeyError( - f'Input points table is missing colum {col}. It must have ' - 'the colums (x, y, id)' - ) - - return input_points_df - -def are_acdc_dfs_equal(df_left, df_right): - if df_left.shape != df_right.shape: - return False - - try: - for col in df_left.columns: - if col not in df_right.columns: - return False - - try: - eq_mask = np.isclose(df_left[col], df_right[col], equal_nan=True) - except Exception as err: - # Data type is string - eq_mask = df_left[col] == df_right[col] - - nan_mask = ((df_left[col].isna()) & (df_right[col].isna())) - equality_mask = (eq_mask) | (nan_mask) - if not equality_mask.all(): - return False - except Exception as err: - return False - - return True - -def is_pos_folderpath(folderpath): - """Determine if a path is a valid Cell-ACDC Position folder - - Parameters - ---------- - folderpath : PathLike - Path to check - - Returns - ------- - bool - True if the path is a valid Cell-ACDC Position folder, False otherwise - - Notes - ----- - A valid Cell-ACDC Position folder must: - - Have a name matching the pattern 'Position_' - - Be a directory - - Contain an 'Images' subdirectory - - The 'Images' subdirectory must not be empty - """ - foldername = os.path.basename(folderpath) - is_valid_pos_folder = ( - re.search(r'^Position_(\d+)$', foldername) is not None - and os.path.isdir(folderpath) - and os.path.exists(os.path.join(folderpath, 'Images')) - and listdir(os.path.join(folderpath, 'Images')) - ) - return is_valid_pos_folder - -def log_segm_params( - model_name, init_params, segm_params, logger_func=print, - preproc_recipe=None, apply_post_process=False, - standard_postprocess_kwargs=None, custom_postprocess_features=None - ): - init_params_format = [ - f' * {option} = {value}' for option, value in init_params.items() - ] - init_params_format = '\n'.join(init_params_format) - - segm_params_format = [ - f' * {option} = {value}' for option, value in segm_params.items() - ] - segm_params_format = '\n'.join(segm_params_format) - - preproc_recipe_format = None - if preproc_recipe is not None: - preproc_recipe_format = [] - for s, step in enumerate(preproc_recipe): - preproc_recipe_format.append(f' * Step {s+1}') - method = step['method'] - preproc_recipe_format.append(f' - Method: {method}') - for option, value in step['kwargs'].items(): - preproc_recipe_format.append(f' - {option}: {value}') - preproc_recipe_format = '\n'.join(preproc_recipe_format) - - standard_postproc_format = None - if apply_post_process and standard_postprocess_kwargs is not None: - standard_postproc_format = [ - f' * {option} = {value}' - for option, value in standard_postprocess_kwargs.items() - ] - standard_postproc_format = '\n'.join(standard_postproc_format) - - custom_postproc_format = None - if apply_post_process and custom_postprocess_features is not None: - custom_postproc_format = [ - f' * {feature} = ({low}, {high})' - for feature, (low, high) in custom_postprocess_features.items() - ] - custom_postproc_format = '\n'.join(custom_postproc_format) - - separator = '-'*100 - params_format = ( - f'{separator}\n' - f'Model name: {model_name}\n\n' - 'Preprocessing recipe:\n\n' - f'{preproc_recipe_format}\n\n' - 'Initialization parameters:\n\n' - f'{init_params_format}\n\n' - 'Segmentation parameters:\n\n' - f'{segm_params_format}\n\n' - 'Post-processing:\n\n' - f'{standard_postproc_format}\n\n' - 'Custom post-processing:\n\n' - f'{custom_postproc_format}\n' - f'{separator}' - ) - logger_func(params_format) - -def pairwise(iterable): - # pairwise('ABCDEFG') → AB BC CD DE EF FG - iterator = iter(iterable) - a = next(iterator, None) - for b in iterator: - yield a, b - a = b - -def append_text_filename(filename: str, text_to_append: str): - filename_noext, ext = os.path.splitext(filename) - filename_out = f'{filename_noext}{text_to_append}{ext}' - return filename_out - -def validate_images_path(input_path: os.PathLike, create_dirs_tree=False): - is_images_path = input_path.endswith('Images') - parent_dir = os.path.dirname(input_path) - parent_foldername = os.path.basename(parent_dir) - is_pos_folder = ( - re.search(r'^Position_(\d+)$', parent_foldername) is not None - and os.path.isdir(parent_dir) - ) - if not is_pos_folder: - existing_pos_foldernames = get_pos_foldernames(input_path) - pos_n = len(existing_pos_foldernames) + 1 - pos_folderpath = os.path.join(input_path, f'Position_{pos_n}') - images_path = os.path.join(pos_folderpath, 'Images') - elif is_images_path: - pos_folderpath = input_path - images_path = os.path.join(pos_folderpath, 'Images') - else: - images_path = input_path - - if create_dirs_tree: - os.makedirs(images_path, exist_ok=True) - - return images_path - -def fix_acdc_df_dtypes(acdc_df): - acdc_df['is_cell_excluded'] = acdc_df['is_cell_excluded'].astype(bool) - return acdc_df - -def _relabel_cca_dfs_and_segm_data( - cca_dfs, - IDs_mapper, - asymm_tracked_segm, - progressbar=True, - ): - # Rename Cell_ID index according to asymmetric cell div convention - if progressbar: - pbar = tqdm( - desc='Applying asymmetric division', - total=len(IDs_mapper), ncols=100 - ) - for key, (root_ID, parent_ID) in IDs_mapper.items(): - div_frame_i, daughter_ID = key - for frame_i in range(div_frame_i, len(asymm_tracked_segm)): - - - lab = asymm_tracked_segm[frame_i] - rp = skimage.measure.regionprops(lab) - rp_mapper = {obj.label: obj for obj in rp} - obj_daught = rp_mapper.get(daughter_ID) - mother_ID = root_ID if rp_mapper.get(root_ID) is None else parent_ID - - cca_dfs[frame_i].rename( - index={daughter_ID: mother_ID}, inplace=True - ) - - if obj_daught is None: - continue - - lab[obj_daught.slice][obj_daught.image] = mother_ID - - if progressbar: - pbar.update() - - if progressbar: - pbar.close() - -def df_ctc_to_acdc_df( - df_ctc, tracked_segm, cell_division_mode='Normal', return_list=False, - progressbar=True - ): - """Convert Cell Tracking Challenge DataFrame with annotated division to - Cell-ACDC cell cycle annotations DataFrame. - - Parameters - ---------- - df_ctc : pd.DataFrame - DataFrame with {'label', 't1', 't2', 'parent'} columns where - 't1' is the frame index of cell division. - tracked_segm : (T, Y, X) array of ints - Array of tracked segmentation labels. - cell_division_mode : {'Normal', 'Asymmetric'}, optional - Type of cell division. `Normal` is the standard cell division, - where the mother cell divides into two daughter cells. For the - tracking, that means the two daughter cells get a new, unique ID - each. - - `Asymmetric` means that the mother cell grows one daughter - cell that eventually divides from the mother (e.g., budding yeast). - For the tracking, this means that the mother cell ID keeps - existing after division and the daughter cell gets a new, unique ID. - - If `Asymmetric`, the third returned element is the segmentation data - with the asymmetric Cell IDs. - return_list : bool, optional - If `True`, the second returned element is the list of created dataframes, - one per frame. Default is False - progressbar : bool, optional - If `True`, displays a tqdm progressbar. Default is True - """ - cca_dfs = [] - keys = [] - df_ctc = df_ctc.set_index(['t1', 'parent']) - - if cell_division_mode == 'Asymmetric': - asymm_tracked_segm = tracked_segm.copy() - - asymmetric_IDs_rename_mapper = {} - if progressbar: - pbar = tqdm( - desc='Converting to Cell-ACDC format', - total=len(tracked_segm), ncols=100 - ) - for frame_i, lab in enumerate(tracked_segm): - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - cca_df = core.getBaseCca_df(IDs, with_tree_cols=True) - keys.append(frame_i) - if frame_i == 0: - cca_dfs.append(cca_df) - if progressbar: - pbar.update() - continue - - # Copy annotations from previous frames - prev_cca_df = cca_dfs[frame_i-1] - old_IDs = cca_df.index.intersection(prev_cca_df.index) - cca_df.loc[old_IDs] = prev_cca_df.loc[old_IDs] - - try: - df_ctc_i = df_ctc.loc[frame_i] - except KeyError as err: - # No division detected --> nothing to annotate - cca_dfs.append(cca_df) - if progressbar: - pbar.update() - continue - - for parent_ID, df_ctc_i_pID in df_ctc_i.groupby(level=0): - daughter_IDs = df_ctc_i_pID['label'].to_list() - - if parent_ID == 0: - continue - - cca_df.loc[daughter_IDs, 'parent_ID_tree'] = parent_ID - cca_df.loc[daughter_IDs, 'emerg_frame_i'] = frame_i - cca_df.loc[daughter_IDs, 'division_frame_i'] = frame_i - - root_ID = prev_cca_df.at[parent_ID, 'root_ID_tree'] - if root_ID == -1: - root_ID = parent_ID - cca_df.loc[daughter_IDs, 'root_ID_tree'] = root_ID - - cca_df.loc[daughter_IDs[0], 'sister_ID_tree'] = daughter_IDs[1] - cca_df.loc[daughter_IDs[1], 'sister_ID_tree'] = daughter_IDs[0] - - prev_gen_num = prev_cca_df.loc[parent_ID, 'generation_num_tree'] - cca_df.loc[daughter_IDs, 'generation_num_tree'] = prev_gen_num + 1 - - # Annotate division from df_ctc_i into - if cell_division_mode == 'Asymmetric': - # Recycle the root_ID and assign it to one of the daughters - replaced_daught_ID = daughter_IDs[1] - key = (frame_i, replaced_daught_ID) - asymmetric_IDs_rename_mapper[key] = (root_ID, parent_ID) - - cca_dfs.append(cca_df) - - if progressbar: - pbar.update() - - if progressbar: - pbar.close() - - if asymmetric_IDs_rename_mapper: - _relabel_cca_dfs_and_segm_data( - cca_dfs, - asymmetric_IDs_rename_mapper, - asymm_tracked_segm, - progressbar=True, - ) - - cca_df = pd.concat(cca_dfs, keys=keys, names=['frame_i']) - - out = [cca_df, None, None] - - if return_list: - out[1] = cca_dfs - - if cell_division_mode == 'Asymmetric': - out[2] = asymm_tracked_segm - - return out - -def check_install_instanseg(): - check_install_package( - pkg_name='InstanSeg', - import_pkg_name='instanseg', - pypi_name='instanseg-torch' - ) - -def validate_tracker_input(tracker, segm_video_to_track): - try: - warning_text = tracker.validate_input(segm_video_to_track) - return warning_text - except Exception as err: - printl(traceback.format_exc()) - pass - return -def format_IDs(IDs): - if isinstance(IDs, str): - raise ValueError('IDs must not be a string') - - IDsRange = [] - text = '' - sorted_vals = sorted(IDs) - for i, e in enumerate(sorted_vals): - e = int(e) - # Get previous and next value (if possible) - if i > 0: - prevVal = sorted_vals[i-1] - else: - prevVal = -1 - if i < len(sorted_vals)-1: - nextVal = sorted_vals[i+1] - else: - nextVal = -1 - - if e-prevVal == 1 or nextVal-e == 1: - if not IDsRange: - if nextVal-e == 1 and e-prevVal != 1: - # Current value is the first value of a new range - IDsRange = [e] - else: - # Current value is the second element of a new range - IDsRange = [prevVal, e] - else: - if e-prevVal == 1: - # Current value is part of an ongoing range - IDsRange.append(e) - else: - # Current value is the first element of a new range - # --> create range text and this element will - # be added to the new range at the next iter - start, stop = IDsRange[0], IDsRange[-1] - if stop-start > 1: - sep = '-' - else: - sep = ',' - text = f'{text},{start}{sep}{stop}' - IDsRange = [] - else: - # Current value doesn't belong to a range - if IDsRange: - # There was a range not added to text --> add it now - start, stop = IDsRange[0], IDsRange[-1] - if stop-start > 1: - sep = '-' - else: - sep = ',' - text = f'{text},{start}{sep}{stop}' - - text = f'{text},{e}' - IDsRange = [] - - if IDsRange: - # Last range was not added --> add it now - start, stop = IDsRange[0], IDsRange[-1] - text = f'{text},{start}-{stop}' - - text = text[1:] - - return text - -def get_empty_stored_data_dict(): - return { - 'regionprops': None, - 'labels': None, - 'acdc_df': None, - 'delROIs_info': { - 'rois': [], 'delMasks': [], 'delIDsROI': [], 'state': [] - }, - 'IDs': [], - 'manually_edited_lab': {'lab': {}, 'zoom_slice': None} - } - -def iterate_along_axes(arr, axes, arr_ndim=None): - if arr_ndim is None: - arr_ndim = arr.ndim - axes = list(axes) - front_axes = axes + [i for i in range(arr_ndim) if i not in axes] - arr_moved = np.moveaxis(arr, front_axes, range(arr_ndim)) - iter_shape = arr_moved.shape[:len(axes)] - for idx in np.ndindex(iter_shape): - # Build the index for the original array - full_idx = [slice(None)] * arr_ndim - for axis, i in zip(axes, idx): - full_idx[axis] = i - yield tuple(full_idx) - -def get_input_output_mapper( - input_shape: Tuple[int], - iterate_axes: Tuple[int], - output_shape: Tuple[int], - output_axes: Tuple[int], -) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]: - """Creates list of tuples with the input and output indices - - Parameters - ---------- - input_shape : Tuple[int] - Shape of the input array - iterate_axes : Tuple[int] - Axes to iterate over - output_shape : Tuple[int] - Shape of the output array - output_axes : Tuple[int] - Axes of the output array - """ - assert len(iterate_axes) == len(output_axes) - - iterate_shape = tuple(input_shape[axis] for axis in iterate_axes) - mapper = [] - - for idx_vals in itertools.product(*[range(s) for s in iterate_shape]): - # Build full input index - input_index = [slice(None)] * len(input_shape) - for axis in iterate_axes: - i = iterate_axes.index(axis) - input_index[axis] = idx_vals[i] - - # Build full output index - output_index = [slice(None)] * len(output_shape) - for axis in output_axes: - i = output_axes.index(axis) - output_index[axis] = idx_vals[i] - - input_index = tuple(input_index) - output_index = tuple(output_index) - - mapper.append((input_index, output_index)) - - return mapper - -def translateStrNone(*args): - args = list(args) - for i, arg in enumerate(args): - if isinstance(arg, str): - if arg.lower() == 'none': - args[i] = None - elif arg.lower() == 'true': - args[i] = True - elif arg.lower() == 'false': - args[i] = False - - return args - -def get_pytorch_command(): - """Get the command to install pytorch CPU or CUDA - - Returns - ------- - dict - Dictionary mapping OS to commands for installing PyTorch - - Notes - ----- - As of Oct 2024, the `pytorch` channel on Anaconda was deprecated. - See here https://github.com/pytorch/pytorch/issues/138506 - """ - conda_prefix, pip_prefix = get_pip_conda_prefix() - - pytorch_commands = { - 'Windows': { - # 'Conda': { - # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', - # 'CUDA 11.8 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=11.8 -c conda-forge -c nvidia', - # 'CUDA 12.1 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c nvidia' - # }, - 'Pip': { - 'CPU': f'{pip_prefix} torch torchvision', - 'CUDA 11.8 (NVIDIA GPU)': f'{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu118', - 'CUDA 12.1 (NVIDIA GPU)': f'{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu121' - } - }, - 'Mac': { - # 'Conda': { - # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', - # 'CUDA 11.8 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS', - # 'CUDA 12.1 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS' - # }, - 'Pip': { - 'CPU': f'{pip_prefix} torch torchvision', - 'CUDA 11.8 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS', - 'CUDA 12.1 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS' - } - }, - 'Linux': { - # 'Conda': { - # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', - # 'CUDA 11.8 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=11.8 -c conda-forge -c nvidia', - # 'CUDA 12.1 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c nvidia' - # }, - 'Pip': { - 'CPU': f'{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cpu', - 'CUDA 11.8 (NVIDIA GPU)': f'{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu118', - 'CUDA 12.1 (NVIDIA GPU)': f'{pip_prefix} torch torchvision' - } - } - } - - return pytorch_commands - -def get_package_info(package_name): - try: - result = subprocess.run([ - sys.executable, '-m', 'pip', 'show', package_name - ], capture_output=True, text=True, check=True) - - info = {} - for line in result.stdout.split('\n'): - if ':' in line: - key, value = line.split(':', 1) - info[key.strip()] = value.strip() - - # Check if it's editable by looking at the location - location = info.get('Location', '') - editable_location = info.get('Editable project location', '') - - return { - 'installed': True, - 'editable': bool(editable_location), - 'location': location, - 'editable_location': editable_location - } - - except subprocess.CalledProcessError: - return {'installed': False, 'editable': False} - -# Usage -def update_package(parent, package_name): - package_info = get_package_info(package_name) - if not package_info['installed']: - printl(f"Package {package_name} is not installed.") - return False - editable = package_info.get('editable', False) - if editable: - return update_editable_package(parent, package_name, package_info) - else: - return update_not_editable_package(package_name, package_info) - -def update_editable_package(parent, package_name, package_info): - repo_location = package_info.get('editable_location', '') - - if not repo_location or not os.path.exists(repo_location): - print(f"Repository location not found for {package_name}") - return False - - return _update_repo_with_git_command(package_name, repo_location) - -def _update_repo_with_git_command(package_name, repo_location): - """Update repository using git command""" - try: - print(f"Updating {package_name} repository at {repo_location} using git command...") - - # Change to repository directory - original_cwd = os.getcwd() - os.chdir(repo_location) - - stashed_changes = False - - # check if there is a portable git - from .config import parser_args - try: - cp = parser_args - if cp["install_details"] is not None: - no_cli_install = True - install_details = cp["install_details"] - target_dir = install_details.get('target_dir', '') - target_dir = target_dir.strip().strip('"').strip("'") - target_dir = os.path.abspath(target_dir) - else: - no_cli_install = False - except: - no_cli_install = False - pass - - if is_win and no_cli_install: - git_loc = os.path.join(target_dir, - "portable_git", - "cmd", - "git.exe") - if not os.path.exists(git_loc): - print(f"Portable git not found at {git_loc}. Using system git.") - git_loc = 'git' - else: - git_loc = 'git' - - # Check if git is available - if not shutil.which(git_loc): - print(f"Git command not found. Please install git to update {package_name}.") - return False - - try: - # Check for uncommitted changes - - branch_result = subprocess.run([git_loc, 'branch', '--show-current'], - capture_output=True, text=True, check=True) - current_branch = branch_result.stdout.strip() - print(f"Current branch: {current_branch}") - - result = subprocess.run([git_loc, 'status', '--porcelain'], - capture_output=True, text=True, check=True) - if result.stdout.strip(): - print(f"Repository {package_name} has uncommitted changes") - print("Stashing changes before update...") - subprocess.run([git_loc, 'stash'], check=True) - stashed_changes = True - - # Pull changes - subprocess.run([git_loc, 'pull'], check=True) - print(f"Successfully updated {package_name}") - - # Pop stashed changes if any were stashed - if stashed_changes: - try: - subprocess.run([git_loc, 'stash', 'pop'], check=True) - print("Restored stashed changes") - except subprocess.CalledProcessError as pop_error: - print(f"Warning: Could not restore stashed changes: {pop_error}") - - return True - - except subprocess.CalledProcessError as e: - print(f"Git command failed for {package_name}: {e}") - return False - finally: - os.chdir(original_cwd) - - except Exception as e: - print(f"Error updating {package_name} with git command: {e}") - return False - -def update_not_editable_package(package_name, package_info): - """Update a non-editable package using pip""" - try: - _, pip_list = get_pip_conda_prefix(list_return=True) - command = pip_list + ["--upgrade ", package_name] - - print(f"Updating {package_name} using pip...") - result = subprocess.run(command, shell=True, capture_output=True, text=True) - - if result.returncode == 0: - print(f"Successfully updated {package_name}") - return True - else: - print(f"Failed to update {package_name}: {result.stderr}") - return False - - except Exception as e: - print(f"Error updating {package_name}: {e}") - return False - -def try_kwargs(func, *args, **kwargs): - """ - Attempt to call a function with the provided arguments and keyword arguments. - - If the function raises a TypeError due to unexpected keyword arguments, - those arguments are dynamically removed, and the function is retried. - This process continues until the function succeeds or no keyword arguments - remain, in which case the exception is re-raised. - - Args: - func (Callable): The function to call. - *args: Positional arguments to pass to the function. - **kwargs: Keyword arguments to pass to the function. - - Returns: - Tuple[Any, List[str]]: A tuple containing: - - The result of the function call (or None if it fails). - - A list of keyword arguments that were removed. - - Raises: - ValueError: If a keyword argument mentioned in the error message - is not found in the provided kwargs. - TypeError: If the function fails with a TypeError after all keyword - arguments have been removed. - """ - - kwargs = kwargs.copy() # Create a copy to avoid modifying the original - removed_kwargs = [] - pattern = r"unexpected keyword argument ['\"](\w+)['\"]" - while True: - try: - return func(*args, **kwargs), removed_kwargs - except TypeError as e: - match = re.search(pattern, str(e)) - if match: - kwarg_name = match.group(1) - if kwarg_name in kwargs: - del kwargs[kwarg_name] - removed_kwargs.append(kwarg_name) - else: - raise ValueError( - f"Keyword argument '{kwarg_name}' not found in kwargs." - ) - else: - raise e - - if len(kwargs) == 0: - print(f"Function {func.__name__} failed with TypeError: {e}") - raise e - -def get_obj_by_label(rp, target_label): - """ - Returns the object with the specified label from the given list of objects. - - Parameters - ---------- - rp : list - The list of objects to search through. - target_label : str - The label of the object to find. - - Returns - ------- - object - The object with the specified label, or None if not found. - """ - for obj in rp: - if obj.label == target_label: - return obj - return None - -def find_distances_ID(rps, point=None, ID=None): - """ - Calculate the distances between a given point and the centroids of a list of regionprops. - - Parameters - ---------- - rps : list - List of regionprops objects. - point : tuple, optional - The coordinates of the point. Defaults to None. - ID : int, optional - The label ID of the regionprops object. Defaults to None. - - Returns - ------- - numpy.ndarray - A matrix of distances between the point and the centroids. - - Raises - ------ - ValueError - If ID is not found in the list of regionprops (list of cells). - ValueError - If neither ID nor point is provided. - ValueError - If both ID and point are provided. - """ - - if ID is not None and point is None: - try: - point = [rp.centroid for rp in rps if rp.label == ID][0] - except IndexError: - raise ValueError(f'ID {ID} not found in regionprops (list of cells).') - - elif ID is None and point is None: - raise ValueError('Either ID or point must be provided.') - - elif ID is not None and point is not None: - raise ValueError('Only one of ID or point must be provided.') - - point = point[::-1] # rp are in (y, x) format (or (z, y, x) for 3D data) so I need to reverse order - point = np.array([point]) - centroids = np.array([rp.centroid for rp in rps]) - diff = point[:, np.newaxis] - centroids - dist_matrix = np.linalg.norm(diff, axis=2) - return dist_matrix - -def sort_IDs_dist(rps, point=None, ID=None): - """Sorts the IDs of regionprops based on their distances to a given point. - - Parameters - ---------- - rps : list - A list of regionprops objects representing cells. - point : tuple, optional - The coordinates of the point to calculate distances from. - If not provided, it will be calculated based on the given ID. - ID : int, optional - The ID of the regionprops object to calculate distances from. - If this and point are both provided, or neither, an error will be - raised. - - Returns - ------- - list - A sorted list of IDs based on their distances to the given point. - - Raises - ------ - ValueError - If ID is not found in the list of regionprops objects. - ValueError - If neither ID nor point is provided. - ValueError - If both ID and point are provided. - - """ - if ID is not None and point is None: - try: - point = [rp.centroid for rp in rps if rp.label == ID][0] - except IndexError: - raise ValueError(f'ID {ID} not found in regionprops (list of cells).') - - elif ID is None and point is None: - raise ValueError('Either ID or point must be provided.') - - elif ID is not None and point is not None: - raise ValueError('Only one of ID or point must be provided.') - - - IDs = [rp.label for rp in rps] - if len(IDs) == 0: - return [] - elif len(IDs) == 1: - return IDs - dist_matrix = find_distances_ID(rps, point=point) - dist_matrix = np.squeeze(dist_matrix) - - sorted_ids = sorted(zip(dist_matrix, IDs)) - sorted_ids = [ID for _, ID in sorted_ids] - return sorted_ids - -def safe_get_or_call(obj, path: str): - """Safely get nested attributes or call methods with literal args from a string path.""" - expr = ast.parse(path, mode='eval').body - - def _eval(node, current_obj): - if isinstance(node, ast.Attribute): - return getattr(_eval(node.value, current_obj), node.attr) - elif isinstance(node, ast.Call): - func = _eval(node.func, current_obj) - args = [ast.literal_eval(arg) for arg in node.args] - kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in node.keywords} - return func(*args, **kwargs) - elif isinstance(node, ast.Name): - # First name in chain is assumed to be from `obj` - return getattr(current_obj, node.id) - else: - raise ValueError(f"Unsupported syntax: {ast.dump(node)}") - - return _eval(expr, obj) - -def format_commit_date_utc(utc_str): - # Parse the UTC date string (ISO 8601 format) - dt = datetime.datetime.fromisoformat(utc_str.replace("Z", "+00:00")) - - # Convert to your local time zone (optional) - local_dt = dt.astimezone() # removes UTC offset if local - - # Format nicely - return local_dt.strftime(r"%A %d %B %Y at %H:%M") - -def get_linux_distribution_name(): - import csv - RELEASE_DATA = {} - with open("/etc/os-release") as f: - reader = csv.reader(f, delimiter="=") - for row in reader: - if row: - RELEASE_DATA[row[0]] = row[1] - if RELEASE_DATA["ID"] in ["debian", "raspbian"]: - with open("/etc/debian_version") as f: - DEBIAN_VERSION = f.readline().strip() - major_version = DEBIAN_VERSION.split(".")[0] - version_split = RELEASE_DATA["VERSION"].split(" ", maxsplit=1) - if version_split[0] == major_version: - # Just major version shown, replace it with the full version - RELEASE_DATA["VERSION"] = " ".join([DEBIAN_VERSION] + version_split[1:]) - - name_version = f'{RELEASE_DATA["NAME"]} {RELEASE_DATA["VERSION"]}' - - return name_version - -def reset_settings(): - question = ( - 'Do you want to reset Cell-ACDC settings' - '- type "h" for help - (y/[n]/h)? ' - ) - info_txt = ( - 'If you reset Cell-ACDC settings, the folder below will be deleted.\n\n' - 'This means deeleting things like custom shortcuts, recent paths, last ' - 'selections, and GUI preferences.\n\n' - f'Settings folder path: "{settings_folderpath}"' - ) - answer = 'y' - while True: - try: - answer = input(f'\n{question}') - except Exception as err: - break - - if answer == 'n': - print('*'*100) - return 'Resetting Cell-ACDC settings cancelled.' - - if answer == 'y': - break - - if answer == 'h': - print('-'*100) - print(f'\n{info_txt}') - print('='*100) - - print( - f'"{answer}" is not a valid answer. ' - 'Type "y" for "yes", "n" for "no", or "h" for help.' - ) - - try: - os.remove(settings_folderpath) - print('*'*100) - out_txt = ( - 'Cell-ACDC settings have been reset.\n\n' - 'The following folder was deleted:\n\n' - f'{settings_folderpath}' - ) - except Exception as err: - traceback.print_exc() - print('*'*100) - out_txt = ( - '**ERROR** occured when trying to remove the settings folder.\n\n' - 'To reset Cell-ACDC settings, please remove this folder:\n\n' - f'{settings_folderpath}\n' - ) - return out_txt - -def separate_fluo_segment_channels(channels): - segms_to_load = [] - channels_to_load = [] - current_segm = False - for ch in channels: - if ch == 'current segm.': - current_segm = True - elif 'segm' in ch: - segms_to_load.append(ch) - else: - channels_to_load.append(ch) - return segms_to_load, channels_to_load, current_segm diff --git a/cellacdc/napari_utils/arboretum.py b/cellacdc/napari_utils/arboretum.py index e23de9edf..00dac275a 100644 --- a/cellacdc/napari_utils/arboretum.py +++ b/cellacdc/napari_utils/arboretum.py @@ -2,21 +2,18 @@ from functools import partial from natsort import natsorted -from .. import myutils, apps, load, printl, core, widgets +from .. import utils, apps, load, printl, core, widgets from .. import exception_handler -from ..utils import base +from ..tools import base from qtpy.QtCore import QTimer, Signal + class NapariArboretumDialog(base.MainThreadSinglePosUtilBase): - def __init__( - self, posPath, app, title: str, infoText: str, parent=None - ): - - module = myutils.get_module_name(__file__) - super().__init__( - app, title, module, infoText, parent - ) + def __init__(self, posPath, app, title: str, infoText: str, parent=None): + + module = utils.get_module_name(__file__) + super().__init__(app, title, module, infoText, parent) self.sigClose.connect(self.close) @@ -25,95 +22,98 @@ def __init__( @exception_handler def launchNapariArboretum(self, posPath): - images_path = os.path.join(posPath, 'Images') - ls = myutils.listdir(images_path) + images_path = os.path.join(posPath, "Images") + ls = utils.listdir(images_path) image_files = [ - file for file in ls - if file.endswith('.tif') - or file.endswith('aligned.npz') - or file.endswith('.h5') + file + for file in ls + if file.endswith(".tif") + or file.endswith("aligned.npz") + or file.endswith(".h5") ] selectImageFile = widgets.QDialogListbox( - 'Select image file', - 'Select which image file to load\n', - image_files, multiSelection=False, parent=self + "Select image file", + "Select which image file to load\n", + image_files, + multiSelection=False, + parent=self, ) selectImageFile.exec_() if selectImageFile.cancel: - self.logger.info('napari-arboretum utility aborted.') + self.logger.info("napari-arboretum utility aborted.") return imageFile = selectImageFile.selectedItemsText[0] - self.logger.info(f'Loading image file {imageFile}...') - + self.logger.info(f"Loading image file {imageFile}...") + imagePath = os.path.join(images_path, imageFile) - posData = load.loadData(imagePath, '') + posData = load.loadData(imagePath, "") posData.getBasenameAndChNames() posData.loadImgData() segm_files = load.get_segm_files(posData.images_path) - existingEndnames = load.get_endnames( - posData.basename, segm_files - ) + existingEndnames = load.get_endnames(posData.basename, segm_files) if len(existingEndnames) > 1: win = apps.SelectSegmFileDialog( - existingEndnames, images_path, parent=self, - basename=posData.basename + existingEndnames, images_path, parent=self, basename=posData.basename ) win.exec_() if win.cancel: - self.logger.info('napari-arboretum utility aborted.') + self.logger.info("napari-arboretum utility aborted.") return selectedSegmEndName = win.selectedItemText else: selectedSegmEndName = existingEndnames[0] - self.logger.info(f'Loading segmentation file ending with {selectedSegmEndName}...') + self.logger.info( + f"Loading segmentation file ending with {selectedSegmEndName}..." + ) posData.loadOtherFiles( load_segm_data=True, load_acdc_df=True, - end_filename_segm=selectedSegmEndName + end_filename_segm=selectedSegmEndName, ) - self.logger.info('Importing napari...') + self.logger.info("Importing napari...") import napari - self.logger.info('Building arboretum lineage tree...') + self.logger.info("Building arboretum lineage tree...") acdc_df = posData.acdc_df.reset_index() tree = core.LineageTree(acdc_df, logging_func=self.logger.info) tracks_data, graph, properties = tree.to_arboretum() props = natsorted(acdc_df.columns.to_list()) selectProps = widgets.QDialogListbox( - 'Select measurements', - 'Select measurements to add as properties in napari viewer

    ' - 'Ctrl+Click to select multiple items
    ' - 'Shift+Click to select a range of items
    ', - props, multiSelection=True, parent=self + "Select measurements", + "Select measurements to add as properties in napari viewer

    " + "Ctrl+Click to select multiple items
    " + "Shift+Click to select a range of items
    ", + props, + multiSelection=True, + parent=self, ) selectProps.exec_() if selectProps.cancel: - self.logger.info('napari-arboretum utility aborted.') + self.logger.info("napari-arboretum utility aborted.") return - + for col in selectProps.selectedItemsText: try: properties[col] = acdc_df[col] except Exception as e: pass - self.logger.info('Launching napari viewer...') + self.logger.info("Launching napari viewer...") viewer = napari.Viewer() viewer.add_image(posData.img_data, name=imageFile) viewer.add_labels(posData.segm_data, name=selectedSegmEndName) - acdc_df_endname = selectedSegmEndName.replace('segm', 'acdc_tracks') + acdc_df_endname = selectedSegmEndName.replace("segm", "acdc_tracks") viewer.add_tracks( - tracks_data, graph=graph, name=acdc_df_endname, - properties=properties + tracks_data, graph=graph, name=acdc_df_endname, properties=properties ) viewer.window.add_plugin_dock_widget( plugin_name="napari-arboretum", widget_name="Arboretum" @@ -121,7 +121,5 @@ def launchNapariArboretum(self, posPath): napari.run(max_loop_level=2) - self.logger.info('napari viewer closed.') + self.logger.info("napari viewer closed.") self.close() - - diff --git a/cellacdc/path.py b/cellacdc/path.py index beca21849..68fd9da76 100644 --- a/cellacdc/path.py +++ b/cellacdc/path.py @@ -9,76 +9,80 @@ from . import is_mac, is_linux from . import printl -from . import myutils +from . import utils + def listdir(path): - return natsorted([ - f for f in os.listdir(path) - if not f.startswith('.') - and not f == 'desktop.ini' - and not f == 'recovery' - ]) - -def newfilepath(file_path, appended_text: str=None): + return natsorted( + [ + f + for f in os.listdir(path) + if not f.startswith(".") and not f == "desktop.ini" and not f == "recovery" + ] + ) + + +def newfilepath(file_path, appended_text: str = None): if appended_text is None: - appended_text='' - + appended_text = "" + if not os.path.exists(file_path): return file_path, appended_text - + folder_path = os.path.dirname(file_path) filename = os.path.basename(file_path) filename, ext = os.path.splitext(filename) if appended_text: - if appended_text.startswith('_'): - appended_text = appended_text.lstrip('_') + if appended_text.startswith("_"): + appended_text = appended_text.lstrip("_") if appended_text: - new_filename = f'{filename}_{appended_text}{ext}' + new_filename = f"{filename}_{appended_text}{ext}" new_filepath = os.path.join(folder_path, new_filename) if not os.path.exists(new_filepath): return new_filepath, appended_text - + i = 0 while True: if appended_text: - new_filename = f'{filename}_{appended_text}_{i+1}{ext}' + new_filename = f"{filename}_{appended_text}_{i + 1}{ext}" else: - new_filename = f'{filename}_{i+1}{ext}' + new_filename = f"{filename}_{i + 1}{ext}" new_filepath = os.path.join(folder_path, new_filename) if not os.path.exists(new_filepath): - return new_filepath, f'{appended_text}_{i+1}' + return new_filepath, f"{appended_text}_{i + 1}" i += 1 + def show_in_file_manager(path): if is_mac: - args = ['open', fr'{path}'] + args = ["open", rf"{path}"] elif is_linux: - args = ['xdg-open', fr'{path}'] + args = ["xdg-open", rf"{path}"] else: if os.path.isfile(path): - args = ['explorer', '/select,', os.path.realpath(path)] + args = ["explorer", "/select,", os.path.realpath(path)] else: - args = ['explorer', os.path.realpath(path)] + args = ["explorer", os.path.realpath(path)] subprocess.run(args) + def copy_or_move_tree( - src: os.PathLike, dst: os.PathLike, copy=False, - sigInitPbar=None, sigUpdatePbar=None - ): + src: os.PathLike, dst: os.PathLike, copy=False, sigInitPbar=None, sigUpdatePbar=None +): if sigInitPbar is not None: sigInitPbar.emit(0) - + files_failed_move = {} files_info = {} for root, dirs, files in os.walk(src): for file in files: - rel_path = os.path.relpath(root, src).replace('\\', '/') + rel_path = os.path.relpath(root, src).replace("\\", "/") src_filepath = os.path.join(root, file) - dst_filepath = os.path.join(dst, *rel_path.split('/'), file) + dst_filepath = os.path.join(dst, *rel_path.split("/"), file) files_info[src_filepath] = dst_filepath - + if sigInitPbar is not None: sigInitPbar.emit(len(files_info)) for src_filepath, dst_filepath in files_info.items(): @@ -95,25 +99,27 @@ def copy_or_move_tree( sigUpdatePbar.emit(1) return files_failed_move + def get_posfolderpaths_walk(folderpath): pos_folderpaths = defaultdict(set) for root, dirs, files in os.walk(folderpath): - if not root.endswith('Images'): + if not root.endswith("Images"): continue - + pos_folderpath = os.path.dirname(root) - if not myutils.is_pos_folderpath(pos_folderpath): + if not utils.is_pos_folderpath(pos_folderpath): continue - - exp_path = os.path.dirname(pos_folderpath).replace('\\', '/') + + exp_path = os.path.dirname(pos_folderpath).replace("\\", "/") pos_foldername = os.path.basename(pos_folderpath) pos_folderpaths[exp_path].add(pos_foldername) - + for exp_path in pos_folderpaths.keys(): pos_folderpaths[exp_path] = natsorted(pos_folderpaths[exp_path]) - + return pos_folderpaths + def get_exp_path_pos_foldernames_mapper(paths): mapper = defaultdict(lambda: defaultdict(list)) @@ -123,9 +129,9 @@ def get_exp_path_pos_foldernames_mapper(paths): filename = os.path.basename(path) path = os.path.dirname(path) - folder_type = myutils.determine_folder_type(path) + folder_type = utils.determine_folder_type(path) is_pos_folder, is_images_folder, _ = folder_type - + if filename is not None and not is_images_folder: continue @@ -139,16 +145,15 @@ def get_exp_path_pos_foldernames_mapper(paths): else: path_mapper = get_posfolderpaths_walk(path) for exp_path, pos_foldernames in path_mapper.items(): - mapper[exp_path]['pos_foldernames'].extend(pos_foldernames) + mapper[exp_path]["pos_foldernames"].extend(pos_foldernames) continue - + exp_path = os.path.dirname(pos_folderpath) pos_foldername = os.path.basename(pos_folderpath) - key = exp_path.replace('\\', '/') - mapper[key]['pos_foldernames'].append(pos_foldername) + key = exp_path.replace("\\", "/") + mapper[key]["pos_foldernames"].append(pos_foldername) if filename is not None: - mapper[key]['filenames'].append(filename) - - return mapper + mapper[key]["filenames"].append(filename) + return mapper diff --git a/cellacdc/plot.py b/cellacdc/plot.py index 3d33b3c23..124864569 100644 --- a/cellacdc/plot.py +++ b/cellacdc/plot.py @@ -25,48 +25,49 @@ from . import printl from . import _core, error_below, error_close -from . import _run, core, myutils +from . import _run, core, utils + def matplotlib_cmap_to_lut( - cmap: Union[Iterable, matplotlib.colors.Colormap, str], - n_colors: int=256 - ): + cmap: Union[Iterable, matplotlib.colors.Colormap, str], n_colors: int = 256 +): if isinstance(cmap, str): cmap = plt.get_cmap(cmap) - - rgbs = [cmap(i) for i in np.linspace(0,1,n_colors)] - lut = (np.array(rgbs)*255).astype(np.uint8) + + rgbs = [cmap(i) for i in np.linspace(0, 1, n_colors)] + lut = (np.array(rgbs) * 255).astype(np.uint8) return lut + def imshow( - *images: Union[np.ndarray, dict], - labels_overlays: np.ndarray | List[np.ndarray]=None, - labels_overlays_luts: np.ndarray | List[np.ndarray]=None, - points_coords: np.ndarray=None, - points_coords_df: pd.DataFrame | List[pd.DataFrame]=None, - points_groups: List[str]=None, - points_data: Union[np.ndarray, pd.DataFrame, pd.Series]=None, - hide_axes: bool=True, - lut: Union[Iterable, matplotlib.colors.Colormap, str]=None, - autoLevels: bool=True, - autoLevelsOnScroll: bool=False, - block: bool=True, - showMaximised=False, - max_ncols=4, - axis_titles: Union[Iterable, None]=None, - parent=None, - window_title='Cell-ACDC image viewer', - figure_title='', - color_scheme=None, - link_scrollbars=True, - annotate_labels_idxs: List[int]=None, - show_duplicated_cursor=True, - selectable_images=False, - infer_rgb=True, - print_call_stack: bool=False - ): + *images: Union[np.ndarray, dict], + labels_overlays: np.ndarray | List[np.ndarray] = None, + labels_overlays_luts: np.ndarray | List[np.ndarray] = None, + points_coords: np.ndarray = None, + points_coords_df: pd.DataFrame | List[pd.DataFrame] = None, + points_groups: List[str] = None, + points_data: Union[np.ndarray, pd.DataFrame, pd.Series] = None, + hide_axes: bool = True, + lut: Union[Iterable, matplotlib.colors.Colormap, str] = None, + autoLevels: bool = True, + autoLevelsOnScroll: bool = False, + block: bool = True, + showMaximised=False, + max_ncols=4, + axis_titles: Union[Iterable, None] = None, + parent=None, + window_title="Cell-ACDC image viewer", + figure_title="", + color_scheme=None, + link_scrollbars=True, + annotate_labels_idxs: List[int] = None, + show_duplicated_cursor=True, + selectable_images=False, + infer_rgb=True, + print_call_stack: bool = False, +): if print_call_stack: - myutils.print_call_stack() + utils.print_call_stack() if isinstance(images[0], dict): images_dict = images[0] @@ -77,26 +78,27 @@ def imshow( axis_titles.append(title) if color_scheme is None: from ._palettes import get_color_scheme + color_scheme = get_color_scheme() - + if lut is None: - lut = matplotlib_cmap_to_lut('viridis') + lut = matplotlib_cmap_to_lut("viridis") if isinstance(lut, str): lut = matplotlib_cmap_to_lut(lut) if isinstance(lut, np.ndarray): - luts = [lut]*len(images) + luts = [lut] * len(images) else: luts = lut - + if luts is not None: for l in range(len(luts)): if not isinstance(luts[l], str): continue - + luts[l] = matplotlib_cmap_to_lut(luts[l]) - + casted_images = [] for image in images: if image.dtype == bool: @@ -105,7 +107,7 @@ def imshow( app = _run._setup_app() win = widgets.ImShow( - parent=parent, + parent=parent, link_scrollbars=link_scrollbars, infer_rgb=infer_rgb, figure_title=figure_title, @@ -117,23 +119,23 @@ def imshow( win.setupMainLayout() win.setupStatusBar() win.setupGraphicLayout( - *casted_images, - hide_axes=hide_axes, + *casted_images, + hide_axes=hide_axes, max_ncols=max_ncols, - color_scheme=color_scheme + color_scheme=color_scheme, ) if axis_titles is not None: win.setupTitles(*axis_titles) win.showImages( - *casted_images, + *casted_images, labels_overlays=labels_overlays, labels_overlays_luts=labels_overlays_luts, - luts=luts, - autoLevels=autoLevels, - autoLevelsOnScroll=autoLevelsOnScroll + luts=luts, + autoLevels=autoLevels, + autoLevelsOnScroll=autoLevelsOnScroll, ) if points_coords_df is not None: - win.drawPointsFromDf(points_coords_df, points_groups=points_groups) + win.drawPointsFromDf(points_coords_df, points_groups=points_groups) if points_coords is not None: points_coords = np.round(points_coords).astype(int) win.drawPoints(points_coords) @@ -142,32 +144,33 @@ def imshow( if show_duplicated_cursor: win.setupDuplicatedCursors() win.annotateObjectIDs( - annotate_labels_idxs=annotate_labels_idxs, + annotate_labels_idxs=annotate_labels_idxs, init=True, ) win.run(block=block, showMaximised=showMaximised, screenToWindowRatio=0.8) return win + def _add_colorbar_axes( - ax: plt.Axes, im: matplotlib.image.AxesImage, size='5%', pad=0.07, - label='' - ): + ax: plt.Axes, im: matplotlib.image.AxesImage, size="5%", pad=0.07, label="" +): divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = plt.colorbar(im, cax=cax) if label: cbar.set_label(label) + def _raise_non_unique_groups(grouping, dfs, groups_xx): groups_with_duplicates = {} for d, df in enumerate(dfs): if df.index.is_unique: continue group_xx = groups_xx[d] - group_with_duplicates = df.columns[0].split(';;')[1].replace('-', ', ') - duplicated_xx = group_xx[df.index.duplicated(keep='first')] + group_with_duplicates = df.columns[0].split(";;")[1].replace("-", ", ") + duplicated_xx = group_xx[df.index.duplicated(keep="first")] groups_with_duplicates[group_with_duplicates] = duplicated_xx - + duplicates = [] for group_values, duplicated_xx in groups_with_duplicates.items(): xx_name = duplicated_xx.name @@ -177,27 +180,34 @@ def _raise_non_unique_groups(grouping, dfs, groups_xx): ) duplicates.append(duplicates_str) - duplicates = '\n'.join(duplicates) + duplicates = "\n".join(duplicates) traceback.print_exc() print(error_below) - grouping_str = f'{grouping}'.strip('()').strip(',') + grouping_str = f"{grouping}".strip("()").strip(",") print(f'The groups determined by "{grouping_str}" are not unique:\n') - print(f'{duplicates}') + print(f"{duplicates}") print(error_close) exit() + def raise_missing_arg(argument_name): traceback.print_exc() print(error_below) - print(f'The argument `{argument_name}` is required.') + print(f"The argument `{argument_name}` is required.") print(error_close) exit() + def _get_groups_data( - df: pd.DataFrame, x: str, z: str, grouping: str, bin_size: int=None, - normalize_x: bool=False, zeroize_x: bool=False - ): + df: pd.DataFrame, + x: str, + z: str, + grouping: str, + bin_size: int = None, + normalize_x: bool = False, + zeroize_x: bool = False, +): grouped = df.groupby(list(grouping)) dfs = [] groups_xx = [] @@ -205,45 +215,43 @@ def _get_groups_data( max_n_decimals = None min_norm_bin_size = None if normalize_x: - min_dx = min([ - group_df[x].diff().abs().min() for _, group_df in grouped - ]) + min_dx = min([group_df[x].diff().abs().min() for _, group_df in grouped]) max_n_decimals = 0 min_norm_bin_size = np.inf for name, group_df in grouped: groups_xx.append(group_df[x]) if zeroize_x: - group_xx = group_df[x]-group_df[x].min() + group_xx = group_df[x] - group_df[x].min() else: group_xx = group_df[x] if len(grouping) == 1: name_str = str(name) else: - name_str = '-'.join([str(n) for n in name]) - group_cols = {col:f'{col};;{name_str}' for col in group_df.columns} + name_str = "-".join([str(n) for n in name]) + group_cols = {col: f"{col};;{name_str}" for col in group_df.columns} group_df = group_df.rename(columns=group_cols) - if normalize_x: + if normalize_x: max_xx = group_xx.max() - norm_dx = min_dx/max_xx + norm_dx = min_dx / max_xx min_dx_rounded = _core.round_to_significant(norm_dx, 2) - n_decimals = len(str(min_dx_rounded).split('.')[1]) + n_decimals = len(str(min_dx_rounded).split(".")[1]) if n_decimals > max_n_decimals: max_n_decimals = n_decimals - norm_xx = (group_xx/max_xx).round(n_decimals) - norm_xx_perc = norm_xx*100 + norm_xx = (group_xx / max_xx).round(n_decimals) + norm_xx_perc = norm_xx * 100 if bin_size is not None: - norm_bin_size = (bin_size/max_xx).round(n_decimals)*100 + norm_bin_size = (bin_size / max_xx).round(n_decimals) * 100 if norm_bin_size < min_norm_bin_size: min_norm_bin_size = norm_bin_size - group_df['x'] = norm_xx_perc + group_df["x"] = norm_xx_perc else: - group_df['x'] = group_xx - col_name = f'{z};;{name_str}' - dfs.append(group_df[[col_name, 'x']].dropna().set_index('x')) - - yticks_labels.append(f'{name}'.strip('()')) - + group_df["x"] = group_xx + col_name = f"{z};;{name_str}" + dfs.append(group_df[[col_name, "x"]].dropna().set_index("x")) + + yticks_labels.append(f"{name}".strip("()")) + try: df_data = pd.concat(dfs, names=[x], axis=1).sort_index() except pd.errors.InvalidIndexError as err: @@ -259,124 +267,127 @@ def _get_groups_data( n_decimals = max_n_decimals - 2 order_of_magnitude = 10**n_decimals df_data = df_data.reset_index() - df_data['x_int'] = (df_data['x']*order_of_magnitude).astype(int) - df_data = df_data.set_index('x_int').drop(columns='x') - bin_size = int(bin_size*order_of_magnitude) + df_data["x_int"] = (df_data["x"] * order_of_magnitude).astype(int) + df_data = df_data.set_index("x_int").drop(columns="x") + bin_size = int(bin_size * order_of_magnitude) df_data.index = pd.to_datetime(df_data.index.astype(int)) - rs = f'{bin_size}ns' - df_data = df_data.resample(rs, label='right').mean() - df_data.index = df_data.index.astype(np.int64)/order_of_magnitude + rs = f"{bin_size}ns" + df_data = df_data.resample(rs, label="right").mean() + df_data.index = df_data.index.astype(np.int64) / order_of_magnitude data = df_data.fillna(0).values.T xx = df_data.index return data, xx, yticks_labels + def _check_df_data_args(**kwargs): for arg_name, arg_value in kwargs.items(): - if arg_value: + if arg_value: continue if arg_value is not None: continue raise_missing_arg(arg_name) + def _raise_group_label_depth_too_deep(group_label_depth, n_levels): traceback.print_exc() print(error_below) print( - f'The `group_label_depth = {group_label_depth}` is too high, ' - f'there are only {n_levels} levels.' + f"The `group_label_depth = {group_label_depth}` is too high, " + f"there are only {n_levels} levels." ) print(error_close) exit() -def _get_heatmap_yticks( - nrows, group_height, yticks_labels, group_label_depth - ): - yticks = np.arange(0,nrows*group_height, group_height) - 0.5 + +def _get_heatmap_yticks(nrows, group_height, yticks_labels, group_label_depth): + yticks = np.arange(0, nrows * group_height, group_height) - 0.5 # yticks = yticks + group_height/2 - 0.5 if group_label_depth is not None: - df_ticks = pd.DataFrame({ - 'yticks': yticks, - 'yticks_labels': yticks_labels - }).set_index('yticks').astype(str) - df_ticks = df_ticks['yticks_labels'].str.split(',', expand=True) + df_ticks = ( + pd.DataFrame({"yticks": yticks, "yticks_labels": yticks_labels}) + .set_index("yticks") + .astype(str) + ) + df_ticks = df_ticks["yticks_labels"].str.split(",", expand=True) if group_label_depth > len(df_ticks.columns): n_levels = len(df_ticks.columns) _raise_group_label_depth_too_deep(group_label_depth, n_levels) df_ticks = df_ticks[list(range(group_label_depth))] - df_ticks['yticks_labels'] = df_ticks.agg(','.join, axis=1) - df_ticks = df_ticks.reset_index().set_index('yticks_labels') - yticks_first = df_ticks[~df_ticks.index.duplicated(keep='first')] - yticks_last = df_ticks[~df_ticks.index.duplicated(keep='last')] - yticks_start = yticks_first['yticks'] - yticks_end = yticks_last['yticks'] - yticks_center = yticks_start + (yticks_end-yticks_start)/2 + df_ticks["yticks_labels"] = df_ticks.agg(",".join, axis=1) + df_ticks = df_ticks.reset_index().set_index("yticks_labels") + yticks_first = df_ticks[~df_ticks.index.duplicated(keep="first")] + yticks_last = df_ticks[~df_ticks.index.duplicated(keep="last")] + yticks_start = yticks_first["yticks"] + yticks_end = yticks_last["yticks"] + yticks_center = yticks_start + (yticks_end - yticks_start) / 2 yticks_center = yticks_center return yticks_start, yticks_end, yticks_center + def _raise_convert_time_how(convert_time_how): print(error_below) - conversion_methods = [ - f' * {how}' for how in _core.time_units_converters.keys() - ] - conversion_methods = '\n'.join(conversion_methods) - print( - f'"{convert_time_how}" is not a valid `convert_time_how` value.\n' - ) - print( - f'Valid methods are:\n\n{conversion_methods}' - ) + conversion_methods = [f" * {how}" for how in _core.time_units_converters.keys()] + conversion_methods = "\n".join(conversion_methods) + print(f'"{convert_time_how}" is not a valid `convert_time_how` value.\n') + print(f"Valid methods are:\n\n{conversion_methods}") print(error_close) exit() + def _get_heatmap_xticks( - xx, x_unit_width, num_xticks, convert_time_how, - num_decimals_xticks_labels, x_label_loc='right', - add_x_0_label=True, x_labels=None - ): + xx, + x_unit_width, + num_xticks, + convert_time_how, + num_decimals_xticks_labels, + x_label_loc="right", + add_x_0_label=True, + x_labels=None, +): series_xindex = pd.Series(xx).repeat(x_unit_width) - if x_label_loc == 'right': + if x_label_loc == "right": series_xindex.index = series_xindex.index + 1 - elif x_label_loc == 'center': + elif x_label_loc == "center": series_xindex.index = series_xindex.index + 0.5 - elif x_label_loc == 'left': + elif x_label_loc == "left": pass - + if x_labels is not None: - series_xticks = ( - series_xindex[series_xindex.isin(x_labels)] - .drop_duplicates(keep='first') + series_xticks = series_xindex[series_xindex.isin(x_labels)].drop_duplicates( + keep="first" ) else: - resampling_step = round(len(series_xindex)/(num_xticks)) + resampling_step = round(len(series_xindex) / (num_xticks)) series_xticks = series_xindex.iloc[::resampling_step] - + xticks = series_xticks.index.to_list() xticks_labels = series_xticks.values.astype(int) - + if add_x_0_label and xticks[0] != 0: xticks = [0, *xticks] xticks_labels = np.zeros(len(xticks), dtype=int) xticks_labels[1:] = series_xticks - + if convert_time_how is None: return xticks, xticks_labels - - from_unit, to_unit = convert_time_how.split('->') + + from_unit, to_unit = convert_time_how.split("->") xticks_labels = _core.convert_time_units(xticks_labels, from_unit, to_unit) if xticks_labels is None: _raise_convert_time_how(convert_time_how) - + if num_decimals_xticks_labels is None: return xticks, xticks_labels - + xticks_labels = xticks_labels.round(num_decimals_xticks_labels) - + return xticks, xticks_labels + def _check_x_dtype(df, x, force_x_to_int): if force_x_to_int: return @@ -385,46 +396,47 @@ def _check_x_dtype(df, x, force_x_to_int): return print(error_below) print( - f'The `x` column must be of data type integer. ' - 'Pass `force_x_to_int=True` if you want to force conversion to ' - 'integers.' + f"The `x` column must be of data type integer. " + "Pass `force_x_to_int=True` if you want to force conversion to " + "integers." ) print(error_close) exit() + def heatmap( - data: Union[pd.DataFrame, np.ndarray], - x: str='', - z: str='', - y_grouping: Union[str, List[str]]='', - sort_groups: bool=True, - normalize_x: bool=False, - zeroize_x: bool=False, - x_bin_size: int=None, - x_label_loc: str='right', - x_labels: np.ndarray=None, - add_x_0_label: bool=False, - convert_time_how: str=None, - xlabel: str=None, - num_decimals_xticks_labels: int=None, - force_x_to_int: bool=False, - z_min: Union[int, float]=None, - z_max: Union[int, float]=None, - stretch_height_factor: float=None, - stretch_width_factor: float=None, - group_label_depth: int=1, - num_xticks: int=6, - colormap: Union[str, matplotlib.colors.Colormap]='viridis', - missing_values_color=None, - colorbar_pad: float= 0.07, - colorbar_size: float=0.05, - colorbar_label: str='', - ax: plt.Axes=None, - fig: plt.Figure=None, - backend: str='matplotlib', - block: bool=False, - imshow_kwargs: dict=None - ): + data: Union[pd.DataFrame, np.ndarray], + x: str = "", + z: str = "", + y_grouping: Union[str, List[str]] = "", + sort_groups: bool = True, + normalize_x: bool = False, + zeroize_x: bool = False, + x_bin_size: int = None, + x_label_loc: str = "right", + x_labels: np.ndarray = None, + add_x_0_label: bool = False, + convert_time_how: str = None, + xlabel: str = None, + num_decimals_xticks_labels: int = None, + force_x_to_int: bool = False, + z_min: Union[int, float] = None, + z_max: Union[int, float] = None, + stretch_height_factor: float = None, + stretch_width_factor: float = None, + group_label_depth: int = 1, + num_xticks: int = 6, + colormap: Union[str, matplotlib.colors.Colormap] = "viridis", + missing_values_color=None, + colorbar_pad: float = 0.07, + colorbar_size: float = 0.05, + colorbar_label: str = "", + ax: plt.Axes = None, + fig: plt.Figure = None, + backend: str = "matplotlib", + block: bool = False, + imshow_kwargs: dict = None, +): """Generate heatmap plot from data Parameters @@ -434,10 +446,10 @@ def heatmap( x : str, optional Name of the column used for the x-axis. Default is '' z : str, optional - Name of the column used for the z-axis, i.e., the values that + Name of the column used for the z-axis, i.e., the values that determine the color of each pixel. Default is '' y_grouping : Union[str, List[str]], optional - Column or list of columns that identifies a single row in the + Column or list of columns that identifies a single row in the heatmap. Default is '' sort_groups : bool, optional _description_. Default is True @@ -498,11 +510,11 @@ def heatmap( ------- _type_ _description_ - """ - + """ + if ax is None: fig, ax = plt.subplots() - + if imshow_kwargs is None: imshow_kwargs = {} @@ -518,33 +530,38 @@ def heatmap( if sort_groups: data = data.sort_values(list(y_cols)) data, xx, yticks_labels = _get_groups_data( - data, x, z, grouping=y_cols, normalize_x=normalize_x, - bin_size=x_bin_size, zeroize_x=zeroize_x + data, + x, + z, + grouping=y_cols, + normalize_x=normalize_x, + bin_size=x_bin_size, + zeroize_x=zeroize_x, ) else: - x = 'x' if not x else x - y_grouping = 'groups' if not y_grouping else y_grouping - z = 'x' if not z else z + x = "x" if not x else x + y_grouping = "groups" if not y_grouping else y_grouping + z = "x" if not z else z xx = np.arange(data.shape[-1]) if z_min is None: z_min = np.nanmin(data) - + if z_max is None: z_max = np.nanmax(data) Y, X = data.shape - group_height = round(X/Y) + group_height = round(X / Y) if stretch_height_factor is not None: - group_height = round(group_height*stretch_height_factor) - + group_height = round(group_height * stretch_height_factor) + Y, X = data.shape - x_unit_width = round(Y/X) + x_unit_width = round(Y / X) if stretch_width_factor is not None: x_unit_width = round(stretch_width_factor) - - group_height = group_height if group_height>1 else 1 - x_unit_width = x_unit_width if x_unit_width>1 else 1 + + group_height = group_height if group_height > 1 else 1 + x_unit_width = x_unit_width if x_unit_width > 1 else 1 yticks_start, yticks_end, yticks_center = _get_heatmap_yticks( len(data), group_height, yticks_labels, group_label_depth @@ -553,34 +570,39 @@ def heatmap( yticks = yticks_start.values xticks, xticks_labels = _get_heatmap_xticks( - xx, x_unit_width, num_xticks, convert_time_how, - num_decimals_xticks_labels, x_label_loc=x_label_loc, - add_x_0_label=add_x_0_label, x_labels=x_labels + xx, + x_unit_width, + num_xticks, + convert_time_how, + num_decimals_xticks_labels, + x_label_loc=x_label_loc, + add_x_0_label=add_x_0_label, + x_labels=x_labels, ) if group_height > 1: - data = np.repeat(data, [group_height]*len(data), axis=0) - + data = np.repeat(data, [group_height] * len(data), axis=0) + if x_unit_width > 1: ncols = data.shape[-1] - data = np.repeat(data, [x_unit_width]*ncols, axis=1) - xticks = [xtick*x_unit_width for xtick in xticks] - + data = np.repeat(data, [x_unit_width] * ncols, axis=1) + xticks = [xtick * x_unit_width for xtick in xticks] + if missing_values_color is not None: if isinstance(colormap, str): colormap = plt.get_cmap(colormap) bkgr_color = matplotlib.colors.to_rgba(missing_values_color) - colors = colormap(np.linspace(0,1,256)) + colors = colormap(np.linspace(0, 1, 256)) colors[0] = bkgr_color colormap = matplotlib.colors.ListedColormap(colors) if xlabel is None: xlabel = x - # Make sure to label the side of the pixel + # Make sure to label the side of the pixel xticks = np.array(xticks) - xticks = (xticks + (xticks-x_unit_width))/2 + xticks = (xticks + (xticks - x_unit_width)) / 2 xticks -= 0.5 im = ax.imshow(data, cmap=colormap, vmin=z_min, vmax=z_max, **imshow_kwargs) @@ -588,57 +610,53 @@ def heatmap( ax.set_xticks(xticks, labels=xticks_labels) ax.set_ylabel(y_grouping) ax.set_yticks(yticks, labels=yticks_labels) - - _size_perc = f'{int(colorbar_size*100)}%' - _add_colorbar_axes( - ax, im, size=_size_perc, pad=colorbar_pad, label=colorbar_label - ) - + + _size_perc = f"{int(colorbar_size * 100)}%" + _add_colorbar_axes(ax, im, size=_size_perc, pad=colorbar_pad, label=colorbar_label) + if block: plt.show() else: return fig, ax, im + def _binned_mean_stats(x, y, bins, bins_min_count): x = np.array(x).astype(float) y = np.array(y).astype(float) - bin_counts, _, _ = scipy.stats.binned_statistic( - x, y, statistic='count', bins=bins - ) + bin_counts, _, _ = scipy.stats.binned_statistic(x, y, statistic="count", bins=bins) bin_means, bin_edges, _ = scipy.stats.binned_statistic( x, y, bins=bins, statistic=np.nanmean ) - bin_std, _, _ = scipy.stats.binned_statistic( - x, y, statistic=np.nanstd, bins=bins - ) - bin_width = (bin_edges[1] - bin_edges[0]) - bin_centers = bin_edges[1:] - bin_width/2 + bin_std, _, _ = scipy.stats.binned_statistic(x, y, statistic=np.nanstd, bins=bins) + bin_width = bin_edges[1] - bin_edges[0] + bin_centers = bin_edges[1:] - bin_width / 2 if bins_min_count > 1: bin_centers = bin_centers[bin_counts > bins_min_count] bin_means = bin_means[bin_counts > bins_min_count] bin_std = bin_std[bin_counts > bins_min_count] bin_counts = bin_counts[bin_counts > bins_min_count] - std_err = bin_std/np.sqrt(bin_counts) + std_err = bin_std / np.sqrt(bin_counts) return bin_centers, bin_means, bin_std, std_err + def binned_means_plot( - x: Union[str, Iterable] = None, - y: Union[str, Iterable] = None, - bins: Union[int, Iterable] = 10, - bins_min_count: int = 1, - data: pd.DataFrame = None, - ci_plot: Literal['errorbar', 'fill_between']='errorbar', - scatter: bool = True, - line_plot = True, - use_std_err: bool = True, - color = None, - label = None, - scatter_kws = None, - errorbar_kws = None, - fill_between_kws = None, - ax: plt.Axes = None, - scatter_colors = None - ): + x: Union[str, Iterable] = None, + y: Union[str, Iterable] = None, + bins: Union[int, Iterable] = 10, + bins_min_count: int = 1, + data: pd.DataFrame = None, + ci_plot: Literal["errorbar", "fill_between"] = "errorbar", + scatter: bool = True, + line_plot=True, + use_std_err: bool = True, + color=None, + label=None, + scatter_kws=None, + errorbar_kws=None, + fill_between_kws=None, + ax: plt.Axes = None, + scatter_colors=None, +): if ax is None: fig, ax = plt.subplots(1) @@ -655,16 +673,16 @@ def binned_means_plot( if color is None: color = sns.color_palette(n_colors=1)[0] - + if scatter_kws is None: - scatter_kws = {'alpha': 0.3} - - if 'alpha' not in scatter_kws: - scatter_kws['alpha'] = 0.3 - + scatter_kws = {"alpha": 0.3} + + if "alpha" not in scatter_kws: + scatter_kws["alpha"] = 0.3 + if label is None: - label = '' - + label = "" + if scatter_colors is None: scatter_colors = color @@ -672,34 +690,34 @@ def binned_means_plot( if scatter: ax.scatter(x, y, color=scatter_colors, **scatter_kws) yerr = std_err if use_std_err else std - - if ci_plot == 'errorbar': + + if ci_plot == "errorbar": if errorbar_kws is None: - errorbar_kws = {'capsize': 3, 'lw': 2} - + errorbar_kws = {"capsize": 3, "lw": 2} + if not line_plot: - fmt = '.' + fmt = "." else: - fmt = '' - + fmt = "" + ax.errorbar( xe, ye, yerr=yerr, fmt=fmt, color=color, label=label, **errorbar_kws ) - elif ci_plot == 'fill_between': + elif ci_plot == "fill_between": if fill_between_kws is None: - fill_between_kws = {'alpha': 0.3} - + fill_between_kws = {"alpha": 0.3} + if line_plot: ax.plot(xe, ye, color=color, label=label) - label = '' - + label = "" + ax.fill_between( - xe, ye-yerr, ye+yerr, color=color, label=label, **fill_between_kws + xe, ye - yerr, ye + yerr, color=color, label=label, **fill_between_kws ) - return ax + def text_to_pg_scatter_symbol(text: str, font=None, return_scale=False): if font is None: font = QtGui.QFont() @@ -708,27 +726,28 @@ def text_to_pg_scatter_symbol(text: str, font=None, return_scale=False): symbol = QtGui.QPainterPath() symbol.addText(0, 0, font, text) br = symbol.boundingRect() - scale = min(1. / br.width(), 1. / br.height()) + scale = min(1.0 / br.width(), 1.0 / br.height()) tr = QtGui.QTransform() tr.scale(scale, scale) - tr.translate(-br.x() - br.width()/2., -br.y() - br.height()/2.) + tr.translate(-br.x() - br.width() / 2.0, -br.y() - br.height() / 2.0) symbol = tr.map(symbol) if return_scale: return symbol, scale else: return symbol + def get_symbol_sizes(scales: dict, symbols: dict, size: int): scales_arr = np.array([scales[text] for text in symbols.keys()]) - normalized_scales = scales_arr/scales_arr.max() - sizes = np.round(size/normalized_scales).astype(int) - sizes = {text:scale for text, scale in zip(symbols.keys(), sizes)} + normalized_scales = scales_arr / scales_arr.max() + sizes = np.round(size / normalized_scales).astype(int) + sizes = {text: scale for text, scale in zip(symbols.keys(), sizes)} return sizes + def texts_to_pg_scatter_symbols( - texts: Union[str, list], font=None, progress=True, - return_scales=False - ): + texts: Union[str, list], font=None, progress=True, return_scales=False +): if font is None: font = QtGui.QFont() font.setPixelSize(11) @@ -736,7 +755,7 @@ def texts_to_pg_scatter_symbols( texts = [texts] if progress: - pbar = tqdm(total=len(texts)*2, ncols=100) + pbar = tqdm(total=len(texts) * 2, ncols=100) symbols = {} scales = {} @@ -744,102 +763,108 @@ def texts_to_pg_scatter_symbols( symbol = QtGui.QPainterPath() symbol.addText(0, 0, font, text) br = symbol.boundingRect() - scale = min(1. / br.width(), 1. / br.height()) + scale = min(1.0 / br.width(), 1.0 / br.height()) if progress: pbar.update() tr = QtGui.QTransform() tr.scale(scale, scale) - tr.translate(-br.x() - br.width()*0.5, -br.y() - br.height()*0.5) + tr.translate(-br.x() - br.width() * 0.5, -br.y() - br.height() * 0.5) symbols[text] = tr.map(symbol) scales[text] = scale if progress: pbar.update() - + if progress: pbar.close() - + if return_scales: return symbols, scales else: return symbols + def plt_contours( - ax, lab=None, rp=None, plot_kwargs=None, only_IDs=None, - clear_borders=True, obj_contours_kwargs=None - ): + ax, + lab=None, + rp=None, + plot_kwargs=None, + only_IDs=None, + clear_borders=True, + obj_contours_kwargs=None, +): if rp is None: rp = skimage.measure.regionprops(lab) if plot_kwargs is None: plot_kwargs = {} - + if obj_contours_kwargs is None: obj_contours_kwargs = {} - + for obj in rp: if only_IDs is not None and obj.label not in only_IDs: continue - + contours = core.get_obj_contours(obj, **obj_contours_kwargs) if not isinstance(contours, list): - contours = [contours] - + contours = [contours] + for contour in contours: xx = contour[:, 0] yy = contour[:, 1] if clear_borders: - valid_mask = np.logical_and(xx>0.5, yy>0.5) + valid_mask = np.logical_and(xx > 0.5, yy > 0.5) xx = xx[valid_mask] yy = yy[valid_mask] - + ax.plot(xx, yy, **plot_kwargs) + def plt_moth_bud_lines( - ax, cca_df, lab=None, rp=None, plot_kwargs=None, - only_moth_IDs=None - ): + ax, cca_df, lab=None, rp=None, plot_kwargs=None, only_moth_IDs=None +): if rp is None: rp = skimage.measure.regionprops(lab) if plot_kwargs is None: plot_kwargs = {} - - rp_mapper = {obj.label:obj for obj in rp} - + + rp_mapper = {obj.label: obj for obj in rp} + for obj in rp: - ccs = cca_df.at[obj.label, 'cell_cycle_stage'] - if ccs == 'G1': + ccs = cca_df.at[obj.label, "cell_cycle_stage"] + if ccs == "G1": continue - - status = cca_df.at[obj.label, 'relationship'] - if status == 'mother': + + status = cca_df.at[obj.label, "relationship"] + if status == "mother": continue - - mothID = cca_df.at[obj.label, 'relative_ID'] + + mothID = cca_df.at[obj.label, "relative_ID"] if only_moth_IDs is not None and mothID not in only_moth_IDs: continue - + moth_obj = rp_mapper[mothID] - + y1, x1 = obj.centroid y2, x2 = moth_obj.centroid - + ax.plot([x1, x2], [y1, y2], **plot_kwargs) -if __name__ == '__main__': + +if __name__ == "__main__": x = np.arange(0, 1000).astype(float) - y = 2*x+10 + y = 2 * x + 10 noise = np.random.normal(0, 100, size=1000) y += noise - data = pd.DataFrame({'x': x, 'y': y}) + data = pd.DataFrame({"x": x, "y": y}) nbins = 10 bins_min_count = 10 binned_means_plot( - x='x', y='y', data=data, nbins=nbins, bins_min_count=bins_min_count + x="x", y="y", data=data, nbins=nbins, bins_min_count=bins_min_count ) - + plt.show() - \ No newline at end of file diff --git a/cellacdc/preprocess.py b/cellacdc/preprocess.py index 60d006140..67ab4b11d 100644 --- a/cellacdc/preprocess.py +++ b/cellacdc/preprocess.py @@ -1,20 +1,21 @@ """ -This module contains the functions that can be used as pre-processing steps -before segmentation. +This module contains the functions that can be used as pre-processing steps +before segmentation. -These functions are automatically added to `apps.QDialogModelParams` and they -can be selected in the pre-processing recipe. +These functions are automatically added to `apps.QDialogModelParams` and they +can be selected in the pre-processing recipe. -Every function must have a single argument for the image, while all -other parameters must be keyword arguments. +Every function must have a single argument for the image, while all +other parameters must be keyword arguments. -Functions that should not be used as pre-processing steps must start with `_`. -The list of functions is generated in the module `cellacdc.config` +Functions that should not be used as pre-processing steps must start with `_`. +The list of functions is generated in the module `cellacdc.config` (see PREPROCESS_MAPPER variable). -IMPORTANT: Do not import functions otherwise they will be added as possible +IMPORTANT: Do not import functions otherwise they will be added as possible step (for example do not do `from skimage.util import img_as_ubyte`). """ + from typing import Hashable, Union, Optional, Tuple from tqdm import tqdm @@ -25,6 +26,7 @@ try: import cupyx.scipy.ndimage import cupy as cp + CUPY_INSTALLED = True except Exception as e: CUPY_INSTALLED = False @@ -40,13 +42,11 @@ SQRT_2 = math.sqrt(2) + def remove_hot_pixels( - image, - logger_func=print, - progress=True, - apply_to_all_zslices=True - ): - """Apply a morphological opening operation to remove isolated bright + image, logger_func=print, progress=True, apply_to_all_zslices=True +): + """Apply a morphological opening operation to remove isolated bright pixels. Parameters @@ -62,7 +62,7 @@ def remove_hot_pixels( ------- (Y, X) or (Z, Y, X) numpy.ndarray Filtered image - """ + """ is_3D = image.ndim == 3 if is_3D: if progress: @@ -78,13 +78,14 @@ def remove_hot_pixels( filtered = skimage.morphology.opening(image) return filtered + def gaussian_filter( - image, - sigma: _types.Vector=0.75, - use_gpu=False, - logger_func=print, - apply_to_all_zslices=True - ): + image, + sigma: _types.Vector = 0.75, + use_gpu=False, + logger_func=print, + apply_to_all_zslices=True, +): """Multi-dimensional Gaussian filter Parameters @@ -92,11 +93,11 @@ def gaussian_filter( image : numpy.ndarray Input image (grayscale or color) to filter. sigma : types.Vector - Standard deviation for Gaussian kernel. The standard deviations of the - Gaussian filter are given for each axis as a sequence, or as a single + Standard deviation for Gaussian kernel. The standard deviations of the + Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes. use_gpu : bool, optional - If True, uses `cupy` instead of `skimage.filters.gaussian`. + If True, uses `cupy` instead of `skimage.filters.gaussian`. Default is False logger_func : callable, optional Function used to log information. Default is print @@ -109,48 +110,45 @@ def gaussian_filter( See also -------- Wikipedia link: `Gaussian blur `_ - """ + """ try: if len(sigma) > 1 and sigma[0] == 0: return image except Exception as err: pass - + try: if sigma == 0: return image except Exception as err: pass - + try: if len(sigma) == 0: sigma = sigma[0] except Exception as err: pass - + if CUPY_INSTALLED and use_gpu: try: image = cp.array(image, dtype=float) filtered = cupyx.scipy.ndimage.gaussian_filter(image, sigma) filtered = cp.asnumpy(filtered) except Exception as err: - logger_func('*'*100) + logger_func("*" * 100) logger_func(err) logger_func( - '[WARNING]: GPU acceleration of the gaussian filter failed. ' - f'Using CPU...{error_up_str}' + "[WARNING]: GPU acceleration of the gaussian filter failed. " + f"Using CPU...{error_up_str}" ) filtered = skimage.filters.gaussian(image, sigma=sigma) else: filtered = skimage.filters.gaussian(image, sigma=sigma) return filtered -def ridge_filter( - image, - sigmas: _types.Vector=(1.0, 2.0), - apply_to_all_zslices=True - ): - """Filter used to enhance network-like structures (Sato filter). More info + +def ridge_filter(image, sigmas: _types.Vector = (1.0, 2.0), apply_to_all_zslices=True): + """Filter used to enhance network-like structures (Sato filter). More info here https://scikit-image.org/docs/stable/auto_examples/edges/plot_ridge_filter.html Parameters @@ -164,20 +162,21 @@ def ridge_filter( ------- (Y, X) or (Z, Y, X) numpy.ndarray Filtered image - """ + """ input_shape = image.shape filtered = skimage.filters.sato( np.squeeze(image), sigmas=sigmas, black_ridges=False ).reshape(input_shape) return filtered + def spot_detector_filter( - image, - spots_zyx_radii_pxl: _types.Vector=(3, 5, 5), - use_gpu=False, - logger_func=print, - apply_to_all_zslices=True - ): + image, + spots_zyx_radii_pxl: _types.Vector = (3, 5, 5), + use_gpu=False, + logger_func=print, + apply_to_all_zslices=True, +): """Spot detection using Difference of Gaussians filter. Parameters @@ -185,10 +184,10 @@ def spot_detector_filter( image : (Y, X) or (Z, Y, X) numpy.ndarray Input image spots_zyx_radii_pxl : sequence of floats, one for each dimension, optional - Expected size of the spots in pixels. One size for each dimension in + Expected size of the spots in pixels. One size for each dimension in `image`. Default is (3, 5, 5) use_gpu : bool, optional - If `True` uses GPU if `cupy` is installed and a CUDA-compatible GPU + If `True` uses GPU if `cupy` is installed and a CUDA-compatible GPU is available . Default is False logger_func : callable, optional Function used to log additional information on progress. Default is print @@ -202,45 +201,42 @@ def spot_detector_filter( ------ TypeError Error raised when on of the input sigmas is zero. - """ + """ spots_zyx_radii_pxl = np.array(spots_zyx_radii_pxl) if image.ndim == 2 and len(spots_zyx_radii_pxl) == 3: spots_zyx_radii_pxl = spots_zyx_radii_pxl[1:] - - sigma1 = spots_zyx_radii_pxl/(1+SQRT_2) - + + sigma1 = spots_zyx_radii_pxl / (1 + SQRT_2) + if 0 in sigma1: raise TypeError( - f'Sharpening filter input sigmas cannot be 0. `zyx_sigma1 = {sigma1}`' + f"Sharpening filter input sigmas cannot be 0. `zyx_sigma1 = {sigma1}`" ) - - blurred1 = gaussian_filter( - image, sigma1, use_gpu=use_gpu, logger_func=logger_func - ) - - sigma2 = SQRT_2*sigma1 - blurred2 = gaussian_filter( - image, sigma2, use_gpu=use_gpu, logger_func=logger_func - ) - + + blurred1 = gaussian_filter(image, sigma1, use_gpu=use_gpu, logger_func=logger_func) + + sigma2 = SQRT_2 * sigma1 + blurred2 = gaussian_filter(image, sigma2, use_gpu=use_gpu, logger_func=logger_func) + sharpened = blurred1 - blurred2 - + out_range = (image.min(), image.max()) - in_range = 'image' + in_range = "image" sharp_rescaled = skimage.exposure.rescale_intensity( sharpened, in_range=in_range, out_range=out_range ) - + return sharp_rescaled + def correct_illumination( - image, - block_size=45, - # rescale_illumination=True, - approximate_object_diameter=15, - # background_threshold=0.3, - apply_gaussian_filter=True - ): + image, + block_size=45, + # rescale_illumination=True, + approximate_object_diameter=15, + # background_threshold=0.3, + apply_gaussian_filter=True, +): """ Correct illumination of an image. Based on CellProfiler's illumination correction. @@ -291,11 +287,9 @@ def correct_illumination( return corrected_image -def enhance_speckles(img, - radius=15, - apply_to_all_zslices=False - ): - """Enhance speckles in an image using white_tophat. Based on + +def enhance_speckles(img, radius=15, apply_to_all_zslices=False): + """Enhance speckles in an image using white_tophat. Based on EnhanceOrSuppressFeatures from Cell profiler with 'Feature type: Speckles' Parameters @@ -303,7 +297,7 @@ def enhance_speckles(img, image : np.ndarray 2D image to enhance radius : int, optional - Radius to use for the enhancer. Will suppress objects smaller than this + Radius to use for the enhancer. Will suppress objects smaller than this radius. Default is 15 Returns @@ -318,18 +312,19 @@ def enhance_speckles(img, output_image = skimage.morphology.white_tophat(img, footprint=footprint) return output_image + def fucci_filter( - image, - correct_illumination_toggle=False, - do_basicpy_background_correction=True, - enhance_speckles_toggle=True, - block_size=120, - # rescale_illumination=False, - approximate_object_diameter=25, - # background_threshold=0.3, - apply_gaussian_filter=True, - speckle_radius=25 - ): + image, + correct_illumination_toggle=False, + do_basicpy_background_correction=True, + enhance_speckles_toggle=True, + block_size=120, + # rescale_illumination=False, + approximate_object_diameter=25, + # background_threshold=0.3, + apply_gaussian_filter=True, + speckle_radius=25, +): """Basic filter pipeline proposed for Fucci images. If you want custom pipelines and more in depth control, create your own recipe using the GUI or segmentation and tracking modules. @@ -339,13 +334,13 @@ def fucci_filter( image : (Y, X) numpy.ndarray 2D image to correct correct_illumination_toggle : bool, optional - If illumination should be corrected. + If illumination should be corrected. Default is True do_basicpy_background_correction : bool, optional - If BaSiC background correction should be applied. + If BaSiC background correction should be applied. Default is False enhance_speckles_toggle : bool, optional - If speckles should be enhanced. + If speckles should be enhanced. Default is True block_size : int, optional Block size for which to calculate the background illumination. @@ -354,16 +349,16 @@ def fucci_filter( # if illumination should be rescaled with skimage.exposure.rescale_intensity range=(0, 1). # Default is True approximate_object_diameter : int, optional - Approximate object diameter for gaussian_filter. + Approximate object diameter for gaussian_filter. Default is 25 # background_threshold : float, optional - # Threshold to be used to determine the background. + # Threshold to be used to determine the background. # Default is 0.3 apply_gaussian_filter : bool, optional - If gaussian_filter should be applied to the illumination_function. + If gaussian_filter should be applied to the illumination_function. Default is True speckle_radius : int, optional - Radius to use for the enhancer. Will suppress objects smaller than this + Radius to use for the enhancer. Will suppress objects smaller than this radius. Default is 25 Returns @@ -373,14 +368,14 @@ def fucci_filter( """ if do_basicpy_background_correction: image = basicpy_background_correction( - image, + image, apply_to_all_frames=False, apply_to_all_zslices=False, ) if correct_illumination_toggle: image = correct_illumination( - image, - block_size=block_size, + image, + block_size=block_size, # rescale_illumination=rescale_illumination, approximate_object_diameter=approximate_object_diameter, # background_threshold=background_threshold, @@ -388,136 +383,132 @@ def fucci_filter( ) if enhance_speckles_toggle: image = enhance_speckles(image, radius=speckle_radius) - + return image + def dummy_filter( - image: np.ndarray, - apply_to_all_zslices=False, - apply_to_all_frames=False - ): + image: np.ndarray, apply_to_all_zslices=False, apply_to_all_frames=False +): printl(image.shape) return image + class VolumeImageData: def __init__(self): self._data = {} - def __setitem__( - self, - z_slice: int, - image: np.ndarray - ): + def __setitem__(self, z_slice: int, image: np.ndarray): if not isinstance(z_slice, (int, str)): raise TypeError( - f'{z_slice} is not not a valid index. ' - f'It must be an integer or a string and not {type(z_slice)}' + f"{z_slice} is not not a valid index. " + f"It must be an integer or a string and not {type(z_slice)}" ) - + if image.ndim != 2: raise TypeError( - 'Only 2D images can be assigned to a specifc z-slice index.' + "Only 2D images can be assigned to a specifc z-slice index." ) - + self._data[z_slice] = image - - def __getitem__( - self, z_slice: Union[int, Tuple[Union[int, slice]], None] - ): + + def __getitem__(self, z_slice: Union[int, Tuple[Union[int, slice]], None]): if isinstance(z_slice, int): return self._data[z_slice] - + arr = self._build_arr() return arr[z_slice] - + def __array__(self) -> np.ndarray: return self._build_arr() - + def __repr__(self): return str(self._data) - + def _build_arr(self): if not self._data: return - + img = self._data[0] SizeZ = len(self._data) arr = np.zeros((SizeZ, *img.shape), dtype=img.dtype) for z_slice, img in self._data.items(): arr[z_slice] = img return np.squeeze(arr) - + def max(self, axis=None): arr = self._build_arr() if arr is None: return - + return arr.max(axis=axis) def min(self, axis=None): arr = self._build_arr() if arr is None: return - + return arr.min(axis=axis) - + def mean(self, axis=None): arr = self._build_arr() if arr is None: return - + return arr.mean(axis=axis) - + + class PreprocessedData: def __init__(self, image_data=None): self._data = {} if image_data is not None: self._init_data(image_data) - + def _init_data(self, image_data): for frame_i, img in enumerate(image_data): self[frame_i] = img - + def __getitem__(self, frame_i: int): if frame_i not in self._data: self._data[frame_i] = VolumeImageData() - + return self._data[frame_i] - + def __setitem__(self, frame_i: int, image: np.ndarray): if not isinstance(frame_i, int): raise TypeError( - f'{frame_i} is not not a valid index. ' - f'It must be an integer and not {type(frame_i)}' + f"{frame_i} is not not a valid index. " + f"It must be an integer and not {type(frame_i)}" ) - + if frame_i not in self._data: self._data[frame_i] = VolumeImageData() - + if image.ndim == 2: self._data[frame_i][0] = image else: for z_slice, img in enumerate(image): self._data[frame_i][z_slice] = img - + def __repr__(self): return str(self._data) - + def get(self, frame_i: int, default_value=None): try: return self._data[frame_i] except KeyError: return default_value + def rescale_intensities( - image: np.array, - out_range_low: float=0.0, - out_range_high: float=1.0, - in_range_how: _types.RescaleIntensitiesInRangeHow='percentage', - in_range_low: float=0.0, - in_range_high: float=1.0, - apply_to_all_zslices=True, - ): + image: np.array, + out_range_low: float = 0.0, + out_range_high: float = 1.0, + in_range_how: _types.RescaleIntensitiesInRangeHow = "percentage", + in_range_low: float = 0.0, + in_range_high: float = 1.0, + apply_to_all_zslices=True, +): """Rescale the intensities of an image to a given range. Parameters @@ -529,19 +520,19 @@ def rescale_intensities( out_range_high : float, optional Max value of the output image. Default is 1.0 in_range_low : float, optional - Min value of the output image. See `in_range_how` for more details. + Min value of the output image. See `in_range_how` for more details. Default is 0.0 in_range_high : float, optional - Max value of the output image. See `in_range_how` for more details. + Max value of the output image. See `in_range_how` for more details. Default is 1.0 in_range_how : {'percentage', 'image', 'absolute'}, optional - If `percentage`, the image is first rescaled to (0, 1) using the - minimum and maximum value of the input image. This allows to specify - the input range as a percentage of the image intensity range. - If `image`, the input range is the minimum and maximum value of the - input image. - If `absolute`, the input range is specified by `in_range_low` and - `in_range_high` in absolute values (same scale as the input image). + If `percentage`, the image is first rescaled to (0, 1) using the + minimum and maximum value of the input image. This allows to specify + the input range as a percentage of the image intensity range. + If `image`, the input range is the minimum and maximum value of the + input image. + If `absolute`, the input range is specified by `in_range_low` and + `in_range_high` in absolute values (same scale as the input image). Default is 'percentage'. apply_to_all_zslices : bool, optional Scale intensities across multi-dimensional images. Default is True @@ -552,33 +543,36 @@ def rescale_intensities( The rescaled image """ out_range = (out_range_low, out_range_high) - if in_range_how == 'image': - in_range = 'image' - elif in_range_how == 'percentage': + if in_range_how == "image": + in_range = "image" + elif in_range_how == "percentage": image = skimage.exposure.rescale_intensity( - image, in_range='image', out_range=(0, 1) + image, in_range="image", out_range=(0, 1) ) - in_range = (in_range_low, in_range_high) # which now will be in (0, 1) - elif in_range_how == 'absolute': + in_range = (in_range_low, in_range_high) # which now will be in (0, 1) + elif in_range_how == "absolute": in_range = (in_range_low, in_range_high) - + rescaled = skimage.exposure.rescale_intensity( image, in_range=in_range, out_range=out_range ) return rescaled + def _init_dummy_filter(**kwargs): """ - This function runs automatically as part of the preprocessing recipe if - the user selects the 'dummy_filter' step. The 'dummy_filter' is available - only in debug mode. Initialization functions run in the main GUI thread - and they can be used to set up the related function, for example to + This function runs automatically as part of the preprocessing recipe if + the user selects the 'dummy_filter' step. The 'dummy_filter' is available + only in debug mode. Initialization functions run in the main GUI thread + and they can be used to set up the related function, for example to prompt the user that a package needs to be installed. """ pass + def _init_basicpy_background_correction(**kwargs): - from . import myutils + from . import utils + custom_install_requires = [ "hyperactive>=4.4.0", "jax>=0.3.10,<0.4.23", @@ -587,62 +581,63 @@ def _init_basicpy_background_correction(**kwargs): "pooch", "pydantic>=2.7.0,<3.0.0", "scikit-image", - "scipy", # this will theoretically have the wrong version of scipy in the end - ] - - myutils.check_install_custom_dependencies( - custom_install_requires, 'basicpy', parent=kwargs.get('parent') + "scipy", # this will theoretically have the wrong version of scipy in the end + ] + + utils.check_install_custom_dependencies( + custom_install_requires, "basicpy", parent=kwargs.get("parent") ) + def basicpy_background_correction( - images, - apply_to_all_frames=False, - apply_to_all_zslices=False, - smoothness_flatfield=1.0, - get_darkfield=True, - smoothness_darkfield=1.0, - sparse_cost_darkfield=0.01, - # baseline=None, - # darkfield=None, - fitting_mode: _types.BaSiCpyFittingModes="ladmap", - epsilon=0.1, - # flatfield=None, - autosegment=False, - autosegment_margin=10, - max_iterations=500, - max_reweight_iterations=10, - max_reweight_iterations_baseline=5, - # max_workers=2, - rho=1.5, - mu_coef=12.5, - max_mu_coef=10000000.0, - optimization_tol=0.001, - optimization_tol_diff=0.01, - resize_mode: _types.BaSiCpyResizeModes="jax", - resize_params: _types.NotGUIParam=None, - reweighting_tol=0.01, - sort_intensity=False, - working_size=128, - timelapse: _types.BaSiCpyTimelapse="True", - parent: _types.NotGUIParam=None - ): + images, + apply_to_all_frames=False, + apply_to_all_zslices=False, + smoothness_flatfield=1.0, + get_darkfield=True, + smoothness_darkfield=1.0, + sparse_cost_darkfield=0.01, + # baseline=None, + # darkfield=None, + fitting_mode: _types.BaSiCpyFittingModes = "ladmap", + epsilon=0.1, + # flatfield=None, + autosegment=False, + autosegment_margin=10, + max_iterations=500, + max_reweight_iterations=10, + max_reweight_iterations_baseline=5, + # max_workers=2, + rho=1.5, + mu_coef=12.5, + max_mu_coef=10000000.0, + optimization_tol=0.001, + optimization_tol_diff=0.01, + resize_mode: _types.BaSiCpyResizeModes = "jax", + resize_params: _types.NotGUIParam = None, + reweighting_tol=0.01, + sort_intensity=False, + working_size=128, + timelapse: _types.BaSiCpyTimelapse = "True", + parent: _types.NotGUIParam = None, +): """ A function for fitting and applying BaSiC illumination correction profiles. Parameters ---------- images : (T, Z, Y, X) numpy.ndarray - Image. Make sure to set have (T, Z, Y, X) dimensions, + Image. Make sure to set have (T, Z, Y, X) dimensions, or missing dimensions - in accordance with the `apply_to_all_frames` and + in accordance with the `apply_to_all_frames` and `apply_to_all_zslices` parameters. apply_to_all_frames : bool, default=True - Whether to apply the correction to all frames. - If set to falce, assumes that the image has + Whether to apply the correction to all frames. + If set to falce, assumes that the image has no T dimension, so either (Z, Y, X) or (Y, X). apply_to_all_zslices : bool, default=True - Whether to apply the correction to all Z slices. - If set to falce, assumes that the image has + Whether to apply the correction to all Z slices. + If set to falce, assumes that the image has no Z dimension, so either (T, Y, X) or (Y, X). smoothness_flatfield : float, default=1.0 Weight of the flatfield term in the Lagrangian. @@ -687,7 +682,7 @@ def basicpy_background_correction( optimization_tol_diff : float, default=0.01 Optimization tolerance for update difference. resize_mode : str, default="jax" - Resize mode for downsampling images. Must be one of + Resize mode for downsampling images. Must be one of ['jax', 'skimage', 'skimage_dask']. resize_params : dict, default={} Parameters for the resize function. @@ -739,7 +734,7 @@ def basicpy_background_correction( images = transformation.correct_img_dimension( images, input_dims=input_dims, output_dims=output_dims ) - + from basicpy import BaSiC basic = BaSiC( @@ -767,16 +762,13 @@ def basicpy_background_correction( resize_params=resize_params, reweighting_tol=reweighting_tol, sort_intensity=sort_intensity, - working_size=working_size - ) + working_size=working_size, + ) print("Fitting BaSiC model, may take a while...") basic.fit(images) - images = basic.transform( - images, - timelapse=timelapse - ) - + images = basic.transform(images, timelapse=timelapse) + images = images.squeeze() images = np.array(images) - return images \ No newline at end of file + return images diff --git a/cellacdc/promptable_models/micro-sam/__init__.py b/cellacdc/promptable_models/micro-sam/__init__.py deleted file mode 100644 index 40ce20431..000000000 --- a/cellacdc/promptable_models/micro-sam/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import cellacdc.myutils as myutils - -myutils.check_install_microsam() \ No newline at end of file diff --git a/cellacdc/promptable_models/nnInteractive/__init__.py b/cellacdc/promptable_models/nnInteractive/__init__.py deleted file mode 100644 index da985eec7..000000000 --- a/cellacdc/promptable_models/nnInteractive/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import cellacdc.myutils as myutils - -myutils.check_install_nnInteractive() diff --git a/cellacdc/promptable_models/sam2/__init__.py b/cellacdc/promptable_models/sam2/__init__.py deleted file mode 100644 index 554a3b083..000000000 --- a/cellacdc/promptable_models/sam2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import cellacdc.myutils as myutils - -myutils.check_install_sam2() diff --git a/cellacdc/promptable_models/segment_anything/__init__.py b/cellacdc/promptable_models/segment_anything/__init__.py deleted file mode 100644 index 14bd9dda0..000000000 --- a/cellacdc/promptable_models/segment_anything/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import cellacdc.myutils as myutils - -myutils.check_install_segment_anything() diff --git a/cellacdc/prompts.py b/cellacdc/prompts.py index 86e20a3cc..a676c152e 100755 --- a/cellacdc/prompts.py +++ b/cellacdc/prompts.py @@ -13,15 +13,20 @@ if GUI_INSTALLED: from qtpy.QtWidgets import ( - QApplication, QPushButton, QHBoxLayout, QLabel, QSizePolicy + QApplication, + QPushButton, + QHBoxLayout, + QLabel, + QSizePolicy, ) from qtpy.QtCore import Qt from qtpy.QtGui import QFont from . import widgets, apps -from . import myutils, printl, html_utils, load +from . import utils, printl, html_utils, load from . import settings_folderpath + class select_channel_name: def __init__(self, which_channel=None, allow_abort=True): self.is_first_call = True @@ -31,30 +36,28 @@ def __init__(self, which_channel=None, allow_abort=True): self.allow_abort = allow_abort def _get_available_channels_from_metadata( - self, metadata_csv_path, filenames, channelExt - ): + self, metadata_csv_path, filenames, channelExt + ): df = pd.read_csv(metadata_csv_path) basename = None channel_names = None - if 'Description' not in df.columns: + if "Description" not in df.columns: return [] - channelNamesMask = df.Description.str.contains(r'channel_\d+_name') - channelNames = df[channelNamesMask]['values'].to_list() + channelNamesMask = df.Description.str.contains(r"channel_\d+_name") + channelNames = df[channelNamesMask]["values"].to_list() try: - basename = df.set_index('Description').at['basename', 'values'] + basename = df.set_index("Description").at["basename", "values"] except Exception as e: basename = None if channelNames: - # There are channel names in metadata --> check that they + # There are channel names in metadata --> check that they # are still existing as files channel_names = channelNames.copy() for chName in channelNames: chSaved = [] for file in filenames: - patterns = ( - f'{chName}.tif', f'{chName}_aligned.npz' - ) + patterns = (f"{chName}.tif", f"{chName}_aligned.npz") ends = [p for p in patterns if file.endswith(p)] if ends: pattern = ends[0] @@ -77,36 +80,40 @@ def _get_available_channels_from_metadata( if channel_names is None or basename is None: return [] - + # Add additional channels existing as file but not in metadata.csv for file in filenames: ends = [ - ext for ext in channelExt if (file.endswith(ext) - and not file.endswith('btrack_tracks.h5')) - and not file.endswith('edited.h5') + ext + for ext in channelExt + if (file.endswith(ext) and not file.endswith("btrack_tracks.h5")) + and not file.endswith("edited.h5") ] if ends: - endName = file[len(basename):] - chName = endName.replace(ends[0], '') + endName = file[len(basename) :] + chName = endName.replace(ends[0], "") if chName not in channel_names: channel_names.append(chName) - + channel_names = natsorted(channel_names) - + return channel_names - + def get_available_channels( - self, filenames, images_path, useExt=None, - channelExt=('.tif', '_aligned.npz'), - validEndnames=('aligned.npz', 'acdc_output.csv', 'segm.npz') - ): + self, + filenames, + images_path, + useExt=None, + channelExt=(".tif", "_aligned.npz"), + validEndnames=("aligned.npz", "acdc_output.csv", "segm.npz"), + ): # First check if metadata.csv already has the channel names metadata_csv_path = None - for file in myutils.listdir(images_path): - if file.endswith('metadata.csv'): + for file in utils.listdir(images_path): + if file.endswith("metadata.csv"): metadata_csv_path = os.path.join(images_path, file) break - + chNames_found = False channel_names = set() basename = None @@ -116,11 +123,11 @@ def get_available_channels( ) if channel_names: return channel_names, False - + # Find basename as intersection of filenames channel_names = set() self.basenameNotFound = False - isBasenamePresent = myutils.checkDataIntegrity(filenames, images_path) + isBasenamePresent = utils.checkDataIntegrity(filenames, images_path) if basename is None: basename = filenames[0] basename = filenames[0] @@ -130,14 +137,13 @@ def get_available_channels( validFile = False if useExt is None: validFile = True - elif ext in useExt and not file.endswith('btrack_tracks.h5'): + elif ext in useExt and not file.endswith("btrack_tracks.h5"): validFile = True elif any([file.endswith(end) for end in validEndnames]): validFile = True else: - validFile = ( - (file.find('_acdc_output_') != -1 and ext == '.csv') - or (file.find('_segm_') != -1 and ext == '.npz') + validFile = (file.find("_acdc_output_") != -1 and ext == ".csv") or ( + file.find("_segm_") != -1 and ext == ".npz" ) if not validFile: continue @@ -145,29 +151,29 @@ def get_available_channels( i, j, k = sm.find_longest_match(0, len(file), 0, len(basename)) if i > 0: continue - basename = file[i:i+k] + basename = file[i : i + k] self.basename = basename - + basenameNotFound = [False] for file in filenames: - if file.endswith('edited.h5'): + if file.endswith("edited.h5"): continue - - if file.endswith('btrack_tracks.h5'): + + if file.endswith("btrack_tracks.h5"): continue - + filename, ext = os.path.splitext(file) validImageFile = False if ext in channelExt: validImageFile = True - elif file.endswith('aligned.npz'): + elif file.endswith("aligned.npz"): validImageFile = True - filename = filename[:-len('_aligned')] - + filename = filename[: -len("_aligned")] + if not validImageFile: continue - - channel_name = filename.split(basename)[-1] + + channel_name = filename.split(basename)[-1] channel_names.add(channel_name) if channel_name == filename: # Warn that an intersection could not be found @@ -176,26 +182,29 @@ def get_available_channels( if any(basenameNotFound): self.basenameNotFound = True filenameNOext, _ = os.path.splitext(basename) - self.basename = f'{filenameNOext}_' + self.basename = f"{filenameNOext}_" if self.which_channel is not None: # Search for "phase" and put that channel first on the list - if self.which_channel == 'segm': - is_phase_contr_li = [c.lower().find('phase')!=-1 - for c in channel_names] + if self.which_channel == "segm": + is_phase_contr_li = [ + c.lower().find("phase") != -1 for c in channel_names + ] if any(is_phase_contr_li): idx = is_phase_contr_li.index(True) channel_names[0], channel_names[idx] = ( - channel_names[idx], channel_names[0]) - + channel_names[idx], + channel_names[0], + ) + channel_names = natsorted(channel_names) - + return channel_names, any(basenameNotFound) def _load_last_selection(self): last_sel_channel = None ch = self.which_channel if self.which_channel is not None: - txt_path = os.path.join(settings_folderpath, f'{ch}_last_sel.txt') + txt_path = os.path.join(settings_folderpath, f"{ch}_last_sel.txt") if os.path.exists(txt_path): with open(txt_path) as txt: last_sel_channel = txt.read() @@ -206,20 +215,21 @@ def _save_last_selection(self, selection): if self.which_channel is not None: if not os.path.exists(settings_folderpath): os.mkdir(settings_folderpath) - txt_path = os.path.join(settings_folderpath, f'{ch}_last_sel.txt') - with open(txt_path, 'w') as txt: + txt_path = os.path.join(settings_folderpath, f"{ch}_last_sel.txt") + with open(txt_path, "w") as txt: txt.write(selection) - + def askChannelName(self, filenames, images_path, ask, ch_names): from . import apps + if not ask: return ch_names filename = self.basename possibleChannelNames = [] - splits = [split for split in filename.split('_') if split] + splits = [split for split in filename.split("_") if split] possibleChannelNames = [] - for i in range(len(splits)-1): - possibleChanneName = '_'.join(splits[i+1:]) + for i in range(len(splits) - 1): + possibleChanneName = "_".join(splits[i + 1 :]) possibleChannelNames.append(possibleChanneName) possibleChannelNames = possibleChannelNames[::-1] @@ -230,8 +240,12 @@ def askChannelName(self, filenames, images_path, ask, ch_names): Filename: {filename} """) win = apps.QDialogCombobox( - 'Select channel name', possibleChannelNames, txt, - CbLabel='Select channel name: ', parent=None, centeredCombobox=True + "Select channel name", + possibleChannelNames, + txt, + CbLabel="Select channel name: ", + parent=None, + centeredCombobox=True, ) win.exec_() if win.cancel: @@ -244,24 +258,29 @@ def askChannelName(self, filenames, images_path, ask, ch_names): df_metadata, metadata_csv_path = load.get_posData_metadata( images_path, basename ) - df_metadata.at['channel_0_name', 'values'] = channel_name + df_metadata.at["channel_0_name", "values"] = channel_name df_metadata.to_csv(metadata_csv_path) ch_names, _ = self.get_available_channels(filenames, images_path) return ch_names - - def QtPrompt(self, parent, channel_names, informativeText='', - CbLabel='Select channel name: '): + def QtPrompt( + self, + parent, + channel_names, + informativeText="", + CbLabel="Select channel name: ", + ): from . import apps + font = QFont() font.setPixelSize(13) win = apps.QDialogCombobox( - 'Select channel name', + "Select channel name", channel_names, informativeText, CbLabel=CbLabel, parent=parent, - defaultChannelName=self.last_sel_channel + defaultChannelName=self.last_sel_channel, ) win.setFont(font) win.exec_() @@ -274,7 +293,7 @@ def QtPrompt(self, parent, channel_names, informativeText='', def setUserChannelName(self): if self.basenameNotFound: reverse_ch_name = self.channel_name[::-1] - idx = reverse_ch_name.find('_') + idx = reverse_ch_name.find("_") if idx != -1: self.user_ch_name = self.channel_name[-idx:] else: @@ -288,57 +307,63 @@ def _test(self, name=None, index=None, mode=None): def _abort(self): self.was_aborted = True if self.allow_abort: - exit('Execution aborted by the user') + exit("Execution aborted by the user") + def exportToImageFinished(filepath, qparent=None): from cellacdc import widgets - - txt = 'Exporting to image done!' - txt = f'{txt}

    Files were saved here:' - + + txt = "Exporting to image done!" + txt = f"{txt}

    Files were saved here:" + txt = html_utils.paragraph(txt) msg = widgets.myMessageBox(wrapText=False) msg.information( - qparent, 'Exporting image finished', txt, - commands=(filepath,), - path_to_browse=os.path.dirname(filepath) + qparent, + "Exporting image finished", + txt, + commands=(filepath,), + path_to_browse=os.path.dirname(filepath), ) -def exportToVideoFinished( - preferences, conversion_to_mp4_successful, qparent=None - ): + +def exportToVideoFinished(preferences, conversion_to_mp4_successful, qparent=None): from cellacdc import widgets - - txt = 'Exporting to video finished!' - - msg_type = 'information' + + txt = "Exporting to video finished!" + + msg_type = "information" if not conversion_to_mp4_successful: from . import urls - github_href = html_utils.href_tag('GitHub page', urls.issues_url) - msg_type = 'warning' + + github_href = html_utils.href_tag("GitHub page", urls.issues_url) + msg_type = "warning" txt = ( - f'{txt}

    ' - 'WARNING: Conversion to MP4 failed. ' - 'Video file was saved as AVI instead. ' - f'Feel free to report the issue on our {github_href}' + f"{txt}

    " + "WARNING: Conversion to MP4 failed. " + "Video file was saved as AVI instead. " + f"Feel free to report the issue on our {github_href}" ) - - txt = f'{txt}

    Files were saved here:' - + + txt = f"{txt}

    Files were saved here:" + txt = html_utils.paragraph(txt) - - - folderpath = os.path.dirname(preferences['filepath']) - commands = [preferences['filepath']] - if preferences['save_pngs']: - commands.append(preferences['pngs_folderpath']) - + + folderpath = os.path.dirname(preferences["filepath"]) + commands = [preferences["filepath"]] + if preferences["save_pngs"]: + commands.append(preferences["pngs_folderpath"]) + msg = widgets.myMessageBox(wrapText=False) getattr(msg, msg_type)( - qparent, 'Exporting video finished', txt, - commands=commands, path_to_browse=folderpath + qparent, + "Exporting video finished", + txt, + commands=commands, + path_to_browse=folderpath, ) + def askSamSaveEmbeddings(qparent=None): txt = html_utils.paragraph(""" Segment Anything Model generates image embeddings that you @@ -347,24 +372,31 @@ def askSamSaveEmbeddings(qparent=None): prompts).

    Do you want to save the image embeddings? """) - saveOnlyButton = widgets.BedPushButton('Save only embeddings') - saveButton = widgets.BedPlusLabelPushButton('Save also embeddings') - saveOnlyButton = widgets.BedPushButton('Save only embeddings') + saveOnlyButton = widgets.BedPushButton("Save only embeddings") + saveButton = widgets.BedPlusLabelPushButton("Save also embeddings") + saveOnlyButton = widgets.BedPushButton("Save only embeddings") msg = widgets.myMessageBox(wrapText=False) _, saveOnlyButton, saveButton, _ = msg.question( - qparent, 'Save SAM Image Embeddings?', txt, + qparent, + "Save SAM Image Embeddings?", + txt, buttonsTexts=( - 'Cancel', saveOnlyButton, saveButton, - widgets.NoBedPushButton('Do not save embeddings') - ) + "Cancel", + saveOnlyButton, + saveButton, + widgets.NoBedPushButton("Do not save embeddings"), + ), ) sam_only_embeddings = msg.clickedButton == saveOnlyButton sam_also_embeddings = msg.clickedButton == saveButton return sam_only_embeddings, sam_also_embeddings, msg.cancel + def askSamLoadEmbeddings( - sam_embeddings_path, qparent=None, is_gui_caller=False, - ): + sam_embeddings_path, + qparent=None, + is_gui_caller=False, +): txt = html_utils.paragraph(""" Cell-ACDC detected previously saved Segment Anything image embeddings (see file path below).

    @@ -372,18 +404,20 @@ def askSamLoadEmbeddings( Do you want to load the image embeddings? """) msg = widgets.myMessageBox(wrapText=False) - loadButton = widgets.BedPlusLabelPushButton('Load embeddings and segment') - doNotLoadButton = widgets.NoBedPushButton('Do not load embeddings') + loadButton = widgets.BedPlusLabelPushButton("Load embeddings and segment") + doNotLoadButton = widgets.NoBedPushButton("Do not load embeddings") buttons = (loadButton, doNotLoadButton) if is_gui_caller: - loadOnlyEmbedButton = widgets.BedPushButton('Only load embeddings') + loadOnlyEmbedButton = widgets.BedPushButton("Only load embeddings") buttons = (loadOnlyEmbedButton, *buttons) - + msg.question( - qparent, 'Load SAM Image Embeddings?', txt, - buttonsTexts=('Cancel', *buttons), - commands=(sam_embeddings_path,), - path_to_browse=os.path.dirname(sam_embeddings_path) + qparent, + "Load SAM Image Embeddings?", + txt, + buttonsTexts=("Cancel", *buttons), + commands=(sam_embeddings_path,), + path_to_browse=os.path.dirname(sam_embeddings_path), ) loadEmbed = msg.clickedButton == loadButton onlyLoadEmbed = False @@ -391,77 +425,86 @@ def askSamLoadEmbeddings( onlyLoadEmbed = msg.clickedButton == loadOnlyEmbedButton return loadEmbed, onlyLoadEmbed, msg.cancel + def init_prompt_model_params( - posData, model_name, init_params, segment_params, - qparent=None, help_url=None, init_last_params=False, - ini_filename=None - ): + posData, + model_name, + init_params, + segment_params, + qparent=None, + help_url=None, + init_last_params=False, + ini_filename=None, +): out = {} - + segm_files = load.get_segm_files(posData.images_path) - existingSegmEndnames = load.get_endnames( - posData.basename, segm_files - ) + existingSegmEndnames = load.get_endnames(posData.basename, segm_files) win = apps.QDialogModelParams( init_params, segment_params, - model_name, + model_name, parent=qparent, - url=help_url, - initLastParams=init_last_params, + url=help_url, + initLastParams=init_last_params, posData=posData, segmFileEndnames=existingSegmEndnames, df_metadata=posData.metadata_df, addPreProcessParams=False, addPostProcessParams=False, ini_filename=ini_filename, - add_additional_segm_params=False + add_additional_segm_params=False, ) win.setChannelNames(posData.chNames) - out['win'] = win + out["win"] = win win.exec_() return out + def init_segm_model_params( - posData, model_name, init_params, segment_params, - qparent=None, help_url=None, init_last_params=False, - check_sam_embeddings=True, is_gui_caller=False, - extraParams=None, extraParamsTitle=None, - ini_filename=None, add_additional_segm_params=False - ): + posData, + model_name, + init_params, + segment_params, + qparent=None, + help_url=None, + init_last_params=False, + check_sam_embeddings=True, + is_gui_caller=False, + extraParams=None, + extraParamsTitle=None, + ini_filename=None, + add_additional_segm_params=False, +): out = {} - - is_sam_model = ( - model_name in ('segment_anything', 'sam2') and check_sam_embeddings - ) - + + is_sam_model = model_name in ("segment_anything", "sam2") and check_sam_embeddings + # If SAM with prompts and embeddings were prev saved, asks to load them load_sam_embed = False only_load_sam_embed = False sam_embeddings_exist = os.path.exists(posData.sam_embeddings_path) - sam_embeddings_loaded = hasattr(posData, 'sam_embeddings') + sam_embeddings_loaded = hasattr(posData, "sam_embeddings") if is_sam_model and sam_embeddings_exist and not sam_embeddings_loaded: load_sam_embed, only_load_sam_embed, cancel = askSamLoadEmbeddings( - posData.sam_embeddings_path, qparent=qparent, - is_gui_caller=is_gui_caller + posData.sam_embeddings_path, qparent=qparent, is_gui_caller=is_gui_caller ) if cancel: return out - - out['load_sam_embeddings'] = only_load_sam_embed or load_sam_embed + + out["load_sam_embeddings"] = only_load_sam_embed or load_sam_embed if only_load_sam_embed: return out - + segm_files = load.get_segm_files(posData.images_path) - existingSegmEndnames = load.get_endnames( - posData.basename, segm_files - ) + existingSegmEndnames = load.get_endnames(posData.basename, segm_files) win = apps.QDialogModelParams( init_params, segment_params, - model_name, parent=qparent, - url=help_url, - initLastParams=init_last_params, + model_name, + parent=qparent, + url=help_url, + initLastParams=init_last_params, posData=posData, segmFileEndnames=existingSegmEndnames, df_metadata=posData.metadata_df, @@ -469,35 +512,32 @@ def init_segm_model_params( extraParams=extraParams, extraParamsTitle=extraParamsTitle, ini_filename=ini_filename, - add_additional_segm_params=add_additional_segm_params + add_additional_segm_params=add_additional_segm_params, ) win.setChannelNames(posData.chNames) - out['win'] = win + out["win"] = win win.exec_() if win.cancel: return out - + if load_sam_embed: - win.model_kwargs['use_loaded_embeddings'] = True + win.model_kwargs["use_loaded_embeddings"] = True posData.loadSamEmbeddings() - + ask_sam_embeddings = ( - model_name in ('segment_anything', 'sam2') + model_name in ("segment_anything", "sam2") and not load_sam_embed and check_sam_embeddings ) # If SAM and embeddings were not laoded, asks to save them if ask_sam_embeddings: - sam_only_embeddings, sam_also_embeddings, cancel = ( - askSamSaveEmbeddings(qparent=qparent) + sam_only_embeddings, sam_also_embeddings, cancel = askSamSaveEmbeddings( + qparent=qparent ) if cancel: return out - win.model_kwargs['only_embeddings'] = sam_only_embeddings - win.model_kwargs['save_embeddings'] = ( - sam_only_embeddings or sam_also_embeddings - ) - + win.model_kwargs["only_embeddings"] = sam_only_embeddings + win.model_kwargs["save_embeddings"] = sam_only_embeddings or sam_also_embeddings + return out - \ No newline at end of file diff --git a/cellacdc/qrc_resources_dark.py b/cellacdc/qrc_resources_dark.py index 4a7518461..f15931e41 100644 --- a/cellacdc/qrc_resources_dark.py +++ b/cellacdc/qrc_resources_dark.py @@ -338070,7 +338070,7 @@ \x00\x00\x01\x97\x4b\x92\x55\xaa\ " -qt_version = [int(v) for v in QtCore.qVersion().split('.')] +qt_version = [int(v) for v in QtCore.qVersion().split(".")] if qt_version < [5, 8, 0]: rcc_version = 1 qt_resource_struct = qt_resource_struct_v1 @@ -338078,10 +338078,17 @@ rcc_version = 2 qt_resource_struct = qt_resource_struct_v2 + def qInitResources(): - QtCore.qRegisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data) + QtCore.qRegisterResourceData( + rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data + ) + def qCleanupResources(): - QtCore.qUnregisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data) + QtCore.qUnregisterResourceData( + rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data + ) + qInitResources() diff --git a/cellacdc/qrc_resources_light.py b/cellacdc/qrc_resources_light.py index 4a977f553..21d48ba99 100644 --- a/cellacdc/qrc_resources_light.py +++ b/cellacdc/qrc_resources_light.py @@ -337032,7 +337032,7 @@ \x00\x00\x01\x97\x4b\x92\x55\xaa\ " -qt_version = [int(v) for v in QtCore.qVersion().split('.')] +qt_version = [int(v) for v in QtCore.qVersion().split(".")] if qt_version < [5, 8, 0]: rcc_version = 1 qt_resource_struct = qt_resource_struct_v1 @@ -337040,10 +337040,17 @@ rcc_version = 2 qt_resource_struct = qt_resource_struct_v2 + def qInitResources(): - QtCore.qRegisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data) + QtCore.qRegisterResourceData( + rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data + ) + def qCleanupResources(): - QtCore.qUnregisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data) + QtCore.qUnregisterResourceData( + rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data + ) + qInitResources() diff --git a/cellacdc/qutils.py b/cellacdc/qutils.py index 721eef9d0..352fd66a3 100644 --- a/cellacdc/qutils.py +++ b/cellacdc/qutils.py @@ -1,13 +1,10 @@ -from qtpy.QtCore import ( - Qt, QTimer, QEventLoop -) +from qtpy.QtCore import Qt, QTimer, QEventLoop from qtpy.QtWidgets import QWidget import functools + class QWhileLoop: - def __init__( - self, loop_callback, period=100, max_duration=None - ): + def __init__(self, loop_callback, period=100, max_duration=None): self._loop_callback = loop_callback self._period = period self._max_duration = max_duration @@ -22,20 +19,21 @@ def exec_(self): self.max_duration_timer.timeout.connect(self.stop) self.max_duration_timer.start(self._max_duration) self.loop.exec_() - + def stop(self): self.timer.stop() if self._max_duration is not None: self.max_duration_timer.stop() self.loop.exit() + class QControlBlink: def __init__(self, QWidgetToBlink: QWidget, duration_ms=2000, qparent=None) -> None: self.duration_ms = duration_ms self._widget = QWidgetToBlink self.qparent = qparent self.blinkON = False - + def start(self): self.timer = QTimer(self.qparent) self.timer.timeout.connect(self.timerCallback) @@ -44,17 +42,18 @@ def start(self): self.stopTimer = QTimer(self.qparent) self.stopTimer.timeout.connect(self.stop) self.stopTimer.start(self.duration_ms) - + def timerCallback(self): if self.blinkON: - self._widget.setStyleSheet('background-color: orange') + self._widget.setStyleSheet("background-color: orange") else: - self._widget.setStyleSheet('background-color: none') + self._widget.setStyleSheet("background-color: none") self.blinkON = not self.blinkON def stop(self): self.timer.stop() - self._widget.setStyleSheet('background-color: none') + self._widget.setStyleSheet("background-color: none") + def hide_and_delete_layout(layout): # Hide all widgets in the layout @@ -64,20 +63,23 @@ def hide_and_delete_layout(layout): widget.hide() layout.removeWidget(widget) widget.setParent(None) - + # Delete the layout layout.deleteLater() + def delete_widget(widget): widget.hide() widget.setParent(None) widget.deleteLater() + def replace_certain_vals(getVal, replace_val, by_val): """ Decorator: If the return value of getVal equals replace_val (type-cast to value's type), return by_val instead. Otherwise, return the original value. """ + @functools.wraps(getVal) def wrapper(*args, **kwargs): value = getVal(*args, **kwargs) @@ -88,13 +90,16 @@ def wrapper(*args, **kwargs): if value == target_val: return by_val return value + return wrapper + def set_value_no_signals(widget, value): was_blocked = widget.blockSignals(True) widget.setValue(value) widget.blockSignals(was_blocked) + def set_exclusive_valueSetter(widget, valueSetter, value): was_blocked = widget.blockSignals(True) try: @@ -103,6 +108,7 @@ def set_exclusive_valueSetter(widget, valueSetter, value): valueSetter(value) widget.blockSignals(was_blocked) + def hardDelete(item, setPosData=True): try: item.setParent(None) @@ -118,9 +124,10 @@ def hardDelete(item, setPosData=True): except AttributeError: pass item = None - + + def insert_row(layout, insert_at, new_widget, col=0, dont_shift_other_cols=False): -# Shift all widgets down by one row from insert_at onwards + # Shift all widgets down by one row from insert_at onwards for row in range(layout.rowCount() - 1, insert_at - 1, -1): for loc_col in range(layout.columnCount()): if loc_col != col and dont_shift_other_cols: @@ -129,4 +136,4 @@ def insert_row(layout, insert_at, new_widget, col=0, dont_shift_other_cols=False if item is not None: layout.removeItem(item) layout.addItem(item, row + 1, loc_col) - layout.addWidget(new_widget, insert_at, col) \ No newline at end of file + layout.addWidget(new_widget, insert_at, col) diff --git a/cellacdc/record.py b/cellacdc/record.py index 094ad654a..e97d84f1e 100644 --- a/cellacdc/record.py +++ b/cellacdc/record.py @@ -13,6 +13,7 @@ from .. import user_data_folderpath from .. import workers + class ScreenRecorderFrame(QFrame): def __init__(self, app, parent=None): super().__init__(parent) @@ -21,14 +22,14 @@ def __init__(self, app, parent=None): # Border tolerance to trigger resizing self.px = 10 self.app = app - + def mousePressEvent(self, event): x, y = event.pos().x(), event.pos().y() # x00, y00 = self._parent.x0-self.px, self._parent.y0-self.px - x01, y01 = self._parent.x0+self.px, self._parent.y0+self.px - x10, y10 = self._parent.x1-self.px, self._parent.y1-self.px + x01, y01 = self._parent.x0 + self.px, self._parent.y0 + self.px + x10, y10 = self._parent.x1 - self.px, self._parent.y1 - self.px # x11, y11 = self._parent.x1+self.px, self._parent.y1+self.px - if yy01 and xx01: + if y < y10 and y > y01 and x < x10 and x > x01: # Cursor click inside rectangle self.app.setOverrideCursor(Qt.ClosedHandCursor) self.xc, self.yc = x, y @@ -40,56 +41,55 @@ def mouseMoveEvent(self, event): return x, y = event.pos().x(), event.pos().y() - x00, y00 = self._parent.x0-self.px, self._parent.y0-self.px - x01, y01 = self._parent.x0+self.px, self._parent.y0+self.px - x10, y10 = self._parent.x1-self.px, self._parent.y1-self.px - x11, y11 = self._parent.x1+self.px, self._parent.y1+self.px - if yy01 and xx01: + x00, y00 = self._parent.x0 - self.px, self._parent.y0 - self.px + x01, y01 = self._parent.x0 + self.px, self._parent.y0 + self.px + x10, y10 = self._parent.x1 - self.px, self._parent.y1 - self.px + x11, y11 = self._parent.x1 + self.px, self._parent.y1 + self.px + if y < y10 and y > y01 and x < x10 and x > x01: # Cursor inside rectangle self.app.setOverrideCursor(Qt.OpenHandCursor) - elif yy00 and xx00: + elif y < y11 and y > y00 and x < x11 and x > x00: # Cursor on border --> determine if ver, hor or diags - if xy10: + self.corner = "topLeft" + elif x < x01 and y > y10: # Bottom left corner self.app.setOverrideCursor(Qt.SizeBDiagCursor) - self.corner = 'bottomLeft' - elif x>x10 and y x10 and y < y01: # Top right corner self.app.setOverrideCursor(Qt.SizeBDiagCursor) - self.corner = 'topRight' - elif x>x10 and y>y10: + self.corner = "topRight" + elif x > x10 and y > y10: # Bottom right corner self.app.setOverrideCursor(Qt.SizeFDiagCursor) - self.corner = 'bottomRight' - elif xx10: + self.corner = "bottomRight" + elif x < x01 or x > x10: # Left or right side self.app.setOverrideCursor(Qt.SizeHorCursor) - if x xmax: x = xmax - + if y < xmin: y = ymin elif y > ymax: y = ymax - - return x, y - + + return x, y + def mouseMoveEvent(self, event): x, y = event.pos().x(), event.pos().y() x, y = self.boundXYtoScreen(x, y) if self.app.overrideCursor() == Qt.SizeFDiagCursor: - if self.frame.corner == 'topLeft': + if self.frame.corner == "topLeft": self.x0, self.y0 = x, y self.update() else: @@ -201,7 +202,7 @@ def mouseMoveEvent(self, event): self.x1, self.y1 = x, y self.update() elif self.app.overrideCursor() == Qt.SizeBDiagCursor: - if self.frame.corner == 'bottomLeft': + if self.frame.corner == "bottomLeft": self.x0, self.y1 = x, y self.update() else: @@ -209,23 +210,23 @@ def mouseMoveEvent(self, event): self.x1, self.y0 = x, y self.update() elif self.app.overrideCursor() == Qt.SizeHorCursor: - if self.frame.corner == 'left': + if self.frame.corner == "left": self.x0 = x self.update() else: self.x1 = x self.update() elif self.app.overrideCursor() == Qt.SizeVerCursor: - if self.frame.corner == 'top': + if self.frame.corner == "top": self.y0 = y self.update() else: self.y1 = y self.update() elif self.app.overrideCursor() == Qt.ClosedHandCursor: - deltax, deltay = x-self.frame.xc, y-self.frame.yc - self.x0, self.y0 = self.x0+deltax, self.y0+deltay - self.x1, self.y1 = self.x1+deltax, self.y1+deltay + deltax, deltay = x - self.frame.xc, y - self.frame.yc + self.x0, self.y0 = self.x0 + deltax, self.y0 + deltay + self.x1, self.y1 = self.x1 + deltax, self.y1 + deltay self.frame.xc, self.frame.yc = x, y self.update() @@ -237,9 +238,7 @@ def mouseReleaseEvent(self, event): def startRecorder(self): self.thread = QThread() - self.screenGrabWorker = workers.ScreenRecorderWorker( - self, user_data_folderpath - ) + self.screenGrabWorker = workers.ScreenRecorderWorker(self, user_data_folderpath) self.screenGrabWorker.moveToThread(self.thread) self.screenGrabWorker.finished.connect(self.thread.quit) @@ -248,7 +247,7 @@ def startRecorder(self): self.thread.started.connect(self.screenGrabWorker.run) self.thread.start() - print('Recording started...') + print("Recording started...") def keyPressEvent(self, event): if event.key() == Qt.Key_Escape: diff --git a/cellacdc/resources/to_dark_mode_svg.py b/cellacdc/resources/to_dark_mode_svg.py index 896cc6145..418995d25 100644 --- a/cellacdc/resources/to_dark_mode_svg.py +++ b/cellacdc/resources/to_dark_mode_svg.py @@ -4,8 +4,8 @@ from tqdm import tqdm LIGHT_TO_DARK_MAPPER = { - '#666666': '#9a9a9a', - '#4d4d4d': '#f0f0f0', + "#666666": "#9a9a9a", + "#4d4d4d": "#f0f0f0", # '#d9d9d9': '#4d4d4d', } @@ -13,50 +13,55 @@ # Read resources_light.qrc file and extract SVG relative paths resources_folderpath = os.path.dirname(os.path.abspath(__file__)) cellacdc_path = os.path.dirname(resources_folderpath) -resources_filepath = os.path.join(cellacdc_path, 'resources_light.qrc') +resources_filepath = os.path.join(cellacdc_path, "resources_light.qrc") -qrc_resources_light_path = os.path.join(cellacdc_path, 'qrc_resources_light.py') -qrc_resources_dark_path = os.path.join(cellacdc_path, 'qrc_resources_dark.py') -qrc_resources_path = os.path.join(cellacdc_path, 'qrc_resources.py') +qrc_resources_light_path = os.path.join(cellacdc_path, "qrc_resources_light.py") +qrc_resources_dark_path = os.path.join(cellacdc_path, "qrc_resources_dark.py") +qrc_resources_path = os.path.join(cellacdc_path, "qrc_resources.py") if os.path.exists(qrc_resources_light_path): os.rename(qrc_resources_path, qrc_resources_dark_path) os.rename(qrc_resources_light_path, qrc_resources_path) - -with open(resources_filepath, 'r') as resources_file: + +with open(resources_filepath, "r") as resources_file: resources_txt = resources_file.read() resources_dark_txt = resources_txt svg_relpaths = re.findall(r'(.+)', resources_txt) -# Iterate SVGs and replace colors +# Iterate SVGs and replace colors for svg_relpath in tqdm(svg_relpaths, ncols=100): - svg_relpath_parts = svg_relpath.split('/') + svg_relpath_parts = svg_relpath.split("/") svg_abspath = os.path.join(cellacdc_path, *svg_relpath_parts) svg_folderpath = os.path.dirname(svg_abspath) - if 'icons' not in svg_relpath_parts: + if "icons" not in svg_relpath_parts: # Skip SVGs outside of the icons folder continue - + # Read svg files and replace colors - with open(svg_abspath, 'r', encoding="utf8") as svg_file: + with open(svg_abspath, "r", encoding="utf8") as svg_file: svg_text = svg_file.read() for light_hex, dark_hex in LIGHT_TO_DARK_MAPPER.items(): svg_text_dark = svg_text.replace(light_hex, dark_hex) - + # Save additional _dark.svg and replace them in resources_txt - svg_dark_abspath = svg_abspath.replace('.svg', '_dark.svg') - with open(svg_dark_abspath, 'w', encoding="utf8") as svg_file: + svg_dark_abspath = svg_abspath.replace(".svg", "_dark.svg") + with open(svg_dark_abspath, "w", encoding="utf8") as svg_file: svg_file.write(svg_text_dark) - svg_relpath_dark = svg_relpath.replace('.svg', '_dark.svg') + svg_relpath_dark = svg_relpath.replace(".svg", "_dark.svg") resources_txt = resources_txt.replace(svg_relpath, svg_relpath_dark) # Save a new resouces_dark.qrc file -with open(qrc_resources_dark_path, 'w') as resources_file: +with open(qrc_resources_dark_path, "w") as resources_file: resources_file.write(resources_txt) # Compule new qrc_resources.py dark -print('Compiling the Qt resource file...') -qrc_resources_dark_filepath = os.path.join(cellacdc_path, 'qrc_resources_dark.py') -commands = ['pyrcc5', f"{qrc_resources_dark_path}", '-o', f"{qrc_resources_dark_filepath}"] -subprocess.run(commands, check=True) \ No newline at end of file +print("Compiling the Qt resource file...") +qrc_resources_dark_filepath = os.path.join(cellacdc_path, "qrc_resources_dark.py") +commands = [ + "pyrcc5", + f"{qrc_resources_dark_path}", + "-o", + f"{qrc_resources_dark_filepath}", +] +subprocess.run(commands, check=True) diff --git a/cellacdc/scripts/correct_shift_X.py b/cellacdc/scripts/correct_shift_X.py index 3d31db1dd..1b4ccfaa0 100644 --- a/cellacdc/scripts/correct_shift_X.py +++ b/cellacdc/scripts/correct_shift_X.py @@ -19,124 +19,152 @@ import json - PREVIEW_Z_STACK = None PREVIEW_Z = None NEW_PATH_SUF = None INCLUDE_PATTERN_TIF_SEARCH = None + def load_constants(): - print('Loading constants...') + print("Loading constants...") global PREVIEW_Z_STACK global PREVIEW_Z global INCLUDE_PATTERN_TIF_SEARCH - with open('regex.txt', 'r') as input_file: + with open("regex.txt", "r") as input_file: regex_file = input_file.read() for line in regex_file.splitlines(): - if re.search('x_INCLUDE_PATTERN_TIF_SEARCH', line): - line = line.split(':', 1)[1].strip().lstrip().rstrip(',') + if re.search("x_INCLUDE_PATTERN_TIF_SEARCH", line): + line = line.split(":", 1)[1].strip().lstrip().rstrip(",") INCLUDE_PATTERN_TIF_SEARCH = line - with open('config.json', 'r') as input_file: + with open("config.json", "r") as input_file: config = json.load(input_file) - PREVIEW_Z_STACK = config['correct_shift_x']['PREVIEW_Z_STACK'] - PREVIEW_Z = config['correct_shift_x']['PREVIEW_Z'] - NEW_PATH_SUF = config['correct_shift_x']['NEW_PATH_SUF'] + PREVIEW_Z_STACK = config["correct_shift_x"]["PREVIEW_Z_STACK"] + PREVIEW_Z = config["correct_shift_x"]["PREVIEW_Z"] + NEW_PATH_SUF = config["correct_shift_x"]["NEW_PATH_SUF"] return NEW_PATH_SUF + def correct_constant_shift_X_img(img, shift): for i, row in enumerate(img[::2]): - l = i*2 + l = i * 2 img[l] = np.roll(row, shift) return img + def correct_constant_shift_X(z_stack, shift): for z, img in enumerate(z_stack): img = correct_constant_shift_X_img(img, shift) z_stack[z] = img return z_stack + def find_other_tif(file_path): folder_path = os.path.dirname(file_path) file_list = os.listdir(folder_path) - tif_files = [filename for filename in file_list if filename.lower().endswith('.tif')] + tif_files = [ + filename for filename in file_list if filename.lower().endswith(".tif") + ] return tif_files + def finding_shift(tif_data, shift, NEW_PATH_SUF): eval_img = (tif_data[PREVIEW_Z_STACK][PREVIEW_Z]).copy() - eval_img = correct_constant_shift_X_img(eval_img, shift) + eval_img = correct_constant_shift_X_img(eval_img, shift) imshow(tif_data[PREVIEW_Z_STACK][PREVIEW_Z], eval_img) while True: - answer = input('Do you want to proceed with the shift or change it ([y]/n/"number"/help)? ') - if answer.lower() == 'n': + answer = input( + 'Do you want to proceed with the shift or change it ([y]/n/"number"/help)? ' + ) + if answer.lower() == "n": exit() elif answer.isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) return shift - elif answer.lstrip('-').isdigit(): + elif answer.lstrip("-").isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) return shift - elif answer.lower() == 'help': - print('Change the shown image by changing PREVIEW_Z_STACK and PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: ' + str(PREVIEW_Z_STACK) + ' ' +str(PREVIEW_Z) + '\nCurrent ending: ' + NEW_PATH_SUF) + elif answer.lower() == "help": + print( + "Change the shown image by changing PREVIEW_Z_STACK and PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: " + + str(PREVIEW_Z_STACK) + + " " + + str(PREVIEW_Z) + + "\nCurrent ending: " + + NEW_PATH_SUF + ) finding_shift(tif_data, shift, NEW_PATH_SUF) return shift elif not answer: return shift - elif answer.lower() == 'y': + elif answer.lower() == "y": return shift else: - print('The input is not an integer') - + print("The input is not an integer") + def shiftingstuff_main(shift, tif_data, tif_path, NEW_PATH_SUF): corrected_data = tif_data.copy() for frame_i, img in enumerate(tif_data): corrected_data[frame_i] = correct_constant_shift_X(img.copy(), shift) - new_path = tif_path.replace('.tif', NEW_PATH_SUF + '.tif' ) + new_path = tif_path.replace(".tif", NEW_PATH_SUF + ".tif") skimage.io.imsave(new_path, corrected_data, check_contrast=False) del corrected_data del tif_data return + def shiftingstuff_other(tif_name, shift, tif_path, scan_other, NEW_PATH_SUF): if scan_other == True: - tif_path = os.path.join(os.path.dirname(tif_path), tif_name) + tif_path = os.path.join(os.path.dirname(tif_path), tif_name) tif_data = load.imread(tif_path) shiftingstuff_main(shift, tif_data, tif_path, NEW_PATH_SUF) del tif_data return + def sequential(NEW_PATH_SUF): parser = argparse.ArgumentParser() - parser.add_argument('tif_path', help='Path to the tif-file') - parser.add_argument('shift', help='Amount of shift') + parser.add_argument("tif_path", help="Path to the tif-file") + parser.add_argument("shift", help="Amount of shift") args = parser.parse_args() tif_path = args.tif_path shift = int(args.shift) - print('Path: \n' + tif_path) - print('Original Shift: ' + str(shift)) + print("Path: \n" + tif_path) + print("Original Shift: " + str(shift)) tif_data = load.imread(tif_path) - print('Please close the window after inspecting if the shift value is right in order to proceed.') + print( + "Please close the window after inspecting if the shift value is right in order to proceed." + ) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) - print('Shift used: ' +str(shift)) + print("Shift used: " + str(shift)) - tif_files = find_other_tif(tif_path) - tif_names = [tif_file for tif_file in tif_files if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file)] - print('New tif file(s) found:\n' + "\n".join(tif_names)) + tif_files = find_other_tif(tif_path) + tif_names = [ + tif_file + for tif_file in tif_files + if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file) + ] + print("New tif file(s) found:\n" + "\n".join(tif_names)) while True: - answer = input('Do you want to shift the other .tif files in the folder too? ([y]/n/help)') - if answer.lower() == 'n': + answer = input( + "Do you want to shift the other .tif files in the folder too? ([y]/n/help)" + ) + if answer.lower() == "n": scan_other = False break - elif answer.lower() == 'help': - print('You can change the regex pattern in the beginning of the code (EXCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: ' + INCLUDE_PATTERN_TIF_SEARCH) + elif answer.lower() == "help": + print( + "You can change the regex pattern in the beginning of the code (EXCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: " + + INCLUDE_PATTERN_TIF_SEARCH + ) exit() else: scan_other = True @@ -149,8 +177,15 @@ def sequential(NEW_PATH_SUF): shift, tif_data, tif_names, scan_other, tif_path = sequential(NEW_PATH_SUF) with concurrent.futures.ProcessPoolExecutor() as executor: futures = [] - futures = [executor.submit(shiftingstuff_other, tif_name, shift, tif_path, scan_other, NEW_PATH_SUF) for tif_name in tif_names] - futures.append(executor.submit(shiftingstuff_main, shift, tif_data, tif_path, NEW_PATH_SUF)) + futures = [ + executor.submit( + shiftingstuff_other, tif_name, shift, tif_path, scan_other, NEW_PATH_SUF + ) + for tif_name in tif_names + ] + futures.append( + executor.submit(shiftingstuff_main, shift, tif_data, tif_path, NEW_PATH_SUF) + ) results = [future.result() for future in futures] - print('Done!') - exit() \ No newline at end of file + print("Done!") + exit() diff --git a/cellacdc/scripts/correct_shift_X_multi.py b/cellacdc/scripts/correct_shift_X_multi.py index 519ec78ca..728551a36 100644 --- a/cellacdc/scripts/correct_shift_X_multi.py +++ b/cellacdc/scripts/correct_shift_X_multi.py @@ -11,13 +11,17 @@ from cellacdc.plot import imshow -#Change this if your data structure is different:# +# Change this if your data structure is different:# def finding_base_tif_files_path(root_path): print(INCLUDE_PATTERN_TIF_BASESEARCH) - base_tif_files_paths =[] + base_tif_files_paths = [] tif_files_paths = [] folder_list = os.listdir(root_path) - folder_list = [os.path.join(root_path, folder_name, 'Images') for folder_name in folder_list if folder_name.lower().startswith(FOLDER_FILTER.lower())] + folder_list = [ + os.path.join(root_path, folder_name, "Images") + for folder_name in folder_list + if folder_name.lower().startswith(FOLDER_FILTER.lower()) + ] for folder_name in folder_list: folder_cont = os.listdir(folder_name) for file_name in folder_cont: @@ -25,94 +29,116 @@ def finding_base_tif_files_path(root_path): base_tif_files_paths.append(os.path.join(folder_name, file_name)) tif_files_paths.append(folder_name) return base_tif_files_paths, tif_files_paths + + ################################################## + def load_constants(): - print('Loading constants...') + print("Loading constants...") global PREVIEW_Z_STACK global PREVIEW_Z global FOLDER_FILTER global INCLUDE_PATTERN_TIF_SEARCH global INCLUDE_PATTERN_TIF_BASESEARCH global PRESET_SHIFT - with open('regex.txt', 'r') as input_file: + with open("regex.txt", "r") as input_file: regex_file = input_file.read() for line in regex_file.splitlines(): - if re.search('x_mult_INCLUDE_PATTERN_TIF_SEARCH', line): - line = line.split(':', 1)[1].strip().lstrip().rstrip(',') + if re.search("x_mult_INCLUDE_PATTERN_TIF_SEARCH", line): + line = line.split(":", 1)[1].strip().lstrip().rstrip(",") INCLUDE_PATTERN_TIF_SEARCH = line - elif re.search('x_mult_INCLUDE_PATTERN_TIF_BASESEARCH', line): - line = line.split(':', 1)[1].strip().lstrip().rstrip(',') + elif re.search("x_mult_INCLUDE_PATTERN_TIF_BASESEARCH", line): + line = line.split(":", 1)[1].strip().lstrip().rstrip(",") INCLUDE_PATTERN_TIF_BASESEARCH = line - with open('config.json', 'r') as input_file: + with open("config.json", "r") as input_file: config = json.load(input_file) - PREVIEW_Z_STACK = config['correct_shift_x_multi']['PREVIEW_Z_STACK'] - PREVIEW_Z = config['correct_shift_x_multi']['PREVIEW_Z'] - NEW_PATH_SUF = config['correct_shift_x_multi']['NEW_PATH_SUF'] - FOLDER_FILTER = config['correct_shift_x_multi']['FOLDER_FILTER'] - PRESET_SHIFT = config['correct_shift_x_multi']['PRESET_SHIFT'] - return NEW_PATH_SUF #IDK WHY THIS CAN'T BE GLOBAL(ID DOESNT WORK LIKE THE OTHERS? WHY?) + PREVIEW_Z_STACK = config["correct_shift_x_multi"]["PREVIEW_Z_STACK"] + PREVIEW_Z = config["correct_shift_x_multi"]["PREVIEW_Z"] + NEW_PATH_SUF = config["correct_shift_x_multi"]["NEW_PATH_SUF"] + FOLDER_FILTER = config["correct_shift_x_multi"]["FOLDER_FILTER"] + PRESET_SHIFT = config["correct_shift_x_multi"]["PRESET_SHIFT"] + return NEW_PATH_SUF # IDK WHY THIS CAN'T BE GLOBAL(ID DOESNT WORK LIKE THE OTHERS? WHY?) + + # #Ok it is bc it is used in concurrent.futures. Wellp I guess I'll just return it then + def correct_constant_shift_X_img(img, shift): for i, row in enumerate(img[::2]): - l = i*2 + l = i * 2 img[l] = np.roll(row, shift) return img + def correct_constant_shift_X(z_stack, shift): - for z, img in enumerate(z_stack): + for z, img in enumerate(z_stack): for i, row in enumerate(img[::2]): - l = i*2 + l = i * 2 z_stack[z, l] = np.roll(row, shift) return z_stack + def find_other_tif(file_path): folder_path = os.path.dirname(file_path) file_list = os.listdir(folder_path) - file_list = [filename for filename in file_list if filename.lower().endswith('.tif')] + file_list = [ + filename for filename in file_list if filename.lower().endswith(".tif") + ] return file_list + def finding_shift(tif_data, shift, NEW_PATH_SUF): eval_img = (tif_data[PREVIEW_Z_STACK][PREVIEW_Z]).copy() eval_img = correct_constant_shift_X_img(eval_img, shift) imshow(tif_data[PREVIEW_Z_STACK][PREVIEW_Z], eval_img) while True: - answer = input('Do you want to proceed with the shift or change it?([y]/n/"number"/help)') - if answer.lower() == 'n': + answer = input( + 'Do you want to proceed with the shift or change it?([y]/n/"number"/help)' + ) + if answer.lower() == "n": exit() elif answer.isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) return shift - elif answer.lstrip('-').isdigit(): + elif answer.lstrip("-").isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) return shift - elif answer.lower() == 'help': - print('Change the shown image by changing PREVIEW_Z_STACK and PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: ' + str(PREVIEW_Z_STACK) + ' ' +str(PREVIEW_Z) + '\nCurrent ending: ' + NEW_PATH_SUF) + elif answer.lower() == "help": + print( + "Change the shown image by changing PREVIEW_Z_STACK and PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: " + + str(PREVIEW_Z_STACK) + + " " + + str(PREVIEW_Z) + + "\nCurrent ending: " + + NEW_PATH_SUF + ) finding_shift(tif_data, shift, NEW_PATH_SUF) return shift elif not answer: return shift - elif answer.lower() == 'y': + elif answer.lower() == "y": return shift else: - print('The input is not an integer') + print("The input is not an integer") + def shiftingstuff_main(shift, tif_data, tif_path, NEW_PATH_SUF): corrected_data = tif_data.copy() for frame_i, img in enumerate(tif_data): corrected_data[frame_i] = correct_constant_shift_X(img.copy(), shift) - new_path = tif_path.replace('.tif', NEW_PATH_SUF + '.tif') + new_path = tif_path.replace(".tif", NEW_PATH_SUF + ".tif") skimage.io.imsave(new_path, corrected_data, check_contrast=False) print("Saved under:\n" + str(new_path)) del tif_data del corrected_data return + def shiftingstuff_other(shifttif, NEW_PATH_SUF): if shifttif[0] != 0: tif_data = load.imread(shifttif[1]) @@ -120,25 +146,34 @@ def shiftingstuff_other(shifttif, NEW_PATH_SUF): del tif_data return + def sequential(NEW_PATH_SUF): parser = argparse.ArgumentParser() - parser.add_argument('root_path', help='Path to the folder containing all the folders with the positions') + parser.add_argument( + "root_path", + help="Path to the folder containing all the folders with the positions", + ) args = parser.parse_args() root_path = args.root_path base_file_paths, other_files_paths = finding_base_tif_files_path(root_path) - print('Path: \n' + root_path) - print('Base files found:\n' + "\n".join(base_file_paths)) + print("Path: \n" + root_path) + print("Base files found:\n" + "\n".join(base_file_paths)) if base_file_paths == []: - print('No files found!') + print("No files found!") exit() while True: - answer = input('Do you want to shift the other .tif files in the folders too? ([y]/n/help)') - if answer.lower() == 'n': + answer = input( + "Do you want to shift the other .tif files in the folders too? ([y]/n/help)" + ) + if answer.lower() == "n": scan_other = False break - elif answer.lower() == 'help': - print('You can change the regex pattern in the beginning of the code (EXCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: ' + INCLUDE_PATTERN_TIF_SEARCH) + elif answer.lower() == "help": + print( + "You can change the regex pattern in the beginning of the code (EXCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: " + + INCLUDE_PATTERN_TIF_SEARCH + ) exit() else: scan_other = True @@ -148,28 +183,43 @@ def sequential(NEW_PATH_SUF): for i, tif_path in enumerate(base_file_paths): shift = PRESET_SHIFT tif_data = load.imread(tif_path) - print('You are looking at:\n' + str(tif_path) + '\nPlease close the window after inspecting if the shift value is right in order to proceed.') + print( + "You are looking at:\n" + + str(tif_path) + + "\nPlease close the window after inspecting if the shift value is right in order to proceed." + ) shift = finding_shift(tif_data, shift, NEW_PATH_SUF) tif_files_master.append([shift, tif_path]) del tif_data if scan_other == True: other_tif_files = [] - other_tif_files = find_other_tif(tif_path) - other_tif_files = [tif_file for tif_file in other_tif_files if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file)] - other_tif_files = [os.path.join(other_files_paths[i], tif_file) for tif_file in other_tif_files] + other_tif_files = find_other_tif(tif_path) + other_tif_files = [ + tif_file + for tif_file in other_tif_files + if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file) + ] + other_tif_files = [ + os.path.join(other_files_paths[i], tif_file) + for tif_file in other_tif_files + ] for other_tif_file in other_tif_files: tif_files_master.append([shift, other_tif_file]) return tif_files_master + if __name__ == "__main__": NEW_PATH_SUF = load_constants() tif_files_master = sequential(NEW_PATH_SUF) - print('\nFiles with shift:\n') + print("\nFiles with shift:\n") for sub_list in tif_files_master: - print('Shift: ' + str(sub_list[0]) + '\nPath:' + str(sub_list[1]) + '\n') + print("Shift: " + str(sub_list[0]) + "\nPath:" + str(sub_list[1]) + "\n") with concurrent.futures.ProcessPoolExecutor() as executor: futures = [] - futures = [executor.submit(shiftingstuff_other, shifttif, NEW_PATH_SUF) for shifttif in tif_files_master] + futures = [ + executor.submit(shiftingstuff_other, shifttif, NEW_PATH_SUF) + for shifttif in tif_files_master + ] results = [future.result() for future in futures] - print('Done!') - exit() \ No newline at end of file + print("Done!") + exit() diff --git a/cellacdc/scripts/correct_shift_X_single.py b/cellacdc/scripts/correct_shift_X_single.py index 81b7d6096..d5463a8bb 100644 --- a/cellacdc/scripts/correct_shift_X_single.py +++ b/cellacdc/scripts/correct_shift_X_single.py @@ -26,124 +26,155 @@ NEW_PATH_SUF = None INCLUDE_PATTERN_TIF_SEARCH = None + def load_constants(): - print('Loading constants...') + print("Loading constants...") global PREVIEW_Z global INCLUDE_PATTERN_TIF_SEARCH - with open('regex.txt', 'r') as input_file: + with open("regex.txt", "r") as input_file: regex_file = input_file.read() for line in regex_file.splitlines(): - if re.search('x_INCLUDE_PATTERN_TIF_SEARCH', line): - line = line.split(':', 1)[1].strip().lstrip().rstrip(',') + if re.search("x_INCLUDE_PATTERN_TIF_SEARCH", line): + line = line.split(":", 1)[1].strip().lstrip().rstrip(",") INCLUDE_PATTERN_TIF_SEARCH = line - with open('config.json', 'r') as input_file: + with open("config.json", "r") as input_file: config = json.load(input_file) - PREVIEW_Z = config['correct_shift_x_single']['PREVIEW_Z'] - NEW_PATH_SUF = config['correct_shift_x_single']['NEW_PATH_SUF'] + PREVIEW_Z = config["correct_shift_x_single"]["PREVIEW_Z"] + NEW_PATH_SUF = config["correct_shift_x_single"]["NEW_PATH_SUF"] return NEW_PATH_SUF + def correct_constant_shift_X_img(img, shift): for i, row in enumerate(img[::2]): - l = i*2 + l = i * 2 img[l] = np.roll(row, shift) return img + def correct_constant_shift_X(z_stack, shift): for z, img in enumerate(z_stack): img = correct_constant_shift_X_img(img, shift) z_stack[z] = img return z_stack + def find_other_tif(file_path): folder_path = os.path.dirname(file_path) file_list = os.listdir(folder_path) - tif_files = [filename for filename in file_list if filename.lower().endswith('.tif')] + tif_files = [ + filename for filename in file_list if filename.lower().endswith(".tif") + ] return tif_files + def finding_shift(tif_data, shift, start_frame, NEW_PATH_SUF): eval_img = (tif_data[start_frame][PREVIEW_Z]).copy() - eval_img = correct_constant_shift_X_img(eval_img, shift) + eval_img = correct_constant_shift_X_img(eval_img, shift) imshow(tif_data[start_frame][PREVIEW_Z], eval_img) while True: - answer = input('Do you want to proceed with the shift or change it ([y]/n/"number"/help)? ') - if answer.lower() == 'n': + answer = input( + 'Do you want to proceed with the shift or change it ([y]/n/"number"/help)? ' + ) + if answer.lower() == "n": exit() elif answer.isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, start_frame, NEW_PATH_SUF) return shift - elif answer.lstrip('-').isdigit(): + elif answer.lstrip("-").isdigit(): shift = int(answer) shift = finding_shift(tif_data, shift, start_frame, NEW_PATH_SUF) return shift - elif answer.lower() == 'help': - print('Change the shown image by changing PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: ' + str(PREVIEW_Z) + '\nCurrent ending: ' + NEW_PATH_SUF) + elif answer.lower() == "help": + print( + "Change the shown image by changing PREVIEW_Z in the beginning of the code. \nChange the ending of the new file name by changing NEW_PATH_SUF in the code. \nCurrent z stack and z displayed: " + + str(PREVIEW_Z) + + "\nCurrent ending: " + + NEW_PATH_SUF + ) finding_shift(tif_data, shift, start_frame, NEW_PATH_SUF) return shift elif not answer: return shift - elif answer.lower() == 'y': + elif answer.lower() == "y": return shift else: - print('The input is not an integer') - + print("The input is not an integer") + def shiftingstuff_main(shift, tif_data, tif_path, start_frame, end_frame, NEW_PATH_SUF): corrected_data = tif_data.copy() for frame_i, img in islice(enumerate(tif_data), start_frame, end_frame): corrected_data[frame_i] = correct_constant_shift_X(img.copy(), shift) - new_path = tif_path.replace('.tif', NEW_PATH_SUF + '.tif' ) + new_path = tif_path.replace(".tif", NEW_PATH_SUF + ".tif") skimage.io.imsave(new_path, corrected_data, check_contrast=False) del corrected_data del tif_data return -def shiftingstuff_other(tif_name, shift, tif_path, scan_other, start_frame, end_frame, NEW_PATH_SUF): + +def shiftingstuff_other( + tif_name, shift, tif_path, scan_other, start_frame, end_frame, NEW_PATH_SUF +): if scan_other == True: tif_path = os.path.join(os.path.dirname(tif_path), tif_name) tif_data = load.imread(tif_path) - shiftingstuff_main(shift, tif_data, tif_path, start_frame, end_frame, NEW_PATH_SUF) + shiftingstuff_main( + shift, tif_data, tif_path, start_frame, end_frame, NEW_PATH_SUF + ) del tif_data return + def sequential(NEW_PATH_SUF): parser = argparse.ArgumentParser() - parser.add_argument('tif_path', help='Path to the tif-file') - parser.add_argument('shift', help='Amount of shift') - parser.add_argument('frame_start', help='Start of frames which should be shifted') - parser.add_argument('frame_end', help='End of frames which should be shifted') + parser.add_argument("tif_path", help="Path to the tif-file") + parser.add_argument("shift", help="Amount of shift") + parser.add_argument("frame_start", help="Start of frames which should be shifted") + parser.add_argument("frame_end", help="End of frames which should be shifted") args = parser.parse_args() tif_path = args.tif_path shift = int(args.shift) start_frame = int(args.frame_start) end_frame = int(args.frame_end) - print('Path: \n' + tif_path) - print('Original Shift: ' + str(shift)) - print('Start from frame: ' + str(start_frame)) - print('End on frame: ' + str(end_frame)) + print("Path: \n" + tif_path) + print("Original Shift: " + str(shift)) + print("Start from frame: " + str(start_frame)) + print("End on frame: " + str(end_frame)) tif_data = load.imread(tif_path) start_frame -= 1 - print('Please close the window after inspecting if the shift value is right in order to proceed.') + print( + "Please close the window after inspecting if the shift value is right in order to proceed." + ) shift = finding_shift(tif_data, shift, start_frame, NEW_PATH_SUF) - print('Shift used: ' +str(shift)) + print("Shift used: " + str(shift)) - tif_files = find_other_tif(tif_path) - tif_names = [tif_file for tif_file in tif_files if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file)] - print('New tif file(s) found:\n' + "\n".join(tif_names)) + tif_files = find_other_tif(tif_path) + tif_names = [ + tif_file + for tif_file in tif_files + if re.match(INCLUDE_PATTERN_TIF_SEARCH, tif_file) + ] + print("New tif file(s) found:\n" + "\n".join(tif_names)) while True: - answer = input('Do you want to shift the other .tif files in the folder too? ([y]/n/help)') - if answer.lower() == 'n': + answer = input( + "Do you want to shift the other .tif files in the folder too? ([y]/n/help)" + ) + if answer.lower() == "n": scan_other = False break - elif answer.lower() == 'help': - print('You can change the regex pattern in the beginning of the code (INCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: ' + INCLUDE_PATTERN_TIF_SEARCH) + elif answer.lower() == "help": + print( + "You can change the regex pattern in the beginning of the code (INCLUDE_PATTERN_TIF_SEARCH). \nIf you dont know regex, ask Chat_GPT to generate one for you by giving it examples of file names and then asking it to generate a regex code which excludes the files you want to exclude. \nCurrent expression is: " + + INCLUDE_PATTERN_TIF_SEARCH + ) exit() else: scan_other = True @@ -153,11 +184,35 @@ def sequential(NEW_PATH_SUF): if __name__ == "__main__": NEW_PATH_SUF = load_constants() - shift, tif_data, tif_names, scan_other, tif_path, start_frame, end_frame = sequential(NEW_PATH_SUF) + shift, tif_data, tif_names, scan_other, tif_path, start_frame, end_frame = ( + sequential(NEW_PATH_SUF) + ) with concurrent.futures.ProcessPoolExecutor() as executor: futures = [] - futures = [executor.submit(shiftingstuff_other, tif_name, shift, tif_path, scan_other, start_frame, end_frame, NEW_PATH_SUF) for tif_name in tif_names] - futures.append(executor.submit(shiftingstuff_main, shift, tif_data, tif_path, start_frame, end_frame, NEW_PATH_SUF)) + futures = [ + executor.submit( + shiftingstuff_other, + tif_name, + shift, + tif_path, + scan_other, + start_frame, + end_frame, + NEW_PATH_SUF, + ) + for tif_name in tif_names + ] + futures.append( + executor.submit( + shiftingstuff_main, + shift, + tif_data, + tif_path, + start_frame, + end_frame, + NEW_PATH_SUF, + ) + ) results = [future.result() for future in futures] - print('Done!') - exit() \ No newline at end of file + print("Done!") + exit() diff --git a/cellacdc/scripts/pngtotif.py b/cellacdc/scripts/pngtotif.py index cb956dfd4..af6828e17 100644 --- a/cellacdc/scripts/pngtotif.py +++ b/cellacdc/scripts/pngtotif.py @@ -1,9 +1,10 @@ from PIL import Image import os + def convert_png_to_tif(input_folder, output_tif): images = [] - + # Get a list of all PNG files in the input folder png_files = [file for file in os.listdir(input_folder) if file.endswith(".png")] @@ -21,7 +22,12 @@ def convert_png_to_tif(input_folder, output_tif): print(f"Conversion completed. TIFF file saved at {output_tif}") -input_folder = r"C:\Users\SchmollerLab\Documents\Timon\DeepSea_data\test\set_22_MESC\images" -output_tif = r"C:\Users\SchmollerLab\Documents\Timon\DeepSea_data\test\set_22_MESC\images.tiff" + +input_folder = ( + r"C:\Users\SchmollerLab\Documents\Timon\DeepSea_data\test\set_22_MESC\images" +) +output_tif = ( + r"C:\Users\SchmollerLab\Documents\Timon\DeepSea_data\test\set_22_MESC\images.tiff" +) convert_png_to_tif(input_folder, output_tif) diff --git a/cellacdc/scripts/split_segm_mask_yeast.py b/cellacdc/scripts/split_segm_mask_yeast.py index db1ac7d1b..9cdd4ec8c 100644 --- a/cellacdc/scripts/split_segm_mask_yeast.py +++ b/cellacdc/scripts/split_segm_mask_yeast.py @@ -6,22 +6,24 @@ import qtpy.compat -from cellacdc import printl, myutils, apps, load, core, widgets +from cellacdc import printl, utils, apps, load, core, widgets from cellacdc._run import _setup_app -from cellacdc.utils.base import NewThreadMultipleExpBaseUtil +from cellacdc.tools.base import NewThreadMultipleExpBaseUtil from cellacdc import io DEBUG = False + def ask_select_folder(): selected_path = qtpy.compat.getexistingdirectory( - caption='Select experiment folder to analyse', - basedir=myutils.getMostRecentPath() + caption="Select experiment folder to analyse", + basedir=utils.getMostRecentPath(), ) return selected_path + def get_exp_path_pos_foldernames(selected_path): - folder_type = myutils.determine_folder_type(selected_path) + folder_type = utils.determine_folder_type(selected_path) is_pos_folder, is_images_folder, exp_path = folder_type if is_pos_folder: exp_path = os.path.dirname(selected_path) @@ -32,151 +34,143 @@ def get_exp_path_pos_foldernames(selected_path): pos_foldernames = [os.path.basename(pos_path)] else: exp_path = selected_path - pos_foldernames = myutils.get_pos_foldernames(exp_path) - + pos_foldernames = utils.get_pos_foldernames(exp_path) + return exp_path, pos_foldernames + def select_segm_masks(exp_path, pos_foldernames): - infoText = 'Select which segmentation file OF THE CELLS:' + infoText = "Select which segmentation file OF THE CELLS:" existingEndNames = load.get_segm_endnames_from_exp_path( exp_path, pos_foldernames=pos_foldernames ) win = apps.SelectSegmFileDialog( - existingEndNames, exp_path, - infoText=infoText, - fileType='segmentation' + existingEndNames, exp_path, infoText=infoText, fileType="segmentation" ) win.exec_() if win.cancel: return - + cells_segm_endname = win.selectedItemText - - infoText = 'Select segmentation files to SPLIT:' + + infoText = "Select segmentation files to SPLIT:" existingEndNames.discard(cells_segm_endname) win = apps.SelectSegmFileDialog( - existingEndNames, exp_path, - infoText=infoText, - fileType='segmentation', - allowMultipleSelection=True + existingEndNames, + exp_path, + infoText=infoText, + fileType="segmentation", + allowMultipleSelection=True, ) win.exec_() if win.cancel: return - + list_segm_endnames_to_split = win.selectedItemTexts return cells_segm_endname, list_segm_endnames_to_split + def run(): - app, splashScreen = _setup_app(splashscreen=True) + app, splashScreen = _setup_app(splashscreen=True) splashScreen.close() - + selected_path = ask_select_folder() if not selected_path: - exit('Execution cancelled') - - myutils.addToRecentPaths(selected_path) + exit("Execution cancelled") + + utils.addToRecentPaths(selected_path) exp_path, pos_foldernames = get_exp_path_pos_foldernames(selected_path) - + if len(pos_foldernames) > 1: selectPosWin = widgets.QDialogListbox( - 'Select Positions to analyse', - 'Select Positions to analyse:\n', - pos_foldernames, - multiSelection=True, - parent=None + "Select Positions to analyse", + "Select Positions to analyse:\n", + pos_foldernames, + multiSelection=True, + parent=None, ) selectPosWin.exec_() if selectPosWin.cancel: - print('Execution stopped by the user') + print("Execution stopped by the user") return - + pos_foldernames = selectPosWin.selectedItemsText - + selected_segm_endnames = select_segm_masks(exp_path, pos_foldernames) if selected_segm_endnames is None: - exit('Execution cancelled') - + exit("Execution cancelled") + cells_segm_endname, list_segm_endnames_to_split = selected_segm_endnames - + list_segm_endnames_to_split_str = [ - f' {val}' for val in list_segm_endnames_to_split + f" {val}" for val in list_segm_endnames_to_split ] - list_segm_endnames_to_split_str = '\n'.join(list_segm_endnames_to_split_str) - print('='*100) + list_segm_endnames_to_split_str = "\n".join(list_segm_endnames_to_split_str) + print("=" * 100) print( - f' - Cells segmentation endname: {cells_segm_endname}', - f' - Segmentation files to split:', - f'{list_segm_endnames_to_split_str}', - sep='\n' + f" - Cells segmentation endname: {cells_segm_endname}", + f" - Segmentation files to split:", + f"{list_segm_endnames_to_split_str}", + sep="\n", ) - - acdc_df_endname = cells_segm_endname.replace('segm', 'acdc_output') - if not acdc_df_endname.endswith('.csv'): - acdc_df_endname = f'{acdc_df_endname}.csv' - - print(f' - Cell cycle annotations file: {acdc_df_endname}') + + acdc_df_endname = cells_segm_endname.replace("segm", "acdc_output") + if not acdc_df_endname.endswith(".csv"): + acdc_df_endname = f"{acdc_df_endname}.csv" + + print(f" - Cell cycle annotations file: {acdc_df_endname}") pbar = tqdm(total=len(pos_foldernames), ncols=100) for pos in pos_foldernames: - images_path = os.path.join(exp_path, pos, 'Images') + images_path = os.path.join(exp_path, pos, "Images") cells_segm_data = load.load_segm_file( images_path, end_name_segm_file=cells_segm_endname ) - + acdc_df = load.load_acdc_df_file( images_path, end_name_acdc_df_file=acdc_df_endname ) if acdc_df is None: - files_format = '\n'.join([ - f' - {file}' for file in os.listdir(images_path) - ]) - print('', '='*100, sep='\n') - print( - f'Files present in "{images_path}":\n\n{files_format}' + files_format = "\n".join( + [f" - {file}" for file in os.listdir(images_path)] ) + print("", "=" * 100, sep="\n") + print(f'Files present in "{images_path}":\n\n{files_format}') print( f'\n[WARNING]: Cell cycle annotations file "{acdc_df_endname}" ' - 'not found in the following folder. Skipping it.\n\n' - f'{images_path}' + "not found in the following folder. Skipping it.\n\n" + f"{images_path}" ) - print('='*100) + print("=" * 100) continue - + pbar.set_description(pos) for segm_endname in list_segm_endnames_to_split: segm_data_to_split, segm_data_to_split_fp = load.load_segm_file( - images_path, end_name_segm_file=segm_endname, - return_path=True + images_path, end_name_segm_file=segm_endname, return_path=True ) out = core.split_segm_masks_mother_bud_line( - cells_segm_data, segm_data_to_split, acdc_df, - debug=DEBUG + cells_segm_data, segm_data_to_split, acdc_df, debug=DEBUG ) split_segm_close, split_segm_away = out - + segm_data_to_split_fn = os.path.basename(segm_data_to_split_fp) - + split_close_filename = segm_data_to_split_fn.replace( - segm_endname, f'{segm_endname}_split_close.npz' - ).replace('.npz.npz', '.npz') - split_close_filepath = os.path.join( - images_path, split_close_filename - ) - + segm_endname, f"{segm_endname}_split_close.npz" + ).replace(".npz.npz", ".npz") + split_close_filepath = os.path.join(images_path, split_close_filename) + io.savez_compressed(split_close_filepath, split_segm_close) - - + split_away_filename = segm_data_to_split_fn.replace( - segm_endname, f'{segm_endname}_split_away.npz' - ).replace('.npz.npz', '.npz') - split_away_filepath = os.path.join( - images_path, split_away_filename - ) + segm_endname, f"{segm_endname}_split_away.npz" + ).replace(".npz.npz", ".npz") + split_away_filepath = os.path.join(images_path, split_away_filename) io.savez_compressed(split_away_filepath, split_segm_away) pbar.update() - + pbar.close() - -if __name__ == '__main__': - run() \ No newline at end of file + +if __name__ == "__main__": + run() diff --git a/cellacdc/segm.py b/cellacdc/segm.py index c75c2589d..aca24674c 100755 --- a/cellacdc/segm.py +++ b/cellacdc/segm.py @@ -19,19 +19,34 @@ from tqdm import tqdm from qtpy.QtWidgets import ( - QApplication, QMainWindow, QFileDialog, - QVBoxLayout, QPushButton, QLabel, QProgressBar, QHBoxLayout, - QStyleFactory, QWidget, QMessageBox, QTextEdit + QApplication, + QMainWindow, + QFileDialog, + QVBoxLayout, + QPushButton, + QLabel, + QProgressBar, + QHBoxLayout, + QStyleFactory, + QWidget, + QMessageBox, + QTextEdit, ) from qtpy.QtCore import ( - Qt, QEventLoop, Signal, QObject, QMutex, QWaitCondition, QThread, - QTimer + Qt, + QEventLoop, + Signal, + QObject, + QMutex, + QWaitCondition, + QThread, + QTimer, ) from qtpy import QtGui import qtpy.compat # Custom modules -from . import prompts, load, myutils, apps, core, dataPrep, widgets +from . import prompts, load, utils, apps, core, dataPrep, widgets from . import html_utils, printl from . import exception_handler from . import workers @@ -39,21 +54,24 @@ from . import config from . import urls -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except: pass + class QTerminal(QTextEdit): def write(self, message): - message = message.replace('\r ', '') + message = message.replace("\r ", "") if message: self.setText(message) + class SegmWorkerSignals(QObject): finished = Signal(object) progress = Signal(str) @@ -66,6 +84,7 @@ class SegmWorkerSignals(QObject): debug = Signal(object) critical = Signal(object) + import os import time @@ -74,28 +93,25 @@ class SegmWorkerSignals(QObject): from cellacdc import load, core, features, cli + class SegmWorker(QObject): - def __init__( - self, img_path, mainWin, stop_frame_n - ): + def __init__(self, img_path, mainWin, stop_frame_n): QObject.__init__(self) self.signals = SegmWorkerSignals() self.img_path = img_path self.stop_frame_n = stop_frame_n self.mainWin = mainWin self.init_kernel(mainWin) - + def init_kernel(self, mainWin): use_ROI = not mainWin.ROIdeactivatedByUser - self.kernel = cli.SegmKernel( - mainWin.logger, mainWin.log_path, is_cli=False - ) + self.kernel = cli.SegmKernel(mainWin.logger, mainWin.log_path, is_cli=False) self.kernel.init_args( - mainWin.user_ch_name, + mainWin.user_ch_name, mainWin.endFilenameSegm, - mainWin.model_name, + mainWin.model_name, mainWin.do_tracking, - mainWin.applyPostProcessing, + mainWin.applyPostProcessing, mainWin.save, mainWin.image_chName_tracker, mainWin.standardPostProcessKwargs, @@ -107,7 +123,7 @@ def init_kernel(self, mainWin): mainWin.use3DdataFor2Dsegm, mainWin.model_kwargs, mainWin.track_params, - mainWin.SizeT, + mainWin.SizeT, mainWin.SizeZ, model=mainWin.model, tracker=mainWin.tracker, @@ -115,24 +131,21 @@ def init_kernel(self, mainWin): signals=self.signals, logger_func=self.signals.progress.emit, innerPbar_available=mainWin.innerPbar_available, - is_segment3DT_available=mainWin.is_segment3DT_available, - preproc_recipe=mainWin.preproc_recipe, + is_segment3DT_available=mainWin.is_segment3DT_available, + preproc_recipe=mainWin.preproc_recipe, reduce_memory_usage=mainWin.reduce_memory_usage, use_freehand_ROI=mainWin.useFreeHandROI, ) - + def run_kernels(self): - self.kernel.run( - self.img_path, - self.stop_frame_n - ) + self.kernel.run(self.img_path, self.stop_frame_n) if self.mainWin._measurements_kernel is None: return - - segm_endname = self.kernel.segm_endname.replace('.npz', '') + + segm_endname = self.kernel.segm_endname.replace(".npz", "") self.mainWin._measurements_kernel.run( - img_path=self.img_path, - stop_frame_n=self.stop_frame_n, + img_path=self.img_path, + stop_frame_n=self.stop_frame_n, end_filename_segm=segm_endname, ) @@ -140,14 +153,19 @@ def run_kernels(self): def run(self): self.run_kernels() self.signals.finished.emit(self) - + + class segmWin(QMainWindow): sigClosed = Signal() - + def __init__( - self, parent=None, allowExit=False, buttonToRestore=None, - mainWin=None, version=None - ): + self, + parent=None, + allowExit=False, + buttonToRestore=None, + mainWin=None, + version=None, + ): super().__init__(parent) self.allowExit = allowExit @@ -155,23 +173,21 @@ def __init__( self.mainWin = mainWin if mainWin is not None: self.app = mainWin.app - + self._version = version - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='segm' - ) + logger, logs_path, log_path, log_filename = utils.setupLogger(module="segm") self.logger = logger self.log_path = log_path self.log_filename = log_filename self.logs_path = logs_path if self._version is not None: - logger.info(f'Initializing Segmentation module v{self._version}...') + logger.info(f"Initializing Segmentation module v{self._version}...") else: - logger.info(f'Initializing Segmentation module...') + logger.info(f"Initializing Segmentation module...") - self.setWindowTitle(f'Cell-ACDC v{self._version} - Segment') + self.setWindowTitle(f"Cell-ACDC v{self._version} - Segment") self.setWindowIcon(QtGui.QIcon(":icon.ico")) mainContainer = QWidget() @@ -212,7 +228,7 @@ def __init__( self.progressLabel = widgets.Label(self, force_html=True) self.mainLayout.addWidget(self.progressLabel) - abortButton = widgets.cancelPushButton('Stop processs') + abortButton = widgets.cancelPushButton("Stop processs") abortButton.clicked.connect(self.close) buttonsLayout.addStretch(1) buttonsLayout.addWidget(abortButton) @@ -224,25 +240,25 @@ def __init__( def getMostRecentPath(self): if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - self.MostRecentPath = df.iloc[0]['path'] + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + self.MostRecentPath = df.iloc[0]["path"] if not isinstance(self.MostRecentPath, str): - self.MostRecentPath = '' + self.MostRecentPath = "" else: - self.MostRecentPath = '' + self.MostRecentPath = "" def addToRecentPaths(self, exp_path): if not os.path.exists(exp_path): return if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() else: - openedOn = [np.nan]*len(recentPaths) + openedOn = [np.nan] * len(recentPaths) if exp_path in recentPaths: pop_idx = recentPaths.index(exp_path) recentPaths.pop(pop_idx) @@ -256,10 +272,13 @@ def addToRecentPaths(self, exp_path): else: recentPaths = [exp_path] openedOn = [datetime.datetime.now()] - df = pd.DataFrame({'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, - dtype='datetime64[ns]')}) - df.index.name = 'index' + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" df.to_csv(recentPaths_path) def addPbar(self, add_inner=False): @@ -267,7 +286,7 @@ def addPbar(self, add_inner=False): QPbar = widgets.ProgressBar(self) pBarLayout.addWidget(QPbar) ETA_label = QLabel() - ETA_label.setText('ETA: NDh:NDm:NDs') + ETA_label.setText("ETA: NDh:NDm:NDs") pBarLayout.addWidget(ETA_label) if add_inner: self.innerQPbar = QPbar @@ -281,7 +300,7 @@ def addPbar(self, add_inner=False): screen = self.screen() screenHeight = screen.size().height() screenWidth = screen.size().width() - self.resize(int(screenWidth*0.5), int(screenHeight*0.6)) + self.resize(int(screenWidth * 0.5), int(screenHeight * 0.6)) def askHowToHandleROI(self): if len(self.posData.dataPrepFreeRoiPoints) > 0: @@ -293,42 +312,44 @@ def askHowToHandleROI(self): """) msg = widgets.myMessageBox(wrapText=False) _, noButton, yesButton = msg.question( - self, 'Use the free-hand ROI?', txt, - buttonsTexts = ( - 'Cancel', - 'No, segment the entire image', - 'Yes, use the free-hand ROI' - ) + self, + "Use the free-hand ROI?", + txt, + buttonsTexts=( + "Cancel", + "No, segment the entire image", + "Yes, use the free-hand ROI", + ), ) return False, False, msg.clickedButton == yesButton - - idx_slice = pd.IndexSlice[:, 'cropped'] + + idx_slice = pd.IndexSlice[:, "cropped"] df_ROI = self.posData.dataPrep_ROIcoords if df_ROI is None: - href = html_utils.href_tag('here', urls.dataprep_docs) + href = html_utils.href_tag("here", urls.dataprep_docs) txt = html_utils.paragraph(f""" Do you want to segment only a rectangluar sub-region (ROI) of the image?

    If yes, Cell-ACDC will launch the Data-prep module later.

    See {href} for more details on how to use the Data-prep module. """) - elif int(df_ROI.loc[idx_slice, 'value'].iloc[0]) > 0: + elif int(df_ROI.loc[idx_slice, "value"].iloc[0]) > 0: # Data is cropped, do not ask to segment a roi return False, False, False else: - xl_slice = pd.IndexSlice[:, 'x_left'] - xr_slice = pd.IndexSlice[:, 'x_right'] - yt_slice = pd.IndexSlice[:, 'y_top'] - yb_slice = pd.IndexSlice[:, 'y_bottom'] + xl_slice = pd.IndexSlice[:, "x_left"] + xr_slice = pd.IndexSlice[:, "x_right"] + yt_slice = pd.IndexSlice[:, "y_top"] + yb_slice = pd.IndexSlice[:, "y_bottom"] SizeY, SizeX = self.posData.img_data.shape[-2:] - x0 = int(df_ROI.loc[xl_slice, 'value'].iloc[0]) - x1 = int(df_ROI.loc[xr_slice, 'value'].iloc[0]) - y0 = int(df_ROI.loc[yt_slice, 'value'].iloc[0]) - y1 = int(df_ROI.loc[yb_slice, 'value'].iloc[0]) - if x0 == 0 and y0 == 0 and y1==SizeY and y1 == SizeX: + x0 = int(df_ROI.loc[xl_slice, "value"].iloc[0]) + x1 = int(df_ROI.loc[xr_slice, "value"].iloc[0]) + y0 = int(df_ROI.loc[yt_slice, "value"].iloc[0]) + y1 = int(df_ROI.loc[yb_slice, "value"].iloc[0]) + if x0 == 0 and y0 == 0 and y1 == SizeY and y1 == SizeX: # ROI is present but with same shape as image --> ignore return False, False, False - + note = html_utils.to_admonition(""" If you need to modify the existing ROI, cancel the process now and launch Data-prep again. @@ -340,67 +361,61 @@ def askHowToHandleROI(self): {note} """) msg = widgets.myMessageBox(showCentered=False, wrapText=False) - _, yesButton, noButton = msg.question(self, 'ROI?', txt, - buttonsTexts = ('Cancel','Yes','No') + _, yesButton, noButton = msg.question( + self, "ROI?", txt, buttonsTexts=("Cancel", "Yes", "No") ) return msg.cancel, msg.clickedButton == yesButton, False - + def main(self): selectFoldersWin = apps.SelectFoldersToAnalyse( - parent=self, - instructionsText= - 'Select experiment folders to analyse using ' - 'the same set of parameters', - askSelectPosFolders=True + parent=self, + instructionsText="Select experiment folders to analyse using " + "the same set of parameters", + askSelectPosFolders=True, ) selectFoldersWin.exec_() if selectFoldersWin.cancel: self.processStopped() return - - expToPosFoldersMapper = ( - selectFoldersWin.selectedExpFolderToPosFoldernamesMapper - ) + + expToPosFoldersMapper = selectFoldersWin.selectedExpFolderToPosFoldernamesMapper font = QtGui.QFont() font.setPixelSize(13) self.setWindowTitle( - f'Cell-ACDC v{self._version} - Segmentation and Tracking workflow' + f"Cell-ACDC v{self._version} - Segmentation and Tracking workflow" ) self.addPbar() self.addlogTerminal() - self.log('Loading data...') - self.progressLabel.setText('Loading data...') + self.log("Loading data...") + self.progressLabel.setText("Loading data...") ch_name_selector = prompts.select_channel_name( - which_channel='segm', allow_abort=True + which_channel="segm", allow_abort=True ) images_paths = [] for exp_path, pos_foldernames in expToPosFoldersMapper.items(): for pos_foldername in pos_foldernames: - images_path = os.path.join( - exp_path, pos_foldername, 'Images' - ) + images_path = os.path.join(exp_path, pos_foldername, "Images") images_paths.append(images_path) user_ch_file_paths = [] for images_path in images_paths: - print('') - self.log(f'Processing {images_path}') - filenames = myutils.listdir(images_path) + print("") + self.log(f"Processing {images_path}") + filenames = utils.listdir(images_path) if not filenames: self.criticalImagesFolderEmpty(images_path) self.close() return if ch_name_selector.is_first_call: - ch_names, warn = ( - ch_name_selector.get_available_channels( - filenames, images_path - )) + ch_names, warn = ch_name_selector.get_available_channels( + filenames, images_path + ) if not ch_names: self.criticalNoTifFound(images_path) self.close() @@ -420,30 +435,34 @@ def main(self): tif_found = False dataPrep_fn = None for filename in filenames: - if filename.find(f'{user_ch_name}_aligned.npz') != -1: + if filename.find(f"{user_ch_name}_aligned.npz") != -1: img_path = os.path.join(images_path, filename) - idx = filename.find('_aligned.npz') + idx = filename.find("_aligned.npz") dataPrep_fn = filename[:idx] aligned_npz_found = True - elif filename.find(f'{user_ch_name}.tif') != -1: + elif filename.find(f"{user_ch_name}.tif") != -1: img_path = os.path.join(images_path, filename) tif_found = True if not aligned_npz_found and not tif_found: - print('') - print('-------------------------------------------------------') - self.log(f'The folder {images_path}\n does not contain the file ' - f'{user_ch_name}_aligned.npz\n or the file {user_ch_name}.tif. ' - 'Skipping it.') - print('-------------------------------------------------------') - print('') + print("") + print("-------------------------------------------------------") + self.log( + f"The folder {images_path}\n does not contain the file " + f"{user_ch_name}_aligned.npz\n or the file {user_ch_name}.tif. " + "Skipping it." + ) + print("-------------------------------------------------------") + print("") elif not aligned_npz_found and tif_found: - print('') - print('-------------------------------------------------------') - self.log(f'The folder {images_path}\n does not contain the file ' - f'{user_ch_name}_aligned.npz. Segmenting .tif data.') - print('-------------------------------------------------------') - print('') + print("") + print("-------------------------------------------------------") + self.log( + f"The folder {images_path}\n does not contain the file " + f"{user_ch_name}_aligned.npz. Segmenting .tif data." + ) + print("-------------------------------------------------------") + print("") user_ch_file_paths.append(img_path) elif aligned_npz_found: user_ch_file_paths.append(img_path) @@ -468,7 +487,7 @@ def main(self): load_last_tracked_i=False, load_metadata=True, load_dataprep_free_roi=True, - load_customCombineMetrics=True + load_customCombineMetrics=True, ) proceed = self.posData.askInputMetadata( self.numPos, @@ -476,7 +495,7 @@ def main(self): ask_TimeIncrement=False, ask_PhysicalSizes=False, save=True, - forceEnableAskSegm3D=True + forceEnableAskSegm3D=True, ) # Store metadata for all other positions loaded for other_img_path in user_ch_file_paths[1:]: @@ -489,127 +508,130 @@ def main(self): ) self._posData.isSegm3D = self.posData.isSegm3D try: - _SizeT = int(self._posData.metadata_df.at['SizeT', 'values']) + _SizeT = int(self._posData.metadata_df.at["SizeT", "values"]) if _SizeT == self.posData.SizeT: continue - - self._posData.metadata_df.at['SizeT', 'values'] = self.posData.SizeT + + self._posData.metadata_df.at["SizeT", "values"] = self.posData.SizeT self._posData.SizeT = self.posData.SizeT except Exception as err: self._posData.SizeT = self.posData.SizeT - + self._posData.saveMetadata() - + self.isSegm3D = self.posData.isSegm3D self.SizeT = self.posData.SizeT self.SizeZ = self.posData.SizeZ if not proceed: self.processStopped() return - + # Ask which model win = apps.QDialogSelectModel( - parent=self, addSkipSegmButton=self.posData.SizeT>1 + parent=self, addSkipSegmButton=self.posData.SizeT > 1 ) win.exec_() if win.cancel: self.processStopped() return - + model_name = win.selectedModel - if model_name == 'thresholding': - win = apps.QDialogAutomaticThresholding( - parent=self, isSegm3D=self.isSegm3D - ) + if model_name == "thresholding": + win = apps.QDialogAutomaticThresholding(parent=self, isSegm3D=self.isSegm3D) win.exec_() if win.cancel: self.processStopped() return self.model_kwargs = win.segment_kwargs - self.log(f'Downloading {model_name} (if needed)...') + self.log(f"Downloading {model_name} (if needed)...") self.downloadWin = apps.downloadModel(model_name, parent=self) self.downloadWin.download() - - self.log(f'Importing {model_name}...') + + self.log(f"Importing {model_name}...") self.model_name = model_name - acdcSegment = myutils.import_segment_module(model_name) - self.acdcSegment = acdcSegment + acdcSegment = utils.import_segment_module(model_name) + self.acdcSegment = acdcSegment # Read all models parameters - init_params, segment_params = myutils.getModelArgSpec(self.acdcSegment) + init_params, segment_params = utils.getModelArgSpec(self.acdcSegment) # Prompt user to enter the model parameters try: url = acdcSegment.url_help() except AttributeError: url = None - + out = prompts.init_segm_model_params( - self.posData, model_name, init_params, segment_params, - help_url=url, qparent=self, init_last_params=False, - add_additional_segm_params=True + self.posData, + model_name, + init_params, + segment_params, + help_url=url, + qparent=self, + init_last_params=False, + add_additional_segm_params=True, ) - win = out.get('win') + win = out.get("win") if win.cancel: self.processStopped() return - - if model_name != 'thresholding': + + if model_name != "thresholding": self.model_kwargs = win.model_kwargs self.standardPostProcessKwargs = win.standardPostProcessKwargs self.customPostProcessFeatures = win.customPostProcessFeatures - self.customPostProcessGroupedFeatures = ( - win.customPostProcessGroupedFeatures - ) + self.customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures self.applyPostProcessing = win.applyPostProcessing self.secondChannelName = win.secondChannelName - - myutils.log_segm_params( - model_name, win.init_kwargs, win.model_kwargs, - logger_func=self.logger.info, - preproc_recipe=win.preproc_recipe, - apply_post_process=self.applyPostProcessing, - standard_postprocess_kwargs=self.standardPostProcessKwargs, - custom_postprocess_features=self.customPostProcessFeatures + + utils.log_segm_params( + model_name, + win.init_kwargs, + win.model_kwargs, + logger_func=self.logger.info, + preproc_recipe=win.preproc_recipe, + apply_post_process=self.applyPostProcessing, + standard_postprocess_kwargs=self.standardPostProcessKwargs, + custom_postprocess_features=self.customPostProcessFeatures, ) init_kwargs = win.init_kwargs self.init_model_kwargs = init_kwargs self.preproc_recipe = win.preproc_recipe self.reduce_memory_usage = win.reduceMemoryUsage - + if self.secondChannelName is not None: - init_kwargs['is_rgb'] = True - - self.model = myutils.init_segm_model(acdcSegment, self.posData, init_kwargs) + init_kwargs["is_rgb"] = True + + self.model = utils.init_segm_model(acdcSegment, self.posData, init_kwargs) if self.model is None: - self.logger.info('Segmentation model was not initialized correctly!') + self.logger.info("Segmentation model was not initialized correctly!") self.processStopped() return try: self.model.setupLogger(self.logger) except Exception as e: pass - + self.predictCcaState_model = None self.is_segment3DT_available = False - if self.posData.SizeT>1 and not self.isSegm3D: + if self.posData.SizeT > 1 and not self.isSegm3D: self.is_segment3DT_available = any( - [name=='segment3DT' for name in dir(acdcSegment.Model)] + [name == "segment3DT" for name in dir(acdcSegment.Model)] ) self.innerPbar_available = False - if len(user_ch_file_paths)>1 and self.posData.SizeT>1: + if len(user_ch_file_paths) > 1 and self.posData.SizeT > 1: self.addPbar(add_inner=True) self.innerPbar_available = True - + # Check if there are segmentation already computed self.selectedSegmFile = None - self.endFilenameSegm = 'segm.npz' + self.endFilenameSegm = "segm.npz" self.isNewSegmFile = False askNewName = True isMultiSegm = False @@ -619,55 +641,55 @@ def main(self): if len(segm_files) > 0: isMultiSegm = True break - - sam_only_embeddings = self.model_kwargs.get('only_embeddings', False) + + sam_only_embeddings = self.model_kwargs.get("only_embeddings", False) self.save = not sam_only_embeddings if isMultiSegm and not sam_only_embeddings: askNewName = self.askMultipleSegm( - segm_files, isTimelapse=self.posData.SizeT>1 + segm_files, isTimelapse=self.posData.SizeT > 1 ) if askNewName is None: self.save = False self.processStopped() return - + if self.selectedSegmFile is not None: - self.endFilenameSegm = self.selectedSegmFile[len(self.posData.basename):] - + self.endFilenameSegm = self.selectedSegmFile[len(self.posData.basename) :] + if askNewName and self.save: self.isNewSegmFile = True win = apps.filenameDialog( - basename=f'{self.posData.basename}segm', - hintText='Insert a filename for the segmentation file:
    ', - existingNames=segm_files + basename=f"{self.posData.basename}segm", + hintText="Insert a filename for the segmentation file:
    ", + existingNames=segm_files, ) win.exec_() if win.cancel: self.processStopped() return if win.entryText: - self.endFilenameSegm = f'segm_{win.entryText}.npz' + self.endFilenameSegm = f"segm_{win.entryText}.npz" else: - self.endFilenameSegm = f'segm.npz' + self.endFilenameSegm = f"segm.npz" # Save hyperparams + post_process_params = {"applied_postprocessing": self.applyPostProcessing} post_process_params = { - 'applied_postprocessing': self.applyPostProcessing - } - post_process_params = { - **post_process_params, + **post_process_params, **self.standardPostProcessKwargs, - **self.customPostProcessFeatures + **self.customPostProcessFeatures, } - + for other_img_path in user_ch_file_paths: self._posData = load.loadData(other_img_path, user_ch_name, QParent=self) self._posData.getBasenameAndChNames(qparent=self) self._posData.buildPaths() self._posData.saveSegmHyperparams( - model_name, self.init_model_kwargs, self.model_kwargs, - post_process_params=post_process_params, - preproc_recipe=self.preproc_recipe + model_name, + self.init_model_kwargs, + self.model_kwargs, + post_process_params=post_process_params, + preproc_recipe=self.preproc_recipe, ) # Ask ROI @@ -701,16 +723,16 @@ def main(self): if self._posData.segmInfo_df is None: isSegmInfoPresent = False break - + self.use3DdataFor2Dsegm = False if self.posData.SizeZ > 1 and not self.isSegm3D: cancel, use3DdataFor2Dsegm = self.askHowToHandle2DsegmOn3Ddata() if cancel: self.processStopped() return - + self.use3DdataFor2Dsegm = use3DdataFor2Dsegm - + segm2D_never_visualized_dataPrep = ( not self.isSegm3D and self.posData.SizeZ > 1 @@ -739,21 +761,19 @@ def main(self): ) dataPrepWin.show() if selectROI: - dataPrepWin.titleText = ( - """ + dataPrepWin.titleText = """ If you need to crop press the green tick button,
    otherwise you can close the window. """ - ) else: - print('') + print("") self.log( - f'WARNING: The image data in {img_path} is 3D but ' - f'_segmInfo.csv file not found. Launching dataPrep.py...' + f"WARNING: The image data in {img_path} is 3D but " + f"_segmInfo.csv file not found. Launching dataPrep.py..." ) self.logTerminal.setText( - f'The image data in {img_path} is 3D but ' - f'_segmInfo.csv file not found. Launching dataPrep.py...' + f"The image data in {img_path} is 3D but " + f"_segmInfo.csv file not found. Launching dataPrep.py..." ) msg = widgets.myMessageBox() txt = html_utils.paragraph(f""" @@ -766,15 +786,13 @@ def main(self): or projection for each Position or frame
    . """) msg.warning( - self, '3D z-stacks info missing', txt, - buttonsTexts=('Cancel', 'Ok') + self, "3D z-stacks info missing", txt, buttonsTexts=("Cancel", "Ok") ) if msg.cancel: self.processStopped() return - dataPrepWin.titleText = ( - """ + dataPrepWin.titleText = """ Select z-slice (or projection) for each frame/position.
    Then, if you want to segment the entire field of view, close the window.
    @@ -782,11 +800,9 @@ def main(self): press the "Start" button, draw the ROI
    and confirm with the green tick button. """ - ) autoStart = False dataPrepWin.initLoading() - dataPrepWin.loadFiles( - exp_path, user_ch_file_paths, user_ch_name) + dataPrepWin.loadFiles(exp_path, user_ch_file_paths, user_ch_name) if self.posData.SizeZ == 1: dataPrepWin.prepData(None) loop = QEventLoop(self) @@ -794,10 +810,7 @@ def main(self): loop.exec_() # If data was aligned then we make sure to load it here - user_ch_file_paths = load.get_user_ch_paths( - images_paths, - user_ch_name - ) + user_ch_file_paths = load.get_user_ch_paths(images_paths, user_ch_name) img_path = user_ch_file_paths[0] self.posData = load.loadData(img_path, user_ch_name, QParent=self) @@ -813,33 +826,36 @@ def main(self): load_dataPrep_ROIcoords=True, load_bkgr_data=False, load_last_tracked_i=False, - load_metadata=True + load_metadata=True, ) self.posData.isSegm3D = self.isSegm3D - elif self.posData.SizeZ > 1 and not self.isSegm3D and not self.use3DdataFor2Dsegm: + elif ( + self.posData.SizeZ > 1 and not self.isSegm3D and not self.use3DdataFor2Dsegm + ): df = self.posData.segmInfo_df.loc[self.posData.filename] - zz = df['z_slice_used_dataPrep'].to_list() + zz = df["z_slice_used_dataPrep"].to_list() isROIactive = False - if self.posData.dataPrep_ROIcoords is not None and not self.ROIdeactivatedByUser: + if ( + self.posData.dataPrep_ROIcoords is not None + and not self.ROIdeactivatedByUser + ): df_roi = self.posData.dataPrep_ROIcoords.loc[0] - isROIactive = df_roi.at['cropped', 'value'] == 0 - x0, x1, y0, y1 = df_roi['value'][:4] + isROIactive = df_roi.at["cropped", "value"] == 0 + x0, x1, y0, y1 = df_roi["value"][:4] df_roi = self.posData.dataPrep_ROIcoords.loc[0] - isROIactive = df_roi.at['cropped', 'value'] == 0 - x0, x1, y0, y1 = df_roi['value'][:4] + isROIactive = df_roi.at["cropped", "value"] == 0 + x0, x1, y0, y1 = df_roi["value"][:4] self.image_chName_tracker = None self.do_tracking = False self.tracker = None self.track_params = {} self.tracker_init_params = {} - self.trackerName = '' + self.trackerName = "" self.stopFrames = [1 for _ in range(len(user_ch_file_paths))] if self.posData.SizeT > 1: - win = apps.askStopFrameSegm( - user_ch_file_paths, user_ch_name, parent=self - ) + win = apps.askStopFrameSegm(user_ch_file_paths, user_ch_name, parent=self) win.setFont(font) win.exec_() if win.cancel: @@ -849,16 +865,18 @@ def main(self): self.stopFrames = win.stopFrames # Ask whether to track the frames - trackers = myutils.get_list_of_trackers() - txt = html_utils.paragraph(''' + trackers = utils.get_list_of_trackers() + txt = html_utils.paragraph(""" Do you want to track the objects?

    If yes, select the tracker to use

    - ''') + """) win = widgets.QDialogListbox( - 'Track objects?', txt, - trackers, additionalButtons=['Do not track'], + "Track objects?", + txt, + trackers, + additionalButtons=["Do not track"], multiSelection=False, - parent=self + parent=self, ) win.exec_() if win.cancel: @@ -868,14 +886,14 @@ def main(self): self.image_chName_tracker = None if win.clickedButton in win._additionalButtons: self.do_tracking = False - trackerName = '' + trackerName = "" self.trackerName = trackerName else: self.do_tracking = True trackerName = win.selectedItemsText[0] self.trackerName = trackerName - init_tracker_output = myutils.init_tracker( - self.posData, trackerName, return_init_params=True, qparent=self + init_tracker_output = utils.init_tracker( + self.posData, trackerName, return_init_params=True, qparent=self ) self.tracker, self.track_params, self.tracker_init_params = ( init_tracker_output @@ -883,24 +901,21 @@ def main(self): if self.track_params is None: self.processStopped() return - - if 'image_channel_name' in self.track_params: - # Store the channel name for the tracker for loading it + + if "image_channel_name" in self.track_params: + # Store the channel name for the tracker for loading it # in case of multiple pos self.image_chName_tracker = self.track_params.pop( - 'image_channel_name' + "image_channel_name" ) - self.progressLabel.setText('Starting main worker...') + self.progressLabel.setText("Starting main worker...") max = 0 for i, imgPath in enumerate(user_ch_file_paths): self._posData = load.loadData(imgPath, user_ch_name) self._posData.getBasenameAndChNames(qparent=self) - self._posData.loadOtherFiles( - load_segm_data=False, - load_metadata=True - ) + self._posData.loadOtherFiles(load_segm_data=False, load_metadata=True) if self.posData.SizeT > 1: max += self.stopFrames[i] else: @@ -912,7 +927,7 @@ def main(self): if self.innerPbar_available: self.QPbar.setMaximum(len(user_ch_file_paths)) else: - self.QPbar.setMaximum(max*2) + self.QPbar.setMaximum(max * 2) self.exec_time_per_iter = 0 self.exec_time_per_frame = 0 @@ -923,172 +938,164 @@ def main(self): self.exp_path = exp_path self.user_ch_file_paths = user_ch_file_paths self.user_ch_name = user_ch_name - + proceed, measurements_kernel = self.askSaveMeasurements() if not proceed: - self.logger.info('Segmentation process interrupted.') + self.logger.info("Segmentation process interrupted.") self.close() return self._measurements_kernel = measurements_kernel - + proceed = self.askRunNowOrSaveConfigFile() if not proceed: - self.logger.info('Segmentation process interrupted.') + self.logger.info("Segmentation process interrupted.") self.close() return - + t0 = time.perf_counter() for pos_idx, img_path in enumerate(self.user_ch_file_paths): stop_frame_n = self.stopFrames[pos_idx] - segmWorker, segmThread = self.startSegmWorker( - img_path, stop_frame_n - ) + segmWorker, segmThread = self.startSegmWorker(img_path, stop_frame_n) self.waitSegmWorker(segmWorker) if segmWorker.is_error: break - + t1 = time.perf_counter() - - self.processFinished(t1-t0) + + self.processFinished(t1 - t0) def criticalImagesFolderEmpty(self, images_path): - err_title = 'The images folder is empty' + err_title = "The images folder is empty" err_msg = html_utils.paragraph( - 'The following folder

    ' - f'{images_path}

    ' - 'is empty.

    ' + "The following folder

    " + f"{images_path}

    " + "is empty.

    " ) msg = widgets.myMessageBox() msg.addShowInFileManagerButton(images_path) msg.critical(self, err_title, err_msg) - + def criticalNoTifFound(self, images_path): - err_title = 'No .tif files found in folder.' + err_title = "No .tif files found in folder." err_msg = html_utils.paragraph( - 'The following folder

    ' - f'{images_path}

    ' - 'does not contain .tif or .h5 files.

    ' + "The following folder

    " + f"{images_path}

    " + "does not contain .tif or .h5 files.

    " 'Only .tif or .h5 files can be loaded with "Open Folder" button.

    ' - 'Try with File --> Open image/video file... ' - 'and directly select the file you want to load.' + "Try with File --> Open image/video file... " + "and directly select the file you want to load." ) msg = widgets.myMessageBox() msg.addShowInFileManagerButton(images_path) msg.critical(self, err_title, err_msg) - + def waitSegmWorker(self, worker): worker.loop = QEventLoop(self) worker.loop.exec_() - + def _saveConfigurationFile(self, filepath): init_args = { - 'user_ch_name': self.user_ch_name, - 'segm_endname': self.endFilenameSegm, - 'model_name': self.model_name, - 'tracker_name': self.trackerName, - 'do_tracking': self.do_tracking, - 'do_postprocess': self.applyPostProcessing, - 'do_save': self.save, - 'image_channel_tracker': self.image_chName_tracker, - 'isSegm3D': self.isSegm3D, - 'use_ROI': not self.ROIdeactivatedByUser, - 'second_channel_name': self.secondChannelName, - 'use3DdataFor2Dsegm': self.use3DdataFor2Dsegm, - 'reduce_memory_usage': self.reduce_memory_usage, - } - metadata_params = { - 'SizeT': self.SizeT, - 'SizeZ': self.SizeZ + "user_ch_name": self.user_ch_name, + "segm_endname": self.endFilenameSegm, + "model_name": self.model_name, + "tracker_name": self.trackerName, + "do_tracking": self.do_tracking, + "do_postprocess": self.applyPostProcessing, + "do_save": self.save, + "image_channel_tracker": self.image_chName_tracker, + "isSegm3D": self.isSegm3D, + "use_ROI": not self.ROIdeactivatedByUser, + "second_channel_name": self.secondChannelName, + "use3DdataFor2Dsegm": self.use3DdataFor2Dsegm, + "reduce_memory_usage": self.reduce_memory_usage, } + metadata_params = {"SizeT": self.SizeT, "SizeZ": self.SizeZ} track_params = { - key:value for key, value in self.track_params.items() - if key != 'image' + key: value for key, value in self.track_params.items() if key != "image" } ini_items = { - 'workflow': {'type': 'segmentation and/or tracking'}, - 'initialization': init_args, - 'metadata': metadata_params, - 'init_segmentation_model_params': self.init_model_kwargs, - 'segmentation_model_params': self.model_kwargs, - 'init_tracker_params': self.tracker_init_params, - 'tracker_params': track_params, - 'standard_postprocess_features': self.standardPostProcessKwargs, - 'custom_postprocess_features': self.customPostProcessFeatures, + "workflow": {"type": "segmentation and/or tracking"}, + "initialization": init_args, + "metadata": metadata_params, + "init_segmentation_model_params": self.init_model_kwargs, + "segmentation_model_params": self.model_kwargs, + "init_tracker_params": self.tracker_init_params, + "tracker_params": track_params, + "standard_postprocess_features": self.standardPostProcessKwargs, + "custom_postprocess_features": self.customPostProcessFeatures, } - preprocessing_items = config.preprocess_recipe_to_ini_items( - self.preproc_recipe - ) + preprocessing_items = config.preprocess_recipe_to_ini_items(self.preproc_recipe) ini_items = {**ini_items, **preprocessing_items} - + grouped_features = self.customPostProcessGroupedFeatures for category, metrics_names in grouped_features.items(): category_params = {} if isinstance(metrics_names, dict): for channel, channel_metrics in metrics_names.items(): - values = '\n'.join(channel_metrics) - values = f'\n{values}' + values = "\n".join(channel_metrics) + values = f"\n{values}" category_params[channel] = values else: - values = '\n'.join(metrics_names) - values = f'\n{values}' - category_params['names'] = values - ini_items[f'postprocess_features.{category}'] = category_params + values = "\n".join(metrics_names) + values = f"\n{values}" + category_params["names"] = values + ini_items[f"postprocess_features.{category}"] = category_params if self._measurements_kernel is not None: - ini_items['measurements'] = ( + ini_items["measurements"] = ( self._measurements_kernel.to_workflow_config_params() ) - + load.save_workflow_to_config( filepath, ini_items, self.user_ch_file_paths, self.stopFrames ) - + self.logger.info(f'Segmentation workflow saved to "{filepath}"') - + txt = html_utils.paragraph( - 'Segmentation workflow successfully saved to the following location:

    ' - f'{filepath}

    ' - 'You can run the segmentation workflow with the following command:' + "Segmentation workflow successfully saved to the following location:

    " + f"{filepath}

    " + "You can run the segmentation workflow with the following command:" ) command = f'acdc -p "{filepath}"' msg = widgets.myMessageBox(wrapText=False) msg.information( - self, 'Workflow save', txt, + self, + "Workflow save", + txt, commands=(command,), - path_to_browse=os.path.dirname(filepath) + path_to_browse=os.path.dirname(filepath), ) - + def saveWorkflowToConfigFile(self): - timestamp = datetime.datetime.now().strftime( - r'%Y-%m-%d_%H-%M' - ) + timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M") win = apps.filenameDialog( - parent=self, - ext='.ini', - title='Insert filename for configuration file', - hintText='Insert filename for the configuration file', - allowEmpty=False, - defaultEntry=f'{timestamp}_acdc_segm_track_workflow' + parent=self, + ext=".ini", + title="Insert filename for configuration file", + hintText="Insert filename for the configuration file", + allowEmpty=False, + defaultEntry=f"{timestamp}_acdc_segm_track_workflow", ) win.exec_() if win.cancel: return False - + config_filename = win.filename - mostRecentPath = myutils.getMostRecentPath() + mostRecentPath = utils.getMostRecentPath() folder_path = apps.get_existing_directory( allow_images_path=False, - parent=self, - caption='Select folder where to save configuration file', + parent=self, + caption="Select folder where to save configuration file", basedir=mostRecentPath, # options=QFileDialog.DontUseNativeDialog ) if not folder_path: return False - + config_filepath = os.path.join(folder_path, config_filename) self._saveConfigurationFile(config_filepath) - + def showHelpSaveMeasurements(self, parent=None): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(f""" @@ -1099,73 +1106,72 @@ def showHelpSaveMeasurements(self, parent=None): If you plan to visualize and correct segmentation results, and you need the measurements, you will anyway need to compute
    and save them after correcting the segmentations. - """ - ) - msg.information(parent, 'Help - Save measurements', txt) - + """) + msg.information(parent, "Help - Save measurements", txt) + def askSaveMeasurements(self): measurements_kernel = None - + if not self.save: return True, measurements_kernel - - acdcOutputEndname = ( - self.endFilenameSegm.replace('segm', 'acdc_output') - .replace('.npz', '.csv') + + acdcOutputEndname = self.endFilenameSegm.replace("segm", "acdc_output").replace( + ".npz", ".csv" ) txt = html_utils.paragraph(f""" Do you also want to save measurements in the {acdcOutputEndname} table after segmentation? """) msg = widgets.myMessageBox(wrapText=False) - saveButton = widgets.savePushButton('Yes, save measurements') - noSaveButton = widgets.noPushButton('Do not save measurements') - helpButton = widgets.helpPushButton('Help...') + saveButton = widgets.savePushButton("Yes, save measurements") + noSaveButton = widgets.noPushButton("Do not save measurements") + helpButton = widgets.helpPushButton("Help...") msg.question( - self, 'Save measurements?', txt, - buttonsTexts=( - 'Cancel', helpButton, noSaveButton, saveButton - ), - showDialog=False + self, + "Save measurements?", + txt, + buttonsTexts=("Cancel", helpButton, noSaveButton, saveButton), + showDialog=False, ) helpButton.clicked.disconnect() - helpButton.clicked.connect( - partial(self.showHelpSaveMeasurements, parent=msg) - ) + helpButton.clicked.connect(partial(self.showHelpSaveMeasurements, parent=msg)) msg.exec_() if msg.cancel: return False, measurements_kernel - + if not msg.clickedButton == saveButton: return True, measurements_kernel - - self.logger.info('Setting up measurements...') - - segmEndname = self.endFilenameSegm.replace('.npz', '') + + self.logger.info("Setting up measurements...") + + segmEndname = self.endFilenameSegm.replace(".npz", "") images_path = os.path.dirname(self.user_ch_file_paths[0]) pos_path = os.path.dirname(images_path) exp_path = os.path.dirname(pos_path) pos_foldernames = [ - os.path.basename(os.path.dirname(os.path.dirname(img_path))) + os.path.basename(os.path.dirname(os.path.dirname(img_path))) for img_path in self.user_ch_file_paths ] selectedExpPaths = {exp_path: pos_foldernames} - - from .utils import compute as utilsCompute + + from .tools import compute as utilsCompute + self.calcMeasUtility = utilsCompute.computeMeasurmentsUtilWin( - selectedExpPaths, self.app, segmEndname=segmEndname, - parent=self, doRunComputation=False + selectedExpPaths, + self.app, + segmEndname=segmEndname, + parent=self, + doRunComputation=False, ) self.calcMeasUtility.runWorker( - showProgress=False, - stopFrameNumber=self.stopFrames + showProgress=False, stopFrameNumber=self.stopFrames ) self.waitCalcMeasUtility() - + measurements_kernel = self.calcMeasUtility.worker.kernel - + return not self.calcMeasUtility.cancel, measurements_kernel - + def waitCalcMeasUtility(self): self.waitCalcMeasUtilityTimer = QTimer(self) self.waitCalcMeasUtilityTimer.timeout.connect( @@ -1179,7 +1185,7 @@ def checkCalcMeasUtilityFinished(self, calcMeasUtility): if calcMeasUtility.isWorkerFinished: self.waitCalcMeasUtilityLoop.exit() self.waitCalcMeasUtilityTimer.stop() - + def askRunNowOrSaveConfigFile(self): txt = html_utils.paragraph(""" Do you want to run the segmentation process now
    @@ -1190,24 +1196,24 @@ def askRunNowOrSaveConfigFile(self): (i.e., headless).
    """) msg = widgets.myMessageBox(wrapText=False) - saveButton = widgets.savePushButton('Save and run later') - runNowButton = widgets.playPushButton('Run now') + saveButton = widgets.savePushButton("Save and run later") + runNowButton = widgets.playPushButton("Run now") _, saveButton, runNowButton = msg.question( - self, 'Run workflow now?', txt, - buttonsTexts=( - 'Cancel', saveButton, runNowButton - ) + self, + "Run workflow now?", + txt, + buttonsTexts=("Cancel", saveButton, runNowButton), ) if msg.cancel: return False - + if msg.clickedButton == saveButton: saved = self.saveWorkflowToConfigFile() if not saved: return False - + return msg.clickedButton == runNowButton - + def askMultipleSegm(self, segm_files, isTimelapse=True): txt = html_utils.paragraph(""" At least one of the loaded positions already contains a @@ -1216,25 +1222,22 @@ def askMultipleSegm(self, segm_files, isTimelapse=True): NOTE: you will be able to choose a stop frame later.
    """) msg = widgets.myMessageBox(resizeButtons=False) - msg.setWindowTitle('Multiple segmentation files') + msg.setWindowTitle("Multiple segmentation files") msg.addText(txt) if len(segm_files) > 1: - overWriteText = 'Select segm. file to overwrite...' + overWriteText = "Select segm. file to overwrite..." else: - overWriteText = 'Overwrite existing segmentation file' + overWriteText = "Overwrite existing segmentation file" overWriteButton = widgets.savePushButton(overWriteText) - doNotSaveButton = widgets.noPushButton('Do not save') - newButton = widgets.newFilePushButton('Save as...') + doNotSaveButton = widgets.noPushButton("Do not save") + newButton = widgets.newFilePushButton("Save as...") msg.addCancelButton(connect=True) msg.addButton(overWriteButton) msg.addButton(newButton) msg.addButton(doNotSaveButton) - if len(segm_files)>1: + if len(segm_files) > 1: overWriteButton.clicked.disconnect() - func = partial( - self.selectSegmFile, segm_files, True, msg, - overWriteButton - ) + func = partial(self.selectSegmFile, segm_files, True, msg, overWriteButton) overWriteButton.clicked.connect(func) else: self.selectedSegmFile = segm_files[0] @@ -1253,29 +1256,25 @@ def askMultipleSegm(self, segm_files, isTimelapse=True): return askNewName def askHowToHandle2DsegmOn3Ddata(self): - txt = html_utils.paragraph( - 'How do you want to handle 3D data?' - ) - use3DButton = widgets.threeDPushButton( - 'Pass all z-slices to the model' - ) + txt = html_utils.paragraph("How do you want to handle 3D data?") + use3DButton = widgets.threeDPushButton("Pass all z-slices to the model") convertTo2DButton = widgets.twoDPushButton( - 'Use or select z-slices or projection from Data prep' - ) - buttons = ( - 'Cancel', use3DButton, convertTo2DButton + "Use or select z-slices or projection from Data prep" ) + buttons = ("Cancel", use3DButton, convertTo2DButton) msg = widgets.myMessageBox(wrapText=False) - msg.question(self, 'How to handle 3D data', txt, buttonsTexts=buttons) - + msg.question(self, "How to handle 3D data", txt, buttonsTexts=buttons) + return msg.cancel, msg.clickedButton == use3DButton - + def selectSegmFile(self, segm_files, isOverwrite, msg, button): - action = 'overwrite' if isOverwrite else 'concatenate to' + action = "overwrite" if isOverwrite else "concatenate to" selectSegmFileWin = widgets.QDialogListbox( - 'Select segmentation file', - f'Select segmentation file to {action}:\n', - segm_files, multiSelection=False, parent=msg + "Select segmentation file", + f"Select segmentation file to {action}:\n", + segm_files, + multiSelection=False, + parent=msg, ) selectSegmFileWin.exec_() if selectSegmFileWin.cancel: @@ -1287,12 +1286,12 @@ def selectSegmFile(self, segm_files, isOverwrite, msg, button): button.clicked.disconnect() button.clicked.connect(msg.buttonCallBack) button.click() - + def log(self, text): self.logger.info(text) try: self.logTerminal.append(text) - self.logTerminal.append('-'*30) + self.logTerminal.append("-" * 30) maxScrollbar = self.logTerminal.verticalScrollBar().maximum() self.logTerminal.verticalScrollBar().setValue(maxScrollbar) except AttributeError: @@ -1312,7 +1311,7 @@ def reset_innerQPbar(self, num_frames): def create_tqdm_pbar(self, num_frames): self.tqdm_pbar = tqdm( - total=num_frames, unit=' frames', ncols=75, file=self.logTerminal + total=num_frames, unit=" frames", ncols=75, file=self.logTerminal ) def update_tqdm_pbar(self, step): @@ -1322,22 +1321,23 @@ def close_tqdm(self): self.tqdm_pbar.close() def setPredictBuddingModel(self): - self.downloadYeastMate = apps.downloadModel('YeastMate', parent=self) + self.downloadYeastMate = apps.downloadModel("YeastMate", parent=self) self.downloadYeastMate.download() import models.YeastMate.acdcSegment as yeastmate + self.predictCcaState_model = yeastmate.Model() def startSegmWorker(self, img_path, stop_frame_n): thread = QThread() - + worker = SegmWorker(img_path, self, stop_frame_n) worker.is_error = False - + worker.moveToThread(thread) worker.signals.finished.connect(thread.quit) worker.signals.finished.connect(worker.deleteLater) thread.finished.connect(thread.deleteLater) - + worker.signals.finished.connect(self.segmWorkerFinished) worker.signals.progress.connect(self.segmWorkerProgress) worker.signals.progressBar.connect(self.segmWorkerProgressBar) @@ -1347,12 +1347,12 @@ def startSegmWorker(self, img_path, stop_frame_n): worker.signals.progress_tqdm.connect(self.update_tqdm_pbar) worker.signals.signal_close_tqdm.connect(self.close_tqdm) worker.signals.critical.connect(self.workerCritical) - + thread.started.connect(worker.run) thread.start() - + return worker, thread - + @exception_handler def workerCritical(self, out: Tuple[QObject, Exception]): worker, error = out @@ -1364,73 +1364,67 @@ def debugSegmWorker(self, lab): apps.imshow_tk(lab) def segmWorkerProgress(self, text): - print('-----------------------------------------') + print("-----------------------------------------") self.logger.info(text) self.progressLabel.setText(text) def segmWorkerProgressBar(self, step): - self.QPbar.setValue(self.QPbar.value()+step) - steps_left = self.QPbar.maximum()-self.QPbar.value() + self.QPbar.setValue(self.QPbar.value() + step) + steps_left = self.QPbar.maximum() - self.QPbar.value() # Update ETA every two calls of this function - if steps_left%2 == 0: + if steps_left % 2 == 0: t = time.time() self.exec_time_per_iter = t - self.time_last_pbar_update - groups_2steps_left = steps_left/2 - seconds = round(self.exec_time_per_iter*groups_2steps_left) - ETA = myutils.seconds_to_ETA(seconds) - self.ETA_label.setText(f'ETA: {ETA}') + groups_2steps_left = steps_left / 2 + seconds = round(self.exec_time_per_iter * groups_2steps_left) + ETA = utils.seconds_to_ETA(seconds) + self.ETA_label.setText(f"ETA: {ETA}") self.exec_time_per_iter = 0 self.time_last_pbar_update = t def segmWorkerInnerProgressBar(self, step): - self.innerQPbar.setValue(self.innerQPbar.value()+step) + self.innerQPbar.setValue(self.innerQPbar.value() + step) t = time.time() self.exec_time_per_frame = t - self.time_last_innerPbar_update - steps_left = self.QPbar.maximum()-self.QPbar.value() - seconds = round(self.exec_time_per_frame*steps_left) - ETA = myutils.seconds_to_ETA(seconds) - self.innerETA_label.setText(f'ETA: {ETA}') + steps_left = self.QPbar.maximum() - self.QPbar.value() + seconds = round(self.exec_time_per_frame * steps_left) + ETA = utils.seconds_to_ETA(seconds) + self.innerETA_label.setText(f"ETA: {ETA}") self.exec_time_per_frame = 0 self.time_last_innerPbar_update = t # Estimate total ETA current_numFrames = self.QPbar.maximum() - tot_seconds = round(self.exec_time_per_frame*current_numFrames) + tot_seconds = round(self.exec_time_per_frame * current_numFrames) numPos = self.QPbar.maximum() - allPos_seconds = tot_seconds*numPos - tot_seconds_left = allPos_seconds-tot_seconds - ETA = myutils.seconds_to_ETA(round(tot_seconds_left)) - total_ETA = self.ETA_label.setText(f'ETA: {ETA}') + allPos_seconds = tot_seconds * numPos + tot_seconds_left = allPos_seconds - tot_seconds + ETA = utils.seconds_to_ETA(round(tot_seconds_left)) + total_ETA = self.ETA_label.setText(f"ETA: {ETA}") - def segmWorkerFinished(self, worker): + def segmWorkerFinished(self, worker): worker.loop.exit() - + def processFinished(self, total_exec_time): - short_txt = 'Segmentation process finished!' + short_txt = "Segmentation process finished!" exec_time = round(total_exec_time) delta = datetime.timedelta(seconds=exec_time) - exec_time_delta = str(delta).split(',')[-1].strip() - h, m, s = str(exec_time_delta).split(':') - exec_time_delta = f'{int(h):02}h:{int(m):02}m:{int(s):02}s' + exec_time_delta = str(delta).split(",")[-1].strip() + h, m, s = str(exec_time_delta).split(":") + exec_time_delta = f"{int(h):02}h:{int(m):02}m:{int(s):02}s" items = ( - f'Total execution time: {exec_time_delta}
    ', - f'Selected folder: {self.exp_path}' - ) - txt = ( - 'Segmentation task ended.' - f'{html_utils.to_list(items)}' - ) - steps_left = self.QPbar.maximum()-self.QPbar.value() - self.QPbar.setValue(self.QPbar.value()+steps_left) - - txt = html_utils.paragraph( - f'{txt}
    {myutils.get_salute_string()}' + f"Total execution time: {exec_time_delta}
    ", + f"Selected folder: {self.exp_path}", ) + txt = f"Segmentation task ended.{html_utils.to_list(items)}" + steps_left = self.QPbar.maximum() - self.QPbar.value() + self.QPbar.setValue(self.QPbar.value() + steps_left) + + txt = html_utils.paragraph(f"{txt}
    {utils.get_salute_string()}") self.progressLabel.setText(short_txt) msg = widgets.myMessageBox(self, wrapText=False) msg.information( - self, 'Segmentation task ended.', txt, - path_to_browse=self.exp_path + self, "Segmentation task ended.", txt, path_to_browse=self.exp_path ) try: del self.posData @@ -1453,7 +1447,7 @@ def processStopped(self): msg = widgets.myMessageBox(showCentered=False) closeAnswer = msg.warning( - self, 'Execution cancelled', 'Segmentation task cancelled.' + self, "Execution cancelled", "Segmentation task cancelled." ) try: del self.posData @@ -1468,9 +1462,7 @@ def processStopped(self): except AttributeError: pass self.close() - - def warnSegmWorkerStillRunning(self): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph(""" @@ -1479,38 +1471,36 @@ def warnSegmWorkerStillRunning(self): Are you sure you want to continue? """) noButton, yesButton = msg.warning( - self, 'Process still running', txt, - buttonsTexts=( - 'No, wait for the process to end', - 'Yes, close Cell-ACDC' - ) + self, + "Process still running", + txt, + buttonsTexts=("No, wait for the process to end", "Yes, close Cell-ACDC"), ) if msg.cancel: return False return msg.clickedButton == yesButton def closeEvent(self, event): - print('') - self.log('Closing segmentation module...') + print("") + self.log("Closing segmentation module...") if self.buttonToRestore is not None: button, color, text = self.buttonToRestore button.setText(text) - button.setStyleSheet( - f'QPushButton {{background-color: {color};}}') + button.setStyleSheet(f"QPushButton {{background-color: {color};}}") self.mainWin.setWindowState(Qt.WindowNoState) self.mainWin.setWindowState(Qt.WindowActive) self.mainWin.raise_() - - self.log('Closing segmentation module logger...') + + self.log("Closing segmentation module logger...") handlers = self.logger.handlers[:] for handler in handlers: handler.close() self.logger.removeHandler(handler) - + try: self.model.closeLogger() except Exception as e: pass - - self.log('Segmentation module closed.') + + self.log("Segmentation module closed.") self.sigClosed.emit() diff --git a/cellacdc/segm_utils.py b/cellacdc/segm_utils.py index 790e59719..b1129d3df 100644 --- a/cellacdc/segm_utils.py +++ b/cellacdc/segm_utils.py @@ -7,9 +7,9 @@ import inspect +import os # for dbug +import json # for dbug -import os # for dbug -import json # for dbug def find_overlap(lab_1, lab_2): """ @@ -38,12 +38,14 @@ def find_overlap(lab_1, lab_2): return ID_overlap + def get_obj_from_rps(rps, ID): for obj in rps: if obj.label == ID: return obj return None + def get_box_coords(rps, prev_lab_shape, ID, padding): """ Calculate the coordinates of a bounding box around a given ID in a labeled image, @@ -73,15 +75,16 @@ def get_box_coords(rps, prev_lab_shape, ID, padding): return box_x_min, box_x_max, box_y_min, box_y_max + def find_overlapping_bboxs(IDs, bboxs, order=1): """ Finds and merges overlapping bounding boxes by considering chained overlaps. - + Parameters: - IDs: List of IDs corresponding to the bounding boxes. - bboxs: List of bounding boxes (x_min, x_max, y_min, y_max). - order: Number of times to perform the merging process. - + Returns: - new_bboxs: List of merged bounding boxes. """ @@ -92,16 +95,12 @@ def boxes_overlap(bbox1, bbox2): x_min2, x_max2, y_min2, y_max2 = bbox2 # Check if there's no overlap - if (x_max1 <= x_min2 or - x_max2 <= x_min1 or - y_max1 <= y_min2 or - y_max2 <= y_min1 - ): + if x_max1 <= x_min2 or x_max2 <= x_min1 or y_max1 <= y_min2 or y_max2 <= y_min1: return False else: return True - - IDs = [[ID] for ID in IDs] + + IDs = [[ID] for ID in IDs] for _ in range(order): merged = [False] * len(bboxs) # Keep track of whether a box has been merged @@ -115,7 +114,7 @@ def boxes_overlap(bbox1, bbox2): # Start with the current bbox as the base for merging current_merged_bbox = bbox merged[i] = True # Mark this box as merged - IDs_merged = IDs[i] # Keep track of the IDs that have been merged + IDs_merged = IDs[i] # Keep track of the IDs that have been merged # Try to merge it with all other boxes for j, other_bbox in enumerate(bboxs): @@ -131,17 +130,18 @@ def boxes_overlap(bbox1, bbox2): min(x_min1, x_min2), max(x_max1, x_max2), min(y_min1, y_min2), - max(y_max1, y_max2) + max(y_max1, y_max2), ) merged[j] = True # Mark the other box as merged - IDs_merged.extend(IDs[j]) # Add the IDs of the other box to the merged IDs + IDs_merged.extend( + IDs[j] + ) # Add the IDs of the other box to the merged IDs # Add the merged bbox to the new list new_bboxs.append(current_merged_bbox) new_IDs.append(IDs_merged) - # If no changes occur, break the loop early if len(new_bboxs) == len(bboxs): break @@ -152,6 +152,7 @@ def boxes_overlap(bbox1, bbox2): return IDs, bboxs + # def fast_border_touching_labels(label_img): # # Get unique labels from the four borders # border_labels = np.r_[ @@ -163,13 +164,23 @@ def boxes_overlap(bbox1, bbox2): # # Use np.unique once on the combined array # return np.unique(border_labels[border_labels != 0]) -def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, - win, posData, distance_filler_growth=1, - overlap_threshold=0.5, padding=0.4, - export_bbox_for_training=False, - ): + +def single_cell_seg( + model, + prev_lab, + curr_lab, + curr_img, + IDs, + new_unique_ID, + win, + posData, + distance_filler_growth=1, + overlap_threshold=0.5, + padding=0.4, + export_bbox_for_training=False, +): """ - Function to segment single cells in the current frame using the previous frame segmentation as a reference. + Function to segment single cells in the current frame using the previous frame segmentation as a reference. IDs is from the previous frame segmentation, and the current frame should have already been tracked so the IDs match! Args: model: eval function used to segment the cells @@ -206,10 +217,12 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, bboxs = [get_box_coords(prev_rp, prev_lab_shape, ID, padding) for ID in IDs] IDs_bboxs, bboxs = find_overlapping_bboxs(IDs, bboxs) - + assigned_IDs = [] - uses_diameter = inspect.signature(model.segment).parameters.get('diameter', None) is not None + uses_diameter = ( + inspect.signature(model.segment).parameters.get("diameter", None) is not None + ) for IDs, bbox in zip(IDs_bboxs, bboxs): box_x_min, box_x_max, box_y_min, box_y_max = bbox @@ -220,12 +233,16 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, IDs = np.array(IDs) box_curr_lab_other_IDs[np.isin(box_curr_lab_other_IDs, IDs)] = 0 - box_curr_lab_other_IDs_grown = skimage.segmentation.expand_labels(box_curr_lab_other_IDs, distance=distance_filler_growth) + box_curr_lab_other_IDs_grown = skimage.segmentation.expand_labels( + box_curr_lab_other_IDs, distance=distance_filler_growth + ) # Fill other IDs with random samples from the background indices_to_fill = np.where(box_curr_lab_other_IDs_grown != 0) - box_background = box_curr_img[box_curr_lab_other_IDs_grown==0] - random_samples = np.random.choice(box_background, size=indices_to_fill[0].shape, replace=True) + box_background = box_curr_img[box_curr_lab_other_IDs_grown == 0] + random_samples = np.random.choice( + box_background, size=indices_to_fill[0].shape, replace=True + ) box_curr_img[indices_to_fill] = random_samples # Run model, give it the diameter of cell if possible @@ -234,36 +251,42 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, for ID in IDs: obj = get_obj_from_rps(prev_rp, ID) diameters.append(obj.axis_major_length) - + if len(diameters) == 0: diameter = None else: diameter = np.mean(diameters) - model_kwargs['diameter'] = diameter - + model_kwargs["diameter"] = diameter + box_model_lab = segm_model_segment( - model, box_curr_img, model_kwargs, + model, + box_curr_img, + model_kwargs, preproc_recipe=preproc_recipe, posData=posData, ) if export_bbox_for_training: - bboxs_for_debug.append([IDs, bbox, box_model_lab.copy(), box_curr_lab.copy()]) + bboxs_for_debug.append( + [IDs, bbox, box_model_lab.copy(), box_curr_lab.copy()] + ) - # Post-processing + # Post-processing if applyPostProcessing: box_model_lab = post_process_segm( box_model_lab, **standardPostProcessKwargs ) if customPostProcessFeatures: box_model_lab = custom_post_process_segm( - posData, - customPostProcessGroupedFeatures, - box_model_lab, box_curr_img, posData.frame_i, - posData.filename, - posData.user_ch_name, - customPostProcessFeatures + posData, + customPostProcessGroupedFeatures, + box_model_lab, + box_curr_img, + posData.frame_i, + posData.filename, + posData.user_ch_name, + customPostProcessFeatures, ) ### maybe add roi extension if cells are deleted... @@ -275,7 +298,7 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, for ID, overlap_perc in overlap: if overlap_perc > overlap_threshold: box_model_lab[box_model_lab == ID] = 0 - + rp_model_lab = skimage.measure.regionprops(box_model_lab) for obj in rp_model_lab: box_curr_lab_other_IDs[box_model_lab == obj.label] = new_unique_ID @@ -283,7 +306,9 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, new_unique_ID += 1 positive_mask = box_curr_lab_other_IDs > 0 - curr_lab[box_x_min:box_x_max, box_y_min:box_y_max][positive_mask] = box_curr_lab_other_IDs[positive_mask] + curr_lab[box_x_min:box_x_max, box_y_min:box_y_max][positive_mask] = ( + box_curr_lab_other_IDs[positive_mask] + ) if export_bbox_for_training: bboxs_for_debug[-1].append(box_curr_lab_other_IDs.copy()) @@ -291,13 +316,20 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, if export_bbox_for_training: frame_i = posData.frame_i - os.makedirs(os.path.join(posData.images_path, ".train_box_data", posData.filename), exist_ok=True) + os.makedirs( + os.path.join(posData.images_path, ".train_box_data", posData.filename), + exist_ok=True, + ) - npz_filepath = os.path.join(posData.images_path, ".train_box_data", posData.filename) - json_filepath = os.path.join(posData.images_path, ".train_box_data", posData.filename, 'info.json') + npz_filepath = os.path.join( + posData.images_path, ".train_box_data", posData.filename + ) + json_filepath = os.path.join( + posData.images_path, ".train_box_data", posData.filename, "info.json" + ) try: - with open(json_filepath, 'r') as f: + with open(json_filepath, "r") as f: loaded_dict = json.load(f) except FileNotFoundError: loaded_dict = {} @@ -311,14 +343,21 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, end_i = start_i + len(bboxs_for_debug) for i in range(start_i, end_i): - IDs, bbox, box_model_lab, box_prev_lab, box_final_lab = bboxs_for_debug[i - start_i] + IDs, bbox, box_model_lab, box_prev_lab, box_final_lab = bboxs_for_debug[ + i - start_i + ] npz_path = os.path.join(npz_filepath, f"{frame_i}_{i}.npz") - io.savez_compressed(npz_path, box_model_lab=box_model_lab, box_prev_lab=box_prev_lab, box_final_lab=box_final_lab) + io.savez_compressed( + npz_path, + box_model_lab=box_model_lab, + box_prev_lab=box_prev_lab, + box_final_lab=box_final_lab, + ) bboxs_info.append([IDs, bbox, npz_path]) - + loaded_dict[frame_i] = bboxs_info - with open(json_filepath, 'w') as f: + with open(json_filepath, "w") as f: json.dump(loaded_dict, f, indent=4) - return curr_lab, assigned_IDs, IDs_bboxs, bboxs \ No newline at end of file + return curr_lab, assigned_IDs, IDs_bboxs, bboxs diff --git a/cellacdc/segmentation.py b/cellacdc/segmentation.py index 9bfba56ad..4dfafa578 100644 --- a/cellacdc/segmentation.py +++ b/cellacdc/segmentation.py @@ -5,21 +5,17 @@ import cv2 -def _find_contours_2D( - image, bbox_lower_coords=(0, 0), all=False, closed=True - ): + +def _find_contours_2D(image, bbox_lower_coords=(0, 0), all=False, closed=True): mode = cv2.RETR_CCOMP if all else cv2.RETR_EXTERNAL contours, _ = cv2.findContours(image, mode, cv2.CHAIN_APPROX_NONE) - + if all: all_contours = [ - np.squeeze(contour, axis=1)+bbox_lower_coords - for contour in contours + np.squeeze(contour, axis=1) + bbox_lower_coords for contour in contours ] if closed: - all_contours = [ - np.vstack((contour, contour[0])) for contour in contours - ] + all_contours = [np.vstack((contour, contour[0])) for contour in contours] return all_contours else: contour = np.squeeze(contours[0], axis=1) @@ -28,17 +24,21 @@ def _find_contours_2D( contour = contour + bbox_lower_coords return contour + def find_obj_contour( - obj: skimage.measure._regionprops.RegionProperties, all=False, - local=False, do_z_max_proj=False, closed=True - ): + obj: skimage.measure._regionprops.RegionProperties, + all=False, + local=False, + do_z_max_proj=False, + closed=True, +): is3D = obj.image.ndim == 3 bbox_y_idx = 1 if is3D else 0 if local: - bbox_lower_coords=(0, 0) + bbox_lower_coords = (0, 0) else: - min_y, min_x = obj.bbox[bbox_y_idx:bbox_y_idx+2] + min_y, min_x = obj.bbox[bbox_y_idx : bbox_y_idx + 2] bbox_lower_coords = (min_x, min_y) if is3D and do_z_max_proj: @@ -47,23 +47,18 @@ def find_obj_contour( else: obj_image = obj.image.astype(np.uint8) - kwargs = { - 'bbox_lower_coords': bbox_lower_coords, - 'all':all, 'closed': closed - } + kwargs = {"bbox_lower_coords": bbox_lower_coords, "all": all, "closed": closed} if is3D: - contours = [ - _find_contours_2D(image_z, **kwargs) for image_z in obj_image - ] + contours = [_find_contours_2D(image_z, **kwargs) for image_z in obj_image] else: contours = _find_contours_2D(obj_image, **kwargs) return contours + def find_contours( - label_img, connectivity=1, mode='thick', background=0, - return_coords=False, **kwargs - ): - """Return bool array where boundaries between labeled regions are True. + label_img, connectivity=1, mode="thick", background=0, return_coords=False, **kwargs +): + """Return bool array where boundaries between labeled regions are True. If `return_coords` is True then return also a list of objects' contours coordinates. @@ -92,7 +87,7 @@ def find_contours( marked. - subpixel: return a doubled image, with pixels *between* the original pixels marked as boundary where appropriate., - + By default 'thick' background : int, optional For modes 'inner' and 'outer', a definition of a background @@ -102,8 +97,8 @@ def find_contours( If ``True``, also return a list of objects' contours coordinates, by default False kwargs : dict, optional - Additional arguments passed `acdctools.segmentation.find_obj_contour` - function. This function uses the opencv find contours function + Additional arguments passed `acdctools.segmentation.find_obj_contour` + function. This function uses the opencv find contours function `cv2.findContours`. Used only if `mode='inner'`. Returns @@ -115,25 +110,25 @@ def find_contours( inserted in between all other pairs of pixels). contours_coords: list of ndarray A list of ndarrays with shape (N, n) where `n` is the number of - dimensions of `label_img` and `N` is the number of points in each - object's contour. The list contains one ndarray per object in - `label_img`. - The ordering of columns follows the numpy's order of dimensions - convention, e.g., for 2-D, the first and second column are the - y and x coordinates, respectively. + dimensions of `label_img` and `N` is the number of points in each + object's contour. The list contains one ndarray per object in + `label_img`. + The ordering of columns follows the numpy's order of dimensions + convention, e.g., for 2-D, the first and second column are the + y and x coordinates, respectively. Only provided if `return_coords` is True. - """ + """ boundaries = skimage.segmentation.find_boundaries( label_img, connectivity=connectivity, mode=mode, background=background ) if not return_coords: return boundaries - + is2D = label_img.ndim == 2 rp = skimage.measure.regionprops(label_img) contours_coords = [] for obj in rp: - if mode == 'inner' and is2D: + if mode == "inner" and is2D: pass else: pass diff --git a/cellacdc/models/BABY/__init__.py b/cellacdc/segmenters/BABY/__init__.py similarity index 59% rename from cellacdc/models/BABY/__init__.py rename to cellacdc/segmenters/BABY/__init__.py index 04cce55ec..02f9914df 100644 --- a/cellacdc/models/BABY/__init__.py +++ b/cellacdc/segmenters/BABY/__init__.py @@ -1,2 +1,2 @@ # Installation of BABY is taken care of in the tracker implementation -from cellacdc.trackers.BABY import BABY_MODELS \ No newline at end of file +from cellacdc.trackers.BABY import BABY_MODELS diff --git a/cellacdc/models/BABY/acdcSegment.py b/cellacdc/segmenters/BABY/acdcSegment.py similarity index 61% rename from cellacdc/models/BABY/acdcSegment.py rename to cellacdc/segmenters/BABY/acdcSegment.py index 1d6706d78..d494f62c1 100644 --- a/cellacdc/models/BABY/acdcSegment.py +++ b/cellacdc/segmenters/BABY/acdcSegment.py @@ -3,33 +3,32 @@ from baby import modelsets from baby import BabyCrawler -from cellacdc import myutils +from cellacdc import utils from cellacdc.trackers import BABY from cellacdc.trackers.BABY import BABY_tracker + class AvailableModels: values = BABY.BABY_MODELS + class Model: def __init__( - self, - model_name: AvailableModels='yeast-alcatras-brightfield-sCMOS-60x-5z', - ): + self, + model_name: AvailableModels = "yeast-alcatras-brightfield-sCMOS-60x-5z", + ): self.tracker = BABY_tracker.tracker(model_name) - + def segment( - self, image, - refine_outlines=True, - swap_YX_axes_to_XY=True, - PhysicalSizeX=1.0 - ): + self, image, refine_outlines=True, swap_YX_axes_to_XY=True, PhysicalSizeX=1.0 + ): Y, X = image.shape[-2:] lab = np.zeros((Y, X), dtype=np.uint32) - + image = self.tracker._preprocess(image, swap_YX_axes_to_XY) - + result_generator = self.tracker.crawler.baby_brain.segment( - image[None, ...], + image[None, ...], pixel_size=PhysicalSizeX, overlap_size=48, yield_edgemasks=False, @@ -38,23 +37,20 @@ def segment( yield_volumes=False, refine_outlines=refine_outlines, yield_rescaling=False, - keep_bb_pixel_size=False + keep_bb_pixel_size=False, ) - + for result in result_generator: - masks = result['masks'] + masks = result["masks"] areas_mapper = { - m: (np.count_nonzero(mask), mask) - for m, mask in enumerate(masks) + m: (np.count_nonzero(mask), mask) for m, mask in enumerate(masks) } areas_mapper = dict( - sorted(areas_mapper.items(), - key=lambda item: item[1][0], - reverse=True) + sorted(areas_mapper.items(), key=lambda item: item[1][0], reverse=True) ) for i, (_, mask) in areas_mapper.items(): if swap_YX_axes_to_XY: mask = np.swapaxes(mask, 0, 1) - lab[mask] = i+1 - + lab[mask] = i + 1 + return lab diff --git a/cellacdc/segmenters/Cellpose_germlineNuclei/__init__.py b/cellacdc/segmenters/Cellpose_germlineNuclei/__init__.py new file mode 100644 index 000000000..a9ac3e8ce --- /dev/null +++ b/cellacdc/segmenters/Cellpose_germlineNuclei/__init__.py @@ -0,0 +1,3 @@ +from cellacdc import utils + +utils.check_install_cellpose() diff --git a/cellacdc/segmenters/Cellpose_germlineNuclei/acdcSegment.py b/cellacdc/segmenters/Cellpose_germlineNuclei/acdcSegment.py new file mode 100644 index 000000000..5331a6fd5 --- /dev/null +++ b/cellacdc/segmenters/Cellpose_germlineNuclei/acdcSegment.py @@ -0,0 +1,224 @@ +import os +import numpy as np + +from skimage.measure import label as skiLabel +import math +import scipy +import scipy.ndimage + +import skimage.exposure +import skimage.filters +import skimage.measure + +from cellpose import models +from cellacdc import user_profile_path + +default_model_path = os.path.join( + user_profile_path, "acdc-Cellpose_germlineNuclei", "cellpose_germlineNuclei_2023" +) + + +class Model: + def __init__(self, model_path: os.PathLike = default_model_path, gpu=False): + self.model = models.CellposeModel( + gpu=gpu, diam_mean=30, pretrained_model=model_path + ) + + def setupLogger(self, logger): + models.models_logger = logger + + def setLoggerPropagation(self, propagate: bool): + models.models_logger.propagate = propagate + + def setLoggerLevel(self, level: str): + import logging + + if level == "error": + models.models_logger.setLevel(logging.ERROR) + + def closeLogger(self): + handlers = models.models_logger.handlers[:] + for handler in handlers: + handler.close() + models.models_logger.removeHandler(handler) + + def _eval(self, image, **kwargs): + return self.model.eval(image.astype(np.float32), **kwargs)[0] + + def _initialize_image(self, image): + # See cellpose.gui.io._initialize_images + if image.ndim > 3: + # make tiff Z x channels x W x H + if image.shape[0] < 4: + # tiff is channels x Z x W x H + image = np.transpose(image, (1, 0, 2, 3)) + elif image.shape[-1] < 4: + # tiff is Z x W x H x channels + image = np.transpose(image, (0, 3, 1, 2)) + # fill in with blank channels to make 3 channels + if image.shape[1] < 3: + shape = image.shape + shape_to_concat = (shape[0], 3 - shape[1], shape[2], shape[3]) + to_concat = np.zeros(shape_to_concat, dtype=np.uint8) + image = np.concatenate((image, to_concat), axis=1) + image = np.transpose(image, (0, 2, 3, 1)) + elif image.ndim == 3: + if image.shape[0] < 5: + image = np.transpose(image, (1, 2, 0)) + if image.shape[-1] < 3: + shape = image.shape + # if parent.autochannelbtn.isChecked(): + # image = normalize99(image) * 255 + shape_to_concat = (shape[0], shape[1], 3 - shape[2]) + to_concat = np.zeros(shape_to_concat, dtype=type(image[0, 0, 0])) + image = np.concatenate((image, to_concat), axis=-1) + image = image[np.newaxis, ...] + elif image.shape[-1] < 5 and image.shape[-1] > 2: + image = image[:, :, :3] + # if parent.autochannelbtn.isChecked(): + # image = normalize99(image) * 255 + image = image[np.newaxis, ...] + else: + image = image[np.newaxis, ...] + + if image.ndim < 4: + image = image[:, :, :, np.newaxis] + return image + + def segment( + self, + image, + diameter_um=3.5, + blurfactor=2.50, + PhysicalSizeZ=1.0001, + PhysicalSizeY=1.0001, + PhysicalSizeX=1.0001, + cellprob_threshold=0.0, + clean_borders=False, + ): + """Cellpose model for C. elegans germline nuclei. This model works on a single channel only. + + Parameters + ---------- + diameter_um : float + Expected diameter of a nucleus in micrometer + blurfactor : float + Sigma value of the gaussian filter used for blurring of the data. + PhysicalSizeZ : float + Spacing of slices in z (unit: micrometer/slice). Prepopulated from image metadata + PhysicalSizeY : float + Pixelsize in y (unit: micrometer/pixel). Prepopulated from image metadata + PhysicalSizeX : float + Pixelsize in x (unit: micrometer/pixel). Prepopulated from image metadata + cellprob_threshold : float + cellprob_threshold for cellpose. + clean_borders : bool + Remove masks that touch the top or bottom slice in z, or that are closer than 2 pixels to the edges in x or y. + + Returns + ----- + np.ndarray + Instance segmentation array with the same shape as the input image. + """ + + # Preprocess image + # image = image/image.max() + # image = skimage.filters.gaussian(image, sigma=1) + # image = skimage.exposure.equalize_adapthist(image) + zspacing = PhysicalSizeZ + xysize = np.mean([PhysicalSizeX, PhysicalSizeY]) + + isRGB = image.shape[-1] == 3 or image.shape[-1] == 4 + if isRGB: + raise TypeError( + "This model was trained for 1 channel only. Please specify a single channel (DNA or synaptonemal complex/axis staining). " + ) + + isZstack = (image.ndim == 3 and not isRGB) or (image.ndim == 4) + + anisotropy = math.ceil(abs(zspacing / xysize)) + pxScale = xysize * 30 / diameter_um + + do_3D = True + + # if stitch_threshold > 0: + # do_3D = False + + channels = [0, 0] + + # Run cellpose eval + if not isZstack: + raise TypeError( + "This script is for 3D data (at least 5 slices) only. If needed, please modify the script to segment 2D data." + ) + else: + img_scaled = np.zeros( + ( + image.shape[0], + round(image.shape[1] * pxScale), + round(image.shape[2] * pxScale), + ) + ) + img_blur = np.zeros((img_scaled.shape)) + image[image == 0] = np.quantile(image[image > 0], 0.01) + + if pxScale > 1: + for i in range(image.shape[0]): + img_scaled[i, :, :] = scipy.ndimage.zoom( + image[i, :, :], pxScale, order=3 + ) + img_blur[i, :, :] = scipy.ndimage.gaussian_filter( + img_scaled[i, :, :], blurfactor + ) + + else: + for i in range(image.shape[0]): + img_scaled[i, :, :] = scipy.ndimage.zoom( + image[i, :, :], pxScale, order=3 + ) + img_blur[i, :, :] = scipy.ndimage.gaussian_filter( + img_scaled[i, :, :], blurfactor + ) + img_blur = self._initialize_image(img_blur) + labels_scaled, flows_blur, styles_blur = self.model.eval( + img_blur.astype(np.uint16), + diameter=30, + channels=channels, + do_3D=True, + anisotropy=anisotropy, + batch_size=3, + cellprob_threshold=cellprob_threshold, + ) + + labels = np.zeros(image.shape, dtype=labels_scaled.dtype) + for i in range(image.shape[0]): + labels[i, :, :] = scipy.ndimage.zoom( + labels_scaled[i, :, :], + ( + image.shape[1] / labels_scaled.shape[1], + image.shape[2] / labels_scaled.shape[2], + ), + order=0, + ) + + if clean_borders: + idx = np.unique( + np.concatenate( + [ + np.unique(labels[-1, :, :][labels[-1, :, :] > 0]), + np.unique(labels[0, :, :][labels[0, :, :] > 0]), + np.unique(labels[:, 0:2, :][labels[:, 0:2, :] > 0]), + np.unique(labels[:, -3:-1, :][labels[:, -3:-1, :] > 0]), + np.unique(labels[:, :, 0:2][labels[:, :, 0:2] > 0]), + np.unique(labels[:, :, -3:-1][labels[:, :, -3:-1] > 0]), + ] + ) + ) + + labels[np.isin(labels, idx)] = 0 + + return labels + + +def url_help(): + return "https://cellpose.readthedocs.io/en/latest/api.html" diff --git a/cellacdc/models/DeepSea/__init__.py b/cellacdc/segmenters/DeepSea/__init__.py similarity index 51% rename from cellacdc/models/DeepSea/__init__.py rename to cellacdc/segmenters/DeepSea/__init__.py index d2fcbf0be..c89b5267e 100644 --- a/cellacdc/models/DeepSea/__init__.py +++ b/cellacdc/segmenters/DeepSea/__init__.py @@ -3,62 +3,64 @@ import numpy as np -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_torch() -myutils.check_install_package('deepsea') -myutils.check_install_package('munkres') +utils.check_install_torch() +utils.check_install_package("deepsea") +utils.check_install_package("munkres") import torch import torchvision.transforms as transforms from PIL import Image -_, deepsea_models_path = myutils.get_model_path('deepsea', create_temp_dir=False) +_, deepsea_segmenters_path = utils.get_model_path("deepsea", create_temp_dir=False) -image_size = [383,512] +image_size = [383, 512] image_means = [0.5] image_stds = [0.5] + def _get_segm_transforms(): - return transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(image_size), - transforms.ToTensor(), - transforms.Normalize(mean=image_means, std=image_stds) - ]) - -def _init_model( - checkpoint_filename, DeepSeaClass, gpu=False - ): - # Initialize torch device + return transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean=image_means, std=image_stds), + ] + ) + + +def _init_model(checkpoint_filename, DeepSeaClass, gpu=False): + # Initialize torch device if gpu: from cellacdc import is_mac import platform + cpu = platform.processor() - if is_mac and cpu == 'arm': - device = 'cpu' + if is_mac and cpu == "arm": + device = "cpu" else: - device = 'cuda' + device = "cuda" else: - device = 'cpu' - + device = "cpu" + torch_device = torch.device(device) # Initialize checkpoint - checkpoint_path = os.path.join(deepsea_models_path, checkpoint_filename) + checkpoint_path = os.path.join(deepsea_segmenters_path, checkpoint_filename) checkpoint = torch.load(checkpoint_path, map_location=torch_device) - model = DeepSeaClass( - n_channels=1, n_classes=2, bilinear=True - ) + model = DeepSeaClass(n_channels=1, n_classes=2, bilinear=True) model.load_state_dict(checkpoint) model = model.to(torch_device) return torch_device, checkpoint, model + def _resize_img(img: Union[Image.Image, np.ndarray], device, transforms): tensor_img = transforms(img).to(device=device, dtype=torch.float32) - resized_img = tensor_img.cpu().numpy()[0,:,:] + resized_img = tensor_img.cpu().numpy()[0, :, :] img_min = np.min(resized_img) img_max = np.max(resized_img) img_range = img_max - img_min diff --git a/cellacdc/models/DeepSea/acdcSegment.py b/cellacdc/segmenters/DeepSea/acdcSegment.py similarity index 80% rename from cellacdc/models/DeepSea/acdcSegment.py rename to cellacdc/segmenters/DeepSea/acdcSegment.py index a8bdf40a8..ece0496ce 100644 --- a/cellacdc/models/DeepSea/acdcSegment.py +++ b/cellacdc/segmenters/DeepSea/acdcSegment.py @@ -10,7 +10,7 @@ import skimage.measure from deepsea.model import DeepSeaSegmentation -from cellacdc import myutils, printl +from cellacdc import utils, printl from . import _init_model from . import _get_segm_transforms @@ -22,19 +22,20 @@ torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True + class Model: def __init__(self, gpu=False): torch_device, checkpoint, model = _init_model( - 'segmentation.pth', DeepSeaSegmentation, gpu=gpu + "segmentation.pth", DeepSeaSegmentation, gpu=gpu ) self.torch_device = torch_device self._transforms = _get_segm_transforms() self._checkpoint = checkpoint self.model = model - + def segment(self, image: np.ndarray): is_rgb_image = image.shape[-1] == 3 or image.shape[-1] == 4 - is_z_stack = (image.ndim==3 and not is_rgb_image) or (image.ndim==4) + is_z_stack = (image.ndim == 3 and not is_rgb_image) or (image.ndim == 4) labels = np.zeros(image.shape, dtype=np.uint32) if is_rgb_image: labels = np.zeros(image.shape[:-1], dtype=np.uint32) @@ -49,24 +50,23 @@ def segment(self, image: np.ndarray): else: labels = self._segment_2D_image(image, (Y, X)) return labels - + def _segment_2D_image(self, img: np.ndarray, grayscale_img_shape): try: img = (255 * ((img - img.min()) / img.ptp())).astype(np.uint8) except AttributeError as e: img = (255 * ((img - img.min()) / np.ptp(img))).astype(np.uint8) - tensor_img = ( - self._transforms(img) - .to(device=self.torch_device, dtype=torch.float32) + tensor_img = self._transforms(img).to( + device=self.torch_device, dtype=torch.float32 ) _eval = self.model.eval() mask_pred, edge_pred = _eval(tensor_img.unsqueeze(0)) - mask_pred = transforms.Resize( - grayscale_img_shape, antialias=True - ).forward(mask_pred) + mask_pred = transforms.Resize(grayscale_img_shape, antialias=True).forward( + mask_pred + ) mask_pred = mask_pred.argmax(dim=1).cpu().numpy()[0, :, :] mask_bool = mask_pred > 0 lab = skimage.measure.label(np.squeeze(mask_bool)) - + return lab diff --git a/cellacdc/segmenters/InstanSeg/__init__.py b/cellacdc/segmenters/InstanSeg/__init__.py new file mode 100644 index 000000000..3fc7f61e3 --- /dev/null +++ b/cellacdc/segmenters/InstanSeg/__init__.py @@ -0,0 +1,5 @@ +from cellacdc import utils + +utils.check_install_instanseg() + +INSTANSEG_MODELS = ("fluorescence_nuclei_and_cells", "brightfield_nuclei") diff --git a/cellacdc/models/InstanSeg/acdcSegment.py b/cellacdc/segmenters/InstanSeg/acdcSegment.py similarity index 52% rename from cellacdc/models/InstanSeg/acdcSegment.py rename to cellacdc/segmenters/InstanSeg/acdcSegment.py index 661c32cee..da6a8e2f4 100644 --- a/cellacdc/models/InstanSeg/acdcSegment.py +++ b/cellacdc/segmenters/InstanSeg/acdcSegment.py @@ -2,90 +2,85 @@ from instanseg import InstanSeg -from ... import myutils, printl +from ... import utils, printl from ..._types import SecondChannelImage from . import INSTANSEG_MODELS + class AvailabelModels: values = INSTANSEG_MODELS + class AvailableDevices: - values = ( - 'Auto', 'GPU', 'CPU' - ) + values = ("Auto", "GPU", "CPU") + class VerbosityValues: - values = ( - 'Silent', 'Normal', 'Verbose' - ) + values = ("Silent", "Normal", "Verbose") + class ChannelOrder: - values = ( - 'First channel', 'Second channel' - ) + values = ("First channel", "Second channel") + class Model: def __init__( - self, - model_type: AvailabelModels='fluorescence_nuclei_and_cells', - custom_model_type: str='', - device: AvailableDevices='Auto', - verbosity: VerbosityValues='1' - ) -> None: + self, + model_type: AvailabelModels = "fluorescence_nuclei_and_cells", + custom_model_type: str = "", + device: AvailableDevices = "Auto", + verbosity: VerbosityValues = "1", + ) -> None: if custom_model_type: model_type = custom_model_type - - if device == 'Auto': - device = myutils.get_torch_device(gpu=True) - elif device == 'CPU': - device = 'cpu' - elif device == 'GPU': - device = myutils.get_torch_device(gpu=True) - - self.model = InstanSeg( - model_type, - device=device, - verbosity=verbosity - ) + + if device == "Auto": + device = utils.get_torch_device(gpu=True) + elif device == "CPU": + device = "cpu" + elif device == "GPU": + device = utils.get_torch_device(gpu=True) + + self.model = InstanSeg(model_type, device=device, verbosity=verbosity) def preprocess(self, image, rescale_intensities, warn=True): if rescale_intensities: image_min = image - image.min() - image_float = image_min/image_min.max() + image_float = image_min / image_min.max() else: - image_float = myutils.img_to_float(image, warn=warn) - - return (image_float*255).astype(np.uint8) - + image_float = utils.img_to_float(image, warn=warn) + + return (image_float * 255).astype(np.uint8) + def segment( - self, - image, - second_channel_image: SecondChannelImage=None, - return_masks_for_channel: ChannelOrder='First channel', - PhysicalSizeX: float=1.0, - do_not_resize_to_pixel_size: bool=False, - rescale_intensities: bool=False - ): + self, + image, + second_channel_image: SecondChannelImage = None, + return_masks_for_channel: ChannelOrder = "First channel", + PhysicalSizeX: float = 1.0, + do_not_resize_to_pixel_size: bool = False, + rescale_intensities: bool = False, + ): if do_not_resize_to_pixel_size: PhysicalSizeX = None - + image_in = image if second_channel_image is not None: image_in = self.second_ch_img_to_stack(image, second_channel_image) - + image_in = self.preprocess(image_in, rescale_intensities) - + if image_in.shape[-1] > 2: image_in = image_in[..., np.newaxis] - + is_zstack = image_in.ndim == 4 - + if isinstance(return_masks_for_channel, int): masks_index = return_masks_for_channel else: - masks_index = 0 if return_masks_for_channel == 'First channel' else 1 - + masks_index = 0 if return_masks_for_channel == "First channel" else 1 + if is_zstack: lab = np.zeros((image_in.shape[:3]), dtype=np.uint32) for z, img in enumerate(image_in): @@ -93,26 +88,22 @@ def segment( img, PhysicalSizeX, masks_index=masks_index ) else: - lab = self._segment_2D_img( - image_in, PhysicalSizeX, masks_index=masks_index - ) - + lab = self._segment_2D_img(image_in, PhysicalSizeX, masks_index=masks_index) + return lab - + def _segment_2D_img(self, image, PhysicalSizeX, masks_index=0): - labeled_output, image_tensor = self.model.eval_small_image( - image, PhysicalSizeX - ) + labeled_output, image_tensor = self.model.eval_small_image(image, PhysicalSizeX) labels = labeled_output[0].cpu().detach().numpy() lab = labels[masks_index].astype(np.uint32) return lab - + def second_ch_img_to_stack(self, image, second_image): img_stack = np.zeros((*image.shape, 2)) img_stack[..., 0] = image img_stack[..., 1] = second_image return img_stack - - + + def url_help(): - return 'https://github.com/instanseg/instanseg' \ No newline at end of file + return "https://github.com/instanseg/instanseg" diff --git a/cellacdc/models/StarDist/__init__.py b/cellacdc/segmenters/StarDist/__init__.py similarity index 65% rename from cellacdc/models/StarDist/__init__.py rename to cellacdc/segmenters/StarDist/__init__.py index 14bd152b0..626441990 100755 --- a/cellacdc/models/StarDist/__init__.py +++ b/cellacdc/segmenters/StarDist/__init__.py @@ -2,11 +2,11 @@ import sys import subprocess -from cellacdc import myutils +from cellacdc import utils -note = '' -if sys.platform == 'darwin': - note = (""" +note = "" +if sys.platform == "darwin": + note = """

    NOTE for M1 mac users: if you are on MacOS with an Apple Silicon processor cancel this operation and follow the @@ -15,23 +15,24 @@ here.

    - """) -myutils.check_install_package('tensorflow', note=note) -myutils.check_install_package('numpy', max_version='2.0.0') -myutils.check_install_package('stardist') + """ +utils.check_install_package("tensorflow", note=note) +utils.check_install_package("numpy", max_version="2.0.0") +utils.check_install_package("stardist") import sys import tensorflow import h5py + if sys.version_info.minor < 9: # Tensorflow > 2.3 has the requirement h5py~=3.1.0, # but stardist 0.7.3 with python<3.9 requires h5py<3 # see issue here https://github.com/stardist/stardist/issues/180 - tf_version = tensorflow.__version__.split('.') + tf_version = tensorflow.__version__.split(".") tf_major, tf_minor = [int(v) for v in tf_version][:2] - h5py_version = h5py.__version__.split('.') + h5py_version = h5py.__version__.split(".") h5py_major = int(h5py_version[0]) if tf_major > 1 and tf_minor > 2 and h5py_major >= 3: subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', '--upgrade', 'h5py==2.10.0'] + [sys.executable, "-m", "pip", "install", "--upgrade", "h5py==2.10.0"] ) diff --git a/cellacdc/models/StarDist/acdcSegment.py b/cellacdc/segmenters/StarDist/acdcSegment.py similarity index 66% rename from cellacdc/models/StarDist/acdcSegment.py rename to cellacdc/segmenters/StarDist/acdcSegment.py index 4058789f6..38811c73d 100755 --- a/cellacdc/models/StarDist/acdcSegment.py +++ b/cellacdc/segmenters/StarDist/acdcSegment.py @@ -6,31 +6,31 @@ from cellacdc import models + class AvailableModels: values = models.STARDIST_MODELS + class Model: def __init__( - self, - model_name: AvailableModels='2D_versatile_fluo', - load_stardist_3D=False - ): + self, model_name: AvailableModels = "2D_versatile_fluo", load_stardist_3D=False + ): """_summary_ Parameters ---------- model_name : str, optional - Name of the pre-trained model to load. - - Available models are '2D_versatile_fluo', '2D_versatile_he', and + Name of the pre-trained model to load. + + Available models are '2D_versatile_fluo', '2D_versatile_he', and '2D_paper_dsb2018'. - + Default is '2D_versatile_fluo' - """ - + """ + stardist_default_models = models.STARDIST_MODELS stardist_path = os.path.dirname(os.path.abspath(__file__)) - T_cell_path = os.path.join(stardist_path, 'model', 'T_cell') + T_cell_path = os.path.join(stardist_path, "model", "T_cell") model_class = StarDist3D if load_stardist_3D else StarDist2D if not os.path.exists(T_cell_path): model_name = stardist_default_models[0] @@ -40,39 +40,35 @@ def __init__( else: script_path = os.path.abspath(__file__) stardist_path = os.path.dirname(script_path) - model_path = os.path.join(stardist_path, 'model') - self.model = model_class( - None, name=model_name, basedir=model_path - ) + model_path = os.path.join(stardist_path, "model") + self.model = model_class(None, name=model_name, basedir=model_path) self.load_stardist_3D = load_stardist_3D - def segment( - self, image, prob_thresh=0.0, nms_thresh=0.0, - segment_3D_volume=False - ): + def segment(self, image, prob_thresh=0.0, nms_thresh=0.0, segment_3D_volume=False): # Check on image shape is2D = image.ndim == 2 is3D = image.ndim == 3 calling_stardist3D_on_2D_data = ( (is3D and self.load_stardist_3D and not segment_3D_volume) - or is2D and self.load_stardist_3D + or is2D + and self.load_stardist_3D ) calling_stardist2D_on_3D_data = ( is3D and not self.load_stardist_3D and segment_3D_volume ) if calling_stardist3D_on_2D_data: - print('') - print('='*30) + print("") + print("=" * 30) raise ValueError( - 'StarDist3D cannot segment 2D image data. If you are trying to ' + "StarDist3D cannot segment 2D image data. If you are trying to " 'segment z-slices one by one you need to click "True" at the ' '"Segment 3D Volume" entry.' ) elif calling_stardist2D_on_3D_data: - print('') - print('='*30) + print("") + print("=" * 30) raise ValueError( - 'StarDist2D cannot segment 3D image data. If you are trying to ' + "StarDist2D cannot segment 3D image data. If you are trying to " 'segment z-slices one by one you need to click "False" at the ' '"Segment 3D Volume" entry.' ) @@ -84,16 +80,12 @@ def segment( labels = np.zeros(image.shape, dtype=np.uint32) for i, _img in enumerate(image): lab, _ = self.model.predict_instances( - normalize(_img), - prob_thresh=prob_thresh, - nms_thresh=nms_thresh + normalize(_img), prob_thresh=prob_thresh, nms_thresh=nms_thresh ) labels[i] = lab - labels = skimage.measure.label(labels>0) + labels = skimage.measure.label(labels > 0) else: labels, _ = self.model.predict_instances( - normalize(image), - prob_thresh=prob_thresh, - nms_thresh=nms_thresh + normalize(image), prob_thresh=prob_thresh, nms_thresh=nms_thresh ) return labels.astype(np.uint32) diff --git a/cellacdc/segmenters/YeaZ/__init__.py b/cellacdc/segmenters/YeaZ/__init__.py new file mode 100755 index 000000000..0b298b89f --- /dev/null +++ b/cellacdc/segmenters/YeaZ/__init__.py @@ -0,0 +1,3 @@ +from cellacdc import utils + +utils.check_install_package("tensorflow", max_version="2.17") diff --git a/cellacdc/models/YeaZ/acdcSegment.py b/cellacdc/segmenters/YeaZ/acdcSegment.py similarity index 75% rename from cellacdc/models/YeaZ/acdcSegment.py rename to cellacdc/segmenters/YeaZ/acdcSegment.py index d5d0154f1..a37b4c4bc 100755 --- a/cellacdc/models/YeaZ/acdcSegment.py +++ b/cellacdc/segmenters/YeaZ/acdcSegment.py @@ -14,9 +14,10 @@ from tensorflow import keras from tqdm import tqdm -from cellacdc import myutils +from cellacdc import utils from cellacdc import user_profile_path + class progressCallback(keras.callbacks.Callback): def __init__(self, signals): self.signals = signals @@ -34,27 +35,25 @@ def on_predict_batch_end(self, batch, logs=None): else: self.signals[0].progressBar.emit(1) + class Model: def __init__(self, is_phase_contrast=True): # Initialize model - self.model = model.unet( - pretrained_weights=None, - input_size=(None,None,1) - ) + self.model = model.unet(pretrained_weights=None, input_size=(None, None, 1)) # Get the path where the weights are saved. # We suggest saving the weights files into a 'model' subfolder - model_path = os.path.join(str(user_profile_path), f'acdc-YeaZ') + model_path = os.path.join(str(user_profile_path), f"acdc-YeaZ") if is_phase_contrast: - weights_fn = 'unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5' + weights_fn = "unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5" else: - weights_fn = 'weights_budding_BF_multilab_0_1.hdf5' + weights_fn = "weights_budding_BF_multilab_0_1.hdf5" weights_path = os.path.join(model_path, weights_fn) if not os.path.exists(model_path): - raise FileNotFoundError(f'Weights file not found in {model_path}') + raise FileNotFoundError(f"Weights file not found in {model_path}") self.model.load_weights(weights_path) @@ -62,7 +61,7 @@ def yeaz_preprocess(self, image, tqdm_pbar=None, warn=True): # image = skimage.filters.gaussian(image, sigma=1) # image = skimage.exposure.equalize_adapthist(image) # image = image/image.max() - image = myutils.img_to_float(image, warn=warn) + image = utils.img_to_float(image, warn=warn) image = skimage.exposure.equalize_adapthist(image) if tqdm_pbar is not None: tqdm_pbar.emit(1) @@ -71,17 +70,16 @@ def yeaz_preprocess(self, image, tqdm_pbar=None, warn=True): def predict3DT(self, timelapse3D): # pad with zeros such that is divisible by 16 (nrow, ncol) = timelapse3D[0].shape - row_add = 16-nrow%16 - col_add = 16-ncol%16 + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 pad_info = ((0, 0), (0, row_add), (0, col_add)) - padded = np.pad(timelapse3D, pad_info, 'constant') + padded = np.pad(timelapse3D, pad_info, "constant") x = padded[:, :, :, np.newaxis] prediction = self.model.predict(x, batch_size=1, verbose=1) prediction = prediction[:, 0:-row_add, 0:-col_add, 0] return prediction - def segment2D(self, image, thresh_val=0.0, min_distance=10): # Preprocess image image = self.yeaz_preprocess(image) @@ -91,13 +89,13 @@ def segment2D(self, image, thresh_val=0.0, min_distance=10): # pad with zeros such that is divisible by 16 (nrow, ncol) = image.shape - row_add = 16-nrow%16 - col_add = 16-ncol%16 + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 pad_info = ((0, row_add), (0, col_add)) - padded = np.pad(image, pad_info, 'constant') - x = padded[np.newaxis,:,:,np.newaxis] + padded = np.pad(image, pad_info, "constant") + x = padded[np.newaxis, :, :, np.newaxis] - prediction = self.model.predict(x, batch_size=1, verbose=1)[0,:,:,0] + prediction = self.model.predict(x, batch_size=1, verbose=1)[0, :, :, 0] # remove padding with 0s prediction = prediction[0:-row_add, 0:-col_add] @@ -115,26 +113,26 @@ def segment(self, image, thresh_val=0.0, min_distance=10): img, thresh_val=thresh_val, min_distance=min_distance ) labels[z] = lab - labels = skimage.measure.label(labels>0) + labels = skimage.measure.label(labels > 0) else: labels = self.segment2D( image, thresh_val=thresh_val, min_distance=min_distance ) return labels - def segment3DT( - self, timelapse3D, thresh_val=0.0, min_distance=10, signals=None - ): + def segment3DT(self, timelapse3D, thresh_val=0.0, min_distance=10, signals=None): sig_progress_tqdm = None if signals is not None: - signals[0].progress.emit(f'Preprocessing images...') + signals[0].progress.emit(f"Preprocessing images...") signals[0].create_tqdm.emit(len(timelapse3D)) sig_progress_tqdm = signals[0].progress_tqdm - timelapse3D = np.array([ - self.yeaz_preprocess(image, tqdm_pbar=sig_progress_tqdm, warn=i==0) - for i, image in enumerate(timelapse3D) - ]) + timelapse3D = np.array( + [ + self.yeaz_preprocess(image, tqdm_pbar=sig_progress_tqdm, warn=i == 0) + for i, image in enumerate(timelapse3D) + ] + ) if signals is not None: signals[0].signal_close_tqdm.emit() @@ -144,15 +142,15 @@ def segment3DT( # pad with zeros such that is divisible by 16 (nrow, ncol) = timelapse3D[0].shape - row_add = 16-nrow%16 - col_add = 16-ncol%16 + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 pad_info = ((0, 0), (0, row_add), (0, col_add)) - padded = np.pad(timelapse3D, pad_info, 'constant') + padded = np.pad(timelapse3D, pad_info, "constant") x = padded[:, :, :, np.newaxis] if signals is not None: - signals[0].progress.emit(f'Predicting (the future) with YeaZ...') + signals[0].progress.emit(f"Predicting (the future) with YeaZ...") callbacks = None if signals is not None: @@ -160,10 +158,10 @@ def segment3DT( prediction = self.model.predict( x, batch_size=1, verbose=1, callbacks=callbacks - )[:,:,:,0] + )[:, :, :, 0] if signals is not None: - signals[0].progress.emit(f'Labelling objects with YeaZ...') + signals[0].progress.emit(f"Labelling objects with YeaZ...") # remove padding with 0s prediction = prediction[:, 0:-row_add, 0:-col_add] @@ -180,5 +178,6 @@ def segment3DT( signals[0].signal_close_tqdm.emit() return lab_timelapse + def url_help(): - return 'https://github.com/rahi-lab/YeaZ-GUI' \ No newline at end of file + return "https://github.com/rahi-lab/YeaZ-GUI" diff --git a/cellacdc/models/YeaZ/unet/LaunchBatchPrediction.py b/cellacdc/segmenters/YeaZ/unet/LaunchBatchPrediction.py similarity index 50% rename from cellacdc/models/YeaZ/unet/LaunchBatchPrediction.py rename to cellacdc/segmenters/YeaZ/unet/LaunchBatchPrediction.py index 21539a693..c4be0d213 100755 --- a/cellacdc/models/YeaZ/unet/LaunchBatchPrediction.py +++ b/cellacdc/segmenters/YeaZ/unet/LaunchBatchPrediction.py @@ -3,68 +3,78 @@ Created on Tue Nov 19 17:38:58 2019 """ -from qtpy.QtWidgets import (QDialog, QDialogButtonBox, QLineEdit, QFormLayout, - QLabel, QListWidget, QAbstractItemView, QCheckBox, - QButtonGroup, QRadioButton) +from qtpy.QtWidgets import ( + QDialog, + QDialogButtonBox, + QLineEdit, + QFormLayout, + QLabel, + QListWidget, + QAbstractItemView, + QCheckBox, + QButtonGroup, + QRadioButton, +) from qtpy import QtGui class CustomDialog(QDialog): def __init__(self, *args, **kwargs): super(CustomDialog, self).__init__(*args, **kwargs) - - app, = args + + (app,) = args maxtimeindex = app.reader.sizet - + self.setWindowTitle("Launch NN") - self.setGeometry(100,100, 500,200) - + self.setGeometry(100, 100, 500, 200) + self.entry1 = QLineEdit() - self.entry1.setValidator(QtGui.QIntValidator(0,int(maxtimeindex-1))) + self.entry1.setValidator(QtGui.QIntValidator(0, int(maxtimeindex - 1))) self.entry2 = QLineEdit() - self.entry2.setValidator(QtGui.QIntValidator(0,int(maxtimeindex-1))) - + self.entry2.setValidator(QtGui.QIntValidator(0, int(maxtimeindex - 1))) + # FOV dialog self.listfov = QListWidget() self.listfov.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) for f in range(0, app.reader.Npos): - self.listfov.addItem('Field of View {}'.format(f+1)) + self.listfov.addItem("Field of View {}".format(f + 1)) + + self.labeltime = QLabel( + "Enter range of frames ({}-{}) to segment".format(0, app.reader.sizet - 1) + ) - self.labeltime = QLabel("Enter range of frames ({}-{}) to segment".format(0, app.reader.sizet-1)) - self.entry_threshold = QLineEdit() self.entry_threshold.setValidator(QtGui.QDoubleValidator()) - self.entry_threshold.setText('0.5') - + self.entry_threshold.setText("0.5") + self.entry_segmentation = QLineEdit() self.entry_segmentation.setValidator(QtGui.QIntValidator()) - self.entry_segmentation.setText('5') - + self.entry_segmentation.setText("5") + flo = QFormLayout() flo.addWidget(self.labeltime) - flo.addRow('Start from frame:', self.entry1) - flo.addRow('End at frame:', self.entry2) - flo.addRow('Select field(s) of view:', self.listfov) - flo.addRow('Threshold value:', self.entry_threshold) - flo.addRow('Min. distance between seeds:', self.entry_segmentation) - + flo.addRow("Start from frame:", self.entry1) + flo.addRow("End at frame:", self.entry2) + flo.addRow("Select field(s) of view:", self.listfov) + flo.addRow("Threshold value:", self.entry_threshold) + flo.addRow("Min. distance between seeds:", self.entry_segmentation) + self.radiobuttons = QButtonGroup() - self.buttonBF = QRadioButton('Images are bright-field') - self.buttonPC = QRadioButton('Images are phase contrast') + self.buttonBF = QRadioButton("Images are bright-field") + self.buttonPC = QRadioButton("Images are phase contrast") self.buttonPC.setChecked(True) self.radiobuttons.addButton(self.buttonBF, id=0) self.radiobuttons.addButton(self.buttonPC, id=1) flo.addWidget(self.buttonBF) flo.addWidget(self.buttonPC) - - QBtn = QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel - + + QBtn = ( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + self.buttonBox = QDialogButtonBox(QBtn) self.buttonBox.accepted.connect(self.accept) self.buttonBox.rejected.connect(self.reject) flo.addWidget(self.buttonBox) self.setLayout(flo) - - - diff --git a/cellacdc/models/YeaZ/unet/__init__.py b/cellacdc/segmenters/YeaZ/unet/__init__.py similarity index 100% rename from cellacdc/models/YeaZ/unet/__init__.py rename to cellacdc/segmenters/YeaZ/unet/__init__.py diff --git a/cellacdc/models/YeaZ/unet/hungarian.py b/cellacdc/segmenters/YeaZ/unet/hungarian.py similarity index 77% rename from cellacdc/models/YeaZ/unet/hungarian.py rename to cellacdc/segmenters/YeaZ/unet/hungarian.py index 302729eda..d23e6f26d 100755 --- a/cellacdc/models/YeaZ/unet/hungarian.py +++ b/cellacdc/segmenters/YeaZ/unet/hungarian.py @@ -9,16 +9,16 @@ def correspondence(prev, curr): """ Corrects correspondence between previous and current mask, returns current mask with corrected cell values. New cells are given the unique identifier - starting at max(prev)+1. - + starting at max(prev)+1. + This is done by embedding every cell into a feature space consisting of - the center of mass and the area. The pairwise euclidean distance is - calculated between the cells of the previous and current frame. This is + the center of mass and the area. The pairwise euclidean distance is + calculated between the cells of the previous and current frame. This is then used as a cost for the bipartite matching problem which is in turn solved by the Hungarian algorithm as implemented in the munkres package. """ newcell = np.max(prev) + 1 - + hu_dict = hungarian_align(prev, curr) new = curr.copy() for key, val in hu_dict.items(): @@ -26,66 +26,68 @@ def correspondence(prev, curr): if val == -1: val = newcell newcell += 1 - - new[curr==key] = val - + + new[curr == key] = val + return new def hungarian_align(m1, m2): """ - Aligns the cells using the hungarian algorithm using the euclidean distance as - cost. - Returns dictionary of cells in m2 to cells in m1. If a cell is new, the dictionary + Aligns the cells using the hungarian algorithm using the euclidean distance as + cost. + Returns dictionary of cells in m2 to cells in m1. If a cell is new, the dictionary value is -1. """ dist, ix1, ix2 = cell_distance(m1, m2) - - # If dist couldn't be calculated, return dictionary from cells to themselves + + # If dist couldn't be calculated, return dictionary from cells to themselves if dist is None: unique_m2 = np.unique(m2) return dict(zip(unique_m2, unique_m2)) - + solver = Munkres() indexes = solver.compute(make_square(dist)) - + # Create dictionary of cell indicies d = dict([(ix2.get(i2, -1), ix1.get(i1, -1)) for i1, i2 in indexes]) - d.pop(-1, None) + d.pop(-1, None) return d def cell_to_features(im, c, nsamples=None, time=None): """Embeds cell c in image im into feature space""" - coord = np.argwhere(im==c) + coord = np.argwhere(im == c) area = coord.shape[0] - + if nsamples is not None: samples = np.random.choice(area, min(nsamples, area), replace=False) - sampled = coord[samples,:] + sampled = coord[samples, :] else: sampled = coord - + com = sampled.mean(axis=0) - - return {'cell': c, - 'time': time, - 'sqrtarea': np.sqrt(area), - 'area': area, - 'com_x': com[0], - 'com_y': com[1]} - - + + return { + "cell": c, + "time": time, + "sqrtarea": np.sqrt(area), + "area": area, + "com_x": com[0], + "com_y": com[1], + } + + def cell_distance(m1, m2, weight_com=3): """ Gives distance matrix between cells in first and second frame, by embedding all cells into the feature space. Currently uses center of mass and area - as features, with center of mass weighted with factor weight_com (to + as features, with center of mass weighted with factor weight_com (to make it more important). """ # Modify to compute use more computed features - #cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] - cols = ['com_x', 'com_y', 'area'] + # cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] + cols = ["com_x", "com_y", "area"] def get_features(m, t): cells = list(np.unique(m)) @@ -93,29 +95,28 @@ def get_features(m, t): cells.remove(0) features = [cell_to_features(m, c, time=t) for c in cells] return pd.DataFrame(features), dict(enumerate(cells)) - + # Create df, rescale feat1, ix_to_cell1 = get_features(m1, 1) feat2, ix_to_cell2 = get_features(m2, 2) - + # Check if one of matrices doesn't contain cells - if len(feat1)==0 or len(feat2)==0: + if len(feat1) == 0 or len(feat2) == 0: return None, None, None - + df = pd.concat((feat1, feat2)) df[cols] = scale(df[cols]) - + # give more importance to center of mass - df[['com_x', 'com_y']] = df[['com_x', 'com_y']] * weight_com + df[["com_x", "com_y"]] = df[["com_x", "com_y"]] * weight_com # pairwise euclidean dist dist = euclidean_distances( - df.loc[df['time']==1][cols], - df.loc[df['time']==2][cols] + df.loc[df["time"] == 1][cols], df.loc[df["time"] == 2][cols] ) return dist, ix_to_cell1, ix_to_cell2 - - + + def zero_pad(m, shape): """Pads matrix with zeros to be of desired shape""" out = np.zeros(shape) @@ -126,12 +127,10 @@ def zero_pad(m, shape): def make_square(m): """Turns matrix into square matrix, as required by Munkres algorithm""" - r,c = m.shape - if r==c: + r, c = m.shape + if r == c: return m - elif r>c: - return zero_pad(m, (r,r)) + elif r > c: + return zero_pad(m, (r, r)) else: - return zero_pad(m, (c,c)) - - + return zero_pad(m, (c, c)) diff --git a/cellacdc/segmenters/YeaZ/unet/model.py b/cellacdc/segmenters/YeaZ/unet/model.py new file mode 100755 index 000000000..feb40151f --- /dev/null +++ b/cellacdc/segmenters/YeaZ/unet/model.py @@ -0,0 +1,140 @@ +""" +Source of the code: https://github.com/zhixuhao/unet +""" + +# Turn off GPU access so can train and use the YeaZ-GUI +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +# Import tensorflow differently depending on version +from tensorflow import __version__ as tf_version + +tf_version_old = int(tf_version[0]) <= 1 + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import ( + Input, + Conv2D, + MaxPooling2D, + Dropout, + concatenate, + UpSampling2D, +) +from tensorflow.keras.optimizers import Adam +# from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler + +if tf_version_old: + from tensorflow import ConfigProto + from tensorflow import InteractiveSession + +else: + from tensorflow.compat.v1 import ConfigProto + from tensorflow.compat.v1 import InteractiveSession + + +config = ConfigProto() +config.gpu_options.allow_growth = True +session = InteractiveSession(config=config) + + +def unet(pretrained_weights=None, input_size=(256, 256, 1)): + inputs = Input(input_size) + conv1 = Conv2D( + 64, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(inputs) + conv1 = Conv2D( + 64, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + conv2 = Conv2D( + 128, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(pool1) + conv2 = Conv2D( + 128, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + conv3 = Conv2D( + 256, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(pool2) + conv3 = Conv2D( + 256, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + conv4 = Conv2D( + 512, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(pool3) + conv4 = Conv2D( + 512, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv4) + drop4 = Dropout(0.5)(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) + + conv5 = Conv2D( + 1024, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(pool4) + conv5 = Conv2D( + 1024, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv5) + drop5 = Dropout(0.5)(conv5) + + up6 = Conv2D( + 512, 2, activation="relu", padding="same", kernel_initializer="he_normal" + )(UpSampling2D(size=(2, 2))(drop5)) + merge6 = concatenate([drop4, up6], axis=3) + conv6 = Conv2D( + 512, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(merge6) + conv6 = Conv2D( + 512, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv6) + + up7 = Conv2D( + 256, 2, activation="relu", padding="same", kernel_initializer="he_normal" + )(UpSampling2D(size=(2, 2))(conv6)) + merge7 = concatenate([conv3, up7], axis=3) + conv7 = Conv2D( + 256, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(merge7) + conv7 = Conv2D( + 256, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv7) + + up8 = Conv2D( + 128, 2, activation="relu", padding="same", kernel_initializer="he_normal" + )(UpSampling2D(size=(2, 2))(conv7)) + merge8 = concatenate([conv2, up8], axis=3) + conv8 = Conv2D( + 128, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(merge8) + conv8 = Conv2D( + 128, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv8) + + up9 = Conv2D( + 64, 2, activation="relu", padding="same", kernel_initializer="he_normal" + )(UpSampling2D(size=(2, 2))(conv8)) + merge9 = concatenate([conv1, up9], axis=3) + conv9 = Conv2D( + 64, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(merge9) + conv9 = Conv2D( + 64, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv9) + conv9 = Conv2D( + 2, 3, activation="relu", padding="same", kernel_initializer="he_normal" + )(conv9) + conv10 = Conv2D(1, 1, activation="sigmoid")(conv9) + + model = Model(inputs=inputs, outputs=conv10) + + model.compile( + optimizer=Adam(learning_rate=1e-4), + loss="binary_crossentropy", + metrics=["accuracy"], + ) + + if pretrained_weights: + model.load_weights(pretrained_weights) + + return model diff --git a/cellacdc/models/YeaZ/unet/neural_network.py b/cellacdc/segmenters/YeaZ/unet/neural_network.py similarity index 68% rename from cellacdc/models/YeaZ/unet/neural_network.py rename to cellacdc/segmenters/YeaZ/unet/neural_network.py index 6169fcdba..2fd55819a 100755 --- a/cellacdc/models/YeaZ/unet/neural_network.py +++ b/cellacdc/segmenters/YeaZ/unet/neural_network.py @@ -1,9 +1,9 @@ - # -*- coding: utf-8 -*- """ Created on Sat Dec 21 18:54:10 2019 """ + import os import sys import numpy as np @@ -13,17 +13,19 @@ from .model import unet + def determine_path_weights(): script_dirname = os.path.dirname(os.path.realpath(__file__)) main_path = os.path.dirname(os.path.dirname(os.path.dirname(script_dirname))) - model_path = os.path.join(main_path, 'models', 'YeaZ_model') + model_path = os.path.join(main_path, "models", "YeaZ_model") - if getattr(sys, 'frozen', False): - path_weights = os.path.join(sys._MEIPASS, 'unet/') + if getattr(sys, "frozen", False): + path_weights = os.path.join(sys._MEIPASS, "unet/") else: path_weights = model_path return path_weights + def create_directory_if_not_exists(path): """ Create in the file system a new directory if it doesn't exist yet. @@ -62,33 +64,28 @@ def prediction(im, is_pc, path_weights): """ # pad with zeros such that is divisible by 16 (nrow, ncol) = im.shape - row_add = 16-nrow%16 - col_add = 16-ncol%16 - padded = np.pad(im, ((0, row_add), (0, col_add)), 'constant') + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 + padded = np.pad(im, ((0, row_add), (0, col_add)), "constant") # WHOLE CELL PREDICTION - model = unet(pretrained_weights = None, - input_size = (None,None,1)) + model = unet(pretrained_weights=None, input_size=(None, None, 1)) if is_pc: path = os.path.join( - path_weights, - 'unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5' + path_weights, "unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5" ) else: - path = os.path.join( - path_weights, - 'weights_budding_BF_multilab_0_1.hdf5' - ) + path = os.path.join(path_weights, "weights_budding_BF_multilab_0_1.hdf5") if not os.path.exists(path): - raise ValueError(f'Weights file not found in {path}') + raise ValueError(f"Weights file not found in {path}") model.load_weights(path) - results = model.predict(padded[np.newaxis,:,:,np.newaxis], batch_size=1) + results = model.predict(padded[np.newaxis, :, :, np.newaxis], batch_size=1) - res = results[0,:,:,0] + res = results[0, :, :, 0] return res[:nrow, :ncol] @@ -106,23 +103,27 @@ def batch_prediction(im_stack, is_pc, path_weights, batch_size=1): col_add = 16 - ncol % 16 im_stack_padded = [] for im in im_stack: - padded = np.pad(im, ((0, row_add), (0, col_add)), mode='constant') + padded = np.pad(im, ((0, row_add), (0, col_add)), mode="constant") im_stack_padded.append(padded) im_stack_padded = np.array(im_stack_padded) # WHOLE CELL PREDICTION - model = unet(pretrained_weights=None, - input_size=(None, None, 1)) + model = unet(pretrained_weights=None, input_size=(None, None, 1)) if is_pc: - path = os.path.join(path_weights, 'unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5') + path = os.path.join( + path_weights, "unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5" + ) else: - path = os.path.join(path_weights, 'unet_weights_BF_batchsize_25_Nepochs_100_SJR_0_1.hdf5') + path = os.path.join( + path_weights, "unet_weights_BF_batchsize_25_Nepochs_100_SJR_0_1.hdf5" + ) if not os.path.exists(path): raise ValueError( - 'Weights file not found! Download them from the link ' - f'below and place them into {path_weights}.\n' - 'Link: https://drive.google.com/file/d/1CO7uF-werl9y8s3Fel0cVjRHCdXRf2Ly/view?usp=sharing') + "Weights file not found! Download them from the link " + f"below and place them into {path_weights}.\n" + "Link: https://drive.google.com/file/d/1CO7uF-werl9y8s3Fel0cVjRHCdXRf2Ly/view?usp=sharing" + ) model.load_weights(path) diff --git a/cellacdc/models/YeaZ/unet/segment.py b/cellacdc/segmenters/YeaZ/unet/segment.py similarity index 68% rename from cellacdc/models/YeaZ/unet/segment.py rename to cellacdc/segmenters/YeaZ/unet/segment.py index 0c2d037f4..689d430f9 100755 --- a/cellacdc/models/YeaZ/unet/segment.py +++ b/cellacdc/segmenters/YeaZ/unet/segment.py @@ -31,9 +31,9 @@ def segment(th, pred, min_distance=10, topology=None, merge=True, q=0.75): m[tuple(peak_idx.T)] = True # Uncomment to start with cross for every pixel instead of single pixel - m_lab = label(m) #comment this - #m_dil = dilation(m) - #m_lab = label(m_dil) + m_lab = label(m) # comment this + # m_dil = dilation(m) + # m_lab = label(m_dil) wsh = watershed(topology, m_lab, mask=th, connectivity=2) if merge: merged = cell_merge(wsh, pred, q) @@ -41,6 +41,7 @@ def segment(th, pred, min_distance=10, topology=None, merge=True, q=0.75): merged = wsh return correct_artefacts(merged) + def segment_stack(th, pred, min_distance=10, topology=None, signals=None): """ source: YeaZ @@ -62,9 +63,9 @@ def correct_artefacts(wsh): by another cell. Those are removed here. """ unique, count = np.unique(wsh, return_counts=True) - to_remove = unique[count<=3] + to_remove = unique[count <= 3] for rem in to_remove: - rem_im = wsh==rem + rem_im = wsh == rem rem_cont = dilation(rem_im) & ~rem_im vals, val_counts = np.unique(wsh[rem_cont], return_counts=True) replace_val = vals[np.argmax(val_counts)] @@ -78,51 +79,52 @@ def cell_merge(wsh, pred, q=0.75): Procedure that merges cells if the border between them is predicted to be cell pixels. """ - wshshape=wsh.shape + wshshape = wsh.shape # masks for the original cells - objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]), dtype=bool) + objs = np.zeros((wsh.max() + 1, wshshape[0], wshshape[1]), dtype=bool) # masks for dilated cells - dil_objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]), dtype=bool) + dil_objs = np.zeros((wsh.max() + 1, wshshape[0], wshshape[1]), dtype=bool) # bounding box coordinates - obj_coords = np.zeros((wsh.max()+1,4)) + obj_coords = np.zeros((wsh.max() + 1, 4)) # cleaned watershed, output of function - wshclean = np.zeros((wshshape[0],wshshape[1])) + wshclean = np.zeros((wshshape[0], wshshape[1])) # kernel to dilate objects - kernel = np.ones((3,3), dtype=bool) + kernel = np.ones((3, 3), dtype=bool) for obj1 in range(wsh.max()): # create masks and dilated masks for obj - objs[obj1,:,:] = wsh==(obj1+1) - dil_objs[obj1,:,:] = dilation(objs[obj1,:,:], kernel) + objs[obj1, :, :] = wsh == (obj1 + 1) + dil_objs[obj1, :, :] = dilation(objs[obj1, :, :], kernel) # bounding box - obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:]) + obj_coords[obj1, :] = get_bounding_box(dil_objs[obj1, :, :]) - objcounter = 0 # counter for new watershed objects + objcounter = 0 # counter for new watershed objects for obj1 in range(wsh.max()): - dil1 = dil_objs[obj1,:,:] + dil1 = dil_objs[obj1, :, :] # check if mask has been deleted if np.sum(dil1) == 0: continue objcounter = objcounter + 1 - orig1 = objs[obj1,:,:] + orig1 = objs[obj1, :, :] - for obj2 in range(obj1+1,wsh.max()): - dil2 = dil_objs[obj2,:,:] + for obj2 in range(obj1 + 1, wsh.max()): + dil2 = dil_objs[obj2, :, :] # only check border if bounding box overlaps, and second mask # is not yet deleted - if (do_box_overlap(obj_coords[obj1,:], obj_coords[obj2,:]) - and np.sum(dil2) > 0): - + if ( + do_box_overlap(obj_coords[obj1, :], obj_coords[obj2, :]) + and np.sum(dil2) > 0 + ): border = dil1 * dil2 border_pred = pred[border] @@ -137,13 +139,13 @@ def cell_merge(wsh, pred, q=0.75): top_border_area = len(top_border_pred) # merge cells - if top_border_height / top_border_area > .99: - orig1 = np.logical_or(orig1, objs[obj2,:,:]) - dil_objs[obj1,:,:] = np.logical_or(dil1, dil2) - dil_objs[obj2,:,:] = np.zeros((wshshape[0], wshshape[1])) - obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:]) + if top_border_height / top_border_area > 0.99: + orig1 = np.logical_or(orig1, objs[obj2, :, :]) + dil_objs[obj1, :, :] = np.logical_or(dil1, dil2) + dil_objs[obj2, :, :] = np.zeros((wshshape[0], wshshape[1])) + obj_coords[obj1, :] = get_bounding_box(dil_objs[obj1, :, :]) - wshclean = wshclean + orig1*objcounter + wshclean = wshclean + orig1 * objcounter return wshclean @@ -152,15 +154,22 @@ def do_box_overlap(coord1, coord2): """Checks if boxes, determined by their coordinates, overlap. Safety margin of 2 pixels""" return ( - (coord1[0] - 2 < coord2[0] and coord1[1] + 2 > coord2[0] - or coord2[0] - 2 < coord1[0] and coord2[1] + 2 > coord1[0]) - and (coord1[2] - 2 < coord2[2] and coord1[3] + 2 > coord2[2] - or coord2[2] - 2 < coord1[2] and coord2[3] + 2 > coord1[2])) + coord1[0] - 2 < coord2[0] + and coord1[1] + 2 > coord2[0] + or coord2[0] - 2 < coord1[0] + and coord2[1] + 2 > coord1[0] + ) and ( + coord1[2] - 2 < coord2[2] + and coord1[3] + 2 > coord2[2] + or coord2[2] - 2 < coord1[2] + and coord2[3] + 2 > coord1[2] + ) def get_bounding_box(im): """Returns bounding box of object in boolean image""" coords = np.where(im) - return np.array([np.min(coords[0]), np.max(coords[0]), - np.min(coords[1]), np.max(coords[1])]) + return np.array( + [np.min(coords[0]), np.max(coords[0]), np.min(coords[1]), np.max(coords[1])] + ) diff --git a/cellacdc/models/YeaZ/unet/tracking.py b/cellacdc/segmenters/YeaZ/unet/tracking.py similarity index 82% rename from cellacdc/models/YeaZ/unet/tracking.py rename to cellacdc/segmenters/YeaZ/unet/tracking.py index fa3ce2176..aca6ca05f 100755 --- a/cellacdc/models/YeaZ/unet/tracking.py +++ b/cellacdc/segmenters/YeaZ/unet/tracking.py @@ -15,6 +15,7 @@ except ModuleNotFoundError as e: pass + def correspondence(prev, curr, use_scipy=True, use_modified_yeaz=True): """ source: YeaZ modified by Cell-ACDC developers @@ -41,10 +42,11 @@ def correspondence(prev, curr, use_scipy=True, use_modified_yeaz=True): val = newcell newcell += 1 - new[curr==key] = val + new[curr == key] = val return new + def scipy_align(m1, m2, acdc_yeaz=True): """ source: YeaZ modified by Cell-ACDC @@ -67,6 +69,7 @@ def scipy_align(m1, m2, acdc_yeaz=True): d.pop(-1, None) return d + def correspondence_stack(stack, signals=None): """ source: YeaZ @@ -77,15 +80,16 @@ def correspondence_stack(stack, signals=None): corrected_stack[0] = stack[0] for idx in range(len(stack)): try: - curr = stack[idx+1] + curr = stack[idx + 1] prev = corrected_stack[idx] except IndexError: continue - corrected_stack[idx+1] = correspondence(prev, curr) + corrected_stack[idx + 1] = correspondence(prev, curr) if signals is not None: signals.progressBar.emit(1) return corrected_stack + def hungarian_align(m1, m2, acdc_yeaz=True): """ source: YeaZ @@ -109,50 +113,54 @@ def hungarian_align(m1, m2, acdc_yeaz=True): d.pop(-1, None) return d + def cell_to_features(im, c, nsamples=None, time=None): """ source: YeaZ Embeds cell c in image im into feature space """ - coord = np.argwhere(im==c) + coord = np.argwhere(im == c) area = coord.shape[0] if nsamples is not None: samples = np.random.choice(area, min(nsamples, area), replace=False) - sampled = coord[samples,:] + sampled = coord[samples, :] else: sampled = coord com = sampled.mean(axis=0) - return {'cell': c, - 'time': time, - 'sqrtarea': np.sqrt(area), - 'area': area, - 'com_x': com[0], - 'com_y': com[1]} + return { + "cell": c, + "time": time, + "sqrtarea": np.sqrt(area), + "area": area, + "com_x": com[0], + "com_y": com[1], + } + def get_features_acdc(m, t): rp = regionprops(m) features = { - 'cell': [], - 'time': [], - 'sqrtarea': [], - 'area': [], - 'com_x': [], - 'com_y': [] + "cell": [], + "time": [], + "sqrtarea": [], + "area": [], + "com_x": [], + "com_y": [], } for obj in rp: area = obj.area y, x = obj.centroid - features['cell'].append(obj.label) - features['time'].append(t) - features['sqrtarea'].append(sqrt(area)) - features['area'].append(area) - features['com_x'].append(y) - features['com_y'].append(x) + features["cell"].append(obj.label) + features["time"].append(t) + features["sqrtarea"].append(sqrt(area)) + features["area"].append(area) + features["com_x"].append(y) + features["com_y"].append(x) df = pd.DataFrame(features) - return df, dict(enumerate(features['cell'])) + return df, dict(enumerate(features["cell"])) def get_features(m, t): @@ -162,6 +170,7 @@ def get_features(m, t): features = [cell_to_features(m, c, time=t) for c in cells] return pd.DataFrame(features), dict(enumerate(cells)) + def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): """ source: YeaZ @@ -171,8 +180,8 @@ def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): make it more important). """ # Modify to compute use more computed features - #cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] - cols = ['com_x', 'com_y', 'area'] + # cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] + cols = ["com_x", "com_y", "area"] get_features_func = get_features_acdc if acdc_yeaz else get_features @@ -183,19 +192,18 @@ def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): # feat1_acdc, ix_to_cell1_acdc = get_features_acdc(m1, 1) # Check if one of matrices doesn't contain cells - if len(feat1)==0 or len(feat2)==0: + if len(feat1) == 0 or len(feat2) == 0: return None, None, None df = pd.concat((feat1, feat2)) df[cols] = scale(df[cols]) # give more importance to center of mass - df[['com_x', 'com_y']] = df[['com_x', 'com_y']] * weight_com + df[["com_x", "com_y"]] = df[["com_x", "com_y"]] * weight_com # pairwise euclidean dist dist = euclidean_distances( - df.loc[df['time']==1][cols], - df.loc[df['time']==2][cols] + df.loc[df["time"] == 1][cols], df.loc[df["time"] == 2][cols] ) return dist, ix_to_cell1, ix_to_cell2 @@ -216,10 +224,10 @@ def make_square(m): source: YeaZ Turns matrix into square matrix, as required by Munkres algorithm """ - r,c = m.shape - if r==c: + r, c = m.shape + if r == c: return m - elif r>c: - return zero_pad(m, (r,r)) + elif r > c: + return zero_pad(m, (r, r)) else: - return zero_pad(m, (c,c)) + return zero_pad(m, (c, c)) diff --git a/cellacdc/segmenters/YeaZ_v2/__init__.py b/cellacdc/segmenters/YeaZ_v2/__init__.py new file mode 100644 index 000000000..94d05b57d --- /dev/null +++ b/cellacdc/segmenters/YeaZ_v2/__init__.py @@ -0,0 +1,50 @@ +import os + +from cellacdc import utils, load + +utils.check_install_yeaz() + +custom_weights_json_filename = "custom_weights_name_filepath.json" + + +def add_model_filepath(name: str, filepath: os.PathLike): + _, model_folderpath = utils.get_model_path("YeaZ_v2", create_temp_dir=False) + custom_weights_json_file = os.path.join( + model_folderpath, custom_weights_json_filename + ) + custom_weights_mapper = {} + if os.path.exists(custom_weights_json_file): + custom_weights_mapper = load.read_json( + custom_weights_json_file, desc="YeaZ_v2 custom weights filepath info" + ) + + custom_weights_mapper[name] = filepath + load.write_json(custom_weights_mapper, custom_weights_json_file) + + +def load_models_filepath(): + values = ["Phase contrast", "Bright-field", "Fission yeast"] + mapper = { + "Phase contrast": "weights_budding_PhC_multilab_0_1", + "Bright-field": "weights_budding_BF_multilab_0_1", + "Fission yeast": "weights_fission_multilab_0_2", + } + _, model_folderpath = utils.get_model_path("YeaZ_v2", create_temp_dir=False) + mapper = { + name: os.path.join(model_folderpath, filename) + for name, filename in mapper.items() + } + + custom_weights_json_file = os.path.join( + model_folderpath, custom_weights_json_filename + ) + if not os.path.exists(custom_weights_json_file): + return values, mapper + + custom_weights_mapper = load.read_json( + custom_weights_json_file, desc="YeaZ_v2 custom weights filepath info" + ) + values.extend(custom_weights_mapper.keys()) + mapper = {**mapper, **custom_weights_mapper} + + return values, mapper diff --git a/cellacdc/models/YeaZ_v2/acdcSegment.py b/cellacdc/segmenters/YeaZ_v2/acdcSegment.py similarity index 80% rename from cellacdc/models/YeaZ_v2/acdcSegment.py rename to cellacdc/segmenters/YeaZ_v2/acdcSegment.py index 35848d64d..208396176 100644 --- a/cellacdc/models/YeaZ_v2/acdcSegment.py +++ b/cellacdc/segmenters/YeaZ_v2/acdcSegment.py @@ -12,61 +12,66 @@ from yeaz.unet import segment as yeaz_segment import yeaz.unet.neural_network as nn -from cellacdc import myutils, printl, load +from cellacdc import utils, printl, load from . import load_models_filepath -class ModelType: + +class ModelType: isWidget = True def __init__(self): from cellacdc import widgets + self.widget = widgets.YeazV2SelectModelNameCombobox( - custom_select_item_text='Select custom weights file...' + custom_select_item_text="Select custom weights file..." ) + class Model: - def __init__(self, model_type: ModelType='Phase contrast'): + def __init__(self, model_type: ModelType = "Phase contrast"): # Initialize model models_name, models_name_filepath_mapper = load_models_filepath() weights_filepath = models_name_filepath_mapper[model_type] - + self.model = UNet() self.model.load_state_dict(torch.load(weights_filepath)) - + if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") self._is_gpu = True elif torch.backends.mps.is_available(): - device = torch.device('mps') + device = torch.device("mps") self._is_gpu = True else: - device = torch.device('cpu') + device = torch.device("cpu") self._is_gpu = False - + self.device = device self.model = self.model.to(device) - + def _segment_img_3D(self, image, thresh_val=0.0, min_distance=10): # Preprocess image - image = np.array([ - self._preprocess_image(img, warn=i==0).astype(np.float32) - for i, img in enumerate(image) - ]) - + image = np.array( + [ + self._preprocess_image(img, warn=i == 0).astype(np.float32) + for i, img in enumerate(image) + ] + ) + # pad with zeros such that is divisible by 16 (nrow, ncol) = image.shape[-2:] - row_add = 16-nrow%16 - col_add = 16-ncol%16 + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 pad_width = ((0, 0), (0, row_add), (0, col_add)) padded = np.pad(image, pad_width) - + padded = torch.from_numpy(padded) if self._is_gpu: padded = padded.to(self.device) - + self.model.eval() - + with torch.no_grad(): # Convert input tensor to PyTorch tensor input_tensor = padded.unsqueeze(1).float() @@ -75,7 +80,7 @@ def _segment_img_3D(self, image, thresh_val=0.0, min_distance=10): # Convert output tensor to NumPy array output_array = output_tensor.cpu().detach().numpy() result = output_array[:, 0, :, :] - + if self._is_gpu: try: gc.collect() @@ -83,37 +88,35 @@ def _segment_img_3D(self, image, thresh_val=0.0, min_distance=10): except Exception as e: pass prediction = result[:, :nrow, :ncol] - + if thresh_val == 0: thresh_val = None - + labels = np.zeros(prediction.shape, dtype=np.uint32) for i, pred in enumerate(prediction): thresh = nn.threshold(pred, th=thresh_val) - lab = yeaz_segment.segment( - thresh, pred, min_distance=min_distance - ) + lab = yeaz_segment.segment(thresh, pred, min_distance=min_distance) labels[i] = lab.astype(np.uint32) return labels - + def _segment_img_2D(self, image, thresh_val=0.0, min_distance=10, warn=True): # Preprocess image image = self._preprocess_image(image, warn=warn).astype(np.float32) - + # pad with zeros such that is divisible by 16 (nrow, ncol) = image.shape - row_add = 16-nrow%16 - col_add = 16-ncol%16 + row_add = 16 - nrow % 16 + col_add = 16 - ncol % 16 pad_width = ((0, row_add), (0, col_add)) padded = np.pad(image, pad_width) - + padded = torch.from_numpy(padded) if self._is_gpu: padded = padded.to(self.device) - + self.model.eval() - + with torch.no_grad(): # Convert input tensor to PyTorch tensor input_tensor = padded.unsqueeze(0).unsqueeze(0).float() @@ -122,7 +125,7 @@ def _segment_img_2D(self, image, thresh_val=0.0, min_distance=10, warn=True): # Convert output tensor to NumPy array output_array = output_tensor.cpu().detach().numpy() result = output_array[0, 0, :, :] - + if self._is_gpu: try: gc.collect() @@ -130,19 +133,17 @@ def _segment_img_2D(self, image, thresh_val=0.0, min_distance=10, warn=True): except Exception as e: pass prediction = result[:nrow, :ncol] - + if thresh_val == 0: thresh_val = None - + thresholded = nn.threshold(prediction, th=thresh_val) - lab = yeaz_segment.segment( - thresholded, prediction, min_distance=min_distance - ) - + lab = yeaz_segment.segment(thresholded, prediction, min_distance=min_distance) + return lab.astype(np.uint32) - + def _preprocess_image(self, image, tqdm_pbar=None, warn=True): - image = myutils.img_to_float(image, warn=warn) + image = utils.img_to_float(image, warn=warn) image = skimage.exposure.equalize_adapthist(image) if tqdm_pbar is not None: tqdm_pbar.emit(1) @@ -150,19 +151,18 @@ def _preprocess_image(self, image, tqdm_pbar=None, warn=True): # def segment3DT( # self, timelapse3D, thresh_val=0.0, min_distance=10, signals=None - # ): + # ): # lab = self._segment_img_3D( # timelapse3D, thresh_val=thresh_val, min_distance=min_distance # ) # return lab - + def segment(self, image, thresh_val=0.0, min_distance=10): if image.ndim == 3: labels = np.zeros(image.shape, dtype=np.uint32) for z, img in enumerate(image): lab = self._segment_img_2D( - img, thresh_val=thresh_val, min_distance=min_distance, - warn=z==0 + img, thresh_val=thresh_val, min_distance=min_distance, warn=z == 0 ) labels[z] = lab else: @@ -171,5 +171,6 @@ def segment(self, image, thresh_val=0.0, min_distance=10): ) return labels + def url_help(): - return 'https://github.com/rahi-lab/YeaZ-GUI' \ No newline at end of file + return "https://github.com/rahi-lab/YeaZ-GUI" diff --git a/cellacdc/models/YeastMate/__init__.py b/cellacdc/segmenters/YeastMate/__init__.py similarity index 55% rename from cellacdc/models/YeastMate/__init__.py rename to cellacdc/segmenters/YeastMate/__init__.py index 138c1f61c..1276f0781 100755 --- a/cellacdc/models/YeastMate/__init__.py +++ b/cellacdc/segmenters/YeastMate/__init__.py @@ -13,23 +13,24 @@ if QCoreApplication.instance() is None: app = QApplication(sys.argv) - win = warnVisualCppRequired(pkg_name='YeastMate') + win = warnVisualCppRequired(pkg_name="YeastMate") win.exec_() if win.cancel: - raise ModuleNotFoundError( - 'User cancelled Visual C++ installation' - ) + raise ModuleNotFoundError("User cancelled Visual C++ installation") - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', 'Cython'] - ) + subprocess.check_call([sys.executable, "-m", "pip", "install", "Cython"]) # subprocess.check_call( # [sys.executable, '-m', 'pip', 'install', # 'git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI'] # ) subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', - 'git+https://github.com/facebookresearch/detectron2.git@v0.5'] + [ + sys.executable, + "-m", + "pip", + "install", + "git+https://github.com/facebookresearch/detectron2.git@v0.5", + ] ) try: @@ -41,25 +42,29 @@ if QCoreApplication.instance() is None: app = QApplication(sys.argv) - from cellacdc import myutils - cancel = myutils._install_package_msg('YeastMate') + from cellacdc import utils + + cancel = utils._install_package_msg("YeastMate") if cancel: - raise ModuleNotFoundError( - 'User aborted YeastMate installation' - ) + raise ModuleNotFoundError("User aborted YeastMate installation") subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', - 'git+https://github.com/hoerlteam/YeastMate.git'] + [ + sys.executable, + "-m", + "pip", + "install", + "git+https://github.com/hoerlteam/YeastMate.git", + ] ) # YeastMate installs opencv-python which is not functional with PyQt5 on macOS. # Uninstall it, and reinstall opencv-python-headless subprocess.check_call( - [sys.executable, '-m', 'pip', 'uninstall', '-y', 'opencv-python'] + [sys.executable, "-m", "pip", "uninstall", "-y", "opencv-python"] ) subprocess.check_call( - [sys.executable, '-m', 'pip', 'uninstall', '-y', 'opencv-python-headless'] + [sys.executable, "-m", "pip", "uninstall", "-y", "opencv-python-headless"] ) subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', 'opencv-python-headless'] + [sys.executable, "-m", "pip", "install", "opencv-python-headless"] ) diff --git a/cellacdc/models/YeastMate/acdcSegment.py b/cellacdc/segmenters/YeastMate/acdcSegment.py similarity index 65% rename from cellacdc/models/YeastMate/acdcSegment.py rename to cellacdc/segmenters/YeastMate/acdcSegment.py index ec026c81c..41415c524 100755 --- a/cellacdc/models/YeastMate/acdcSegment.py +++ b/cellacdc/segmenters/YeastMate/acdcSegment.py @@ -15,32 +15,31 @@ from cellacdc.core import getBaseCca_df from cellacdc import user_profile_path + class Model: def __init__(self): - model_path = os.path.join(str(user_profile_path), f'acdc-YeastMate') - yaml_path = os.path.join(model_path, 'yeastmate.yaml') - weights_path = os.path.join(model_path, 'yeastmate_weights.pth') + model_path = os.path.join(str(user_profile_path), f"acdc-YeastMate") + yaml_path = os.path.join(model_path, "yeastmate.yaml") + weights_path = os.path.join(model_path, "yeastmate_weights.pth") - self.model = YeastMatePredictor( - yaml_path, - weights_path - ) + self.model = YeastMatePredictor(yaml_path, weights_path) def segment( - self, image, - score_threshold_0=0.9, - score_thresholds_1=0.75, - score_thresholds_2=0.75, - pixel_size=110, - reference_pixel_size=110, - lower_quantile=1.5, - upper_quantile=98.5 - ): + self, + image, + score_threshold_0=0.9, + score_thresholds_1=0.75, + score_thresholds_2=0.75, + pixel_size=110, + reference_pixel_size=110, + lower_quantile=1.5, + upper_quantile=98.5, + ): score_thresholds = { - 0: score_threshold_0, + 0: score_threshold_0, 1: score_thresholds_1, - 2: score_thresholds_2 + 2: score_thresholds_2, } detections, lab = self.model.inference( @@ -49,7 +48,7 @@ def segment( pixel_size=pixel_size, reference_pixel_size=reference_pixel_size, lower_quantile=lower_quantile, - upper_quantile=upper_quantile + upper_quantile=upper_quantile, ) return lab @@ -68,15 +67,15 @@ def predictCcaState(self, image, precomputed_lab): if info is None: continue - obj_class = info.get('class') + obj_class = info.get("class") if len(obj_class) < 2: continue - is_budding = float(obj_class[1])>2 + is_budding = float(obj_class[1]) > 2 if not is_budding: continue - links = info.get('links') + links = info.get("links") if not links: continue @@ -85,7 +84,7 @@ def predictCcaState(self, image, precomputed_lab): if mother_bud_info is None: continue - mother_bud_ids = mother_bud_info.get('links') + mother_bud_ids = mother_bud_info.get("links") if mother_bud_ids is None: continue @@ -110,14 +109,15 @@ def predictCcaState(self, image, precomputed_lab): if budID not in cca_df.index: continue - cca_df.at[mothID, 'relative_ID'] = budID - cca_df.at[mothID, 'cell_cycle_stage'] = 'S' + cca_df.at[mothID, "relative_ID"] = budID + cca_df.at[mothID, "cell_cycle_stage"] = "S" - cca_df.at[budID, 'relative_ID'] = mothID - cca_df.at[budID, 'cell_cycle_stage'] = 'S' - cca_df.at[budID, 'relationship'] = 'bud' - cca_df.at[budID, 'generation_num'] = 0 + cca_df.at[budID, "relative_ID"] = mothID + cca_df.at[budID, "cell_cycle_stage"] = "S" + cca_df.at[budID, "relationship"] = "bud" + cca_df.at[budID, "generation_num"] = 0 return cca_df + def url_help(): - return 'https://github.com/hoerlteam/YeastMate/blob/main/examples/python_detection.ipynb' + return "https://github.com/hoerlteam/YeastMate/blob/main/examples/python_detection.ipynb" diff --git a/cellacdc/segmenters/__init__.py b/cellacdc/segmenters/__init__.py new file mode 100755 index 000000000..be133a6e6 --- /dev/null +++ b/cellacdc/segmenters/__init__.py @@ -0,0 +1 @@ +STARDIST_MODELS = ["2D_versatile_fluo", "2D_versatile_he", "2D_paper_dsb2018"] diff --git a/cellacdc/segmenters/_cellpose_base/__init__.py b/cellacdc/segmenters/_cellpose_base/__init__.py new file mode 100644 index 000000000..cea3daa92 --- /dev/null +++ b/cellacdc/segmenters/_cellpose_base/__init__.py @@ -0,0 +1,5 @@ +min_target_versions_cp = { + "2": "2.3.2", + "3": "3.1.1.2", + "4": "4.0.6", +} diff --git a/cellacdc/models/_cellpose_base/_directML.py b/cellacdc/segmenters/_cellpose_base/_directML.py similarity index 80% rename from cellacdc/models/_cellpose_base/_directML.py rename to cellacdc/segmenters/_cellpose_base/_directML.py index b71e55027..fafde0b0e 100644 --- a/cellacdc/models/_cellpose_base/_directML.py +++ b/cellacdc/segmenters/_cellpose_base/_directML.py @@ -1,20 +1,22 @@ from cellacdc import printl -from cellacdc.myutils import check_install_package +from cellacdc.utils import check_install_package import sys -def init_directML(): + +def init_directML(): success = True try: import torch_directml except ImportError: py_ver = sys.version_info - #check windows + # check windows from cellacdc import is_win + if is_win and py_ver.major == 3 and py_ver.minor < 13: success = check_install_package( - pkg_name = 'torch-directml', - import_pkg_name = 'torch_directml', - pypi_name = 'torch-directml', + pkg_name="torch-directml", + import_pkg_name="torch_directml", + pypi_name="torch-directml", return_outcome=True, ) else: @@ -28,10 +30,11 @@ def init_directML(): success = False return success + def setup_custom_device(model, device): """ Forces the model to use a custom device (e.g., DirectML) for inference. - This is a workaround, and could be handled better in the future. + This is a workaround, and could be handled better in the future. (Ideally when all parameters are set initially) Args: @@ -41,25 +44,25 @@ def setup_custom_device(model, device): Returns: model (cellpose.CellposeModel): Cellpose model with custom device set. """ - if hasattr(model, 'model'): + if hasattr(model, "model"): model = model.model - + model.gpu = True model.device = device model.mkldnn = False - if hasattr(model, 'net'): + if hasattr(model, "net"): model.net.to(device) model.net.mkldnn = False - if hasattr(model, 'cp'): + if hasattr(model, "cp"): model.cp.gpu = True model.cp.device = device model.cp.mkldnn = False - if hasattr(model.cp, 'net'): + if hasattr(model.cp, "net"): model.cp.net.to(device) model.cp.net.mkldnn = False - if hasattr(model, 'sz'): + if hasattr(model, "sz"): model.sz.device = device - + return model @@ -69,14 +72,13 @@ def setup_directML(acdc_cp_model): Args: model (cellpose.CellposeModel|cellpse.Cellpos): Cellpose model. Should work for v2, v3 and custom. - + Returns: model (cellpose.CellposeModel|cellpse.Cellpos): Cellpose model with DirectML set as the device. """ - print( - 'Using DirectML GPU for Cellpose model inference' - ) + print("Using DirectML GPU for Cellpose model inference") import torch_directml + directml_device = torch_directml.device() acdc_cp_model = setup_custom_device(acdc_cp_model, directml_device) - return acdc_cp_model \ No newline at end of file + return acdc_cp_model diff --git a/cellacdc/models/_cellpose_base/acdcSegment.py b/cellacdc/segmenters/_cellpose_base/acdcSegment.py similarity index 69% rename from cellacdc/models/_cellpose_base/acdcSegment.py rename to cellacdc/segmenters/_cellpose_base/acdcSegment.py index e971cde02..910a17f2e 100644 --- a/cellacdc/models/_cellpose_base/acdcSegment.py +++ b/cellacdc/segmenters/_cellpose_base/acdcSegment.py @@ -4,32 +4,35 @@ from typing import Tuple -from cellacdc import printl, myutils, core +from cellacdc import printl, utils, core import inspect + class BackboneOptions: """Options for cellpose backbone""" - values = ['default', "transformer"] + + values = ["default", "transformer"] + class GPUDirectMLGPUCPU: """Options for DirectML GPU acceleration""" - values = ['cpu', 'gpu','directml_gpu'] + + values = ["cpu", "gpu", "directml_gpu"] def cpu_gpu_directml_gpu( - input_string: str, - ): - """Translate input string to cpu, gpu or directml_gpu. - """ + input_string: str, +): + """Translate input string to cpu, gpu or directml_gpu.""" directml_gpu = False gpu = False input_string = input_string.lower() - if input_string == 'cpu': + if input_string == "cpu": pass - elif input_string == 'gpu': + elif input_string == "gpu": gpu = True - elif input_string == 'directml_gpu': + elif input_string == "directml_gpu": directml_gpu = True else: raise ValueError( @@ -38,12 +41,15 @@ def cpu_gpu_directml_gpu( ) return directml_gpu, gpu + class DealWithSecondChannelOptions: """Options available for dealing with second channel""" - values = ['together','separately', 'ignore'] + + values = ["together", "separately", "ignore"] + def check_deal_with_second_channel( - input_string: DealWithSecondChannelOptions, is_rgb: bool + input_string: DealWithSecondChannelOptions, is_rgb: bool ): if input_string not in DealWithSecondChannelOptions.values: raise ValueError( @@ -51,16 +57,16 @@ def check_deal_with_second_channel( f"Expected one of {DealWithSecondChannelOptions.values}." ) input_string = input_string.lower() - seperatly= False + seperatly = False together = False ignore = False if not is_rgb: pass - elif input_string == 'separately': + elif input_string == "separately": seperatly = True - elif input_string == 'together': + elif input_string == "together": together = True - elif input_string == 'ignore': + elif input_string == "ignore": ignore = True else: raise ValueError( @@ -69,29 +75,28 @@ def check_deal_with_second_channel( ) return seperatly, together, ignore + class Model: def __init__( - self, - ): - """Initialize cellpose base model class, which is used in the cellpose versions - """ + self, + ): + """Initialize cellpose base model class, which is used in the cellpose versions""" self.initConstants() - - + def check_model_path_model_type(self, model_path, model_type): - if model_path == 'None' or not model_path: + if model_path == "None" or not model_path: model_path = None - - if model_type == 'None' or not model_type: + + if model_type == "None" or not model_type: model_type = None - + if model_path is not None and model_type is not None: raise TypeError( "You cannot set both `model_type` and `model_path`. " "Please set only one of them." ) - + if model_path is None and model_type is None: raise TypeError( "You must set either `model_type` or `model_path`. " @@ -104,33 +109,35 @@ def initConstants(self, is_rgb=False): self.img_ndim = None self.z_axis = None self.channel_axis = None - self.cp_version = myutils.get_cellpose_major_version() + self.cp_version = utils.get_cellpose_major_version() self._sizemodelnotfound = True self.batch_size = None self.printed_model_params = False - + def setupLogger(self, logger): from cellpose import models + models.models_logger = logger - + def closeLogger(self): from cellpose import models + handlers = models.models_logger.handlers[:] for handler in handlers: handler.close() models.models_logger.removeHandler(handler) - + def _eval(self, image, **kwargs): if self.batch_size is not None: - kwargs['batch_size'] = self.batch_size + kwargs["batch_size"] = self.batch_size if self.cp_version == 4: - del kwargs['channels'] - kwargs['channel_axis'] = self.channel_axis - kwargs['z_axis'] = self.z_axis + del kwargs["channels"] + kwargs["channel_axis"] = self.channel_axis + kwargs["z_axis"] = self.z_axis if self.cp_version == 3: kwargs["channel_axis"] = self.channel_axis kwargs["z_axis"] = self.z_axis - + if not self.printed_model_params: if isinstance(image, list): sample_img = image[0] @@ -143,16 +150,12 @@ def _eval(self, image, **kwargs): print(f"Running model on image shape: {shape}, kwargs: {kwargs}") if self.is_rgb: for i, subarr in enumerate(np.moveaxis(sample_img, -3, 0)): - print(f"Channel {i+1} min: {subarr.min()}, max: {subarr.max()}") + print(f"Channel {i + 1} min: {subarr.min()}, max: {subarr.max()}") else: print(f"Image min: {sample_img.min()}, max: {sample_img.max()}") self.printed_model_params = True - - out, removed_kwargs = myutils.try_kwargs( - self.model.eval, - image, - **kwargs - ) + + out, removed_kwargs = utils.try_kwargs(self.model.eval, image, **kwargs) segm = out[0] if removed_kwargs: print( @@ -170,74 +173,75 @@ def _eval(self, image, **kwargs): "but was removed from eval method." ) return segm - + def second_ch_img_to_stack(self, first_ch_data, second_ch_data): # The 'cyto' model can work with a second channel (e.g., nucleus). # However, it needs to be encoded into one of the RGB channels - # Here we put the first channel in the 'red' channel and the + # Here we put the first channel in the 'red' channel and the # second channel in the 'green' channel. We then pass # `channels = [1,2]` to the segment method rgb_stack = np.zeros((*first_ch_data.shape, 3), dtype=first_ch_data.dtype) - - R_slice = [slice(None)]*(rgb_stack.ndim) + + R_slice = [slice(None)] * (rgb_stack.ndim) R_slice[-1] = 0 R_slice = tuple(R_slice) rgb_stack[R_slice] = first_ch_data - G_slice = [slice(None)]*(rgb_stack.ndim) + G_slice = [slice(None)] * (rgb_stack.ndim) G_slice[-1] = 1 G_slice = tuple(G_slice) rgb_stack[G_slice] = second_ch_data - + self.is_rgb = True return rgb_stack - + def get_zStack_rgb(self, image): if self.img_shape is None: self.img_shape = image.shape if self.img_ndim is None: self.img_ndim = len(self.img_shape) - self.is_rgb = (self.img_shape[-1] == 3 or self.img_shape[-1] == 4) if not self.is_rgb else self.is_rgb + self.is_rgb = ( + (self.img_shape[-1] == 3 or self.img_shape[-1] == 4) + if not self.is_rgb + else self.is_rgb + ) remaining_dims = self.img_ndim if self.is_rgb: remaining_dims -= 1 if self.timelapse: remaining_dims -= 1 - self.isZstack = ( - remaining_dims == 3 - ) + self.isZstack = remaining_dims == 3 return self.isZstack, self.is_rgb - + def get_eval_kwargs( - self, image, - diameter=0.0, - flow_threshold=0.4, - # cellprob_threshold=0.0, - stitch_threshold=0.0, - # min_size=15, - anisotropy=0.0, - # normalize=True, - # resample=True, - segment_3D_volume=False, - # max_size_fraction=0.4, - # flow3D_smooth=0, - # tile_overlap=0.1, - **kwargs - ): - """Get evaluation kwargs for the model.eval method, accurate for v2. - """ + self, + image, + diameter=0.0, + flow_threshold=0.4, + # cellprob_threshold=0.0, + stitch_threshold=0.0, + # min_size=15, + anisotropy=0.0, + # normalize=True, + # resample=True, + segment_3D_volume=False, + # max_size_fraction=0.4, + # flow3D_smooth=0, + # tile_overlap=0.1, + **kwargs, + ): + """Get evaluation kwargs for the model.eval method, accurate for v2.""" if diameter == 0.0 and self._sizemodelnotfound: raise TypeError( - 'Diameter is 0.0 but size model is not found. ' - 'Please set diameter to a non-zero value.' + "Diameter is 0.0 but size model is not found. " + "Please set diameter to a non-zero value." ) - if self.img_shape is None: self.img_shape = image.shape @@ -249,9 +253,9 @@ def get_eval_kwargs( if anisotropy == 0.0 and segment_3D_volume: if not self.printed_model_params: print( - 'Anisotropy is 0.0 but segment_3D_volume is True. ' - 'Please set anisotropy to a non-zero value.' \ - 'For now set to 1.0, assuming isotropic data.' + "Anisotropy is 0.0 but segment_3D_volume is True. " + "Please set anisotropy to a non-zero value." + "For now set to 1.0, assuming isotropic data." ) anisotropy = 1.0 @@ -259,62 +263,58 @@ def get_eval_kwargs( if not self.printed_model_params: print( """Anisotropy is set to 1.0 (assuming isotropic data), - since data is not a z-stack""") + since data is not a z-stack""" + ) anisotropy = 1.0 - + do_3D = segment_3D_volume if not isZstack: stitch_threshold = 0.0 segment_3D_volume = False do_3D = False - + if stitch_threshold > 0: if not self.printed_model_params: - print( - 'Using stiching mode instead of trying to segment 3D volume.' - ) + print("Using stiching mode instead of trying to segment 3D volume.") do_3D = False - + if isZstack and flow_threshold > 0: if not self.printed_model_params: print( - 'Flow threshold is not used for 3D segmentation. ' - 'Setting it to 0.0.' + "Flow threshold is not used for 3D segmentation. Setting it to 0.0." ) flow_threshold = 0.0 - - if flow_threshold==0.0: + + if flow_threshold == 0.0: flow_threshold = None - channels = [0,0] if not is_rgb else [1,2] + channels = [0, 0] if not is_rgb else [1, 2] eval_kwargs = { - 'channels': channels, - 'diameter': diameter, - 'flow_threshold': flow_threshold, + "channels": channels, + "diameter": diameter, + "flow_threshold": flow_threshold, #'cellprob_threshold': cellprob_threshold, - 'stitch_threshold': stitch_threshold, + "stitch_threshold": stitch_threshold, # 'min_size': min_size, # 'normalize': normalize, - 'do_3D': do_3D, - 'anisotropy': anisotropy, + "do_3D": do_3D, + "anisotropy": anisotropy, # 'resample': resample, # 'max_size_fraction': max_size_fraction, # 'flow3D_smooth': flow3D_smooth, # 'tile_overlap': tile_overlap } - if not segment_3D_volume and isZstack and stitch_threshold>0: + if not segment_3D_volume and isZstack and stitch_threshold > 0: raise TypeError( "`stitch_threshold` must be 0 when segmenting slice-by-slice. " "Alternatively, set `segment_3D_volume = True`." ) - + return eval_kwargs, isZstack - def eval_loop( - self, images, segment_3D_volume, init_imgs=True, **eval_kwargs - ): + def eval_loop(self, images, segment_3D_volume, init_imgs=True, **eval_kwargs): """No support for time lapse. This is handles in self._segment3DT_eval Parameters @@ -331,111 +331,123 @@ def eval_loop( Returns ------- np.ndarray - Segmentation masks array. If `segment_3D_volume` is True, - the shape is (Z, Y, X) or (T, Z, Y, X). If `segment_3D_volume` + Segmentation masks array. If `segment_3D_volume` is True, + the shape is (Z, Y, X) or (T, Z, Y, X). If `segment_3D_volume` is False, the shape is (Y, X) or (T, Y, X). """ if self.img_shape is None: self.img_shape = images.shape - if not segment_3D_volume and self.isZstack: # segment on a per slice basis + if not segment_3D_volume and self.isZstack: # segment on a per slice basis if init_imgs: images, z_axis, channel_axis = _initialize_image( - images, self.is_rgb, iter_axis_zstack=0, + images, + self.is_rgb, + iter_axis_zstack=0, isZstack=self.isZstack, ) else: z_axis = self.z_axis channel_axis = self.channel_axis - self.z_axis = None # since we are segmenting slice-by-slice - self.channel_axis = channel_axis - 1 if channel_axis is not None else None # since we iterate over z-axis + self.z_axis = None # since we are segmenting slice-by-slice + self.channel_axis = ( + channel_axis - 1 if channel_axis is not None else None + ) # since we iterate over z-axis if self.channel_axis is None: labels = np.zeros(images.shape, dtype=np.uint32) else: - shape = images.shape[:channel_axis] + images.shape[channel_axis+1:] + shape = images.shape[:channel_axis] + images.shape[channel_axis + 1 :] labels = np.zeros(shape, dtype=np.uint32) for i, z_img in enumerate(images): lab = self._eval(z_img, **eval_kwargs) labels[i] = lab - labels = skimage.measure.label(labels>0) + labels = skimage.measure.label(labels > 0) else: if init_imgs: - images, z_axis, channel_axis = _initialize_image(images, self.is_rgb, - isZstack=self.isZstack, - ) + images, z_axis, channel_axis = _initialize_image( + images, + self.is_rgb, + isZstack=self.isZstack, + ) self.z_axis = z_axis self.channel_axis = channel_axis else: z_axis = self.z_axis channel_axis = self.channel_axis - + labels = self._eval(images, **eval_kwargs) - + return labels - - def segment3DT_eval( - self, images, eval_kwargs, init_imgs=True, **kwargs - ): - if not kwargs['segment_3D_volume'] and self.isZstack: + + def segment3DT_eval(self, images, eval_kwargs, init_imgs=True, **kwargs): + if not kwargs["segment_3D_volume"] and self.isZstack: if init_imgs: - images, z_axis, channel_axis = _initialize_image(images, self.is_rgb, - iter_axis_time=0, - iter_axis_zstack=1, - timelapse=True, - isZstack=self.isZstack, - ) + images, z_axis, channel_axis = _initialize_image( + images, + self.is_rgb, + iter_axis_time=0, + iter_axis_zstack=1, + timelapse=True, + isZstack=self.isZstack, + ) else: z_axis = self.z_axis channel_axis = self.channel_axis - - self.z_axis = z_axis - 2 if z_axis is not None else None # video doesnt count as dim. iterate over time + + self.z_axis = ( + z_axis - 2 if z_axis is not None else None + ) # video doesnt count as dim. iterate over time self.channel_axis = channel_axis - 2 if channel_axis is not None else None - # Passing entire 4D video and segmenting slice-by-slice is + # Passing entire 4D video and segmenting slice-by-slice is # not possible --> iterate each frame and run normal segment if self.channel_axis is None: labels = np.zeros(images.shape, dtype=np.uint32) else: - shape = images.shape[:channel_axis] + images.shape[channel_axis+1:] + shape = images.shape[:channel_axis] + images.shape[channel_axis + 1 :] labels = np.zeros(shape, dtype=np.uint32) for i, img_t in enumerate(images): lab = self.eval_loop( - img_t, segment_3D_volume=False, - init_imgs=False, - **eval_kwargs + img_t, segment_3D_volume=False, init_imgs=False, **eval_kwargs ) labels[i] = lab else: - eval_kwargs['channels'] = [eval_kwargs['channels']]*len(images) + eval_kwargs["channels"] = [eval_kwargs["channels"]] * len(images) if init_imgs: - images, z_axis, channel_axis = _initialize_image(images, self.is_rgb, - iter_axis_time=0, - timelapse=True, - isZstack=self.isZstack, - ) + images, z_axis, channel_axis = _initialize_image( + images, + self.is_rgb, + iter_axis_time=0, + timelapse=True, + isZstack=self.isZstack, + ) else: z_axis = self.z_axis channel_axis = self.channel_axis - self.z_axis = z_axis - 1 if z_axis is not None else None # video doesnt count as dim + self.z_axis = ( + z_axis - 1 if z_axis is not None else None + ) # video doesnt count as dim self.channel_axis = channel_axis - 1 if channel_axis is not None else None - images = [image.astype(np.float32) for image in images] # convert to list + images = [image.astype(np.float32) for image in images] # convert to list labels = np.array(self._eval(images, **eval_kwargs)) return labels - -def _initialize_image(image:np.ndarray, - is_rgb:bool, - # single_img_shape:Tuple[int], - # single_img_ndim:int, - iter_axis_time:int=None, - iter_axis_zstack:int=None, - target_shape:Tuple[int]=None, - timelapse:bool=False, - isZstack:bool=False, - target_axis_iter:Tuple[int]=None, - add_rgb:bool=False, - ): + + +def _initialize_image( + image: np.ndarray, + is_rgb: bool, + # single_img_shape:Tuple[int], + # single_img_ndim:int, + iter_axis_time: int = None, + iter_axis_zstack: int = None, + target_shape: Tuple[int] = None, + timelapse: bool = False, + isZstack: bool = False, + target_axis_iter: Tuple[int] = None, + add_rgb: bool = False, +): """Tries to initialize image for cellpose. You will have to specify the target shape and the axis to iterate over. Target order of dimensions is (Z x nchan x Y x X) or (T x Z x nchan x Y x X) @@ -482,23 +494,39 @@ def _initialize_image(image:np.ndarray, f"Image is {len(true_img_shape)}D with shape {true_img_shape}. " "It was expected to have 4D shape (T x Z x Y x X x nchan)" ) - + z_axis = 1 if add_rgb: - target_shape = (true_img_shape[0], true_img_shape[1], 3, true_img_shape[-2], true_img_shape[-1]) + target_shape = ( + true_img_shape[0], + true_img_shape[1], + 3, + true_img_shape[-2], + true_img_shape[-1], + ) channel_axis = 2 elif is_rgb: - target_shape = (true_img_shape[0], true_img_shape[1], 3, true_img_shape[-3], true_img_shape[-2]) + target_shape = ( + true_img_shape[0], + true_img_shape[1], + 3, + true_img_shape[-3], + true_img_shape[-2], + ) channel_axis = 2 else: - target_shape = (true_img_shape[0], true_img_shape[1], true_img_shape[-2], true_img_shape[-1]) + target_shape = ( + true_img_shape[0], + true_img_shape[1], + true_img_shape[-2], + true_img_shape[-1], + ) channel_axis = None - if iter_axis_time is not None and iter_axis_zstack is not None: iter_axis = [iter_axis_time, iter_axis_zstack] target_axis_iter = [0, 1] - elif iter_axis_time is not None and iter_axis_zstack is None: + elif iter_axis_time is not None and iter_axis_zstack is None: iter_axis = [iter_axis_time] target_axis_iter = [0] elif iter_axis_time is None and iter_axis_zstack is not None: @@ -507,7 +535,7 @@ def _initialize_image(image:np.ndarray, else: iter_axis = None target_axis_iter = None - + elif timelapse and not isZstack: z_axis = None if len(true_img_shape) < 3 or (is_rgb and len(true_img_shape) < 4): @@ -516,10 +544,20 @@ def _initialize_image(image:np.ndarray, "It was expected to have 3D shape (T x Y x X x nchan)" ) if add_rgb: - target_shape = (true_img_shape[0], 3, true_img_shape[-2], true_img_shape[-1]) + target_shape = ( + true_img_shape[0], + 3, + true_img_shape[-2], + true_img_shape[-1], + ) channel_axis = 1 elif is_rgb: - target_shape = (true_img_shape[0], 3, true_img_shape[-3], true_img_shape[-2]) + target_shape = ( + true_img_shape[0], + 3, + true_img_shape[-3], + true_img_shape[-2], + ) channel_axis = 1 else: target_shape = (true_img_shape[0], true_img_shape[-2], true_img_shape[-1]) @@ -531,7 +569,7 @@ def _initialize_image(image:np.ndarray, else: iter_axis = None target_axis_iter = None - + elif not timelapse and isZstack: z_axis = 0 if len(true_img_shape) < 3 or (is_rgb and len(true_img_shape) < 4): @@ -540,10 +578,20 @@ def _initialize_image(image:np.ndarray, "It was expected to have 3D shape (Z x Y x X x nchan)" ) if add_rgb: - target_shape = (true_img_shape[0], 3, true_img_shape[-2], true_img_shape[-1]) + target_shape = ( + true_img_shape[0], + 3, + true_img_shape[-2], + true_img_shape[-1], + ) channel_axis = 1 elif is_rgb: - target_shape = (true_img_shape[0], 3, true_img_shape[-3], true_img_shape[-2]) + target_shape = ( + true_img_shape[0], + 3, + true_img_shape[-3], + true_img_shape[-2], + ) channel_axis = 1 else: target_shape = (true_img_shape[0], true_img_shape[-2], true_img_shape[-1]) @@ -555,7 +603,7 @@ def _initialize_image(image:np.ndarray, else: iter_axis = None target_axis_iter = None - + elif not timelapse and not isZstack: z_axis = None @@ -569,7 +617,7 @@ def _initialize_image(image:np.ndarray, channel_axis = 0 elif is_rgb: target_shape = (3, true_img_shape[-3], true_img_shape[-2]) - channel_axis = 0 + channel_axis = 0 else: target_shape = (true_img_shape[-2], true_img_shape[-1]) channel_axis = None @@ -580,18 +628,19 @@ def _initialize_image(image:np.ndarray, single_img_from_iter_axis = image[tuple(idx)] else: single_img_from_iter_axis = image - + if single_img_from_iter_axis is not None: single_img_shape = single_img_from_iter_axis.shape single_img_ndim = len(single_img_shape) else: single_img_shape = true_img_shape single_img_ndim = len(single_img_shape) - + single_img_isZstack = isZstack if iter_axis_zstack is None else False single_img_timelapse = timelapse if iter_axis_time is None else False - + from cellacdc._core import _initialize_single_image + image = core.apply_func_to_imgs( image, _initialize_single_image, @@ -604,32 +653,33 @@ def _initialize_image(image:np.ndarray, img_shape=single_img_shape, img_ndim=single_img_ndim, timelapse=single_img_timelapse, - add_rgb=add_rgb, ) return image, z_axis, channel_axis + def check_directml_gpu_gpu(model_name, directml_gpu, gpu, ask_install=True): if ask_install: - proceed, available_frameworks_list = myutils.check_gpu_available( - model_name, - use_gpu=(gpu or directml_gpu), - cuda=gpu, - return_available_gpu_type=True + proceed, available_frameworks_list = utils.check_gpu_available( + model_name, + use_gpu=(gpu or directml_gpu), + cuda=gpu, + return_available_gpu_type=True, ) else: proceed = True - available_frameworks_list = ['cuda', 'directML'] + available_frameworks_list = ["cuda", "directML"] - if 'cuda' not in available_frameworks_list: + if "cuda" not in available_frameworks_list: gpu = False - if 'directML' not in available_frameworks_list: + if "directML" not in available_frameworks_list: directml_gpu = False if not proceed: return directml_gpu, gpu, proceed if directml_gpu: - from cellacdc.models._cellpose_base._directML import init_directML + from cellacdc.segmenters._cellpose_base._directML import init_directML + directml_gpu = init_directML() if directml_gpu and gpu: @@ -640,20 +690,24 @@ def check_directml_gpu_gpu(model_name, directml_gpu, gpu, ask_install=True): """ ) gpu = False - + return directml_gpu, gpu, proceed + def setup_gpu_direct_ml(self, directml_gpu, gpu, device): if directml_gpu: - from cellacdc.models._cellpose_base._directML import setup_directML + from cellacdc.segmenters._cellpose_base._directML import setup_directML + setup_directML(self) from cellacdc.core import fix_sparse_directML + fix_sparse_directML() - - if gpu: # sometimes gpu is not properly set up ^^ - from cellacdc.models._cellpose_base._directML import setup_custom_device + + if gpu: # sometimes gpu is not properly set up ^^ + from cellacdc.segmenters._cellpose_base._directML import setup_custom_device + if device is None: device = 0 try: @@ -662,27 +716,28 @@ def setup_gpu_direct_ml(self, directml_gpu, gpu, device): pass if isinstance(device, int): - device = torch.device(f'cuda:{device}') + device = torch.device(f"cuda:{device}") elif isinstance(device, str): device = torch.device(device) - + setup_custom_device(self, device) + def _get_normalize_params( - image, - normalize=False, - rescale_intensity_low_val_perc=0.0, - rescale_intensity_high_val_perc=100.0, - # sharpen=0, - low_percentile=1.0, - high_percentile=99.0, - norm3D=False, - cp_version=4, - tile_norm_blocksize=0, - ): + image, + normalize=False, + rescale_intensity_low_val_perc=0.0, + rescale_intensity_high_val_perc=100.0, + # sharpen=0, + low_percentile=1.0, + high_percentile=99.0, + norm3D=False, + cp_version=4, + tile_norm_blocksize=0, +): if not normalize: return False - + rescale_intensity_low_val_perc = float(rescale_intensity_low_val_perc) rescale_intensity_high_val_perc = float(rescale_intensity_high_val_perc) low_percentile = float(low_percentile) @@ -690,26 +745,26 @@ def _get_normalize_params( normalize_kwargs = {} do_rescale = ( - rescale_intensity_low_val_perc != 0 - or rescale_intensity_high_val_perc != 100.0 + rescale_intensity_low_val_perc != 0 or rescale_intensity_high_val_perc != 100.0 ) if not do_rescale: - normalize_kwargs['lowhigh'] = None + normalize_kwargs["lowhigh"] = None else: - low = image*rescale_intensity_low_val_perc/100 - high = image*rescale_intensity_high_val_perc/100 - normalize_kwargs['lowhigh'] = (low, high) - + low = image * rescale_intensity_low_val_perc / 100 + high = image * rescale_intensity_high_val_perc / 100 + normalize_kwargs["lowhigh"] = (low, high) + # normalize_kwargs['sharpen'] = sharpen - normalize_kwargs['percentile'] = (low_percentile, high_percentile) - normalize_kwargs['norm3D'] = norm3D + normalize_kwargs["percentile"] = (low_percentile, high_percentile) + normalize_kwargs["norm3D"] = norm3D if cp_version == 4: - normalize_kwargs['tile_norm_blocksize'] = tile_norm_blocksize + normalize_kwargs["tile_norm_blocksize"] = tile_norm_blocksize elif cp_version == 3: - normalize_kwargs['tile_norm'] = tile_norm_blocksize - + normalize_kwargs["tile_norm"] = tile_norm_blocksize + return normalize_kwargs + def url_help(): - return 'https://cellpose.readthedocs.io/en/latest/api.html' \ No newline at end of file + return "https://cellpose.readthedocs.io/en/latest/api.html" diff --git a/cellacdc/segmenters/cellpose_v2/__init__.py b/cellacdc/segmenters/cellpose_v2/__init__.py new file mode 100644 index 000000000..848f98ccc --- /dev/null +++ b/cellacdc/segmenters/cellpose_v2/__init__.py @@ -0,0 +1,12 @@ +import cellacdc.utils as utils + +utils.check_install_cellpose(2) + + +class AvailableModelsv2: + from cellpose.models import MODEL_NAMES + + values = MODEL_NAMES + + is_exclusive_with = ["model_path"] + default_exclusive = "Using custom model" diff --git a/cellacdc/models/cellpose_v2/acdcSegment.py b/cellacdc/segmenters/cellpose_v2/acdcSegment.py similarity index 72% rename from cellacdc/models/cellpose_v2/acdcSegment.py rename to cellacdc/segmenters/cellpose_v2/acdcSegment.py index 561a587fb..13412f330 100644 --- a/cellacdc/models/cellpose_v2/acdcSegment.py +++ b/cellacdc/segmenters/cellpose_v2/acdcSegment.py @@ -1,27 +1,27 @@ import os -from cellacdc.models._cellpose_base.acdcSegment import Model as CellposeBaseModel +from cellacdc.segmenters._cellpose_base.acdcSegment import Model as CellposeBaseModel import torch -from cellacdc import myutils +from cellacdc import utils from . import AvailableModelsv2 + class Model(CellposeBaseModel): def __new__(cls, *args, **kwargs): - myutils.check_install_cellpose(2) + utils.check_install_cellpose(2) return super().__new__(cls) - + def __init__( - self, - model_type: AvailableModelsv2='cyto', - model_path: os.PathLike='', - net_avg:bool=False, - gpu:bool=False, - device:torch.device|int='None', - custom_concatenation:bool=False, - custom_style_on:bool=True, - custom_residual_on:bool=True, - custom_diam_mean:float=30.0, - - ): + self, + model_type: AvailableModelsv2 = "cyto", + model_path: os.PathLike = "", + net_avg: bool = False, + gpu: bool = False, + device: torch.device | int = "None", + custom_concatenation: bool = False, + custom_style_on: bool = True, + custom_residual_on: bool = True, + custom_diam_mean: float = 30.0, + ): """Initialize cellpose 2 model Parameters @@ -35,16 +35,16 @@ def __init__( `model_type`. If you want to use a custom model, set this to the path of the model file. Default is None. gpu : bool, optional - If True and PyTorch for your GPU is correctly installed, - denoising and segmentation processes will run on the GPU. + If True and PyTorch for your GPU is correctly installed, + denoising and segmentation processes will run on the GPU. Default is False directml_gpu : bool, optional If True, will attempt to use DirectML for GPU acceleration. Only for v3 and v4. v2 loads the model later, which causes problems. Dont want to edit cellpose code too much... device : torch.device or int or None If not None, this is the device used for running the model - (torch.device('cuda') or torch.device('cpu')). - It overrides `gpu`, recommended if you want to use a specific GPU + (torch.device('cuda') or torch.device('cpu')). + It overrides `gpu`, recommended if you want to use a specific GPU (e.g. torch.device('cuda:1'). Default is None custom_concatenation : bool, optional Only effects custom trained models. See cellpose v2 for more info. @@ -61,20 +61,24 @@ def __init__( """ self.init_successful = False self.initConstants() - model_type, model_path, device = myutils.translateStrNone(model_type, model_path, device) - + model_type, model_path, device = utils.translateStrNone( + model_type, model_path, device + ) + self.check_model_path_model_type( - model_type=model_type, - model_path=model_path, + model_type=model_type, + model_path=model_path, ) - - print(f'Initializing Cellpose v2...') + + print(f"Initializing Cellpose v2...") from cellpose import models + if model_type: try: self.model = models.Cellpose( - gpu=gpu, net_avg=net_avg, + gpu=gpu, + net_avg=net_avg, model_type=model_type, device=device, ) @@ -82,13 +86,17 @@ def __init__( except FileNotFoundError: self._sizemodelnotfound = True self.model = models.CellposeModel( - gpu=gpu, net_avg=net_avg, model_type=model_type, + gpu=gpu, + net_avg=net_avg, + model_type=model_type, device=device, ) elif model_path is not None: self._sizemodelnotfound = True self.model = models.CellposeModel( - gpu=gpu, net_avg=net_avg, device=device, + gpu=gpu, + net_avg=net_avg, + device=device, pretrained_model=model_path, concatenation=custom_concatenation, style_on=custom_style_on, @@ -96,40 +104,40 @@ def __init__( diam_mean=custom_diam_mean, ) self.init_successful = True - + def _get_eval_kwargs_v2( - self, - cellprob_threshold:float=0.0, - min_size:int=15, - normalize:bool=True, - resample:bool=True, - invert:bool=False, - original_kwargs:dict=None, + self, + cellprob_threshold: float = 0.0, + min_size: int = 15, + normalize: bool = True, + resample: bool = True, + invert: bool = False, + original_kwargs: dict = None, ): additional_kwargs = { - 'cellprob_threshold': cellprob_threshold, - 'min_size': min_size, - 'normalize': normalize, - 'resample': resample, - 'invert': invert, + "cellprob_threshold": cellprob_threshold, + "min_size": min_size, + "normalize": normalize, + "resample": resample, + "invert": invert, } original_kwargs.update(additional_kwargs) return original_kwargs - def segment( - self, image, - diameter:float=0.0, - flow_threshold:float=0.4, - cellprob_threshold:float=0.0, - stitch_threshold:float=0.0, - min_size:int=15, - normalize:bool=True, - resample:bool=True, - segment_3D_volume:bool=False, - anisotropy:float=0.0, - invert:bool=False, - ): + self, + image, + diameter: float = 0.0, + flow_threshold: float = 0.4, + cellprob_threshold: float = 0.0, + stitch_threshold: float = 0.0, + min_size: int = 15, + normalize: bool = True, + resample: bool = True, + segment_3D_volume: bool = False, + anisotropy: float = 0.0, + invert: bool = False, + ): """Segment image using cellpose eval Parameters @@ -137,38 +145,38 @@ def segment( image : (Y, X) or (Z, Y, X) numpy.ndarray Input image. Either 2D or 3D z-stack. diameter : float, optional - Average diameter (in pixels) of the obejcts of interest. + Average diameter (in pixels) of the obejcts of interest. Default is 0.0 flow_threshold : float, optional - Flow error threshold (all cells with errors below threshold are + Flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Default is 0.4 cellprob_threshold : float, optional - All pixels with value above threshold will be part of an object. + All pixels with value above threshold will be part of an object. Decrease this value to find more and larger masks. Default is 0.0 stitch_threshold : float, optional - If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` - is True, masks are stitched in 3D to return volume segmentation. + If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` + is True, masks are stitched in 3D to return volume segmentation. Default is 0.0 min_size : int, optional - Minimum number of pixels per mask, you can turn off this filter + Minimum number of pixels per mask, you can turn off this filter with `min_size = -1`. Default is 15 anisotropy : float, optional - For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if + For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Default is 0.0 normalize : bool, optional - If True, normalize image using the other parameters. + If True, normalize image using the other parameters. Default is True resample : bool, optional - Run dynamics at original image size (will be slower but create + Run dynamics at original image size (will be slower but create more accurate boundaries). Default is True segment_3D_volume : bool, optional - If True and input `image` is a 3D z-stack the entire z-stack - is passed to cellpose model. If False, Cell-ACDC will force one - z-slice at the time. Best results with 3D data are obtained by - passing the entire z-stack, but with a `stitch_threshold` greater - than 0 (e.g., 0.4). This way cellpose will internally segment - slice-by-slice and it will merge the resulting z-slice masks - belonging to the same object. + If True and input `image` is a 3D z-stack the entire z-stack + is passed to cellpose model. If False, Cell-ACDC will force one + z-slice at the time. Best results with 3D data are obtained by + passing the entire z-stack, but with a `stitch_threshold` greater + than 0 (e.g., 0.4). This way cellpose will internally segment + slice-by-slice and it will merge the resulting z-slice masks + belonging to the same object. Default is False invert : bool, optional If True, invert image pixel intensity before running network. @@ -183,7 +191,7 @@ def segment( ------ TypeError `stitch_threshold` must be 0 when segmenting slice-by-slice. - """ + """ self.timelapse = False self.img_shape = image.shape self.img_ndim = len(self.img_shape) @@ -198,7 +206,7 @@ def segment( anisotropy=anisotropy, # normalize=normalize, # resample=resample, - segment_3D_volume=segment_3D_volume + segment_3D_volume=segment_3D_volume, ) eval_kwargs = self._get_eval_kwargs_v2( @@ -207,20 +215,18 @@ def segment( normalize=normalize, resample=resample, invert=invert, - original_kwargs=eval_kwargs + original_kwargs=eval_kwargs, ) labels = self.eval_loop( - image, - segment_3D_volume=segment_3D_volume, - **eval_kwargs + image, segment_3D_volume=segment_3D_volume, **eval_kwargs ) self.img_shape = None self.img_ndim = None return labels - + def segment3DT(self, video_data, signals=None, **kwargs): self.timelapse = True self.img_shape = video_data[0].shape @@ -228,20 +234,16 @@ def segment3DT(self, video_data, signals=None, **kwargs): eval_kwargs, self.isZstack = self.get_eval_kwargs(video_data[0], **kwargs) eval_kwargs = self._get_eval_kwargs_v2( - cellprob_threshold=kwargs['cellprob_threshold'], - min_size=kwargs['min_size'], - normalize=kwargs['normalize'], - resample=kwargs['resample'], - invert=kwargs['invert'], - original_kwargs=eval_kwargs - ) - - - labels = self.segment3DT_eval( - video_data, eval_kwargs, **kwargs + cellprob_threshold=kwargs["cellprob_threshold"], + min_size=kwargs["min_size"], + normalize=kwargs["normalize"], + resample=kwargs["resample"], + invert=kwargs["invert"], + original_kwargs=eval_kwargs, ) + labels = self.segment3DT_eval(video_data, eval_kwargs, **kwargs) self.img_shape = None self.img_ndim = None - return labels \ No newline at end of file + return labels diff --git a/cellacdc/segmenters/cellpose_v3/__init__.py b/cellacdc/segmenters/cellpose_v3/__init__.py new file mode 100644 index 000000000..4b39378a4 --- /dev/null +++ b/cellacdc/segmenters/cellpose_v3/__init__.py @@ -0,0 +1,21 @@ +import cellacdc.utils as utils + +utils.check_install_cellpose(3) + + +class AvailableModelsv3: + from cellpose.models import MODEL_NAMES + + values = MODEL_NAMES + + is_exclusive_with = ["model_path"] + default_exclusive = "Using custom model" + + +class AvailableModelsv3Denoise: + from cellpose.denoise import MODEL_NAMES + + values = MODEL_NAMES + + is_exclusive_with = ["denoise_model_path"] + default_exclusive = "Using custom denoise model" diff --git a/cellacdc/models/cellpose_v3/_denoise.py b/cellacdc/segmenters/cellpose_v3/_denoise.py similarity index 56% rename from cellacdc/models/cellpose_v3/_denoise.py rename to cellacdc/segmenters/cellpose_v3/_denoise.py index 46f74f8aa..33a98b700 100644 --- a/cellacdc/models/cellpose_v3/_denoise.py +++ b/cellacdc/segmenters/cellpose_v3/_denoise.py @@ -5,31 +5,38 @@ from cellpose.denoise import DenoiseModel from . import AvailableModelsv3Denoise import os -from cellacdc import myutils +from cellacdc import utils -from cellacdc.models._cellpose_base.acdcSegment import (_initialize_image, GPUDirectMLGPUCPU, - cpu_gpu_directml_gpu, check_directml_gpu_gpu, - setup_gpu_direct_ml, _get_normalize_params, - DealWithSecondChannelOptions, check_deal_with_second_channel) +from cellacdc.segmenters._cellpose_base.acdcSegment import ( + _initialize_image, + GPUDirectMLGPUCPU, + cpu_gpu_directml_gpu, + check_directml_gpu_gpu, + setup_gpu_direct_ml, + _get_normalize_params, + DealWithSecondChannelOptions, + check_deal_with_second_channel, +) import torch from _types import NotGUIParam import itertools + class CellposeDenoiseModel(DenoiseModel): def __init__( - self, - device_type: GPUDirectMLGPUCPU='cpu', - device: torch.device | int | None = None, - batch_size: int = 8, - denoise_model: AvailableModelsv3Denoise='denoise_cyto3', - deal_with_second_channel: DealWithSecondChannelOptions = 'together', - denoise_model_path: os.PathLike='', - diam_mean: float = 30.0, - denoise_nchan: int = 1, - is_rgb: NotGUIParam = False, - ask_install_gpu: NotGUIParam = True, - ): + self, + device_type: GPUDirectMLGPUCPU = "cpu", + device: torch.device | int | None = None, + batch_size: int = 8, + denoise_model: AvailableModelsv3Denoise = "denoise_cyto3", + deal_with_second_channel: DealWithSecondChannelOptions = "together", + denoise_model_path: os.PathLike = "", + diam_mean: float = 30.0, + denoise_nchan: int = 1, + is_rgb: NotGUIParam = False, + ask_install_gpu: NotGUIParam = True, + ): """Initialize cellpose 3.0 denoising model Parameters @@ -41,8 +48,8 @@ def __init__( - 'directml': Use DirectML for running the model on GPU. device : torch.device or int or None If not None, this is the device used for running the model - (torch.device('cuda') or torch.device('cpu')). - It overrides `gpu`, recommended if you want to use a specific GPU + (torch.device('cuda') or torch.device('cpu')). + It overrides `gpu`, recommended if you want to use a specific GPU (e.g. torch.device('cuda:1'). Default is None batch_size : int, optional Batch size for running the model on GPU. Reduce to decrease memory usage, but it will slow down the processing. @@ -73,7 +80,7 @@ def __init__( Path to a custom cellpose denoise model file. diam_mean : float, optional Mean diameter of objects in the image for denoising during training. - If using a pretrained model, it is recommended to leave it as 0.0, + If using a pretrained model, it is recommended to leave it as 0.0, which will use the default diameter 30 denoise_nchan : int, optional Number of channels in the denoised image. Default is 1. @@ -89,75 +96,82 @@ def __init__( self.printed_model_params = False - self.is_rgb = is_rgb - self.denoise_second_channel_separately, self.denoise_second_channel_together, self.ignore_second_channel = check_deal_with_second_channel( - deal_with_second_channel, is_rgb) - + ( + self.denoise_second_channel_separately, + self.denoise_second_channel_together, + self.ignore_second_channel, + ) = check_deal_with_second_channel(deal_with_second_channel, is_rgb) + self.batch_size = batch_size - denoise_model, denoise_model_path, device = myutils.translateStrNone(denoise_model, denoise_model_path, device) - directml_gpu, gpu = cpu_gpu_directml_gpu( + denoise_model, denoise_model_path, device = utils.translateStrNone( + denoise_model, denoise_model_path, device + ) + directml_gpu, gpu = cpu_gpu_directml_gpu( input_string=device_type, ) - self.nstr = denoise_model.split('_')[-1] if denoise_model else None + self.nstr = denoise_model.split("_")[-1] if denoise_model else None - directml_gpu, gpu, proceed= check_directml_gpu_gpu( - 'cellpose_v3', directml_gpu, gpu, ask_install=ask_install_gpu + directml_gpu, gpu, proceed = check_directml_gpu_gpu( + "cellpose_v3", directml_gpu, gpu, ask_install=ask_install_gpu ) - + if not proceed: return - if denoise_model_path and denoise_model: + if denoise_model_path and denoise_model: raise ValueError( "You can only specify one of 'denoise_model_path' or 'denoise_model'." ) if diam_mean == 0.0: diam_mean = 30 - + if diam_mean != 30 and denoise_model: printl( f"[WARNING] It is recommended not to set 'denoise_diameter' for pretrained models!" ) - super().__init__(gpu=gpu, pretrained_model=denoise_model_path, diam_mean=diam_mean, chan2=self.denoise_second_channel_together, - nchan=denoise_nchan, device=device, model_type=denoise_model) - - setup_gpu_direct_ml( - self, - directml_gpu, - gpu, device) + super().__init__( + gpu=gpu, + pretrained_model=denoise_model_path, + diam_mean=diam_mean, + chan2=self.denoise_second_channel_together, + nchan=denoise_nchan, + device=device, + model_type=denoise_model, + ) + + setup_gpu_direct_ml(self, directml_gpu, gpu, device) - def run( - self, - image: np.ndarray, - diameter:float=0.0, - do_3D:bool=True, - invert:bool=False, - normalize:bool=True, - rescale_intensity_low_val_perc:float=0.0, - rescale_intensity_high_val_perc:float=100.0, - # sharpen:float=0, - low_percentile:float=1.0, - high_percentile:float=99.0, - norm3D:bool=False, - rescale:float=1.0, - tile_overlap:float=0.1, - isZstack:NotGUIParam=False, - timelapse:NotGUIParam=False, - init_image:NotGUIParam=True, - bsize:int=224, - normalize_dict:NotGUIParam=None, - ): + self, + image: np.ndarray, + diameter: float = 0.0, + do_3D: bool = True, + invert: bool = False, + normalize: bool = True, + rescale_intensity_low_val_perc: float = 0.0, + rescale_intensity_high_val_perc: float = 100.0, + # sharpen:float=0, + low_percentile: float = 1.0, + high_percentile: float = 99.0, + norm3D: bool = False, + rescale: float = 1.0, + tile_overlap: float = 0.1, + isZstack: NotGUIParam = False, + timelapse: NotGUIParam = False, + init_image: NotGUIParam = True, + bsize: int = 224, + normalize_dict: NotGUIParam = None, + ): """Run cellpose 3.0 denoise model Parameters ---------- image : numpy.ndarray - (Y, X) or (Z, Y, X) or (C, Y, X) (Z, C, Y, X). If timelapse, the left most dim is expected to be time + (Y, X) or (Z, Y, X) or (C, Y, X) (Z, C, Y, X). If timelapse, the left most dim is expected to be time diameter : float, optional Diameter of expected objects. If 0.0, cellpose will not try to estimate it (as opposed to the segmentation model) Will use 30 for everything except nuclei, which will use 17.0. @@ -170,20 +184,20 @@ def run( normalize : bool, optional If True, normalize image using the other parameters. Default is True rescale_intensity_low_val_perc : float, optional - Rescale intensities so that this is the minimum value in the image. + Rescale intensities so that this is the minimum value in the image. Default is 0.0 rescale_intensity_high_val_perc : float, optional - Rescale intensities so that this is the maximum value in the image. + Rescale intensities so that this is the maximum value in the image. Default is 100.0 # sharpen : int, optional - # Sharpen image with high pass filter, recommended to be 1/4-1/8 + # Sharpen image with high pass filter, recommended to be 1/4-1/8 # diameter of cells in pixels. Default is 0. low_percentile : float, optional Lower percentile for normalizing image. Default is 1.0 high_percentile : float, optional Higher percentile for normalizing image. Default is 99.0 norm3D : bool, optional - Compute normalization across entire z-stack rather than + Compute normalization across entire z-stack rather than plane-by-plane in stitching mode. Default is False rescale : float, optional Rescale image intensities to this value. Defaults to 1.0. Unless edge cases, should left to default None. @@ -205,92 +219,94 @@ def run( rescale = None if diameter == 0: - diameter = 30.0 if self.nstr != 'nuclei' else 17.0 - + diameter = 30.0 if self.nstr != "nuclei" else 17.0 + is_rgb = self.is_rgb if normalize_dict is None: normalize_params = _get_normalize_params( image, - normalize=normalize, - rescale_intensity_low_val_perc=rescale_intensity_low_val_perc, - rescale_intensity_high_val_perc=rescale_intensity_high_val_perc, + normalize=normalize, + rescale_intensity_low_val_perc=rescale_intensity_low_val_perc, + rescale_intensity_high_val_perc=rescale_intensity_high_val_perc, # sharpen=sharpen, - low_percentile=low_percentile, + low_percentile=low_percentile, high_percentile=high_percentile, - norm3D=norm3D + norm3D=norm3D, ) else: normalize_params = normalize_dict - - normalize_dict['invert'] = invert + + normalize_dict["invert"] = invert eval_kwargs = { - 'diameter': diameter, - 'normalize': normalize_params, - 'rescale': rescale, - 'tile_overlap': tile_overlap, - 'do_3D': do_3D, - 'bsize': bsize, + "diameter": diameter, + "normalize": normalize_params, + "rescale": rescale, + "tile_overlap": tile_overlap, + "do_3D": do_3D, + "bsize": bsize, } if self.batch_size is not None: - eval_kwargs['batch_size'] = self.batch_size - + eval_kwargs["batch_size"] = self.batch_size + self.isZstack = isZstack self.denoise_slices_separately = not do_3D and isZstack if self.denoise_second_channel_together: - eval_kwargs['channels'] = self.cellpose_rgb_channel - elif self.denoise_second_channel_separately or self.ignore_second_channel or not self.is_rgb: - eval_kwargs['channels'] = self.cellpose_greyscale_channel + eval_kwargs["channels"] = self.cellpose_rgb_channel + elif ( + self.denoise_second_channel_separately + or self.ignore_second_channel + or not self.is_rgb + ): + eval_kwargs["channels"] = self.cellpose_greyscale_channel else: - raise ValueError( - f"Invalid channels configuration for denoising!" - ) + raise ValueError(f"Invalid channels configuration for denoising!") iter_axis_zstack = None if not self.denoise_slices_separately else 0 - iter_axis_zstack = iter_axis_zstack + 1 if (timelapse and iter_axis_zstack is not None) else iter_axis_zstack + iter_axis_zstack = ( + iter_axis_zstack + 1 + if (timelapse and iter_axis_zstack is not None) + else iter_axis_zstack + ) if init_image: image, z_axis, channel_axis = _initialize_image( - image, - isZstack=isZstack, - is_rgb=is_rgb, + image, + isZstack=isZstack, + is_rgb=is_rgb, timelapse=timelapse, iter_axis_zstack=iter_axis_zstack, iter_axis_time=0 if timelapse else None, add_rgb=False, ) denoised_img = np.zeros_like(image) - # add proper iterations, check wtf is going on wuit + # add proper iterations, check wtf is going on wuit # (Z x nchan x Y x X) - is_model_given_3D = (isZstack and not self.denoise_slices_separately) - eval_kwargs['z_axis'] = 0 if is_model_given_3D else None - eval_kwargs['channel_axis'] = 1 if is_model_given_3D else 0 - if self.denoise_second_channel_separately or not is_rgb or self.ignore_second_channel: - eval_kwargs['channel_axis'] = None + is_model_given_3D = isZstack and not self.denoise_slices_separately + eval_kwargs["z_axis"] = 0 if is_model_given_3D else None + eval_kwargs["channel_axis"] = 1 if is_model_given_3D else 0 + if ( + self.denoise_second_channel_separately + or not is_rgb + or self.ignore_second_channel + ): + eval_kwargs["channel_axis"] = None if timelapse: - pbartime = tqdm( - total=len(image), ncols=100, desc='Denoising time-lapse: ' - ) + pbartime = tqdm(total=len(image), ncols=100, desc="Denoising time-lapse: ") else: - pbartime = tqdm( - total=1, ncols=100, desc='Denoising image: ' - ) + pbartime = tqdm(total=1, ncols=100, desc="Denoising image: ") if timelapse: for t, img in enumerate(image): - denoised_img = self._eval_image( - img, eval_kwargs, denoised_img, t=t - ) + denoised_img = self._eval_image(img, eval_kwargs, denoised_img, t=t) pbartime.update(1) else: - denoised_img = self._eval_image( - image, eval_kwargs, denoised_img - ) + denoised_img = self._eval_image(image, eval_kwargs, denoised_img) pbartime.update(1) - + pbartime.close() return denoised_img - + def _eval_image(self, image, eval_kwargs, entire_denoised_img, t=None): if t is not None: denoised_img = entire_denoised_img[t] @@ -298,63 +314,94 @@ def _eval_image(self, image, eval_kwargs, entire_denoised_img, t=None): denoised_img = entire_denoised_img # for NOT timelapse images helper funciton if self.denoise_slices_separately: - if self.denoise_second_channel_separately: # dont need to move channel axis in output since I iterate over it and put it back correctly - for z, c in tqdm(itertools.product(range(len(image)), self.first_second_channel), # only denoise channels which are requested - desc=f'Denoising z-slicesand channels', ncols=100): + if self.denoise_second_channel_separately: # dont need to move channel axis in output since I iterate over it and put it back correctly + for z, c in tqdm( + itertools.product( + range(len(image)), self.first_second_channel + ), # only denoise channels which are requested + desc=f"Denoising z-slicesand channels", + ncols=100, + ): img = image[z][c] img = self._acdc_eval(img, eval_kwargs) - img = np.squeeze(img) # remove channel axis if it was added + img = np.squeeze(img) # remove channel axis if it was added denoised_img[z, c] = img - elif self.ignore_second_channel: # dont need to move channel axis in output since I iterate over it and put it back correctly - for z, img_z in tqdm(enumerate(image), desc=f'Denoising z-slices: ', ncols=100): + elif self.ignore_second_channel: # dont need to move channel axis in output since I iterate over it and put it back correctly + for z, img_z in tqdm( + enumerate(image), desc=f"Denoising z-slices: ", ncols=100 + ): img = img_z[self.first_second_channel[0]] img = self._acdc_eval(img, eval_kwargs) - img = np.squeeze(img) # remove channel axis if it was added + img = np.squeeze(img) # remove channel axis if it was added denoised_img[z, self.first_second_channel[0]] = img # copy second channel as it is - denoised_img[:, self.first_second_channel[1]] = image[:, self.first_second_channel[1]] - else: # model either gets single gray or RGB image slices, channels was set correctly before - for z, img_z in tqdm(enumerate(image), desc=f'Denoising z-slices: ', ncols=100): - img = self._acdc_eval(img_z, eval_kwargs) # oputputs rgb last... + denoised_img[:, self.first_second_channel[1]] = image[ + :, self.first_second_channel[1] + ] + else: # model either gets single gray or RGB image slices, channels was set correctly before + for z, img_z in tqdm( + enumerate(image), desc=f"Denoising z-slices: ", ncols=100 + ): + img = self._acdc_eval(img_z, eval_kwargs) # oputputs rgb last... if not self.is_rgb: - img = np.squeeze(img) # remove channel axis if it was added + img = np.squeeze(img) # remove channel axis if it was added else: - img = np.moveaxis(img, -1, 0) # move channel axis to the front - img = self._add_rgb_channels(img, isZstack=False) # add rgb channels if needed + img = np.moveaxis(img, -1, 0) # move channel axis to the front + img = self._add_rgb_channels( + img, isZstack=False + ) # add rgb channels if needed denoised_img[z] = img else: - if self.denoise_second_channel_separately: # dont need to move channel axis in output since I iterate over it and put it back correctly + if self.denoise_second_channel_separately: # dont need to move channel axis in output since I iterate over it and put it back correctly if self.isZstack: - image = np.moveaxis(image, 1, 0) # move channel axis to the front + image = np.moveaxis(image, 1, 0) # move channel axis to the front denoised_img = np.moveaxis(denoised_img, 1, 0) - for c in tqdm(self.first_second_channel, desc=f'Denoising channels: ', ncols=100): + for c in tqdm( + self.first_second_channel, desc=f"Denoising channels: ", ncols=100 + ): img = self._acdc_eval(image[c], eval_kwargs) - denoised_img[c] = np.squeeze(img) # remove channel axis if it was added + denoised_img[c] = np.squeeze( + img + ) # remove channel axis if it was added if self.isZstack: denoised_img = np.moveaxis(denoised_img, 0, 1) - elif self.ignore_second_channel: # dont need to move channel axis in output since I iterate over it and put it back correctly + elif self.ignore_second_channel: # dont need to move channel axis in output since I iterate over it and put it back correctly if self.isZstack: - image = np.moveaxis(image, 1, 0) # move channel axis to the front + image = np.moveaxis(image, 1, 0) # move channel axis to the front denoised_img = np.moveaxis(denoised_img, 1, 0) img = self._acdc_eval(image[self.first_second_channel[0]], eval_kwargs) - img = np.squeeze(img) # remove channel axis if it was added - denoised_img[self.first_second_channel[0]] = img # remove channel axis if it was added + img = np.squeeze(img) # remove channel axis if it was added + denoised_img[self.first_second_channel[0]] = ( + img # remove channel axis if it was added + ) # copy second channel as it is - denoised_img[self.first_second_channel[1]] = image[self.first_second_channel[1]] + denoised_img[self.first_second_channel[1]] = image[ + self.first_second_channel[1] + ] if self.isZstack: - denoised_img = np.moveaxis(denoised_img, 0, 1) # move channel axis to the back + denoised_img = np.moveaxis( + denoised_img, 0, 1 + ) # move channel axis to the back else: - denoised_img = self._acdc_eval(image, eval_kwargs) # pass entire iamge, with or without channels. Channels param set before + denoised_img = self._acdc_eval( + image, eval_kwargs + ) # pass entire iamge, with or without channels. Channels param set before if self.is_rgb: if self.isZstack: - denoised_img = np.moveaxis(denoised_img, -1, 1) # move channel axis to the front after z + denoised_img = np.moveaxis( + denoised_img, -1, 1 + ) # move channel axis to the front after z else: - denoised_img = np.moveaxis(denoised_img, -1, 0) # move channel axis to the front + denoised_img = np.moveaxis( + denoised_img, -1, 0 + ) # move channel axis to the front else: - denoised_img = np.squeeze(denoised_img) # remove channel axis if it was added - + denoised_img = np.squeeze( + denoised_img + ) # remove channel axis if it was added + # make sure that no channel is lost and if true, add it back # should not be needed, as entire_denoised_img has right shape denoised_img = self._add_rgb_channels(denoised_img) @@ -365,7 +412,7 @@ def _eval_image(self, image, eval_kwargs, entire_denoised_img, t=None): return entire_denoised_img - def _add_rgb_channels(self, denoised_img:np.ndarray, isZstack=None): + def _add_rgb_channels(self, denoised_img: np.ndarray, isZstack=None): if not self.is_rgb: return denoised_img @@ -376,25 +423,34 @@ def _add_rgb_channels(self, denoised_img:np.ndarray, isZstack=None): if isZstack: channels = denoised_image_shape[1] if channels < 3: - shape_to_concat = (denoised_image_shape[0], 3 - channels, denoised_image_shape[2], denoised_image_shape[3]) + shape_to_concat = ( + denoised_image_shape[0], + 3 - channels, + denoised_image_shape[2], + denoised_image_shape[3], + ) # put it at position 0 denoised_img = np.concatenate( [denoised_img, np.zeros(shape_to_concat, dtype=denoised_img.dtype)], - axis=1 + axis=1, ) else: channels = denoised_image_shape[0] if channels < 3: - shape_to_concat = (3 - channels, denoised_image_shape[1], denoised_image_shape[2]) + shape_to_concat = ( + 3 - channels, + denoised_image_shape[1], + denoised_image_shape[2], + ) # put it at position 0 denoised_img = np.concatenate( [denoised_img, np.zeros(shape_to_concat, dtype=denoised_img.dtype)], - axis=0 + axis=0, ) return denoised_img def _acdc_eval(self, image, eval_kwargs): - if not self.printed_model_params: + if not self.printed_model_params: if isinstance(image, list): shape = image[0].shape shape = f"{len(image)} images of shape {shape}" @@ -404,12 +460,12 @@ def _acdc_eval(self, image, eval_kwargs): print(f"Running denoise on image shape: {shape}, kwargs: {eval_kwargs}") if self.denoise_second_channel_together: for i, subarr in enumerate(np.moveaxis(image, -3, 0)): - print(f"Channel {i+1} min: {subarr.min()}, max: {subarr.max()}") + print(f"Channel {i + 1} min: {subarr.min()}, max: {subarr.max()}") else: print(f"Image min: {image.min()}, max: {image.max()}") self.printed_model_params = True return self.eval(image, **eval_kwargs) - - + + def url_help(): - return 'https://www.biorxiv.org/content/10.1101/2024.02.10.579780v1' \ No newline at end of file + return "https://www.biorxiv.org/content/10.1101/2024.02.10.579780v1" diff --git a/cellacdc/models/cellpose_v3/acdcSegment.py b/cellacdc/segmenters/cellpose_v3/acdcSegment.py similarity index 65% rename from cellacdc/models/cellpose_v3/acdcSegment.py rename to cellacdc/segmenters/cellpose_v3/acdcSegment.py index 66e4d034e..b65c33a24 100644 --- a/cellacdc/models/cellpose_v3/acdcSegment.py +++ b/cellacdc/segmenters/cellpose_v3/acdcSegment.py @@ -1,39 +1,47 @@ import os -from cellacdc import myutils, printl -from cellacdc.models._cellpose_base.acdcSegment import Model as CellposeBaseModel -from cellacdc.models._cellpose_base.acdcSegment import (BackboneOptions, GPUDirectMLGPUCPU, cpu_gpu_directml_gpu, - check_directml_gpu_gpu, setup_gpu_direct_ml, _get_normalize_params, DealWithSecondChannelOptions) +from cellacdc import utils, printl +from cellacdc.segmenters._cellpose_base.acdcSegment import Model as CellposeBaseModel +from cellacdc.segmenters._cellpose_base.acdcSegment import ( + BackboneOptions, + GPUDirectMLGPUCPU, + cpu_gpu_directml_gpu, + check_directml_gpu_gpu, + setup_gpu_direct_ml, + _get_normalize_params, + DealWithSecondChannelOptions, +) import torch from . import AvailableModelsv3, AvailableModelsv3Denoise -from cellacdc.models._cellpose_base.acdcSegment import _initialize_image +from cellacdc.segmenters._cellpose_base.acdcSegment import _initialize_image from cellacdc._types import NotGUIParam import numpy as np + class Model(CellposeBaseModel): def __new__(cls, *args, **kwargs): - myutils.check_install_cellpose(3) + utils.check_install_cellpose(3) return super().__new__(cls) def __init__( - self, - model_type:AvailableModelsv3='cyto3', - model_path: os.PathLike='', - device_type: GPUDirectMLGPUCPU='cpu', - device: torch.device | int | None = None, - batch_size:int=8, - denoise_before_segmentation:bool=False, - denoise_model: AvailableModelsv3Denoise='denoise_cyto3', - denoise_second_channel: DealWithSecondChannelOptions = 'together', - denoise_model_path: os.PathLike='', - denoise_diameter:float=0.0, - denoise_nchan: int = 1, - backbone: BackboneOptions='default', - is_rgb: NotGUIParam = False, # whether the input image will be rgb - ): + self, + model_type: AvailableModelsv3 = "cyto3", + model_path: os.PathLike = "", + device_type: GPUDirectMLGPUCPU = "cpu", + device: torch.device | int | None = None, + batch_size: int = 8, + denoise_before_segmentation: bool = False, + denoise_model: AvailableModelsv3Denoise = "denoise_cyto3", + denoise_second_channel: DealWithSecondChannelOptions = "together", + denoise_model_path: os.PathLike = "", + denoise_diameter: float = 0.0, + denoise_nchan: int = 1, + backbone: BackboneOptions = "default", + is_rgb: NotGUIParam = False, # whether the input image will be rgb + ): """Initialize cellpose 3 model Parameters @@ -53,8 +61,8 @@ def __init__( - 'directml': Use DirectML for running the model on GPU. device : torch.device or int or None If not None, this is the device used for running the model - (torch.device('cuda') or torch.device('cpu')). - It overrides `gpu`, recommended if you want to use a specific GPU + (torch.device('cuda') or torch.device('cpu')). + It overrides `gpu`, recommended if you want to use a specific GPU (e.g. torch.device('cuda:1'). Default is None batch_size : int, optional Batch size for running the model on GPU. Reduce to decrease memory usage, but it will slow down the processing. @@ -87,7 +95,7 @@ def __init__( denoise_diameter : float, optional Mean diameter of objects in the image for denoising (during training?). If left at 0, will pass 30, the cellpose default value. - It is not clear what exactly this parameter does from cellpose documentation, + It is not clear what exactly this parameter does from cellpose documentation, please don't touch it if you dont know too! denoise_nchan : int, optional Number of channels in the denoised image. Default is 1. @@ -99,34 +107,39 @@ def __init__( self.batch_size = batch_size self.init_successful = False - out = myutils.translateStrNone( + out = utils.translateStrNone( model_type, model_path, device, denoise_model, denoise_model_path ) model_type, model_path, device, denoise_model, denoise_model_path = out self.check_model_path_model_type( - model_type=model_type, - model_path=model_path, + model_type=model_type, + model_path=model_path, ) - directml_gpu, gpu = cpu_gpu_directml_gpu( + directml_gpu, gpu = cpu_gpu_directml_gpu( input_string=device_type, ) directml_gpu, gpu, proceed = check_directml_gpu_gpu( - 'cellpose_v3', directml_gpu=directml_gpu, gpu=gpu, + "cellpose_v3", + directml_gpu=directml_gpu, + gpu=gpu, ) if not proceed: return if denoise_before_segmentation and denoise_model: - denoise_model_type = denoise_model.split('_')[-1] if denoise_model else None + denoise_model_type = denoise_model.split("_")[-1] if denoise_model else None if denoise_model_type != model_type: - printl(f'[WARNING] denoise model type {denoise_model_type} does not match ') - - print(f'Initializing Cellpose v3...') + printl( + f"[WARNING] denoise model type {denoise_model_type} does not match " + ) + + print(f"Initializing Cellpose v3...") import cellpose + if model_type: try: self.model = cellpose.models.Cellpose( @@ -135,17 +148,17 @@ def __init__( model_type=model_type, backbone=backbone, ) - self._sizemodelnotfound = False + self._sizemodelnotfound = False except FileNotFoundError: - printl(f'Size model for {model_type} not found.') + printl(f"Size model for {model_type} not found.") self._sizemodelnotfound = True self.model = cellpose.models.CellposeModel( gpu=gpu, device=device, model_type=model_type, backbone=backbone, - ) + ) elif model_path is not None: self._sizemodelnotfound = True self.model = cellpose.models.CellposeModel( @@ -157,134 +170,132 @@ def __init__( self.denoiseModel = None if denoise_before_segmentation: - from cellacdc.models.cellpose_v3 import _denoise + from cellacdc.segmenters.cellpose_v3 import _denoise + self.denoiseModel = _denoise.CellposeDenoiseModel( device_type=device_type, - device=device, denoise_model=denoise_model, + device=device, + denoise_model=denoise_model, denoise_model_path=denoise_model_path, diam_mean=denoise_diameter, - deal_with_second_channel=denoise_second_channel, + deal_with_second_channel=denoise_second_channel, denoise_nchan=denoise_nchan, - batch_size=batch_size, + batch_size=batch_size, is_rgb=self.is_rgb, ask_install_gpu=False, # don't ask to install cellpose if not installed ) - - setup_gpu_direct_ml( - self, - directml_gpu, - gpu, device) - + + setup_gpu_direct_ml(self, directml_gpu, gpu, device) + self.init_successful = True def _get_eval_kwargs_v3( - self, - eval_kwargs: dict, - **kwargs: dict, + self, + eval_kwargs: dict, + **kwargs: dict, ): eval_kwargs_3 = { - 'cellprob_threshold': kwargs['cellprob_threshold'], - 'min_size': kwargs['min_size'], - 'resample': kwargs['resample'], - 'max_size_fraction': kwargs['max_size_fraction'], - 'flow3D_smooth': kwargs['flow3D_smooth'], - 'tile_overlap': kwargs['tile_overlap'], - 'invert': kwargs['invert'], - + "cellprob_threshold": kwargs["cellprob_threshold"], + "min_size": kwargs["min_size"], + "resample": kwargs["resample"], + "max_size_fraction": kwargs["max_size_fraction"], + "flow3D_smooth": kwargs["flow3D_smooth"], + "tile_overlap": kwargs["tile_overlap"], + "invert": kwargs["invert"], } eval_kwargs.update(eval_kwargs_3) return eval_kwargs - def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not supported - self, - image, - diameter:float=0.0, - flow_threshold:float=0.4, - cellprob_threshold:float=0.0, - resample:bool=True, - min_size:int=15, - max_size_fraction:float=0.4, - segment_3D_volume:bool=False, - stitch_threshold:float=0.0, - flow3D_smooth:float=0, - anisotropy:float=0.0, - tile_overlap:float=0.1, - invert:bool=False, - normalize:bool=True, - rescale_intensity_low_val_perc:float=0.0, - rescale_intensity_high_val_perc:float=100.0, - # sharpen:int=0, - low_percentile:float=1.0, - high_percentile:float=99.0, - norm3D:bool=False, - tile_norm_blocksize: int=0, - denoise_rescale:float=1.0, - init_imgs:NotGUIParam=True, - bsize:int=224, - ): + def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not supported + self, + image, + diameter: float = 0.0, + flow_threshold: float = 0.4, + cellprob_threshold: float = 0.0, + resample: bool = True, + min_size: int = 15, + max_size_fraction: float = 0.4, + segment_3D_volume: bool = False, + stitch_threshold: float = 0.0, + flow3D_smooth: float = 0, + anisotropy: float = 0.0, + tile_overlap: float = 0.1, + invert: bool = False, + normalize: bool = True, + rescale_intensity_low_val_perc: float = 0.0, + rescale_intensity_high_val_perc: float = 100.0, + # sharpen:int=0, + low_percentile: float = 1.0, + high_percentile: float = 99.0, + norm3D: bool = False, + tile_norm_blocksize: int = 0, + denoise_rescale: float = 1.0, + init_imgs: NotGUIParam = True, + bsize: int = 224, + ): """Run cellpose 3.0 denoising + segmentation model Parameters ---------- image : (Y, X) or (Z, Y, X) numpy.ndarray - 2D or 3D image (z-stack). + 2D or 3D image (z-stack). diameter : float, optional - Diameter of expected objects. If 0.0, it uses 30.0 for "one-click" + Diameter of expected objects. If 0.0, it uses 30.0 for "one-click" and 17.0 for "nuclei". Default is 0.0 flow_threshold : float, optional - Flow error threshold (all cells with errors below threshold are + Flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Default is 0.4 cellprob_threshold : float, optional - All pixels with value above threshold will be part of an object. + All pixels with value above threshold will be part of an object. Decrease this value to find more and larger masks. Default is 0.0 resample : bool, optional - Run dynamics at original image size (will be slower but create + Run dynamics at original image size (will be slower but create more accurate boundaries). Default is True min_size : int, optional - Minimum number of pixels per mask, you can turn off this filter + Minimum number of pixels per mask, you can turn off this filter with `min_size = -1`. Default is 15 max_size_fraction : float, optional Masks larger than this fraction of total image size are removed. Default is 0.4. segment_3D_volume : bool, optional - If True and input `image` is a 3D z-stack the entire z-stack - is passed to cellpose model. If False, Cell-ACDC will force one - z-slice at the time. Best results with cellpose and 3D data are - obtained by passing the entire z-stack, but with a - `stitch_threshold` greater than 0 (e.g., 0.4). This way cellpose - will internally segment slice-by-slice and it will merge the - resulting z-slice masks belonging to the same object. + If True and input `image` is a 3D z-stack the entire z-stack + is passed to cellpose model. If False, Cell-ACDC will force one + z-slice at the time. Best results with cellpose and 3D data are + obtained by passing the entire z-stack, but with a + `stitch_threshold` greater than 0 (e.g., 0.4). This way cellpose + will internally segment slice-by-slice and it will merge the + resulting z-slice masks belonging to the same object. Default is False stitch_threshold : float, optional - If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` - is True, masks are stitched in 3D to return volume segmentation. + If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` + is True, masks are stitched in 3D to return volume segmentation. Default is 0.0 anisotropy : float, optional - For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if + For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Default is 0.0 tile_overlap : float, optional Fraction of overlap of tiles when computing flows. Defaults to 0.1. invert : bool, optional Invert image pixel intensity before running network. Default is False. normalize : bool, optional - If True, normalize image using the other parameters. + If True, normalize image using the other parameters. Default is True rescale_intensity_low_val_perc : float, optional - Rescale intensities so that this is the minimum value in the image. + Rescale intensities so that this is the minimum value in the image. Default is 0.0 rescale_intensity_high_val_perc : float, optional - Rescale intensities so that this is the maximum value in the image. + Rescale intensities so that this is the maximum value in the image. Default is 100.0 # sharpen : int, optional - # Sharpen image with high pass filter, recommended to be 1/4-1/8 + # Sharpen image with high pass filter, recommended to be 1/4-1/8 # diameter of cells in pixels. Default is 0. low_percentile : float, optional Lower percentile for normalizing image. Default is 1.0 high_percentile : float, optional Higher percentile for normalizing image. Default is 99.0 norm3D : bool, optional - Compute normalization across entire z-stack rather than + Compute normalization across entire z-stack rather than plane-by-plane in stitching mode. Default is False tile_norm_blocksize : int, optional Size of the tiles for normalization. Default is 0, which means no tiling. @@ -294,7 +305,7 @@ def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not support Default is 224. Don't change it unless you know what you are doing, please! - """ + """ self.timelapse = False image @@ -334,30 +345,35 @@ def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not support high_percentile=high_percentile, norm3D=norm3D, normalize=normalize, - tile_norm_blocksize=tile_norm_blocksize - ) + tile_norm_blocksize=tile_norm_blocksize, + ) if self.denoiseModel is not None: self.isZstack, self.is_rgb = self.get_zStack_rgb( - image,) + image, + ) if init_imgs: if not segment_3D_volume and self.isZstack: - image, z_axis, channel_axis = _initialize_image(image, self.is_rgb, - iter_axis_zstack=0, - isZstack=self.isZstack, - ) - self.channel_axis = channel_axis # changing the axis for cellpose is handled in the eval loop + image, z_axis, channel_axis = _initialize_image( + image, + self.is_rgb, + iter_axis_zstack=0, + isZstack=self.isZstack, + ) + self.channel_axis = channel_axis # changing the axis for cellpose is handled in the eval loop self.z_axis = z_axis else: - image, z_axis, channel_axis = _initialize_image(image, self.is_rgb, - isZstack=self.isZstack, - ) + image, z_axis, channel_axis = _initialize_image( + image, + self.is_rgb, + isZstack=self.isZstack, + ) self.z_axis = z_axis self.channel_axis = channel_axis image = self.denoiseModel.run( image, diameter=diameter, - do_3D=eval_kwargs['do_3D'], + do_3D=eval_kwargs["do_3D"], normalize_dict=norm_kwargs, tile_overlap=tile_overlap, timelapse=False, @@ -368,8 +384,9 @@ def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not support invert=invert, ) - eval_kwargs['normalize'] = norm_kwargs if self.denoiseModel is None else True # if denoise model was used, just normalise the image with default parameters - + eval_kwargs["normalize"] = ( + norm_kwargs if self.denoiseModel is None else True + ) # if denoise model was used, just normalise the image with default parameters self.img_shape = image.shape self.img_ndim = len(self.img_shape) @@ -379,77 +396,85 @@ def segment( # 2D, 2D x stacks. 2D over time is in segment3DT, 4D is not support image, segment_3D_volume=segment_3D_volume, init_imgs=init_imgs_eval_loop, - **eval_kwargs + **eval_kwargs, ) self.img_shape = None self.img_ndim = None return labels - - def segment3DT(self, video_data, signals=None, init_imgs=True, **kwargs): # just 2D over time + + def segment3DT( + self, video_data, signals=None, init_imgs=True, **kwargs + ): # just 2D over time self.timelapse = True eval_kwargs, self.isZstack = self.get_eval_kwargs(video_data[0], **kwargs) eval_kwargs = self._get_eval_kwargs_v3( eval_kwargs=eval_kwargs, - cellprob_threshold=kwargs['cellprob_threshold'], - min_size=kwargs['min_size'], - resample=kwargs['resample'], - max_size_fraction=kwargs['max_size_fraction'], - flow3D_smooth=kwargs['flow3D_smooth'], - tile_overlap=kwargs['tile_overlap'], - invert=kwargs['invert'], + cellprob_threshold=kwargs["cellprob_threshold"], + min_size=kwargs["min_size"], + resample=kwargs["resample"], + max_size_fraction=kwargs["max_size_fraction"], + flow3D_smooth=kwargs["flow3D_smooth"], + tile_overlap=kwargs["tile_overlap"], + invert=kwargs["invert"], ) norm_kwargs = _get_normalize_params( image=video_data, - normalize=kwargs['normalize'], - rescale_intensity_low_val_perc=kwargs['rescale_intensity_low_val_perc'], - rescale_intensity_high_val_perc=kwargs['rescale_intensity_high_val_perc'], + normalize=kwargs["normalize"], + rescale_intensity_low_val_perc=kwargs["rescale_intensity_low_val_perc"], + rescale_intensity_high_val_perc=kwargs["rescale_intensity_high_val_perc"], # sharpen=kwargs['sharpen'], - low_percentile=kwargs['low_percentile'], - high_percentile=kwargs['high_percentile'], - norm3D=kwargs['norm3D'], + low_percentile=kwargs["low_percentile"], + high_percentile=kwargs["high_percentile"], + norm3D=kwargs["norm3D"], ) if self.denoiseModel is not None: if init_imgs: - if not kwargs['segment_3D_volume'] and self.isZstack: - video_data, z_axis, channel_axis = _initialize_image(video_data, self.is_rgb, - iter_axis_time=0, - iter_axis_zstack=1, - timelapse=True, - isZstack=self.isZstack, - ) - self.z_axis = z_axis # changing of axis is handled in the eval loop - self.channel_axis = channel_axis + if not kwargs["segment_3D_volume"] and self.isZstack: + video_data, z_axis, channel_axis = _initialize_image( + video_data, + self.is_rgb, + iter_axis_time=0, + iter_axis_zstack=1, + timelapse=True, + isZstack=self.isZstack, + ) + self.z_axis = z_axis # changing of axis is handled in the eval loop + self.channel_axis = channel_axis else: - video_data, z_axis, channel_axis = _initialize_image(video_data, self.is_rgb, - iter_axis_time=0, - timelapse=True, - isZstack=self.isZstack, - ) - self.z_axis = z_axis # changing of axis is handled in the eval loop + video_data, z_axis, channel_axis = _initialize_image( + video_data, + self.is_rgb, + iter_axis_time=0, + timelapse=True, + isZstack=self.isZstack, + ) + self.z_axis = z_axis # changing of axis is handled in the eval loop self.channel_axis = channel_axis - + video_data = self.denoiseModel.run( video_data, - diameter=eval_kwargs['diameter'], - do_3D=eval_kwargs['do_3D'], + diameter=eval_kwargs["diameter"], + do_3D=eval_kwargs["do_3D"], normalize_dict=norm_kwargs, - tile_overlap=kwargs['tile_overlap'], + tile_overlap=kwargs["tile_overlap"], timelapse=True, - bsize=kwargs['bsize'], + bsize=kwargs["bsize"], isZstack=self.isZstack, init_image=False, # Denoise model does not need init_imgs, already done - rescale=kwargs['denoise_rescale'], - ) - + rescale=kwargs["denoise_rescale"], + ) + self.img_shape = video_data[0].shape self.img_ndim = len(self.img_shape) - eval_kwargs['normalize'] = norm_kwargs if self.denoiseModel is None else True # if denoise model was used, just normalise the image with default parameters + eval_kwargs["normalize"] = ( + norm_kwargs if self.denoiseModel is None else True + ) # if denoise model was used, just normalise the image with default parameters init_imgs_segment3DT_eval = init_imgs if self.denoiseModel is None else False labels = self.segment3DT_eval( @@ -460,5 +485,6 @@ def segment3DT(self, video_data, signals=None, init_imgs=True, **kwargs): # just self.img_ndim = None return labels + def url_help(): - return 'https://cellpose.readthedocs.io/en/latest/api.html' + return "https://cellpose.readthedocs.io/en/latest/api.html" diff --git a/cellacdc/segmenters/cellpose_v4/__init__.py b/cellacdc/segmenters/cellpose_v4/__init__.py new file mode 100644 index 000000000..90b83d418 --- /dev/null +++ b/cellacdc/segmenters/cellpose_v4/__init__.py @@ -0,0 +1,12 @@ +import cellacdc.utils as utils + +utils.check_install_cellpose(4) + + +class AvailableModelsv4: + from cellpose.models import MODEL_NAMES + + values = MODEL_NAMES + + is_exclusive_with = ["model_path"] + default_exclusive = "Using custom model" diff --git a/cellacdc/models/cellpose_v4/acdcSegment.py b/cellacdc/segmenters/cellpose_v4/acdcSegment.py similarity index 67% rename from cellacdc/models/cellpose_v4/acdcSegment.py rename to cellacdc/segmenters/cellpose_v4/acdcSegment.py index d0eb2e007..2c0291694 100644 --- a/cellacdc/models/cellpose_v4/acdcSegment.py +++ b/cellacdc/segmenters/cellpose_v4/acdcSegment.py @@ -1,28 +1,30 @@ import os -from cellacdc import myutils, printl +from cellacdc import utils, printl import torch -from cellacdc.models._cellpose_base.acdcSegment import (Model as CellposeBaseModel, - GPUDirectMLGPUCPU, - cpu_gpu_directml_gpu, - check_directml_gpu_gpu, - setup_gpu_direct_ml, - _get_normalize_params) +from cellacdc.segmenters._cellpose_base.acdcSegment import ( + Model as CellposeBaseModel, + GPUDirectMLGPUCPU, + cpu_gpu_directml_gpu, + check_directml_gpu_gpu, + setup_gpu_direct_ml, + _get_normalize_params, +) from . import AvailableModelsv4 + class Model(CellposeBaseModel): def __new__(cls, *args, **kwargs): - myutils.check_install_cellpose(4) + utils.check_install_cellpose(4) return super().__new__(cls) - + def __init__( - self, - model_type: AvailableModelsv4='cpsam', - model_path: os.PathLike='', - device_type: GPUDirectMLGPUCPU='cpu', - device:torch.device|int='None', - batch_size:int=8, - - ): + self, + model_type: AvailableModelsv4 = "cpsam", + model_path: os.PathLike = "", + device_type: GPUDirectMLGPUCPU = "cpu", + device: torch.device | int = "None", + batch_size: int = 8, + ): """Initialize Cellpose 4 (Cellpose-SAM) model Parameters @@ -42,114 +44,113 @@ def __init__( - 'directml': Use DirectML for running the model on GPU. device : torch.device or int or None If not None, this is the device used for running the model - (torch.device('cuda') or torch.device('cpu')). - It overrides `gpu`, recommended if you want to use a specific GPU + (torch.device('cuda') or torch.device('cpu')). + It overrides `gpu`, recommended if you want to use a specific GPU (e.g. torch.device('cuda:1'). Default is None batch_size : int, optional Batch size for running the model on GPU. Reduce to decrease memory usage, but it will slow down the processing. Default is 8. - """ + """ self.init_successful = False self.initConstants() self.batch_size = batch_size - model_type, model_path, device = myutils.translateStrNone(model_type, model_path, device) + model_type, model_path, device = utils.translateStrNone( + model_type, model_path, device + ) self.check_model_path_model_type( - model_type=model_type, - model_path=model_path, + model_type=model_type, + model_path=model_path, ) - directml_gpu, gpu = cpu_gpu_directml_gpu( + directml_gpu, gpu = cpu_gpu_directml_gpu( input_string=device_type, ) directml_gpu, gpu, proceed = check_directml_gpu_gpu( - 'cellpose_v4', directml_gpu=directml_gpu, gpu=gpu, + "cellpose_v4", + directml_gpu=directml_gpu, + gpu=gpu, ) if not proceed: return model_path = model_path or model_type - - major_version = myutils.get_cellpose_major_version() - print(f'Initializing Cellpose v{major_version}...') + + major_version = utils.get_cellpose_major_version() + print(f"Initializing Cellpose v{major_version}...") from cellpose import models + self.model = models.CellposeModel( gpu=gpu, device=device, pretrained_model=model_path, - ) - - setup_gpu_direct_ml( - self, - directml_gpu, - gpu, device) - + ) + + setup_gpu_direct_ml(self, directml_gpu, gpu, device) + self.init_successful = True - + def _get_eval_kwargs_v4( - self, - max_size_fraction:float=0.4, - invert:bool=False, - flow3D_smooth:int=0, - niter:int=0, - augment:bool=False, - tile_overlap:float=0.1, - bsize:int=224, - # interp:bool=True, - min_size:int=15, - cellprob_threshold:float=0.0, - prev_kwargs:dict=None, - **kwargs - ): + self, + max_size_fraction: float = 0.4, + invert: bool = False, + flow3D_smooth: int = 0, + niter: int = 0, + augment: bool = False, + tile_overlap: float = 0.1, + bsize: int = 224, + # interp:bool=True, + min_size: int = 15, + cellprob_threshold: float = 0.0, + prev_kwargs: dict = None, + **kwargs, + ): if niter == 0: niter = None prev_kwargs = self._filter_kwargs(**prev_kwargs) - + additional_kwargs = { - 'max_size_fraction': max_size_fraction, - 'invert': invert, - 'flow3D_smooth': flow3D_smooth, - 'niter': niter, - 'augment': augment, - 'tile_overlap': tile_overlap, - 'bsize': bsize, - 'min_size': min_size, - 'cellprob_threshold': cellprob_threshold, + "max_size_fraction": max_size_fraction, + "invert": invert, + "flow3D_smooth": flow3D_smooth, + "niter": niter, + "augment": augment, + "tile_overlap": tile_overlap, + "bsize": bsize, + "min_size": min_size, + "cellprob_threshold": cellprob_threshold, # 'interp': interp } prev_kwargs.update(additional_kwargs) - + return prev_kwargs - def _filter_kwargs( - self, - **kwargs - ): + def _filter_kwargs(self, **kwargs): kwarg_key_list = [ - 'channels', - 'diameter', - 'flow_threshold', - 'stitch_threshold', - 'do_3D', - 'anisotropy', + "channels", + "diameter", + "flow_threshold", + "stitch_threshold", + "do_3D", + "anisotropy", ] for key in list(kwargs.keys()): if key not in kwarg_key_list: del kwargs[key] - + for key in kwarg_key_list: if key not in kwargs: raise KeyError( f"Key '{key}' not found in kwargs. " "Please provide all required keys." ) - + return kwargs # def _filter_kwargs( @@ -169,101 +170,101 @@ def _filter_kwargs( # for key in list(kwargs.keys()): # if key not in kwarg_key_list: # del kwargs[key] - + # for key in kwarg_key_list: # if key not in kwargs: # raise KeyError( # f"Key '{key}' not found in kwargs. " # "Please provide all required keys." # ) - + # return kwargs - + def segment( - self, image, - diameter:float=0.0, - flow_threshold:float=0.4, - cellprob_threshold:float=0.0, - min_size:int=15, - max_size_fraction:float=0.4, - invert:bool=False, - segment_3D_volume:bool=False, - stitch_threshold:float=0.0, - flow3D_smooth:float=0, - anisotropy:float=0.0, - tile_overlap:float=0.1, - normalize:bool=True, - rescale_intensity_low_val_perc:float=0.0, - rescale_intensity_high_val_perc:float=100.0, - # sharpen:int=0, - low_percentile:float=1.0, - high_percentile:float=99.0, - norm3D:bool=False, - tile_norm_blocksize: int=0, - niter:int=0, - augment:bool=False, - bsize:int=256, - # interp:bool=True, - - ): + self, + image, + diameter: float = 0.0, + flow_threshold: float = 0.4, + cellprob_threshold: float = 0.0, + min_size: int = 15, + max_size_fraction: float = 0.4, + invert: bool = False, + segment_3D_volume: bool = False, + stitch_threshold: float = 0.0, + flow3D_smooth: float = 0, + anisotropy: float = 0.0, + tile_overlap: float = 0.1, + normalize: bool = True, + rescale_intensity_low_val_perc: float = 0.0, + rescale_intensity_high_val_perc: float = 100.0, + # sharpen:int=0, + low_percentile: float = 1.0, + high_percentile: float = 99.0, + norm3D: bool = False, + tile_norm_blocksize: int = 0, + niter: int = 0, + augment: bool = False, + bsize: int = 256, + # interp:bool=True, + ): """Segment an image using Cellpose (see details in v2) Parameters ---------- image : (Y, X) or (Z, Y, X) numpy.ndarray - 2D or 3D image (z-stack). + 2D or 3D image (z-stack). diameter : float, optional - Diameter of expected objects. If 0.0, it uses 30.0 for "one-click" + Diameter of expected objects. If 0.0, it uses 30.0 for "one-click" and 17.0 for "nuclei". Default is 0.0 flow_threshold : float, optional - Flow error threshold (all cells with errors below threshold are + Flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Default is 0.4 cellprob_threshold : float, optional - All pixels with value above threshold will be part of an object. + All pixels with value above threshold will be part of an object. Decrease this value to find more and larger masks. Default is 0.0 min_size : int, optional - Minimum number of pixels per mask, you can turn off this filter + Minimum number of pixels per mask, you can turn off this filter with `min_size = -1`. Default is 15 max_size_fraction : float, optional Masks larger than this fraction of total image size are removed. Default is 0.4. invert : bool, optional Invert image pixel intensity before running network. Default is False. segment_3D_volume : bool, optional - If True and input `image` is a 3D z-stack the entire z-stack - is passed to cellpose model. If False, Cell-ACDC will force one - z-slice at the time. Best results with cellpose and 3D data are - obtained by passing the entire z-stack, but with a - `stitch_threshold` greater than 0 (e.g., 0.4). This way cellpose - will internally segment slice-by-slice and it will merge the - resulting z-slice masks belonging to the same object. + If True and input `image` is a 3D z-stack the entire z-stack + is passed to cellpose model. If False, Cell-ACDC will force one + z-slice at the time. Best results with cellpose and 3D data are + obtained by passing the entire z-stack, but with a + `stitch_threshold` greater than 0 (e.g., 0.4). This way cellpose + will internally segment slice-by-slice and it will merge the + resulting z-slice masks belonging to the same object. Default is False stitch_threshold : float, optional - If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` - is True, masks are stitched in 3D to return volume segmentation. + If `stitch_threshold` is greater than 0.0 and `segment_3D_volume` + is True, masks are stitched in 3D to return volume segmentation. Default is 0.0 anisotropy : float, optional - For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if + For 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Default is 0.0 tile_overlap : float, optional Fraction of overlap of tiles when computing flows. Defaults to 0.1. normalize : bool, optional - If True, normalize image using the other parameters. + If True, normalize image using the other parameters. Default is True rescale_intensity_low_val_perc : float, optional - Rescale intensities so that this is the minimum value in the image. + Rescale intensities so that this is the minimum value in the image. Default is 0.0 rescale_intensity_high_val_perc : float, optional - Rescale intensities so that this is the maximum value in the image. + Rescale intensities so that this is the maximum value in the image. Default is 100.0 # sharpen : int, optional - # Sharpen image with high pass filter, recommended to be 1/4-1/8 + # Sharpen image with high pass filter, recommended to be 1/4-1/8 # diameter of cells in pixels. Default is 0. low_percentile : float, optional Lower percentile for normalizing image. Default is 1.0 high_percentile : float, optional Higher percentile for normalizing image. Default is 99.0 norm3D : bool, optional - Compute normalization across entire z-stack rather than + Compute normalization across entire z-stack rather than plane-by-plane in stitching mode. Default is False tile_norm_blocksize : int, optional Size of the tiles for normalization. Default is 0, which means no tiling. @@ -276,7 +277,7 @@ def segment( bsize : int, optional Block size for tiles, recommended to keep at 224 (as in training). Default is 224. - """ + """ self.timelapse = False self.img_shape = image.shape self.img_ndim = len(self.img_shape) @@ -307,7 +308,7 @@ def segment( # interp=interp, min_size=min_size, cellprob_threshold=cellprob_threshold, - prev_kwargs=eval_kwargs + prev_kwargs=eval_kwargs, ) norm_kwargs = _get_normalize_params( @@ -319,15 +320,13 @@ def segment( high_percentile=high_percentile, norm3D=norm3D, normalize=normalize, - tile_norm_blocksize=tile_norm_blocksize - ) - - eval_kwargs['normalize'] = norm_kwargs - - labs = self.eval_loop( - image, segment_3D_volume, **eval_kwargs + tile_norm_blocksize=tile_norm_blocksize, ) + eval_kwargs["normalize"] = norm_kwargs + + labs = self.eval_loop(image, segment_3D_volume, **eval_kwargs) + self.img_shape = None self.img_ndim = None @@ -337,26 +336,19 @@ def segment3DT(self, video_data, signals=None, **kwargs): self.timelapse = True self.img_shape = video_data[0].shape self.img_ndim = len(self.img_shape) - + image = video_data[0] - eval_kwargs, self.isZstack = self.get_eval_kwargs( - image, - **kwargs - ) + eval_kwargs, self.isZstack = self.get_eval_kwargs(image, **kwargs) - eval_kwargs = self._get_eval_kwargs_v4( - **kwargs, - prev_kwargs=eval_kwargs - ) + eval_kwargs = self._get_eval_kwargs_v4(**kwargs, prev_kwargs=eval_kwargs) - labels = self.segment3DT_eval( - video_data, eval_kwargs, **kwargs - ) + labels = self.segment3DT_eval(video_data, eval_kwargs, **kwargs) self.img_shape = None self.img_ndim = None - return labels + return labels + def url_help(): - return 'https://cellpose.readthedocs.io/en/latest/api.html' \ No newline at end of file + return "https://cellpose.readthedocs.io/en/latest/api.html" diff --git a/cellacdc/models/cellsam/__init__.py b/cellacdc/segmenters/cellsam/__init__.py similarity index 62% rename from cellacdc/models/cellsam/__init__.py rename to cellacdc/segmenters/cellsam/__init__.py index 3faed5aad..c40d16f0c 100644 --- a/cellacdc/models/cellsam/__init__.py +++ b/cellacdc/segmenters/cellsam/__init__.py @@ -1,6 +1,6 @@ -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_cellsam() +utils.check_install_cellsam() import cellSAM @@ -8,6 +8,6 @@ # cellsam_general: trained on datasets from the original publication # cellsam_extra: incorporates additional datasets beyond the paper model_types = { - 'General': 'cellsam_general', - 'Extra': 'cellsam_extra', + "General": "cellsam_general", + "Extra": "cellsam_extra", } diff --git a/cellacdc/models/cellsam/acdcSegment.py b/cellacdc/segmenters/cellsam/acdcSegment.py similarity index 89% rename from cellacdc/models/cellsam/acdcSegment.py rename to cellacdc/segmenters/cellsam/acdcSegment.py index a024b9dc2..6d64afd11 100644 --- a/cellacdc/models/cellsam/acdcSegment.py +++ b/cellacdc/segmenters/cellsam/acdcSegment.py @@ -10,7 +10,7 @@ from cellSAM.cellsam_pipeline import cellsam_pipeline, normalize_image from cellSAM.wsi import segment_wsi -from cellacdc import myutils, printl +from cellacdc import utils, printl class AvailableModels: @@ -27,21 +27,21 @@ class Boolean: class Model: def __init__( - self, - model_type: AvailableModels='General', - model_path: os.PathLike='', - bbox_threshold: float=0.4, - low_contrast_enhancement: bool=False, - use_wsi: bool=True, - gauge_cell_size: bool=False, - block_size: int=400, - overlap: int=56, - iou_depth: int=56, - iou_threshold: float=0.5, - postprocess: bool=False, - remove_boundaries: bool=False, - gpu: bool=True - ): + self, + model_type: AvailableModels = "General", + model_path: os.PathLike = "", + bbox_threshold: float = 0.4, + low_contrast_enhancement: bool = False, + use_wsi: bool = True, + gauge_cell_size: bool = False, + block_size: int = 400, + overlap: int = 56, + iou_depth: int = 56, + iou_threshold: float = 0.5, + postprocess: bool = False, + remove_boundaries: bool = False, + gpu: bool = True, + ): """Initialization of CellSAM Model within Cell-ACDC CellSAM is a foundation model for cell segmentation that achieves @@ -100,9 +100,9 @@ def __init__( Whether to use GPU for inference (if available). Default is True """ if gpu and torch.cuda.is_available(): - self.device = 'cuda' + self.device = "cuda" else: - self.device = 'cpu' + self.device = "cpu" self.bbox_threshold = bbox_threshold self.low_contrast_enhancement = low_contrast_enhancement @@ -115,10 +115,10 @@ def __init__( self.postprocess = postprocess self.remove_boundaries = remove_boundaries - model_path = myutils.translateStrNone(model_path)[0] + model_path = utils.translateStrNone(model_path)[0] if model_path: - print(f'Loading CellSAM model from {model_path}...') + print(f"Loading CellSAM model from {model_path}...") self.model = get_local_model(model_path) else: model_name = model_types[model_type] @@ -127,7 +127,7 @@ def __init__( self.model = get_model(model=model_name) except Exception as e: error_msg = str(e).lower() - if 'token' in error_msg or 'auth' in error_msg or '401' in error_msg: + if "token" in error_msg or "auth" in error_msg or "401" in error_msg: raise RuntimeError( f"Failed to download CellSAM model: {e}\n\n" "Hint: CellSAM requires a DeepCell access token. " @@ -139,15 +139,15 @@ def __init__( self.model = self.model.to(self.device) self.model.bbox_threshold = bbox_threshold - print(f'CellSAM model loaded successfully on {self.device}') + print(f"CellSAM model loaded successfully on {self.device}") def segment( - self, - image: np.ndarray, - frame_i: int=0, - automatic_removal_of_background: Boolean=False, - posData: NotParam=None - ) -> np.ndarray: + self, + image: np.ndarray, + frame_i: int = 0, + automatic_removal_of_background: Boolean = False, + posData: NotParam = None, + ) -> np.ndarray: """Segment image using CellSAM Parameters @@ -223,23 +223,28 @@ def _segment_2D_image(self, image: np.ndarray) -> np.ndarray: if self.use_wsi: # Use WSI pipeline for large images or dense cell populations import dask.array as da + img_normalized = normalize_image(img.astype(np.float32)) if self.low_contrast_enhancement: from cellSAM.utils import enhance_low_contrast + img_normalized = enhance_low_contrast(img_normalized) inp = da.from_array(img_normalized, chunks=256) if self.gauge_cell_size: from cellSAM.cellsam_pipeline import use_cellsize_gaging + labels = use_cellsize_gaging( - inp, self.model, self.device, + inp, + self.model, + self.device, block_size=self.block_size, overlap=self.overlap, iou_depth=self.iou_depth, iou_threshold=self.iou_threshold, - bbox_threshold=self.bbox_threshold + bbox_threshold=self.bbox_threshold, ) else: labels = segment_wsi( @@ -251,7 +256,7 @@ def _segment_2D_image(self, image: np.ndarray) -> np.ndarray: normalize=True, model=self.model, device=self.device, - bbox_threshold=self.bbox_threshold + bbox_threshold=self.bbox_threshold, ).compute() else: # Direct segmentation for smaller images @@ -262,7 +267,7 @@ def _segment_2D_image(self, image: np.ndarray) -> np.ndarray: postprocess=self.postprocess, remove_boundaries=self.remove_boundaries, bbox_threshold=self.bbox_threshold, - device=self.device + device=self.device, ) return labels.astype(np.uint32) @@ -304,7 +309,7 @@ def _prepare_image(self, image: np.ndarray) -> np.ndarray: else: # Pad with zeros img = np.zeros((*image.shape[:-1], 3), dtype=image.dtype) - img[..., :image.shape[-1]] = image + img[..., : image.shape[-1]] = image else: raise ValueError(f"Unexpected image shape: {image.shape}") @@ -350,4 +355,4 @@ def _remove_background_from_labels(self, labels: np.ndarray) -> np.ndarray: def url_help(): - return 'https://github.com/vanvalenlab/cellSAM' + return "https://github.com/vanvalenlab/cellSAM" diff --git a/cellacdc/segmenters/delta/__init__.py b/cellacdc/segmenters/delta/__init__.py new file mode 100644 index 000000000..f18c9391d --- /dev/null +++ b/cellacdc/segmenters/delta/__init__.py @@ -0,0 +1,9 @@ +""" +Installs delta2 into acdc. + +@author: jroberts / jamesr787 +""" + +from cellacdc import utils + +utils.check_install_package("delta", pypi_name="delta2") diff --git a/cellacdc/models/delta/acdcSegment.py b/cellacdc/segmenters/delta/acdcSegment.py similarity index 81% rename from cellacdc/models/delta/acdcSegment.py rename to cellacdc/segmenters/delta/acdcSegment.py index 1cec89787..fd6bd09e0 100644 --- a/cellacdc/models/delta/acdcSegment.py +++ b/cellacdc/segmenters/delta/acdcSegment.py @@ -19,9 +19,7 @@ class Model: - - def __init__(self, - model_type='2D or mothermachine'): + def __init__(self, model_type="2D or mothermachine"): """ Configures data, initializes model, loads weights for model. @@ -50,18 +48,21 @@ def __init__(self, except ValueError: # Downloads model weights and configuration files for 2D and mothermachine - download_assets(load_models=True, - load_sets=False, - load_evals=False, - config_level='local') - - def delta_preprocess(self, - image, - target_size: Tuple[int, int] = (256, 32), - order: int = 1, - rangescale: bool = True, - crop: bool = False, - ): + download_assets( + load_models=True, + load_sets=False, + load_evals=False, + config_level="local", + ) + + def delta_preprocess( + self, + image, + target_size: Tuple[int, int] = (256, 32), + order: int = 1, + rangescale: bool = True, + crop: bool = False, + ): """ Takes image and reformat it @@ -104,7 +105,7 @@ def delta_preprocess(self, for j in range(2) ] img = np.zeros((fill_shape[0], fill_shape[1])) - img[0: i.shape[0], 0: i.shape[1]] = i + img[0 : i.shape[0], 0 : i.shape[1]] = i if rangescale: if np.ptp(img) != 0: @@ -136,26 +137,23 @@ def segment(self, image): original_shape = image.shape if image.ndim != 2: - raise ValueError( - f"""Delta only works with 2 dimensional images.""" - ) + raise ValueError(f"""Delta only works with 2 dimensional images.""") # 2D: Cut into overlapping windows - img = self.delta_preprocess(image=image, - target_size=self.target_size, - crop=True) + img = self.delta_preprocess( + image=image, target_size=self.target_size, crop=True + ) # Process image to use for delta - image = self.delta_preprocess(image=image, - target_size=self.target_size, - crop=cfg.crop_windows) + image = self.delta_preprocess( + image=image, target_size=self.target_size, crop=cfg.crop_windows + ) # Change Dimensions to 4D numpy array image = np.reshape(image, (1,) + image.shape + (1,)) # mother machine: Don't crop images into windows if not cfg.crop_windows: - # Predictions: results = self.model.predict(image, verbose=1)[0, :, :, 0] @@ -190,5 +188,6 @@ def segment(self, image): return lab.astype(np.uint32) + def url_help(): - return 'https://gitlab.com/dunloplab/delta' \ No newline at end of file + return "https://gitlab.com/dunloplab/delta" diff --git a/cellacdc/segmenters/omnipose/__init__.py b/cellacdc/segmenters/omnipose/__init__.py new file mode 100644 index 000000000..d9ffcd128 --- /dev/null +++ b/cellacdc/segmenters/omnipose/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +import subprocess + +from cellacdc import utils + +utils.check_install_omnipose() diff --git a/cellacdc/models/omnipose/acdcSegment.py b/cellacdc/segmenters/omnipose/acdcSegment.py similarity index 54% rename from cellacdc/models/omnipose/acdcSegment.py rename to cellacdc/segmenters/omnipose/acdcSegment.py index 36f1fbdc3..a84e7e652 100644 --- a/cellacdc/models/omnipose/acdcSegment.py +++ b/cellacdc/segmenters/omnipose/acdcSegment.py @@ -12,88 +12,85 @@ from omnipose.core import OMNI_MODELS + class AvailableModels: values = OMNI_MODELS + class Model: def __init__( - self, - model_type: AvailableModels='bact_phase_omni', - net_avg=False, - gpu=False - ): + self, model_type: AvailableModels = "bact_phase_omni", net_avg=False, gpu=False + ): if model_type not in OMNI_MODELS: err_msg = ( - f'"{model_type}" not available. ' - f'Available models are {OMNI_MODELS}' + f'"{model_type}" not available. Available models are {OMNI_MODELS}' ) raise NameError(err_msg) - self.model = models.Cellpose( - gpu=gpu, net_avg=net_avg, model_type=model_type - ) - + self.model = models.Cellpose(gpu=gpu, net_avg=net_avg, model_type=model_type) + def _eval(self, image, **kwargs): - kwargs['omni'] = True + kwargs["omni"] = True return self.model.eval(image.astype(np.float32), **kwargs)[0] - + def _initialize_image(self, image): # See cellpose.io._initialize_images if image.ndim == 2: - image = image[np.newaxis,...] - - img_min = image.min() + image = image[np.newaxis, ...] + + img_min = image.min() img_max = image.max() image = image.astype(np.float32) image -= img_min if img_max > img_min + 1e-3: - image /= (img_max - img_min) + image /= img_max - img_min image *= 255 if image.ndim < 4: - image = image[:,:,:,np.newaxis] + image = image[:, :, :, np.newaxis] return image - + def segment( - self, image, - diameter=0.0, - flow_threshold=0.4, - cellprob_threshold=0.0, - stitch_threshold=0.0, - min_size=15, - anisotropy=0.0, - normalize=True, - resample=True, - segment_3D_volume=False - ): + self, + image, + diameter=0.0, + flow_threshold=0.4, + cellprob_threshold=0.0, + stitch_threshold=0.0, + min_size=15, + anisotropy=0.0, + normalize=True, + resample=True, + segment_3D_volume=False, + ): # Preprocess image # image = image/image.max() # image = skimage.filters.gaussian(image, sigma=1) # image = skimage.exposure.equalize_adapthist(image) if anisotropy == 0 or image.ndim == 2: anisotropy = None - + do_3D = segment_3D_volume if image.ndim == 2: stitch_threshold = 0.0 segment_3D_volume = False do_3D = False - + if stitch_threshold > 0: do_3D = False - - if flow_threshold==0.0 or image.ndim==3: + + if flow_threshold == 0.0 or image.ndim == 3: flow_threshold = None eval_kwargs = { - 'channels': [0,0], - 'diameter': diameter, - 'flow_threshold': flow_threshold, - 'cellprob_threshold': cellprob_threshold, - 'stitch_threshold': stitch_threshold, - 'min_size': min_size, - 'normalize': normalize, - 'do_3D': do_3D, - 'anisotropy': anisotropy, - 'resample': resample + "channels": [0, 0], + "diameter": diameter, + "flow_threshold": flow_threshold, + "cellprob_threshold": cellprob_threshold, + "stitch_threshold": stitch_threshold, + "min_size": min_size, + "normalize": normalize, + "do_3D": do_3D, + "anisotropy": anisotropy, + "resample": resample, } # Run cellpose eval @@ -103,11 +100,12 @@ def segment( _img = self._initialize_image(_img) lab = self._eval(_img, **eval_kwargs) labels[i] = lab - labels = skimage.measure.label(labels>0) + labels = skimage.measure.label(labels > 0) else: - image = self._initialize_image(image) + image = self._initialize_image(image) labels = self._eval(image, **eval_kwargs) return labels + def url_help(): - return 'https://omnipose.readthedocs.io/' + return "https://omnipose.readthedocs.io/" diff --git a/cellacdc/segmenters/omnipose_custom/__init__.py b/cellacdc/segmenters/omnipose_custom/__init__.py new file mode 100644 index 000000000..a0b504c6f --- /dev/null +++ b/cellacdc/segmenters/omnipose_custom/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +import subprocess + +from cellacdc import utils + +utils.check_install_package("omnipose_acdc") diff --git a/cellacdc/models/omnipose_custom/acdcSegment.py b/cellacdc/segmenters/omnipose_custom/acdcSegment.py similarity index 61% rename from cellacdc/models/omnipose_custom/acdcSegment.py rename to cellacdc/segmenters/omnipose_custom/acdcSegment.py index 949f435b6..8948c604e 100644 --- a/cellacdc/models/omnipose_custom/acdcSegment.py +++ b/cellacdc/segmenters/omnipose_custom/acdcSegment.py @@ -9,29 +9,31 @@ from cellpose_omni import models -from cellacdc.models.omnipose import acdcSegment as cp_omni +from cellacdc.segmenters.omnipose import acdcSegment as cp_omni from omnipose.core import OMNI_MODELS from cellacdc import printl + class Model: - def __init__(self, model_path: os.PathLike = '', net_avg=False, gpu=False): + def __init__(self, model_path: os.PathLike = "", net_avg=False, gpu=False): self.acdcCellpose = cp_omni.Model() self.acdcCellpose.model = models.CellposeModel( gpu=gpu, net_avg=net_avg, pretrained_model=model_path ) def segment( - self, image, - diameter=0.0, - flow_threshold=0.4, - cellprob_threshold=0.0, - stitch_threshold=0.0, - min_size=15, - anisotropy=0.0, - normalize=True, - resample=True, - segment_3D_volume=False - ): + self, + image, + diameter=0.0, + flow_threshold=0.4, + cellprob_threshold=0.0, + stitch_threshold=0.0, + min_size=15, + anisotropy=0.0, + normalize=True, + resample=True, + segment_3D_volume=False, + ): labels = self.acdcCellpose.segment( image, diameter=diameter, @@ -42,9 +44,10 @@ def segment( anisotropy=anisotropy, normalize=normalize, resample=resample, - segment_3D_volume=segment_3D_volume + segment_3D_volume=segment_3D_volume, ) return labels + def url_help(): - return 'https://omnipose.readthedocs.io/' + return "https://omnipose.readthedocs.io/" diff --git a/cellacdc/segmenters/pomBseen/__init__.py b/cellacdc/segmenters/pomBseen/__init__.py new file mode 100644 index 000000000..6b779a3d1 --- /dev/null +++ b/cellacdc/segmenters/pomBseen/__init__.py @@ -0,0 +1,3 @@ +from cellacdc import utils + +utils.check_install_package("pombseen", pypi_name="pomBseen") diff --git a/cellacdc/models/pomBseen/acdcSegment.py b/cellacdc/segmenters/pomBseen/acdcSegment.py similarity index 65% rename from cellacdc/models/pomBseen/acdcSegment.py rename to cellacdc/segmenters/pomBseen/acdcSegment.py index 186699441..da482f418 100644 --- a/cellacdc/models/pomBseen/acdcSegment.py +++ b/cellacdc/segmenters/pomBseen/acdcSegment.py @@ -1,32 +1,34 @@ from pombseen.main import pomBseg + class Model: def __init__(self): pass + def segment( - self, - image, - offset = -2.5, - connectivity_remove_small_objects_inverse_bw = 1, - connectivity_label = 1, - connectivity_remove_small_objects_binarize = 1, - sharpen_image = False, - radius=1.0, - amount=1.0, - block_size = 15, - min_pix_inverse_bw = 600, - min_pix_inverted_inverse_bw = 600, - min_pix_thresh_binarize = 600, - footprint = 'default', - clear_border_buffer = 2, - clear_border_max_pix = 1200, - convex_filter_slope = 12.8571, - convex_filter_intercept = 12.5, - min_size = 500, - max_size = 100000, - apply_convex_hull = False, - ): - """Segment the input `image` and returns a labelled array with the same + self, + image, + offset=-2.5, + connectivity_remove_small_objects_inverse_bw=1, + connectivity_label=1, + connectivity_remove_small_objects_binarize=1, + sharpen_image=False, + radius=1.0, + amount=1.0, + block_size=15, + min_pix_inverse_bw=600, + min_pix_inverted_inverse_bw=600, + min_pix_thresh_binarize=600, + footprint="default", + clear_border_buffer=2, + clear_border_max_pix=1200, + convex_filter_slope=12.8571, + convex_filter_intercept=12.5, + min_size=500, + max_size=100000, + apply_convex_hull=False, + ): + """Segment the input `image` and returns a labelled array with the same shape as input image (i.e., instance segmentation). Parameters @@ -76,34 +78,35 @@ def segment( ------- _type_ _description_ - """ - if footprint == 'default': + """ + if footprint == "default": footprint = None - + # Make sure block_size is odd if block_size % 2 == 0: block_size += 1 - - segmented_img = pomBseg(image, - sharpen_image, - radius, - amount, - block_size, - offset, - footprint, - min_pix_inverse_bw, - min_pix_inverted_inverse_bw, - min_pix_thresh_binarize, - connectivity_remove_small_objects_inverse_bw, - connectivity_label, - connectivity_remove_small_objects_binarize, - clear_border_buffer, - clear_border_max_pix, - convex_filter_slope, - convex_filter_intercept, - min_size, - max_size, - apply_convex_hull, + + segmented_img = pomBseg( + image, + sharpen_image, + radius, + amount, + block_size, + offset, + footprint, + min_pix_inverse_bw, + min_pix_inverted_inverse_bw, + min_pix_thresh_binarize, + connectivity_remove_small_objects_inverse_bw, + connectivity_label, + connectivity_remove_small_objects_binarize, + clear_border_buffer, + clear_border_max_pix, + convex_filter_slope, + convex_filter_intercept, + min_size, + max_size, + apply_convex_hull, ) - return segmented_img \ No newline at end of file + return segmented_img diff --git a/cellacdc/segmenters/pomBseen_nuclear/__init__.py b/cellacdc/segmenters/pomBseen_nuclear/__init__.py new file mode 100644 index 000000000..6b779a3d1 --- /dev/null +++ b/cellacdc/segmenters/pomBseen_nuclear/__init__.py @@ -0,0 +1,3 @@ +from cellacdc import utils + +utils.check_install_package("pombseen", pypi_name="pomBseen") diff --git a/cellacdc/models/pomBseen_nuclear/acdcSegment.py b/cellacdc/segmenters/pomBseen_nuclear/acdcSegment.py similarity index 69% rename from cellacdc/models/pomBseen_nuclear/acdcSegment.py rename to cellacdc/segmenters/pomBseen_nuclear/acdcSegment.py index a7f9ce9e0..413fd8223 100644 --- a/cellacdc/models/pomBseen_nuclear/acdcSegment.py +++ b/cellacdc/segmenters/pomBseen_nuclear/acdcSegment.py @@ -1,20 +1,21 @@ from pombseen.main import pomBsegNuc + class Model: def __init__(self, segm_data): self.segm_data = segm_data def segment( - self, - image, - connectivity = 1, - offset = 0, - min_size=5, - max_size=100000, - max_nuclei = 2, - rel_size_max = 0.3 - ): - """Segment the input `image` and returns a labelled array with the same + self, + image, + connectivity=1, + offset=0, + min_size=5, + max_size=100000, + max_nuclei=2, + rel_size_max=0.3, + ): + """Segment the input `image` and returns a labelled array with the same shape as input image (i.e., instance segmentation). Parameters @@ -36,11 +37,17 @@ def segment( ------- _type_ Segmented image - """ + """ segmented_img = pomBsegNuc( - image, self.segm_data, connectivity, offset, min_size, max_size, - max_nuclei, rel_size_max + image, + self.segm_data, + connectivity, + offset, + min_size, + max_size, + max_nuclei, + rel_size_max, ) - return segmented_img \ No newline at end of file + return segmented_img diff --git a/cellacdc/segmenters/sam2/__init__.py b/cellacdc/segmenters/sam2/__init__.py new file mode 100644 index 000000000..97fee8a40 --- /dev/null +++ b/cellacdc/segmenters/sam2/__init__.py @@ -0,0 +1,20 @@ +from cellacdc import utils + +utils.check_install_sam2() +import sam2 + +import os +from pathlib import Path + +# Get SAM2 models path +# Using the same pattern as segment_anything +_, sam_segmenters_path = utils.get_model_path("sam2", create_temp_dir=False) + +# SAM2 model configurations +# Format: 'Display Name': ('config_file', 'checkpoint_filename') +model_types = { + "Large": ("configs/sam2.1/sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt"), + "Base Plus": ("configs/sam2.1/sam2.1_hiera_b+.yaml", "sam2.1_hiera_base_plus.pt"), + "Small": ("configs/sam2.1/sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"), + "Tiny": ("configs/sam2.1/sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"), +} diff --git a/cellacdc/models/sam2/acdcSegment.py b/cellacdc/segmenters/sam2/acdcSegment.py similarity index 79% rename from cellacdc/models/sam2/acdcSegment.py rename to cellacdc/segmenters/sam2/acdcSegment.py index d50bcb5d3..bd47c1838 100644 --- a/cellacdc/models/sam2/acdcSegment.py +++ b/cellacdc/segmenters/sam2/acdcSegment.py @@ -8,43 +8,49 @@ import skimage.measure -from . import model_types, sam_models_path +from . import model_types, sam_segmenters_path from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator -from cellacdc import myutils, widgets, printl +from cellacdc import utils, widgets, printl + class AvailableModels: values = list(model_types.keys()) + class DataFrame: not_a_param = True + class NotParam: not_a_param = True + class Boolean: not_a_param = True + class Integer: not_a_param = True + class Model: def __init__( - self, - model_type: AvailableModels='Large', - input_points_path: widgets.CsvFilePathControl='', - input_points_df: DataFrame='None', - points_per_side=32, - pred_iou_thresh=0.8, - stability_score_thresh=0.95, - crop_n_layers=0, - crop_n_points_downscale_factor=1, - min_mask_region_area=0, - gpu=True - ): + self, + model_type: AvailableModels = "Large", + input_points_path: widgets.CsvFilePathControl = "", + input_points_df: DataFrame = "None", + points_per_side=32, + pred_iou_thresh=0.8, + stability_score_thresh=0.95, + crop_n_layers=0, + crop_n_points_downscale_factor=1, + min_mask_region_area=0, + gpu=True, + ): """Initialization of Segment Anything Model 2 within Cell-ACDC Parameters @@ -130,34 +136,34 @@ def __init__( """ if gpu: from cellacdc import is_mac_arm64 + if is_mac_arm64: - device = 'cpu' + device = "cpu" else: - device = 'cuda' + device = "cuda" else: - device = 'cpu' + device = "cpu" - if isinstance(input_points_df, str) and input_points_df=='None': + if isinstance(input_points_df, str) and input_points_df == "None": input_points_df = None - load_points_df = ( - input_points_path - and input_points_df is None - ) + load_points_df = input_points_path and input_points_df is None if load_points_df: input_points_df = pd.read_csv(input_points_path) if input_points_df is not None: - if 'z' in input_points_df.columns: - input_points_df = input_points_df.sort_values(['z', 'id']) + if "z" in input_points_df.columns: + input_points_df = input_points_df.sort_values(["z", "id"]) else: - input_points_df = input_points_df.sort_values('id') + input_points_df = input_points_df.sort_values("id") self._input_points_df = input_points_df config_file, sam_checkpoint = model_types[model_type] - sam_checkpoint = os.path.join(sam_models_path, sam_checkpoint) - sam = build_sam2(config_file=config_file, ckpt_path=sam_checkpoint, device=device) + sam_checkpoint = os.path.join(sam_segmenters_path, sam_checkpoint) + sam = build_sam2( + config_file=config_file, ckpt_path=sam_checkpoint, device=device + ) if input_points_df is None: self.model = SAM2AutomaticMaskGenerator( @@ -175,18 +181,17 @@ def __init__( self._embedded_img = None def segment( - self, - image: np.ndarray, - frame_i: int, - automatic_removal_of_background: bool=True, - input_points_df: DataFrame='None', - posData: NotParam=None, - save_embeddings: Boolean=False, - only_embeddings: Boolean=False, - use_loaded_embeddings: Boolean=False, - start_z_slice: Integer=0 - ) -> np.ndarray: - + self, + image: np.ndarray, + frame_i: int, + automatic_removal_of_background: bool = True, + input_points_df: DataFrame = "None", + posData: NotParam = None, + save_embeddings: Boolean = False, + only_embeddings: Boolean = False, + use_loaded_embeddings: Boolean = False, + start_z_slice: Integer = 0, + ) -> np.ndarray: """Segment image using SAM2 image : ([Z], Y, X, [C]) numpy.ndarray @@ -260,7 +265,7 @@ def segment( self._input_points_df = input_points_df is_rgb_image = image.shape[-1] == 3 or image.shape[-1] == 4 - is_z_stack = (image.ndim==3 and not is_rgb_image) or (image.ndim==4) + is_z_stack = (image.ndim == 3 and not is_rgb_image) or (image.ndim == 4) if is_rgb_image: labels = np.zeros(image.shape[:-1], dtype=np.uint32) else: @@ -268,15 +273,13 @@ def segment( if self._input_points_df is None: df_points = None - elif 'frame_i' in self._input_points_df.columns: - mask = self._input_points_df['frame_i'] == frame_i + elif "frame_i" in self._input_points_df.columns: + mask = self._input_points_df["frame_i"] == frame_i df_points = self._input_points_df[mask] else: df_points = self._input_points_df - input_points, input_labels = self._get_input_points( - is_z_stack, df_points - ) + input_points, input_labels = self._get_input_points(is_z_stack, df_points) if is_z_stack: for z, img in enumerate(image): input_points_z = None @@ -287,40 +290,42 @@ def segment( embeddings_init = False if use_loaded_embeddings: embeddings_init = self._get_img_embeddings( - posData, frame_i=frame_i, z=z+start_z_slice + posData, frame_i=frame_i, z=z + start_z_slice ) if only_embeddings: self._init_embeddings(img) else: lab_2D = self._segment_2D_image( - img, input_points_z, input_labels_z, - embeddings_already_init=embeddings_init + img, + input_points_z, + input_labels_z, + embeddings_already_init=embeddings_init, ) labels[z] = lab_2D if save_embeddings or only_embeddings: posData.storeSamEmbeddings( - self, frame_i=frame_i, z=z+start_z_slice + self, frame_i=frame_i, z=z + start_z_slice ) if automatic_removal_of_background and input_points is None: # For z-stacks, remove background after 3D relabeling labels = self._remove_background_from_labels(labels) - - labels = skimage.measure.label(labels>0).astype(np.uint32) + + labels = skimage.measure.label(labels > 0).astype(np.uint32) else: embeddings_init = False if use_loaded_embeddings: - embeddings_init = self._get_img_embeddings( - posData, frame_i=frame_i - ) + embeddings_init = self._get_img_embeddings(posData, frame_i=frame_i) if only_embeddings: self._init_embeddings(image) else: labels = self._segment_2D_image( - image, input_points, input_labels, + image, + input_points, + input_labels, embeddings_already_init=embeddings_init, - automatic_removal_of_background=automatic_removal_of_background + automatic_removal_of_background=automatic_removal_of_background, ) if save_embeddings or only_embeddings: @@ -345,25 +350,21 @@ def _get_input_points(self, is_z_stack, df_points): if is_z_stack: input_points = defaultdict(dict) input_labels = defaultdict(dict) - neg_input_points_df = ( - df_points[df_points['id'] == 0] - .set_index('z') - ) - for (z, id), sub_df in df_points.groupby(['z', 'id']): + neg_input_points_df = df_points[df_points["id"] == 0].set_index("z") + for (z, id), sub_df in df_points.groupby(["z", "id"]): if id == 0: continue # Concatenate negative points - points_data_z = sub_df[['x', 'y']].to_numpy() + points_data_z = sub_df[["x", "y"]].to_numpy() points_labels_z = np.ones(len(sub_df), dtype=int) # 1 = positive try: - neg_points_data_z = ( - neg_input_points_df.loc[z][['x', 'y']].to_numpy()) - points_data_z = np.row_stack(( - neg_points_data_z, points_data_z - )) + neg_points_data_z = neg_input_points_df.loc[z][ + ["x", "y"] + ].to_numpy() + points_data_z = np.row_stack((neg_points_data_z, points_data_z)) points_labels_z = np.concatenate( - ([0]*len(neg_points_data_z), points_labels_z) + ([0] * len(neg_points_data_z), points_labels_z) ) except IndexError: pass @@ -373,23 +374,19 @@ def _get_input_points(self, is_z_stack, df_points): else: input_points = {} input_labels = {} - neg_input_points_df = ( - df_points[df_points['id'] == 0] - ) - neg_input_points_data = neg_input_points_df[['x', 'y']].to_numpy() - for id, df_id in df_points.groupby('id'): + neg_input_points_df = df_points[df_points["id"] == 0] + neg_input_points_data = neg_input_points_df[["x", "y"]].to_numpy() + for id, df_id in df_points.groupby("id"): if id == 0: continue - points_data_id = df_id[['x', 'y']].to_numpy() - points_data_id = np.row_stack(( - neg_input_points_data, points_data_id - )) + points_data_id = df_id[["x", "y"]].to_numpy() + points_data_id = np.row_stack((neg_input_points_data, points_data_id)) # Use 1 for positive labels (not actual IDs) - SAM expects binary 0/1 points_labels_id = np.ones(len(df_id), dtype=int) points_labels_id = np.concatenate( - ([0]*len(neg_input_points_data), points_labels_id) + ([0] * len(neg_input_points_data), points_labels_id) ) input_points[id] = points_data_id input_labels[id] = points_labels_id @@ -398,7 +395,7 @@ def _get_input_points(self, is_z_stack, df_points): def _init_embeddings(self, img_rgb): if img_rgb.ndim == 2: - img_rgb = myutils.to_uint8(img_rgb) + img_rgb = utils.to_uint8(img_rgb) img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) # Create embeddings only if new image @@ -407,7 +404,7 @@ def _init_embeddings(self, img_rgb): except Exception as err: init_embeddings = True - if hasattr(self.model, 'predictor'): + if hasattr(self.model, "predictor"): predictor = self.model.predictor else: predictor = self.model @@ -417,14 +414,15 @@ def _init_embeddings(self, img_rgb): self._embedded_img = img_rgb def _segment_2D_image( - self, image: np.ndarray, - input_points: np.ndarray, - input_labels: np.ndarray, - embeddings_already_init: bool=False, - automatic_removal_of_background: bool=False - ) -> np.ndarray: - - img = myutils.to_uint8(image) + self, + image: np.ndarray, + input_points: np.ndarray, + input_labels: np.ndarray, + embeddings_already_init: bool = False, + automatic_removal_of_background: bool = False, + ) -> np.ndarray: + + img = utils.to_uint8(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) labels = np.zeros(image.shape[:2], dtype=np.uint32) @@ -439,10 +437,10 @@ def _segment_2D_image( masks = [m for i, m in enumerate(masks) if i != bg_idx] # Sort by area descending so smaller masks overwrite larger ones - masks = sorted(masks, key=lambda m: m['area'], reverse=True) + masks = sorted(masks, key=lambda m: m["area"], reverse=True) for id, mask in enumerate(masks): - obj_image = mask['segmentation'] - labels[obj_image] = id+1 + obj_image = mask["segmentation"] + labels[obj_image] = id + 1 return labels @@ -456,7 +454,7 @@ def _segment_2D_image( for id, point_coords in input_points.items(): point_labels = input_labels[id] - multimask_output = len(point_coords)==1 + multimask_output = len(point_coords) == 1 masks, scores, logits = self.model.predict( point_coords=point_coords, point_labels=point_labels, @@ -470,9 +468,7 @@ def _segment_2D_image( labels[mask] = id return labels - def _find_background_mask_index( - self, masks: list, shape: tuple - ) -> int | None: + def _find_background_mask_index(self, masks: list, shape: tuple) -> int | None: """Find the mask with the most pixels touching the image border.""" if not masks: return None @@ -484,7 +480,7 @@ def _find_background_mask_index( max_border_pixels = 0 bg_idx = None for i, mask in enumerate(masks): - segmentation = mask['segmentation'] + segmentation = mask["segmentation"] border_pixels = np.sum(segmentation & border_mask) if border_pixels > max_border_pixels: max_border_pixels = border_pixels @@ -505,4 +501,4 @@ def _remove_background_from_labels(self, labels: np.ndarray) -> np.ndarray: def url_help(): - return 'https://github.com/facebookresearch/segment-anything-2' + return "https://github.com/facebookresearch/segment-anything-2" diff --git a/cellacdc/segmenters/segment_anything/__init__.py b/cellacdc/segmenters/segment_anything/__init__.py new file mode 100644 index 000000000..c0aa3abcf --- /dev/null +++ b/cellacdc/segmenters/segment_anything/__init__.py @@ -0,0 +1,16 @@ +from cellacdc import utils + +utils.check_install_segment_anything() + +import os +from cellacdc import segment_anything_weights_filenames + +_, sam_segmenters_path = utils.get_model_path( + "segment_anything", create_temp_dir=False +) + +model_types = { + "Large": ("default", segment_anything_weights_filenames[0]), + "Medium": ("vit_l", segment_anything_weights_filenames[1]), + "Small": ("vit_b", segment_anything_weights_filenames[2]), +} diff --git a/cellacdc/models/segment_anything/acdcSegment.py b/cellacdc/segmenters/segment_anything/acdcSegment.py similarity index 68% rename from cellacdc/models/segment_anything/acdcSegment.py rename to cellacdc/segmenters/segment_anything/acdcSegment.py index d0e49480c..aab4cd591 100644 --- a/cellacdc/models/segment_anything/acdcSegment.py +++ b/cellacdc/segmenters/segment_anything/acdcSegment.py @@ -10,159 +10,161 @@ import skimage.measure -from . import model_types, sam_models_path +from . import model_types, sam_segmenters_path + +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor +from cellacdc import utils, widgets, printl -from segment_anything import ( - sam_model_registry, SamAutomaticMaskGenerator, SamPredictor -) -from cellacdc import myutils, widgets, printl class AvailableModels: values = list(model_types.keys()) + class DataFrame: not_a_param = True + class NotParam: not_a_param = True + class Boolean: not_a_param = True + class Integer: not_a_param = True + class Model: def __init__( - self, - model_type: AvailableModels='Large', - input_points_path: widgets.CsvFilePathControl='', - input_points_df: DataFrame='None', - points_per_side=32, - pred_iou_thresh=0.88, - stability_score_thresh=0.95, - crop_n_layers=0, - crop_n_points_downscale_factor=2, - min_mask_region_area=1, - gpu=False - ): + self, + model_type: AvailableModels = "Large", + input_points_path: widgets.CsvFilePathControl = "", + input_points_df: DataFrame = "None", + points_per_side=32, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + crop_n_layers=0, + crop_n_points_downscale_factor=2, + min_mask_region_area=1, + gpu=False, + ): """Initialization of Segment Anything Model within Cell-ACDC Parameters ---------- points_per_side : int or None, optional - The number of points to be sampled along one side of the image. - The total number of points is points_per_side**2. - If None, 'point_grids' must provide explicit point sampling. - Ignored if `input_points_path` is not empty or `input_points_df` is not + The number of points to be sampled along one side of the image. + The total number of points is points_per_side**2. + If None, 'point_grids' must provide explicit point sampling. + Ignored if `input_points_path` is not empty or `input_points_df` is not 'None'. Default is 32 - pred_iou_thresh : float, optional - A filtering threshold in [0,1], using the model's predicted mask - quality. - Ignored if `input_points_path` is not empty or `input_points_df` is not - 'None'. Default is pred_iou_thresh + pred_iou_thresh : float, optional + A filtering threshold in [0,1], using the model's predicted mask + quality. + Ignored if `input_points_path` is not empty or `input_points_df` is not + 'None'. Default is pred_iou_thresh stability_score_thresh : float, optional - A filtering threshold in [0,1], using the stability of the mask - under changes to the cutoff used to binarize the model's mask - predictions. - Ignored if `input_points_path` is not empty or `input_points_df` is not + A filtering threshold in [0,1], using the stability of the mask + under changes to the cutoff used to binarize the model's mask + predictions. + Ignored if `input_points_path` is not empty or `input_points_df` is not 'None'. Default is 0.95 - crop_n_layers : int - If >0, mask prediction will be run again on crops of the image. - Sets the number of layers to run, where each layer has 2**i_layer + crop_n_layers : int + If >0, mask prediction will be run again on crops of the image. + Sets the number of layers to run, where each layer has 2**i_layer number of image crops. - Ignored if `input_points_path` is not empty or `input_points_df` is not + Ignored if `input_points_path` is not empty or `input_points_df` is not 'None'. Default is 0 crop_n_points_downscale_factor : int, optional - The number of points-per-side sampled in layer n is scaled down by + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - Ignored if `input_points_path` is not empty or `input_points_df` is not + Ignored if `input_points_path` is not empty or `input_points_df` is not 'None'. Default is 2 min_mask_region_area: int, optional - If >0, postprocessing will be applied mto remove disconnected - regions and holes in masks with area smaller than - min_mask_region_area. - Ignored if `input_points_path` is not empty or `input_points_df` is not + If >0, postprocessing will be applied mto remove disconnected + regions and holes in masks with area smaller than + min_mask_region_area. + Ignored if `input_points_path` is not empty or `input_points_df` is not 'None'. Default is 1 input_points_path : str, optional - If not empty, this is the path to the CSV file with the coordinates - of the input points for SAM. It must contain the columns - ('x', 'y', 'id') with an optional 'z' column for segmentation of 3D - z-stack data (slice-by-slice) and a 'frame_i' columns for - time-lapse data. - - Note that `id = 0` will be used for the negative points, i.e. those + If not empty, this is the path to the CSV file with the coordinates + of the input points for SAM. It must contain the columns + ('x', 'y', 'id') with an optional 'z' column for segmentation of 3D + z-stack data (slice-by-slice) and a 'frame_i' columns for + time-lapse data. + + Note that `id = 0` will be used for the negative points, i.e. those objects (like the background) that should not be segmented. - - In the Cell-ACDC GUI (module 3) you can click to add points and - save them to a file whose path or endname can be provided for the - `input_points_path`. To do so, click on the "Add points layer" - button on the top toolbar and choose "Add points with mouse clicks". - - To add a new point for a new object click with the mouse left - button. To add points to the same object click with the right - button. The 'id' of the point will be visible next to the point - symbol. To delete a point click on the point. - - To add negative points click with the middle button (Cmd+click on - macOS) or enter 0 in the "Point id" numeric control (top toolbar) - and then right-click to add points with the current id. - + + In the Cell-ACDC GUI (module 3) you can click to add points and + save them to a file whose path or endname can be provided for the + `input_points_path`. To do so, click on the "Add points layer" + button on the top toolbar and choose "Add points with mouse clicks". + + To add a new point for a new object click with the mouse left + button. To add points to the same object click with the right + button. The 'id' of the point will be visible next to the point + symbol. To delete a point click on the point. + + To add negative points click with the middle button (Cmd+click on + macOS) or enter 0 in the "Point id" numeric control (top toolbar) + and then right-click to add points with the current id. + To load the coordinates from a CSV file click on the browse button. - - If empty string and `inputs_points_df` is 'None', SAM will run + + If empty string and `inputs_points_df` is 'None', SAM will run in automatic mode on the entire image. Default is None - + input_points_df : pd.DataFrame or 'None', optional - If not 'None', this is a pandas DataFrame (a table) with the - coordinates of the input points for SAM. - - It must contain the columns ('x', 'y', 'id') with an optional - 'z' column for segmentation of 3D z-stack data (slice-by-slice) and - a 'frame_i' columns for time-lapse data. Note that `id = 0` will - be used for the negative points, i.e. those objects (like the + If not 'None', this is a pandas DataFrame (a table) with the + coordinates of the input points for SAM. + + It must contain the columns ('x', 'y', 'id') with an optional + 'z' column for segmentation of 3D z-stack data (slice-by-slice) and + a 'frame_i' columns for time-lapse data. Note that `id = 0` will + be used for the negative points, i.e. those objects (like the background) that should not be segmented. - - If not 'None', `input_points_path` will be ignored and this will be used - instead. - - If 'None' and `input_points_path` is empty, SAM will run - in automatic mode on the entire image. Default is 'None' - """ + + If not 'None', `input_points_path` will be ignored and this will be used + instead. + + If 'None' and `input_points_path` is empty, SAM will run + in automatic mode on the entire image. Default is 'None' + """ if gpu: from cellacdc import is_mac_arm64 + if is_mac_arm64: - device = 'cpu' + device = "cpu" else: - device = 'cuda' + device = "cuda" else: - device = 'cpu' - - if isinstance(input_points_df, str) and input_points_df=='None': + device = "cpu" + + if isinstance(input_points_df, str) and input_points_df == "None": input_points_df = None - - load_points_df = ( - input_points_path - and input_points_df is None - ) + + load_points_df = input_points_path and input_points_df is None if load_points_df: input_points_df = pd.read_csv(input_points_path) - + if input_points_df is not None: - if 'z' in input_points_df.columns: - input_points_df = input_points_df.sort_values(['z', 'id']) + if "z" in input_points_df.columns: + input_points_df = input_points_df.sort_values(["z", "id"]) else: - input_points_df = input_points_df.sort_values('id') - + input_points_df = input_points_df.sort_values("id") + self._input_points_df = input_points_df - + model_type, sam_checkpoint = model_types[model_type] - sam_checkpoint = os.path.join(sam_models_path, sam_checkpoint) - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) + sam_checkpoint = os.path.join(sam_segmenters_path, sam_checkpoint) + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) if input_points_df is None: self.model = SamAutomaticMaskGenerator( - sam, + sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, @@ -172,115 +174,112 @@ def __init__( ) else: self.model = SamPredictor(sam) - + self._embedded_img = None - + def segment( - self, - image: np.ndarray, - frame_i: int, - automatic_removal_of_background: bool=True, - input_points_df: DataFrame='None', - posData: NotParam=None, - save_embeddings: Boolean=False, - only_embeddings: Boolean=False, - use_loaded_embeddings: Boolean=False, - start_z_slice: Integer=0 - ) -> np.ndarray: - + self, + image: np.ndarray, + frame_i: int, + automatic_removal_of_background: bool = True, + input_points_df: DataFrame = "None", + posData: NotParam = None, + save_embeddings: Boolean = False, + only_embeddings: Boolean = False, + use_loaded_embeddings: Boolean = False, + start_z_slice: Integer = 0, + ) -> np.ndarray: """_summary_ image : ([Z], Y, X, [C]) numpy.ndarray - Input image. It can be grayscale 2D (Y, X), or 3D (Z, Y, X) for - z-stack data, or it can have additional dimension C for the RGB + Input image. It can be grayscale 2D (Y, X), or 3D (Z, Y, X) for + z-stack data, or it can have additional dimension C for the RGB channels. - + frame_i : int - Frame index (starting from 0). Used to get the input points from - `input_points_df` with timelapse data. Ignored if the + Frame index (starting from 0). Used to get the input points from + `input_points_df` with timelapse data. Ignored if the `input_points_df` does not have the 'frame_i' column. - + automatic_removal_of_background : bool, optional - If True, the background object will be removed. The background - object is defined as the largest object touching the borders of the - image. Used only with automatic generator without input prompts, - i.e., `input_points_path` is empty and `input_points_df` is equal + If True, the background object will be removed. The background + object is defined as the largest object touching the borders of the + image. Used only with automatic generator without input prompts, + i.e., `input_points_path` is empty and `input_points_df` is equal to 'None'. - + input_points_df : pd.DataFrame or 'None', optional - If not 'None', this is a pandas DataFrame (a table) with the - coordinates of the input points for SAM. - - It must contain the columns ('x', 'y', 'id') with an optional - 'z' column for segmentation of 3D z-stack data (slice-by-slice) and - a 'frame_i' columns for time-lapse data. Note that `id = 0` will - be used for the negative points, i.e. those objects (like the + If not 'None', this is a pandas DataFrame (a table) with the + coordinates of the input points for SAM. + + It must contain the columns ('x', 'y', 'id') with an optional + 'z' column for segmentation of 3D z-stack data (slice-by-slice) and + a 'frame_i' columns for time-lapse data. Note that `id = 0` will + be used for the negative points, i.e. those objects (like the background) that should not be segmented. - - If not 'None', and there is already an `input_points_df` from the - `__init__` (initialization of the model) method it will be + + If not 'None', and there is already an `input_points_df` from the + `__init__` (initialization of the model) method it will be overwritten with the new table. - + posData : load.loadData or None, optional - This is not a parameter configurable through the GUI. Cell-ACDC - will pass the class of the loaded data from the specific Position. - This is the used internally to add image embeddings if + This is not a parameter configurable through the GUI. Cell-ACDC + will pass the class of the loaded data from the specific Position. + This is the used internally to add image embeddings if `save_embeddings` is True. - + save_embeddings : bool, optional - This is not a parameter configurable through the GUI. If `posData` - is not None, the image embeddings will be stored in the dictionary - `posData.sam_embeddings`. This dictionary can be later used to + This is not a parameter configurable through the GUI. If `posData` + is not None, the image embeddings will be stored in the dictionary + `posData.sam_embeddings`. This dictionary can be later used to save the embeddings to disk. - + only_embeddings : bool, optional - This is not a parameter configurable through the GUI. If `True`, - The labels masks will not be generated and the model will only - be used to generate the image embeddings stored in + This is not a parameter configurable through the GUI. If `True`, + The labels masks will not be generated and the model will only + be used to generate the image embeddings stored in `posData.sam_embeddings`. - - use_loaded_embeddings : bool, optional - This is not a parameter configurable through the GUI. If `posData` - is not None, the image embeddings will be loaded from the dictionary + + use_loaded_embeddings : bool, optional + This is not a parameter configurable through the GUI. If `posData` + is not None, the image embeddings will be loaded from the dictionary `posData.sam_embeddings`. - + start_z_slice : int, optional - This is not a parameter configurable through the GUI. Cell-ACDC - will pass the correct start z-slice to store embeddings at the + This is not a parameter configurable through the GUI. Cell-ACDC + will pass the correct start z-slice to store embeddings at the right z-slice. - + Returns ------- ([Z], Y, X) numpy.ndarray of ints - Output labelled masks with the same shape as input image but without - the channel dimension. Every pixel belonging to the same object + Output labelled masks with the same shape as input image but without + the channel dimension. Every pixel belonging to the same object will have the same integer ID. ID = 0 is for the background. - """ - + """ + if isinstance(input_points_df, pd.DataFrame): self._input_points_df = input_points_df - + is_rgb_image = image.shape[-1] == 3 or image.shape[-1] == 4 - is_z_stack = (image.ndim==3 and not is_rgb_image) or (image.ndim==4) + is_z_stack = (image.ndim == 3 and not is_rgb_image) or (image.ndim == 4) if is_rgb_image: labels = np.zeros(image.shape[:-1], dtype=np.uint32) else: labels = np.zeros(image.shape, dtype=np.uint32) - + if self._input_points_df is None: df_points = None - elif 'frame_i' in self._input_points_df.columns: - mask = self._input_points_df['frame_i'] == frame_i + elif "frame_i" in self._input_points_df.columns: + mask = self._input_points_df["frame_i"] == frame_i df_points = self._input_points_df[mask] else: df_points = self._input_points_df - + auto_remove_bkgr = automatic_removal_of_background - input_points, input_labels = self._get_input_points( - is_z_stack, df_points - ) - if is_z_stack: - pbar_z = tqdm(total=len(image), ncols=100, desc='z-slice') + input_points, input_labels = self._get_input_points(is_z_stack, df_points) + if is_z_stack: + pbar_z = tqdm(total=len(image), ncols=100, desc="z-slice") for z, img in enumerate(image): input_points_z = None input_labels_z = None @@ -288,150 +287,145 @@ def segment( input_points_z = input_points.get(z, None) if input_points_z is not None: input_labels_z = input_labels.get(z, []) - + embeddings_init = False if use_loaded_embeddings: embeddings_init = self._get_img_embeddings( - posData, frame_i=frame_i, z=z+start_z_slice + posData, frame_i=frame_i, z=z + start_z_slice ) - + if only_embeddings: self._init_embeddings(img) else: lab_2D = self._segment_2D_image( - img, input_points_z, input_labels_z, + img, + input_points_z, + input_labels_z, embeddings_already_init=embeddings_init, ) labels[z] = lab_2D if save_embeddings or only_embeddings: posData.storeSamEmbeddings( - self, frame_i=frame_i, z=z+start_z_slice + self, frame_i=frame_i, z=z + start_z_slice ) - + pbar_z.update() pbar_z.close() - + if automatic_removal_of_background and input_points is None: # For z-stacks, remove background after 3D relabeling labels = self._remove_background_from_labels(labels) - - labels = skimage.measure.label(labels>0).astype(np.uint32) + + labels = skimage.measure.label(labels > 0).astype(np.uint32) else: embeddings_init = False if use_loaded_embeddings: - embeddings_init = self._get_img_embeddings( - posData, frame_i=frame_i - ) + embeddings_init = self._get_img_embeddings(posData, frame_i=frame_i) if only_embeddings: self._init_embeddings(image) else: labels = self._segment_2D_image( - image, input_points, input_labels, + image, + input_points, + input_labels, embeddings_already_init=embeddings_init, - automatic_removal_of_background=auto_remove_bkgr + automatic_removal_of_background=auto_remove_bkgr, ) if save_embeddings or only_embeddings: posData.storeSamEmbeddings(self, frame_i=frame_i) - + return labels def _get_img_embeddings(self, posData, frame_i=0, z=0): img_embeddings = posData.getSamEmbeddings(frame_i=frame_i, z=z) if img_embeddings is None: return False - + for key, value in img_embeddings.items(): setattr(self, key, value) - + return True - + def _get_input_points(self, is_z_stack, df_points): if df_points is None: return None, None - + if is_z_stack: input_points = defaultdict(dict) input_labels = defaultdict(dict) - neg_input_points_df = ( - df_points[df_points['id'] == 0] - .set_index('z') - ) - for (z, id), sub_df in df_points.groupby(['z', 'id']): + neg_input_points_df = df_points[df_points["id"] == 0].set_index("z") + for (z, id), sub_df in df_points.groupby(["z", "id"]): if id == 0: continue - + # Concatenate negative points - points_data_z = sub_df[['x', 'y']].to_numpy() + points_data_z = sub_df[["x", "y"]].to_numpy() points_labels_z = np.ones(len(sub_df), dtype=int) # 1 = positive try: - neg_points_data_z = ( - neg_input_points_df.loc[z][['x', 'y']].to_numpy()) - points_data_z = np.row_stack(( - neg_points_data_z, points_data_z - )) + neg_points_data_z = neg_input_points_df.loc[z][ + ["x", "y"] + ].to_numpy() + points_data_z = np.row_stack((neg_points_data_z, points_data_z)) points_labels_z = np.concatenate( - ([0]*len(neg_points_data_z), points_labels_z) + ([0] * len(neg_points_data_z), points_labels_z) ) except IndexError: pass - + input_points[z][id] = points_data_z input_labels[z][id] = points_labels_z else: input_points = {} input_labels = {} - neg_input_points_df = ( - df_points[df_points['id'] == 0] - ) - neg_input_points_data = neg_input_points_df[['x', 'y']].to_numpy() - for id, df_id in df_points.groupby('id'): + neg_input_points_df = df_points[df_points["id"] == 0] + neg_input_points_data = neg_input_points_df[["x", "y"]].to_numpy() + for id, df_id in df_points.groupby("id"): if id == 0: continue - - points_data_id = df_id[['x', 'y']].to_numpy() - points_data_id = np.row_stack(( - neg_input_points_data, points_data_id - )) + + points_data_id = df_id[["x", "y"]].to_numpy() + points_data_id = np.row_stack((neg_input_points_data, points_data_id)) # Use 1 for positive labels (not actual IDs) - SAM expects binary 0/1 points_labels_id = np.ones(len(df_id), dtype=int) points_labels_id = np.concatenate( - ([0]*len(neg_input_points_data), points_labels_id) + ([0] * len(neg_input_points_data), points_labels_id) ) input_points[id] = points_data_id input_labels[id] = points_labels_id - + return input_points, input_labels - - def _init_embeddings(self, img_rgb): + + def _init_embeddings(self, img_rgb): if img_rgb.ndim == 2: - img_rgb = myutils.to_uint8(img_rgb) + img_rgb = utils.to_uint8(img_rgb) img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) - - # Create embeddings only if new image + + # Create embeddings only if new image try: init_embeddings = not np.allclose(img_rgb, self._embedded_img) except Exception as err: init_embeddings = True - - if hasattr(self.model, 'predictor'): + + if hasattr(self.model, "predictor"): predictor = self.model.predictor else: predictor = self.model - - if init_embeddings: + + if init_embeddings: predictor.set_image(img_rgb) self._embedded_img = img_rgb - + def _segment_2D_image( - self, image: np.ndarray, - input_points: dict[int, np.ndarray], - input_labels: dict[int, np.ndarray], - embeddings_already_init: bool=False, - automatic_removal_of_background: bool=False - ) -> np.ndarray: - - img = myutils.to_uint8(image) + self, + image: np.ndarray, + input_points: dict[int, np.ndarray], + input_labels: dict[int, np.ndarray], + embeddings_already_init: bool = False, + automatic_removal_of_background: bool = False, + ) -> np.ndarray: + + img = utils.to_uint8(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) labels = np.zeros(image.shape[:2], dtype=np.uint32) @@ -446,24 +440,24 @@ def _segment_2D_image( masks = [m for i, m in enumerate(masks) if i != bg_idx] # Sort by area descending so smaller masks overwrite larger ones - masks = sorted(masks, key=lambda m: m['area'], reverse=True) + masks = sorted(masks, key=lambda m: m["area"], reverse=True) for id, mask in enumerate(masks): - obj_image = mask['segmentation'] - labels[obj_image] = id+1 + obj_image = mask["segmentation"] + labels[obj_image] = id + 1 return labels - + # No input points --> return empty labels if len(input_points) == 0: return labels - + # SAM with input points if not embeddings_already_init: self._init_embeddings(img) - + for id, point_coords in input_points.items(): point_labels = input_labels[id] - multimask_output = len(point_coords)==1 + multimask_output = len(point_coords) == 1 masks, scores, logits = self.model.predict( point_coords=point_coords, point_labels=point_labels, @@ -477,9 +471,7 @@ def _segment_2D_image( labels[mask] = id return labels - def _find_background_mask_index( - self, masks: list, shape: tuple - ) -> int | None: + def _find_background_mask_index(self, masks: list, shape: tuple) -> int | None: """Find the mask with the most pixels touching the image border.""" if not masks: return None @@ -491,7 +483,7 @@ def _find_background_mask_index( max_border_pixels = 0 bg_idx = None for i, mask in enumerate(masks): - segmentation = mask['segmentation'] + segmentation = mask["segmentation"] border_pixels = np.sum(segmentation & border_mask) if border_pixels > max_border_pixels: max_border_pixels = border_pixels @@ -512,4 +504,4 @@ def _remove_background_from_labels(self, labels: np.ndarray) -> np.ndarray: def url_help(): - return 'https://github.com/facebookresearch/segment-anything' + return "https://github.com/facebookresearch/segment-anything" diff --git a/cellacdc/models/skip_segmentation/__init__.py b/cellacdc/segmenters/skip_segmentation/__init__.py similarity index 100% rename from cellacdc/models/skip_segmentation/__init__.py rename to cellacdc/segmenters/skip_segmentation/__init__.py diff --git a/cellacdc/models/skip_segmentation/acdcSegment.py b/cellacdc/segmenters/skip_segmentation/acdcSegment.py similarity index 77% rename from cellacdc/models/skip_segmentation/acdcSegment.py rename to cellacdc/segmenters/skip_segmentation/acdcSegment.py index aed82e3f1..7679d63c1 100644 --- a/cellacdc/models/skip_segmentation/acdcSegment.py +++ b/cellacdc/segmenters/skip_segmentation/acdcSegment.py @@ -1,14 +1,13 @@ class Model: def __init__(self, segm_data): self.segm_data = segm_data - def segment( - self, - image, - frame_i, - skip_segmentation = True, - ): + self, + image, + frame_i, + skip_segmentation=True, + ): """Skips the segmentation step and instead uses the provided segmentation data. Parameters @@ -21,6 +20,6 @@ def segment( ------- _type_ Segmented image (same as segm_data) - """ - - return self.segm_data[frame_i] \ No newline at end of file + """ + + return self.segm_data[frame_i] diff --git a/cellacdc/models/thresholding/__init__.py b/cellacdc/segmenters/thresholding/__init__.py similarity index 100% rename from cellacdc/models/thresholding/__init__.py rename to cellacdc/segmenters/thresholding/__init__.py diff --git a/cellacdc/models/thresholding/acdcSegment.py b/cellacdc/segmenters/thresholding/acdcSegment.py similarity index 85% rename from cellacdc/models/thresholding/acdcSegment.py rename to cellacdc/segmenters/thresholding/acdcSegment.py index 3ee0f1ca3..c4d3d1612 100644 --- a/cellacdc/models/thresholding/acdcSegment.py +++ b/cellacdc/segmenters/thresholding/acdcSegment.py @@ -4,10 +4,11 @@ from cellacdc import printl + class Model: def __init__(self): pass - + def _preprocess(self, img, gauss_sigma): if gauss_sigma > 0: filtered = skimage.filters.gaussian(img, sigma=gauss_sigma) @@ -18,12 +19,14 @@ def _preprocess(self, img, gauss_sigma): def _apply_threshold(self, img, threshold_method): thresh_val = getattr(skimage.filters, threshold_method)(img) return img > thresh_val - + def segment( - self, image, gauss_sigma=1.0, - threshold_method='threshold_otsu', - segment_3D_volume=False - ): + self, + image, + gauss_sigma=1.0, + threshold_method="threshold_otsu", + segment_3D_volume=False, + ): is3D = image.ndim > 2 if is3D and not segment_3D_volume: # Segment slice-by-slice @@ -35,7 +38,7 @@ def segment( else: filtered = self._preprocess(image, gauss_sigma) thresh = self._apply_threshold(filtered, threshold_method) - + labels = skimage.measure.label(thresh) - return labels \ No newline at end of file + return labels diff --git a/cellacdc/promptable_models/__init__.py b/cellacdc/segmenters_promptable/__init__.py similarity index 100% rename from cellacdc/promptable_models/__init__.py rename to cellacdc/segmenters_promptable/__init__.py diff --git a/cellacdc/segmenters_promptable/micro-sam/__init__.py b/cellacdc/segmenters_promptable/micro-sam/__init__.py new file mode 100644 index 000000000..eb89a63d9 --- /dev/null +++ b/cellacdc/segmenters_promptable/micro-sam/__init__.py @@ -0,0 +1,3 @@ +import cellacdc.utils as utils + +utils.check_install_microsam() diff --git a/cellacdc/promptable_models/micro-sam/acdcSegment.py b/cellacdc/segmenters_promptable/micro-sam/acdcSegment.py similarity index 100% rename from cellacdc/promptable_models/micro-sam/acdcSegment.py rename to cellacdc/segmenters_promptable/micro-sam/acdcSegment.py diff --git a/cellacdc/segmenters_promptable/nnInteractive/__init__.py b/cellacdc/segmenters_promptable/nnInteractive/__init__.py new file mode 100644 index 000000000..71baa747f --- /dev/null +++ b/cellacdc/segmenters_promptable/nnInteractive/__init__.py @@ -0,0 +1,3 @@ +import cellacdc.utils as utils + +utils.check_install_nnInteractive() diff --git a/cellacdc/promptable_models/nnInteractive/acdcPromptSegment.py b/cellacdc/segmenters_promptable/nnInteractive/acdcPromptSegment.py similarity index 74% rename from cellacdc/promptable_models/nnInteractive/acdcPromptSegment.py rename to cellacdc/segmenters_promptable/nnInteractive/acdcPromptSegment.py index 6fe921112..f7ad7faaa 100644 --- a/cellacdc/promptable_models/nnInteractive/acdcPromptSegment.py +++ b/cellacdc/segmenters_promptable/nnInteractive/acdcPromptSegment.py @@ -3,32 +3,35 @@ import numpy as np -from cellacdc.promptable_models.utils import build_combined_mask +from cellacdc.segmenters_promptable.utils import build_combined_mask import torch from cellacdc import user_profile_path -from cellacdc import myutils +from cellacdc import utils from cellacdc import printl from huggingface_hub import snapshot_download + class AvailableModels: - values = ['nnInteractive_v1.0'] + values = ["nnInteractive_v1.0"] + class GPUorCPU: - values = ['gpu', 'cpu'] + values = ["gpu", "cpu"] + class Model: def __init__( - self, - model_name: AvailableModels = 'nnInteractive_v1.0', - run_on: GPUorCPU = 'cpu', - device: torch.device | int ='None', - verbose: bool = False, - torch_number_of_threads: int = os.cpu_count(), - **kwargs - ): + self, + model_name: AvailableModels = "nnInteractive_v1.0", + run_on: GPUorCPU = "cpu", + device: torch.device | int = "None", + verbose: bool = False, + torch_number_of_threads: int = os.cpu_count(), + **kwargs, + ): """_summary_ Parameters @@ -39,26 +42,26 @@ def __init__( Whether to run on CPU or first GPU available. Default is 'cpu' device : torch.device or int or None If not None, this is the device used for running the model - (torch.device('cuda') or torch.device('cpu')). - It overrides `run_on`, recommended if you want to use a specific GPU + (torch.device('cuda') or torch.device('cpu')). + It overrides `run_on`, recommended if you want to use a specific GPU (e.g. torch.device('cuda:1'). Default is None verbose : bool, optional - If True, more information will be displayed in the terminal. + If True, more information will be displayed in the terminal. Default is False torch_number_of_threads : int, optional - Number of CPU threads to use for the computation. + Number of CPU threads to use for the computation. Default is `os.cpu_count()`, i.e., the maximum available CPU cores. - """ + """ from nnInteractive.inference.inference_session import ( - nnInteractiveInferenceSession + nnInteractiveInferenceSession, ) - - if device == 'None': + + if device == "None": device = None - + if device is None: - device = myutils.get_torch_device(gpu=run_on == 'gpu') - + device = utils.get_torch_device(gpu=run_on == "gpu") + self.model = nnInteractiveInferenceSession( device=device, # Set inference device use_torch_compile=False, # Experimental: Not tested yet @@ -67,55 +70,55 @@ def __init__( do_autozoom=True, # Enables AutoZoom for better patching use_pinned_memory=True, # Optimizes GPU memory transfers ) - - download_dir = os.path.join(user_profile_path, 'acdc-nnInteractive') + + download_dir = os.path.join(user_profile_path, "acdc-nnInteractive") os.makedirs(download_dir, exist_ok=True) - + download_path = snapshot_download( - repo_id='nnInteractive/nnInteractive', + repo_id="nnInteractive/nnInteractive", allow_patterns=[f"{model_name}/*"], - local_dir=download_dir + local_dir=download_dir, ) - + model_path = os.path.join(download_dir, model_name) - + self.model.initialize_from_trained_model_folder(model_path) - + self.prompt_ids_image_mapper = {} self.prompts = defaultdict(list) self.negative_prompts = defaultdict(list) - - def _validate_prompt(self, prompt, prompt_type='point'): - if prompt_type == 'point': + + def _validate_prompt(self, prompt, prompt_type="point"): + if prompt_type == "point": prompt = tuple(prompt) if len(prompt) != 3: raise ValueError( "Point prompt must be a sequence of 3 coordinates (z, y, x)." ) - + def _validate_image(self, image): if image is None: return - + if image.ndim == 3: return - + raise ValueError( "Only 3D images are supported by nnInteractive. " "Please provide a 3D image with (Z, Y, X) dimensions." ) - + def add_prompt( - self, - prompt, - prompt_id: int, - *args, - image=None, - image_origin=(0, 0, 0), - parent_obj_id=0, - prompt_type='point', - **kwargs - ): + self, + prompt, + prompt_id: int, + *args, + image=None, + image_origin=(0, 0, 0), + parent_obj_id=0, + prompt_type="point", + **kwargs, + ): """Add prompt to model Parameters @@ -124,97 +127,83 @@ def add_prompt( Prompt to add. If 'point', this should be a sequence of 3 coordinates (z, y, x). prompt_id : int - Unique identifier for the prompt. If 0, then it will be treated as a + Unique identifier for the prompt. If 0, then it will be treated as a negative prompt (i.e., the background). image : np.ndarray, optional - Image to which the prompt is associated. If None, the prompt will + Image to which the prompt is associated. If None, the prompt will be associated to the entire image passed to the `segment` method. image_origin : tuple of (z0, y0, x0) coordinates, optional - Origin of the image in the global image coordinate system. This - is useful when you want to pass a crop of the image to the model, - but still have the result inserted into the global image by + Origin of the image in the global image coordinate system. This + is useful when you want to pass a crop of the image to the model, + but still have the result inserted into the global image by the `segment` method. Default is (0, 0, 0). parent_obj_id : int, optional - The ID of the parent object. If not 0, this will be used to assign - negative prompts only to the parent object. + The ID of the parent object. If not 0, this will be used to assign + negative prompts only to the parent object. prompt_type : {'point'}, optional The type of prompt to add. Default is 'point'. - """ + """ self._validate_prompt(prompt, prompt_type=prompt_type) self._validate_image(image) - + if prompt_id not in self.prompt_ids_image_mapper and prompt_id != 0: self.prompt_ids_image_mapper[prompt_id] = (image, image_origin) - + if prompt_id != 0: - self.prompts[prompt_id].append( - (prompt, prompt_type) - ) + self.prompts[prompt_id].append((prompt, prompt_type)) elif parent_obj_id != 0: # Negative prompt for a specific parent object - self.negative_prompts[parent_obj_id].append( - (prompt, prompt_type) - ) + self.negative_prompts[parent_obj_id].append((prompt, prompt_type)) else: # Negative prompt for the background self.negative_prompts[0].append((prompt, prompt_type)) - + def _add_object_prompts(self, prompt_id, is_negative=False): prompts = self.prompts[prompt_id] for prompt, prompt_type in prompts: - if prompt_type == 'point': + if prompt_type == "point": # nnInteractive requires (x, y, z) order point_prompt = tuple(prompt[::-1]) self.model.add_point_interaction( point_prompt, include_interaction=not is_negative, - run_prediction=True + run_prediction=True, ) else: raise ValueError(f"Unsupported prompt type: {prompt_type}") - + def _add_object_specific_negative_prompts(self, prompt_id): obj_negative_prompts = self.negative_prompts[prompt_id] for prompt, prompt_type in obj_negative_prompts: - if prompt_type == 'point': + if prompt_type == "point": # nnInteractive requires (x, y, z) order point_prompt = tuple(prompt[::-1]) self.model.add_point_interaction( - point_prompt, - include_interaction=False, - run_prediction=True + point_prompt, include_interaction=False, run_prediction=True ) else: raise ValueError(f"Unsupported prompt type: {prompt_type}") - + def _add_global_negative_prompts(self): global_negative_prompts = self.negative_prompts[0] for prompt, prompt_type in global_negative_prompts: - if prompt_type == 'point': + if prompt_type == "point": # nnInteractive requires (x, y, z) order point_prompt = tuple(prompt[::-1]) self.model.add_point_interaction( - point_prompt, - include_interaction=False, - run_prediction=True + point_prompt, include_interaction=False, run_prediction=True ) else: raise ValueError(f"Unsupported prompt type: {prompt_type}") - + def _add_other_objects_prompts_as_negative(self, current_prompt_id): for prompt_id, prompts in self.prompts.items(): if prompt_id == current_prompt_id: continue - + self._add_object_prompts(prompt_id, is_negative=True) - - def segment( - self, - image, - treat_other_objects_as_background=True, - *args, - **kwargs - ): + + def segment(self, image, treat_other_objects_as_background=True, *args, **kwargs): """Run the segmentation model on the image using the prompts added Parameters @@ -222,15 +211,15 @@ def segment( image : (Z, Y, X) np.ndarray 3D z-stack image to segment. treat_other_objects_as_background : bool, optional - If True, when segmenting an object, the prompts added + If True, when segmenting an object, the prompts added for all the other objects are treated as negative prompts for the current object. Default is True Returns ------- (Z, Y, X) np.ndarray - Labelled array with the segmentation masks of the objects. - Smaller objects are added on top to prevent larger + Labelled array with the segmentation masks of the objects. + Smaller objects are added on top to prevent larger objects from removing smaller ones. Raises @@ -240,58 +229,57 @@ def segment( """ self._validate_image(image) - lab = np.zeros(image.shape, dtype=np.uint32) + lab = np.zeros(image.shape, dtype=np.uint32) for prompt_id, value in self.prompt_ids_image_mapper.items(): prompt_image, image_origin = value - + if prompt_image is None: prompt_image = image - + # Re-order axis from (z, y, x) to (x, y, z) for the model prompt_image = np.moveaxis(prompt_image, (0, 1, 2), (2, 1, 0)) - + prompt_image = prompt_image[np.newaxis] self.model.set_image(prompt_image) - - target_tensor = torch.zeros( - prompt_image.shape[1:], dtype=torch.uint8 - ) + + target_tensor = torch.zeros(prompt_image.shape[1:], dtype=torch.uint8) self.model.set_target_buffer(target_tensor) - + self._add_object_prompts(prompt_id, is_negative=False) self._add_object_specific_negative_prompts(prompt_id) self._add_global_negative_prompts() - + if treat_other_objects_as_background: # Add the other objects prompts as negative self._add_other_objects_prompts_as_negative(prompt_id) - + # self.model._predict() - + result_tensor = target_tensor.clone() - + # Convert to numpy array and re-order axis back to (z, y, x) result_mask = np.moveaxis( result_tensor.numpy(), (2, 1, 0), (0, 1, 2) ).astype(bool) - + # Insert the result into the global label array z0, y0, x0 = image_origin d, h, w = result_mask.shape z1, y1, x1 = z0 + d, y0 + h, x0 + w - + obj_slice = (slice(z0, z1), slice(y0, y1), slice(x0, x1)) lab[obj_slice][result_mask] = prompt_id - + self.model.reset_interactions() - + lab = build_combined_mask(lab) - + self.prompt_ids_image_mapper = {} self.prompts = defaultdict(list) self.negative_prompts = defaultdict(list) - + return lab + def url_help(): - return 'https://github.com/MIC-DKFZ/nnInteractive' \ No newline at end of file + return "https://github.com/MIC-DKFZ/nnInteractive" diff --git a/cellacdc/segmenters_promptable/sam2/__init__.py b/cellacdc/segmenters_promptable/sam2/__init__.py new file mode 100644 index 000000000..be442a127 --- /dev/null +++ b/cellacdc/segmenters_promptable/sam2/__init__.py @@ -0,0 +1,3 @@ +import cellacdc.utils as utils + +utils.check_install_sam2() diff --git a/cellacdc/promptable_models/sam2/acdcPromptSegment.py b/cellacdc/segmenters_promptable/sam2/acdcPromptSegment.py similarity index 90% rename from cellacdc/promptable_models/sam2/acdcPromptSegment.py rename to cellacdc/segmenters_promptable/sam2/acdcPromptSegment.py index a9c7d69c6..9d0e0e1da 100644 --- a/cellacdc/promptable_models/sam2/acdcPromptSegment.py +++ b/cellacdc/segmenters_promptable/sam2/acdcPromptSegment.py @@ -1,7 +1,7 @@ import os from collections import defaultdict -from cellacdc.promptable_models.utils import build_combined_mask, log_mask_selection +from cellacdc.segmenters_promptable.utils import build_combined_mask, log_mask_selection import numpy as np import cv2 @@ -9,8 +9,8 @@ from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor -from cellacdc import myutils -from cellacdc.models.sam2 import model_types, sam_models_path +from cellacdc import utils +from cellacdc.segmenters.sam2 import model_types, sam_segmenters_path class AvailableModels: @@ -43,8 +43,10 @@ def __init__(self, model_type: AvailableModels = "Large", gpu: bool = True): device = "cpu" config_file, sam_checkpoint = model_types[model_type] - sam_checkpoint = os.path.join(sam_models_path, sam_checkpoint) - sam = build_sam2(config_file=config_file, ckpt_path=sam_checkpoint, device=device) + sam_checkpoint = os.path.join(sam_segmenters_path, sam_checkpoint) + sam = build_sam2( + config_file=config_file, ckpt_path=sam_checkpoint, device=device + ) self.model = SAM2ImagePredictor(sam) self._embedded_img = None @@ -65,7 +67,7 @@ def _normalize_prompt(self, prompt): return int(z), float(y), float(x) def _to_rgb(self, image): - img = myutils.to_uint8(image) + img = utils.to_uint8(image) if img.ndim == 2: try: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -167,7 +169,10 @@ def segment( else: lab_out = np.zeros(image.shape, dtype=np.uint32) - for prompt_id, (prompt_image, image_origin) in self.prompt_ids_image_mapper.items(): + for prompt_id, ( + prompt_image, + image_origin, + ) in self.prompt_ids_image_mapper.items(): if prompt_id == 0: continue @@ -178,12 +183,9 @@ def segment( prompt_id, treat_other_objects_as_background ) - is_prompt_rgb = ( - prompt_image.ndim >= 3 and prompt_image.shape[-1] in (3, 4) - ) - is_prompt_z_stack = ( - (prompt_image.ndim == 3 and not is_prompt_rgb) - or (prompt_image.ndim == 4) + is_prompt_rgb = prompt_image.ndim >= 3 and prompt_image.shape[-1] in (3, 4) + is_prompt_z_stack = (prompt_image.ndim == 3 and not is_prompt_rgb) or ( + prompt_image.ndim == 4 ) if is_prompt_rgb: @@ -218,7 +220,9 @@ def segment( ) mask_idx = np.argmax(scores) if multimask_output else 0 if multimask_output: - log_mask_selection(prompt_id, masks, scores, mask_idx, z_slice=z) + log_mask_selection( + prompt_id, masks, scores, mask_idx, z_slice=z + ) mask = masks[mask_idx].astype(bool) obj_mask[z][mask] = True else: diff --git a/cellacdc/segmenters_promptable/segment_anything/__init__.py b/cellacdc/segmenters_promptable/segment_anything/__init__.py new file mode 100644 index 000000000..d7cbb0acc --- /dev/null +++ b/cellacdc/segmenters_promptable/segment_anything/__init__.py @@ -0,0 +1,3 @@ +import cellacdc.utils as utils + +utils.check_install_segment_anything() diff --git a/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py b/cellacdc/segmenters_promptable/segment_anything/acdcPromptSegment.py similarity index 91% rename from cellacdc/promptable_models/segment_anything/acdcPromptSegment.py rename to cellacdc/segmenters_promptable/segment_anything/acdcPromptSegment.py index 20ee1014b..922bd4250 100644 --- a/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py +++ b/cellacdc/segmenters_promptable/segment_anything/acdcPromptSegment.py @@ -1,15 +1,15 @@ import os from collections import defaultdict -from cellacdc.promptable_models.utils import build_combined_mask, log_mask_selection +from cellacdc.segmenters_promptable.utils import build_combined_mask, log_mask_selection import numpy as np import cv2 from segment_anything import sam_model_registry, SamPredictor -from cellacdc import myutils -from cellacdc.models.segment_anything import model_types, sam_models_path +from cellacdc import utils +from cellacdc.segmenters.segment_anything import model_types, sam_segmenters_path class AvailableModels: @@ -42,7 +42,7 @@ def __init__(self, model_type: AvailableModels = "Large", gpu: bool = False): device = "cpu" model_key, sam_checkpoint = model_types[model_type] - sam_checkpoint = os.path.join(sam_models_path, sam_checkpoint) + sam_checkpoint = os.path.join(sam_segmenters_path, sam_checkpoint) sam = sam_model_registry[model_key](checkpoint=sam_checkpoint) sam.to(device=device) @@ -65,7 +65,7 @@ def _normalize_prompt(self, prompt): return int(z), float(y), float(x) def _to_rgb(self, image): - img = myutils.to_uint8(image) + img = utils.to_uint8(image) if img.ndim == 2: try: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -167,7 +167,10 @@ def segment( else: lab_out = np.zeros(image.shape, dtype=np.uint32) - for prompt_id, (prompt_image, image_origin) in self.prompt_ids_image_mapper.items(): + for prompt_id, ( + prompt_image, + image_origin, + ) in self.prompt_ids_image_mapper.items(): if prompt_id == 0: continue @@ -178,12 +181,9 @@ def segment( prompt_id, treat_other_objects_as_background ) - is_prompt_rgb = ( - prompt_image.ndim >= 3 and prompt_image.shape[-1] in (3, 4) - ) - is_prompt_z_stack = ( - (prompt_image.ndim == 3 and not is_prompt_rgb) - or (prompt_image.ndim == 4) + is_prompt_rgb = prompt_image.ndim >= 3 and prompt_image.shape[-1] in (3, 4) + is_prompt_z_stack = (prompt_image.ndim == 3 and not is_prompt_rgb) or ( + prompt_image.ndim == 4 ) if is_prompt_rgb: @@ -218,7 +218,9 @@ def segment( ) mask_idx = np.argmax(scores) if multimask_output else 0 if multimask_output: - log_mask_selection(prompt_id, masks, scores, mask_idx, z_slice=z) + log_mask_selection( + prompt_id, masks, scores, mask_idx, z_slice=z + ) mask = masks[mask_idx].astype(bool) obj_mask[z][mask] = True else: diff --git a/cellacdc/promptable_models/utils.py b/cellacdc/segmenters_promptable/utils.py similarity index 89% rename from cellacdc/promptable_models/utils.py rename to cellacdc/segmenters_promptable/utils.py index ef23a3830..89afdd2eb 100644 --- a/cellacdc/promptable_models/utils.py +++ b/cellacdc/segmenters_promptable/utils.py @@ -28,11 +28,7 @@ def build_combined_mask(model_out): return combined -def _apply_overlap_rule( - lab_old, - lab_new, - mode: Literal['union', 'intersection'] - ): +def _apply_overlap_rule(lab_old, lab_new, mode: Literal["union", "intersection"]): """ Apply overlap rules between old and new label masks. @@ -72,13 +68,13 @@ def _apply_overlap_rule( p_only = np.logical_and(p_mask, ~q_mask) # Old only q_only = np.logical_and(q_mask, ~p_mask) # New only - if mode == 'union': + if mode == "union": # p OR q → all become p result[p_and_q] = p result[p_only] = p result[q_only] = p - elif mode == 'intersection': + elif mode == "intersection": # Only p AND q → p; p XOR q → 0 result[p_and_q] = p # p_only and q_only become 0 (already 0 in result) @@ -87,7 +83,7 @@ def _apply_overlap_rule( non_overlapping_new_ids = new_ids - overlapping_new_ids for q in non_overlapping_new_ids: q_mask = lab_new == q - if mode == 'union': + if mode == "union": result[q_mask] = q # In 'intersection' mode, non-overlapping new IDs are not added @@ -95,10 +91,10 @@ def _apply_overlap_rule( def insert_model_output_into_labels( - lab, - model_out, - edited_IDs: int | List[int] = 0, - ): + lab, + model_out, + edited_IDs: int | List[int] = 0, +): """ Combine model output with existing labels using three strategies. @@ -122,7 +118,7 @@ def insert_model_output_into_labels( lab_new = build_combined_mask(model_out) # Apply overlap rules for union and intersection - lab_union = _apply_overlap_rule(lab, lab_new, mode='union') - lab_intersection = _apply_overlap_rule(lab, lab_new, mode='intersection') + lab_union = _apply_overlap_rule(lab, lab_new, mode="union") + lab_intersection = _apply_overlap_rule(lab, lab_new, mode="intersection") return lab_new, lab_union, lab_intersection diff --git a/cellacdc/syntax.py b/cellacdc/syntax.py index 09205a00a..718c56d41 100644 --- a/cellacdc/syntax.py +++ b/cellacdc/syntax.py @@ -4,17 +4,17 @@ from qtpy import QtCore, QtGui, QtWidgets -def format(color, style=''): - """Return a QTextCharFormat with the given attributes. - """ + +def format(color, style=""): + """Return a QTextCharFormat with the given attributes.""" _color = QtGui.QColor() _color.setNamedColor(color) _format = QtGui.QTextCharFormat() _format.setForeground(_color) - if 'bold' in style: + if "bold" in style: _format.setFontWeight(QtGui.QFont.Weight.Bold) - if 'italic' in style: + if "italic" in style: _format.setFontItalic(True) return _format @@ -22,97 +22,145 @@ def format(color, style=''): # Syntax styles that can be shared by all languages STYLES = { - 'keyword': format('red'), - 'operator': format('red'), - 'brace': format('darkGray'), - 'defclass': format('darkMagenta'), - 'string': format('green'), - 'string2': format('darkMagenta'), - 'comment': format('darkBlu', 'italic'), - 'self': format('black', 'italic'), - 'numbers': format('brown'), + "keyword": format("red"), + "operator": format("red"), + "brace": format("darkGray"), + "defclass": format("darkMagenta"), + "string": format("green"), + "string2": format("darkMagenta"), + "comment": format("darkBlu", "italic"), + "self": format("black", "italic"), + "numbers": format("brown"), } -class PythonHighlighter (QtGui.QSyntaxHighlighter): - """Syntax highlighter for the Python language. - """ +class PythonHighlighter(QtGui.QSyntaxHighlighter): + """Syntax highlighter for the Python language.""" + # Python keywords keywords = [ - 'and', 'assert', 'break', 'class', 'continue', 'def', - 'del', 'elif', 'else', 'except', 'exec', 'finally', - 'for', 'from', 'global', 'if', 'import', 'in', - 'is', 'lambda', 'not', 'or', 'pass', 'print', - 'raise', 'return', 'try', 'while', 'yield', - 'None', 'True', 'False', + "and", + "assert", + "break", + "class", + "continue", + "def", + "del", + "elif", + "else", + "except", + "exec", + "finally", + "for", + "from", + "global", + "if", + "import", + "in", + "is", + "lambda", + "not", + "or", + "pass", + "print", + "raise", + "return", + "try", + "while", + "yield", + "None", + "True", + "False", ] # Python operators operators = [ - '=', + "=", # Comparison - '==', '!=', '<', '<=', '>', '>=', + "==", + "!=", + "<", + "<=", + ">", + ">=", # Arithmetic - '\+', '-', '\*', '/', '//', '\%', '\*\*', + "\+", + "-", + "\*", + "/", + "//", + "\%", + "\*\*", # In-place - '\+=', '-=', '\*=', '/=', '\%=', + "\+=", + "-=", + "\*=", + "/=", + "\%=", # Bitwise - '\^', '\|', '\&', '\~', '>>', '<<', + "\^", + "\|", + "\&", + "\~", + ">>", + "<<", ] # Python braces braces = [ - '\{', '\}', '\(', '\)', '\[', '\]', + "\{", + "\}", + "\(", + "\)", + "\[", + "\]", ] def __init__(self, parent: QtGui.QTextDocument) -> None: super().__init__(parent) # Multi-line strings (expression, flag, style) - self.tri_single = (QtCore.QRegularExpression("'''"), 1, STYLES['string2']) - self.tri_double = (QtCore.QRegularExpression('"""'), 2, STYLES['string2']) + self.tri_single = (QtCore.QRegularExpression("'''"), 1, STYLES["string2"]) + self.tri_double = (QtCore.QRegularExpression('"""'), 2, STYLES["string2"]) rules = [] # Keyword, operator, and brace rules - rules += [(r'\b%s\b' % w, 0, STYLES['keyword']) - for w in PythonHighlighter.keywords] - rules += [(r'%s' % o, 0, STYLES['operator']) - for o in PythonHighlighter.operators] - rules += [(r'%s' % b, 0, STYLES['brace']) - for b in PythonHighlighter.braces] + rules += [ + (r"\b%s\b" % w, 0, STYLES["keyword"]) for w in PythonHighlighter.keywords + ] + rules += [ + (r"%s" % o, 0, STYLES["operator"]) for o in PythonHighlighter.operators + ] + rules += [(r"%s" % b, 0, STYLES["brace"]) for b in PythonHighlighter.braces] # All other rules rules += [ # 'self' - (r'\bself\b', 0, STYLES['self']), - + (r"\bself\b", 0, STYLES["self"]), # 'def' followed by an identifier - (r'\bdef\b\s*(\w+)', 1, STYLES['defclass']), + (r"\bdef\b\s*(\w+)", 1, STYLES["defclass"]), # 'class' followed by an identifier - (r'\bclass\b\s*(\w+)', 1, STYLES['defclass']), - + (r"\bclass\b\s*(\w+)", 1, STYLES["defclass"]), # Numeric literals - (r'\b[+-]?[0-9]+[lL]?\b', 0, STYLES['numbers']), - (r'\b[+-]?0[xX][0-9A-Fa-f]+[lL]?\b', 0, STYLES['numbers']), - (r'\b[+-]?[0-9]+(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?\b', 0, STYLES['numbers']), - + (r"\b[+-]?[0-9]+[lL]?\b", 0, STYLES["numbers"]), + (r"\b[+-]?0[xX][0-9A-Fa-f]+[lL]?\b", 0, STYLES["numbers"]), + (r"\b[+-]?[0-9]+(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?\b", 0, STYLES["numbers"]), # Double-quoted string, possibly containing escape sequences - (r'"[^"\\]*(\\.[^"\\]*)*"', 0, STYLES['string']), + (r'"[^"\\]*(\\.[^"\\]*)*"', 0, STYLES["string"]), # Single-quoted string, possibly containing escape sequences - (r"'[^'\\]*(\\.[^'\\]*)*'", 0, STYLES['string']), - + (r"'[^'\\]*(\\.[^'\\]*)*'", 0, STYLES["string"]), # From '#' until a newline - (r'#[^\n]*', 0, STYLES['comment']), + (r"#[^\n]*", 0, STYLES["comment"]), ] # Build a QRegularExpression for each pattern - self.rules = [(QtCore.QRegularExpression(pat), index, fmt) - for (pat, index, fmt) in rules] + self.rules = [ + (QtCore.QRegularExpression(pat), index, fmt) for (pat, index, fmt) in rules + ] def highlightBlock(self, text): - """Apply syntax highlighting to the given block of text. - """ + """Apply syntax highlighting to the given block of text.""" self.tripleQuoutesWithinStrings = [] # Do other syntax formatting for expression, nth, format in self.rules: @@ -121,7 +169,10 @@ def highlightBlock(self, text): # if there is a string we check # if there are some triple quotes within the string # they will be ignored if they are matched again - if expression.pattern() in [r'"[^"\\]*(\\.[^"\\]*)*"', r"'[^'\\]*(\\.[^'\\]*)*'"]: + if expression.pattern() in [ + r'"[^"\\]*(\\.[^"\\]*)*"', + r"'[^'\\]*(\\.[^'\\]*)*'", + ]: innerIndex = self.tri_single[0].indexIn(text, index + 1) if innerIndex == -1: innerIndex = self.tri_double[0].indexIn(text, index + 1) @@ -191,4 +242,4 @@ def match_multiline(self, text, delimiter, in_state, style): if self.currentBlockState() == in_state: return True else: - return False \ No newline at end of file + return False diff --git a/cellacdc/test_segm_model.py b/cellacdc/test_segm_model.py index 6cc950975..02e9bf9f1 100755 --- a/cellacdc/test_segm_model.py +++ b/cellacdc/test_segm_model.py @@ -8,7 +8,7 @@ from importlib import import_module from cellacdc._run import _setup_app -from cellacdc import apps, myutils, widgets, data, core, load +from cellacdc import apps, utils, widgets, data, core, load from cellacdc import prompts import skimage.color @@ -16,7 +16,8 @@ try: import pytest - pytest.skip('skipping this test since it is gui based', allow_module_level=True) + + pytest.skip("skipping this test since it is gui based", allow_module_level=True) except Exception as e: pass @@ -29,40 +30,39 @@ # test_data = data.BABYtestData() # test_data = data.YeastMitoSnapshotData() -app, splashScreen = _setup_app(splashscreen=True) +app, splashScreen = _setup_app(splashscreen=True) splashScreen.close() initialWindow = apps.TestSegmModelInitalDialog() initialWindow.exec_() if initialWindow.cancel: - print('Execution cancelled.') + print("Execution cancelled.") exit() start_frame_idx = initialWindow.start_frame_n if start_frame_idx is not None: start_frame_idx -= 1 - -stop_frame_n = initialWindow.stop_frame_n + +stop_frame_n = initialWindow.stop_frame_n start_z_slice_idx = initialWindow.start_z_slice_n if start_z_slice_idx is not None: start_z_slice_idx -= 1 - + stop_z_slice_n = initialWindow.stop_z_slice_n is_timelapse = initialWindow.is_timelapse if test_data is None: tif_filepath, _ = qtpy.compat.getopenfilename( - basedir=myutils.getMostRecentPath(), - filters=('Images (*.tif)') + basedir=utils.getMostRecentPath(), filters=("Images (*.tif)") ) if not tif_filepath: - exit('Execution cancelled.') - + exit("Execution cancelled.") + images_path = os.path.dirname(tif_filepath) - basename = os.path.commonprefix(myutils.listdir(images_path)) + basename = os.path.commonprefix(utils.listdir(images_path)) filename, ext = os.path.splitext(os.path.basename(tif_filepath)) - channel = filename[len(basename):] + channel = filename[len(basename) :] posData = load.loadData(tif_filepath, channel) posData.loadImgData() image_data = posData.img_data @@ -71,46 +71,44 @@ posData = test_data.posData() image_data = test_data.image_data() images_path = test_data.images_path - + posData.loadOtherFiles(load_segm_data=False, load_metadata=True) posData.buildPaths() if is_timelapse: - img = image_data[ - start_frame_idx:stop_frame_n, - start_z_slice_idx:stop_z_slice_n - ] + img = image_data[start_frame_idx:stop_frame_n, start_z_slice_idx:stop_z_slice_n] else: img = image_data[start_z_slice_idx:stop_z_slice_n] from cellacdc.plot import imshow + imshow(img) cellacdc_path = os.path.dirname(os.path.abspath(__file__)) -models = myutils.get_list_of_models() +models = utils.get_list_of_models() win = widgets.QDialogListbox( - 'Select model', - 'Select model to use for segmentation: ', + "Select model", + "Select model to use for segmentation: ", models, - multiSelection=False + multiSelection=False, ) win.exec_() if win.cancel: - sys.exit('Execution aborted') + sys.exit("Execution aborted") model_name = win.selectedItemsText[0] -if model_name == 'Automatic thresholding': - model_name = 'thresholding' +if model_name == "Automatic thresholding": + model_name = "thresholding" # Check if model needs to be downloaded downloadWin = apps.downloadModel(model_name, parent=None) downloadWin.download() # Load model as a module -acdcSegment = myutils.import_segment_module(model_name) +acdcSegment = utils.import_segment_module(model_name) # Read all models parameters -init_params, segment_params = myutils.getModelArgSpec(acdcSegment) +init_params, segment_params = utils.getModelArgSpec(acdcSegment) # Prompt user to enter the model parameters try: @@ -119,46 +117,54 @@ url = None out = prompts.init_segm_model_params( - posData, model_name, init_params, segment_params, - help_url=url, qparent=None, init_last_params=True + posData, + model_name, + init_params, + segment_params, + help_url=url, + qparent=None, + init_last_params=True, ) -win = out.get('win') +win = out.get("win") if win.cancel: - exit('Execution cancelled.') + exit("Execution cancelled.") # Initialize model segm_data = None init_kwargs = win.init_kwargs -segm_endname = init_kwargs.pop('segm_endname', None) +segm_endname = init_kwargs.pop("segm_endname", None) if segm_endname is not None: segm_filepath, _ = load.get_path_from_endname(segm_endname, images_path) - segm_data = np.load(segm_filepath)['arr_0'] + segm_data = np.load(segm_filepath)["arr_0"] -model = myutils.init_segm_model(acdcSegment, posData, win.init_kwargs) +model = utils.init_segm_model(acdcSegment, posData, win.init_kwargs) if model is None: - sys.exit('Segmentation model was not initialized correctly!') -is_segment3DT_available = any( - [name=='segment3DT' for name in dir(model)] -) + sys.exit("Segmentation model was not initialized correctly!") +is_segment3DT_available = any([name == "segment3DT" for name in dir(model)]) if img.ndim == 3 and (img.shape[-1] == 3 or img.shape[-1] == 4): img = skimage.color.rgb2gray(img) -print('Input image shape: ', img.shape) -print('Segmentation process started...') +print("Input image shape: ", img.shape) +print("Segmentation process started...") lab = core.segm_model_segment( - model, img, win.model_kwargs, frame_i=start_frame_idx, - preproc_recipe=win.preproc_recipe, posData=posData, + model, + img, + win.model_kwargs, + frame_i=start_frame_idx, + preproc_recipe=win.preproc_recipe, + posData=posData, is_timelapse_model_and_data=is_segment3DT_available and is_timelapse, ) from cellacdc.plot import imshow + imshow( - img, - lab, + img, + lab, window_title=f'Result of segmenting with "{model_name}" model', - axis_titles=['Input image', 'Segmentation result'], + axis_titles=["Input image", "Segmentation result"], annotate_labels_idxs=[1], ) diff --git a/cellacdc/test_tracker.py b/cellacdc/test_tracker.py index 2edc18be1..b54c03913 100644 --- a/cellacdc/test_tracker.py +++ b/cellacdc/test_tracker.py @@ -2,7 +2,7 @@ import sys import numpy as np import skimage.measure -from cellacdc import core, myutils, widgets, load, html_utils +from cellacdc import core, utils, widgets, load, html_utils from cellacdc import data, data_path from cellacdc import transformation from cellacdc.plot import imshow @@ -10,7 +10,8 @@ try: import pytest - pytest.skip('skipping this test since it is gui based', allow_module_level=True) + + pytest.skip("skipping this test since it is gui based", allow_module_level=True) except Exception as e: pass @@ -18,12 +19,12 @@ from cellacdc._run import _setup_app # Ask which model to use --> Test if new model is visible -app, splashScreen = _setup_app(splashscreen=True) +app, splashScreen = _setup_app(splashscreen=True) splashScreen.close() -channel_name = 'SiR_Hoechst' -end_filename_segm = 'segm' # 'segm_test' -START_FRAME = 0 +channel_name = "SiR_Hoechst" +end_filename_segm = "segm" # 'segm_test' +START_FRAME = 0 STOP_FRAME = 10 # PLOT_FRAME = 499 SAVE = False @@ -39,16 +40,15 @@ if test_data is None: tif_filepath, _ = qtpy.compat.getopenfilename( - basedir=myutils.getMostRecentPath(), - filters=('Images (*.tif)') + basedir=utils.getMostRecentPath(), filters=("Images (*.tif)") ) if not tif_filepath: - exit('Execution cancelled.') - + exit("Execution cancelled.") + images_path = os.path.dirname(tif_filepath) - basename = os.path.commonprefix(myutils.listdir(images_path)) + basename = os.path.commonprefix(utils.listdir(images_path)) filename, ext = os.path.splitext(os.path.basename(tif_filepath)) - channel = filename[len(basename):] + channel = filename[len(basename) :] posData = load.loadData(tif_filepath, channel) else: posData = test_data.posData() @@ -56,41 +56,39 @@ posData.loadImgData() posData.loadOtherFiles( - load_segm_data=True, - load_metadata=True, - end_filename_segm=end_filename_segm + load_segm_data=True, load_metadata=True, end_filename_segm=end_filename_segm ) -lab_stack = posData.segm_data[START_FRAME:STOP_FRAME+1] +lab_stack = posData.segm_data[START_FRAME : STOP_FRAME + 1] -imshow(lab_stack, axis_titles=['Before tracking'], annotate_labels_idxs=[0]) +imshow(lab_stack, axis_titles=["Before tracking"], annotate_labels_idxs=[0]) -trackers = myutils.get_list_of_trackers() -txt = html_utils.paragraph(''' +trackers = utils.get_list_of_trackers() +txt = html_utils.paragraph(""" Select the tracker to use

    -''') +""") win = widgets.QDialogListbox( - 'Select tracker', txt, trackers, multiSelection=False, parent=None + "Select tracker", txt, trackers, multiSelection=False, parent=None ) win.exec_() if win.cancel: - sys.exit('Execution aborted') + sys.exit("Execution aborted") trackerName = win.selectedItemsText[0] # Load tracker -tracker, track_params = myutils.init_tracker( +tracker, track_params = utils.init_tracker( posData, trackerName, qparent=None, realTime=REAL_TIME_TRACKER ) if track_params is None: - exit('Execution aborted') + exit("Execution aborted") print(posData.segm_data.shape) -lab_stack = posData.segm_data[START_FRAME:STOP_FRAME+1] +lab_stack = posData.segm_data[START_FRAME : STOP_FRAME + 1] if SCRUMBLE_IDs: # Scrumble IDs last frame - + last_lab = lab_stack[-1] last_rp = skimage.measure.regionprops(lab_stack[-1]) IDs = [obj.label for obj in last_rp] @@ -106,53 +104,52 @@ obj_to_del = last_rp[random_idx] last_lab[obj_to_del.slice][obj_to_del.image] = 0 -print(f'Tracking data with shape {lab_stack.shape}') +print(f"Tracking data with shape {lab_stack.shape}") trackerInputImage = None -if 'image' in track_params: - trackerInputImage = track_params.pop('image')[START_FRAME:STOP_FRAME+1] +if "image" in track_params: + trackerInputImage = track_params.pop("image")[START_FRAME : STOP_FRAME + 1] -if 'image_channel_name' in track_params: - # Store the channel name for the tracker for loading it +if "image_channel_name" in track_params: + # Store the channel name for the tracker for loading it # in case of multiple pos - track_params.pop('image_channel_name') + track_params.pop("image_channel_name") tracked_stack = core.tracker_track( - lab_stack, tracker, track_params, - intensity_img=trackerInputImage, - logger_func=print + lab_stack, tracker, track_params, intensity_img=trackerInputImage, logger_func=print ) -if hasattr(posData, 'acdc_output_csv_path'): +if hasattr(posData, "acdc_output_csv_path"): posData.fromTrackerToAcdcDf(tracker, tracked_stack, save=True) first_untracked_lab = lab_stack[0] uniqueID = max(np.max(lab_stack), np.max(tracked_stack)) + 1 retracked_video = transformation.retrack_based_on_untracked_first_frame( - tracked_stack.copy(), first_untracked_lab, uniqueID=uniqueID + tracked_stack.copy(), first_untracked_lab, uniqueID=uniqueID ) if SAVE: try: io.savez_compressed( - posData.segm_npz_path.replace('segm', 'segm_tracked'), - tracked_stack + posData.segm_npz_path.replace("segm", "segm_tracked"), tracked_stack ) except Exception as e: - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() imshow( - posData.loadChannelData(''), + posData.loadChannelData(""), tracked_stack, retracked_video, lab_stack, axis_titles=[ - 'Intensity channel', - 'After tracking', - 'After re-tracking first frame', - 'Before tracking' + "Intensity channel", + "After tracking", + "After re-tracking first frame", + "Before tracking", ], - annotate_labels_idxs=[1, 2, 3] + annotate_labels_idxs=[1, 2, 3], ) diff --git a/cellacdc/tools/__init__.py b/cellacdc/tools/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/cellacdc/utils/acdcToSymDiv.py b/cellacdc/tools/acdcToSymDiv.py similarity index 74% rename from cellacdc/utils/acdcToSymDiv.py rename to cellacdc/tools/acdcToSymDiv.py index 24f514f2f..0d2067acf 100644 --- a/cellacdc/utils/acdcToSymDiv.py +++ b/cellacdc/tools/acdcToSymDiv.py @@ -7,24 +7,20 @@ from tqdm import tqdm from qtpy.QtCore import Signal, QThread -from qtpy.QtWidgets import ( - QDialog, QVBoxLayout, QHBoxLayout, QLabel, QStyle -) +from qtpy.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QLabel, QStyle + +from .. import widgets, apps, workers, html_utils, utils, gui, load, printl -from .. import ( - widgets, apps, workers, html_utils, myutils, - gui, load, printl -) class AcdcToSymDivUtil(QDialog): def __init__(self, expPaths, app, parent=None): super().__init__(parent) - self.setWindowTitle('Utility to add symmetric division table') + self.setWindowTitle("Utility to add symmetric division table") self.parent = parent - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='utils.AcdcToSymDiv' + logger, logs_path, log_path, log_filename = utils.setupLogger( + module="utils.AcdcToSymDiv" ) self.logger = logger self.log_path = log_path @@ -41,11 +37,11 @@ def __init__(self, expPaths, app, parent=None): infoLayout = QHBoxLayout() infoTxt = html_utils.paragraph( - 'Computing lineage tree table for symmetrically dividing cells...' + "Computing lineage tree table for symmetrically dividing cells..." ) iconLabel = QLabel(self) - standardIcon = getattr(QStyle, 'SP_MessageBoxInformation') + standardIcon = getattr(QStyle, "SP_MessageBoxInformation") icon = self.style().standardIcon(standardIcon) pixmap = icon.pixmap(60, 60) iconLabel.setPixmap(pixmap) @@ -54,7 +50,7 @@ def __init__(self, expPaths, app, parent=None): infoLayout.addWidget(QLabel(infoTxt)) buttonsLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton('Cancel') + cancelButton = widgets.cancelPushButton("Cancel") buttonsLayout.addStretch(1) buttonsLayout.addWidget(cancelButton) @@ -72,8 +68,9 @@ def showEvent(self, event): def runWorker(self): self.progressWin = apps.QDialogWorkerProgress( - title='Building lineage tree table', parent=self, - pbarDesc='Building lineage tree table...' + title="Building lineage tree table", + parent=self, + pbarDesc="Building lineage tree table...", ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) @@ -103,35 +100,35 @@ def workerInitProgressbar(self, totalIter): if totalIter == 1: totalIter = 0 self.progressWin.mainPbar.setMaximum(totalIter) - + def workerUpdateProgressbar(self, step): self.progressWin.mainPbar.update(step) - + def workerUpdatePbarDesc(self, desc): self.progressWin.progressLabel.setText(desc) - + def warnPermissionError(self, traceback_str, path): err_msg = html_utils.paragraph( - 'The file below is open in another app ' - '(Excel maybe?).

    ' - f'{path}

    ' + "The file below is open in another app " + "(Excel maybe?).

    " + f"{path}

    " 'Close file and then press "Ok".' ) msg = widgets.myMessageBox(wrapText=False) msg.setDetailedText(traceback_str) - msg.warning(self, 'Permission error', err_msg) + msg.warning(self, "Permission error", err_msg) self.worker.waitCond.wakeAll() - + def selectSegmFileLoadData(self, exp_path, pos_foldernames): # Get end name of every existing segmentation file existingSegmEndNames = set() for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) # Use first found channel, it doesn't matter for metrics for chName in chNames: - filePath = myutils.getChannelFilePath(images_path, chName) + filePath = utils.getChannelFilePath(images_path, chName) if filePath: break else: @@ -142,9 +139,7 @@ def selectSegmFileLoadData(self, exp_path, pos_foldernames): _posData = load.loadData(filePath, chName) _posData.getBasenameAndChNames() segm_files = load.get_segm_files(_posData.images_path) - _existingEndnames = load.get_endnames( - _posData.basename, segm_files - ) + _existingEndnames = load.get_endnames(_posData.basename, segm_files) existingSegmEndNames.update(_existingEndnames) if len(existingSegmEndNames) == 1: @@ -152,26 +147,24 @@ def selectSegmFileLoadData(self, exp_path, pos_foldernames): self.worker.waitCond.wakeAll() return - win = apps.SelectSegmFileDialog( - existingSegmEndNames, exp_path, parent=self - ) + win = apps.SelectSegmFileDialog(existingSegmEndNames, exp_path, parent=self) win.exec_() self.endFilenameSegm = win.selectedItemText self.worker.abort = win.cancel self.worker.waitCond.wakeAll() - + def addRegionPropsErrors(self, traceback_format, error_message): - self.logger.info('') - print('====================================') + self.logger.info("") + print("====================================") self.logger.info(traceback_format) - print('====================================') + print("====================================") self.worker.regionPropsErrors[error_message] = traceback_format - + def addCombinedMetricsError(self, traceback_format, func_name): - self.logger.info('') - print('====================================') + self.logger.info("") + print("====================================") self.logger.info(traceback_format) - print('====================================') + print("====================================") self.worker.customMetricsErrors[func_name] = traceback_format def skipEvent(self, dummy): @@ -189,17 +182,16 @@ def abortCallback(self): self.worker.abort = True else: self.close() - + def warnMissingAnnot(self, missingAnnotErrors): win = apps.ComputeMetricsErrorsDialog( - missingAnnotErrors, self.logs_path, log_type='missing_annot', - parent=self + missingAnnotErrors, self.logs_path, log_type="missing_annot", parent=self ) win.exec_() - + def warnErrors(self, errors): win = apps.ComputeMetricsErrorsDialog( - errors, self.logs_path, log_type='generic', parent=self + errors, self.logs_path, log_type="generic", parent=self ) win.exec_() @@ -208,39 +200,39 @@ def workerCritical(self, error): raise error except: traceback_str = traceback.format_exc() - print('='*20) + print("=" * 20) self.worker.logger.log(traceback_str) - print('='*20) + print("=" * 20) def workerFinished(self, worker): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() - + if worker.abort: - txt = 'Adding lineage tree table ABORTED.' + txt = "Adding lineage tree table ABORTED." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.warning(self, 'Process aborted', html_utils.paragraph(txt)) + msg.warning(self, "Process aborted", html_utils.paragraph(txt)) elif worker.errors or worker.missingAnnotErrors: if worker.errors: self.warnErrors(worker.errors) else: self.warnMissingAnnot(worker.missingAnnotErrors) - txt = 'Adding lineage tree table completed WITH ERRORS.' + txt = "Adding lineage tree table completed WITH ERRORS." msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.warning(self, 'Process warning', html_utils.paragraph(txt)) + msg.warning(self, "Process warning", html_utils.paragraph(txt)) else: - txt = 'Adding lineage tree table completed.' + txt = "Adding lineage tree table completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) self.worker = None self.progressWin = None self.close() - def workerProgress(self, text, loggerLevel='INFO'): + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: self.progressWin.logConsole.append(text) self.logger.log(getattr(logging, loggerLevel), text) diff --git a/cellacdc/utils/align.py b/cellacdc/tools/align.py similarity index 64% rename from cellacdc/utils/align.py rename to cellacdc/tools/align.py index 5938367a6..9cb5f3ee5 100755 --- a/cellacdc/utils/align.py +++ b/cellacdc/tools/align.py @@ -2,32 +2,37 @@ from qtpy.QtWidgets import QFileDialog -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class alignWin(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.AlignWorker(self) self.worker.sigAskUseSavedShifts.connect(self.askUseSavedShifts) self.worker.sigAskSelectChannel.connect(self.askSelectChannel) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def askUseSavedShifts(self, exp_path, basename): txt = html_utils.paragraph(f""" Some or all the Positions in this experiment folder

    @@ -37,34 +42,39 @@ def askUseSavedShifts(self, exp_path, basename): """) msg = widgets.myMessageBox(wrapText=False, showCentered=False) _, useShiftsButton, ignoreShiftsButton, revertButton = msg.question( - self, 'Select how saved shifts', txt, + self, + "Select how saved shifts", + txt, buttonsTexts=( - 'Cancel', 'Apply alignment from saved shifts', - 'Ignore saved shifts and compute alignment', - 'Revert alignment using saved shifts' - ) + "Cancel", + "Apply alignment from saved shifts", + "Ignore saved shifts and compute alignment", + "Revert alignment using saved shifts", + ), ) if msg.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() - + self.worker.revertedAlignEndname = None if msg.clickedButton == useShiftsButton: - savedShiftsHow = 'use_saved_shifts' + savedShiftsHow = "use_saved_shifts" elif msg.clickedButton == ignoreShiftsButton: - savedShiftsHow = 'ignore_saved_shifts' + savedShiftsHow = "ignore_saved_shifts" elif msg.clickedButton == revertButton: - savedShiftsHow = 'rever_alignment' + savedShiftsHow = "rever_alignment" txt = html_utils.paragraph(f""" How do you want to save the image file with reverted alignment? """) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - overWriteButton = widgets.savePushButton('Overwrite existing file') - saveAsButton = widgets.newFilePushButton('Save as new file...') + overWriteButton = widgets.savePushButton("Overwrite existing file") + saveAsButton = widgets.newFilePushButton("Save as new file...") _, overWriteButton, saveAsButton = msg.question( - self, 'Select how saved shifts', txt, - buttonsTexts=('Cancel', overWriteButton, saveAsButton), - showDialog=False + self, + "Select how saved shifts", + txt, + buttonsTexts=("Cancel", overWriteButton, saveAsButton), + showDialog=False, ) saveAsButton.clicked.disconnect() saveAsCallback = partial(self.askAppendedName, basename, msg) @@ -73,16 +83,19 @@ def askUseSavedShifts(self, exp_path, basename): if msg.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() - + self.worker.savedShiftsHow = savedShiftsHow self.worker.waitCond.wakeAll() - + def askAppendedName(self, basename, parent): win = apps.filenameDialog( - ext='.tif', title='Reverted alignment data filename', - hintText='Insert a text to append to the filename', - parent=self, basename=basename, allowEmpty=False + ext=".tif", + title="Reverted alignment data filename", + hintText="Insert a text to append to the filename", + parent=self, + basename=basename, + allowEmpty=False, ) win.exec_() if win.cancel: @@ -90,33 +103,36 @@ def askAppendedName(self, basename, parent): self.worker.revertedAlignEndname = win.entryText parent.cancel = False parent.close() - + def askSelectChannel(self, channels): selectChannelWin = apps.QDialogCombobox( - 'Select channel', channels, 'Select reference channel for the aligment', - CbLabel='Select channel: ', parent=self + "Select channel", + channels, + "Select reference channel for the aligment", + CbLabel="Select channel: ", + parent=self, ) selectChannelWin.exec_() if selectChannelWin.cancel: self.worker.abort = True - self.worker.waitCond.wakeAll() - + self.worker.waitCond.wakeAll() + self.worker.chName = selectChannelWin.selectedItemText self.worker.waitCond.wakeAll() def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Aligning frames process CANCELLED.' + txt = "Aligning frames process CANCELLED." else: - txt = 'Aligning frames process completed.' + txt = "Aligning frames process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/applyTrackFromTable.py b/cellacdc/tools/applyTrackFromTable.py similarity index 53% rename from cellacdc/utils/applyTrackFromTable.py rename to cellacdc/tools/applyTrackFromTable.py index 47997efa8..49bffed78 100644 --- a/cellacdc/utils/applyTrackFromTable.py +++ b/cellacdc/tools/applyTrackFromTable.py @@ -4,44 +4,44 @@ import pandas as pd from .. import exception_handler -from .. import myutils, apps, widgets, html_utils, printl, workers -from ..utils import base +from .. import utils, apps, widgets, html_utils, printl, workers +from . import base from qtpy.QtWidgets import QFileDialog + class ApplyTrackingInfoFromTableUtil(base.MainThreadSinglePosUtilBase): def __init__( - self, app, title: str, infoText: str, parent=None, - callbackOnFinished=None - ): - module = myutils.get_module_name(__file__) - super().__init__( - app, title, module, infoText, parent - ) + self, app, title: str, infoText: str, parent=None, callbackOnFinished=None + ): + module = utils.get_module_name(__file__) + super().__init__(app, title, module, infoText, parent) self.sigClose.connect(self.close) self.callbackOnFinished = callbackOnFinished @exception_handler def run(self, posPath): - self.logger.info('Reading exisiting segmentation file names...') + self.logger.info("Reading exisiting segmentation file names...") endFilenameSegm = self.selectSegmFileLoadData(posPath) if not endFilenameSegm: return False - + msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph( 'After clicking "Ok" you will be asked to select the table ' - 'file (.csv) containing the tracking information.' + "file
    (.csv) containing the tracking information." ) - msg.information(self, 'Instructions', txt) + msg.information(self, "Instructions", txt) if msg.cancel: return False - + csvPath = QFileDialog.getOpenFileName( - self, 'Select table with tracking info', posPath, - "CSV files (*.csv);;All Files (*)" + self, + "Select table with tracking info", + posPath, + "CSV files (*.csv);;All Files (*)", )[0] if not csvPath: return False @@ -55,32 +55,32 @@ def run(self, posPath): win.exec_() if win.cancel: return False - + columnsInfo = { - 'frameIndexCol': win.frameIndexCol, - 'trackIDsCol': win.trackedIDsCol, - 'maskIDsCol': win.maskIDsCol, - 'xCentroidCol': win.xCentroidCol, - 'yCentroidCol': win.yCentroidCol, - 'parentIDcol': win.parentIDcol, - 'isFirstFrameOne': win.isFirstFrameOne, - 'deleteUntrackedIDs': win.deleteUntrackedIDs + "frameIndexCol": win.frameIndexCol, + "trackIDsCol": win.trackedIDsCol, + "maskIDsCol": win.maskIDsCol, + "xCentroidCol": win.xCentroidCol, + "yCentroidCol": win.yCentroidCol, + "parentIDcol": win.parentIDcol, + "isFirstFrameOne": win.isFirstFrameOne, + "deleteUntrackedIDs": win.deleteUntrackedIDs, } - - imagesPath = os.path.join(posPath, 'Images') + + imagesPath = os.path.join(posPath, "Images") segmFilename = [ - f for f in myutils.listdir(imagesPath) - if f.endswith(f'{endFilenameSegm}.npz') + f + for f in utils.listdir(imagesPath) + if f.endswith(f"{endFilenameSegm}.npz") ][0] basename = os.path.splitext(segmFilename)[0] - overWriteButton = widgets.savePushButton( - 'Overwrite existing segmentation file' - ) + overWriteButton = widgets.savePushButton("Overwrite existing segmentation file") win = apps.filenameDialog( - basename=f'{basename}_', - hintText='Insert a filename for the tracked masks file:', - allowEmpty=False, defaultEntry='tracked', - additionalButtons=(overWriteButton, ) + basename=f"{basename}_", + hintText="Insert a filename for the tracked masks file:", + allowEmpty=False, + defaultEntry="tracked", + additionalButtons=(overWriteButton,), ) overWriteButton.clicked.connect(partial(self.overWriteClicked, win)) win.exec_() @@ -94,9 +94,8 @@ def run(self, posPath): self.worker.signals.finished.connect(self.callbackOnFinished) self.runWorker(self.worker) return True - + def overWriteClicked(self, win): win.cancel = False - win.filename = '' + win.filename = "" win.close() - diff --git a/cellacdc/utils/applyTrackFromTrackMateXML.py b/cellacdc/tools/applyTrackFromTrackMateXML.py similarity index 58% rename from cellacdc/utils/applyTrackFromTrackMateXML.py rename to cellacdc/tools/applyTrackFromTrackMateXML.py index c1b96fd6a..ed91d1daf 100644 --- a/cellacdc/utils/applyTrackFromTrackMateXML.py +++ b/cellacdc/tools/applyTrackFromTrackMateXML.py @@ -4,54 +4,54 @@ import pandas as pd from .. import exception_handler -from .. import myutils, apps, widgets, html_utils, printl, workers +from .. import utils, apps, widgets, html_utils, printl, workers from .. import transformation, load -from ..utils import base +from . import base from qtpy.QtWidgets import QFileDialog + class ApplyTrackingInfoFromTrackMateUtil(base.MainThreadSinglePosUtilBase): def __init__( - self, app, title: str, infoText: str, parent=None, - callbackOnFinished=None - ): - module = myutils.get_module_name(__file__) - super().__init__( - app, title, module, infoText, parent - ) + self, app, title: str, infoText: str, parent=None, callbackOnFinished=None + ): + module = utils.get_module_name(__file__) + super().__init__(app, title, module, infoText, parent) self.sigClose.connect(self.close) self.callbackOnFinished = callbackOnFinished @exception_handler def run(self, posPath): - self.logger.info('Reading exisiting segmentation file names...') + self.logger.info("Reading exisiting segmentation file names...") endFilenameSegm = self.selectSegmFileLoadData(posPath) if not endFilenameSegm: return False - + msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph( 'After clicking "Ok" you will be asked to select the XML ' - 'file (.csv) containing the tracking information.' + "file (.csv) containing the tracking information." ) - msg.information(self, 'Instructions', txt) + msg.information(self, "Instructions", txt) if msg.cancel: return False - + xmlPath = QFileDialog.getOpenFileName( - self, 'Select table with tracking info', posPath, - "XML files (*.xml);;All Files (*)" + self, + "Select table with tracking info", + posPath, + "XML files (*.xml);;All Files (*)", )[0] if not xmlPath: return False xmlName = os.path.basename(xmlPath) self.logger.info(f'Parsing XML file "{xmlName}"...') - + df = transformation.trackmate_xml_to_df(xmlPath) - csvName = xmlName.replace('.xml', '.csv') + csvName = xmlName.replace(".xml", ".csv") csvPath = load.save_df_to_csv_temp_path(df, csvName, index=False) deleteUntrackedIDs, proceed = self.askDeleteUntrackedIDs() @@ -61,32 +61,32 @@ def run(self, posPath): # win.exec_() # if win.cancel: # return False - + columnsInfo = { - 'frameIndexCol': 'frame_i', - 'trackIDsCol': 'ID', - 'maskIDsCol': 'None', - 'xCentroidCol': 'x', - 'yCentroidCol': 'y', - 'parentIDcol': 'None', - 'isFirstFrameOne': False, - 'deleteUntrackedIDs': deleteUntrackedIDs + "frameIndexCol": "frame_i", + "trackIDsCol": "ID", + "maskIDsCol": "None", + "xCentroidCol": "x", + "yCentroidCol": "y", + "parentIDcol": "None", + "isFirstFrameOne": False, + "deleteUntrackedIDs": deleteUntrackedIDs, } - - imagesPath = os.path.join(posPath, 'Images') + + imagesPath = os.path.join(posPath, "Images") segmFilename = [ - f for f in myutils.listdir(imagesPath) - if f.endswith(f'{endFilenameSegm}.npz') + f + for f in utils.listdir(imagesPath) + if f.endswith(f"{endFilenameSegm}.npz") ][0] basename = os.path.splitext(segmFilename)[0] - overWriteButton = widgets.savePushButton( - 'Overwrite existing segmentation file' - ) + overWriteButton = widgets.savePushButton("Overwrite existing segmentation file") win = apps.filenameDialog( - basename=f'{basename}_', - hintText='Insert a filename for the tracked masks file:', - allowEmpty=False, defaultEntry='tracked', - additionalButtons=(overWriteButton, ) + basename=f"{basename}_", + hintText="Insert a filename for the tracked masks file:", + allowEmpty=False, + defaultEntry="tracked", + additionalButtons=(overWriteButton,), ) overWriteButton.clicked.connect(partial(self.overWriteClicked, win)) win.exec_() @@ -100,20 +100,19 @@ def run(self, posPath): self.worker.signals.finished.connect(self.callbackOnFinished) self.runWorker(self.worker) return True - + def overWriteClicked(self, win): win.cancel = False - win.filename = '' + win.filename = "" win.close() - + def askDeleteUntrackedIDs(self): msg = widgets.myMessageBox(wrapText=False) txt = html_utils.paragraph( - 'Do you want to remove objects that were not tracked?' + "Do you want to remove objects that were not tracked?" ) _, yesButton, noButton = msg.question( - self, 'Delete untracked objects?', txt, - buttonsTexts=('Cancel', 'No', 'Yes') + self, "Delete untracked objects?", txt, buttonsTexts=("Cancel", "No", "Yes") ) if msg.cancel: return False, False diff --git a/cellacdc/utils/base.py b/cellacdc/tools/base.py similarity index 72% rename from cellacdc/utils/base.py rename to cellacdc/tools/base.py index 263c9a7aa..9b9d5edbd 100644 --- a/cellacdc/utils/base.py +++ b/cellacdc/tools/base.py @@ -4,12 +4,10 @@ from natsort import natsorted from qtpy.QtCore import Qt, QThread, QSize -from qtpy.QtWidgets import ( - QDialog, QVBoxLayout, QHBoxLayout, QLabel -) +from qtpy.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QLabel from qtpy import QtGui -from .. import exception_handler, myutils, html_utils, workers, widgets +from .. import exception_handler, utils, html_utils, workers, widgets from .. import _critical_exception_gui import os @@ -23,40 +21,57 @@ from qtpy.QtCore import Signal, QThread from qtpy.QtWidgets import ( - QDialog, QVBoxLayout, QHBoxLayout, QLabel, QStyle, QApplication + QDialog, + QVBoxLayout, + QHBoxLayout, + QLabel, + QStyle, + QApplication, ) from .. import ( - widgets, apps, workers, html_utils, myutils, - gui, load, printl, exception_handler + widgets, + apps, + workers, + html_utils, + utils, + gui, + load, + printl, + exception_handler, ) + def log_init_util(logger, expPaths: dict, util_title, util_module): exp_paths_str = pprint.pformat(expPaths, indent=1) - + logger.info(f'Utility title: "{util_title}"') logger.info(f'Utility module: "{util_module}"') - logger.info(f'Selected experiments:\n{exp_paths_str}') - - + logger.info(f"Selected experiments:\n{exp_paths_str}") + + class NewThreadMultipleExpBaseUtil(QDialog): def __init__( - self, expPaths, app: QApplication, title: str, module: str, - infoText: str, progressDialogueTitle: str, parent=None - ): + self, + expPaths, + app: QApplication, + title: str, + module: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): super().__init__(parent) self.setWindowTitle(title) self._title = title self._parent = parent - self.progressDialogueTitle = progressDialogueTitle + self.progressDialogueTitle = progressDialogueTitle + + logger, logs_path, log_path, log_filename = utils.setupLogger(module=module) - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module=module - ) - log_init_util(logger, expPaths, title, module) - + self.logger = logger self.log_path = log_path self.log_filename = log_filename @@ -75,7 +90,7 @@ def __init__( infoTxt = html_utils.paragraph(infoText) iconLabel = QLabel(self) - standardIcon = getattr(QStyle, 'SP_MessageBoxInformation') + standardIcon = getattr(QStyle, "SP_MessageBoxInformation") icon = self.style().standardIcon(standardIcon) pixmap = icon.pixmap(60, 60) iconLabel.setPixmap(pixmap) @@ -84,7 +99,7 @@ def __init__( infoLayout.addWidget(QLabel(infoTxt)) buttonsLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton('Cancel') + cancelButton = widgets.cancelPushButton("Cancel") buttonsLayout.addStretch(1) buttonsLayout.addWidget(cancelButton) @@ -99,8 +114,9 @@ def __init__( def runWorker(self, worker): self.progressWin = apps.QDialogWorkerProgress( - title=self.progressDialogueTitle, parent=self, - pbarDesc=f'{self.progressDialogueTitle}...' + title=self.progressDialogueTitle, + parent=self, + pbarDesc=f"{self.progressDialogueTitle}...", ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) @@ -115,33 +131,25 @@ def runWorker(self, worker): self.worker.signals.progress.connect(self.workerProgress) self.worker.signals.critical.connect(self.workerCritical) - self.worker.signals.sigSelectSegmFiles.connect( - self.selectSegmFileLoadData - ) + self.worker.signals.sigSelectSegmFiles.connect(self.selectSegmFileLoadData) self.worker.signals.sigSelectFilesWithText.connect( self.selectFileFromFilesWithText ) self.worker.signals.sigSelectAcdcOutputFiles.connect( self.selectAcdcOutputTables - ) - self.worker.signals.sigSelectSpotmaxRun.connect( - self.selectSpotmaxRun - ) - self.worker.signals.sigSelectFile.connect( - self.selectFile - ) + ) + self.worker.signals.sigSelectSpotmaxRun.connect(self.selectSpotmaxRun) + self.worker.signals.sigSelectFile.connect(self.selectFile) self.worker.signals.sigPermissionError.connect(self.warnPermissionError) self.worker.signals.initProgressBar.connect(self.workerInitProgressbar) self.worker.signals.sigInitInnerPbar.connect(self.workerInitInnerPbar) self.worker.signals.progressBar.connect(self.workerUpdateProgressbar) - self.worker.signals.sigUpdateInnerPbar.connect( - self.workerUpdateInnerPbar - ) + self.worker.signals.sigUpdateInnerPbar.connect(self.workerUpdateInnerPbar) self.worker.signals.sigUpdatePbarDesc.connect(self.workerUpdatePbarDesc) self.thread.started.connect(self.worker.run) self.thread.start() - + def workerInitInnerPbar(self, totalIter): if totalIter <= 1: self.progressWin.innerPbar.hide() @@ -155,40 +163,39 @@ def workerInitProgressbar(self, totalIter): if totalIter == 1: totalIter = 0 self.progressWin.mainPbar.setMaximum(totalIter) - + def workerUpdateInnerPbar(self, step): self.progressWin.innerPbar.update(step) - + def workerUpdateProgressbar(self, step): self.progressWin.mainPbar.update(step) - + def workerUpdatePbarDesc(self, desc): self.progressWin.progressLabel.setText(desc) - + def warnPermissionError(self, traceback_str, path): err_msg = html_utils.paragraph( - 'The file below is open in another app ' - '(Excel maybe?).

    ' - f'{path}

    ' + "The file below is open in another app " + "(Excel maybe?).

    " + f"{path}

    " 'Close file and then press "Ok".' ) msg = widgets.myMessageBox(wrapText=False) msg.setDetailedText(traceback_str) - msg.warning(self, 'Permission error', err_msg) + msg.warning(self, "Permission error", err_msg) self.worker.waitCond.wakeAll() - + def selectAcdcOutputTables( - self, exp_path, pos_foldernames, infoText, allowSingleSelection, - multiSelection - ): + self, exp_path, pos_foldernames, infoText, allowSingleSelection, multiSelection + ): existingAcdcOutputEndnames = set() for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) # Use first found channel, it doesn't matter for basename for chName in chNames: - filePath = myutils.getChannelFilePath(images_path, chName) + filePath = utils.getChannelFilePath(images_path, chName) if filePath: break else: @@ -205,7 +212,7 @@ def selectAcdcOutputTables( _posData.basename, acdc_output_files ) existingAcdcOutputEndnames.update(acdc_output_endnames) - + self.existingAcdcOutputEndnames = list(existingAcdcOutputEndnames) if len(self.existingAcdcOutputEndnames) == 1: @@ -213,35 +220,37 @@ def selectAcdcOutputTables( self.selectedAcdcOutputEndnames = self.existingAcdcOutputEndnames self.worker.waitCond.wakeAll() return - + if multiSelection: selectWindow = apps.OrderableListWidgetDialog( - self.existingAcdcOutputEndnames, - title='Select acdc_output files', - infoTxt=( - 'Select acdc_output tables and choose a table number (optional)

    ' - 'Ctrl+Click to select multiple items
    ' - 'Shift+Click to select a range of items
    ' - ), - helpText=( - 'The table number is useful to ensure that you can load the ' - 'same exact equations you used in a previous sessions.

    ' - 'Cell-ACDC will automatically save the equations you enter. ' - 'They will be saved in a file ending with ' - '_equations_appended_name.ini
    and each table will ' - 'be numbered with the number you enter now.

    ' - 'When you reopen the equations dialogue you can select to load ' - 'equations from a saved .ini file, however,
    only the equations that ' - 'used the table ending with the same name you select now
    ' - 'AND same number can be loaded
    .' + self.existingAcdcOutputEndnames, + title="Select acdc_output files", + infoTxt=( + "Select acdc_output tables and choose a table number (optional)

    " + "Ctrl+Click to select multiple items
    " + "Shift+Click to select a range of items
    " + ), + helpText=( + "The table number is useful to ensure that you can load the " + "same exact equations you used in a previous sessions.

    " + "Cell-ACDC will automatically save the equations you enter. " + "They will be saved in a file ending with " + "_equations_appended_name.ini
    and each table will " + "be numbered with the number you enter now.

    " + "When you reopen the equations dialogue you can select to load " + "equations from a saved .ini file, however,
    only the equations that " + "used the table ending with the same name you select now
    " + "AND same number can be loaded
    ." + ), ) - ) else: selectWindow = widgets.QDialogListbox( - 'Select acdc_output files', - f'Select acdc_output files{infoText}\n', - self.existingAcdcOutputEndnames, multiSelection=multiSelection, - parent=self, allowSingleSelection=allowSingleSelection + "Select acdc_output files", + f"Select acdc_output files{infoText}\n", + self.existingAcdcOutputEndnames, + multiSelection=multiSelection, + parent=self, + allowSingleSelection=allowSingleSelection, ) selectWindow.exec_() self.worker.abort = selectWindow.cancel @@ -249,37 +258,42 @@ def selectAcdcOutputTables( self.worker.waitCond.wakeAll() def selectSpotmaxRun( - self, exp_path, pos_foldernames, all_runs, infoText, - allowSingleSelection, multiSelection - ): - items = natsorted([f'{run}_...{desc}' for run, desc in all_runs]) + self, + exp_path, + pos_foldernames, + all_runs, + infoText, + allowSingleSelection, + multiSelection, + ): + items = natsorted([f"{run}_...{desc}" for run, desc in all_runs]) if len(items) == 1: self.selectedSpotmaxRuns = items self.worker.waitCond.wakeAll() return - + selectWindow = widgets.QDialogListbox( - 'Select spotmax run(s)', - f'Select one or more spotmax runs{infoText}\n', - items, multiSelection=multiSelection, - parent=self, allowSingleSelection=allowSingleSelection + "Select spotmax run(s)", + f"Select one or more spotmax runs{infoText}\n", + items, + multiSelection=multiSelection, + parent=self, + allowSingleSelection=allowSingleSelection, ) selectWindow.exec_() if selectWindow.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.selectedSpotmaxRuns = selectWindow.selectedItemsText self.worker.waitCond.wakeAll() - + def selectFile(self, start_dir, caption, filters): from qtpy.compat import getopenfilename + filepath = getopenfilename( - parent=self, - caption=caption, - basedir=start_dir, - filters=filters + parent=self, caption=caption, basedir=start_dir, filters=filters )[0] if not filepath: self.worker.abort = True @@ -288,19 +302,17 @@ def selectFile(self, start_dir, caption, filters): self.selectedFilepath = filepath self.worker.waitCond.wakeAll() - - def _selectFileFromFilesWithText( - self, exp_path, pos_foldernames, with_text, ext - ): + + def _selectFileFromFilesWithText(self, exp_path, pos_foldernames, with_text, ext): # Get end name of every existing segmentation file existingEndNames = set() for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) # Use first found channel, it doesn't matter for metrics for chName in chNames: - filePath = myutils.getChannelFilePath(images_path, chName) + filePath = utils.getChannelFilePath(images_path, chName) if filePath: break else: @@ -310,74 +322,71 @@ def _selectFileFromFilesWithText( ) _posData = load.loadData(filePath, chName) _posData.getBasenameAndChNames() - if with_text == 'segm': + if with_text == "segm": found_files = load.get_segm_files(_posData.images_path) else: found_files = load.get_files_with( _posData.images_path, with_text, ext=ext ) - _existingEndnames = load.get_endnames( - _posData.basename, found_files - ) + _existingEndnames = load.get_endnames(_posData.basename, found_files) existingEndNames.update(_existingEndnames) if len(existingEndNames) == 1: return existingEndNames, list(existingEndNames)[0], False - if hasattr(self, 'infoText'): + if hasattr(self, "infoText"): infoText = self.infoText else: infoText = None - if with_text == 'segm': - fileType = 'segmentation' - elif with_text == 'imagej_rois': - fileType = 'ImageJ ROIs' + if with_text == "segm": + fileType = "segmentation" + elif with_text == "imagej_rois": + fileType = "ImageJ ROIs" else: - fileType = with_text.split('_') - + fileType = with_text.split("_") + win = apps.SelectSegmFileDialog( - existingEndNames, exp_path, parent=self, infoText=infoText, - fileType=fileType + existingEndNames, + exp_path, + parent=self, + infoText=infoText, + fileType=fileType, ) win.exec_() return existingEndNames, win.selectedItemText, win.cancel - - def selectFileFromFilesWithText( - self, exp_path, pos_foldernames, with_text, ext - ): - + + def selectFileFromFilesWithText(self, exp_path, pos_foldernames, with_text, ext): + out = self._selectFileFromFilesWithText( exp_path, pos_foldernames, with_text, ext ) existingEndNamesWithText, endFilenameWithText, cancel = out - + self.existingEndNamesWithText = list(existingEndNamesWithText) - self.endFilenameWithText = endFilenameWithText + self.endFilenameWithText = endFilenameWithText self.worker.abort = cancel self.worker.waitCond.wakeAll() - + def selectSegmFileLoadData(self, exp_path, pos_foldernames): - out = self._selectFileFromFilesWithText( - exp_path, pos_foldernames, 'segm', None - ) + out = self._selectFileFromFilesWithText(exp_path, pos_foldernames, "segm", None) existingSegmEndNames, endFilenameSegm, cancel = out - + self.existingSegmEndNames = list(existingSegmEndNames) self.endFilenameSegm = endFilenameSegm - + self.worker.abort = cancel self.worker.waitCond.wakeAll() - + # # Get end name of every existing segmentation file # existingSegmEndNames = set() # for p, pos in enumerate(pos_foldernames): # pos_path = os.path.join(exp_path, pos) # images_path = os.path.join(pos_path, 'Images') - # basename, chNames = myutils.getBasenameAndChNames(images_path) + # basename, chNames = utils.getBasenameAndChNames(images_path) # # Use first found channel, it doesn't matter for metrics # for chName in chNames: - # filePath = myutils.getChannelFilePath(images_path, chName) + # filePath = utils.getChannelFilePath(images_path, chName) # if filePath: # break # else: @@ -441,17 +450,17 @@ def workerCritical(self, error): worker = None raise error except: - print('='*20) - if hasattr(self, 'worker'): + print("=" * 20) + if hasattr(self, "worker"): self.worker.logger.log(traceback.format_exc()) - elif worker is not None and hasattr(worker, 'logger'): + elif worker is not None and hasattr(worker, "logger"): worker.logger.log(traceback.format_exc()) - elif hasattr(self, 'logger'): + elif hasattr(self, "logger"): self.logger.log(traceback.format_exc()) else: print(traceback.format_exc()) - print('='*20) - result = _critical_exception_gui(self, f'{self._title} utility') + print("=" * 20) + result = _critical_exception_gui(self, f"{self._title} utility") # mutex and workerFinished handeling try: worker.workerAborted() @@ -484,38 +493,36 @@ def workerFinished(self, worker): self.worker = None self.progressWin = None - def workerProgress(self, text, loggerLevel='INFO'): + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: self.progressWin.logConsole.append(text) self.logger.log(getattr(logging, loggerLevel), text) - + def closeEvent(self, event): - self.logger.info('Closing logger...') + self.logger.info("Closing logger...") handlers = self.logger.handlers[:] for handler in handlers: handler.close() self.logger.removeHandler(handler) + class MainThreadSinglePosUtilBase(QDialog): sigClose = Signal() def __init__( - self, app: QApplication, title: str, module: str, infoText: str, - parent=None - ): + self, app: QApplication, title: str, module: str, infoText: str, parent=None + ): super().__init__(parent) self.setWindowTitle(title) - self.progressDialogueTitle = title + self.progressDialogueTitle = title self._parent = parent - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module=module - ) + logger, logs_path, log_path, log_filename = utils.setupLogger(module=module) logger.info(f'Utility title: "{title}"') logger.info(f'Utility module: "{module}"') - + self.logger = logger self.log_path = log_path self.log_filename = log_filename @@ -532,7 +539,7 @@ def __init__( infoTxt = html_utils.paragraph(infoText) iconLabel = QLabel(self) - standardIcon = getattr(QStyle, 'SP_MessageBoxInformation') + standardIcon = getattr(QStyle, "SP_MessageBoxInformation") icon = self.style().standardIcon(standardIcon) pixmap = icon.pixmap(60, 60) iconLabel.setPixmap(pixmap) @@ -541,7 +548,7 @@ def __init__( infoLayout.addWidget(QLabel(infoTxt)) buttonsLayout = QHBoxLayout() - cancelButton = widgets.cancelPushButton('Close') + cancelButton = widgets.cancelPushButton("Close") buttonsLayout.addStretch(1) buttonsLayout.addWidget(cancelButton) @@ -555,21 +562,22 @@ def __init__( self.worker = None self.setLayout(mainLayout) - + def closeClicked(self): self.sigClose.emit() - + def closeEvent(self, event): - self.logger.info('Closing logger...') + self.logger.info("Closing logger...") handlers = self.logger.handlers[:] for handler in handlers: handler.close() self.logger.removeHandler(handler) - + def runWorker(self, worker): self.progressWin = apps.QDialogWorkerProgress( - title=self.progressDialogueTitle, parent=self, - pbarDesc=f'{self.progressDialogueTitle}...' + title=self.progressDialogueTitle, + parent=self, + pbarDesc=f"{self.progressDialogueTitle}...", ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) @@ -583,18 +591,16 @@ def runWorker(self, worker): self.thread.finished.connect(self.thread.deleteLater) self.worker.signals.progress.connect(self.workerProgress) - self.worker.signals.critical.connect(self.workerCritical) + self.worker.signals.critical.connect(self.workerCritical) self.worker.signals.initProgressBar.connect(self.workerInitProgressbar) self.worker.signals.sigInitInnerPbar.connect(self.workerInitInnerPbar) self.worker.signals.progressBar.connect(self.workerUpdateProgressbar) - self.worker.signals.sigUpdateInnerPbar.connect( - self.workerUpdateInnerPbar - ) + self.worker.signals.sigUpdateInnerPbar.connect(self.workerUpdateInnerPbar) self.worker.signals.sigUpdatePbarDesc.connect(self.workerUpdatePbarDesc) self.thread.started.connect(self.worker.run) self.thread.start() - + def workerCritical(self, error): if self.progressWin is not None: self.progressWin.workerFinished = True @@ -603,9 +609,9 @@ def workerCritical(self, error): raise error except: self.traceback_str = traceback.format_exc() - print('='*20) + print("=" * 20) self.worker.logger.log(self.traceback_str) - print('='*20) + print("=" * 20) def workerFinished(self, worker): if self.progressWin is not None: @@ -615,11 +621,11 @@ def workerFinished(self, worker): self.worker = None self.progressWin = None - def workerProgress(self, text, loggerLevel='INFO'): + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: self.progressWin.logConsole.append(text) self.logger.log(getattr(logging, loggerLevel), text) - + def workerInitInnerPbar(self, totalIter): if totalIter <= 1: self.progressWin.innerPbar.hide() @@ -633,18 +639,18 @@ def workerInitProgressbar(self, totalIter): if totalIter == 1: totalIter = 0 self.progressWin.mainPbar.setMaximum(totalIter) - + def workerUpdateInnerPbar(self, step): self.progressWin.innerPbar.update(step) - + def workerUpdateProgressbar(self, step): self.progressWin.mainPbar.update(step) - + def workerUpdatePbarDesc(self, desc): self.progressWin.progressLabel.setText(desc) - + def progressWinClosed(self, aborted): self.abort = aborted if aborted and self.worker is not None: self.worker.abort = True - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/combineChannels.py b/cellacdc/tools/combineChannels.py similarity index 70% rename from cellacdc/utils/combineChannels.py rename to cellacdc/tools/combineChannels.py index cac68e310..ea1315e1d 100644 --- a/cellacdc/utils/combineChannels.py +++ b/cellacdc/tools/combineChannels.py @@ -2,20 +2,25 @@ import pandas as pd -from .. import apps, myutils, workers, widgets, html_utils, load +from .. import apps, utils, workers, widgets, html_utils, load from .. import printl from .base import NewThreadMultipleExpBaseUtil + class CombineChannelsUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths @@ -25,111 +30,103 @@ def runWorker(self): self.worker.sigAskSetup.connect(self.askSetup) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def askSetup(self, expPaths): self.images_paths = [] chNames = {} for j, (exp_path, pos_foldernames) in enumerate(expPaths.items()): for i, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') + images_path = os.path.join(pos_path, "Images") self.images_paths.append(images_path) - basename, chNames_loc = myutils.getBasenameAndChNames( - images_path - ) + basename, chNames_loc = utils.getBasenameAndChNames(images_path) segm_files = load.get_segm_files(images_path) - segm_endnames = load.get_endnames( - basename, segm_files - ) + segm_endnames = load.get_endnames(basename, segm_files) if i == 0 and j == 0: chNames = set(chNames_loc) chNames.update(segm_endnames) continue - + chNames_loc = set(chNames_loc) chNames_loc.update(segm_endnames) chNames = chNames.intersection(chNames_loc) chNames = sorted(set(chNames)) - + self.worker.basename = basename df_metadata = load.load_metadata_df(images_path) - + win = apps.CombineChannelsSetupDialogUtil( - chNames, - df_metadata=df_metadata, - parent=self + chNames, df_metadata=df_metadata, parent=self ) win.exec_() - + if win.cancel: self.worker.abort = win.cancel self.worker.waitCond.wakeAll() - return - + return + self.worker.keepInputDataType = win.keepInputDataType self.worker.selectedSteps = win.selectedSteps self.worker.nThreads = win.nThreadsSpinBox.value() self.worker.formula = win.formulaEditWidget.text() self.worker.saveAsSegm = win.saveAsSegm() self.worker.waitCond.wakeAll() - + def showEvent(self, event): self.runWorker() - + def getBasenameExtAndExtensionOutputImage(self): saveAsSegm = self.worker.saveAsSegm if saveAsSegm: - basename_ext = 'segm' - ext = '.npz' + basename_ext = "segm" + ext = ".npz" return basename_ext, ext else: - basename_ext = '' - ext = '.tif' + basename_ext = "" + ext = ".tif" return basename_ext, ext - + def askAppendName(self, basename): basename_ext, ext = self.getBasenameExtAndExtensionOutputImage() saveAsSegm = self.worker.saveAsSegm - helpText = ( - f""" + helpText = f""" The {"combined channels" if not saveAsSegm else "combined segmentation"} file will be saved with a different file name.

    Insert a name to append to the end of the new file name. The rest of the name will be the same as the original file base. """ - ) win = apps.filenameDialog( - basename=f'{basename}{basename_ext}', + basename=f"{basename}{basename_ext}", ext=ext, - hintText=f'Insert a name for the {"combined channels" if not saveAsSegm else "combined segmentation"} file:', - defaultEntry='combined', + hintText=f"Insert a name for the {'combined channels' if not saveAsSegm else 'combined segmentation'} file:", + defaultEntry="combined", helpText=helpText, allowEmpty=False, - parent=self + parent=self, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Channel combination aborted.' + txt = "Channel combination aborted." else: - txt = 'Channel combination completed.' + txt = "Channel combination completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/compute.py b/cellacdc/tools/compute.py similarity index 70% rename from cellacdc/utils/compute.py rename to cellacdc/tools/compute.py index 8273a6a3a..48c14796a 100755 --- a/cellacdc/utils/compute.py +++ b/cellacdc/tools/compute.py @@ -9,39 +9,43 @@ from tqdm import tqdm from qtpy.QtCore import Signal, QThread -from qtpy.QtWidgets import ( - QDialog, QVBoxLayout, QHBoxLayout, QLabel, QStyle -) +from qtpy.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QLabel, QStyle from .base import NewThreadMultipleExpBaseUtil from .. import ( - widgets, apps, workers, html_utils, myutils, - gui, cca_functions, load, printl + widgets, + apps, + workers, + html_utils, + utils, + gui, + cca_functions, + load, + printl, ) from .. import cellacdc_path, settings_folderpath favourite_func_metrics_csv_path = os.path.join( - settings_folderpath, 'favourite_func_metrics.csv' + settings_folderpath, "favourite_func_metrics.csv" ) + class computeMeasurmentsUtilWin(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, parent=None, segmEndname='', - doRunComputation=True - ): - title = 'Compute measurements utility' - infoText = 'Computing measurements routine running...' - progressDialogueTitle = 'Computing measurements' - module = myutils.get_module_name(__file__) + self, expPaths, app, parent=None, segmEndname="", doRunComputation=True + ): + title = "Compute measurements utility" + infoText = "Computing measurements routine running..." + progressDialogueTitle = "Computing measurements" + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.parent = parent - + self.cancel = False self.endFilenameSegm = segmEndname @@ -56,15 +60,16 @@ def runWorker(self, showProgress=True, stopFrameNumber=None): self.gui.logger = self.logger self.progressWin = apps.QDialogWorkerProgress( - title='Computing measurements', parent=self, - pbarDesc='Computing measurements...' + title="Computing measurements", + parent=self, + pbarDesc="Computing measurements...", ) self.progressWin.sigClosed.connect(self.progressWinClosed) self.progressWin.show(self.app) if not showProgress: self.progressWin.hide() - + self.thread = QThread() self.worker = workers.ComputeMetricsWorker(self) self.worker.moveToThread(self.thread) @@ -77,23 +82,17 @@ def runWorker(self, showProgress=True, stopFrameNumber=None): self.worker.signals.progress.connect(self.workerProgress) self.worker.signals.critical.connect(self.workerCritical) if not self.endFilenameSegm: - self.worker.signals.sigSelectSegmFiles.connect( - self.selectSegmFileLoadData - ) + self.worker.signals.sigSelectSegmFiles.connect(self.selectSegmFileLoadData) else: - self.worker.signals.sigSelectSegmFiles.connect( - self.wakeUpWorkerThread - ) + self.worker.signals.sigSelectSegmFiles.connect(self.wakeUpWorkerThread) self.worker.signals.sigInitAddMetrics.connect(self.initAddMetricsWorker) self.worker.signals.sigPermissionError.connect(self.warnPermissionError) self.worker.signals.initProgressBar.connect(self.workerInitProgressbar) self.worker.signals.progressBar.connect(self.workerUpdateProgressbar) self.worker.signals.sigUpdatePbarDesc.connect(self.workerUpdatePbarDesc) self.worker.signals.sigComputeVolume.connect(self.computeVolumeRegionprop) - self.worker.signals.sigAskRunNow.connect( - self.askRunNowOrSaveToConfig - ) - + self.worker.signals.sigAskRunNow.connect(self.askRunNowOrSaveToConfig) + if stopFrameNumber is None: self.worker.signals.sigAskStopFrame.connect(self.workerAskStopFrame) else: @@ -104,10 +103,10 @@ def runWorker(self, showProgress=True, stopFrameNumber=None): self.thread.started.connect(self.worker.run) self.thread.start() - + def askRunNowOrSaveToConfig(self, worker): self.worker.savedToWorkflow = False - + txt = html_utils.paragraph(""" Do you want to compute the measurements now
    or save the workflow to a configuration file and run it @@ -117,42 +116,40 @@ def askRunNowOrSaveToConfig(self, worker): (i.e., headless).
    """) msg = widgets.myMessageBox(wrapText=False) - saveButton = widgets.savePushButton('Save and run later') - runNowButton = widgets.playPushButton('Run now') + saveButton = widgets.savePushButton("Save and run later") + runNowButton = widgets.playPushButton("Run now") _, saveButton, runNowButton = msg.question( - self, 'Run workflow now?', txt, - buttonsTexts=( - 'Cancel', saveButton, runNowButton - ) + self, + "Run workflow now?", + txt, + buttonsTexts=("Cancel", saveButton, runNowButton), ) if not msg.clickedButton == saveButton: self.worker.abort = msg.cancel self.worker.waitCond.wakeAll() return - - timestamp = datetime.datetime.now().strftime( - r'%Y-%m-%d_%H-%M' - ) + + timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M") win = apps.filenameDialog( - parent=self, - ext='.ini', - title='Insert filename for configuration file', - hintText='Insert filename for the configuration file', - allowEmpty=False, - defaultEntry=f'{timestamp}_acdc_measurements_workflow' + parent=self, + ext=".ini", + title="Insert filename for configuration file", + hintText="Insert filename for the configuration file", + allowEmpty=False, + defaultEntry=f"{timestamp}_acdc_measurements_workflow", ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + config_filename = win.filename - mostRecentPath = myutils.getMostRecentPath() + mostRecentPath = utils.getMostRecentPath() folder_path = apps.get_existing_directory( allow_images_path=False, - parent=self.progressWin, - caption='Select folder where to save configuration file', + parent=self.progressWin, + caption="Select folder where to save configuration file", basedir=mostRecentPath, # options=QFileDialog.DontUseNativeDialog ) @@ -160,46 +157,44 @@ def askRunNowOrSaveToConfig(self, worker): self.worker.abort = True self.worker.waitCond.wakeAll() return - + config_filepath = os.path.join(folder_path, config_filename) kernel = self.worker.kernel self.saveConfigurationFile(config_filepath, kernel) - + self.worker.savedToWorkflow = True self.worker.waitCond.wakeAll() - + def saveConfigurationFile(self, config_filepath, kernel): - ini_items = {'workflow': {'type': 'measurements'}} - ini_items['measurements'] = kernel.to_workflow_config_params() - paths = [] + ini_items = {"workflow": {"type": "measurements"}} + ini_items["measurements"] = kernel.to_workflow_config_params() + paths = [] stopFrames = [] for pathInfo in self.worker.allPosDataInputs: - images_path = os.path.dirname(pathInfo['file_path']) + images_path = os.path.dirname(pathInfo["file_path"]) paths.append(images_path) - stopFrames.append(pathInfo['stopFrameNum']) - + stopFrames.append(pathInfo["stopFrameNum"]) + load.save_workflow_to_config( - config_filepath, - ini_items, - paths, - stopFrames, - type='measure' + config_filepath, ini_items, paths, stopFrames, type="measure" ) self.worker.kernel.setup_done = True - + txt = html_utils.paragraph( - 'Compute measurements workflow successfully saved to the following location:

    ' - f'{config_filepath}

    ' - 'You can run the workflow with the following command:' + "Compute measurements workflow successfully saved to the following location:

    " + f"{config_filepath}

    " + "You can run the workflow with the following command:" ) command = f'acdc -p "{config_filepath}"' msg = widgets.myMessageBox(wrapText=False) msg.information( - self, 'Workflow save', txt, + self, + "Workflow save", + txt, commands=(command,), - path_to_browse=os.path.dirname(config_filepath) + path_to_browse=os.path.dirname(config_filepath), ) - + def setStopFrame(self, posDatas, stopFrameNumber=1): for p, posData in enumerate(posDatas): if isinstance(stopFrameNumber, int): @@ -209,29 +204,30 @@ def setStopFrame(self, posDatas, stopFrameNumber=1): posData.stopFrameNum = stop_frame_n self.worker.waitCond.wakeAll() - + def wakeUpWorkerThread(self, *args, **kwargs): self.worker.waitCond.wakeAll() - - def warnErrors( - self, standardMetricsErrors, customMetricsErrors, regionPropsErrors - ): + + def warnErrors(self, standardMetricsErrors, customMetricsErrors, regionPropsErrors): if standardMetricsErrors: win = apps.ComputeMetricsErrorsDialog( - standardMetricsErrors, self.logs_path, - log_type='standard_metrics', parent=self + standardMetricsErrors, + self.logs_path, + log_type="standard_metrics", + parent=self, ) win.exec_() if regionPropsErrors: win = apps.ComputeMetricsErrorsDialog( - regionPropsErrors, self.logs_path, - log_type='region_props', parent=self + regionPropsErrors, self.logs_path, log_type="region_props", parent=self ) win.exec_() if customMetricsErrors: win = apps.ComputeMetricsErrorsDialog( - customMetricsErrors, self.logs_path, - log_type='custom_metrics', parent=self + customMetricsErrors, + self.logs_path, + log_type="custom_metrics", + parent=self, ) win.exec_() self.worker.waitCond.wakeAll() @@ -256,14 +252,14 @@ def workerUpdatePbarDesc(self, desc): def warnPermissionError(self, traceback_str, path): err_msg = html_utils.paragraph( - 'The file below is open in another app ' - '(Excel maybe?).

    ' - f'{path}

    ' + "The file below is open in another app " + "(Excel maybe?).

    " + f"{path}

    " 'Close file and then press "Ok".' ) msg = widgets.myMessageBox(wrapText=False) msg.setDetailedText(traceback_str) - msg.warning(self, 'Permission error', err_msg) + msg.warning(self, "Permission error", err_msg) self.worker.waitCond.wakeAll() def selectSegmFileLoadData(self, exp_path, pos_foldernames): @@ -271,11 +267,11 @@ def selectSegmFileLoadData(self, exp_path, pos_foldernames): existingSegmEndNames = set() for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) # Use first found channel, it doesn't matter for metrics for chName in chNames: - filePath = myutils.getChannelFilePath(images_path, chName) + filePath = utils.getChannelFilePath(images_path, chName) if filePath: break else: @@ -286,9 +282,7 @@ def selectSegmFileLoadData(self, exp_path, pos_foldernames): _posData = load.loadData(filePath, chName) _posData.getBasenameAndChNames() segm_files = load.get_segm_files(_posData.images_path) - _existingEndnames = load.get_endnames( - _posData.basename, segm_files - ) + _existingEndnames = load.get_endnames(_posData.basename, segm_files) existingSegmEndNames.update(_existingEndnames) if len(existingSegmEndNames) == 1: @@ -296,23 +290,24 @@ def selectSegmFileLoadData(self, exp_path, pos_foldernames): self.worker.waitCond.wakeAll() return - win = apps.SelectSegmFileDialog( - existingSegmEndNames, exp_path, parent=self - ) + win = apps.SelectSegmFileDialog(existingSegmEndNames, exp_path, parent=self) win.exec_() self.endFilenameSegm = win.selectedItemText self.worker.abort = win.cancel self.worker.waitCond.wakeAll() - + def addCombineMetric(self): isZstack = self.posData.SizeZ > 1 self.combineMetricWindow = apps.combineMetricsEquationDialog( - self.posData.chNames, isZstack, self.posData.isSegm3D, - parent=self.measurementsWin, closeOnOk=False + self.posData.chNames, + isZstack, + self.posData.isSegm3D, + parent=self.measurementsWin, + closeOnOk=False, ) self.combineMetricWindow.sigOk.connect(self.saveCombineMetricsToPosData) self.combineMetricWindow.show() - + def saveCombineMetricsToPosData(self, window): for p, _posData in enumerate(self.allPosData): equationsDict, isMixedChannels = window.getEquationsDict() @@ -321,7 +316,7 @@ def saveCombineMetricsToPosData(self, window): equation, newColName, isMixedChannels ) _posData.saveCombineMetrics() - + self.combineMetricWindow.close() self.measurementsWinState = self.measurementsWin.state() self.measurementsWin.restart() @@ -332,20 +327,20 @@ def initAddMetricsWorker(self, posData, allPosDataInputs): # Set measurements try: df_favourite_funcs = pd.read_csv(favourite_func_metrics_csv_path) - favourite_funcs = df_favourite_funcs['favourite_func_name'].to_list() + favourite_funcs = df_favourite_funcs["favourite_func_name"].to_list() except Exception as e: favourite_funcs = None self.posData = posData self.allPosDataInputs = allPosDataInputs - if not hasattr(self, 'allPosData'): + if not hasattr(self, "allPosData"): self.allPosData = [] for p, posDataInputs in enumerate(self.allPosDataInputs): - combineMetricsConfig = posDataInputs['combineMetricsConfig'] - combineMetricsPath = posDataInputs['combineMetricsPath'] + combineMetricsConfig = posDataInputs["combineMetricsConfig"] + combineMetricsPath = posDataInputs["combineMetricsPath"] - # Here we build a placeholder loadData class but we get what is + # Here we build a placeholder loadData class but we get what is # needed to save custom combine metrics from posDataInputs _posData = load.loadData( self.posData.imgPath, self.posData.user_ch_name @@ -355,15 +350,19 @@ def initAddMetricsWorker(self, posData, allPosDataInputs): self.allPosData.append(_posData) self.measurementsWin = apps.SetMeasurementsDialog( - posData.chNames, [], posData.SizeZ > 1, posData.isSegm3D, - favourite_funcs=favourite_funcs, posData=posData, + posData.chNames, + [], + posData.SizeZ > 1, + posData.isSegm3D, + favourite_funcs=favourite_funcs, + posData=posData, addCombineMetricCallback=self.addCombineMetric, - allPosData=self.allPosData + allPosData=self.allPosData, ) self.measurementsWin.sigClosed.connect(self.askSaveObjectsCount) self.measurementsWin.sigCancel.connect(self.abortWorkerMeasurementsWin) self.measurementsWin.show() - + def abortWorkerMeasurementsWin(self): self.worker.abort = self.measurementsWin.cancel self.worker.waitCond.wakeAll() @@ -378,22 +377,22 @@ def askSaveObjectsCount(self): ending with acdc_objects_count. """) noButton, yesButton = msg.question( - self, 'Save objects count?', txt, - buttonsTexts=('No', 'Yes, save objects count') + self, + "Save objects count?", + txt, + buttonsTexts=("No", "Yes, save objects count"), ) if msg.clickedButton == yesButton: self.worker.kernel.set_save_objects_count_table(True) - + self.startSaveDataWorker() - + def startSaveDataWorker(self): - self.worker.kernel.init_args( - self.posData.chNames, self.endFilenameSegm - ) + self.worker.kernel.init_args(self.posData.chNames, self.endFilenameSegm) self.worker.kernel.set_metrics_from_set_measurements_dialog( self.measurementsWin ) - + if not self.doRunComputation: self.worker.setup_done = True self.worker.abort = True @@ -403,7 +402,7 @@ def startSaveDataWorker(self): self.gui.mutex = self.worker.mutex self.gui.waitCond = self.worker.waitCond self.gui.saveWin = self.progressWin - + self.gui.saveDataWorker = workers.saveDataWorker(self.gui) self.gui.saveDataWorker.criticalPermissionError.connect(self.skipEvent) @@ -411,44 +410,42 @@ def startSaveDataWorker(self): self.gui.saveDataWorker.customMetricsCritical.connect( self.addCombinedMetricsError ) - self.gui.saveDataWorker.regionPropsCritical.connect( - self.addRegionPropsErrors - ) + self.gui.saveDataWorker.regionPropsCritical.connect(self.addRegionPropsErrors) self.gui.worker = self.gui.saveDataWorker self.worker.waitCond.wakeAll() - + def addRegionPropsErrors(self, traceback_format, error_message): - self.logger.info('') - print('====================================') + self.logger.info("") + print("====================================") self.logger.info(traceback_format) - print('====================================') + print("====================================") self.worker.regionPropsErrors[error_message] = traceback_format - + def addCombinedMetricsError(self, traceback_format, func_name): - self.logger.info('') - print('====================================') + self.logger.info("") + print("====================================") self.logger.info(traceback_format) - print('====================================') + print("====================================") self.worker.customMetricsErrors[func_name] = traceback_format def skipEvent(self, dummy): self.worker.waitCond.wakeAll() def computeVolumeRegionprop(self, end_frame_i, posData): - if 'cell_vol_vox' not in self.worker.kernel.sizeMetricsToSave: + if "cell_vol_vox" not in self.worker.kernel.sizeMetricsToSave: self.worker.waitCond.wakeAll() return # We compute the cell volume in the main thread because calling # skimage.transform.rotate in a separate thread causes crashes # with segmentation fault on macOS. I don't know why yet. - self.logger.info('Computing cell volume...') + self.logger.info("Computing cell volume...") PhysicalSizeY = posData.PhysicalSizeY PhysicalSizeX = posData.PhysicalSizeX - iterable = enumerate(tqdm(posData.allData_li[:end_frame_i+1], ncols=100)) + iterable = enumerate(tqdm(posData.allData_li[: end_frame_i + 1], ncols=100)) for frame_i, data_dict in iterable: - lab = data_dict['labels'] - rp = data_dict['regionprops'] + lab = data_dict["labels"] + rp = data_dict["regionprops"] obj_iter = tqdm(rp, ncols=100, position=1, leave=False) for i, obj in enumerate(obj_iter): vol_vox, vol_fl = cca_functions._calc_rot_vol( @@ -456,7 +453,7 @@ def computeVolumeRegionprop(self, end_frame_i, posData): ) obj.vol_vox = vol_vox obj.vol_fl = vol_fl - posData.allData_li[frame_i]['regionprops'] = rp + posData.allData_li[frame_i]["regionprops"] = rp self.worker.waitCond.wakeAll() def progressWinClosed(self, aborted): @@ -476,29 +473,29 @@ def workerFinished(self, worker): if self.progressWin is not None: self.progressWin.workerFinished = True self.progressWin.close() - + if worker.setup_done: - txt = 'Measurements set up completed.' + txt = "Measurements set up completed." self.logger.info(txt) elif worker.abort: - txt = 'Computing measurements cancelled.' + txt = "Computing measurements cancelled." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.warning(self, 'Process cancelled', html_utils.paragraph(txt)) - + msg.warning(self, "Process cancelled", html_utils.paragraph(txt)) + else: - txt = 'Computing measurements completed.' + txt = "Computing measurements completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) self.isWorkerFinished = True self.progressWin = None self.close() - def workerProgress(self, text, loggerLevel='INFO'): + def workerProgress(self, text, loggerLevel="INFO"): if self.progressWin is not None: self.progressWin.logConsole.append(text) - if loggerLevel.upper() == 'EXCEPTION': - loggerLevel = 'ERROR' + if loggerLevel.upper() == "EXCEPTION": + loggerLevel = "ERROR" self.logger.log(getattr(logging, loggerLevel.upper()), text) diff --git a/cellacdc/utils/computeMultiChannel.py b/cellacdc/tools/computeMultiChannel.py similarity index 67% rename from cellacdc/utils/computeMultiChannel.py rename to cellacdc/tools/computeMultiChannel.py index 7765a4ed6..b76164b59 100644 --- a/cellacdc/utils/computeMultiChannel.py +++ b/cellacdc/tools/computeMultiChannel.py @@ -1,21 +1,26 @@ import pandas as pd -from .. import apps, myutils, workers, widgets, html_utils, load +from .. import apps, utils, workers, widgets, html_utils, load from .base import NewThreadMultipleExpBaseUtil + class ComputeMetricsMultiChannel(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.ComputeMetricsMultiChannelWorker(self) self.worker.sigAskAppendName.connect(self.askAppendName) @@ -25,14 +30,17 @@ def runWorker(self): self.worker.sigHowCombineMetrics.connect(self.showHowCombineMetrics) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def showHowCombineMetrics( - self, imagesPath, selectedAcdcOutputEndnames, - existingAcdcOutputEndnames, allChNames - ): + self, + imagesPath, + selectedAcdcOutputEndnames, + existingAcdcOutputEndnames, + allChNames, + ): self.imagesPath = imagesPath self.existingAcdcOutputEndnames = existingAcdcOutputEndnames acdcDfsDict = {} @@ -44,12 +52,8 @@ def showHowCombineMetrics( self.combineWindow = apps.CombineMetricsMultiDfsSummaryDialog( acdcDfsDict, allChNames, parent=self ) - self.combineWindow.setLogger( - self.logger, self.logs_path, self.log_path - ) - self.combineWindow.sigLoadAdditionalAcdcDf.connect( - self.loadAdditionalAcdcDf - ) + self.combineWindow.setLogger(self.logger, self.logs_path, self.log_path) + self.combineWindow.sigLoadAdditionalAcdcDf.connect(self.loadAdditionalAcdcDf) self.combineWindow.exec_() if self.combineWindow.cancel: self.worker.abort = True @@ -59,19 +63,21 @@ def showHowCombineMetrics( self.worker.equations = self.combineWindow.equations self.worker.acdcDfs = self.combineWindow.acdcDfs self.worker.waitCond.wakeAll() - + def loadAdditionalAcdcDf(self): selectWindow = widgets.QDialogListbox( - 'Select acdc_output files', - f'Select acdc_output files to load\n', - self.existingAcdcOutputEndnames, multiSelection=True, - parent=self, allowSingleSelection=True + "Select acdc_output files", + f"Select acdc_output files to load\n", + self.existingAcdcOutputEndnames, + multiSelection=True, + parent=self, + allowSingleSelection=True, ) selectWindow.exec_() if selectWindow.cancel or not selectWindow.selectedItemsText: - self.logger.info('Loading additional tables cancelled.') + self.logger.info("Loading additional tables cancelled.") return - + acdcDfsDict = {} for end in selectWindow.selectedItemsText: filePath, _ = load.get_path_from_endname(end, self.imagesPath) @@ -87,76 +93,72 @@ def criticalNotEnoughSegmFiles(self, exp_path): """) msg = widgets.myMessageBox(wrapText=False, showCentered=False) msg.addShowInFileManagerButton(exp_path) - msg.critical( - self, 'Not enough segmentation files!', text - ) + msg.critical(self, "Not enough segmentation files!", text) self.worker.abort = True self.worker.waitCond.wakeAll() - + def askAppendName(self, basename, existingEndnames, selectedEndnames): - helpText = ( - """ + helpText = """ The CSV table file with the combined measurements will be saved with a different file name.

    Insert a name to append to the end of the new name. The rest of the name will have the same basename as all other files. """ - ) - channels = [end.replace('acdc_output_', '') for end in selectedEndnames] - channels = [end.replace('acdc_output', '') for end in channels] - channels = [end if end else 'refCh' for end in channels] + channels = [end.replace("acdc_output_", "") for end in selectedEndnames] + channels = [end.replace("acdc_output", "") for end in channels] + channels = [end if end else "refCh" for end in channels] defaultEntry = f"{'_'.join(channels)}_combined_metrics" win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the new, table file:', - existingNames=existingEndnames, - helpText=helpText, + hintText="Insert a name for the new, table file:", + existingNames=existingEndnames, + helpText=helpText, allowEmpty=False, - ext='.csv', + ext=".csv", defaultEntry=defaultEntry, - resizeOnShow=True + resizeOnShow=True, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerCritical(self, error): super().workerCritical(error) self.worker.errors[error] = self.traceback_str - + def warnErrors(self, errors): win = apps.ComputeMetricsErrorsDialog( - errors, self.logs_path, log_type='generic', parent=self + errors, self.logs_path, log_type="generic", parent=self ) win.exec_() - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Combining multiple channels measurements aborted.' + txt = "Combining multiple channels measurements aborted." isWarning = True elif worker.errors: - txt = 'Combining multiple channels measurements completed WITH ERRORS.' + txt = "Combining multiple channels measurements completed WITH ERRORS." self.warnErrors(worker.errors) isWarning = True else: txt = html_utils.paragraph( - 'Combining multiple channels measurements completed.

    ' - 'Results were saved in the respective Position folder(s).' + "Combining multiple channels measurements completed.

    " + "Results were saved in the respective Position folder(s)." ) isWarning = False self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if isWarning: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/concat.py b/cellacdc/tools/concat.py similarity index 66% rename from cellacdc/utils/concat.py rename to cellacdc/tools/concat.py index f4d4898ea..069e56b2f 100755 --- a/cellacdc/utils/concat.py +++ b/cellacdc/tools/concat.py @@ -2,30 +2,35 @@ from cellacdc import measurements -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .. import printl from .base import NewThreadMultipleExpBaseUtil + class ConcatWin(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - if title.find('spotMAX') != -1: + if title.find("spotMAX") != -1: self.worker_func = workers.ConcatSpotmaxDfsWorker - self._infoText = 'spotMAX' + self._infoText = "spotMAX" else: self.worker_func = workers.ConcatAcdcDfsWorker - self._infoText = 'acdc_output' - - def runWorker(self, format='CSV'): + self._infoText = "acdc_output" + + def runWorker(self, format="CSV"): self.worker = self.worker_func(self, format=format) self.worker.sigAskFolder.connect(self.askFolderWhereToSaveAllExp) self.worker.sigAborted.connect(self.workerAborted) @@ -33,33 +38,36 @@ def runWorker(self, format='CSV'): self.worker.sigSetMeasurements.connect(self.askSetMeasurements) self.worker.signals.sigAskCopyCca.connect(self.askCopyCcaFromAcdcOutput) super().runWorker(self.worker) - + def askCopyCcaFromAcdcOutput(self, images_path): acdc_output_tables = [] - for file in myutils.listdir(images_path): - if not file.endswith('.csv'): + for file in utils.listdir(images_path): + if not file.endswith(".csv"): continue - - idx = file.find('acdc_output') + + idx = file.find("acdc_output") if idx == -1: continue - + acdc_output_tables.append(file[idx:]) - + if not acdc_output_tables: self.worker.waitCond.wakeAll() return - + txt = html_utils.paragraph( - 'Do you want to copy cell cycle annotations
    ' - 'from one of the tables below?

    ' - 'If yes, please select from which table you want to copy from:' + "Do you want to copy cell cycle annotations
    " + "from one of the tables below?

    " + "If yes, please select from which table you want to copy from:" ) - noButton = widgets.noPushButton('No, do not copy') + noButton = widgets.noPushButton("No, do not copy") selectTableNameWin = widgets.QDialogListbox( - 'Copy cell cycle annotations?', txt, - acdc_output_tables, multiSelection=False, parent=self, - additionalButtons=(noButton,) + "Copy cell cycle annotations?", + txt, + acdc_output_tables, + multiSelection=False, + parent=self, + additionalButtons=(noButton,), ) noButton.clicked.connect(selectTableNameWin.ok_cb) selectTableNameWin.exec_() @@ -72,74 +80,74 @@ def askCopyCcaFromAcdcOutput(self, images_path): if selectTableNameWin.clickedButton == noButton: self.worker.waitCond.wakeAll() return - - self.worker.setAcdcOutputEndname(selectedTableNames[0]) + + self.worker.setAcdcOutputEndname(selectedTableNames[0]) self.worker.waitCond.wakeAll() - + def askSetMeasurements(self, kwargs): - loadedChNames = kwargs['loadedChNames'] - notLoadedChNames = kwargs['notLoadedChNames'] - isZstack = kwargs['isZstack'] - isSegm3D = kwargs['isSegm3D'] + loadedChNames = kwargs["loadedChNames"] + notLoadedChNames = kwargs["notLoadedChNames"] + isZstack = kwargs["isZstack"] + isSegm3D = kwargs["isSegm3D"] self.setMeasurementsWin = apps.SetMeasurementsDialog( - loadedChNames, notLoadedChNames, isZstack, isSegm3D, - is_concat=True, parent=self - ) - existing_colnames = kwargs['existing_colnames'] - self.setMeasurementsWin.addNonMeasurementColumns( - existing_colnames - ) - self.setMeasurementsWin.setDisabledNotExistingMeasurements( - existing_colnames + loadedChNames, + notLoadedChNames, + isZstack, + isSegm3D, + is_concat=True, + parent=self, ) + existing_colnames = kwargs["existing_colnames"] + self.setMeasurementsWin.addNonMeasurementColumns(existing_colnames) + self.setMeasurementsWin.setDisabledNotExistingMeasurements(existing_colnames) self.setMeasurementsWin.sigClosed.connect(self.setMeasurements) self.setMeasurementsWin.sigCancel.connect(self.setMeasurementsCancelled) self.setMeasurementsWin.show() - + def setMeasurements(self): selectedColumns = [] - if hasattr(self.setMeasurementsWin, 'nonMeasurementsGroupbox'): + if hasattr(self.setMeasurementsWin, "nonMeasurementsGroupbox"): if self.setMeasurementsWin.nonMeasurementsGroupbox.isChecked(): groupbox = self.setMeasurementsWin.nonMeasurementsGroupbox - for checkBox in groupbox.checkBoxes: + for checkBox in groupbox.checkBoxes: if not checkBox.isEnabled(): continue if not checkBox.isChecked(): continue colname = checkBox.text() - selectedColumns.append(colname) - + selectedColumns.append(colname) + for chNameGroupbox in self.setMeasurementsWin.chNameGroupboxes: chName = chNameGroupbox.chName if not chNameGroupbox.isChecked(): # Skip entire channel continue - + for checkBox in chNameGroupbox.checkBoxes: if not checkBox.isEnabled(): continue - + if not checkBox.isChecked(): continue colname = checkBox.text() selectedColumns.append(colname) - + if self.setMeasurementsWin.sizeMetricsQGBox.isChecked(): for checkBox in self.setMeasurementsWin.sizeMetricsQGBox.checkBoxes: if not checkBox.isEnabled(): continue - + if not checkBox.isChecked(): continue colname = checkBox.text() selectedColumns.append(colname) - + selectedPropsNames = [] if self.setMeasurementsWin.regionPropsQGBox.isChecked(): for checkBox in self.setMeasurementsWin.regionPropsQGBox.checkBoxes: if not checkBox.isEnabled(): continue - + if not checkBox.isChecked(): continue colname = checkBox.text() @@ -148,7 +156,7 @@ def setMeasurements(self): self.setMeasurementsWin.existing_colnames, selectedPropsNames ) selectedColumns.extend(selectedRpCols) - + checkMixedChannel = ( self.setMeasurementsWin.mixedChannelsCombineMetricsQGBox is not None and self.setMeasurementsWin.mixedChannelsCombineMetricsQGBox.isChecked() @@ -159,87 +167,84 @@ def setMeasurements(self): for checkBox in checkBoxes: if not checkBox.isEnabled(): continue - + if not checkBox.isChecked(): continue colname = checkBox.text() selectedColumns.append(colname) - + self.worker.selectedColumns = selectedColumns self.worker.abort = False self.worker.waitCond.wakeAll() - + def setMeasurementsCancelled(self): self.worker.abort = True self.worker.waitCond.wakeAll() - + def showEvent(self, event): - formats = ( - 'CSV (Comma Separated Values)', - 'XLS (Excel)' - ) + formats = ("CSV (Comma Separated Values)", "XLS (Excel)") selectFormatWin = widgets.QDialogListbox( - 'Select output file format', - 'Select format of the output file\n', - formats, multiSelection=False, parent=self + "Select output file format", + "Select format of the output file\n", + formats, + multiSelection=False, + parent=self, ) selectFormatWin.exec_() if selectFormatWin.cancel: return - - if selectFormatWin.selectedItemsText[0].startswith('CSV'): - self._ext = '.csv' + + if selectFormatWin.selectedItemsText[0].startswith("CSV"): + self._ext = ".csv" else: - self._ext = '.xlsx' - myutils.check_install_package( - 'OpenPyXL', - import_pkg_name='openpyxl', - pypi_name='XlsxWriter' + self._ext = ".xlsx" + utils.check_install_package( + "OpenPyXL", import_pkg_name="openpyxl", pypi_name="XlsxWriter" ) self.runWorker(format=selectFormatWin.selectedItemsText[0]) - + def askAppendName(self, basename, existingEndnames): win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the concatenated table file:', - existingNames=existingEndnames, + hintText="Insert a name for the concatenated table file:", + existingNames=existingEndnames, allowEmpty=True, - ext=self._ext + ext=self._ext, ) win.exec_() if win.cancel: - self.worker.abort = True - else: + self.worker.abort = True + else: self.worker.concat_df_filename = win.filename self.worker.waitCond.wakeAll() - + def askFolderWhereToSaveAllExp(self, allExp_filename): - txt = (""" + txt = """ After clicking "Ok" you will be asked to select a folder where you want to save the file
    with the concatenated tables from the multiple experiments selected
    - """) + """ if allExp_filename: - txt = f'{txt}(the filename will be {allExp_filename})' - + txt = f"{txt}(the filename will be {allExp_filename})" + txt = html_utils.paragraph(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Select folder', txt) + msg.information(self, "Select folder", txt) if msg.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() - - - mostRecentPath = myutils.getMostRecentPath() + + mostRecentPath = utils.getMostRecentPath() save_to_dir = QFileDialog.getExistingDirectory( - self, f'Select folder where to save multiple experiments table', - mostRecentPath + self, + f"Select folder where to save multiple experiments table", + mostRecentPath, ) if not save_to_dir: self.worker.abort = True self.worker.waitCond.wakeAll() - + self.worker.allExpSaveFolder = save_to_dir self.worker.waitCond.wakeAll() @@ -247,17 +252,17 @@ def askFolderWhereToSaveAllExp(self, allExp_filename): def workerAborted(self): self.worker.signals.finished.emit(self) self.workerFinished(self.worker, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = f'Concatenating {self._infoText} tables aborted.' + txt = f"Concatenating {self._infoText} tables aborted." else: - txt = f'Concatenating {self._infoText} tables completed.' + txt = f"Concatenating {self._infoText} tables completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/convert.py b/cellacdc/tools/convert.py similarity index 62% rename from cellacdc/utils/convert.py rename to cellacdc/tools/convert.py index 987d03487..c18984d23 100755 --- a/cellacdc/utils/convert.py +++ b/cellacdc/tools/convert.py @@ -15,13 +15,19 @@ import skimage.color from qtpy.QtWidgets import ( - QApplication, QMainWindow, QFileDialog, - QVBoxLayout, QPushButton, QLabel, QStyleFactory, - QWidget, QMessageBox, QDialog, QHBoxLayout -) -from qtpy.QtCore import ( - Qt, QEventLoop, QSize, QThread, Signal, QObject + QApplication, + QMainWindow, + QFileDialog, + QVBoxLayout, + QPushButton, + QLabel, + QStyleFactory, + QWidget, + QMessageBox, + QDialog, + QHBoxLayout, ) +from qtpy.QtCore import Qt, QEventLoop, QSize, QThread, Signal, QObject from qtpy import QtGui script_path = os.path.dirname(os.path.realpath(__file__)) @@ -30,26 +36,33 @@ # Custom modules from .. import exception_handler, printl -from .. import prompts, load, myutils, apps, load, widgets, html_utils +from .. import prompts, load, utils, apps, load, widgets, html_utils from .. import workers from .. import cellacdc_path, recentPaths_path, settings_folderpath from .. import io -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except Exception as e: pass + class convertFileFormatWin(QMainWindow): def __init__( - self, parent=None, allowExit=False, - actionToEnable=None, mainWin=None, - from_='npz', to='npy', info='' - ): + self, + parent=None, + allowExit=False, + actionToEnable=None, + mainWin=None, + from_="npz", + to="npy", + info="", + ): self.from_ = from_ self.to = to self.info = info @@ -68,16 +81,16 @@ def __init__( mainLayout = QVBoxLayout() titleText = html_utils.paragraph( - f'
    Converting .{from_} to .{to} routine running...', - font_size='14px' + f"
    Converting .{from_} to .{to} routine running...", + font_size="14px", ) titleLabel = QLabel(titleText) mainLayout.addWidget(titleLabel) infoTxt = ( - 'Follow the instructions in the pop-up windows.
    ' - 'Note that pop-ups might be minimized or behind other open windows.

    ' - 'Progess is displayed in the terminal/console.' + "Follow the instructions in the pop-up windows.
    " + "Note that pop-ups might be minimized or behind other open windows.

    " + "Progess is displayed in the terminal/console." ) informativeLabel = QLabel(html_utils.paragraph(infoTxt)) @@ -86,7 +99,7 @@ def __init__( informativeLabel.setAlignment(Qt.AlignLeft) mainLayout.addWidget(informativeLabel) - abortButton = QPushButton('Stop processs') + abortButton = QPushButton("Stop processs") abortButton.clicked.connect(self.close) mainLayout.addWidget(abortButton) @@ -95,23 +108,26 @@ def __init__( def getMostRecentPath(self): if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - self.MostRecentPath = df.iloc[0]['path'] + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + self.MostRecentPath = df.iloc[0]["path"] if not isinstance(self.MostRecentPath, str): - self.MostRecentPath = '' + self.MostRecentPath = "" else: - self.MostRecentPath = '' + self.MostRecentPath = "" def main(self): self.getMostRecentPath() exp_path = QFileDialog.getExistingDirectory( - self, 'Select experiment folder containing Position_n folders ' - 'or specific Position_n folder', self.MostRecentPath) + self, + "Select experiment folder containing Position_n folders " + "or specific Position_n folder", + self.MostRecentPath, + ) self.addToRecentPaths(exp_path) - if exp_path == '': + if exp_path == "": abort = self.doAbort() if abort: self.close() @@ -121,25 +137,23 @@ def main(self): f'Cell-ACDC - Convert .{self.from_} to .{self.to} - "{exp_path}"' ) - folder_type = myutils.determine_folder_type(exp_path) + folder_type = utils.determine_folder_type(exp_path) is_pos_folder, is_images_folder, exp_path = folder_type - print('Loading data...') + print("Loading data...") if not is_pos_folder and not is_images_folder: select_folder = load.select_exp_folder() values = select_folder.get_values_segmGUI(exp_path) if not values: txt = html_utils.paragraph( - 'The selected folder:

    ' - f'{exp_path}

    ' - 'is not a valid folder. ' - 'Select a folder that contains the Position_n folders' + "The selected folder:

    " + f"{exp_path}

    " + "is not a valid folder. " + "Select a folder that contains the Position_n folders" ) msg = widgets.myMessageBox() - msg.critical( - self, 'Incompatible folder', txt - ) + msg.critical(self, "Incompatible folder", txt) self.close() return @@ -154,19 +168,20 @@ def main(self): else: pos_foldernames = values - images_paths = [os.path.join(exp_path, pos, 'Images') - for pos in pos_foldernames] + images_paths = [ + os.path.join(exp_path, pos, "Images") for pos in pos_foldernames + ] elif is_pos_folder: pos_foldername = os.path.basename(exp_path) exp_path = os.path.dirname(exp_path) - images_paths = [f'{exp_path}/{pos_foldername}/Images'] + images_paths = [f"{exp_path}/{pos_foldername}/Images"] elif is_images_folder: images_paths = [exp_path] proceed, selectedFilenames = self.selectFiles( - images_paths[0], filterExt=[f'{self.from_}'] + images_paths[0], filterExt=[f"{self.from_}"] ) if not proceed: abort = self.doAbort() @@ -183,43 +198,48 @@ def main(self): self.close() return - print(f'Converting .{self.from_} to .{self.to} started...') + print(f"Converting .{self.from_} to .{self.to} started...") if len(images_paths) > 1: - _endswith = selectedFilenames[0][len(basename):] + _endswith = selectedFilenames[0][len(basename) :] if not _endswith: - self.criticalNoCommonBasename( - selectedFilenames, images_paths[0] - ) + self.criticalNoCommonBasename(selectedFilenames, images_paths[0]) self.close() return for pos_i, images_path in enumerate(tqdm(images_paths, ncols=100)): - ls = myutils.listdir(images_path) - _basename = self.getBasename( - images_path, selectedFilenames - ) + ls = utils.listdir(images_path) + _basename = self.getBasename(images_path, selectedFilenames) for file in ls: if file.endswith(_endswith): proceed = self.convert( - images_path, file, appendedTxt, _basename, - from_=self.from_, to=self.to, prompt=False + images_path, + file, + appendedTxt, + _basename, + from_=self.from_, + to=self.to, + prompt=False, ) if not proceed: self.close() return else: proceed = self.convert( - images_paths[0], selectedFilenames[0], appendedTxt, basename, - from_=self.from_, to=self.to + images_paths[0], + selectedFilenames[0], + appendedTxt, + basename, + from_=self.from_, + to=self.to, ) - + self.success = True self.close() if self.allowExit: - exit('Done.') + exit("Done.") def getBasename(self, images_path, selectedFilenames): - commonStartFilenames = myutils.filterCommonStart(images_path) + commonStartFilenames = utils.filterCommonStart(images_path) selector = prompts.select_channel_name() _, noBasename = selector.get_available_channels( commonStartFilenames, images_path, useExt=None @@ -229,139 +249,150 @@ def getBasename(self, images_path, selectedFilenames): else: basename = selector.basename - if basename.endswith('_'): - if self.info.startswith('_'): - basename = f'{basename}{self.info[1:]}' + if basename.endswith("_"): + if self.info.startswith("_"): + basename = f"{basename}{self.info[1:]}" else: - basename = f'{basename}{self.info}' + basename = f"{basename}{self.info}" else: - basename = f'{basename}_{self.info}' + basename = f"{basename}_{self.info}" return basename def convert( - self, images_path, filename, appendedTxt, basename, - from_='npz', to='npy', prompt=True - ): + self, + images_path, + filename, + appendedTxt, + basename, + from_="npz", + to="npy", + prompt=True, + ): filePath = os.path.join(images_path, filename) - if self.from_ == 'npz': - data = np.load(filePath)['arr_0'] - elif self.from_ == 'npy': + if self.from_ == "npz": + data = np.load(filePath)["arr_0"] + elif self.from_ == "npy": data = np.load(filePath) - elif self.from_ == 'tif': + elif self.from_ == "tif": data = load.imread(filePath) - elif self.from_ == 'h5': + elif self.from_ == "h5": data = load.h5dump_to_arr(filePath) - if self.info.find('segm') != -1: + if self.info.find("segm") != -1: data = data.astype(np.uint32) filename, ext = os.path.splitext(filename) if appendedTxt: - if basename.endswith('_'): + if basename.endswith("_"): basename = basename[:-1] - newFilename = f'{basename}_{appendedTxt}.{self.to}' + newFilename = f"{basename}_{appendedTxt}.{self.to}" else: - newFilename = f'{basename}.{self.to}' + newFilename = f"{basename}.{self.to}" newPath = os.path.join(images_path, newFilename) if os.path.exists(newPath): newPath = self.warnFileExisting(newPath) if not newPath: return False - if self.to == 'npy': + if self.to == "npy": np.save(newPath, data) - elif self.to == 'tif': - myutils.to_tiff(newPath, data) - elif self.to == 'npz': + elif self.to == "tif": + utils.to_tiff(newPath, data) + elif self.to == "npz": io.savez_compressed(newPath, data) - print('') - print('-'*30) + print("") + print("-" * 30) print(f'File "{filePath}" saved to "{newPath}"') - print('-'*30) + print("-" * 30) if prompt: self.conversionDone(filePath, newPath) return True - + def warnFileExisting(self, newFilePath): msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph(f""" The following file is already existing:

    - {myutils.trim_path(newFilePath, depth=4)}

    + {utils.trim_path(newFilePath, depth=4)}

    What do you want to do? """) msg.addShowInFileManagerButton(newFilePath) _, overwriteButton, renameButton = msg.warning( - self, 'File existing', txt, - buttonsTexts=('Cancel', 'Overwrite existing', 'Rename new file') + self, + "File existing", + txt, + buttonsTexts=("Cancel", "Overwrite existing", "Rename new file"), ) if msg.cancel: - return '' - + return "" + if msg.clickedButton == overwriteButton: return newFilePath - + if msg.clickedButton == renameButton: folderName = os.path.dirname(newFilePath) filename, ext = os.path.splitext(os.path.basename(newFilePath)) win = apps.filenameDialog( - basename=filename, ext=ext, allowEmpty=False, - hintText='Insert a filename for the new file:
    ' + basename=filename, + ext=ext, + allowEmpty=False, + hintText="Insert a filename for the new file:
    ", ) win.exec_() if win.cancel: - return '' + return "" newFilePath = os.path.join(folderName, win.filename) return newFilePath - def conversionDone(self, src, dst): msg = widgets.myMessageBox() msg.setWidth(700) parent_path = os.path.dirname(dst) txt = ( - 'Done!

    ' - f'The file below was converted to .{self.to}, and saved' + "Done!

    " + f"The file below was converted to .{self.to}, and saved" ) msg.addShowInFileManagerButton(parent_path) msg.information( - self, 'Conversion done!', html_utils.paragraph(txt), - path_to_browse=parent_path, - commands=(src, dst) + self, + "Conversion done!", + html_utils.paragraph(txt), + path_to_browse=parent_path, + commands=(src, dst), ) def askTxtAppend(self, basename): hintText = html_utils.paragraph( - 'OPTIONAL: write here an additional text to append ' - 'to the filename' + "OPTIONAL: write here an additional text to append to the filename" ) - if basename.endswith('_'): + if basename.endswith("_"): basename = basename[:-1] win = apps.filenameDialog( - ext=self.to, title='New filename', - hintText=hintText, parent=self, basename=basename + ext=self.to, + title="New filename", + hintText=hintText, + parent=self, + basename=basename, ) win.exec_() if win.cancel: - win.entryText = '' + win.entryText = "" return win.cancel, win.entryText def criticalNoCommonBasename(self, filenames, parent_path): msg = widgets.myMessageBox() txt = html_utils.paragraph( - f'The file name {filenames[0]}
    ' - 'does not follow Cell-ACDC naming convention.

    ' - 'The name must have the same common basename ' - 'as all the other files inside the ' - 'Position_n/Images folder.

    ' - 'For example, if in the Images folder you have two files called ' - 'ASY015_SCD_phase_contr.tif and ' - 'ASY015_SCD_mCitrine.tif then the common basename ' - 'is ASY015_SCD_ and the file that you are tring to ' - 'convert should start with the same common basename.' - ) - msg.critical( - self, 'Name of selected file not compatible', txt + f"The file name {filenames[0]}
    " + "does not follow Cell-ACDC naming convention.

    " + "The name must have the same common basename " + "as all the other files inside the " + "Position_n/Images folder.

    " + "For example, if in the Images folder you have two files called " + "ASY015_SCD_phase_contr.tif and " + "ASY015_SCD_mCitrine.tif then the common basename " + "is ASY015_SCD_ and the file that you are tring to " + "convert should start with the same common basename." ) + msg.critical(self, "Name of selected file not compatible", txt) def selectFiles(self, images_path, filterExt=None): - files = myutils.listdir(images_path) + files = utils.listdir(images_path) if filterExt is not None: items = [] for file in files: @@ -373,12 +404,14 @@ def selectFiles(self, images_path, filterExt=None): items = files selectFilesWidget = widgets.QDialogListbox( - 'Select files', - f'Select the .{self.from_} files you want to convert to ' - f'.{self.to}\n\n' - 'NOTE: if you selected multiple Position folders I will try \n' - 'to convert all selected files in each Position folder', - items, multiSelection=False, parent=self + "Select files", + f"Select the .{self.from_} files you want to convert to " + f".{self.to}\n\n" + "NOTE: if you selected multiple Position folders I will try \n" + "to convert all selected files in each Position folder", + items, + multiSelection=False, + parent=self, ) selectFilesWidget.exec_() @@ -395,12 +428,12 @@ def addToRecentPaths(self, exp_path): if not os.path.exists(exp_path): return if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() else: - openedOn = [np.nan]*len(recentPaths) + openedOn = [np.nan] * len(recentPaths) if exp_path in recentPaths: pop_idx = recentPaths.index(exp_path) recentPaths.pop(pop_idx) @@ -414,17 +447,20 @@ def addToRecentPaths(self, exp_path): else: recentPaths = [exp_path] openedOn = [datetime.datetime.now()] - df = pd.DataFrame({'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, - dtype='datetime64[ns]')}) - df.index.name = 'index' + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" df.to_csv(recentPaths_path) def doAbort(self): if self.allowExit: - exit('Execution aborted by the user') + exit("Execution aborted by the user") else: - print('Conversion task aborted by the user.') + print("Conversion task aborted by the user.") return True def closeEvent(self, event): @@ -433,20 +469,21 @@ def closeEvent(self, event): txt = html_utils.paragraph(""" Conversion process aborted. """) - msg.warning(self, 'Process aborted', txt) - + msg.warning(self, "Process aborted", txt) + if self.actionToEnable is not None: self.actionToEnable.setDisabled(False) self.mainWin.setWindowState(Qt.WindowNoState) self.mainWin.setWindowState(Qt.WindowActive) self.mainWin.raise_() + class ImagesToPositions(QDialog): def __init__(self, parent=None) -> None: super().__init__(parent) - - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module='converter' + + logger, logs_path, log_path, log_filename = utils.setupLogger( + module="converter" ) self.logger = logger @@ -454,20 +491,19 @@ def __init__(self, parent=None) -> None: self.log_filename = log_filename self.logs_path = logs_path - self.logger.info('Initializing converter...') + self.logger.info("Initializing converter...") self.cancel = True - self.setWindowTitle('Cell-ACDC converter') - self.funcDescription = 'Cell-ACDC converter' + self.setWindowTitle("Cell-ACDC converter") + self.funcDescription = "Cell-ACDC converter" instructions = [ - 'Put all the images into one folder' - 'Press start button', - 'Select folder containing the images', - 'Select where to save the Position folders', - 'Insert a text to append at the end of each image (e.g., the channel name)', - 'Wait that process ends' + "Put all the images into one folderPress start button", + "Select folder containing the images", + "Select where to save the Position folders", + "Insert a text to append at the end of each image (e.g., the channel name)", + "Wait that process ends", ] txt = html_utils.paragraph(f""" @@ -482,7 +518,7 @@ def __init__(self, parent=None) -> None: layout = QVBoxLayout() textLayout = QHBoxLayout() - pixmap = QtGui.QIcon(":cog_play.svg").pixmap(QSize(64,64)) + pixmap = QtGui.QIcon(":cog_play.svg").pixmap(QSize(64, 64)) iconLabel = QLabel() iconLabel.setPixmap(pixmap) @@ -492,9 +528,9 @@ def __init__(self, parent=None) -> None: textLayout.addStretch(1) buttonsLayout = QHBoxLayout() - stopButton = widgets.stopPushButton('Stop process') - startButton = widgets.playPushButton(' Start ') - cancelButton = widgets.cancelPushButton('Close') + stopButton = widgets.stopPushButton("Stop process") + startButton = widgets.playPushButton(" Start ") + cancelButton = widgets.cancelPushButton("Close") buttonsLayout.addStretch(1) buttonsLayout.addWidget(cancelButton) @@ -524,51 +560,53 @@ def __init__(self, parent=None) -> None: cancelButton.clicked.connect(self.close) startButton.clicked.connect(self.start) stopButton.clicked.connect(self.stop) - + def showEvent(self, event: QtGui.QShowEvent) -> None: self.startButton.setFixedWidth(self.stopButton.width()) self.stopButton.hide() return super().showEvent(event) - @exception_handler + @exception_handler def start(self): self.startButton.hide() self.stopButton.show() - MostRecentPath = myutils.getMostRecentPath() + MostRecentPath = utils.getMostRecentPath() folderPath = QFileDialog.getExistingDirectory( - self, 'Select folder containing images', MostRecentPath + self, "Select folder containing images", MostRecentPath ) if not folderPath: - self.logger.info('No path selected. Process stopped.') + self.logger.info("No path selected. Process stopped.") self.stop() return - + tagertFolderPath = QFileDialog.getExistingDirectory( - self, 'Select where to save Position folders', folderPath + self, "Select where to save Position folders", folderPath ) if not tagertFolderPath: - self.logger.info('Target path not selected. Process stopped.') + self.logger.info("Target path not selected. Process stopped.") self.stop() return - - myutils.addToRecentPaths(tagertFolderPath, logger=self.logger) + + utils.addToRecentPaths(tagertFolderPath, logger=self.logger) textToAppendInstructions = html_utils.paragraph( - 'Insert a name to append at the end of each new .tif file.' - '

    ' - 'This name is required because Cell-ACDC needs to load files
    ' - 'that ends with the same common name.

    ' + "Insert a name to append at the end of each new .tif file." + "

    " + "This name is required because Cell-ACDC needs to load files
    " + "that ends with the same common name.

    " 'Typically, you can use this for the channel name, e.g., "GFP".' ) win = apps.filenameDialog( - ext='.tif', title='Insert text to append', + ext=".tif", + title="Insert text to append", hintText=textToAppendInstructions, - parent=self, allowEmpty=False + parent=self, + allowEmpty=False, ) win.exec_() if win.cancel: - self.logger.info('Process cancelled at insert text.') + self.logger.info("Process cancelled at insert text.") self.stop() return @@ -589,37 +627,37 @@ def start(self): self.thread.started.connect(self.worker.run) self.thread.start() - + def stop(self): self.startButton.show() self.stopButton.hide() - if hasattr(self, 'worker'): + if hasattr(self, "worker"): self.worker.abort = True - + @exception_handler def workerInitProgressBar(self, maximum): self.progressBar.setValue(0) self.progressBar.setMaximum(maximum) - + @exception_handler def workerUpdateProgressBar(self): self.progressBar.update(1) - + @exception_handler def workerProgress(self, txt): self.logger.info(txt) self.logConsole.append(txt) - + @exception_handler def workerProgressBar(self, txt): self.logger.info(txt) self.logConsole.write(txt) - + @exception_handler def workerCritical(self, error): raise error - + @exception_handler def workerFinished(self): self.startButton.show() @@ -628,13 +666,14 @@ def workerFinished(self): if self.worker.abort: msg = widgets.myMessageBox() msg.warning( - self, 'Conversion process stopped', - html_utils.paragraph('Conversion process stopped!') + self, + "Conversion process stopped", + html_utils.paragraph("Conversion process stopped!"), ) else: msg = widgets.myMessageBox() msg.information( - self, 'Conversion completed', - html_utils.paragraph('Conversion process completed!') + self, + "Conversion completed", + html_utils.paragraph("Conversion process completed!"), ) - diff --git a/cellacdc/utils/countObjects.py b/cellacdc/tools/countObjects.py similarity index 68% rename from cellacdc/utils/countObjects.py rename to cellacdc/tools/countObjects.py index 92ff1326a..5d412e0b7 100644 --- a/cellacdc/utils/countObjects.py +++ b/cellacdc/tools/countObjects.py @@ -1,30 +1,35 @@ -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class CountObjectsInsegm(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.CountObjectsInSegm(self) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: txt = f'"{self._title}" process cancelled.' @@ -33,8 +38,8 @@ def workerFinished(self, worker, aborted=False): self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/createConnected3Dsegm.py b/cellacdc/tools/createConnected3Dsegm.py similarity index 65% rename from cellacdc/utils/createConnected3Dsegm.py rename to cellacdc/tools/createConnected3Dsegm.py index 1cdde0ffd..3b4b97381 100644 --- a/cellacdc/utils/createConnected3Dsegm.py +++ b/cellacdc/tools/createConnected3Dsegm.py @@ -1,67 +1,70 @@ -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class CreateConnected3Dsegm(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.CreateConnected3Dsegm(self) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def askAppendName(self, basename, existingEndnames): - helpText = ( - """ + helpText = """ The new 3D segmentation file will be saved with a different file name.

    Insert a name to append to the end of the new name. The rest of the name will be the same as the original file. """ - ) win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the new 3D segmentation file:', - existingNames=existingEndnames, - helpText=helpText, + hintText="Insert a name for the new 3D segmentation file:", + existingNames=existingEndnames, + helpText=helpText, allowEmpty=False, - parent=self + parent=self, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = '3D segmentation mask creation process aborted.' + txt = "3D segmentation mask creation process aborted." else: - txt = '3D segmentation mask creation process completed.' + txt = "3D segmentation mask creation process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/customPreprocess.py b/cellacdc/tools/customPreprocess.py similarity index 68% rename from cellacdc/utils/customPreprocess.py rename to cellacdc/tools/customPreprocess.py index 24780ad3d..1e1ff95ff 100644 --- a/cellacdc/utils/customPreprocess.py +++ b/cellacdc/tools/customPreprocess.py @@ -2,102 +2,103 @@ import pandas as pd -from .. import apps, myutils, workers, widgets, html_utils, load, printl +from .. import apps, utils, workers, widgets, html_utils, load, printl from .base import NewThreadMultipleExpBaseUtil + class CustomPreprocessUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.CustomPreprocessWorkerUtil(self) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigAskSetupRecipe.connect(self.askSetupRecipe) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def askSetupRecipe(self, exp_path, pos_foldernames): channel_names = set() df_metadata = None for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) channel_names.update(chNames) if df_metadata is not None: continue - + self.worker.basename = basename df_metadata = load.load_metadata_df(images_path) - + win = apps.PreProcessRecipeDialogUtil( - channel_names, - df_metadata=df_metadata, - parent=self + channel_names, df_metadata=df_metadata, parent=self ) win.exec_() - + if win.cancel: self.worker.abort = win.cancel self.worker.waitCond.wakeAll() - return - + return + self.worker.selectedChannels = win.selectedChannels self.worker.recipe = win.selectedRecipe self.worker.waitCond.wakeAll() - + def showEvent(self, event): self.runWorker() - + def askAppendName(self, basename): - helpText = ( - """ + helpText = """ The preprocessed image file will be saved with a different file name.

    Insert a name to append to the end of the new file name. The rest of the name will be the same as the original file. """ - ) win = apps.filenameDialog( basename=basename, - ext='.tif', - hintText='Insert a name for the preprocessed image file:', - defaultEntry='preprocessed', - helpText=helpText, + ext=".tif", + hintText="Insert a name for the preprocessed image file:", + defaultEntry="preprocessed", + helpText=helpText, allowEmpty=False, - parent=self + parent=self, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Custom pre-processing aborted.' + txt = "Custom pre-processing aborted." else: - txt = 'Custom pre-processing completed.' + txt = "Custom pre-processing completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/fillHolesInSegm.py b/cellacdc/tools/fillHolesInSegm.py similarity index 67% rename from cellacdc/utils/fillHolesInSegm.py rename to cellacdc/tools/fillHolesInSegm.py index 40564e708..6517a7e9a 100644 --- a/cellacdc/utils/fillHolesInSegm.py +++ b/cellacdc/tools/fillHolesInSegm.py @@ -1,54 +1,62 @@ -from .. import apps, myutils, workers, widgets, html_utils, load +from .. import apps, utils, workers, widgets, html_utils, load from .base import NewThreadMultipleExpBaseUtil import os + class fillHolesInSegm(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.FillHolesInSegWorker(self) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigAborted.connect(self.workerAborted) self.worker.sigSelectSegmFiles.connect(self.askInputSegm) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Filling holes in segmentation mask process aborted.' + txt = "Filling holes in segmentation mask process aborted." else: - txt = 'Filling holes in segmentation mask process completed.' + txt = "Filling holes in segmentation mask process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) self.close() - + def askInputSegm(self, exp_path, pos_foldernames): existingSegmEndNames = load.get_segm_endnames_from_exp_path( - exp_path, pos_foldernames=pos_foldernames - ) + exp_path, pos_foldernames=pos_foldernames + ) win = apps.SelectSegmFileDialog( - existingSegmEndNames, exp_path, parent=self, allowMultipleSelection=True, - infoText=f"Select the segmentation files for folder {exp_path}." + existingSegmEndNames, + exp_path, + parent=self, + allowMultipleSelection=True, + infoText=f"Select the segmentation files for folder {exp_path}.", ) win.exec_() if win.cancel: @@ -58,28 +66,27 @@ def askInputSegm(self, exp_path, pos_foldernames): return self.worker.endFilenameSegmTemp = win.selectedItemTexts self.worker.waitCond.wakeAll() - + def askAppendName(self, basename): - helpText = ( - """ + helpText = """ The new segmentation file can be saved with a different file name.

    Insert a name if the old segmentation should not be overwritten. """ - ) win = apps.filenameDialog( - hintText='Insert a name extension if the old file should not be overwritten.', - helpText=helpText, basename=basename, - allowEmpty=True + hintText="Insert a name extension if the old file should not be overwritten.", + helpText=helpText, + basename=basename, + allowEmpty=True, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + if win.entryText is None: self.worker.appendedName = "" else: self.worker.appendedName = win.entryText - self.worker.waitCond.wakeAll() \ No newline at end of file + self.worker.waitCond.wakeAll() diff --git a/cellacdc/utils/filterObjFromCoordsTable.py b/cellacdc/tools/filterObjFromCoordsTable.py similarity index 67% rename from cellacdc/utils/filterObjFromCoordsTable.py rename to cellacdc/tools/filterObjFromCoordsTable.py index bf576ac18..230df7629 100644 --- a/cellacdc/utils/filterObjFromCoordsTable.py +++ b/cellacdc/tools/filterObjFromCoordsTable.py @@ -1,80 +1,82 @@ -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class FilterObjsFromCoordsTable(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.FilterObjsFromCoordsTable(self) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigSetColumnsNames.connect(self.setColumnsNames) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def askAppendName(self, basename, existingEndnames): - helpText = ( - """ + helpText = """ You can choose to save a new file for the filtered segmentation or overwrite the existing one. """ - ) win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the filtered segmentation file:', - existingNames=existingEndnames, - helpText=helpText, + hintText="Insert a name for the filtered segmentation file:", + existingNames=existingEndnames, + helpText=helpText, allowEmpty=False, - parent=self + parent=self, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def setColumnsNames(self, columns, categories, optionalCategories): win = apps.SetColumnNamesDialog( - columns, categories, optionalCategories=optionalCategories, - parent=self + columns, categories, optionalCategories=optionalCategories, parent=self ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() - return - + return + self.selectedColumnsPerCategory = win.selectedColumns self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Filter segmented objects from coordinates table process aborted.' + txt = "Filter segmented objects from coordinates table process aborted." else: - txt = 'Filter segmented objects from coordinates table process completed.' + txt = "Filter segmented objects from coordinates table process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/fromImageJroiToSegm.py b/cellacdc/tools/fromImageJroiToSegm.py similarity index 72% rename from cellacdc/utils/fromImageJroiToSegm.py rename to cellacdc/tools/fromImageJroiToSegm.py index 04f69492a..bb17bc7fd 100644 --- a/cellacdc/utils/fromImageJroiToSegm.py +++ b/cellacdc/tools/fromImageJroiToSegm.py @@ -1,51 +1,58 @@ -from .. import myutils, workers, widgets, html_utils +from .. import utils, workers, widgets, html_utils from .. import apps from .base import NewThreadMultipleExpBaseUtil + class fromImageJRoiToSegmUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.qparent = parent self.expPaths = expPaths - + def runWorker(self): self.worker = workers.FromImajeJroiToSegmNpzWorker(self) self.worker.sigSelectRoisProps.connect(self.selectRoisProps) super().runWorker(self.worker) - + def selectRoisProps(self, roi_filepath, TZYX_shape, is_multi_pos): win = apps.ImageJRoisToSegmManager( - roi_filepath, TZYX_shape, + roi_filepath, + TZYX_shape, addUseSamePropsForNextPosButton=is_multi_pos, - parent=self.qparent + parent=self.qparent, ) win.exec_() self.worker.abort = win.cancel if win.cancel: self.worker.waitCond.wakeAll() return - + self.worker.IDsToRoisMapper = win.IDsToRoisMapper self.worker.rescaleRoisSizes = win.rescaleSizes self.worker.repeatRoisZslicesRange = win.repeatRoisZslicesRange self.worker.useSamePropsForNextPos = win.useSamePropsForNextPos self.worker.areAllRoisSelected = win.areAllRoisSelected self.worker.waitCond.wakeAll() - + def showEvent(self, event): self.runWorker() - + def workerFinished(self, worker): super().workerFinished(worker) - txt = 'Converting from ImageJ ROIs to Cell-ACDC segmentation file(s) completed.' + txt = "Converting from ImageJ ROIs to Cell-ACDC segmentation file(s) completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) - self.close() \ No newline at end of file + msg.information(self, "Process completed", html_utils.paragraph(txt)) + self.close() diff --git a/cellacdc/utils/fucciPreprocess.py b/cellacdc/tools/fucciPreprocess.py similarity index 67% rename from cellacdc/utils/fucciPreprocess.py rename to cellacdc/tools/fucciPreprocess.py index 1724ebe6d..4faabf494 100644 --- a/cellacdc/utils/fucciPreprocess.py +++ b/cellacdc/tools/fucciPreprocess.py @@ -2,116 +2,115 @@ import pandas as pd -from .. import apps, myutils, workers, widgets, html_utils, load +from .. import apps, utils, workers, widgets, html_utils, load from .base import NewThreadMultipleExpBaseUtil + class FucciPreprocessUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.FucciPreprocessWorker(self) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigAskParams.connect(self.askSelectParams) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def askSelectParams(self, exp_path, pos_foldernames): channel_names = set() df_metadata = None for p, pos in enumerate(pos_foldernames): pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames(images_path) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames(images_path) channel_names.update(chNames) if df_metadata is not None: continue - + self.worker.basename = basename df_metadata = load.load_metadata_df(images_path) - + if len(channel_names) < 2: - txt = ( - 'At least two channels are needed to run the FUCCI ' - 'pre-processing.' - ) + txt = "At least two channels are needed to run the FUCCI pre-processing." self.logger.error(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.critical(self, 'Error', html_utils.paragraph(txt)) + msg.critical(self, "Error", html_utils.paragraph(txt)) self.worker.abort = True self.worker.waitCond.wakeAll() return - + win = apps.FucciPreprocessDialog( - channel_names, - df_metadata=df_metadata, - parent=self + channel_names, df_metadata=df_metadata, parent=self ) win.exec_() - + self.worker.firstChannelName = win.firstChannelName self.worker.secondChannelName = win.secondChannelName fucciFilterKwargs = win.function_kwargs self.worker.fucciFilterKwargs = fucciFilterKwargs - - if fucciFilterKwargs['do_basicpy_background_correction']: + + if fucciFilterKwargs["do_basicpy_background_correction"]: from cellacdc import preprocess + preprocess._init_basicpy_background_correction(parent=self) - + self.worker.abort = win.cancel self.worker.waitCond.wakeAll() - + def showEvent(self, event): self.runWorker() - + def askAppendName(self, basename): - helpText = ( - """ + helpText = """ The combined and preprocessed image file will be saved with a different file name.

    Insert a name to append to the end of the new name. The rest of the name will be the same as the original file. """ - ) win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the new combined channels file:', - defaultEntry='fucci_combined', - helpText=helpText, + hintText="Insert a name for the new combined channels file:", + defaultEntry="fucci_combined", + helpText=helpText, allowEmpty=False, - parent=self + parent=self, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'FUCCI pre-processing aborted.' + txt = "FUCCI pre-processing aborted." else: - txt = 'FUCCI pre-processing completed.' + txt = "FUCCI pre-processing completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/generateMothBudTotalTable.py b/cellacdc/tools/generateMothBudTotalTable.py similarity index 63% rename from cellacdc/utils/generateMothBudTotalTable.py rename to cellacdc/tools/generateMothBudTotalTable.py index bcee83e43..edd6a4fb1 100644 --- a/cellacdc/utils/generateMothBudTotalTable.py +++ b/cellacdc/tools/generateMothBudTotalTable.py @@ -4,68 +4,65 @@ import pandas as pd from .. import exception_handler -from .. import myutils, apps, widgets, html_utils, printl, workers -from ..utils import base +from .. import utils, apps, widgets, html_utils, printl, workers +from . import base from qtpy.QtWidgets import QFileDialog + class GenerateMothBudTotalUtil(base.MainThreadSinglePosUtilBase): def __init__( - self, app, title: str, infoText: str, parent=None, - callbackOnFinished=None - ): - module = myutils.get_module_name(__file__) - super().__init__( - app, title, module, infoText, parent - ) + self, app, title: str, infoText: str, parent=None, callbackOnFinished=None + ): + module = utils.get_module_name(__file__) + super().__init__(app, title, module, infoText, parent) self.sigClose.connect(self.close) self.callbackOnFinished = callbackOnFinished @exception_handler - def run(self): + def run(self): msg = widgets.myMessageBox(showCentered=False, wrapText=False) txt = html_utils.paragraph( 'After clicking "Ok" you will be asked to select the input table ' - 'file (.csv) containing pedigree information.' + "file (.csv) containing pedigree information." ) - msg.information(self, 'Instructions', txt) + msg.information(self, "Instructions", txt) if msg.cancel: return False - + import qtpy.compat + input_csv_filepath = qtpy.compat.getopenfilename( - parent=self, - caption='Select CSV file to load', - filters='CSV (*.csv);;All Files (*)', - basedir=myutils.getMostRecentPath() + parent=self, + caption="Select CSV file to load", + filters="CSV (*.csv);;All Files (*)", + basedir=utils.getMostRecentPath(), )[0] if input_csv_filepath is None or not input_csv_filepath: return False - myutils.addToRecentPaths(os.path.dirname(input_csv_filepath)) - + utils.addToRecentPaths(os.path.dirname(input_csv_filepath)) + self.logger.info(f'Reading column names in table "{input_csv_filepath}"...') df = pd.read_csv(input_csv_filepath, nrows=2) - win = apps.GenerateMotherBudTotalTableSelectColumnsDialog( - df, parent=self - ) + win = apps.GenerateMotherBudTotalTableSelectColumnsDialog(df, parent=self) win.exec_() if win.cancel: return False - + selected_options = win.selected_options - + csv_filename = os.path.basename(input_csv_filepath) csv_filename_noext, ext = os.path.splitext(csv_filename) win = apps.filenameDialog( - ext='.csv', - basename=f'{csv_filename_noext}_', - hintText='Insert a filename for the output table file:', - allowEmpty=False, - defaultEntry='mother_bud_total', + ext=".csv", + basename=f"{csv_filename_noext}_", + hintText="Insert a filename for the output table file:", + allowEmpty=False, + defaultEntry="mother_bud_total", ) win.exec_() if win.cancel: @@ -75,7 +72,7 @@ def run(self): out_csv_filepath = os.path.join( os.path.dirname(input_csv_filepath), out_csv_filename ) - + self.worker = workers.GenerateMotherBudTotalTableWorker( self, input_csv_filepath, selected_options, out_csv_filepath ) @@ -83,9 +80,8 @@ def run(self): self.worker.signals.finished.connect(self.callbackOnFinished) self.runWorker(self.worker) return True - + def overWriteClicked(self, win): win.cancel = False - win.filename = '' + win.filename = "" win.close() - diff --git a/cellacdc/utils/rename.py b/cellacdc/tools/rename.py similarity index 63% rename from cellacdc/utils/rename.py rename to cellacdc/tools/rename.py index 7801ad696..ec9e74079 100755 --- a/cellacdc/utils/rename.py +++ b/cellacdc/tools/rename.py @@ -13,9 +13,15 @@ from tqdm import tqdm from qtpy.QtWidgets import ( - QApplication, QMainWindow, QFileDialog, - QVBoxLayout, QPushButton, QLabel, QStyleFactory, - QWidget, QMessageBox + QApplication, + QMainWindow, + QFileDialog, + QVBoxLayout, + QPushButton, + QLabel, + QStyleFactory, + QWidget, + QMessageBox, ) from qtpy.QtCore import Qt, QEventLoop from qtpy import QtGui @@ -25,23 +31,22 @@ sys.path.append(cellacdc_path) # Custom modules -from .. import prompts, load, myutils, apps, html_utils, widgets +from .. import prompts, load, utils, apps, html_utils, widgets from .. import recentPaths_path, cellacdc_path, settings_folderpath -if os.name == 'nt': +if os.name == "nt": try: # Set taskbar icon in windows import ctypes - myappid = 'schmollerlab.cellacdc.pyqt.v1' # arbitrary string + + myappid = "schmollerlab.cellacdc.pyqt.v1" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) except Exception as e: pass + class renameFilesWin(QMainWindow): - def __init__( - self, parent=None, allowExit=False, - actionToEnable=None, mainWin=None - ): + def __init__(self, parent=None, allowExit=False, actionToEnable=None, mainWin=None): self.allowExit = allowExit self.processFinished = False self.actionToEnable = actionToEnable @@ -56,21 +61,21 @@ def __init__( mainLayout = QVBoxLayout() titleText = html_utils.paragraph( - '
    Renaming files utility', font_size='14px' + "
    Renaming files utility", font_size="14px" ) titleLabel = QLabel(titleText) mainLayout.addWidget(titleLabel) infoTxt = ( - 'Follow the instructions in the pop-up windows.
    ' - 'Note that pop-ups might be minimized or behind other open windows.

    ' - 'Progess is displayed in the terminal/console.' + "Follow the instructions in the pop-up windows.
    " + "Note that pop-ups might be minimized or behind other open windows.

    " + "Progess is displayed in the terminal/console." ) informativeLabel = QLabel(html_utils.paragraph(infoTxt)) mainLayout.addWidget(informativeLabel) - abortButton = QPushButton('Stop processs') + abortButton = QPushButton("Stop processs") abortButton.clicked.connect(self.close) mainLayout.addWidget(abortButton) @@ -79,55 +84,53 @@ def __init__( def getMostRecentPath(self): if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - if 'opened_last_on' in df.columns: - df = df.sort_values('opened_last_on', ascending=False) - self.MostRecentPath = df.iloc[0]['path'] + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + self.MostRecentPath = df.iloc[0]["path"] if not isinstance(self.MostRecentPath, str): - self.MostRecentPath = '' + self.MostRecentPath = "" else: - self.MostRecentPath = '' + self.MostRecentPath = "" def main(self): self.getMostRecentPath() exp_path = QFileDialog.getExistingDirectory( - self, 'Select experiment folder containing Position_n folders ' - 'or specific Position_n folder', self.MostRecentPath) + self, + "Select experiment folder containing Position_n folders " + "or specific Position_n folder", + self.MostRecentPath, + ) self.addToRecentPaths(exp_path) - if exp_path == '': + if exp_path == "": abort = self.doAbort() if abort: self.close() return - self.setWindowTitle( - f'Cell-ACDC - Renaming files - "{exp_path}"' - ) + self.setWindowTitle(f'Cell-ACDC - Renaming files - "{exp_path}"') - folder_type = myutils.determine_folder_type(exp_path) + folder_type = utils.determine_folder_type(exp_path) is_pos_folder, is_images_folder, exp_path = folder_type - print('Loading data...') + print("Loading data...") if not is_pos_folder and not is_images_folder: select_folder = load.select_exp_folder() values = select_folder.get_values_segmGUI(exp_path) if not values: txt = ( - 'The selected folder:\n\n ' - f'{exp_path}\n\n' - 'is not a valid folder. ' - 'Select a folder that contains the Position_n folders' + "The selected folder:\n\n " + f"{exp_path}\n\n" + "is not a valid folder. " + "Select a folder that contains the Position_n folders" ) msg = QMessageBox() - msg.critical( - self, 'Incompatible folder', txt, msg.Ok - ) + msg.critical(self, "Incompatible folder", txt, msg.Ok) self.close() return - select_folder.QtPrompt(self, values, allow_cancel=False, show=True) if select_folder.cancel: abort = self.doAbort() @@ -135,15 +138,15 @@ def main(self): self.close() return - pos_foldernames = select_folder.selected_pos - images_paths = [os.path.join(exp_path, pos, 'Images') - for pos in pos_foldernames] + images_paths = [ + os.path.join(exp_path, pos, "Images") for pos in pos_foldernames + ] elif is_pos_folder: pos_foldername = os.path.basename(exp_path) exp_path = os.path.dirname(exp_path) - images_paths = [f'{exp_path}/{pos_foldername}/Images'] + images_paths = [f"{exp_path}/{pos_foldername}/Images"] elif is_images_folder: images_paths = [exp_path] @@ -155,7 +158,6 @@ def main(self): self.close() return - abort, appendedTxt = self.askTxtAppend(selectedFilenames[0]) if abort: abort = self.doAbort() @@ -166,54 +168,45 @@ def main(self): print(f'Renaming files by appending "_{appendedTxt}"...') if len(selectedFilenames) > 1 or len(images_paths) > 1: ch_name_selector = prompts.select_channel_name() - ls = myutils.listdir(images_paths[0]) + ls = utils.listdir(images_paths[0]) all_channelNames, abort = ch_name_selector.get_available_channels( - ls, images_paths[0], useExt=None + ls, images_paths[0], useExt=None ) if abort: - self.criticalNoCommonBasename( - selectedFilenames, images_paths[0] - ) + self.criticalNoCommonBasename(selectedFilenames, images_paths[0]) self.close() return _endswith_li = [ - f[len(ch_name_selector.basename):] for f in selectedFilenames + f[len(ch_name_selector.basename) :] for f in selectedFilenames ] for images_path in tqdm(images_paths, ncols=100): - ls = myutils.listdir(images_path) + ls = utils.listdir(images_path) _, skip = ch_name_selector.get_available_channels( ls, images_path, useExt=None ) if skip: - print('') - print('-------------------------------------') - print( - f'{images_path} data structure compromised!' - 'Skipping it.' - ) - print('-------------------------------------') + print("") + print("-------------------------------------") + print(f"{images_path} data structure compromised!Skipping it.") + print("-------------------------------------") for _endswith in _endswith_li: for file in ls: if file.endswith(_endswith): - self._rename( - file, images_path, appendedTxt - ) + self._rename(file, images_path, appendedTxt) else: self._rename(selectedFilenames[0], images_paths[0], appendedTxt) msg = widgets.myMessageBox() - txt = html_utils.paragraph( - 'Renaming process completed.

    ' - ) - msg.information(self, 'Renaming process completed', txt) + txt = html_utils.paragraph("Renaming process completed.

    ") + msg.information(self, "Renaming process completed", txt) self.close() if self.allowExit: - exit('Done.') + exit("Done.") def _rename(self, file, parent_path, appendedTxt): filename, ext = os.path.splitext(file) - new_file = f'{filename}_{appendedTxt}{ext}' + new_file = f"{filename}_{appendedTxt}{ext}" src_filepath = os.path.join(parent_path, file) new_filepath = os.path.join(parent_path, new_file) os.rename(src_filepath, new_filepath) @@ -221,23 +214,20 @@ def _rename(self, file, parent_path, appendedTxt): def save(self, alignedData, filePath, appendedTxt, first_call=True): dir = os.path.dirname(filePath) filename, ext = os.path.splitext(os.path.basename(filePath)) - path = os.path.join(dir, f'{filename}_{appendedTxt}{ext}') + path = os.path.join(dir, f"{filename}_{appendedTxt}{ext}") def askTxtAppend(self, filename): font = QtGui.QFont() font.setPixelSize(13) - self.win = apps.QDialogAppendTextFilename( - filename, '', parent=self, font=font - ) + self.win = apps.QDialogAppendTextFilename(filename, "", parent=self, font=font) self.win.exec_() return self.win.cancel, self.win.LE.text() def criticalNoCommonBasename(self, filenames, parent_path): - myutils.checkDataIntegrity(filenames, parent_path, parentQWidget=self) - + utils.checkDataIntegrity(filenames, parent_path, parentQWidget=self) def selectFiles(self, images_path, filterExt=None): - files = myutils.listdir(images_path) + files = utils.listdir(images_path) if filterExt is not None: items = [] for file in files: @@ -249,9 +239,11 @@ def selectFiles(self, images_path, filterExt=None): items = files selectFilesWidget = widgets.QDialogListbox( - 'Select files', - 'Select the files you want to rename', - items, multiSelection=True, parent=self + "Select files", + "Select the files you want to rename", + items, + multiSelection=True, + parent=self, ) selectFilesWidget.exec_() @@ -268,12 +260,12 @@ def addToRecentPaths(self, exp_path): if not os.path.exists(exp_path): return if os.path.exists(recentPaths_path): - df = pd.read_csv(recentPaths_path, index_col='index') - recentPaths = df['path'].to_list() - if 'opened_last_on' in df.columns: - openedOn = df['opened_last_on'].to_list() + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() else: - openedOn = [np.nan]*len(recentPaths) + openedOn = [np.nan] * len(recentPaths) if exp_path in recentPaths: pop_idx = recentPaths.index(exp_path) recentPaths.pop(pop_idx) @@ -287,17 +279,20 @@ def addToRecentPaths(self, exp_path): else: recentPaths = [exp_path] openedOn = [datetime.datetime.now()] - df = pd.DataFrame({'path': recentPaths, - 'opened_last_on': pd.Series(openedOn, - dtype='datetime64[ns]')}) - df.index.name = 'index' + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" df.to_csv(recentPaths_path) def doAbort(self): if self.allowExit: - exit('Execution aborted by the user') + exit("Execution aborted by the user") else: - print('Conversion task aborted by the user.') + print("Conversion task aborted by the user.") return True def closeEvent(self, event): diff --git a/cellacdc/utils/repeat.py b/cellacdc/tools/repeat.py similarity index 72% rename from cellacdc/utils/repeat.py rename to cellacdc/tools/repeat.py index 25383a9bf..926a39a1a 100644 --- a/cellacdc/utils/repeat.py +++ b/cellacdc/tools/repeat.py @@ -5,40 +5,44 @@ from qtpy.QtCore import Qt, QThread, QSize from qtpy.QtWidgets import ( - QDialog, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QListWidgetItem + QDialog, + QVBoxLayout, + QHBoxLayout, + QLabel, + QFileDialog, + QListWidgetItem, ) from qtpy import QtGui from .. import exception_handler -from .. import myutils, html_utils, workers, widgets, load, apps +from .. import utils, html_utils, workers, widgets, load, apps + class repeatDataPrepWindow(QDialog): def __init__(self, parent=None) -> None: super().__init__(parent) - name = 'repeat data prep' - - logger, logs_path, log_path, log_filename = myutils.setupLogger( - module=name - ) + name = "repeat data prep" + + logger, logs_path, log_path, log_filename = utils.setupLogger(module=name) self.logger = logger self.log_path = log_path self.log_filename = log_filename self.logs_path = logs_path - self.logger.info(f'Initializing {name}...') + self.logger.info(f"Initializing {name}...") self.cancel = True - self.setWindowTitle(f'Cell-ACDC {name}') - self.funcDescription = f'Cell-ACDC {name}' + self.setWindowTitle(f"Cell-ACDC {name}") + self.funcDescription = f"Cell-ACDC {name}" instructions = [ - 'Press start button', - 'Select experiment folder or specific Position folder', - 'Select which channels or un-prepped .tif file to apply data prep to', - 'Wait until process ends' + "Press start button", + "Select experiment folder or specific Position folder", + "Select which channels or un-prepped .tif file to apply data prep to", + "Wait until process ends", ] txt = html_utils.paragraph(f""" @@ -53,7 +57,7 @@ def __init__(self, parent=None) -> None: layout = QVBoxLayout() textLayout = QHBoxLayout() - pixmap = QtGui.QIcon(":cog_play.svg").pixmap(QSize(64,64)) + pixmap = QtGui.QIcon(":cog_play.svg").pixmap(QSize(64, 64)) iconLabel = QLabel() iconLabel.setPixmap(pixmap) @@ -63,9 +67,9 @@ def __init__(self, parent=None) -> None: textLayout.addStretch(1) buttonsLayout = QHBoxLayout() - stopButton = widgets.stopPushButton('Stop process') - startButton = widgets.playPushButton(' Start ') - cancelButton = widgets.cancelPushButton('Close') + stopButton = widgets.stopPushButton("Stop process") + startButton = widgets.playPushButton(" Start ") + cancelButton = widgets.cancelPushButton("Close") buttonsLayout.addStretch(1) buttonsLayout.addWidget(cancelButton) @@ -94,30 +98,29 @@ def __init__(self, parent=None) -> None: cancelButton.clicked.connect(self.close) startButton.clicked.connect(self.start) stopButton.clicked.connect(self.stop) - + def showEvent(self, event: QtGui.QShowEvent) -> None: self.startButton.setFixedWidth(self.stopButton.width()) self.stopButton.hide() return super().showEvent(event) - @exception_handler + @exception_handler def start(self): self.startButton.hide() self.stopButton.show() - MostRecentPath = myutils.getMostRecentPath() + MostRecentPath = utils.getMostRecentPath() exp_path = QFileDialog.getExistingDirectory( - self, 'Select experiment folder or specific Position folder', - MostRecentPath + self, "Select experiment folder or specific Position folder", MostRecentPath ) if not exp_path: - self.logger.info('No path selected. Process stopped.') + self.logger.info("No path selected. Process stopped.") self.stop() return - - myutils.addToRecentPaths(exp_path, logger=self.logger) - - folder_type = myutils.determine_folder_type(exp_path) + + utils.addToRecentPaths(exp_path, logger=self.logger) + + folder_type = utils.determine_folder_type(exp_path) is_pos_folder, is_images_folder, exp_path = folder_type if is_pos_folder: @@ -137,13 +140,15 @@ def start(self): return if len(values) > 1: select_folder.QtPrompt( - self, values, allow_cancel=False, toggleMulti=True, - CbLabel="Select Position folder(s) to process:" + self, + values, + allow_cancel=False, + toggleMulti=True, + CbLabel="Select Position folder(s) to process:", ) if select_folder.cancel: self.logger.info( - 'Process aborted by the user ' - '(cancelled at Postion selection)' + "Process aborted by the user (cancelled at Postion selection)" ) self.stop() return @@ -152,13 +157,13 @@ def start(self): posFoldernames = select_folder.pos_foldernames self.workerProgress(f'Selected folder: "{exp_path}"') - self.workerProgress(' ') - posListFormat = '\n'.join(posFoldernames) - self.workerProgress(f'Selected Positions:\n{posListFormat}') - self.workerProgress(' ') + self.workerProgress(" ") + posListFormat = "\n".join(posFoldernames) + self.workerProgress(f"Selected Positions:\n{posListFormat}") + self.workerProgress(" ") self.workerInitProgressBar(len(posFoldernames)) - + self.thread = QThread() self.worker = workers.reapplyDataPrepWorker(exp_path, posFoldernames) @@ -177,45 +182,52 @@ def start(self): self.thread.started.connect(self.worker.run) self.thread.start() - + def selectChannels(self, ch_name_selector, ch_names, imagesPath, basename): if basename is not None: self.ch_names = ch_names self.basename = basename self.imagesPath = imagesPath browseButton = widgets.browseFileButton( - 'Select .tif file to add and prep', start_dir=imagesPath, - title='Select .tif file to add and prep', ext={'TIFF files': '.tif'} + "Select .tif file to add and prep", + start_dir=imagesPath, + title="Select .tif file to add and prep", + ext={"TIFF files": ".tif"}, ) browseButton.sigPathSelected.connect(self.selectTifFileToAdd) additionalButtons = (browseButton,) else: additionalButtons = [] self.selectChannelWindow = widgets.QDialogListbox( - 'Select channel', - 'Select channel names to process:\n', - ch_names, multiSelection=True, parent=self, - additionalButtons=additionalButtons + "Select channel", + "Select channel names to process:\n", + ch_names, + multiSelection=True, + parent=self, + additionalButtons=additionalButtons, ) self.selectChannelWindow.exec_() if self.selectChannelWindow.cancel: self.worker.abort = True self.worker.selectedChannels = self.selectChannelWindow.selectedItemsText self.worker.waitCond.wakeAll() - + def selectTifFileToAdd(self, tif_file_path): tif_filename = os.path.splitext(os.path.basename(tif_file_path))[0] win = apps.filenameDialog( - ext='.tif', basename=self.basename, - title='Insert a name for new channel', - hintText='Insert a name for the new channel', - allowEmpty=False, defaultEntry=tif_filename, - existingNames=self.ch_names, parent=self.selectChannelWindow + ext=".tif", + basename=self.basename, + title="Insert a name for new channel", + hintText="Insert a name for the new channel", + allowEmpty=False, + defaultEntry=tif_filename, + existingNames=self.ch_names, + parent=self.selectChannelWindow, ) win.exec_() if win.cancel: return - + newTifFilePath = os.path.join(self.imagesPath, win.filename) try: self.logger.info(f'Copying and renaming "{tif_filename}.tif" file...') @@ -230,7 +242,7 @@ def selectTifFileToAdd(self, tif_file_path): self.selectChannelWindow.listBox.addItem(newItem) self.selectChannelWindow.listBox.clearSelection() newItem.setSelected(True) - + def warnCopyTifFileFailed(self, tif_file_path, newTifFilePath, error): tifFilename = os.path.basename(tif_file_path) msg = widgets.myMessageBox(showCentered=False, wrapText=False) @@ -242,65 +254,62 @@ def warnCopyTifFileFailed(self, tif_file_path, newTifFilePath, error): Copy to: {newTifFilePath} """) msg.setDetailedText(error) - msg.critical(self.selectChannelWindow, 'Copy .tif file failed', txt) - + msg.critical(self.selectChannelWindow, "Copy .tif file failed", txt) + def criticalNotValidFolder(self, path: os.PathLike): txt = html_utils.paragraph( - 'The selected folder:

    ' - f'{path}

    ' - 'is not a valid folder. ' - 'Select a folder that contains the Position_n folders' + "The selected folder:

    " + f"{path}

    " + "is not a valid folder. " + "Select a folder that contains the Position_n folders" ) msg = widgets.myMessageBox() msg.addShowInFileManagerButton(path) - msg.critical( - self, 'Incompatible folder', txt, - buttonsTexts=('Ok',) - ) - + msg.critical(self, "Incompatible folder", txt, buttonsTexts=("Ok",)) + def criticalNoChannelsFound(self, images_path): - err_title = 'Channel names not found' + err_title = "Channel names not found" err_msg = html_utils.paragraph( - 'The following folder

    ' - '{images_path}

    ' - 'does not valid channel files.
    ' + "The following folder

    " + "{images_path}

    " + "does not valid channel files.
    " ) msg = widgets.myMessageBox() msg.addShowInFileManagerButton(images_path) msg.critical(self, err_title, err_msg) self.logger.info(err_title) self.stop() - + def stop(self): self.startButton.show() self.stopButton.hide() - if hasattr(self, 'worker'): + if hasattr(self, "worker"): self.worker.abort = True - + @exception_handler def workerInitProgressBar(self, maximum): self.progressBar.setValue(0) self.progressBar.setMaximum(maximum) - + @exception_handler def workerUpdateProgressBar(self): self.progressBar.update(1) - + @exception_handler def workerProgress(self, txt): self.logger.info(txt) self.logConsole.append(txt) - + @exception_handler def workerProgressBar(self, txt): self.logger.info(txt) self.logConsole.write(txt) - + @exception_handler def workerCritical(self, error): raise error - + @exception_handler def workerFinished(self): self.startButton.show() @@ -309,12 +318,14 @@ def workerFinished(self): if self.worker.abort: msg = widgets.myMessageBox() msg.warning( - self, 'Process stopped', - html_utils.paragraph('Data prep process stopped!') + self, + "Process stopped", + html_utils.paragraph("Data prep process stopped!"), ) else: msg = widgets.myMessageBox() msg.information( - self, 'Process completed', - html_utils.paragraph('Data prep process completed!') - ) \ No newline at end of file + self, + "Process completed", + html_utils.paragraph("Data prep process completed!"), + ) diff --git a/cellacdc/utils/resize/__init__.py b/cellacdc/tools/resize/__init__.py similarity index 69% rename from cellacdc/utils/resize/__init__.py rename to cellacdc/tools/resize/__init__.py index d456c9f4f..e3335d87c 100644 --- a/cellacdc/utils/resize/__init__.py +++ b/cellacdc/tools/resize/__init__.py @@ -9,7 +9,8 @@ import shutil import pandas as pd -from ... import load, myutils, io +from ... import load, utils, io + def process_frame(imgs, images_indx, factor, is_segm): T, Z = images_indx @@ -19,6 +20,7 @@ def process_frame(imgs, images_indx, factor, is_segm): img_resized = ndimage.zoom(imgs[T, Z], factor, order=0) return images_indx, img_resized + def process_frames(imgs, factor, is_segm=False): results = [] @@ -27,18 +29,20 @@ def process_frames(imgs, factor, is_segm=False): images_indxs = list(itertools.product(range(T), range(Z))) images = None with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_frame, imgs, images_indx, factor, is_segm) for images_indx in images_indxs] + futures = [ + executor.submit(process_frame, imgs, images_indx, factor, is_segm) + for images_indx in images_indxs + ] for future in futures: results.append(future.result()) - if not results: raise TypeError("No images to process (or this funciton has a funky error)") - + images_indx, img_resized = results[0] Y, X = img_resized.shape images = np.zeros((T, Z, Y, X), dtype=img_resized.dtype) - + for result in results: images_indx, img_resized = result @@ -47,6 +51,7 @@ def process_frames(imgs, factor, is_segm=False): return images + def load_images(images_path_in, file_path): path = os.path.join(images_path_in, file_path) @@ -64,15 +69,16 @@ def load_images(images_path_in, file_path): return imgs -def save_images(images, filename_in, images_path_out, text_to_append=''): + +def save_images(images, filename_in, images_path_out, text_to_append=""): if images is None: print("No images to save.") return - + images = np.squeeze(images) filename_in_noext, ext = os.path.splitext(filename_in) - filename_out = f'{filename_in_noext}{text_to_append}{ext}' - + filename_out = f"{filename_in_noext}{text_to_append}{ext}" + images_path_out_file = os.path.join(images_path_out, filename_out) if images_path_out_file.endswith(".tif"): @@ -81,26 +87,26 @@ def save_images(images, filename_in, images_path_out, text_to_append=''): io.savez_compressed(images_path_out_file, images) print(f"Sampling completed. File saved in:") - print(f"{images_path_out_file}\n") + print(f"{images_path_out_file}\n") + -def resize_imgs(images_path_in, factor, images_path_out=None, text_to_append=''): +def resize_imgs(images_path_in, factor, images_path_out=None, text_to_append=""): if images_path_out is None: images_path_out = images_path_in - - list_dir = myutils.listdir(images_path_in) - + + list_dir = utils.listdir(images_path_in) + # Get a list of all PNG files in the input folder images_files = [ - file for file in list_dir if ( - file.endswith(".tif") - or file.endswith('aligned.npz') - ) + file + for file in list_dir + if (file.endswith(".tif") or file.endswith("aligned.npz")) ] if not images_files: print("No image files found in the specified folder.") return - + for filename in images_files: print(f"Processing {filename}...") @@ -109,32 +115,32 @@ def resize_imgs(images_path_in, factor, images_path_out=None, text_to_append='') images = process_frames(images, factor) save_images( - images, filename, images_path_out=images_path_out, - text_to_append=text_to_append + images, + filename, + images_path_out=images_path_out, + text_to_append=text_to_append, ) -def edit_subs_bkgrROIs( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + +def edit_subs_bkgrROIs(images_path_in, factor, images_path_out=None, text_to_append=""): if images_path_out is None: images_path_out = images_path_in - - list_dir = myutils.listdir(images_path_in) + + list_dir = utils.listdir(images_path_in) bkgrROIs_jsons = [file for file in list_dir if file.endswith("bkgrROIs.json")] bkgrROIs_npzs = [file for file in list_dir if file.endswith("bkgrROIs.npz")] - + # Is this fine to interpolate bkgrROIs_npzs or do I get the same issues as # with the segmentaion masks?" - if not bkgrROIs_jsons and not bkgrROIs_npzs: return for bkgrROIs_json_file in bkgrROIs_jsons: print(f"Processing {bkgrROIs_json_file}...") bkgrROIs_json_file_path = os.path.join(images_path_in, bkgrROIs_json_file) - with open(bkgrROIs_json_file_path, 'r') as file: + with open(bkgrROIs_json_file_path, "r") as file: data = json.load(file) data_scaled = [] @@ -152,53 +158,51 @@ def edit_subs_bkgrROIs( data_part[key] = value_scaled data_scaled.append(data_part) - - bkgrROIs_json_file_out = myutils.append_text_filename( + + bkgrROIs_json_file_out = utils.append_text_filename( bkgrROIs_json_file, text_to_append ) - images_path_out_file = os.path.join( - images_path_out, bkgrROIs_json_file_out - ) + images_path_out_file = os.path.join(images_path_out, bkgrROIs_json_file_out) - with open(images_path_out_file, 'w') as file: + with open(images_path_out_file, "w") as file: json.dump(data_scaled, file) - print(f'bkgrROIs.json files edited and saved in:') - print(f'{images_path_out_file}\n') + print(f"bkgrROIs.json files edited and saved in:") + print(f"{images_path_out_file}\n") for bkgrROIs_npz_file in bkgrROIs_npzs: - print('WARNING: Not tested yet') + print("WARNING: Not tested yet") print(f"Processing {bkgrROIs_npz_file}...") - + images = load_images(images_path_in, bkgrROIs_npz_file) images = process_frames(images, factor) save_images( - images, bkgrROIs_npz_file, - images_path_out=images_path_out, - text_to_append=text_to_append + images, + bkgrROIs_npz_file, + images_path_out=images_path_out, + text_to_append=text_to_append, ) -def edit_acdc_csvs( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + +def edit_acdc_csvs(images_path_in, factor, images_path_out=None, text_to_append=""): if images_path_out is None: images_path_out = images_path_in - + columns_for_scaling = ["x_centroid", "y_centroid"] acdc_csvs = load.get_acdc_output_files(images_path_in) if not acdc_csvs: return - + for acdc_csv_file in acdc_csvs: print(f"Processing {acdc_csv_file}...") acdc_csv_file_path = os.path.join(images_path_in, acdc_csv_file) if not os.path.exists(acdc_csv_file_path): continue - + try: acdc_df = pd.read_csv(acdc_csv_file_path) except PermissionError as e: @@ -208,26 +212,21 @@ def edit_acdc_csvs( for column in columns_for_scaling: acdc_df[column] = (acdc_df[column] * factor).astype(int) - acdc_csv_file_out = myutils.append_text_filename( - acdc_csv_file, text_to_append - ) + acdc_csv_file_out = utils.append_text_filename(acdc_csv_file, text_to_append) images_path_out_file = os.path.join(images_path_out, acdc_csv_file_out) acdc_df.to_csv(images_path_out_file, index=False) print(f"Modified CSV saved to:") print(f"{images_path_out_file}\n") -def edit_metadata( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + +def edit_metadata(images_path_in, factor, images_path_out=None, text_to_append=""): if images_path_out is None: images_path_out = images_path_in - - list_dir = myutils.listdir(images_path_in) + + list_dir = utils.listdir(images_path_in) data_to_scale_int = ["SizeX", "SizeY"] data_to_scale_float = ["PhysicalSizeY", "PhysicalSizeX"] - metadata_files = [ - file for file in list_dir if file.endswith("metadata.csv") - ] + metadata_files = [file for file in list_dir if file.endswith("metadata.csv")] if not metadata_files: return @@ -235,7 +234,7 @@ def edit_metadata( for metadata_file in metadata_files: print(f"Processing {metadata_file}...") metadata_file_path = os.path.join(images_path_in, metadata_file) - with open(metadata_file_path, 'r') as file: + with open(metadata_file_path, "r") as file: metadata = file.read() new_metadata = "" @@ -249,27 +248,25 @@ def edit_metadata( new_metadata += ",".join(entries) + "\n" - metadata_file_out = myutils.append_text_filename( - metadata_file, text_to_append - ) + metadata_file_out = utils.append_text_filename(metadata_file, text_to_append) images_path_out_file = os.path.join(images_path_out, metadata_file_out) - with open(images_path_out_file, 'w') as file: + with open(images_path_out_file, "w") as file: file.write(new_metadata) print(f"Metadata edited and saved in:") print(f"{images_path_out_file}\n") + def edit_lost_centroids( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + images_path_in, factor, images_path_out=None, text_to_append="" +): if images_path_out is None: images_path_out = images_path_in - - list_dir = myutils.listdir(images_path_in) - + + list_dir = utils.listdir(images_path_in) + lost_centroids_jsons = [ - file for file in list_dir - if file.endswith("tracked_lost_centroids.json") + file for file in list_dir if file.endswith("tracked_lost_centroids.json") ] if not lost_centroids_jsons: @@ -277,10 +274,10 @@ def edit_lost_centroids( for lost_centroids_json in lost_centroids_jsons: print(f"Processing {lost_centroids_json}...") - + lost_centroids_json_path = os.path.join(images_path_in, lost_centroids_json) - with open(lost_centroids_json_path, 'r') as file: + with open(lost_centroids_json_path, "r") as file: lost_centroids = json.load(file) for frame_i, frame in lost_centroids.items(): @@ -293,24 +290,21 @@ def edit_lost_centroids( frame_new.append(new_centroid) lost_centroids[frame_i] = frame_new - lost_centroids_json_out = myutils.append_text_filename( + lost_centroids_json_out = utils.append_text_filename( lost_centroids_json, text_to_append ) - images_path_out_file = os.path.join( - images_path_out, lost_centroids_json_out - ) - with open(images_path_out_file, 'w') as file: + images_path_out_file = os.path.join(images_path_out, lost_centroids_json_out) + with open(images_path_out_file, "w") as file: json.dump(lost_centroids, file, indent=4) - + print(f"Lost centroids edited and saved in:") print(f"{images_path_out_file}\n") -def resize_segms( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + +def resize_segms(images_path_in, factor, images_path_out=None, text_to_append=""): if images_path_out is None: images_path_out = images_path_in - + segm_npzs = load.get_segm_files(images_path_in) if not segm_npzs: @@ -318,26 +312,32 @@ def resize_segms( for segm_npz_file in segm_npzs: print(f"Processing {segm_npz_file}...") - + images = load_images(images_path_in, segm_npz_file) images = process_frames(images, factor, is_segm=True) save_images( - images, segm_npz_file, images_path_out=images_path_out, - text_to_append=text_to_append + images, + segm_npz_file, + images_path_out=images_path_out, + text_to_append=text_to_append, ) + def copy_aux_files(images_path_in, images_path_out=None): if images_path_out is None: images_path_out = images_path_in - list_dir = myutils.listdir(images_path_in) + list_dir = utils.listdir(images_path_in) files_endings = [ - "_last_tracked_i.txt", "_combine_metrics.ini", "_segm_hyperparams.ini" + "_last_tracked_i.txt", + "_combine_metrics.ini", + "_segm_hyperparams.ini", ] aux_files = [ - file for file in list_dir + file + for file in list_dir if any(file.endswith(ending) for ending in files_endings) ] for aux_file in aux_files: @@ -350,31 +350,42 @@ def copy_aux_files(images_path_in, images_path_out=None): print(f"File {aux_file} copied to") print(f"{images_path_out}\n") -def run( - images_path_in, factor, images_path_out=None, text_to_append='' - ): + +def run(images_path_in, factor, images_path_out=None, text_to_append=""): resize_imgs( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, ) edit_subs_bkgrROIs( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, ) copy_aux_files(images_path_in, images_path_out=images_path_out) resize_segms( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, ) edit_acdc_csvs( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, ) edit_metadata( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, ) edit_lost_centroids( - images_path_in, factor, text_to_append=text_to_append, - images_path_out=images_path_out - ) \ No newline at end of file + images_path_in, + factor, + text_to_append=text_to_append, + images_path_out=images_path_out, + ) diff --git a/cellacdc/utils/resize/util.py b/cellacdc/tools/resize/util.py similarity index 74% rename from cellacdc/utils/resize/util.py rename to cellacdc/tools/resize/util.py index 9199f0d1a..0ae039e72 100644 --- a/cellacdc/utils/resize/util.py +++ b/cellacdc/tools/resize/util.py @@ -1,25 +1,31 @@ -from ... import myutils, workers, widgets, html_utils +from ... import utils, workers, widgets, html_utils from ... import apps from ..base import NewThreadMultipleExpBaseUtil + class ResizePositionsUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths self._parent = parent - + def runWorker(self): self.worker = workers.ResizeUtilWorker(self) self.worker.sigSetResizeProps.connect(self.setResizeProps) super().runWorker(self.worker) - + def setResizeProps(self, input_path): win = apps.ResizeUtilProps(input_path=input_path, parent=self._parent) win.exec_() @@ -27,19 +33,19 @@ def setResizeProps(self, input_path): if win.cancel: self.worker.waitCond.wakeAll() return - + self.worker.resizeFactor = win.resizeFactor self.worker.textToAppend = win.textToAppend self.worker.expFolderpathOut = win.expFolderpathOut self.worker.waitCond.wakeAll() - + def showEvent(self, event): self.runWorker() - + def workerFinished(self, worker): super().workerFinished(worker) - txt = 'Resizing data process completed.' + txt = "Resizing data process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) - self.close() \ No newline at end of file + msg.information(self, "Process completed", html_utils.paragraph(txt)) + self.close() diff --git a/cellacdc/utils/stack2Dinto3Dsegm.py b/cellacdc/tools/stack2Dinto3Dsegm.py similarity index 65% rename from cellacdc/utils/stack2Dinto3Dsegm.py rename to cellacdc/tools/stack2Dinto3Dsegm.py index c6c0e6b64..f163336cd 100644 --- a/cellacdc/utils/stack2Dinto3Dsegm.py +++ b/cellacdc/tools/stack2Dinto3Dsegm.py @@ -1,67 +1,71 @@ -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class Stack2DsegmTo3Dsegm(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, SizeZ: int, parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + SizeZ: int, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths self._SizeZ = SizeZ - + def runWorker(self): self.worker = workers.Stack2DsegmTo3Dsegm(self, self._SizeZ) self.worker.sigAskAppendName.connect(self.askAppendName) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def askAppendName(self, basename, existingEndnames): - helpText = ( - """ + helpText = """ The new 3D segmentation file will be saved with a different file name.

    Insert a name to append to the end of the new name. The rest of the name will be the same as the original file. """ - ) win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the new 3D segmentation file:', - existingNames=existingEndnames, - helpText=helpText, - allowEmpty=False + hintText="Insert a name for the new 3D segmentation file:", + existingNames=existingEndnames, + helpText=helpText, + allowEmpty=False, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = '3D segmentation mask creation process aborted.' + txt = "3D segmentation mask creation process aborted." else: - txt = '3D segmentation mask creation process completed.' + txt = "3D segmentation mask creation process completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/utils/toImageJroi.py b/cellacdc/tools/toImageJroi.py similarity index 59% rename from cellacdc/utils/toImageJroi.py rename to cellacdc/tools/toImageJroi.py index 7d3724796..c9a864541 100644 --- a/cellacdc/utils/toImageJroi.py +++ b/cellacdc/tools/toImageJroi.py @@ -1,29 +1,35 @@ -from .. import myutils, workers, widgets, html_utils +from .. import utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class toImageRoiUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.ToImajeJroiWorker(self) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def workerFinished(self, worker): super().workerFinished(worker) - txt = 'Converting to ImageJ ROIs completed.' + txt = "Converting to ImageJ ROIs completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) - self.close() \ No newline at end of file + msg.information(self, "Process completed", html_utils.paragraph(txt)) + self.close() diff --git a/cellacdc/utils/toObjCoords.py b/cellacdc/tools/toObjCoords.py similarity index 59% rename from cellacdc/utils/toObjCoords.py rename to cellacdc/tools/toObjCoords.py index 7d8cfa59a..632563e69 100644 --- a/cellacdc/utils/toObjCoords.py +++ b/cellacdc/tools/toObjCoords.py @@ -1,29 +1,35 @@ -from .. import myutils, workers, widgets, html_utils +from .. import utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class toObjCoordsUtil(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, parent=None): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - + def runWorker(self): self.worker = workers.ToObjCoordsWorker(self) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def workerFinished(self, worker): super().workerFinished(worker) - txt = 'Converting to object coordinates completed.' + txt = "Converting to object coordinates completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) - msg.information(self, 'Process completed', html_utils.paragraph(txt)) - self.close() \ No newline at end of file + msg.information(self, "Process completed", html_utils.paragraph(txt)) + self.close() diff --git a/cellacdc/utils/trackSubCellObjects.py b/cellacdc/tools/trackSubCellObjects.py similarity index 65% rename from cellacdc/utils/trackSubCellObjects.py rename to cellacdc/tools/trackSubCellObjects.py index 87d25326f..a70c02615 100644 --- a/cellacdc/utils/trackSubCellObjects.py +++ b/cellacdc/tools/trackSubCellObjects.py @@ -1,25 +1,30 @@ -from .. import apps, myutils, workers, widgets, html_utils +from .. import apps, utils, workers, widgets, html_utils from .base import NewThreadMultipleExpBaseUtil + class TrackSubCellFeatures(NewThreadMultipleExpBaseUtil): def __init__( - self, expPaths, app, title: str, infoText: str, - progressDialogueTitle: str, trackSubCellObjParams: dict, - parent=None - ): - module = myutils.get_module_name(__file__) + self, + expPaths, + app, + title: str, + infoText: str, + progressDialogueTitle: str, + trackSubCellObjParams: dict, + parent=None, + ): + module = utils.get_module_name(__file__) super().__init__( - expPaths, app, title, module, infoText, progressDialogueTitle, - parent=parent + expPaths, app, title, module, infoText, progressDialogueTitle, parent=parent ) self.expPaths = expPaths - self.trackingMode = trackSubCellObjParams['how'] - self.IoAthresh = trackSubCellObjParams['IoA'] - self.relabelSubObjLab = trackSubCellObjParams['relabelSubObjLab'] - self.createThirdSegm = trackSubCellObjParams['createThirdSegm'] - self.thirdSegmAppendedText = trackSubCellObjParams['thirdSegmAppendedText'] - + self.trackingMode = trackSubCellObjParams["how"] + self.IoAthresh = trackSubCellObjParams["IoA"] + self.relabelSubObjLab = trackSubCellObjParams["relabelSubObjLab"] + self.createThirdSegm = trackSubCellObjParams["createThirdSegm"] + self.thirdSegmAppendedText = trackSubCellObjParams["thirdSegmAppendedText"] + def runWorker(self): self.worker = workers.TrackSubCellObjectsWorker(self) self.worker.sigAskAppendName.connect(self.askAppendName) @@ -28,10 +33,10 @@ def runWorker(self): ) self.worker.sigAborted.connect(self.workerAborted) super().runWorker(self.worker) - + def showEvent(self, event): self.runWorker() - + def criticalNotEnoughSegmFiles(self, exp_path): text = html_utils.paragraph(f""" The following experiment folder

    @@ -44,50 +49,46 @@ def criticalNotEnoughSegmFiles(self, exp_path): """) msg = widgets.myMessageBox(wrapText=False, showCentered=False) msg.addShowInFileManagerButton(exp_path) - msg.critical( - self, 'Not enough segmentation files!', text - ) + msg.critical(self, "Not enough segmentation files!", text) self.worker.abort = True self.worker.waitCond.wakeAll() - + def askAppendName(self, basename, existingEndnames): - helpText = ( - """ + helpText = """ The segmentation file containing the tracked sub-cellular objects will be saved with a different file name.

    Insert a name to append to the end of the new name. The rest of the name will be the same as the original file. """ - ) win = apps.filenameDialog( basename=basename, - hintText='Insert a name for the new, tracked segmentation file:', - existingNames=existingEndnames, - helpText=helpText, - allowEmpty=False + hintText="Insert a name for the new, tracked segmentation file:", + existingNames=existingEndnames, + helpText=helpText, + allowEmpty=False, ) win.exec_() if win.cancel: self.worker.abort = True self.worker.waitCond.wakeAll() return - + self.worker.appendedName = win.entryText self.worker.waitCond.wakeAll() - + def workerAborted(self): self.workerFinished(None, aborted=True) - + def workerFinished(self, worker, aborted=False): if aborted: - txt = 'Tracking sub-cellular objects aborted.' + txt = "Tracking sub-cellular objects aborted." else: - txt = 'Tracking sub-cellular objects completed.' + txt = "Tracking sub-cellular objects completed." self.logger.info(txt) msg = widgets.myMessageBox(wrapText=False, showCentered=False) if aborted: - msg.warning(self, 'Process completed', html_utils.paragraph(txt)) + msg.warning(self, "Process completed", html_utils.paragraph(txt)) else: - msg.information(self, 'Process completed', html_utils.paragraph(txt)) + msg.information(self, "Process completed", html_utils.paragraph(txt)) super().workerFinished(worker) - self.close() \ No newline at end of file + self.close() diff --git a/cellacdc/trackers/BABY/BABY_tracker.py b/cellacdc/trackers/BABY/BABY_tracker.py index 813f13203..1466eb3e6 100644 --- a/cellacdc/trackers/BABY/BABY_tracker.py +++ b/cellacdc/trackers/BABY/BABY_tracker.py @@ -6,110 +6,104 @@ from baby import modelsets from baby import BabyCrawler -from cellacdc import myutils +from cellacdc import utils from cellacdc.trackers import BABY from ..CellACDC import CellACDC_tracker + class AvailableModels: values = BABY.BABY_MODELS + class tracker: def __init__( - self, - model_type: AvailableModels='yeast-alcatras-brightfield-sCMOS-60x-5z', - ): + self, + model_type: AvailableModels = "yeast-alcatras-brightfield-sCMOS-60x-5z", + ): brain = modelsets.get(model_type) self.crawler = BabyCrawler(brain) - + def _preprocess(self, image, swap_YX_axes_to_XY): if image.ndim == 2: image = image[np.newaxis] - - image = myutils.to_uint16(image) - - # BABY requires z-slices as last dimension while Cell-ACDC takes + image = utils.to_uint16(image) + + # BABY requires z-slices as last dimension while Cell-ACDC takes # Z, Y, X input if swap_YX_axes_to_XY: dst_axes = (2, 1, 0) else: - dst_axes = (1, 2, 0) - + dst_axes = (1, 2, 0) + image = image.transpose(dst_axes) - + return image - + def iterate_result_series(self, result_series, swap_YX_axes_to_XY): for frame_i, result in enumerate(result_series): - contour_masks = result[0]['edgemasks'] - IDs = result[0]['cell_label'] + contour_masks = result[0]["edgemasks"] + IDs = result[0]["cell_label"] for ID, contour_mask in zip(IDs, contour_masks): - mask = scipy.ndimage.binary_fill_holes(contour_mask) + mask = scipy.ndimage.binary_fill_holes(contour_mask) if swap_YX_axes_to_XY: - mask = np.swapaxes(mask, 0, 1) + mask = np.swapaxes(mask, 0, 1) yield frame_i, mask, ID - - def track_baby_segm_data( - self, segm_data, result_series, swap_YX_axes_to_XY - ): + + def track_baby_segm_data(self, segm_data, result_series, swap_YX_axes_to_XY): tracked_data = np.zeros_like(segm_data) - result_generator = self.iterate_result_series( - result_series, swap_YX_axes_to_XY - ) - for frame_i, mask, ID in result_generator: + result_generator = self.iterate_result_series(result_series, swap_YX_axes_to_XY) + for frame_i, mask, ID in result_generator: tracked_data[frame_i][mask] = ID return tracked_data - - def track_external_segm_data( - self, segm_data, result_series, swap_YX_axes_to_XY - ): - result_generator = self.iterate_result_series( - result_series, swap_YX_axes_to_XY - ) + + def track_external_segm_data(self, segm_data, result_series, swap_YX_axes_to_XY): + result_generator = self.iterate_result_series(result_series, swap_YX_axes_to_XY) old_IDs_tracks = {} tracked_IDs_tracks = {} - for frame_i, mask, track_ID in result_generator: + for frame_i, mask, track_ID in result_generator: oldID = segm_data[frame_i][mask][0] if oldID == 0: continue - + if frame_i not in old_IDs_tracks: old_IDs_tracks[frame_i] = [oldID] tracked_IDs_tracks[frame_i] = [track_ID] else: old_IDs_tracks[frame_i].append(oldID) tracked_IDs_tracks[frame_i].append(track_ID) - + tracked_data = segm_data.copy() for frame_i in old_IDs_tracks.keys(): tracked_IDs = tracked_IDs_tracks[frame_i] old_IDs = old_IDs_tracks[frame_i] - + lab = self.segm_video[frame_i] rp = skimage.measure.regionprops(lab) IDs_curr_untracked = [obj.label for obj in rp] - - uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked)))+1 + + uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked))) + 1 tracked_lab = CellACDC_tracker.indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, uniqueID + old_IDs, tracked_IDs, IDs_curr_untracked, lab.copy(), rp, uniqueID ) tracked_data[frame_i] = tracked_lab - + return tracked_data - + def track( - self, segm_data, intensity_data, - resegment_data=True, - swap_YX_axes_to_XY=True, - refine_outlines=True, - assign_mothers=True, - with_edgemasks=True, - with_volumes=True, - parallel=False, - signals=None - ): + self, + segm_data, + intensity_data, + resegment_data=True, + swap_YX_axes_to_XY=True, + refine_outlines=True, + assign_mothers=True, + with_edgemasks=True, + with_volumes=True, + parallel=False, + signals=None, + ): """_summary_ Parameters @@ -119,33 +113,33 @@ def track( intensity_data : (T, Y, X) or (T, Z, Y, X) Input intensity data resegment_data : bool, optional - If True, BABY will ignore the input `segm_data` and perform - segmentation de novo. - If False, BABY will only track the input `segm_data`. + If True, BABY will ignore the input `segm_data` and perform + segmentation de novo. + If False, BABY will only track the input `segm_data`. Default is True Returns ------- np.ndarray with the same shape as `segm_data` Tracked data - """ + """ image_series = [ self._preprocess(image, swap_YX_axes_to_XY) for image in intensity_data ] - + result_series = [] for image in image_series: result = self.crawler.step( - image[None, ...], + image[None, ...], refine_outlines=refine_outlines, assign_mothers=assign_mothers, with_edgemasks=with_edgemasks, with_volumes=with_volumes, - parallel=parallel + parallel=parallel, ) result_series.append(result) self.updateGuiProgressBar(signals) - + if resegment_data: tracked_data = self.track_baby_segm_data( segm_data, result_series, swap_YX_axes_to_XY @@ -154,20 +148,18 @@ def track( tracked_data = self.track_external_segm_data( segm_data, result_series, swap_YX_axes_to_XY ) - + return tracked_data def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) return - if hasattr(signals, 'progressBar'): + if hasattr(signals, "progressBar"): signals.progressBar.emit(1) - - diff --git a/cellacdc/trackers/BABY/__init__.py b/cellacdc/trackers/BABY/__init__.py index 6e5b1f9bb..3e1ae7dca 100644 --- a/cellacdc/trackers/BABY/__init__.py +++ b/cellacdc/trackers/BABY/__init__.py @@ -1,8 +1,9 @@ -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_baby() +utils.check_install_baby() from baby import modelsets + meta = modelsets.meta() -BABY_MODELS = list(meta.keys()) \ No newline at end of file +BABY_MODELS = list(meta.keys()) diff --git a/cellacdc/trackers/BayesianTracker/BayesianTracker_tracker.py b/cellacdc/trackers/BayesianTracker/BayesianTracker_tracker.py index 7b55abcca..304e7b71d 100755 --- a/cellacdc/trackers/BayesianTracker/BayesianTracker_tracker.py +++ b/cellacdc/trackers/BayesianTracker/BayesianTracker_tracker.py @@ -15,20 +15,25 @@ from tqdm import tqdm + class tracker: def __init__(self, **params): self.params = params def track( - self, segm_video, image, signals=None, - export_to: os.PathLike=None, verbose=False - ): - FEATURES = self.params['features'] + self, + segm_video, + image, + signals=None, + export_to: os.PathLike = None, + verbose=False, + ): + FEATURES = self.params["features"] if segm_video.ndim == 3: # btrack requires 4D data. Add extra dimension for 3D data segm_video = segm_video[:, np.newaxis, :, :] - + if image is not None: if image.ndim == 3: image = image[:, np.newaxis, :, :] @@ -44,19 +49,18 @@ def track( ) if signals is not None: - signals.progress.emit('Running BayesianTracker...') + signals.progress.emit("Running BayesianTracker...") # initialise a tracker session using a context manager with btrack.BayesianTracker() as tracker: - # configure the tracker using a config file - tracker.configure_from_file(self.params['model_path']) - tracker.verbose = self.params['verbose'] - update_method = self.params['update_method'] + tracker.configure_from_file(self.params["model_path"]) + tracker.verbose = self.params["verbose"] + update_method = self.params["update_method"] - if update_method == 'APPROXIMATE': + if update_method == "APPROXIMATE": tracker.update_method = getattr(BayesianUpdates, update_method) - tracker.max_search_radius = self.params['max_search_radius'] + tracker.max_search_radius = self.params["max_search_radius"] # add features if FEATURES: @@ -66,13 +70,13 @@ def track( tracker.append(obj_from_arr) # set the volume - tracker.volume=self.params['volume'] + tracker.volume = self.params["volume"] # track them (in interactive mode) - tracker.track(step_size=self.params['step_size']) + tracker.track(step_size=self.params["step_size"]) # generate hypotheses and run the global optimizer - if self.params['optimize']: + if self.params["optimize"]: tracker.optimize() # save tracks @@ -88,11 +92,9 @@ def track( ) return tracked_video - def _from_tracks_to_labels( - self, tracks, segm_video, signals=None, verbose=False - ): + def _from_tracks_to_labels(self, tracks, segm_video, signals=None, verbose=False): if signals is not None: - signals.progress.emit('Applying BayesianTracker tracks...') + signals.progress.emit("Applying BayesianTracker tracks...") # Label the segm_video according to tracks tracked_video = np.zeros_like(segm_video) @@ -107,12 +109,12 @@ def _from_tracks_to_labels( tracked_IDs = [] for track in tracks: track_dict = track.to_dict() - if frame_i not in track_dict['t']: + if frame_i not in track_dict["t"]: continue - df = pd.DataFrame(track.to_dict()).set_index('t').loc[frame_i] + df = pd.DataFrame(track.to_dict()).set_index("t").loc[frame_i] - yc, xc = df['y'], df['x'] + yc, xc = df["y"], df["x"] try: old_ID = lab[int(yc), int(xc)] except Exception as e: @@ -122,35 +124,34 @@ def _from_tracks_to_labels( continue old_IDs.append(old_ID) - tracked_IDs.append(df['ID']) + tracked_IDs.append(df["ID"]) if not tracked_IDs: # No cells tracked continue - uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked)))+1 + uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked))) + 1 if verbose: - print('-------------------------') - print(f'Tracking frame n. {frame_i+1}') + print("-------------------------") + print(f"Tracking frame n. {frame_i + 1}") for old_ID, tracked_ID in zip(old_IDs, tracked_IDs): - print(f'Tracking ID {old_ID} --> {tracked_ID}') - print('-------------------------') + print(f"Tracking ID {old_ID} --> {tracked_ID}") + print("-------------------------") tracked_lab = CellACDC_tracker.indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, uniqueID + old_IDs, tracked_IDs, IDs_curr_untracked, lab.copy(), rp, uniqueID ) tracked_video[frame_i] = tracked_lab self.updateGuiProgressBar(signals) return tracked_video - + def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) diff --git a/cellacdc/trackers/BayesianTracker/__init__.py b/cellacdc/trackers/BayesianTracker/__init__.py index 8576fda0a..a63358d6f 100755 --- a/cellacdc/trackers/BayesianTracker/__init__.py +++ b/cellacdc/trackers/BayesianTracker/__init__.py @@ -2,19 +2,20 @@ try: import btrack - from cellacdc.myutils import get_package_version - version = get_package_version('btrack') - minor = version.split('.')[1] + from cellacdc.utils import get_package_version + + version = get_package_version("btrack") + minor = version.split(".")[1] if int(minor) < 5: UPGRADE_BTRACK = True except Exception as e: pass -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_package( - 'Bayesian Tracker', - import_pkg_name='btrack', - pypi_name='btrack', - force_upgrade=UPGRADE_BTRACK +utils.check_install_package( + "Bayesian Tracker", + import_pkg_name="btrack", + pypi_name="btrack", + force_upgrade=UPGRADE_BTRACK, ) diff --git a/cellacdc/trackers/CellACDC/CellACDC_tracker.py b/cellacdc/trackers/CellACDC/CellACDC_tracker.py index 07cfe841b..9e9b0c98d 100755 --- a/cellacdc/trackers/CellACDC/CellACDC_tracker.py +++ b/cellacdc/trackers/CellACDC/CellACDC_tracker.py @@ -11,8 +11,16 @@ DEBUG = False -def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, - denom:str='area_prev', IDs=None): + +def calc_Io_matrix( + lab, + prev_lab, + rp, + prev_rp, + IDs_curr_untracked=None, + denom: str = "area_prev", + IDs=None, +): # maybe its faster to calculate IoU not via mask but via area1 / (area1 + area2 - intersection) IDs_prev = [] if IDs_curr_untracked is None: @@ -31,12 +39,10 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, # prev_lab, prev_rp = lab.copy, rp.copy() # lab, rp = prev_lab_temp, prev_rp_temp - if not denom in ['area_prev', 'union']: - raise ValueError( - "Invalid denom value. Use 'area_prev' or 'union'." - ) + if not denom in ["area_prev", "union"]: + raise ValueError("Invalid denom value. Use 'area_prev' or 'union'.") - # prev_label_positions = {ID_prev: np.where(prev_lab == ID_prev)[0] for ID_prev in set(prev_lab) if ID_prev != 0} + # prev_label_positions = {ID_prev: np.where(prev_lab == ID_prev)[0] for ID_prev in set(prev_lab) if ID_prev != 0} # if denom == 'union': # temp_lab = np.zeros(lab.shape, dtype=bool) for j, obj_prev in enumerate(prev_rp): @@ -45,9 +51,9 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, # if IDs is not None and ID_prev not in IDs: # continue - if denom == 'area_prev': # or denom == 'area_curr': + if denom == "area_prev": # or denom == 'area_curr': denom_val = obj_prev.area - + # Get intersecting IDs between current and object in previous frame intersect_IDs, intersects = np.unique( lab[obj_prev.slice][obj_prev.image], return_counts=True @@ -59,7 +65,7 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, if I == 0: continue - if denom == 'union': + if denom == "union": obj_curr = rp_mapper[intersect_ID] # temp_lab[obj_prev.slice][obj_prev.image] = True # temp_lab[obj_curr.slice][obj_curr.image] = True @@ -68,16 +74,23 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, denom_val = obj_prev.area + obj_curr.area - I if denom_val == 0: continue - - idx = idx_mapper[intersect_ID] - IoA = I/denom_val + + idx = idx_mapper[intersect_ID] + IoA = I / denom_val IoA_matrix[idx, j] = IoA return IoA_matrix, IDs_curr_untracked, IDs_prev + def assign( - IoA_matrix, IDs_curr_untracked, IDs_prev, IoA_thresh=0.4, - aggr_track=None, IoA_thresh_aggr=0.4, daughters_list=None, - IDs=None): + IoA_matrix, + IDs_curr_untracked, + IDs_prev, + IoA_thresh=0.4, + aggr_track=None, + IoA_thresh_aggr=0.4, + daughters_list=None, + IDs=None, +): # Determine max IoA between IDs and assign tracked ID if IoA >= IoA_thresh if IoA_matrix.size == 0: return [], [] @@ -88,7 +101,7 @@ def assign( old_IDs = [] if DEBUG: - printl(f'IDs in previous frame: {IDs_prev}') + printl(f"IDs in previous frame: {IDs_prev}") for i, j in enumerate(max_IoA_col_idx): if daughters_list is not None: @@ -102,94 +115,94 @@ def assign( IoA_thresh_temp = IoA_thresh else: IoA_thresh_temp = IoA_thresh - max_IoU = IoA_matrix[i,j] + max_IoU = IoA_matrix[i, j] count = counts_dict[j] if max_IoU >= IoA_thresh_temp: tracked_ID = IDs_prev[j] if count == 1: old_ID = IDs_curr_untracked[i] elif count > 1: - old_ID_idx = IoA_matrix[:,j].argmax() + old_ID_idx = IoA_matrix[:, j].argmax() old_ID = IDs_curr_untracked[old_ID_idx] tracked_IDs.append(tracked_ID) old_IDs.append(old_ID) return old_IDs, tracked_IDs + def log_debugging(what, **kwargs): if not DEBUG: return - - if what == 'start': - printl('----------------START INDEX ASSIGNMENT----------------') + + if what == "start": + printl("----------------START INDEX ASSIGNMENT----------------") printl( - f'Current IDs: {kwargs["IDs_curr_untracked"]}\n' - f'Previous IDs: {kwargs["old_IDs"]}' - ) - if what == 'assign_unique': - assign_unique_new_IDs = kwargs['assign_unique_new_IDs'] - txt = ( - f'Assign new IDs uniquely = {assign_unique_new_IDs}' + f"Current IDs: {kwargs['IDs_curr_untracked']}\n" + f"Previous IDs: {kwargs['old_IDs']}" ) + if what == "assign_unique": + assign_unique_new_IDs = kwargs["assign_unique_new_IDs"] + txt = f"Assign new IDs uniquely = {assign_unique_new_IDs}" printl(txt) - elif what == 'new_untracked_and_assign_unique': - new_untracked_IDs = kwargs['new_untracked_IDs'] - new_tracked_IDs = kwargs['new_tracked_IDs'] - IDs_curr_untracked = kwargs['IDs_curr_untracked'] - old_IDs = kwargs['old_IDs'] + elif what == "new_untracked_and_assign_unique": + new_untracked_IDs = kwargs["new_untracked_IDs"] + new_tracked_IDs = kwargs["new_tracked_IDs"] + IDs_curr_untracked = kwargs["IDs_curr_untracked"] + old_IDs = kwargs["old_IDs"] txt = ( - f'Current IDs: {IDs_curr_untracked}\n' - f'Previous IDs: {old_IDs}\n' - f'New objects that get a new big ID: {new_untracked_IDs}\n' - f'New unique IDs for the new objects: {new_tracked_IDs}' + f"Current IDs: {IDs_curr_untracked}\n" + f"Previous IDs: {old_IDs}\n" + f"New objects that get a new big ID: {new_untracked_IDs}\n" + f"New unique IDs for the new objects: {new_tracked_IDs}" ) printl(txt) - txt = '' + txt = "" for _ID, replacingID in zip(new_untracked_IDs, new_tracked_IDs): - txt = f'{txt}{_ID} --> {replacingID}\n' + txt = f"{txt}{_ID} --> {replacingID}\n" printl(txt) - elif what == 'new_untracked_and_tracked': - new_untracked_IDs = kwargs['new_untracked_IDs'] - new_tracked_IDs = kwargs['new_tracked_IDs'] - new_IDs_in_trackedIDs = kwargs['new_IDs_in_trackedIDs'] - old_IDs = kwargs['old_IDs'] + elif what == "new_untracked_and_tracked": + new_untracked_IDs = kwargs["new_untracked_IDs"] + new_tracked_IDs = kwargs["new_tracked_IDs"] + new_IDs_in_trackedIDs = kwargs["new_IDs_in_trackedIDs"] + old_IDs = kwargs["old_IDs"] txt = ( - f'New tracked IDs that already exists: {new_IDs_in_trackedIDs}\n' - f'Previous IDs: {old_IDs}\n' - f'New objects that get a new big ID: {new_untracked_IDs}\n' - f'New unique IDs for the new objects: {new_tracked_IDs}' + f"New tracked IDs that already exists: {new_IDs_in_trackedIDs}\n" + f"Previous IDs: {old_IDs}\n" + f"New objects that get a new big ID: {new_untracked_IDs}\n" + f"New unique IDs for the new objects: {new_tracked_IDs}" ) printl(txt) - txt = '' + txt = "" for _ID, replacingID in zip(new_IDs_in_trackedIDs, new_tracked_IDs): - txt = f'{txt}{_ID} --> {replacingID}\n' + txt = f"{txt}{_ID} --> {replacingID}\n" printl(txt) - elif what == 'tracked': - old_IDs = kwargs['old_IDs'] - tracked_IDs = kwargs['tracked_IDs'] + elif what == "tracked": + old_IDs = kwargs["old_IDs"] + tracked_IDs = kwargs["tracked_IDs"] txt = ( - f'Old IDs to be tracked: {old_IDs}\n' - f'New IDs replacing old IDs: {tracked_IDs}' + f"Old IDs to be tracked: {old_IDs}\n" + f"New IDs replacing old IDs: {tracked_IDs}" ) printl(txt) - txt = '' + txt = "" for _ID, replacingID in zip(old_IDs, tracked_IDs): - txt = f'{txt}{_ID} --> {replacingID}\n' + txt = f"{txt}{_ID} --> {replacingID}\n" printl(txt) + def indexAssignment( - old_IDs: List[int], - tracked_IDs: List[int], - IDs_curr_untracked: List[int], - lab: 'np.ndarray[int]', - rp: 'regionprops', - uniqueID: int, - remove_untracked=False, - assign_unique_new_IDs=True, - return_assignments=False, - IDs=None - ): - """Replace `old_IDs` in `lab` with `tracked_IDs` while making sure to + old_IDs: List[int], + tracked_IDs: List[int], + IDs_curr_untracked: List[int], + lab: "np.ndarray[int]", + rp: "regionprops", + uniqueID: int, + remove_untracked=False, + assign_unique_new_IDs=True, + return_assignments=False, + IDs=None, +): + """Replace `old_IDs` in `lab` with `tracked_IDs` while making sure to avoid merging IDs. Parameters @@ -205,17 +218,17 @@ def indexAssignment( rp : list of skimage.measure._regionprops.RegionProperties List of RegionProperties of the objects in `lab` uniqueID : int - Starting unique ID that is going to replace those objects whose ID is + Starting unique ID that is going to replace those objects whose ID is not tracked but they might require a new (unique) one to avoid merging. remove_untracked : bool, optional - If True, those objects that were not tracked will be removed. + If True, those objects that were not tracked will be removed. Default is False assign_unique_new_IDs : bool, optional - If True, uses `uniqueID` to replace the ID of the untracked objects. + If True, uses `uniqueID` to replace the ID of the untracked objects. Default is True return_assignments : bool, optional - If True, returns a dictionary where the keys are the untracked - IDs and the values are the unique IDs that replaced untracked IDs. + If True, returns a dictionary where the keys are the untracked + IDs and the values are the unique IDs that replaced untracked IDs. Default is False IDs : list of ints, optional IDs to be used for the calculation of the IoA matrix. If None, @@ -224,97 +237,97 @@ def indexAssignment( Returns ------- tracked_lab : (Y, X) or (Z, Y, X) array of ints - Segmentation masks with IDs replaced according to input tracking + Segmentation masks with IDs replaced according to input tracking information. assignments: dict Returned only if `return_assignments` is True. - """ - log_debugging( - 'start', - IDs_curr_untracked=IDs_curr_untracked, - old_IDs=old_IDs - ) - + """ + log_debugging("start", IDs_curr_untracked=IDs_curr_untracked, old_IDs=old_IDs) + # Replace untracked IDs with tracked IDs and new IDs with increasing num new_untracked_IDs = [ID for ID in IDs_curr_untracked if ID not in old_IDs] tracked_lab = lab assignments = {} - log_debugging( - 'assign_unique', - assign_unique_new_IDs=assign_unique_new_IDs - ) + log_debugging("assign_unique", assign_unique_new_IDs=assign_unique_new_IDs) if new_untracked_IDs and assign_unique_new_IDs: # Relabel new untracked IDs (i.e., new cells) unique IDs if remove_untracked: - new_tracked_IDs = [0]*len(new_untracked_IDs) + new_tracked_IDs = [0] * len(new_untracked_IDs) else: - new_tracked_IDs = [ - uniqueID+i for i in range(len(new_untracked_IDs)) - ] - core.lab_replace_values( - tracked_lab, rp, new_untracked_IDs, new_tracked_IDs - ) + new_tracked_IDs = [uniqueID + i for i in range(len(new_untracked_IDs))] + core.lab_replace_values(tracked_lab, rp, new_untracked_IDs, new_tracked_IDs) assignments.update(dict(zip(new_untracked_IDs, new_tracked_IDs))) log_debugging( - 'new_untracked_and_assign_unique', + "new_untracked_and_assign_unique", IDs_curr_untracked=IDs_curr_untracked, old_IDs=old_IDs, new_untracked_IDs=new_untracked_IDs, - new_tracked_IDs=new_tracked_IDs + new_tracked_IDs=new_tracked_IDs, ) elif new_untracked_IDs and tracked_IDs: # If we don't replace unique new IDs we check that tracked IDs are # not already existing to avoid duplicates - new_IDs_in_trackedIDs = [ - ID for ID in new_untracked_IDs if ID in tracked_IDs - ] - new_tracked_IDs = [ - uniqueID+i for i in range(len(new_IDs_in_trackedIDs)) - ] - core.lab_replace_values( - tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs - ) + new_IDs_in_trackedIDs = [ID for ID in new_untracked_IDs if ID in tracked_IDs] + new_tracked_IDs = [uniqueID + i for i in range(len(new_IDs_in_trackedIDs))] + core.lab_replace_values(tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs) assignments.update(dict(zip(new_IDs_in_trackedIDs, new_tracked_IDs))) log_debugging( - 'new_untracked_and_tracked', + "new_untracked_and_tracked", new_IDs_in_trackedIDs=new_IDs_in_trackedIDs, old_IDs=old_IDs, new_untracked_IDs=new_untracked_IDs, - new_tracked_IDs=new_tracked_IDs + new_tracked_IDs=new_tracked_IDs, ) if tracked_IDs: - core.lab_replace_values( - tracked_lab, rp, old_IDs, tracked_IDs, in_place=True - ) + core.lab_replace_values(tracked_lab, rp, old_IDs, tracked_IDs, in_place=True) assignments.update(dict(zip(old_IDs, tracked_IDs))) log_debugging( - 'tracked', + "tracked", tracked_IDs=tracked_IDs, old_IDs=old_IDs, ) if not return_assignments: return tracked_lab - else: + else: return tracked_lab, assignments + def track_frame( - prev_lab, prev_rp, lab, rp, IDs_curr_untracked=None, - unique_ID=None, setBrushID_func=None, posData=None, - assign_unique_new_IDs=True, IoA_thresh=0.4, debug=False, - return_all=False, aggr_track=None, IoA_matrix=None, - IoA_thresh_aggr=None, IDs_prev=None, return_prev_IDs=False, - mother_daughters=None, denom_overlap_matrix = 'area_prev', - IDs=None - ): + prev_lab, + prev_rp, + lab, + rp, + IDs_curr_untracked=None, + unique_ID=None, + setBrushID_func=None, + posData=None, + assign_unique_new_IDs=True, + IoA_thresh=0.4, + debug=False, + return_all=False, + aggr_track=None, + IoA_matrix=None, + IoA_thresh_aggr=None, + IDs_prev=None, + return_prev_IDs=False, + mother_daughters=None, + denom_overlap_matrix="area_prev", + IDs=None, +): if not np.any(lab): # Skip empty frames return lab if IoA_matrix is None: IoA_matrix, IDs_curr_untracked, IDs_prev = calc_Io_matrix( - lab, prev_lab, rp, prev_rp, IDs_curr_untracked=IDs_curr_untracked, - denom=denom_overlap_matrix, IDs=IDs + lab, + prev_lab, + rp, + prev_rp, + IDs_curr_untracked=IDs_curr_untracked, + denom=denom_overlap_matrix, + IDs=IDs, ) daughters_list = [] @@ -323,63 +336,79 @@ def track_frame( daughters_list.extend(daughters) old_IDs, tracked_IDs = assign( - IoA_matrix, IDs_curr_untracked, IDs_prev, - IoA_thresh=IoA_thresh, aggr_track=aggr_track, - IoA_thresh_aggr=IoA_thresh_aggr, daughters_list=daughters_list, + IoA_matrix, + IDs_curr_untracked, + IDs_prev, + IoA_thresh=IoA_thresh, + aggr_track=aggr_track, + IoA_thresh_aggr=IoA_thresh_aggr, + daughters_list=daughters_list, ) - + if posData is None and unique_ID is None: - unique_ID = max( - (max(IDs_prev, default=0), max(IDs_curr_untracked, default=0)) - ) + 1 + unique_ID = ( + max((max(IDs_prev, default=0), max(IDs_curr_untracked, default=0))) + 1 + ) elif unique_ID is None: # Compute starting unique ID setBrushID_func(useCurrentLab=True) - unique_ID = posData.brushID+1 + unique_ID = posData.brushID + 1 if not return_all: tracked_lab = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, unique_ID, + old_IDs, + tracked_IDs, + IDs_curr_untracked, + lab.copy(), + rp, + unique_ID, assign_unique_new_IDs=assign_unique_new_IDs, ) else: tracked_lab, assignments = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, unique_ID, - assign_unique_new_IDs=assign_unique_new_IDs, + old_IDs, + tracked_IDs, + IDs_curr_untracked, + lab.copy(), + rp, + unique_ID, + assign_unique_new_IDs=assign_unique_new_IDs, return_assignments=return_all, ) # old_new_ids = dict(zip(old_IDs, tracked_IDs)) # for now not used, but could be useful in the future - + if return_all: - return tracked_lab, IoA_matrix, assignments, tracked_IDs # remove tracked_IDs and change code in CellACDC_tracker.py if causing problems + return ( + tracked_lab, + IoA_matrix, + assignments, + tracked_IDs, + ) # remove tracked_IDs and change code in CellACDC_tracker.py if causing problems else: return tracked_lab + class tracker: def __init__(self, **params): self.params = params - def track(self, segm_video, signals=None, export_to: os.PathLike=None): + def track(self, segm_video, signals=None, export_to: os.PathLike = None): tracked_video = np.zeros_like(segm_video) - pbar = tqdm(total=len(segm_video), desc='Tracking', ncols=100) + pbar = tqdm(total=len(segm_video), desc="Tracking", ncols=100) for frame_i, lab in enumerate(segm_video): if frame_i == 0: tracked_video[frame_i] = lab pbar.update() continue - prev_lab = tracked_video[frame_i-1] + prev_lab = tracked_video[frame_i - 1] prev_rp = regionprops(prev_lab) rp = regionprops(lab.copy()) - IoA_thresh = self.params.get('IoA_thresh', 0.4) - tracked_lab = track_frame( - prev_lab, prev_rp, lab, rp, IoA_thresh=IoA_thresh - ) + IoA_thresh = self.params.get("IoA_thresh", 0.4) + tracked_lab = track_frame(prev_lab, prev_rp, lab, rp, IoA_thresh=IoA_thresh) tracked_video[frame_i] = tracked_lab self.updateGuiProgressBar(signals) @@ -387,19 +416,19 @@ def track(self, segm_video, signals=None, export_to: os.PathLike=None): pbar.close() # tracked_video = relabel_sequential(tracked_video)[0] return tracked_video - + def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) return - if hasattr(signals, 'progressBar'): + if hasattr(signals, "progressBar"): signals.progressBar.emit(1) def save_output(self): - pass \ No newline at end of file + pass diff --git a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py index 4bef3ec46..1cfacfb68 100644 --- a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py +++ b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py @@ -14,53 +14,57 @@ from ..CellACDC import CellACDC_tracker + class SearchRangeUnits: - values = ['pixels', 'micrometre'] + values = ["pixels", "micrometre"] + class Integer: not_a_param = True + class tracker: def __init__( - self, - annotate_objects_tracked_second_step=True, - PhysicalSizeX=1.0, - PhysicalSizeY=1.0, - PhysicalSizeZ=1.0, - ): + self, + annotate_objects_tracked_second_step=True, + PhysicalSizeX=1.0, + PhysicalSizeY=1.0, + PhysicalSizeZ=1.0, + ): """Initialize Cell-ACDC two steps tracker Parameters ---------- annotate_objects_tracked_second_step : bool, optional - If True, Cell-ACDC will draw a line on the GUI between the objects - in previous frame that were lost in current frame according to the - first step (based on overlap) and the objects in current frame that - were matched according to the second step (based on search range). + If True, Cell-ACDC will draw a line on the GUI between the objects + in previous frame that were lost in current frame according to the + first step (based on overlap) and the objects in current frame that + were matched according to the second step (based on search range). Default is True PhysicalSizeX : float, optional - Pixel size in the x-direction in 'micrometre/pixel'. This will be + Pixel size in the x-direction in 'micrometre/pixel'. This will be ignored if `search_range_unit` is `pixels`. Default is 1.0 PhysicalSizeY : float, optional - Pixel size in the y-direction in 'micrometre/pixel'. This will be + Pixel size in the y-direction in 'micrometre/pixel'. This will be ignored if `search_range_unit` is `pixels`. Default is 1.0. PhysicalSizeZ : float, optional - Pixel size in the z-direction in 'micrometre/pixel'. This will be - ignored if `search_range_unit` is `pixels`. Default is 1.0. - """ + Pixel size in the z-direction in 'micrometre/pixel'. This will be + ignored if `search_range_unit` is `pixels`. Default is 1.0. + """ self._annot_obj_2nd_step = annotate_objects_tracked_second_step self._pixel_yx_size = (PhysicalSizeY, PhysicalSizeX) self._voxel_zyx_size = (PhysicalSizeZ, PhysicalSizeY, PhysicalSizeX) - + def track( - self, segm_video, - overlap_threshold=0.4, - search_range_unit: SearchRangeUnits='pixels', - lost_IDs_search_range=10, - signals: cellacdc.workers.signals=None, - export_to_extension='.csv', - export_to: os.PathLike=None, - ): + self, + segm_video, + overlap_threshold=0.4, + search_range_unit: SearchRangeUnits = "pixels", + lost_IDs_search_range=10, + signals: cellacdc.workers.signals = None, + export_to_extension=".csv", + export_to: os.PathLike = None, + ): """Track the objects in `segm_video`. Parameters @@ -68,56 +72,59 @@ def track( segm_video : (T, Y, X) or (T, Z, Y, X) array of ints Input segmentation masks to track. overlap_threshold : float, optional - Minimum overlap between objects of two consecutive frames to - consider the object as not new. The overlap is calculated as the - ratio between the intersection between current object and objects - in previous frame and are of the objects in previous frame. - All new objects will undergo a second step of matching based on + Minimum overlap between objects of two consecutive frames to + consider the object as not new. The overlap is calculated as the + ratio between the intersection between current object and objects + in previous frame and are of the objects in previous frame. + All new objects will undergo a second step of matching based on the `lost_IDs_search_range`. Default is 0.4 search_range_unit : {'pixels', 'micrometre'}, optional - Physical unit of the parameter `lost_IDs_search_range`. If - `micrometre`, distances will be converted using the pixel sizes. - See the parameters `PixelSizeX`, `PixelSizeY`, and `PixelSizeZ`. + Physical unit of the parameter `lost_IDs_search_range`. If + `micrometre`, distances will be converted using the pixel sizes. + See the parameters `PixelSizeX`, `PixelSizeY`, and `PixelSizeZ`. Default is 'pixels' lost_IDs_search_range : int, optional - Maximum distance that a new object (according to `overlap_threshold`) - can travel between two consecutive frames to be considered as - potential candidate to match to a lost object. The unit is + Maximum distance that a new object (according to `overlap_threshold`) + can travel between two consecutive frames to be considered as + potential candidate to match to a lost object. The unit is either `pixels` or `micrometre` (see `search_range_unit` parameter). Default is 10 signals : cellacdc.workers.signals, optional - Class with `qtpy.Signal` attributes used to display progress on the + Class with `qtpy.Signal` attributes used to display progress on the GUI (text and progressbars). Default is None export_to_extension : str, optional - Extension of the optional table that will be saved in the tracking + Extension of the optional table that will be saved in the tracking process. Default is '.csv' export_to : os.PathLike, optional Path of the table to export. Default is None - """ + """ tracked_video = np.copy(segm_video) for frame_i, lab in enumerate(segm_video): if frame_i == 0: continue - prev_frame_lab = tracked_video[frame_i-1] + prev_frame_lab = tracked_video[frame_i - 1] tracked_lab, _ = self.track_frame( - prev_frame_lab, lab, + prev_frame_lab, + lab, search_range_unit=search_range_unit, - overlap_threshold=overlap_threshold, - lost_IDs_search_range=lost_IDs_search_range + overlap_threshold=overlap_threshold, + lost_IDs_search_range=lost_IDs_search_range, ) tracked_video[frame_i] = tracked_lab self.updateGuiProgressBar(signals) return tracked_video - + def track_frame( - self, prev_frame_lab, current_frame_lab, - overlap_threshold=0.4, - search_range_unit: SearchRangeUnits='pixels', - lost_IDs_search_range=10, - unique_ID: Integer=None - ): - """Track two consecutive frames in two steps. First step based on - `overlap_threshold` and second step tracks only lost objects to new + self, + prev_frame_lab, + current_frame_lab, + overlap_threshold=0.4, + search_range_unit: SearchRangeUnits = "pixels", + lost_IDs_search_range=10, + unique_ID: Integer = None, + ): + """Track two consecutive frames in two steps. First step based on + `overlap_threshold` and second step tracks only lost objects to new objects detemined at first step. Parameters @@ -127,90 +134,90 @@ def track_frame( current_frame_lab : (Y, X) or (Z, Y, X) array of ints Segmentation masks of the current frame. overlap_threshold : float, optional - Minimum overlap between objects of two consecutive frames to - consider the object as not new. The overlap is calculated as the - ratio between the intersection between current object and objects - in previous frame and are of the objects in previous frame. - All new objects will undergo a second step of matching based on + Minimum overlap between objects of two consecutive frames to + consider the object as not new. The overlap is calculated as the + ratio between the intersection between current object and objects + in previous frame and are of the objects in previous frame. + All new objects will undergo a second step of matching based on the `lost_IDs_search_range`. Default is 0.4 search_range_unit : {'pixels', 'micrometre'}, optional - Physical unit of the parameter `lost_IDs_search_range`. If - `micrometre`, distances will be converted using the pixel sizes. - See the parameters `PixelSizeX`, `PixelSizeY`, and `PixelSizeZ`. + Physical unit of the parameter `lost_IDs_search_range`. If + `micrometre`, distances will be converted using the pixel sizes. + See the parameters `PixelSizeX`, `PixelSizeY`, and `PixelSizeZ`. Default is 'pixels' lost_IDs_search_range : int, optional - Maximum distance that a new object (according to `overlap_threshold`) - can travel between two consecutive frames to be considered as - potential candidate to match to a lost object. The unit is - either `pixels` or `micrometre`and it is set in the + Maximum distance that a new object (according to `overlap_threshold`) + can travel between two consecutive frames to be considered as + potential candidate to match to a lost object. The unit is + either `pixels` or `micrometre`and it is set in the `search_range_unit` parameter. Default is 10 unique_ID : int, optional If not None, uses this as starting ID for all the untracked objects. If None, this will be calculated based on the two input frames. - """ + """ to_track_tracked_objs_2nd_step = None - + prev_rp = skimage.measure.regionprops(prev_frame_lab) curr_rp = skimage.measure.regionprops(current_frame_lab) - + tracked_lab_1st_step = CellACDC_tracker.track_frame( - prev_frame_lab, - prev_rp, - current_frame_lab, - curr_rp, - IoA_thresh=overlap_threshold, - return_prev_IDs=False, - unique_ID=unique_ID + prev_frame_lab, + prev_rp, + current_frame_lab, + curr_rp, + IoA_thresh=overlap_threshold, + return_prev_IDs=False, + unique_ID=unique_ID, ) - + prev_rp_mapper = {obj.label: obj for obj in prev_rp} - + tracked_rp_1st_step = skimage.measure.regionprops(tracked_lab_1st_step) - tracked_rp_1st_step_mapper = { - obj.label: obj for obj in tracked_rp_1st_step - } - + tracked_rp_1st_step_mapper = {obj.label: obj for obj in tracked_rp_1st_step} + lost_rp_mapper = { - obj.label: obj for obj in prev_rp + obj.label: obj + for obj in prev_rp if tracked_rp_1st_step_mapper.get(obj.label) is None } - + if not lost_rp_mapper: return tracked_lab_1st_step, to_track_tracked_objs_2nd_step - + new_rp_mapper = { - obj.label: obj for obj in tracked_rp_1st_step + obj.label: obj + for obj in tracked_rp_1st_step if prev_rp_mapper.get(obj.label) is None } - + if not new_rp_mapper: return tracked_lab_1st_step, to_track_tracked_objs_2nd_step - + ndim = current_frame_lab.ndim lost_IDs_coords = np.zeros((len(lost_rp_mapper), ndim)) lost_IDs_idx_to_obj_mapper = {} for lost_idx, lost_obj in enumerate(lost_rp_mapper.values()): lost_IDs_coords[lost_idx] = lost_obj.centroid lost_IDs_idx_to_obj_mapper[lost_idx] = lost_obj - + new_IDs_coords = np.zeros((len(new_rp_mapper), ndim)) new_IDs_idx_to_obj_mapper = {} for new_idx, new_obj in enumerate(new_rp_mapper.values()): new_IDs_coords[new_idx] = new_obj.centroid new_IDs_idx_to_obj_mapper[new_idx] = new_obj - - if search_range_unit == 'micrometre': + + if search_range_unit == "micrometre": if ndim == 3: scaling = self._voxel_zyx_size else: scaling = self._pixel_yx_size lost_IDs_coords /= scaling new_IDs_coords /= scaling - + diff = lost_IDs_coords[:, np.newaxis] - new_IDs_coords # dist_matrix[i, j] = euclidean_dist(lost_IDs_coords[i], new_IDs_coords[j]) dist_matrix = np.linalg.norm(diff, axis=2) - + assignments = scipy.optimize.linear_sum_assignment(dist_matrix) IDs_to_track = [] tracked_IDs_2nd_step = [] @@ -221,42 +228,36 @@ def track_frame( dist = dist_matrix[i, j] if dist > lost_IDs_search_range: continue - + IDs_to_track.append(new_IDs_idx_to_obj_mapper[j].label) tracked_IDs_2nd_step.append(lost_IDs_idx_to_obj_mapper[i].label) if self._annot_obj_2nd_step: objs_to_track.append(new_IDs_idx_to_obj_mapper[j]) tracked_objs_2nd_step.append(lost_IDs_idx_to_obj_mapper[i]) - + if not IDs_to_track: return tracked_lab_1st_step, to_track_tracked_objs_2nd_step - + tracked_lab_2nd_step = cellacdc.core.lab_replace_values( - tracked_lab_1st_step, + tracked_lab_1st_step, tracked_rp_1st_step, - IDs_to_track, - tracked_IDs_2nd_step + IDs_to_track, + tracked_IDs_2nd_step, ) - + if self._annot_obj_2nd_step: - to_track_tracked_objs_2nd_step = ( - objs_to_track, tracked_objs_2nd_step - ) - + to_track_tracked_objs_2nd_step = (objs_to_track, tracked_objs_2nd_step) + return tracked_lab_2nd_step, to_track_tracked_objs_2nd_step - + def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) return signals.progressBar.emit(1) - - - - \ No newline at end of file diff --git a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py index 8e47215a9..69a063bd5 100644 --- a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py +++ b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py @@ -2,12 +2,12 @@ from cellacdc.trackers.CellACDC.CellACDC_tracker import calc_Io_matrix from cellacdc.trackers.CellACDC.CellACDC_tracker import track_frame as track_frame_base from cellacdc.core import getBaseCca_df, printl -from cellacdc.myutils import checked_reset_index, checked_reset_index_Cell_ID +from cellacdc.utils import checked_reset_index, checked_reset_index_Cell_ID import numpy as np from skimage.measure import regionprops from tqdm import tqdm import pandas as pd -from cellacdc.myutils import exec_time +from cellacdc.utils import exec_time from cellacdc._types import NotGUIParam import copy import cellacdc.debugutils as debugutils @@ -24,14 +24,15 @@ # Returns: # - pandas.DataFrame: The filtered DataFrame containing only the specified columns. # """ -# lin_tree_cols = {'generation_num_tree', 'root_ID_tree', -# 'sister_ID_tree', 'parent_ID_tree', -# 'parent_ID_tree', 'emerg_frame_i', +# lin_tree_cols = {'generation_num_tree', 'root_ID_tree', +# 'sister_ID_tree', 'parent_ID_tree', +# 'parent_ID_tree', 'emerg_frame_i', # 'division_frame_i', 'is_history_known'} # sis_cols = {col for col in df.columns if col.startswith('sister_ID_tree')} # lin_tree_cols = lin_tree_cols | sis_cols # return df[list(lin_tree_cols)] + def reorg_sister_cells_for_export(lineage_tree_frame): """ Reorganizes the daughter cells in the lineage tree frame for export. @@ -46,20 +47,23 @@ def reorg_sister_cells_for_export(lineage_tree_frame): """ if lineage_tree_frame.empty: return lineage_tree_frame - - old_sister_columns = {col for col in lineage_tree_frame.columns if col.startswith('sister_ID_tree')} - sister_columns = lineage_tree_frame['sister_ID_tree'].apply(pd.Series) + old_sister_columns = { + col for col in lineage_tree_frame.columns if col.startswith("sister_ID_tree") + } + + sister_columns = lineage_tree_frame["sister_ID_tree"].apply(pd.Series) max_daughter = sister_columns.shape[1] - new_columns = [f'sister_ID_tree_{i}' for i in range(max_daughter)] + new_columns = [f"sister_ID_tree_{i}" for i in range(max_daughter)] lineage_tree_frame = lineage_tree_frame.drop(columns=old_sister_columns) lineage_tree_frame[new_columns] = sister_columns - lineage_tree_frame['sister_ID_tree'] = sister_columns[0] + lineage_tree_frame["sister_ID_tree"] = sister_columns[0] return lineage_tree_frame + # def reorg_sister_cells_inner_func(row): # """ # Reorganizes the sister cells in a row of a DataFrame. Used as an inner function for apply. @@ -71,7 +75,7 @@ def reorg_sister_cells_for_export(lineage_tree_frame): # """ # values = [int(i) for i in row if i not in {0, -1} and not np.isnan(i)] or [-1] -# values = list(set(values)) +# values = list(set(values)) # return values @@ -99,7 +103,10 @@ def reorg_sister_cells_for_export(lineage_tree_frame): # df = checked_reset_index_Cell_ID(df) # return df -def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh_instant=None): + +def mother_daughter_assign( + IoA_matrix, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh_instant=None +): """ Identifies cells that have not undergone division based on the input IoA matrix. @@ -116,7 +123,7 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da """ mother_daughters = [] aggr_track = [] - daughter_range = range(min_daughter, max_daughter+1, 1) + daughter_range = range(min_daughter, max_daughter + 1, 1) instant_accept = [] IoA_thresholded = IoA_matrix >= IoA_thresh_daughter @@ -131,7 +138,7 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da if IoA_instant_accept[:, j].any(): instant_accept.append(j) continue - + high_IoA_indices = np.where(IoA_thresholded[:, j])[0] if not high_IoA_indices.size: @@ -146,13 +153,15 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da for daughter in daughters: high_IoA_greater_1 = np.count_nonzero(IoA_thresholded[daughter]) > 1 if high_IoA_greater_1: - should_remove_idx.append(True) + should_remove_idx.append(True) break else: should_remove_idx.append(False) - + # printl(f'length of mother_daughters: {len(mother_daughters), len(should_remove_idx)}') - mother_daughters = [mother_daughters[i] for i, remove in enumerate(should_remove_idx) if not remove] + mother_daughters = [ + mother_daughters[i] for i, remove in enumerate(should_remove_idx) if not remove + ] # daughters_li = [] # for _, daughters in mother_daughters: @@ -160,6 +169,7 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da return aggr_track, mother_daughters + def added_lineage_tree_to_cca_df(added_lineage_tree): """ Converts the added lineage tree into a DataFrame with specific columns. @@ -187,20 +197,25 @@ def added_lineage_tree_to_cca_df(added_lineage_tree): return pd.DataFrame() # Use zip to unpack columns efficiently - emerg_frame_i, cell_id, parent_id, gen_num, root_id, sister_ids = zip(*added_lineage_tree) - cca_df = pd.DataFrame({ - 'Cell_ID': cell_id, - 'emerg_frame_i': emerg_frame_i, - 'division_frame_i': emerg_frame_i, - 'generation_num_tree': gen_num, - 'parent_ID_tree': parent_id, - 'root_ID_tree': root_id, - 'sister_ID_tree': sister_ids, - }) - cca_df['is_history_known'] = (cca_df['parent_ID_tree'] != -1).astype(int) - cca_df = cca_df.set_index('Cell_ID') + emerg_frame_i, cell_id, parent_id, gen_num, root_id, sister_ids = zip( + *added_lineage_tree + ) + cca_df = pd.DataFrame( + { + "Cell_ID": cell_id, + "emerg_frame_i": emerg_frame_i, + "division_frame_i": emerg_frame_i, + "generation_num_tree": gen_num, + "parent_ID_tree": parent_id, + "root_ID_tree": root_id, + "sister_ID_tree": sister_ids, + } + ) + cca_df["is_history_known"] = (cca_df["parent_ID_tree"] != -1).astype(int) + cca_df = cca_df.set_index("Cell_ID") return cca_df + def filter_current_IDs(df, current_IDs): """ Filters for current IDs. @@ -215,6 +230,7 @@ def filter_current_IDs(df, current_IDs): df = checked_reset_index_Cell_ID(df) return df[df.index.isin(current_IDs)] + def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): """ Converts a list of daughter indices (IoA Matrix) to their corresponding IDs. @@ -231,7 +247,7 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): if daughters is None: return - + daughter_IDs = [] for daughter in daughters: if assignments: @@ -241,13 +257,14 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): return daughter_IDs + # def update_fam_dynamically(families, fixed_df, Cell_IDs_fixed=None): # if Cell_IDs_fixed is None: # Cell_IDs_fixed = fixed_df.index # for idx, family in enumerate(families): # # Keep only cellinfos where cell_id is in Cell_IDs_fixed # families[idx] = [cellinfo for cellinfo in family if cellinfo[0] not in Cell_IDs_fixed] - + # families = [family for family in families if family] # Remove empty families # handled_cells = set() # for family in families: @@ -260,14 +277,15 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): # # Update the family with the generation number and root ID # family.append((relevant_cell, relevant_cells.loc[relevant_cell, 'generation_num_tree'])) # handled_cells.update(relevant_cells.index) - + # for cell_id in Cell_IDs_fixed: # if cell_id not in handled_cells: # # If the cell is not handled, create a new family for it # families.append([(cell_id, fixed_df.loc[cell_id, 'generation_num_tree'])]) - + # return families + class normal_division_tracker: """ A class that tracks cell divisions in a video sequence. The tracker uses the Intersection over Area (IoA) metric to track cells and identify daughter cells. @@ -293,13 +311,15 @@ class normal_division_tracker: - track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None): Tracks a single frame in the video sequence. """ - def __init__(self, - segm_video, - IoA_thresh_daughter, - min_daughter, - max_daughter, - IoA_thresh, - IoA_thresh_aggressive): + def __init__( + self, + segm_video, + IoA_thresh_daughter, + min_daughter, + max_daughter, + IoA_thresh, + IoA_thresh_aggressive, + ): """ Initializes the normal_division_tracker object. @@ -322,8 +342,16 @@ def __init__(self, self.tracked_video = np.zeros_like(segm_video) self.tracked_video[0] = segm_video[0] - def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, - IDs=None, unique_ID=None): + def track_frame( + self, + frame_i, + lab=None, + prev_lab=None, + rp=None, + prev_rp=None, + IDs=None, + unique_ID=None, + ): """ Tracks a single frame in the video sequence. @@ -339,7 +367,7 @@ def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, lab = self.segm_video[frame_i] if prev_lab is None: - prev_lab = self.tracked_video[frame_i-1] + prev_lab = self.tracked_video[frame_i - 1] if rp is None: self.rp = regionprops(lab.copy()) @@ -349,36 +377,39 @@ def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, if prev_rp is None: prev_rp = regionprops(prev_lab.copy()) - IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, - prev_lab, - self.rp, - prev_rp, - IDs=IDs, - ) - self.aggr_track, self.mother_daughters = mother_daughter_assign(IoA_matrix, - IoA_thresh_daughter=self.IoA_thresh_daughter, - min_daughter=self.min_daughter, - max_daughter=self.max_daughter, - IoA_thresh_instant=self.IoA_thresh - ) - self.tracked_lab, IoA_matrix, self.assignments, _ = track_frame_base(prev_lab, - prev_rp, - lab, - self.rp, - IoA_thresh=self.IoA_thresh, - IoA_matrix=IoA_matrix, - aggr_track=self.aggr_track, - IoA_thresh_aggr=self.IoA_thresh_aggressive, - IDs_curr_untracked=self.IDs_curr_untracked, - IDs_prev=self.IDs_prev, - return_all=True, - mother_daughters=self.mother_daughters, - unique_ID=unique_ID - ) - + IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix( + lab, + prev_lab, + self.rp, + prev_rp, + IDs=IDs, + ) + self.aggr_track, self.mother_daughters = mother_daughter_assign( + IoA_matrix, + IoA_thresh_daughter=self.IoA_thresh_daughter, + min_daughter=self.min_daughter, + max_daughter=self.max_daughter, + IoA_thresh_instant=self.IoA_thresh, + ) + self.tracked_lab, IoA_matrix, self.assignments, _ = track_frame_base( + prev_lab, + prev_rp, + lab, + self.rp, + IoA_thresh=self.IoA_thresh, + IoA_matrix=IoA_matrix, + aggr_track=self.aggr_track, + IoA_thresh_aggr=self.IoA_thresh_aggressive, + IDs_curr_untracked=self.IDs_curr_untracked, + IDs_prev=self.IDs_prev, + return_all=True, + mother_daughters=self.mother_daughters, + unique_ID=unique_ID, + ) self.tracked_video[frame_i] = self.tracked_lab + class normal_division_lineage_tree: """ Class for tracking and managing cell lineage trees during normal cell division across multiple frames. @@ -423,10 +454,18 @@ class normal_division_lineage_tree: export_lin_tree_info(frame_i) Return information about new, orphan, and lost cells between two consecutive frames. - """ + """ - def __init__(self, lab=None, first_df=None, frame_i=0, max_daughter=2, min_daughter=2, IoA_thresh_daughter=0.25, - gui=None): + def __init__( + self, + lab=None, + first_df=None, + frame_i=0, + max_daughter=2, + min_daughter=2, + IoA_thresh_daughter=0.25, + gui=None, + ): """ Initialize the lineage tree for normal cell divisions. @@ -449,30 +488,35 @@ def __init__(self, lab=None, first_df=None, frame_i=0, max_daughter=2, min_daugh self.gui = gui self.max_daughters_added = 0 self.gui_mode = True if gui is not None else False - self.mother_daughters = [] # just for the dict_curr_frame stuff... + self.mother_daughters = [] # just for the dict_curr_frame stuff... self.frames_for_dfs = set([frame_i]) - self.need_update_gen_df = False # this is only when using the quick option in update_gen_df_from_df + self.need_update_gen_df = ( + False # this is only when using the quick option in update_gen_df_from_df + ) self.first_frame_i_for_ID = dict() self.ID_frame_i_lookup = {} - - if self.gui_mode: # part of loading for gui + + if self.gui_mode: # part of loading for gui posData = self.gui.data[self.gui.pos_i] for i, data in enumerate(posData.allData_li): - if 'generation_num_tree' in data['acdc_df'].columns and data['acdc_df']['generation_num_tree'].notna().all(): + if ( + "generation_num_tree" in data["acdc_df"].columns + and data["acdc_df"]["generation_num_tree"].notna().all() + ): self.frames_for_dfs.add(i) - + self.init_lineage_tree(lab, first_df, frame_i) - + def _get_first_frame_i_for_ID(self, ID): if self.gui_mode: posData = self.gui.data[self.gui.pos_i] if ID in self.first_frame_i_for_ID: frame_i = self.first_frame_i_for_ID[ID] - if ID in posData.allData_li[frame_i]['acdc_df'].index: + if ID in posData.allData_li[frame_i]["acdc_df"].index: return frame_i for i, data in enumerate(posData.allData_li): - if ID in data['acdc_df'].index: + if ID in data["acdc_df"].index: self.first_frame_i_for_ID[ID] = i return i else: @@ -484,40 +528,44 @@ def _get_first_frame_i_for_ID(self, ID): if ID in df.index: self.first_frame_i_for_ID[ID] = i return i - - def _get_extra_daughter_cols(self,num_daughters=None): + + def _get_extra_daughter_cols(self, num_daughters=None): if num_daughters is not None and self.max_daughters_added < num_daughters: missing_i = range(self.max_daughters_added, num_daughters) - missing_cols = [f'sister_ID_tree_{i}' for i in missing_i] + missing_cols = [f"sister_ID_tree_{i}" for i in missing_i] self.max_daughters_added = num_daughters if self.gui_mode: posData = self.gui.data[self.gui.pos_i] for frame_i in self.frames_for_dfs: - df = posData.allData_li[frame_i]['acdc_df'] - missing_cols_loc = [col for col in missing_cols if col not in df.columns] + df = posData.allData_li[frame_i]["acdc_df"] + missing_cols_loc = [ + col for col in missing_cols if col not in df.columns + ] df[missing_cols_loc] = -1 else: for df in self.lineage_list: - missing_cols_loc = [col for col in missing_cols if col not in df.columns] + missing_cols_loc = [ + col for col in missing_cols if col not in df.columns + ] df[missing_cols_loc] = -1 - - return [f'sister_ID_tree_{i}' for i in range(self.max_daughters_added)] - + + return [f"sister_ID_tree_{i}" for i in range(self.max_daughters_added)] + def _get_df_from_frame_i(self, frame_i): if self.gui_mode: posData = self.gui.data[self.gui.pos_i] - return posData.allData_li[frame_i]['acdc_df'] + return posData.allData_li[frame_i]["acdc_df"] else: return self.lineage_list[frame_i] - + def _get_row_from_ID(self, ID, start_search_frame_i=None): if ID in self.ID_frame_i_lookup: frame_i = self.ID_frame_i_lookup[ID] if self.gui_mode: posData = self.gui.data[self.gui.pos_i] try: - df = posData.allData_li[frame_i]['acdc_df'] + df = posData.allData_li[frame_i]["acdc_df"] row = df.loc[ID] return row except: @@ -532,15 +580,15 @@ def _get_row_from_ID(self, ID, start_search_frame_i=None): if self.gui_mode: posData = self.gui.data[self.gui.pos_i] if start_search_frame_i is not None: - df = posData.allData_li[start_search_frame_i]['acdc_df'] + df = posData.allData_li[start_search_frame_i]["acdc_df"] if ID in df.index: row = df.loc[ID] self.ID_frame_i_lookup[ID] = start_search_frame_i return row - + for i, data in enumerate(posData.allData_li): - if ID in data['acdc_df'].index: - df = data['acdc_df'] + if ID in data["acdc_df"].index: + df = data["acdc_df"] row = df.loc[ID] self.ID_frame_i_lookup[ID] = i return row @@ -551,14 +599,14 @@ def _get_row_from_ID(self, ID, start_search_frame_i=None): row = df.loc[ID] self.ID_frame_i_lookup[ID] = start_search_frame_i return row - + for i, df in enumerate(self.lineage_list): if ID in df.index: row = df.loc[ID] self.ID_frame_i_lookup[ID] = i return row - raise ValueError(f'ID {ID} not found in any frame.') + raise ValueError(f"ID {ID} not found in any frame.") def init_lineage_tree(self, lab=None, first_df=None, frame_i=None): """ @@ -571,57 +619,69 @@ def init_lineage_tree(self, lab=None, first_df=None, frame_i=None): Raises: ValueError: If both lab and first_df are provided. """ - print('Initializing lineage tree...') + print("Initializing lineage tree...") if lab is not None and lab.any() and first_df: - raise ValueError('Only one of lab and first_df can be provided.') - + raise ValueError("Only one of lab and first_df can be provided.") + if frame_i is None: frame_i = 0 - + if self.gui_mode: cca_df = self._get_df_from_frame_i(frame_i) - if 'parent_ID_tree' in cca_df.columns: + if "parent_ID_tree" in cca_df.columns: return - cca_df['emerg_frame_i'] = cca_df['division_frame_i'] = frame_i - cca_df['generation_num_tree'] = 1 - cca_df['parent_ID_tree'] = -1 - cca_df['is_history_known'] = (cca_df['parent_ID_tree'] != -1).astype(int) - cca_df['root_ID_tree'] = cca_df.index - cca_df['sister_ID_tree'] = -1 + cca_df["emerg_frame_i"] = cca_df["division_frame_i"] = frame_i + cca_df["generation_num_tree"] = 1 + cca_df["parent_ID_tree"] = -1 + cca_df["is_history_known"] = (cca_df["parent_ID_tree"] != -1).astype(int) + cca_df["root_ID_tree"] = cca_df.index + cca_df["sister_ID_tree"] = -1 cca_df[self._get_extra_daughter_cols()] = -1 - + return - - if lab is not None: + if lab is not None: rp = regionprops(lab) labels = [obj.label for obj in rp] - cca_df = pd.DataFrame({ - 'Cell_ID': labels, - }) - cca_df = cca_df.set_index('Cell_ID') + cca_df = pd.DataFrame( + { + "Cell_ID": labels, + } + ) + cca_df = cca_df.set_index("Cell_ID") # check if the cca_df already has the lineage columns - cca_df['emerg_frame_i'] = cca_df['division_frame_i'] = frame_i - cca_df['generation_num_tree'] = 1 - cca_df['parent_ID_tree'] = -1 - cca_df['is_history_known'] = (cca_df['parent_ID_tree'] != -1).astype(int) - cca_df['root_ID_tree'] = cca_df.index - - cca_df['sister_ID_tree'] = [[-1] * (self.max_daughter-1) for _ in range(len(cca_df))] + cca_df["emerg_frame_i"] = cca_df["division_frame_i"] = frame_i + cca_df["generation_num_tree"] = 1 + cca_df["parent_ID_tree"] = -1 + cca_df["is_history_known"] = (cca_df["parent_ID_tree"] != -1).astype(int) + cca_df["root_ID_tree"] = cca_df.index + + cca_df["sister_ID_tree"] = [ + [-1] * (self.max_daughter - 1) for _ in range(len(cca_df)) + ] cca_df = checked_reset_index_Cell_ID(cca_df) self.lineage_list = [cca_df] - elif first_df is not None and not first_df.empty: if self.gui_mode: # not yet implemented - raise NotImplementedError('Initializing lineage tree with a DataFrame is not yet implemented in GUI mode.') + raise NotImplementedError( + "Initializing lineage tree with a DataFrame is not yet implemented in GUI mode." + ) first_df = checked_reset_index_Cell_ID(first_df) self.lineage_list = [first_df] - - def add_new_frame(self, frame_i, mother_daughters, IDs_prev, - IDs_curr_untracked, assignments, curr_IDs, new_IDs): + + def add_new_frame( + self, + frame_i, + mother_daughters, + IDs_prev, + IDs_curr_untracked, + assignments, + curr_IDs, + new_IDs, + ): """ Add a new frame to the lineage tree, updating families and tracking new and divided cells. @@ -641,12 +701,14 @@ def add_new_frame(self, frame_i, mother_daughters, IDs_prev, added_lineage_tree = [] else: posData = self.gui.data[self.gui.pos_i] - cca_df = posData.allData_li[frame_i]['acdc_df'] + cca_df = posData.allData_li[frame_i]["acdc_df"] daughter_dict = {} daughter_set = set() for mother, daughters in mother_daughters: - daughter_IDs = IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked) + daughter_IDs = IoA_index_daughter_to_ID( + daughters, assignments, IDs_curr_untracked + ) daughter_dict[mother] = daughter_IDs daughter_set.update(set(daughter_IDs)) @@ -654,19 +716,21 @@ def add_new_frame(self, frame_i, mother_daughters, IDs_prev, if not self.gui_mode: for ID in new_unknown_IDs: - added_lineage_tree.append((frame_i, ID, -1, 1, ID, [-1] * (self.max_daughter-1))) + added_lineage_tree.append( + (frame_i, ID, -1, 1, ID, [-1] * (self.max_daughter - 1)) + ) else: relevant_rows = cca_df.index.isin(new_unknown_IDs) - cca_df.loc[relevant_rows, 'generation_num_tree'] = 1 - cca_df.loc[relevant_rows, 'parent_ID_tree'] = -1 - cca_df.loc[relevant_rows, 'emerg_frame_i'] = frame_i - cca_df.loc[relevant_rows, 'division_frame_i'] = frame_i - cca_df.loc[relevant_rows, 'sister_ID_tree'] = -1 - cca_df.loc[relevant_rows, 'root_ID_tree'] = cca_df.index[relevant_rows] + cca_df.loc[relevant_rows, "generation_num_tree"] = 1 + cca_df.loc[relevant_rows, "parent_ID_tree"] = -1 + cca_df.loc[relevant_rows, "emerg_frame_i"] = frame_i + cca_df.loc[relevant_rows, "division_frame_i"] = frame_i + cca_df.loc[relevant_rows, "sister_ID_tree"] = -1 + cca_df.loc[relevant_rows, "root_ID_tree"] = cca_df.index[relevant_rows] cca_df.loc[relevant_rows, self._get_extra_daughter_cols()] = -1 - cca_df.loc[relevant_rows, 'is_history_known'] = False - + cca_df.loc[relevant_rows, "is_history_known"] = False + for mother, _ in mother_daughters: mother_ID = IDs_prev[mother] daughter_IDs = daughter_dict[mother] @@ -675,19 +739,37 @@ def add_new_frame(self, frame_i, mother_daughters, IDs_prev, for daughter_ID in daughter_IDs: daughter_IDs_copy = daughter_IDs.copy() daughter_IDs_copy.remove(daughter_ID) - daughter_IDs_copy = daughter_IDs_copy + [-1] * (self.max_daughters_added - len(daughter_IDs_copy)) + daughter_IDs_copy = daughter_IDs_copy + [-1] * ( + self.max_daughters_added - len(daughter_IDs_copy) + ) if not self.gui_mode: - added_lineage_tree.append((frame_i, daughter_ID, mother_ID, mother_row['generation_num_tree'] + 1, - mother_ID, daughter_IDs_copy)) + added_lineage_tree.append( + ( + frame_i, + daughter_ID, + mother_ID, + mother_row["generation_num_tree"] + 1, + mother_ID, + daughter_IDs_copy, + ) + ) else: - cca_df.loc[daughter_ID, 'generation_num_tree'] = mother_row['generation_num_tree'] + 1 - cca_df.loc[daughter_ID, 'parent_ID_tree'] = mother_ID - cca_df.loc[daughter_ID, 'emerg_frame_i'] = frame_i - cca_df.loc[daughter_ID, 'division_frame_i'] = frame_i - cca_df.loc[daughter_ID, 'sister_ID_tree'] = daughter_IDs_copy[0] # here we dont need to consider the possibility that the sister is already gone, as its the first frame where the daughters appeared - cca_df.loc[daughter_ID, 'root_ID_tree'] = mother_row['root_ID_tree'] - cca_df.loc[daughter_ID, 'is_history_known'] = True - for i, extra_col in enumerate(self._get_extra_daughter_cols(num_daughters=len(daughter_IDs_copy))): + cca_df.loc[daughter_ID, "generation_num_tree"] = ( + mother_row["generation_num_tree"] + 1 + ) + cca_df.loc[daughter_ID, "parent_ID_tree"] = mother_ID + cca_df.loc[daughter_ID, "emerg_frame_i"] = frame_i + cca_df.loc[daughter_ID, "division_frame_i"] = frame_i + cca_df.loc[daughter_ID, "sister_ID_tree"] = daughter_IDs_copy[ + 0 + ] # here we dont need to consider the possibility that the sister is already gone, as its the first frame where the daughters appeared + cca_df.loc[daughter_ID, "root_ID_tree"] = mother_row["root_ID_tree"] + cca_df.loc[daughter_ID, "is_history_known"] = True + for i, extra_col in enumerate( + self._get_extra_daughter_cols( + num_daughters=len(daughter_IDs_copy) + ) + ): cca_df.loc[daughter_ID, extra_col] = daughter_IDs_copy[i] # copy over old lineage info @@ -701,18 +783,27 @@ def add_new_frame(self, frame_i, mother_daughters, IDs_prev, except IndexError: len_lineage_list = len(self.lineage_list) if frame_i >= len_lineage_list: - self.lineage_list.extend([pd.DataFrame()] * (frame_i + 1 - len_lineage_list)) + self.lineage_list.extend( + [pd.DataFrame()] * (frame_i + 1 - len_lineage_list) + ) self.lineage_list[frame_i] = cca_df else: - prev_df = self._get_df_from_frame_i(frame_i-1) + prev_df = self._get_df_from_frame_i(frame_i - 1) same_IDs = prev_df.index.intersection(cca_df.index) - columns = ['generation_num_tree', 'parent_ID_tree', - 'emerg_frame_i', 'division_frame_i', - 'root_ID_tree', 'sister_ID_tree', 'is_history_known'] - cca_df.loc[same_IDs, columns] = prev_df.loc[same_IDs, - columns].values - cca_df.loc[same_IDs, self._get_extra_daughter_cols()] = prev_df.loc[same_IDs, self._get_extra_daughter_cols()].values + columns = [ + "generation_num_tree", + "parent_ID_tree", + "emerg_frame_i", + "division_frame_i", + "root_ID_tree", + "sister_ID_tree", + "is_history_known", + ] + cca_df.loc[same_IDs, columns] = prev_df.loc[same_IDs, columns].values + cca_df.loc[same_IDs, self._get_extra_daughter_cols()] = prev_df.loc[ + same_IDs, self._get_extra_daughter_cols() + ].values self.frames_for_dfs.add(frame_i) def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): @@ -735,13 +826,16 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): if prev_rp is None: prev_rp = regionprops(prev_lab) - IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, prev_lab, rp, prev_rp) - - _, self.mother_daughters = mother_daughter_assign(IoA_matrix, - IoA_thresh_daughter=self.IoA_thresh_daughter, - min_daughter=self.min_daughter, - max_daughter=self.max_daughter - ) + IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix( + lab, prev_lab, rp, prev_rp + ) + + _, self.mother_daughters = mother_daughter_assign( + IoA_matrix, + IoA_thresh_daughter=self.IoA_thresh_daughter, + min_daughter=self.min_daughter, + max_daughter=self.max_daughter, + ) # filter mothers which are actually tracked/present (could be after user correction in the GUI) filtered_mother_daughters = [] for mother, daughters in self.mother_daughters: @@ -749,12 +843,20 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): if mother_ID not in self._get_df_from_frame_i(frame_i).index: filtered_mother_daughters.append((mother, daughters)) self.mother_daughters = filtered_mother_daughters - + curr_IDs = set(self.IDs_curr_untracked) prev_IDs = {obj.label for obj in prev_rp} new_IDs = curr_IDs - prev_IDs self.frames_for_dfs.add(frame_i) - self.add_new_frame(frame_i, self.mother_daughters, self.IDs_prev, self.IDs_curr_untracked, None, curr_IDs, new_IDs) + self.add_new_frame( + frame_i, + self.mother_daughters, + self.IDs_prev, + self.IDs_curr_untracked, + None, + curr_IDs, + new_IDs, + ) def update_df_li_locally(self, df, frame_i): """ @@ -772,79 +874,98 @@ def update_df_li_locally(self, df, frame_i): # we first need to correct generation_num_tree, root_ID_tree, sister_ID_tree if not self.gui_mode: - df = checked_reset_index(df) + df = checked_reset_index(df) corrected_rows = [] for _, Cell_info in df.iterrows(): - if Cell_info['parent_ID_tree'] == -1: - Cell_info['generation_num_tree'] = 1 - Cell_info['root_ID_tree'] = Cell_info['Cell_ID'] - Cell_info['sister_ID_tree'] = [-1] - Cell_info['is_history_known'] = False + if Cell_info["parent_ID_tree"] == -1: + Cell_info["generation_num_tree"] = 1 + Cell_info["root_ID_tree"] = Cell_info["Cell_ID"] + Cell_info["sister_ID_tree"] = [-1] + Cell_info["is_history_known"] = False corrected_rows.append(Cell_info) continue - Cell_info['is_history_known'] = True + Cell_info["is_history_known"] = True - parent_ID = Cell_info['parent_ID_tree'] + parent_ID = Cell_info["parent_ID_tree"] parent_line = self._get_row_from_ID(parent_ID) - Cell_info['generation_num_tree'] = int(parent_line['generation_num_tree']) + 1 - Cell_info['root_ID_tree'] = parent_line['root_ID_tree'] + Cell_info["generation_num_tree"] = ( + int(parent_line["generation_num_tree"]) + 1 + ) + Cell_info["root_ID_tree"] = parent_line["root_ID_tree"] - first_frame_i = self._get_first_frame_i_for_ID(Cell_info['Cell_ID']) + first_frame_i = self._get_first_frame_i_for_ID(Cell_info["Cell_ID"]) df_sisters = self._get_df_from_frame_i(first_frame_i) - sisters = set(df_sisters.loc[df_sisters['parent_ID_tree'] == parent_ID, 'Cell_ID']) - sisters.discard(Cell_info['Cell_ID']) - Cell_info['sister_ID_tree'] = list(sisters) if sisters else [-1] + sisters = set( + df_sisters.loc[df_sisters["parent_ID_tree"] == parent_ID, "Cell_ID"] + ) + sisters.discard(Cell_info["Cell_ID"]) + Cell_info["sister_ID_tree"] = list(sisters) if sisters else [-1] corrected_rows.append(Cell_info) - corrected_df = pd.DataFrame(corrected_rows).set_index('Cell_ID') + corrected_df = pd.DataFrame(corrected_rows).set_index("Cell_ID") self.lineage_list[frame_i] = corrected_df else: posData = self.gui.data[self.gui.pos_i] - df_data = posData.allData_li[frame_i]['acdc_df'] + df_data = posData.allData_li[frame_i]["acdc_df"] df = checked_reset_index_Cell_ID(df) if set(df.index) != set(df_data.index): - raise ValueError('In GUI mode, the DataFrame index must be Cell_ID for lineage updates to work.') - + raise ValueError( + "In GUI mode, the DataFrame index must be Cell_ID for lineage updates to work." + ) + for ID, Cell_info in df.iterrows(): cell_row = df_data.loc[ID] - if Cell_info['parent_ID_tree'] == -1: - df.loc[ID, ['generation_num_tree', 'root_ID_tree', - 'sister_ID_tree', 'is_history_known', - 'parent_ID_tree']] = [1, ID, -1, False, -1] + if Cell_info["parent_ID_tree"] == -1: + df.loc[ + ID, + [ + "generation_num_tree", + "root_ID_tree", + "sister_ID_tree", + "is_history_known", + "parent_ID_tree", + ], + ] = [1, ID, -1, False, -1] df.loc[ID, self._get_extra_daughter_cols()] = -1 continue - cell_row['is_history_known'] = True + cell_row["is_history_known"] = True - parent_ID = Cell_info['parent_ID_tree'] + parent_ID = Cell_info["parent_ID_tree"] parent_line = self._get_row_from_ID(parent_ID) - cell_row['generation_num_tree'] = int(parent_line['generation_num_tree']) + 1 - cell_row['root_ID_tree'] = parent_line['root_ID_tree'] + cell_row["generation_num_tree"] = ( + int(parent_line["generation_num_tree"]) + 1 + ) + cell_row["root_ID_tree"] = parent_line["root_ID_tree"] first_frame_i = self._get_first_frame_i_for_ID(ID) df_sisters = self._get_df_from_frame_i(first_frame_i) - - sisters = set(df_sisters.loc[df_sisters['parent_ID_tree'] == parent_ID].index) + + sisters = set( + df_sisters.loc[df_sisters["parent_ID_tree"] == parent_ID].index + ) sisters.discard(ID) sisters = list(sisters) - cell_row['sister_ID_tree'] = sisters[0] if sisters else -1 + cell_row["sister_ID_tree"] = sisters[0] if sisters else -1 sisters = sisters + [-1] * (self.max_daughters_added - len(sisters)) cols = self._get_extra_daughter_cols(num_daughters=len(sisters)) for col in cols: if col not in cell_row.index: cell_row[col] = -1 - cell_row[self._get_extra_daughter_cols(num_daughters=len(sisters))] = sisters - + cell_row[self._get_extra_daughter_cols(num_daughters=len(sisters))] = ( + sisters + ) df_data.loc[ID] = cell_row + # This will probably be made obsolete by the gui_mode version - # def insert_lineage_df(self, lineage_df, frame_i, update_fams=True, - # consider_children=True, raw_input=False, propagate=True, + # def insert_lineage_df(self, lineage_df, frame_i, update_fams=True, + # consider_children=True, raw_input=False, propagate=True, # relevant_cells=None): # """ # Insert or replace a lineage DataFrame at a given frame index, optionally updating families and propagating changes. @@ -893,7 +1014,6 @@ def update_df_li_locally(self, df, frame_i): # else: # self.lineage_list = out - # elif frame_i > len_lineage_list: # printl(f'WARNING: Frame_i {frame_i} was inserted. The lineage list was only {len(self.lineage_list)} frames long, so the last known lineage tree was copy pasted up to frame_i {frame_i}') @@ -916,9 +1036,14 @@ def update_df_li_locally(self, df, frame_i): # self.lineage_list, self.families = out # else: # self.lineage_list = out - - def _update_consistency(self, fixed_frame_i=None, fixed_df=None, - Cell_IDs_fixed=None, consider_children=True): + + def _update_consistency( + self, + fixed_frame_i=None, + fixed_df=None, + Cell_IDs_fixed=None, + consider_children=True, + ): """ Updates the consistency of lineage information across a list of DataFrames representing cell tracking over time. @@ -946,66 +1071,115 @@ def _update_consistency(self, fixed_frame_i=None, fixed_df=None, - The function maintains a lookup dictionary and a list of fixed DataFrames to efficiently propagate updates. - Sister IDs are stored as sets, excluding the cell's own ID; if a cell has no sisters, the value is set to {-1}. """ - columns_to_replace = ['generation_num_tree', - 'root_ID_tree', - 'sister_ID_tree', - 'parent_ID_tree'] - + columns_to_replace = [ + "generation_num_tree", + "root_ID_tree", + "sister_ID_tree", + "parent_ID_tree", + ] + if fixed_df is not None: fixed_df = checked_reset_index_Cell_ID(fixed_df) elif fixed_frame_i is not None: if self.gui_mode: posData = self.gui.data[self.gui.pos_i] - fixed_df = checked_reset_index_Cell_ID(posData.allData_li[fixed_frame_i]['acdc_df']) + fixed_df = checked_reset_index_Cell_ID( + posData.allData_li[fixed_frame_i]["acdc_df"] + ) else: fixed_df = checked_reset_index_Cell_ID(self.lineage_list[fixed_frame_i]) else: - raise ValueError('Either fixed_frame_i or fixed_df must be provided.') + raise ValueError("Either fixed_frame_i or fixed_df must be provided.") - if Cell_IDs_fixed is not None: # if we have a list of Cell_IDs to consider + if Cell_IDs_fixed is not None: # if we have a list of Cell_IDs to consider fixed_df = fixed_df[fixed_df.index.isin(Cell_IDs_fixed)] - else: + else: Cell_IDs_fixed = fixed_df.index fixed_dfs = [fixed_df] fixed_dfs_lookup = {fixed_df.index[i]: 0 for i in range(len(fixed_df))} - Cell_IDs_fixed = set(Cell_IDs_fixed) # we convert to a set for faster lookups - - df_li = self.lineage_list if not self.gui_mode else [posData.allData_li[i]['acdc_df'] for i in range(len(posData.allData_li))] + Cell_IDs_fixed = set(Cell_IDs_fixed) # we convert to a set for faster lookups + + df_li = ( + self.lineage_list + if not self.gui_mode + else [ + posData.allData_li[i]["acdc_df"] for i in range(len(posData.allData_li)) + ] + ) for frame_df in df_li: - if 'generation_num_tree' not in frame_df.columns or (not frame_df['generation_num_tree'].notna().any()): + if "generation_num_tree" not in frame_df.columns or ( + not frame_df["generation_num_tree"].notna().any() + ): continue frame_df = checked_reset_index_Cell_ID(frame_df) if consider_children: - children = frame_df[frame_df['parent_ID_tree'].isin(Cell_IDs_fixed)] + children = frame_df[frame_df["parent_ID_tree"].isin(Cell_IDs_fixed)] if not children.empty: - for parent_ID, children in children.groupby('parent_ID_tree'): - parent_cell_loc = fixed_dfs_lookup[parent_ID] # we get the parent cell from the lookup dictionary + for parent_ID, children in children.groupby("parent_ID_tree"): + parent_cell_loc = fixed_dfs_lookup[ + parent_ID + ] # we get the parent cell from the lookup dictionary parent_line = fixed_dfs[parent_cell_loc].loc[parent_ID] - children['root_ID_tree'] = parent_line['root_ID_tree'] - children['generation_num_tree'] = parent_line['generation_num_tree'] + 1 - first_frame_i = self._get_first_frame_i_for_ID(children.index[0]) + children["root_ID_tree"] = parent_line["root_ID_tree"] + children["generation_num_tree"] = ( + parent_line["generation_num_tree"] + 1 + ) + first_frame_i = self._get_first_frame_i_for_ID( + children.index[0] + ) df_sisters = self._get_df_from_frame_i(first_frame_i) if self.gui_mode: - sisters = set(df_sisters.loc[df_sisters['parent_ID_tree'] == parent_ID].index) + sisters = set( + df_sisters.loc[ + df_sisters["parent_ID_tree"] == parent_ID + ].index + ) for child in children.index: - child_sisters = [s for s in sisters if s != child] if len(sisters) > 1 else [-1] - child_sisters = child_sisters + [-1] * (self.max_daughters_added - len(child_sisters)) - children.loc[child, 'sister_ID_tree'] = child_sisters[0] if child_sisters else -1 - children.loc[child, self._get_extra_daughter_cols(num_daughters=len(child_sisters))] = child_sisters + child_sisters = ( + [s for s in sisters if s != child] + if len(sisters) > 1 + else [-1] + ) + child_sisters = child_sisters + [-1] * ( + self.max_daughters_added - len(child_sisters) + ) + children.loc[child, "sister_ID_tree"] = ( + child_sisters[0] if child_sisters else -1 + ) + children.loc[ + child, + self._get_extra_daughter_cols( + num_daughters=len(child_sisters) + ), + ] = child_sisters else: - sisters = set(df_sisters.loc[df_sisters['parent_ID_tree'] == parent_ID, 'Cell_ID']) - children['sister_ID_tree'] = [ - [s for s in sisters if s != cell_id] if len(sisters) > 1 else [-1] + sisters = set( + df_sisters.loc[ + df_sisters["parent_ID_tree"] == parent_ID, "Cell_ID" + ] + ) + children["sister_ID_tree"] = [ + [s for s in sisters if s != cell_id] + if len(sisters) > 1 + else [-1] for cell_id in children.index ] - Cell_IDs_fixed = Cell_IDs_fixed.union(children.index) # we add the children IDs to the set of Cell_IDs_fixed - fixed_dfs.append(children) # we append the children to the fixed_dfs list - indx = len(fixed_dfs) - 1 # we get the index of the children in the fixed_dfs list - fixed_dfs_lookup.update({children.index[i]: indx for i in range(len(children))}) # we update the lookup dictionary with the children + Cell_IDs_fixed = Cell_IDs_fixed.union( + children.index + ) # we add the children IDs to the set of Cell_IDs_fixed + fixed_dfs.append( + children + ) # we append the children to the fixed_dfs list + indx = ( + len(fixed_dfs) - 1 + ) # we get the index of the children in the fixed_dfs list + fixed_dfs_lookup.update( + {children.index[i]: indx for i in range(len(children))} + ) # we update the lookup dictionary with the children relevant_cells_mask = frame_df.index.isin(Cell_IDs_fixed) if not relevant_cells_mask.any(): @@ -1021,7 +1195,9 @@ def _update_consistency(self, fixed_frame_i=None, fixed_df=None, # Find the intersection of indices common_idx = frame_df.index.intersection(fixed_df.index) if not common_idx.empty: - frame_df.loc[common_idx, columns_to_replace] = fixed_df.loc[common_idx, columns_to_replace] + frame_df.loc[common_idx, columns_to_replace] = fixed_df.loc[ + common_idx, columns_to_replace + ] def propagate(self, frame_i, relevant_cells=None): """ @@ -1036,12 +1212,13 @@ def propagate(self, frame_i, relevant_cells=None): """ if self.gui_mode: posData = self.gui.data[self.gui.pos_i] - lineage_df = posData.allData_li[frame_i]['acdc_df'] + lineage_df = posData.allData_li[frame_i]["acdc_df"] else: lineage_df = self.lineage_list[frame_i] self.update_df_li_locally(lineage_df, frame_i) - self._update_consistency(fixed_frame_i=frame_i, - consider_children=True, Cell_IDs_fixed=relevant_cells) + self._update_consistency( + fixed_frame_i=frame_i, consider_children=True, Cell_IDs_fixed=relevant_cells + ) # This will probably be made obsolete by the gui_mode version # def load_lineage_df_list(self, df_li): @@ -1066,11 +1243,11 @@ def propagate(self, frame_i, relevant_cells=None): # for i, df in enumerate(df_li): # if df is None: # continue - + # if 'generation_num_tree' not in df.columns: # continue - # mask = (df['generation_num_tree'].isnull() | + # mask = (df['generation_num_tree'].isnull() | # df["generation_num_tree"].isna()) # if mask.any() or df["generation_num_tree"].empty: @@ -1083,7 +1260,7 @@ def propagate(self, frame_i, relevant_cells=None): # self.frames_for_dfs.add(i) # df_li_new.append(df) - # df_filter = df.index.isin(added_IDs) + # df_filter = df.index.isin(added_IDs) # for root_ID, group in df[df_filter].groupby('root_ID_tree'): # if root_ID not in families_root_IDs: # family = list(zip(group.index, group['generation_num_tree'])) @@ -1093,9 +1270,9 @@ def propagate(self, frame_i, relevant_cells=None): # # If the root_ID is already in families, we just update the family with the new cells # family_index = families_root_IDs.index(root_ID) # families[family_index].extend(zip(group.index, group['generation_num_tree'])) - + # added_IDs.update(group.index) - + # if df_li_new: # self.lineage_list = df_li_new @@ -1113,7 +1290,7 @@ def export_df(self, frame_i): df = self.lineage_list[frame_i].copy() if df.empty: - print(f'Warning: No dataframe for frame {frame_i} found.') + print(f"Warning: No dataframe for frame {frame_i} found.") df = reorg_sister_cells_for_export(df) @@ -1128,7 +1305,7 @@ def export_df(self, frame_i): df = df.drop(columns="frame_i") return df - + def export_lin_tree_info(self, frame_i): """ Return information about new, orphan, and lost cells between two consecutive frames. @@ -1144,17 +1321,17 @@ def export_lin_tree_info(self, frame_i): """ if frame_i == 0: return [], [], [] - + if not self.gui_mode: df_curr = self.lineage_list[frame_i] df_curr = checked_reset_index_Cell_ID(df_curr) - df_prev = self.lineage_list[frame_i-1] + df_prev = self.lineage_list[frame_i - 1] df_prev = checked_reset_index_Cell_ID(df_prev) - + else: posData = self.gui.data[self.gui.pos_i] - df_curr = posData.allData_li[frame_i]['acdc_df'] - df_prev = posData.allData_li[frame_i-1]['acdc_df'] + df_curr = posData.allData_li[frame_i]["acdc_df"] + df_prev = posData.allData_li[frame_i - 1]["acdc_df"] new_cells = set(df_curr.index) - set(df_prev.index) lost_cells = set(df_prev.index) - set(df_curr.index) @@ -1166,9 +1343,9 @@ def export_lin_tree_info(self, frame_i): for cell in new_cells: cell_row = df_curr.loc[cell] try: - mother = cell_row['parent_ID_tree'] + mother = cell_row["parent_ID_tree"] except KeyError: - mother = -1 # check for nan mother + mother = -1 # check for nan mother if mother == -1 or pd.isna(mother): orphan_cells.append(cell) else: @@ -1179,11 +1356,13 @@ def export_lin_tree_info(self, frame_i): lost_cells = [int(cell) for cell in lost_cells] cells_with_parent.sort(key=lambda x: x[1]) # Sort by mother ID - cells_with_parent = [(int(cell), int(mother)) for cell, mother in cells_with_parent] + cells_with_parent = [ + (int(cell), int(mother)) for cell, mother in cells_with_parent + ] orphan_cells = [int(cell) for cell in orphan_cells] return cells_with_parent, orphan_cells, lost_cells - + class tracker: """ @@ -1199,23 +1378,25 @@ class tracker: - updateGuiProgressBar(): Updates the GUI progress bar. (Used for GUI communication) - save_output(): Signals to the rest of the programme that the lineage tree should be saved. (Used for module 2) """ + def __init__(self): """ Initializes the CellACDC_normal_division_tracker object. """ pass - def track(self, - segm_video, - IoA_thresh:float = 0.8, - IoA_thresh_daughter:float = 0.25, - IoA_thresh_aggressive:float = 0.5, - min_daughter:int = 2, - max_daughter:int = 2, - record_lineage:bool = True, - return_tracked_lost_centroids:bool = True, - signals = None, - ): + def track( + self, + segm_video, + IoA_thresh: float = 0.8, + IoA_thresh_daughter: float = 0.25, + IoA_thresh_aggressive: float = 0.5, + min_daughter: int = 2, + max_daughter: int = 2, + record_lineage: bool = True, + return_tracked_lost_centroids: bool = True, + signals=None, + ): """ Tracks the segmented video frames and returns the tracked video. (Used for module 2) @@ -1228,15 +1409,17 @@ def track(self, - min_daughter (int, optional): Minimum number of daughter cells. Used for determining if a cell has divided. Defaults to 2. - max_daughter (int, optional): Maximum number of daughter cells. Used for determining if a cell has divided. Defaults to 2. - record_lineage (bool, optional): Flag to record and save lineage. Defaults to True. - + Returns: - list: Tracked video frames. """ if not record_lineage and return_tracked_lost_centroids: - print('return_tracked_lost_centroids is set to True if record_lineage is True.') + print( + "return_tracked_lost_centroids is set to True if record_lineage is True." + ) record_lineage = True - - pbar = tqdm(total=len(segm_video), desc='Tracking', ncols=100) + + pbar = tqdm(total=len(segm_video), desc="Tracking", ncols=100) if return_tracked_lost_centroids: self.tracked_lost_centroids = { @@ -1246,14 +1429,19 @@ def track(self, for frame_i, lab in enumerate(segm_video): if frame_i == 0: tracker = normal_division_tracker( - segm_video, IoA_thresh_daughter, min_daughter, - max_daughter, IoA_thresh, IoA_thresh_aggressive + segm_video, + IoA_thresh_daughter, + min_daughter, + max_daughter, + IoA_thresh, + IoA_thresh_aggressive, ) if record_lineage or return_tracked_lost_centroids: tree = normal_division_lineage_tree( - lab=lab, max_daughter=max_daughter, - min_daughter=min_daughter, - IoA_thresh_daughter=IoA_thresh_daughter + lab=lab, + max_daughter=max_daughter, + min_daughter=min_daughter, + IoA_thresh_daughter=IoA_thresh_daughter, ) pbar.update() rp = regionprops(segm_video[0]) @@ -1275,13 +1463,18 @@ def track(self, new_IDs = curr_IDs - prev_IDs if record_lineage or return_tracked_lost_centroids: tree.add_new_frame( - frame_i, mother_daughters, IDs_prev, IDs_curr_untracked, - assignments, curr_IDs, new_IDs + frame_i, + mother_daughters, + IDs_prev, + IDs_curr_untracked, + assignments, + curr_IDs, + new_IDs, ) tracked_lost_centroids_loc = [] for mother, _ in mother_daughters: mother_ID = IDs_prev[mother] - + found = False for obj in prev_rp: if obj.label == mother_ID: @@ -1291,12 +1484,15 @@ def track(self, if not found: labels = [obj.label for obj in rp] printl(mother, mother_ID, IDs_curr_untracked, labels) - raise ValueError('Something went wrong with the tracked lost centroids.') - + raise ValueError( + "Something went wrong with the tracked lost centroids." + ) if len(mother_daughters) != len(tracked_lost_centroids_loc): - raise ValueError('Something went wrong with the tracked lost centroids.') - + raise ValueError( + "Something went wrong with the tracked lost centroids." + ) + self.tracked_lost_centroids[frame_i] = tracked_lost_centroids_loc prev_IDs = curr_IDs.copy() @@ -1313,22 +1509,22 @@ def track(self, self.cca_dfs_auto = cca_li # here we would also save make sure to save self.tracked_lost_centroids, but since we already assigned it correctly from the get go we dont need to do that - tracked_video = tracker.tracked_video pbar.close() return tracked_video - def track_frame(self, - previous_frame_labels, - current_frame_labels, - IDs : NotGUIParam =None, - IoA_thresh: float = 0.8, - IoA_thresh_daughter:float = 0.25, - IoA_thresh_aggressive:float = 0.5, - min_daughter:int = 2, - max_daughter:int = 2, - unique_ID: NotGUIParam =None, - ): + def track_frame( + self, + previous_frame_labels, + current_frame_labels, + IDs: NotGUIParam = None, + IoA_thresh: float = 0.8, + IoA_thresh_daughter: float = 0.25, + IoA_thresh_aggressive: float = 0.5, + min_daughter: int = 2, + max_daughter: int = 2, + unique_ID: NotGUIParam = None, + ): """ Tracks cell division in a single frame. (This is used for real time tracking in the GUI) @@ -1351,7 +1547,14 @@ def track_frame(self, return current_frame_labels segm_video = [previous_frame_labels, current_frame_labels] - tracker = normal_division_tracker(segm_video, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh, IoA_thresh_aggressive) + tracker = normal_division_tracker( + segm_video, + IoA_thresh_daughter, + min_daughter, + max_daughter, + IoA_thresh, + IoA_thresh_aggressive, + ) tracker.track_frame(1, IDs=IDs, unique_ID=unique_ID) tracked_video = tracker.tracked_video @@ -1374,7 +1577,7 @@ def updateGuiProgressBar(self, signals): if signals is None: return - if hasattr(signals, 'innerPbar_available'): + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) @@ -1392,4 +1595,4 @@ def save_output(self): Returns: - None """ - pass \ No newline at end of file + pass diff --git a/cellacdc/trackers/DeepSea/DeepSea_tracker.py b/cellacdc/trackers/DeepSea/DeepSea_tracker.py index dd3815bb7..9d35243c6 100644 --- a/cellacdc/trackers/DeepSea/DeepSea_tracker.py +++ b/cellacdc/trackers/DeepSea/DeepSea_tracker.py @@ -14,10 +14,10 @@ from deepsea.model import DeepSeaTracker from deepsea.utils import track_cells -from cellacdc import myutils, printl -from cellacdc.models.DeepSea import _init_model, _resize_img -from cellacdc.models.DeepSea import image_size as segm_image_size -from cellacdc.models.DeepSea import _get_segm_transforms +from cellacdc import utils, printl +from cellacdc.segmenters.DeepSea import _init_model, _resize_img +from cellacdc.segmenters.DeepSea import image_size as segm_image_size +from cellacdc.segmenters.DeepSea import _get_segm_transforms from cellacdc.core import get_labels_to_IDs_mapper from . import _get_tracker_transforms @@ -29,25 +29,28 @@ torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True + class tracker: def __init__(self, gpu=False): torch_device, checkpoint, model = _init_model( - 'tracker.pth', DeepSeaTracker, gpu=gpu + "tracker.pth", DeepSeaTracker, gpu=gpu ) self.torch_device = torch_device self._transforms = _get_tracker_transforms() self._segm_transforms = _get_segm_transforms() self._checkpoint = checkpoint self.model = model - + def _resize_lab(self, lab, output_shape, rp): _lab_obj_to_resize = np.zeros(lab.shape, dtype=np.float16) lab_resized = np.zeros(output_shape, dtype=np.uint32) for obj in rp: _lab_obj_to_resize[obj.slice][obj.image] = 1.0 _lab_obj_resized = resize( - _lab_obj_to_resize, output_shape, anti_aliasing=True, - preserve_range=True + _lab_obj_to_resize, + output_shape, + anti_aliasing=True, + preserve_range=True, ).round() lab_resized[_lab_obj_resized == 1.0] = obj.label _lab_obj_to_resize[:] = 0.0 @@ -61,22 +64,19 @@ def _relabel_sequential(self, segm_video): return relabelled_video def track( - self, segm_video, image, min_size=10, annotate_lineage_tree=True, - signals=None - ): + self, segm_video, image, min_size=10, annotate_lineage_tree=True, signals=None + ): self.signals = signals segm_video = self._relabel_sequential(segm_video) labels_list = [] resize_img_list = [] pbar = tqdm(total=len(segm_video), ncols=100) if signals is not None: - signals.progress.emit('Resizing objects...') + signals.progress.emit("Resizing objects...") for img, lab in zip(image, segm_video): img = (255 * ((img - img.min()) / img.ptp())).astype(np.uint8) rp = regionprops(lab) - resized_img = _resize_img( - img, self.torch_device, self._segm_transforms - ) + resized_img = _resize_img(img, self.torch_device, self._segm_transforms) resized_lab = self._resize_lab( lab, output_shape=tuple(segm_image_size), rp=rp ) @@ -85,55 +85,67 @@ def track( pbar.update() pbar.close() if signals is not None: - signals.progress.emit('Tracking...') + signals.progress.emit("Tracking...") result = track_cells( - labels_list, resize_img_list, self.model, self.torch_device, - transforms=self._transforms, min_size=min_size + labels_list, + resize_img_list, + self.model, + self.torch_device, + transforms=self._transforms, + min_size=min_size, ) tracked_labels, tracked_centroids, tracked_imgs = result - + labels_to_IDs_mapper = self._get_labels_to_IDs_mapper(tracked_labels) - + if annotate_lineage_tree: self.cca_dfs = self._annotate_lineage_tree( tracked_labels, labels_to_IDs_mapper ) tracked_video = self._replace_tracked_IDs( - labels_list, tracked_labels, tracked_centroids, - labels_to_IDs_mapper, segm_video + labels_list, + tracked_labels, + tracked_centroids, + labels_to_IDs_mapper, + segm_video, ) return tracked_video def _annotate_lineage_tree(self, tracked_labels, labels_to_IDs_mapper): if self.signals is not None: - self.signals.progress.emit('Annotating lineage trees...') + self.signals.progress.emit("Annotating lineage trees...") from cellacdc.core import annotate_lineage_tree_from_labels + cca_dfs = annotate_lineage_tree_from_labels( tracked_labels, labels_to_IDs_mapper - ) + ) return cca_dfs def _get_labels_to_IDs_mapper(self, tracked_labels): if self.signals is not None: - self.signals.progress.emit('Mapping labels to IDs...') + self.signals.progress.emit("Mapping labels to IDs...") labels_to_IDs_mapper = get_labels_to_IDs_mapper(tracked_labels) return labels_to_IDs_mapper def _replace_tracked_IDs( - self, resized_labels_list, tracked_labels, tracked_centroids, - labels_to_IDs_mapper, segm_video - ): + self, + resized_labels_list, + tracked_labels, + tracked_centroids, + labels_to_IDs_mapper, + segm_video, + ): if self.signals is not None: - self.signals.progress.emit('Applying tracking information...') - + self.signals.progress.emit("Applying tracking information...") + _zip = zip(tracked_labels, tracked_centroids) IDs_prev = [] tracked_video = np.zeros_like(segm_video) for frame_i, track_info_frame in enumerate(_zip): tracked_frame_labels, tracked_frame_centroids = track_info_frame tracked_frame_IDs = [ - int(labels_to_IDs_mapper[label].split('_')[0]) + int(labels_to_IDs_mapper[label].split("_")[0]) for label in tracked_frame_labels ] lab = resized_labels_list[frame_i] @@ -141,16 +153,19 @@ def _replace_tracked_IDs( untracked_lab = segm_video[frame_i] rp = regionprops(lab) IDs_curr_untracked = [obj.label for obj in rp] - uniqueID = max( - max(IDs_prev, default=0), - max(IDs_curr_untracked, default=0), - max(tracked_frame_IDs, default=0) - ) + 1 + uniqueID = ( + max( + max(IDs_prev, default=0), + max(IDs_curr_untracked, default=0), + max(tracked_frame_IDs, default=0), + ) + + 1 + ) IDs_to_replace = { - lab[tuple(centr)]:idx + lab[tuple(centr)]: idx for idx, centr in enumerate(tracked_frame_centroids) } - IDs_prev = [] + IDs_prev = [] for obj in rp: idx_ID_to_replace = IDs_to_replace.get(obj.label) if idx_ID_to_replace is None: @@ -160,22 +175,20 @@ def _replace_tracked_IDs( newID = tracked_frame_IDs[idx_ID_to_replace] tracked_lab[untracked_lab == obj.label] = newID IDs_prev.append(newID) - + tracked_video[frame_i] = tracked_lab self.updateGuiProgressBar(self.signals) return tracked_video - + def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) return signals.progressBar.emit(1) - - \ No newline at end of file diff --git a/cellacdc/trackers/DeepSea/__init__.py b/cellacdc/trackers/DeepSea/__init__.py index 477a8e19b..5e9d882c8 100644 --- a/cellacdc/trackers/DeepSea/__init__.py +++ b/cellacdc/trackers/DeepSea/__init__.py @@ -1,17 +1,20 @@ -from cellacdc.models import DeepSea +from cellacdc.segmenters import DeepSea from deepsea import tracker_transforms import torchvision.transforms as transforms -image_size = [128,128] +image_size = [128, 128] image_means = [0.5] image_stds = [0.5] + def _get_tracker_transforms(): - return tracker_transforms.Compose([ - tracker_transforms.Resize(image_size), - tracker_transforms.Grayscale(num_output_channels=1), - tracker_transforms.ToTensor(), - tracker_transforms.Normalize(mean=image_means, std=image_stds) - ]) \ No newline at end of file + return tracker_transforms.Compose( + [ + tracker_transforms.Resize(image_size), + tracker_transforms.Grayscale(num_output_channels=1), + tracker_transforms.ToTensor(), + tracker_transforms.Normalize(mean=image_means, std=image_stds), + ] + ) diff --git a/cellacdc/trackers/TAPIR/TAPIR_tracker.py b/cellacdc/trackers/TAPIR/TAPIR_tracker.py index 567b7c6c0..ccfa70e06 100644 --- a/cellacdc/trackers/TAPIR/TAPIR_tracker.py +++ b/cellacdc/trackers/TAPIR/TAPIR_tracker.py @@ -21,43 +21,50 @@ from . import TAPIR_CHECKPOINT_PATH from .tracking import build_model, inference + class SizesToResize: values = np.arange(256, 1025, 128) + class TrackingInputs: - values = ['Intensity image', 'Segmented objects'] + values = ["Intensity image", "Segmented objects"] + class PointsToTrack: - values = ['Centroids', 'Contours'] + values = ["Centroids", "Contours"] + class tracker: - def __init__( - self, model_checkpoint_path: os.PathLike=TAPIR_CHECKPOINT_PATH - ): + def __init__(self, model_checkpoint_path: os.PathLike = TAPIR_CHECKPOINT_PATH): ckpt_state = np.load(model_checkpoint_path, allow_pickle=True).item() - params, state = ckpt_state['params'], ckpt_state['state'] + params, state = ckpt_state["params"], ckpt_state["state"] model = hk.transform_with_state(build_model) model_apply = jax.jit(model.apply) self.params = params self.state = state self.model_apply = model_apply - + def track( - self, segm_video, video_grayscale, - resize_to_square_with_size: SizesToResize=256, - max_distance=5, save_napari_tracks=False, - use_visibile_information=True, export_to=None, - signals=None, export_to_extension='.csv', - tracking_input: TrackingInputs='Intensity image', - which_points_to_track: PointsToTrack='Centroids', - number_of_points_per_object: int=8 - ): - + self, + segm_video, + video_grayscale, + resize_to_square_with_size: SizesToResize = 256, + max_distance=5, + save_napari_tracks=False, + use_visibile_information=True, + export_to=None, + signals=None, + export_to_extension=".csv", + tracking_input: TrackingInputs = "Intensity image", + which_points_to_track: PointsToTrack = "Centroids", + number_of_points_per_object: int = 8, + ): + if video_grayscale.ndim == 4: ndim = video_grayscale.ndim - msg = f'TAPIR can only track 2D frames over time. Input image is {ndim}D' + msg = f"TAPIR can only track 2D frames over time. Input image is {ndim}D" raise TypeError(msg) - + self._use_visibile_information = use_visibile_information self._which_points_to_track = which_points_to_track self.segm_video = segm_video @@ -68,27 +75,28 @@ def track( frames = skimage.transform.resize( video_grayscale, (num_frames, new_size, new_size) ) - frames = frames/frames.max() - self.resize_ratio_height = height/new_size - self.resize_ratio_width = width/new_size - + frames = frames / frames.max() + self.resize_ratio_height = height / new_size + self.resize_ratio_width = width / new_size + resized_segm_video = np.array( [resize_lab(lab, (new_size, new_size)) for lab in segm_video] ) - + # We track from last frame backwards reversed_resized_frames = frames[::-1] reversed_resized_segm = resized_segm_video[::-1] - + self.reversed_resized_segm = reversed_resized_segm - + frames_rgb = self._get_frames_to_track( - reversed_resized_frames, reversed_resized_segm, - tracking_input + reversed_resized_frames, reversed_resized_segm, tracking_input ) query_points, tracks_start_frames = self._initialize_query_points( - reversed_resized_segm, tracking_input, which_points_to_track, - number_of_points_per_object + reversed_resized_segm, + tracking_input, + which_points_to_track, + number_of_points_per_object, ) self.tracks_start_frames = tracks_start_frames @@ -96,26 +104,24 @@ def track( # plt.imshow(frames_rgb[0]) # plt.plot(query_points[:,2], query_points[:,1], 'r.') # plt.show() - + self.reversed_tracks, self.reversed_visibles = inference( - frames_rgb, query_points, self.model_apply, self.params, - self.state + frames_rgb, query_points, self.model_apply, self.params, self.state ) - + tracked_video = self._apply_tracks() - + if save_napari_tracks: self._save_napari_tracks(export_to) - + if export_to is not None: self._save_tracks(export_to) return tracked_video def _get_frames_to_track( - self, reversed_resized_frames, reversed_resized_segm, - tracking_input - ): - if tracking_input == 'Segmented objects': + self, reversed_resized_frames, reversed_resized_segm, tracking_input + ): + if tracking_input == "Segmented objects": frames = np.zeros(reversed_resized_segm.shape, dtype=np.float32) for frame_i, lab in enumerate(reversed_resized_segm): rp = skimage.measure.regionprops(lab) @@ -125,19 +131,19 @@ def _get_frames_to_track( frames[frame_i][obj.slice][obj.image] = obj_edt[obj.image] else: frames = reversed_resized_frames - frames_rgb = (skimage.color.gray2rgb(frames)*255).astype(np.uint8) + frames_rgb = (skimage.color.gray2rgb(frames) * 255).astype(np.uint8) return frames_rgb - + def _save_napari_tracks(self, export_to): - print('Saving napari tracks...') + print("Saving napari tracks...") napari_tracks = self.to_napari_tracks() if export_to is None: - napari_tracks_path = 'tapir_napari_tracks.csv' + napari_tracks_path = "tapir_napari_tracks.csv" else: - napari_tracks_path = export_to.replace('.csv', '_napari.csv') - df = pd.DataFrame(data=napari_tracks, columns=['ID', 'T', 'Y', 'X']) + napari_tracks_path = export_to.replace(".csv", "_napari.csv") + df = pd.DataFrame(data=napari_tracks, columns=["ID", "T", "Y", "X"]) df.to_csv(napari_tracks_path, index=False) - + def _build_tracks_table(self): tracks = self.reversed_tracks[:, ::-1] visibles = self.reversed_visibles[:, ::-1] @@ -150,9 +156,9 @@ def _build_tracks_table(self): segm_IDs = [] for tr, track in enumerate(tqdm(tracks, ncols=100)): track_ID = self._get_track_ID(resized_segm, track) - for frame_i, (x, y) in enumerate(track): - yc = y*self.resize_ratio_height - xc = x*self.resize_ratio_width + for frame_i, (x, y) in enumerate(track): + yc = y * self.resize_ratio_height + xc = x * self.resize_ratio_width visible = visibles[tr, frame_i] track_IDs.append(track_ID) frames.append(frame_i) @@ -163,22 +169,28 @@ def _build_tracks_table(self): resized_segm[frame_i], y, x, max_dist=self.max_dist ) segm_IDs.append(segm_ID) - df = pd.DataFrame({ - 'frame_i': frames, - 'track_ID': segm_IDs, - 'segm_ID': track_IDs, - 'y_point': yy, - 'x_point': xx, - 'visible': visibles_li - }).set_index(['frame_i', 'track_ID']).sort_index() + df = ( + pd.DataFrame( + { + "frame_i": frames, + "track_ID": segm_IDs, + "segm_ID": track_IDs, + "y_point": yy, + "x_point": xx, + "visible": visibles_li, + } + ) + .set_index(["frame_i", "track_ID"]) + .sort_index() + ) return df - + def _save_tracks(self, export_to): - print('Saving tracks...') + print("Saving tracks...") self.df_tracks.to_csv(export_to) def to_napari_tracks(self, use_centroids=False): - print('Building napari tracks data...') + print("Building napari tracks data...") napari_tracks = [] num_frames = len(self.reversed_resized_segm) Y, X = self.reversed_resized_segm.shape[-2:] @@ -190,16 +202,27 @@ def to_napari_tracks(self, use_centroids=False): if not visible and self._use_visibile_information: continue self._append_napari_point( - napari_tracks, y, x, num_frames, reversed_frame_i, - track_ID, use_centroids=use_centroids + napari_tracks, + y, + x, + num_frames, + reversed_frame_i, + track_ID, + use_centroids=use_centroids, ) napari_tracks = np.array(napari_tracks) return napari_tracks def _append_napari_point( - self, napari_tracks, y, x, num_frames, - reversed_frame_i, track_ID, use_centroids=False - ): + self, + napari_tracks, + y, + x, + num_frames, + reversed_frame_i, + track_ID, + use_centroids=False, + ): frame_i = num_frames - reversed_frame_i - 1 if use_centroids: lab = self.segm_video[frame_i] @@ -210,41 +233,41 @@ def _append_napari_point( napari_tracks.append((track_ID, frame_i, yc, xc)) break else: - yc = y*self.resize_ratio_height - xc = x*self.resize_ratio_width + yc = y * self.resize_ratio_height + xc = x * self.resize_ratio_width napari_tracks.append((track_ID, frame_i, yc, xc)) - + def _get_track_ID(self, resized_segm, track, max_dist=None): Y, X = resized_segm.shape[-2:] x, y = track[-1] # frame_i = self.tracks_start_frames[(round(y), round(x))] - # I still don't know how to get the start frame of each track - # because TAPIR returns a float even for the initialized query + # I still don't know how to get the start frame of each track + # because TAPIR returns a float even for the initialized query # point of each track frame_i = -1 y_int, x_int = round(y), round(x) - y_int = max(0, min(y_int, Y-1)) - x_int = max(0, min(x_int, X-1)) + y_int = max(0, min(y_int, Y - 1)) + x_int = max(0, min(x_int, X - 1)) track_ID = resized_segm[frame_i, y_int, x_int] return track_ID - + def _apply_tracks(self): - print('Applying tracks data...') - - self.df_tracks = self._build_tracks_table() - self.df_tracks = self.df_tracks[self.df_tracks.visible>0] - + print("Applying tracks data...") + + self.df_tracks = self._build_tracks_table() + self.df_tracks = self.df_tracks[self.df_tracks.visible > 0] + # Iterate tracks and determine tracked IDs old_IDs_tracks = {} tracked_IDs_tracks = {} - for (frame_i, track_ID), df in self.df_tracks.groupby(level=(0,1)): + for (frame_i, track_ID), df in self.df_tracks.groupby(level=(0, 1)): if track_ID == 0: continue - - oldID = df['segm_ID'].mode().iloc[0] + + oldID = df["segm_ID"].mode().iloc[0] if oldID == 0: continue - + if frame_i not in old_IDs_tracks: old_IDs_tracks[frame_i] = [oldID] tracked_IDs_tracks[frame_i] = [track_ID] @@ -256,41 +279,43 @@ def _apply_tracks(self): for frame_i in old_IDs_tracks.keys(): tracked_IDs = tracked_IDs_tracks[frame_i] old_IDs = old_IDs_tracks[frame_i] - + lab = self.segm_video[frame_i] rp = skimage.measure.regionprops(lab) IDs_curr_untracked = [obj.label for obj in rp] - - uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked)))+1 + + uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked))) + 1 tracked_lab = CellACDC_tracker.indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, uniqueID + old_IDs, tracked_IDs, IDs_curr_untracked, lab.copy(), rp, uniqueID ) tracked_video[frame_i] = tracked_lab return tracked_video - + def _initialize_query_points( - self, reversed_resized_segm, tracking_input, - which_points_to_track, number_of_points_per_object - ): + self, + reversed_resized_segm, + tracking_input, + which_points_to_track, + number_of_points_per_object, + ): first_lab = reversed_resized_segm[0] first_lab_rp = skimage.measure.regionprops(first_lab) num_objs = len(first_lab_rp) tracks_start_frames = {} - if which_points_to_track == 'Centroids': - query_points = np.zeros((num_objs, 3), dtype=int) + if which_points_to_track == "Centroids": + query_points = np.zeros((num_objs, 3), dtype=int) else: - all_contours = [] + all_contours = [] for o, obj in enumerate(first_lab_rp): - if which_points_to_track == 'Centroids': - if tracking_input == 'Segmented objects': + if which_points_to_track == "Centroids": + if tracking_input == "Segmented objects": # Track the center of the edt of the object # since edt is also the input image obj_edt = distance_transform_edt(obj.image) argmax = np.argmax(obj_edt) yc_loc, xc_loc = np.unravel_index(argmax, obj_edt.shape) ymin, xmin, _, _ = obj.bbox - yc, xc = yc_loc+ymin, xc_loc+xmin + yc, xc = yc_loc + ymin, xc_loc + xmin else: # Track the centroid of the object yc, xc = obj.centroid @@ -306,14 +331,15 @@ def _initialize_query_points( all_contours.append(contours) for x, y in contours: tracks_start_frames[(y, x)] = 0 - if which_points_to_track == 'Contours': + if which_points_to_track == "Contours": all_contours = np.concatenate(all_contours) nrows = len(all_contours) - query_points = np.zeros((nrows, 3), dtype=int) - query_points[:, 2] = all_contours[:,0] - query_points[:, 1] = all_contours[:,1] - + query_points = np.zeros((nrows, 3), dtype=int) + query_points[:, 2] = all_contours[:, 0] + query_points[:, 1] = all_contours[:, 1] + return query_points, tracks_start_frames + def url_help(): - return 'https://deepmind-tapir.github.io/' \ No newline at end of file + return "https://deepmind-tapir.github.io/" diff --git a/cellacdc/trackers/TAPIR/__init__.py b/cellacdc/trackers/TAPIR/__init__.py index 526c70b8b..b41d60327 100644 --- a/cellacdc/trackers/TAPIR/__init__.py +++ b/cellacdc/trackers/TAPIR/__init__.py @@ -1,7 +1,7 @@ import os -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_tapir() -_, model_path = myutils.get_model_path('TAPIR', create_temp_dir=False) -TAPIR_CHECKPOINT_PATH = os.path.join(model_path, 'tapir_checkpoint.npy') \ No newline at end of file +utils.check_install_tapir() +_, model_path = utils.get_model_path("TAPIR", create_temp_dir=False) +TAPIR_CHECKPOINT_PATH = os.path.join(model_path, "tapir_checkpoint.npy") diff --git a/cellacdc/trackers/TAPIR/tracking.py b/cellacdc/trackers/TAPIR/tracking.py index a54c62977..6706a81e3 100644 --- a/cellacdc/trackers/TAPIR/tracking.py +++ b/cellacdc/trackers/TAPIR/tracking.py @@ -4,6 +4,7 @@ from tapnet import tapir_model + def build_model(frames, query_points): """Compute point tracks and occlusions given frames and query points.""" model = tapir_model.TAPIR() @@ -15,6 +16,7 @@ def build_model(frames, query_points): ) return outputs + def preprocess_frames(frames): """Preprocess frames to model inputs. @@ -40,9 +42,12 @@ def postprocess_occlusions(occlusions, expected_dist): visibles: [num_points, num_frames], bool """ # visibles = occlusions < 0 - visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5 + visibles = (1 - jax.nn.sigmoid(occlusions)) * ( + 1 - jax.nn.sigmoid(expected_dist) + ) > 0.5 return visibles + def inference(frames, query_points, model_apply, params, state): """Inference on one video. @@ -65,9 +70,11 @@ def inference(frames, query_points, model_apply, params, state): outputs, _ = model_apply(params, state, rng, frames, query_points) outputs = tree.map_structure(lambda x: np.array(x[0]), outputs) tracks, occlusions, expected_dist = ( - outputs['tracks'], outputs['occlusion'], outputs['expected_dist'] + outputs["tracks"], + outputs["occlusion"], + outputs["expected_dist"], ) # Binarize occlusions visibles = postprocess_occlusions(occlusions, expected_dist) - return tracks, visibles \ No newline at end of file + return tracks, visibles diff --git a/cellacdc/trackers/Trackastra/Trackastra_tracker.py b/cellacdc/trackers/Trackastra/Trackastra_tracker.py index cc6de1d9c..5a59b6ae7 100644 --- a/cellacdc/trackers/Trackastra/Trackastra_tracker.py +++ b/cellacdc/trackers/Trackastra/Trackastra_tracker.py @@ -1,29 +1,32 @@ - import os from trackastra.model import Trackastra from trackastra.tracking import graph_to_ctc -from ... import _types, myutils, core +from ... import _types, utils, core from . import get_pretrained_model_names + class AvailableModels: values = get_pretrained_model_names() + class AvailableLinkingModes: - values = ['greedy', 'greedy_nodiv', 'ilp'] + values = ["greedy", "greedy_nodiv", "ilp"] + class AvailableCellDivisionModes: - values = ['Normal', 'Asymmetric'] + values = ["Normal", "Asymmetric"] + class tracker: def __init__( - self, - pretrained_model_name: AvailableModels='general_2d', - model_folder_path: _types.FolderPath='', - gpu=False - ) -> None: + self, + pretrained_model_name: AvailableModels = "general_2d", + model_folder_path: _types.FolderPath = "", + gpu=False, + ) -> None: """Initialize tracker Parameters @@ -31,29 +34,29 @@ def __init__( pretrained_model_name : AvailableModels, optional Pre-trained model name. Default is 'general_2d' model_folder_path : os.PathLike, optional - Path to the folder containing `config.yaml` file from + Path to the folder containing `config.yaml` file from custom training. Default is '' gpu : bool, optional - If `True`, attempts to try to use the GPU for inference. + If `True`, attempts to try to use the GPU for inference. Default is False - """ - device = myutils.get_torch_device() + """ + device = utils.get_torch_device() if model_folder_path: - self.model = Trackastra.from_folder( - model_folder_path, device=str(device) - ) + self.model = Trackastra.from_folder(model_folder_path, device=str(device)) else: self.model = Trackastra.from_pretrained( pretrained_model_name, device=str(device) ) - + def track( - self, segm_video, video_grayscale, - linking_mode: AvailableLinkingModes='greedy', - prevent_deleting_objects: bool=True, - cell_division_mode: AvailableCellDivisionModes='Normal', - record_lineage=True - ): + self, + segm_video, + video_grayscale, + linking_mode: AvailableLinkingModes = "greedy", + prevent_deleting_objects: bool = True, + cell_division_mode: AvailableCellDivisionModes = "Normal", + record_lineage=True, + ): """Track the objects in `segm_video` Parameters @@ -63,33 +66,31 @@ def track( video_grayscale : (T, Y, X) np.ndarray Input intensity images over time. linking_mode : {'greedy', 'greedy_nodiv', 'ilp'}, optional - Strategy used to link the predicted associations. Note that + Strategy used to link the predicted associations. Note that 'ilp' requires the package `motile`. Default is 'greedy' prevent_deleting_objects : bool, optional - If `True`, prevent Trackastra from removing untracked objects or - merging them with other objects. Note that these added objects + If `True`, prevent Trackastra from removing untracked objects or + merging them with other objects. Note that these added objects will not be tracked. Default is `True`. cell_division_mode : {'Normal', 'Asymmetric'}, optional - Type of cell division. `Normal` is the standard cell division, - where the mother cell divides into two daughter cells. For the - tracking, that means the two daughter cells get a new, unique ID - each. Note that division is not detected if + Type of cell division. `Normal` is the standard cell division, + where the mother cell divides into two daughter cells. For the + tracking, that means the two daughter cells get a new, unique ID + each. Note that division is not detected if `linking_mode == greedy_nodiv`. - - `Asymmetric` means that the mother cell grows one daughter - cell that eventually divides from the mother (e.g., budding yeast). - For the tracking, this means that the mother cell ID keeps - existing after division and the daughter cell gets a new, unique ID. + + `Asymmetric` means that the mother cell grows one daughter + cell that eventually divides from the mother (e.g., budding yeast). + For the tracking, this means that the mother cell ID keeps + existing after division and the daughter cell gets a new, unique ID. record_lineage : bool, optional - If `True`, store a list of cell lineage annotaions (Cell-ACDC format) - in the `self.cca_dfs` list (one DataFrame with index `Cell_ID` per - frame). When used through Cell-ACDC, this list will be saved - to the acdc_output CSV file. - """ - out = self.model.track( - video_grayscale, segm_video, mode=linking_mode - ) - + If `True`, store a list of cell lineage annotaions (Cell-ACDC format) + in the `self.cca_dfs` list (one DataFrame with index `Cell_ID` per + frame). When used through Cell-ACDC, this list will be saved + to the acdc_output CSV file. + """ + out = self.model.track(video_grayscale, segm_video, mode=linking_mode) + try: df_ctc, tracked_video = graph_to_ctc(out, segm_video) except Exception as e: @@ -99,58 +100,60 @@ def track( except Exception as e2: graph = out[1] df_ctc, tracked_video = graph_to_ctc(graph, segm_video) - if prevent_deleting_objects: - tracked_video = core.insert_missing_objects( - tracked_video, segm_video - ) - - if linking_mode == 'greedy_nodiv': + tracked_video = core.insert_missing_objects(tracked_video, segm_video) + + if linking_mode == "greedy_nodiv": return tracked_video - - acdc_df, cca_dfs, asym_segm_tracked = myutils.df_ctc_to_acdc_df( - df_ctc, tracked_video, cell_division_mode=cell_division_mode, - return_list=True, progressbar=True + + acdc_df, cca_dfs, asym_segm_tracked = utils.df_ctc_to_acdc_df( + df_ctc, + tracked_video, + cell_division_mode=cell_division_mode, + return_list=True, + progressbar=True, ) - - if cell_division_mode == 'Asymmetric': + + if cell_division_mode == "Asymmetric": return asym_segm_tracked - + if record_lineage: self.cca_dfs = cca_dfs - + return tracked_video def validate_input(self, segm_video, progress=True): import skimage.measure + warning_text = None if progress: from tqdm import tqdm + pbar = tqdm( - total=len(segm_video), desc='Validating input', unit='frame', - ncols=100 + total=len(segm_video), desc="Validating input", unit="frame", ncols=100 ) - + empty_frames = [] for frame_i, lab in enumerate(segm_video): rp = skimage.measure.regionprops(lab) if len(rp) == 0: - empty_frames.append(frame_i+1) - + empty_frames.append(frame_i + 1) + if progress: pbar.update(1) - + if empty_frames: warning_text = ( - 'Trackastra requires that each frame has at least one object.\n\n' - f'The following frame numbers have no objects:\n\n{empty_frames}' + "Trackastra requires that each frame has at least one object.\n\n" + f"The following frame numbers have no objects:\n\n{empty_frames}" ) - + if progress: pbar.close() - + return warning_text + def url_help(): - return 'https://github.com/weigertlab/trackastra' \ No newline at end of file + return "https://github.com/weigertlab/trackastra" diff --git a/cellacdc/trackers/Trackastra/__init__.py b/cellacdc/trackers/Trackastra/__init__.py index 159d3595b..5e5120d1c 100644 --- a/cellacdc/trackers/Trackastra/__init__.py +++ b/cellacdc/trackers/Trackastra/__init__.py @@ -1,19 +1,20 @@ import os import json -from ... import myutils +from ... import utils -myutils.check_install_trackastra() +utils.check_install_trackastra() import trackastra trackastra_folderpath = os.path.dirname(os.path.abspath(trackastra.__file__)) pretraned_json_filepath = os.path.join( - trackastra_folderpath, 'model', 'pretrained.json' + trackastra_folderpath, "model", "pretrained.json" ) + def get_pretrained_model_names(): - with open(pretraned_json_filepath, encoding='utf-8') as file: + with open(pretraned_json_filepath, encoding="utf-8") as file: json_data = json.load(file) - - return list(json_data.keys()) \ No newline at end of file + + return list(json_data.keys()) diff --git a/cellacdc/trackers/YeaZ/YeaZ_tracker.py b/cellacdc/trackers/YeaZ/YeaZ_tracker.py index dc703c1e5..976fd0c5d 100755 --- a/cellacdc/trackers/YeaZ/YeaZ_tracker.py +++ b/cellacdc/trackers/YeaZ/YeaZ_tracker.py @@ -3,6 +3,7 @@ from . import tracking + class tracker: def __init__(self): pass diff --git a/cellacdc/trackers/YeaZ/tracking.py b/cellacdc/trackers/YeaZ/tracking.py index 7785aadf9..91a6c1976 100755 --- a/cellacdc/trackers/YeaZ/tracking.py +++ b/cellacdc/trackers/YeaZ/tracking.py @@ -18,6 +18,7 @@ except ModuleNotFoundError as e: pass + def correspondence(prev, curr, use_scipy=True, use_modified_yeaz=True): """ source: YeaZ modified by Cell-ACDC developers @@ -37,19 +38,16 @@ def correspondence(prev, curr, use_scipy=True, use_modified_yeaz=True): IDs_curr_untracked = [obj.label for obj in regionprops(curr)] IDs_prev = [obj.label for obj in regionprops(prev)] if IDs_prev or IDs_curr_untracked: - uniqueID = max( - max(IDs_prev, default=0), - max(IDs_curr_untracked, default=0) - ) + 1 + uniqueID = max(max(IDs_prev, default=0), max(IDs_curr_untracked, default=0)) + 1 else: uniqueID = 1 tracked_lab = CellACDC_tracker.indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - curr.copy(), rp, uniqueID + old_IDs, tracked_IDs, IDs_curr_untracked, curr.copy(), rp, uniqueID ) return tracked_lab + def scipy_align(m1, m2, acdc_yeaz=True): """ source: YeaZ modified by Cell-ACDC @@ -72,11 +70,12 @@ def scipy_align(m1, m2, acdc_yeaz=True): d.pop(-1, None) return d + def updateGuiProgressBar(signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) @@ -84,6 +83,7 @@ def updateGuiProgressBar(signals): signals.progressBar.emit(1) + def correspondence_stack(stack, signals=None): """ source: YeaZ @@ -94,15 +94,16 @@ def correspondence_stack(stack, signals=None): tracked_stack[0] = stack[0] for idx in tqdm(range(len(stack)), ncols=100): try: - curr = stack[idx+1] + curr = stack[idx + 1] prev = tracked_stack[idx] except IndexError: continue - tracked_stack[idx+1] = correspondence(prev, curr) + tracked_stack[idx + 1] = correspondence(prev, curr) updateGuiProgressBar(signals) # tracked_stack = relabel_sequential(tracked_stack)[0] return tracked_stack + def hungarian_align(m1, m2, acdc_yeaz=True): """ source: YeaZ @@ -126,50 +127,54 @@ def hungarian_align(m1, m2, acdc_yeaz=True): d.pop(-1, None) return d + def cell_to_features(im, c, nsamples=None, time=None): """ source: YeaZ Embeds cell c in image im into feature space """ - coord = np.argwhere(im==c) + coord = np.argwhere(im == c) area = coord.shape[0] if nsamples is not None: samples = np.random.choice(area, min(nsamples, area), replace=False) - sampled = coord[samples,:] + sampled = coord[samples, :] else: sampled = coord com = sampled.mean(axis=0) - return {'cell': c, - 'time': time, - 'sqrtarea': np.sqrt(area), - 'area': area, - 'com_x': com[0], - 'com_y': com[1]} + return { + "cell": c, + "time": time, + "sqrtarea": np.sqrt(area), + "area": area, + "com_x": com[0], + "com_y": com[1], + } + def get_features_acdc(m, t): rp = regionprops(m) features = { - 'cell': [], - 'time': [], - 'sqrtarea': [], - 'area': [], - 'com_x': [], - 'com_y': [] + "cell": [], + "time": [], + "sqrtarea": [], + "area": [], + "com_x": [], + "com_y": [], } for obj in rp: area = obj.area y, x = obj.centroid - features['cell'].append(obj.label) - features['time'].append(t) - features['sqrtarea'].append(sqrt(area)) - features['area'].append(area) - features['com_x'].append(y) - features['com_y'].append(x) + features["cell"].append(obj.label) + features["time"].append(t) + features["sqrtarea"].append(sqrt(area)) + features["area"].append(area) + features["com_x"].append(y) + features["com_y"].append(x) df = pd.DataFrame(features) - return df, dict(enumerate(features['cell'])) + return df, dict(enumerate(features["cell"])) def get_features(m, t): @@ -179,6 +184,7 @@ def get_features(m, t): features = [cell_to_features(m, c, time=t) for c in cells] return pd.DataFrame(features), dict(enumerate(cells)) + def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): """ source: YeaZ @@ -188,8 +194,8 @@ def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): make it more important). """ # Modify to compute use more computed features - #cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] - cols = ['com_x', 'com_y', 'area'] + # cols = ['com_x', 'com_y', 'roundness', 'sqrtarea'] + cols = ["com_x", "com_y", "area"] get_features_func = get_features_acdc if acdc_yeaz else get_features @@ -200,19 +206,18 @@ def cell_distance(m1, m2, weight_com=3, acdc_yeaz=True): # feat1_acdc, ix_to_cell1_acdc = get_features_acdc(m1, 1) # Check if one of matrices doesn't contain cells - if len(feat1)==0 or len(feat2)==0: + if len(feat1) == 0 or len(feat2) == 0: return None, None, None df = pd.concat((feat1, feat2)) df[cols] = scale(df[cols]) # give more importance to center of mass - df[['com_x', 'com_y']] = df[['com_x', 'com_y']] * weight_com + df[["com_x", "com_y"]] = df[["com_x", "com_y"]] * weight_com # pairwise euclidean dist dist = euclidean_distances( - df.loc[df['time']==1][cols], - df.loc[df['time']==2][cols] + df.loc[df["time"] == 1][cols], df.loc[df["time"] == 2][cols] ) return dist, ix_to_cell1, ix_to_cell2 @@ -233,10 +238,10 @@ def make_square(m): source: YeaZ Turns matrix into square matrix, as required by Munkres algorithm """ - r,c = m.shape - if r==c: + r, c = m.shape + if r == c: return m - elif r>c: - return zero_pad(m, (r,r)) + elif r > c: + return zero_pad(m, (r, r)) else: - return zero_pad(m, (c,c)) + return zero_pad(m, (c, c)) diff --git a/cellacdc/trackers/delta/__init__.py b/cellacdc/trackers/delta/__init__.py index 5d7863ed5..f6ae895b6 100644 --- a/cellacdc/trackers/delta/__init__.py +++ b/cellacdc/trackers/delta/__init__.py @@ -4,4 +4,4 @@ @author: jroberts / jamesr787 """ -from cellacdc.models import delta \ No newline at end of file +from cellacdc.segmenters import delta diff --git a/cellacdc/trackers/delta/delta_tracker.py b/cellacdc/trackers/delta/delta_tracker.py index e1ddfc9b8..0c6f8f85f 100644 --- a/cellacdc/trackers/delta/delta_tracker.py +++ b/cellacdc/trackers/delta/delta_tracker.py @@ -19,15 +19,9 @@ class FakeReader: - - def __init__(self, - x, - y, - channels, - timepoints, - filename, - original_video, - starting_frame): + def __init__( + self, x, y, channels, timepoints, filename, original_video, starting_frame + ): """ Initialize experiment reader @@ -61,14 +55,15 @@ def __init__(self, self.original_video = original_video self.starting_frame = starting_frame - def getframes(self, - squeeze_dimensions: bool = True, - resize: Tuple[int, int] = None, - rescale: Tuple[float, float] = None, - globalrescale: Tuple[float, float] = None, - rotate: float = None, - **kwargs - ): + def getframes( + self, + squeeze_dimensions: bool = True, + resize: Tuple[int, int] = None, + rescale: Tuple[float, float] = None, + globalrescale: Tuple[float, float] = None, + rotate: float = None, + **kwargs, + ): """ Get frames from experiment. @@ -102,9 +97,7 @@ def getframes(self, dt: Union[str, type] = self.dtype if rescale is None else np.float32 if resize is None: - output = np.empty( - [self.timepoints, self.y, self.x], dtype=dt - ) + output = np.empty([self.timepoints, self.y, self.x], dtype=dt) else: output = np.empty( [self.timepoints, resize[0], resize[1]], @@ -113,7 +106,6 @@ def getframes(self, # Load images: for f in range(self.timepoints): - idx = f + self.starting_frame frame = self.original_video[idx].astype(np.uint16) @@ -138,7 +130,6 @@ def getframes(self, class tracker: - def __init__(self, **params): """ Initializes Tracker @@ -163,8 +154,7 @@ def __init__(self, **params): self.params = params - def __read_tiff(self, - path): + def __read_tiff(self, path): """ Reads multipage tiff to numpy array. @@ -186,8 +176,7 @@ def __read_tiff(self, images.append(np.array(img)) return np.array(images) - def __load_model_and_presets(self, - model_type): + def __load_model_and_presets(self, model_type): """ Loads Presets for 2D or mothermachine, initializes model for tracking and loads model weights. @@ -222,18 +211,20 @@ def __load_model_and_presets(self, except ValueError: # Downloads model weights and configuration files for 2D and mothermachine - download_assets(load_models=True, - load_sets=False, - load_evals=False, - config_level='local') - - if self.params['single mothermachine chamber'] and model_type == 'mothermachine': - self.models.pop('rois') - - def track(self, - segm_video, - signals=None, - export_to: os.PathLike=None): + download_assets( + load_models=True, + load_sets=False, + load_evals=False, + config_level="local", + ) + + if ( + self.params["single mothermachine chamber"] + and model_type == "mothermachine" + ): + self.models.pop("rois") + + def track(self, segm_video, signals=None, export_to: os.PathLike = None): """ Tracks Cells @@ -249,13 +240,13 @@ def track(self, """ # Loads Presets and Initializes Model - self.__load_model_and_presets(model_type=self.params['model_type']) + self.__load_model_and_presets(model_type=self.params["model_type"]) # Original Shape original_shape = segm_video[0].shape # Get original video and original image size - original_video = self.__read_tiff(self.params['original_images_path']) + original_video = self.__read_tiff(self.params["original_images_path"]) reference = utils.rangescale(original_video[0], (0, 1)) # Preprocess Segmentation Video @@ -266,14 +257,18 @@ def track(self, img = cv2.resize(img, cfg.target_size_seg[::-1]) img_sm = (img > 0.5).astype(np.uint8) if cfg.crop_windows: - img_sm = img_sm[: original_shape[0], : original_shape[1]].astype(np.uint8) + img_sm = img_sm[: original_shape[0], : original_shape[1]].astype( + np.uint8 + ) seg_stack.append(img_sm) segm_video = seg_stack # Preprocess Original Video box = utils.CroppingBox( - xtl=0, ytl=0, - xbr=reference.shape[1], ybr=reference.shape[0], + xtl=0, + ytl=0, + xbr=reference.shape[1], + ybr=reference.shape[0], ) img_stack = [] if len(original_video) != len(segm_video): @@ -283,30 +278,35 @@ def track(self, for frame in range(len(segm_video)): idx = frame + starting_frame # Crop and scale: - i = utils.rangescale(utils.cropbox(original_video[idx], box), rescale=(0, 1)) + i = utils.rangescale( + utils.cropbox(original_video[idx], box), rescale=(0, 1) + ) # Append i as is to input images stack: img_stack.append(i) # Get Save Path (File Name is same as Original Images + .format) - savepath = self.params['original_images_path'] - filename = savepath.replace('.tif', '') + savepath = self.params["original_images_path"] + filename = savepath.replace(".tif", "") # Init reader - xpreader = FakeReader(x=original_shape[1], - y=original_shape[0], - channels=0, - timepoints=len(segm_video), - filename=savepath, - original_video=original_video, - starting_frame=starting_frame - ) + xpreader = FakeReader( + x=original_shape[1], + y=original_shape[0], + channels=0, + timepoints=len(segm_video), + filename=savepath, + original_video=original_video, + starting_frame=starting_frame, + ) # Init Position - xp = pipeline.Position(position_nb=0, - reader=xpreader, - models=self.models, - drift_correction=False, - crop_windows=cfg.crop_windows) + xp = pipeline.Position( + position_nb=0, + reader=xpreader, + models=self.models, + drift_correction=False, + crop_windows=cfg.crop_windows, + ) # Preprocess xp.preprocess(rotation_correction=False) @@ -325,13 +325,14 @@ def track(self, tracked_video = np.array(xp.rois[0].label_stack, dtype=np.uint8) # Save Results - if self.params['legacy']: + if self.params["legacy"]: xp.legacysave(filename + ".mat") - if self.params['pickle']: + if self.params["pickle"]: import pickle + with open(filename + ".pkl", "wb") as file: pickle.dump(self, file) - if self.params['movie']: + if self.params["movie"]: movie = xp.results_movie(frames=list(range(len(segm_video)))) utils.vidwrite(movie, filename + ".mp4", verbose=False) diff --git a/cellacdc/trackers/trackpy/__init__.py b/cellacdc/trackers/trackpy/__init__.py index 4051611f2..466c91042 100644 --- a/cellacdc/trackers/trackpy/__init__.py +++ b/cellacdc/trackers/trackpy/__init__.py @@ -1,3 +1,3 @@ -from cellacdc import myutils +from cellacdc import utils -myutils.check_install_package('trackpy') \ No newline at end of file +utils.check_install_package("trackpy") diff --git a/cellacdc/trackers/trackpy/trackpy_tracker.py b/cellacdc/trackers/trackpy/trackpy_tracker.py index 8912130cb..7aecfc9b6 100644 --- a/cellacdc/trackers/trackpy/trackpy_tracker.py +++ b/cellacdc/trackers/trackpy/trackpy_tracker.py @@ -15,14 +15,18 @@ DEBUG = False + class SearchRangeUnits: - values = ['micrometre', 'pixels'] + values = ["micrometre", "pixels"] + class NeighborStrategies: - values = ['KDTree', 'BTree'] + values = ["KDTree", "BTree"] + class LinkStrategies: - values = ['recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'] + values = ["recursive", "nonrecursive", "numba", "hybrid", "drop", "auto"] + class tracker: def __init__(self) -> None: @@ -36,73 +40,73 @@ def _set_frame_features(self, lab, frame_i, tp_df): zc = None else: zc, yc, xc = obj.centroid - tp_df['x'].append(xc) - tp_df['y'].append(yc) + tp_df["x"].append(xc) + tp_df["y"].append(yc) if zc is not None: - tp_df['z'].append(zc) - tp_df['frame'].append(frame_i) - tp_df['ID'].append(obj.label) + tp_df["z"].append(zc) + tp_df["frame"].append(frame_i) + tp_df["ID"].append(obj.label) def _get_pos_columns( - self, tp_df, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, - search_range_unit - ): - is_3D = 'z' in tp_df.columns - if search_range_unit == 'pixels': + self, tp_df, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, search_range_unit + ): + is_3D = "z" in tp_df.columns + if search_range_unit == "pixels": if is_3D: - return ['x', 'y', 'z'] + return ["x", "y", "z"] else: - return ['x', 'y'] - + return ["x", "y"] + pos_columns = [] if is_3D: - tp_df['z_um'] = tp_df['z'] * PhysicalSizeZ - pos_columns.append('z_um') - - tp_df['x_um'] = tp_df['x'] * PhysicalSizeX - tp_df['y_um'] = tp_df['y'] * PhysicalSizeY - pos_columns = ['x_um', 'y_um', *pos_columns] - return pos_columns - + tp_df["z_um"] = tp_df["z"] * PhysicalSizeZ + pos_columns.append("z_um") + + tp_df["x_um"] = tp_df["x"] * PhysicalSizeX + tp_df["y_um"] = tp_df["y"] * PhysicalSizeY + pos_columns = ["x_um", "y_um", *pos_columns] + return pos_columns + def track( - self, segm_video, - search_range_unit: SearchRangeUnits='micrometre', - search_range=10.0, - memory=0, - adaptive_stop: float=1.0, - adaptive_step=0.95, - dynamic_predictor=False, - neighbor_strategy: NeighborStrategies='KDTree', - link_strategy: LinkStrategies='recursive', - signals=None, - export_to=None, - PhysicalSizeX=1.0, - PhysicalSizeY=1.0, - PhysicalSizeZ=1.0, - export_to_extension='.csv' - ): + self, + segm_video, + search_range_unit: SearchRangeUnits = "micrometre", + search_range=10.0, + memory=0, + adaptive_stop: float = 1.0, + adaptive_step=0.95, + dynamic_predictor=False, + neighbor_strategy: NeighborStrategies = "KDTree", + link_strategy: LinkStrategies = "recursive", + signals=None, + export_to=None, + PhysicalSizeX=1.0, + PhysicalSizeY=1.0, + PhysicalSizeZ=1.0, + export_to_extension=".csv", + ): """_summary_ Parameters ---------- search_range_unit : {'micrometres', 'pixels'}, default 'micrometres' - Physical unit of the `search_range`. If 'pixels', PhysicalSizes will - be ignored. + Physical unit of the `search_range`. If 'pixels', PhysicalSizes will + be ignored. search_range : float, optional - Radius of the circle centerd at the object at previous frame where - to search for the object at current frame. - - This is equivalent to the maximum distance the object is allowed - to travel between frames to be considered the same object. - - The unit is pixels for isotropic data (typically 2D over time) and + Radius of the circle centerd at the object at previous frame where + to search for the object at current frame. + + This is equivalent to the maximum distance the object is allowed + to travel between frames to be considered the same object. + + The unit is pixels for isotropic data (typically 2D over time) and in micrometers for anisotropic data (typically 3D over time). - + Default is 10.0. adaptive_stop : float, default 1.0 - If not None, when encountering an oversize subnet, retry by - progressively reducing search_range until the subnet is solvable. - If search_range becomes less or equal than `adaptive_stop`, give up + If not None, when encountering an oversize subnet, retry by + progressively reducing search_range until the subnet is solvable. + If search_range becomes less or equal than `adaptive_stop`, give up and raise a `SubnetOversizeException`. adaptive_step : float, default 0.95 Reduce search_range by multiplying it by this factor. @@ -117,16 +121,16 @@ def track( ------- (T, Y, X) or (T, Z, Y, X) np.array of ints Tracked segmentation masks with the same shape as input `segm_video`. - """ + """ # Handle string input for adaptive_stop if isinstance(adaptive_stop, str): - if adaptive_stop == 'None': + if adaptive_stop == "None": adaptive_stop = None else: adaptive_stop = float(adaptive_stop) - + self.setProgressBarMaximum(signals, len(segm_video)) - + # Build tp DataFrame --> https://soft-matter.github.io/trackpy/v0.5.0/generated/trackpy.link.html#trackpy.link tp_df = defaultdict(list) pbar = tqdm(total=len(segm_video), ncols=100) @@ -135,33 +139,33 @@ def track( pbar.update() self.updateGuiProgressBar(signals) pbar.close() - + tp_df = pd.DataFrame(tp_df) pos_columns = self._get_pos_columns( - tp_df, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, - search_range_unit + tp_df, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, search_range_unit ) # Run tracker if dynamic_predictor: predictor = tp.predict.NearestVelocityPredict() else: predictor = tp - + tp_out_df = predictor.link_df( - tp_df, search_range, + tp_df, + search_range, memory=int(memory), - adaptive_stop=adaptive_stop, + adaptive_stop=adaptive_stop, adaptive_step=adaptive_step, neighbor_strategy=neighbor_strategy, link_strategy=link_strategy, - pos_columns=pos_columns - ).set_index('frame') - + pos_columns=pos_columns, + ).set_index("frame") + if export_to is not None: tp_out_df.to_csv(export_to) - - tp_out_df['particle'] += 1 # trackpy starts from 0 with tracked ids - + + tp_out_df["particle"] += 1 # trackpy starts from 0 with tracked ids + # Generate tracked video data tracked_video = np.zeros_like(segm_video) for frame_i, lab in enumerate(segm_video): @@ -177,66 +181,66 @@ def track( IDs_curr_untracked = [obj.label for obj in rp] if DEBUG: - printl(f'Current untracked IDs: {IDs_curr_untracked}') + printl(f"Current untracked IDs: {IDs_curr_untracked}") if not IDs_curr_untracked: # No cells segmented continue - + try: - tracked_IDs = tp_out_df_frame['particle'].astype(int).to_list() - old_IDs = tp_out_df_frame['ID'].astype(int).to_list() + tracked_IDs = tp_out_df_frame["particle"].astype(int).to_list() + old_IDs = tp_out_df_frame["ID"].astype(int).to_list() except AttributeError: # Single cell - tracked_IDs = [int(tp_out_df_frame['particle'])] - old_IDs = [int(tp_out_df_frame['ID'])] - + tracked_IDs = [int(tp_out_df_frame["particle"])] + old_IDs = [int(tp_out_df_frame["ID"])] + if not tracked_IDs: # No cells tracked continue - uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked)))+1 + uniqueID = max((max(tracked_IDs), max(IDs_curr_untracked))) + 1 if DEBUG: - print('-------------------------') - print(f'Tracking frame n. {frame_i+1}') + print("-------------------------") + print(f"Tracking frame n. {frame_i + 1}") for old_ID, tracked_ID in zip(old_IDs, tracked_IDs): - print(f'Tracking ID {old_ID} --> {tracked_ID}') - print('-------------------------') - + print(f"Tracking ID {old_ID} --> {tracked_ID}") + print("-------------------------") + tracked_lab = CellACDC_tracker.indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, - lab.copy(), rp, uniqueID + old_IDs, tracked_IDs, IDs_curr_untracked, lab.copy(), rp, uniqueID ) tracked_video[frame_i] = tracked_lab self.updateGuiProgressBar(signals) - + return tracked_video - + def setProgressBarMaximum(self, signals, num_frames): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) - signals.sigInitInnerPbar.emit(num_frames*2) + signals.sigInitInnerPbar.emit(num_frames * 2) return - - signals.initProgressBar.emit(num_frames*2) - + + signals.initProgressBar.emit(num_frames * 2) + def updateGuiProgressBar(self, signals): if signals is None: return - - if hasattr(signals, 'innerPbar_available'): + + if hasattr(signals, "innerPbar_available"): if signals.innerPbar_available: # Use inner pbar of the GUI widget (top pbar is for positions) signals.innerProgressBar.emit(1) return signals.progressBar.emit(1) - + + def url_help(): - return 'https://soft-matter.github.io/trackpy/v0.5.0/generated/trackpy.link.html#trackpy.link' \ No newline at end of file + return "https://soft-matter.github.io/trackpy/v0.5.0/generated/trackpy.link.html#trackpy.link" diff --git a/cellacdc/transformation.py b/cellacdc/transformation.py index 113978ba4..4dccc0974 100644 --- a/cellacdc/transformation.py +++ b/cellacdc/transformation.py @@ -1,4 +1,4 @@ -import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET import math import pandas as pd @@ -13,6 +13,7 @@ from typing import List + def resize_lab(lab, output_shape, rp=None): if rp is None: rp = skimage.measure.regionprops(lab) @@ -21,13 +22,13 @@ def resize_lab(lab, output_shape, rp=None): for obj in rp: _lab_obj_to_resize[obj.slice][obj.image] = 1.0 _lab_obj_resized = resize( - _lab_obj_to_resize, output_shape, anti_aliasing=True, - preserve_range=True + _lab_obj_to_resize, output_shape, anti_aliasing=True, preserve_range=True ).round() lab_resized[_lab_obj_resized == 1.0] = obj.label _lab_obj_to_resize[:] = 0.0 return lab_resized + def crop_2D(img, xy_range, tolerance=0, return_copy=True): (xmin, xmax), (ymin, ymax) = xy_range Y, X = img.shape @@ -39,7 +40,7 @@ def crop_2D(img, xy_range, tolerance=0, return_copy=True): xmax = X if xmax > X else round(xmax) ymin = 0 if ymin < 0 else round(ymin) ymax = Y if ymax > Y else round(ymax) - crop_shape = (ymax-ymin, xmax-xmin) + crop_shape = (ymax - ymin, xmax - xmin) crop_slice = (slice(ymin, ymax, None), slice(xmin, xmax, None)) if return_copy: cropped = np.zeros(crop_shape, dtype=img.dtype) @@ -48,17 +49,19 @@ def crop_2D(img, xy_range, tolerance=0, return_copy=True): cropped = img[crop_slice] return cropped, crop_slice + def del_objs_outside_segm_roi(segm_roi, segm): - del_IDs = np.unique(segm[segm_roi==0]) + del_IDs = np.unique(segm[segm_roi == 0]) cleared_segm = segm.copy() clearedIDs = [] for del_ID in del_IDs: if del_ID == 0: continue - cleared_segm[segm==del_ID] = 0 + cleared_segm[segm == del_ID] = 0 clearedIDs.append(del_ID) return cleared_segm, clearedIDs + def trackmate_xml_to_df(xml_file): IDs = [] xx = [] @@ -69,28 +72,23 @@ def trackmate_xml_to_df(xml_file): Tracks = tree.getroot() for i, particle in enumerate(Tracks): - ID = i+1 + ID = i + 1 for t, detection in enumerate(particle): attrib = detection.attrib IDs.append(ID) - xx.append(attrib['x']) - yy.append(attrib['y']) - zz.append(attrib['z']) - frame_idxs.append(attrib['t']) - - df = pd.DataFrame({ - 'frame_i': frame_idxs, - 'ID': IDs, - 'x': xx, - 'y': yy, - 'z': zz - }) + xx.append(attrib["x"]) + yy.append(attrib["y"]) + zz.append(attrib["z"]) + frame_idxs.append(attrib["t"]) + + df = pd.DataFrame({"frame_i": frame_idxs, "ID": IDs, "x": xx, "y": yy, "z": zz}) return df + def retrack_based_on_untracked_first_frame( - tracked_video, first_untracked_lab, uniqueID=None - ): - """Re-tack the objects in the first frame of `tracked_video` to have the + tracked_video, first_untracked_lab, uniqueID=None +): + """Re-tack the objects in the first frame of `tracked_video` to have the same IDs as in `first_untracked_lab` Parameters @@ -98,30 +96,30 @@ def retrack_based_on_untracked_first_frame( tracked_video : (T, Y, X) or (T, Z, Y, X) of ints Array with the segmentation instances of the tracked objects first_untracked_lab : (Y, X) or (Z, Y, X) of ints - Array with the segmentation instances of the objects in the first + Array with the segmentation instances of the objects in the first frame before they were tracked uniqueID : int, optional - If not None, it will be used as first of the unique IDs. - If None, this will be initialized to the maximum in `tracked_video`. + If not None, it will be used as first of the unique IDs. + If None, this will be initialized to the maximum in `tracked_video`. Default is None. Returns ------- (T, Y, X) or (T, Z, Y, X) of ints - Tracked video where the objects in the first frame has the same IDs as - in `first_untracked_lab`. - + Tracked video where the objects in the first frame has the same IDs as + in `first_untracked_lab`. + Notes ----- - The idea of this function is to ensure that objects in the first frame - before and after tracking have the same IDs. This is needed to ensure - continuity of obejct IDs when tracking portions of the video in - different batches. - """ - + The idea of this function is to ensure that objects in the first frame + before and after tracking have the same IDs. This is needed to ensure + continuity of obejct IDs when tracking portions of the video in + different batches. + """ + first_tracked_lab = tracked_video[0] first_tracked_rp = skimage.measure.regionprops(first_tracked_lab) - + tracked_to_untracked_mapper = {} for obj in first_tracked_rp: untracked_ID = first_untracked_lab[obj.slice][obj.image][0] @@ -131,18 +129,16 @@ def retrack_based_on_untracked_first_frame( if not tracked_to_untracked_mapper: return tracked_video - + first_untracked_rp = skimage.measure.regionprops(first_untracked_lab) first_untracked_IDs = [obj.label for obj in first_untracked_rp] - + if uniqueID is None: uniqueID = np.max(tracked_video) + 1 - uniqueIDs = np.arange(uniqueID, uniqueID+len(first_untracked_IDs)) + uniqueIDs = np.arange(uniqueID, uniqueID + len(first_untracked_IDs)) + + untracked_to_unique_mapper = dict(zip(first_untracked_IDs, uniqueIDs)) - untracked_to_unique_mapper = ( - dict(zip(first_untracked_IDs, uniqueIDs)) - ) - pbar = tqdm(total=len(tracked_video), ncols=100) for frame_i, tracked_lab in enumerate(tracked_video): rp_tracked = skimage.measure.regionprops(tracked_lab) @@ -151,39 +147,39 @@ def retrack_based_on_untracked_first_frame( if new_unique_ID is None: # Untracked ID not present in tracked labels continue - + untracked_ID = tracked_to_untracked_mapper.get(obj_tracked.label) if untracked_ID is None: # No need to make ID unique because it will not change later continue - - # Replace untracked ID with a unique ID to prevent merging when later - # we will replace tracked IDs of first frame to their corresponding + + # Replace untracked ID with a unique ID to prevent merging when later + # we will replace tracked IDs of first frame to their corresponding # untracked ID - tracked_video[tracked_video==obj_tracked.label] = new_unique_ID + tracked_video[tracked_video == obj_tracked.label] = new_unique_ID - # Update tracked to untracked mapper because now tracked_video + # Update tracked to untracked mapper because now tracked_video # changed and we would not find the same ID later tracked_to_untracked_mapper[new_unique_ID] = ( tracked_to_untracked_mapper.pop(obj_tracked.label) ) - + pbar.update() pbar.close() - + uniqueID = np.max(tracked_video) + 1 - + untracked_to_unique_mapper = {} pbar = tqdm(total=len(tracked_video), ncols=100) for frame_i, tracked_lab in enumerate(tracked_video): rp_tracked = skimage.measure.regionprops(tracked_lab) - rp_tracked_dict = {obj.label:obj for obj in rp_tracked} + rp_tracked_dict = {obj.label: obj for obj in rp_tracked} for obj_tracked in rp_tracked: untracked_ID = tracked_to_untracked_mapper.get(obj_tracked.label) if untracked_ID is None: # Untracked ID not present in tracked labels continue - + untracked_obj = rp_tracked_dict.get(untracked_ID) if untracked_obj is not None: new_unique_ID = untracked_to_unique_mapper.get(untracked_ID) @@ -191,23 +187,20 @@ def retrack_based_on_untracked_first_frame( new_unique_ID = uniqueID untracked_to_unique_mapper[untracked_ID] = new_unique_ID uniqueID += 1 - + # Make sure to change existing IDs to unique lab = tracked_video[frame_i] - lab[untracked_obj.slice][untracked_obj.image] = ( - new_unique_ID - ) - - # Replace tracked ID of first frame to the untracked ID of the - # reference - tracked_video[frame_i][obj_tracked.slice][obj_tracked.image] = ( - untracked_ID - ) + lab[untracked_obj.slice][untracked_obj.image] = new_unique_ID + + # Replace tracked ID of first frame to the untracked ID of the + # reference + tracked_video[frame_i][obj_tracked.slice][obj_tracked.image] = untracked_ID pbar.update() pbar.close() - + return tracked_video + def remove_padding_2D(arr, val=0, return_crop_slice=False): crop_slice = [] for a, ax in enumerate((1, 0)): @@ -216,27 +209,28 @@ def remove_padding_2D(arr, val=0, return_crop_slice=False): pad_ax_mask = np.isnan(pad_ax) else: pad_ax_mask = pad_ax == val - + pad_ax_left = 0 for i, val in enumerate(pad_ax_mask): if not val: pad_ax_left = i - break - + break + pad_ax_right = arr.shape[a] for j, val in enumerate(pad_ax_mask[::-1]): if not val: pad_ax_right -= j - break - + break + crop_slice.append(slice(pad_ax_left, pad_ax_right)) - + crop_slice = tuple(crop_slice) if return_crop_slice: return arr[crop_slice], crop_slice - + return arr[tuple(crop_slice)] + def crop_outer_padding(arr, value=0, copy=False): if isinstance(value, (int, float)): if arr.ndim > 2: @@ -251,7 +245,7 @@ def crop_outer_padding(arr, value=0, copy=False): # which rows/cols are entirely padding? row_is_pad = np.all(padding_pixel, axis=1) col_is_pad = np.all(padding_pixel, axis=0) - + # build mask padding_mask = np.zeros_like(padding_pixel) @@ -262,57 +256,53 @@ def crop_outer_padding(arr, value=0, copy=False): is_top_padded = True except ValueError: is_top_padded = False - + try: bottom = len(row_is_pad) - np.argmax(~row_is_pad[::-1]) padding_mask[bottom:, :] = True is_bottom_padded = True except ValueError: is_bottom_padded = False - + try: left = np.argmax(~col_is_pad) padding_mask[:, :left] = True is_left_padded = True except ValueError: is_left_padded = False - + try: right = len(col_is_pad) - np.argmax(~col_is_pad[::-1]) padding_mask[:, right:] = True is_right_padded = True except ValueError: is_right_padded = False - - is_padded = ( - is_top_padded or is_bottom_padded or - is_left_padded or is_right_padded - ) - + + is_padded = is_top_padded or is_bottom_padded or is_left_padded or is_right_padded + if not is_padded: return arr.copy() if copy else arr - + # Crop using regionprops - padding_mask_rp = skimage.measure.regionprops( - skimage.measure.label(~padding_mask) - ) + padding_mask_rp = skimage.measure.regionprops(skimage.measure.label(~padding_mask)) if not padding_mask_rp: return arr.copy() if copy else arr - + padding_mask_obj = padding_mask_rp[0] top, left, bottom, right = padding_mask_obj.bbox - + # Crop cropped_arr = arr[top:bottom, left:right] - + if copy: cropped_arr = cropped_arr.copy() - + return cropped_arr + def snap_xy_to_closest_angle(x0, y0, x1, y1, angle_factor=15): # Snap to closest angle divisible by angle_factor degrees - angle = math.degrees(math.atan2(y1-y0, x1-x0)) + angle = math.degrees(math.atan2(y1 - y0, x1 - x0)) snap_angle = math.radians(core.closest_n_divisible_by_m(angle, angle_factor)) dist = math.dist((x0, y0), (x1, y1)) dx = dist * math.cos(snap_angle) @@ -320,6 +310,7 @@ def snap_xy_to_closest_angle(x0, y0, x1, y1, angle_factor=15): x1, y1 = x0 + dx, y0 + dy return x1, y1 + def correct_img_dimension(image, input_dims: List[str], output_dims: List[str]): """Resort and expand the image to the correct dimensions (output_dims). @@ -343,7 +334,7 @@ def correct_img_dimension(image, input_dims: List[str], output_dims: List[str]): if input_dims == output_dims: return image - + if image.ndim != len(input_dims): raise ValueError( f"Image has {image.ndim} dimensions but expected {len(input_dims)}" @@ -352,16 +343,17 @@ def correct_img_dimension(image, input_dims: List[str], output_dims: List[str]): missing_dims = set(output_dims) - set(input_dims) input_dims = list(input_dims) output_dims = list(output_dims) - + for missing_dim in missing_dims: image = np.expand_dims(image, axis=output_dims.index(missing_dim)) input_dims.insert(output_dims.index(missing_dim), missing_dim) - + dim_map = [input_dims.index(dim) for dim in output_dims] image = np.transpose(image, dim_map) - + return image + def clear_objects_not_in_mask(lab, mask): """Clear objects in lab that are not fully contained in mask. @@ -386,5 +378,5 @@ def clear_objects_not_in_mask(lab, mask): if np.all(mask[obj.slice][obj.image]): continue lab_cleared[obj.slice][obj.image] = 0 - - return lab_cleared \ No newline at end of file + + return lab_cleared diff --git a/cellacdc/urls.py b/cellacdc/urls.py index 094b1cc4d..bcd00a75c 100644 --- a/cellacdc/urls.py +++ b/cellacdc/urls.py @@ -1,29 +1,29 @@ -contribute_url = 'https://github.com/SchmollerLab/Cell_ACDC/blob/main/cellacdc/docs/source/contributing.rst' +contribute_url = "https://github.com/SchmollerLab/Cell_ACDC/blob/main/cellacdc/docs/source/contributing.rst" -github_url = 'https://github.com/SchmollerLab/Cell_ACDC' +github_url = "https://github.com/SchmollerLab/Cell_ACDC" -issues_url = 'https://github.com/SchmollerLab/Cell_ACDC/issues' +issues_url = "https://github.com/SchmollerLab/Cell_ACDC/issues" -forum_url = 'https://github.com/SchmollerLab/Cell_ACDC/discussions' +forum_url = "https://github.com/SchmollerLab/Cell_ACDC/discussions" -resources_url = 'https://github.com/SchmollerLab/Cell_ACDC#resources' +resources_url = "https://github.com/SchmollerLab/Cell_ACDC#resources" -my_contact_url = 'https://www.helmholtz-munich.de/ife/about-us/people/staff-detail/ma/8873/Dr.-Padovani/index.html' +my_contact_url = "https://www.helmholtz-munich.de/ife/about-us/people/staff-detail/ma/8873/Dr.-Padovani/index.html" -user_manual_url = 'https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf' +user_manual_url = "https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf" -cite_url = 'https://bmcbiol.biomedcentral.com/articles/10.1186/s12915-022-01372-6' +cite_url = "https://bmcbiol.biomedcentral.com/articles/10.1186/s12915-022-01372-6" -dataprep_docs = 'https://cell-acdc.readthedocs.io/en/latest/getting-started.html#preparing-data-for-further-analysis-data-prep' +dataprep_docs = "https://cell-acdc.readthedocs.io/en/latest/getting-started.html#preparing-data-for-further-analysis-data-prep" -docs_homepage = 'https://cell-acdc.readthedocs.io/en/latest' +docs_homepage = "https://cell-acdc.readthedocs.io/en/latest" -bioformats_jar_home_url = 'https://downloads.openmicroscopy.org/bio-formats/7.2.0/artifacts/bioformats_package.jar' +bioformats_jar_home_url = "https://downloads.openmicroscopy.org/bio-formats/7.2.0/artifacts/bioformats_package.jar" -bioformats_jar_hmgu_url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/NnGCP7nGKHz9Tds/download/bioformats_package.jar' +bioformats_jar_hmgu_url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/NnGCP7nGKHz9Tds/download/bioformats_package.jar" -bioformats_download_page = 'https://www.openmicroscopy.org/bio-formats/downloads/' +bioformats_download_page = "https://www.openmicroscopy.org/bio-formats/downloads/" -install_pytorch = 'https://pytorch.org/get-started/locally/' +install_pytorch = "https://pytorch.org/get-started/locally/" -fiji_downloads = 'https://imagej.net/software/fiji/downloads' \ No newline at end of file +fiji_downloads = "https://imagej.net/software/fiji/downloads" diff --git a/cellacdc/utils/__init__.py b/cellacdc/utils/__init__.py old mode 100755 new mode 100644 index e69de29bb..54f11cddf --- a/cellacdc/utils/__init__.py +++ b/cellacdc/utils/__init__.py @@ -0,0 +1,520 @@ +"""Cell-ACDC shared helpers (logging, paths, install, models, …).""" + +from .dataframe import ( + are_acdc_dfs_equal, + checked_reset_index, + checked_reset_index_Cell_ID, + df_ctc_to_acdc_df, + fix_acdc_df_dtypes, + format_IDs, + get_cca_colname_desc, +) + +from .install import ( + _apt_install_java_command, + _brew_install_hdf5, + _brew_install_java_command, + _get_pkg_command_pip_install, + _inform_install_package_failed, + _install_deepsea, + _install_homebrew_command, + _install_package_cli_msg, + _install_package_gui_msg, + _install_package_msg, + _install_pip_package, + _install_pytorch_cli, + _install_sam2, + _install_segment_anything, + _install_tensorflow, + _java_exists, + _java_instructions_linux, + _java_instructions_macOS, + _java_instructions_windows, + _warn_dll_torch, + _warn_install_gpu, + check_git_installed, + check_gpu_available, + check_gpu_requested_segm_model, + check_install_baby, + check_install_cellpose, + check_install_cellsam, + check_install_custom_dependencies, + check_install_instanseg, + check_install_microsam, + check_install_nnInteractive, + check_install_omnipose, + check_install_package, + check_install_sam2, + check_install_segment_anything, + check_install_tapir, + check_install_torch, + check_install_trackastra, + check_install_yeaz, + check_upgrade_javabridge, + download_java, + get_java_url, + get_package_info, + get_package_version, + get_pip_conda_prefix, + get_pip_install_cellacdc_version_command, + get_pytorch_command, + get_torch_device, + install_java, + install_javabridge, + install_javabridge_help, + install_javabridge_instructions_text, + install_package_conda, + uninstall_omnipose_acdc, + uninstall_pip_package, + update_editable_package, + update_not_editable_package, + update_package, +) + +from .io import ( + _bytes_to_GB, + _bytes_to_MB, + browse_docs, + browse_url, + getMemoryFootprint, + save_response_content, +) + +from .logging import ( + Logger, + _log_system_info, + delete_older_log_files, + get_logs_path, + log_segm_params, + setupLogger, +) + +from .misc import ( + StdErr, + _apt_gcc_command, + _apt_update_command, + _available_frameworks, + _get_doc_stop_idx, + _init_fiji_cli, + _jdk_exists, + _parse_bool_str, + _relabel_cca_dfs_and_segm_data, + _run_command, + _subprocess_run_command, + addToRecentPaths, + add_segm_data_param, + checkDataIntegrity, + check_napari_plugin, + clipSelemMask, + convert_to_dtype, + cpp_windows_url, + exec_time, + extract_zip, + filterCommonStart, + find_distances_ID, + find_missing_integers, + findalliter, + float_img_to_dtype, + format_cca_manual_changes, + format_commit_date_utc, + from_imagej_rois_to_segm_data, + from_lab_to_imagej_rois, + from_lab_to_obj_coords, + getAcdcDfSegmPaths, + getBaseAcdcDf, + getBasename, + getBasenameAndChNames, + getChannelFilePath, + getCustomAnnotTooltip, + getDefault_SegmInfo_df, + getMostRecentPath, + get_chained_attr, + get_chname_from_basename, + get_confirm_token, + get_empty_stored_data_dict, + get_fiji_base_command, + get_function_argspec, + get_input_output_mapper, + get_linux_distribution_name, + get_module_name, + get_obj_by_label, + get_slices_local_into_global_arr, + get_tiff_metadata, + img_to_float, + import_segment_module, + init_input_points_df, + is_gui_running, + is_in_bounds, + is_iterable, + iterate_along_axes, + jdk_windows_url, + lab2d_to_rois, + pairwise, + purge_module, + remove_known_extension, + reset_settings, + run_fiji_command, + safe_get_or_call, + seconds_to_ETA, + separate_fluo_segment_channels, + setRetainSizePolicy, + showInExplorer, + showUserManual, + sort_IDs_dist, + synthetic_image_geneator, + test_fiji_base_command, + to_tiff, + to_uint16, + to_uint8, + translateStrNone, + try_kwargs, + utilClass, +) + +from .models import ( + _download_cellpose_germlineNuclei_model, + _download_deepsea_models, + _download_omnipose_models, + _download_sam2_models, + _download_segment_anything_models, + _download_tapir_model, + _download_yeaz_models, + _model_url, + _write_model_location_to_txt, + aliases_real_time_trackers, + check_model_exists, + download_bioformats_jar, + download_examples, + download_ffmpeg, + download_fiji, + download_manual, + download_model, + download_url, + getClassArgSpecs, + getModelArgSpec, + getTrackerArgSpec, + get_add_custom_model_instructions, + get_add_custom_prompt_model_instructions, + get_list_of_models, + get_list_of_promptable_models, + get_list_of_real_time_trackers, + get_list_of_trackers, + import_promptable_segment_module, + import_tracker_module, + init_prompt_segm_model, + init_segm_model, + init_tracker, + insertModelArgSpec, + isIntensityImgRequiredForTracker, + params_to_ArgSpec, + parse_model_param_doc, + parse_model_params, + setDefaultValueArgSpecsFromKwargs, + validate_tracker_input, +) + +from .paths import ( + _create_temp_dir, + check_v123_model_path, + determine_folder_type, + get_acdc_data_path, + get_acdc_java_path, + get_examples_path, + get_fiji_binary_filepath_mac, + get_fiji_exec_folderpath, + get_gdrive_path, + get_images_folderpath, + get_model_path, + get_open_filemaneger_os_string, + get_pos_foldernames, + get_pos_status, + get_pos_status_acdc, + get_pos_status_spotmax, + is_old_user_profile_path, + is_pos_folderpath, + listdir, + migrate_to_new_user_profile_path, + store_custom_model_path, + store_custom_promptable_model_path, + to_relative_path, + trim_path, + validate_images_path, +) + +from .qt import ( + get_cli_multi_choice_question, + testQcoreApp, +) + +from .text import ( + append_text_filename, + elided_text, + get_number_fstring_formatter, + get_show_in_file_manager_text, + get_trimmed_dict, + get_trimmed_list, +) + +from .version import ( + _update_repo_with_git_command, + check_cellpose_version, + check_matplotlib_version, + check_pkg_exact_version, + check_pkg_max_version, + check_pkg_version, + get_cellpose_major_version, + get_date_from_version, + get_git_branch_name, + get_git_pull_checkout_cellacdc_version_commands, + get_info_version_text, + get_salute_string, + is_pkg_version_within_range, + is_second_version_greater, + read_version, +) + +__all__ = [ + "are_acdc_dfs_equal", + "checked_reset_index", + "checked_reset_index_Cell_ID", + "df_ctc_to_acdc_df", + "fix_acdc_df_dtypes", + "format_IDs", + "get_cca_colname_desc", + "_apt_install_java_command", + "_brew_install_hdf5", + "_brew_install_java_command", + "_get_pkg_command_pip_install", + "_inform_install_package_failed", + "_install_deepsea", + "_install_homebrew_command", + "_install_package_cli_msg", + "_install_package_gui_msg", + "_install_package_msg", + "_install_pip_package", + "_install_pytorch_cli", + "_install_sam2", + "_install_segment_anything", + "_install_tensorflow", + "_java_exists", + "_java_instructions_linux", + "_java_instructions_macOS", + "_java_instructions_windows", + "_warn_dll_torch", + "_warn_install_gpu", + "check_git_installed", + "check_gpu_available", + "check_gpu_requested_segm_model", + "check_install_baby", + "check_install_cellpose", + "check_install_cellsam", + "check_install_custom_dependencies", + "check_install_instanseg", + "check_install_microsam", + "check_install_nnInteractive", + "check_install_omnipose", + "check_install_package", + "check_install_sam2", + "check_install_segment_anything", + "check_install_tapir", + "check_install_torch", + "check_install_trackastra", + "check_install_yeaz", + "check_upgrade_javabridge", + "download_java", + "get_java_url", + "get_package_info", + "get_package_version", + "get_pip_conda_prefix", + "get_pip_install_cellacdc_version_command", + "get_pytorch_command", + "get_torch_device", + "install_java", + "install_javabridge", + "install_javabridge_help", + "install_javabridge_instructions_text", + "install_package_conda", + "uninstall_omnipose_acdc", + "uninstall_pip_package", + "update_editable_package", + "update_not_editable_package", + "update_package", + "_bytes_to_GB", + "_bytes_to_MB", + "browse_docs", + "browse_url", + "getMemoryFootprint", + "save_response_content", + "Logger", + "_log_system_info", + "delete_older_log_files", + "get_logs_path", + "log_segm_params", + "setupLogger", + "StdErr", + "_apt_gcc_command", + "_apt_update_command", + "_available_frameworks", + "_get_doc_stop_idx", + "_init_fiji_cli", + "_jdk_exists", + "_parse_bool_str", + "_relabel_cca_dfs_and_segm_data", + "_run_command", + "_subprocess_run_command", + "addToRecentPaths", + "add_segm_data_param", + "checkDataIntegrity", + "check_napari_plugin", + "clipSelemMask", + "convert_to_dtype", + "cpp_windows_url", + "exec_time", + "extract_zip", + "filterCommonStart", + "find_distances_ID", + "find_missing_integers", + "findalliter", + "float_img_to_dtype", + "format_cca_manual_changes", + "format_commit_date_utc", + "from_imagej_rois_to_segm_data", + "from_lab_to_imagej_rois", + "from_lab_to_obj_coords", + "getAcdcDfSegmPaths", + "getBaseAcdcDf", + "getBasename", + "getBasenameAndChNames", + "getChannelFilePath", + "getCustomAnnotTooltip", + "getDefault_SegmInfo_df", + "getMostRecentPath", + "get_chained_attr", + "get_chname_from_basename", + "get_confirm_token", + "get_empty_stored_data_dict", + "get_fiji_base_command", + "get_function_argspec", + "get_input_output_mapper", + "get_linux_distribution_name", + "get_module_name", + "get_obj_by_label", + "get_slices_local_into_global_arr", + "get_tiff_metadata", + "img_to_float", + "import_segment_module", + "init_input_points_df", + "is_gui_running", + "is_in_bounds", + "is_iterable", + "iterate_along_axes", + "jdk_windows_url", + "lab2d_to_rois", + "pairwise", + "purge_module", + "remove_known_extension", + "reset_settings", + "run_fiji_command", + "safe_get_or_call", + "seconds_to_ETA", + "separate_fluo_segment_channels", + "setRetainSizePolicy", + "showInExplorer", + "showUserManual", + "sort_IDs_dist", + "synthetic_image_geneator", + "test_fiji_base_command", + "to_tiff", + "to_uint16", + "to_uint8", + "translateStrNone", + "try_kwargs", + "utilClass", + "_download_cellpose_germlineNuclei_model", + "_download_deepsea_models", + "_download_omnipose_models", + "_download_sam2_models", + "_download_segment_anything_models", + "_download_tapir_model", + "_download_yeaz_models", + "_model_url", + "_write_model_location_to_txt", + "aliases_real_time_trackers", + "check_model_exists", + "download_bioformats_jar", + "download_examples", + "download_ffmpeg", + "download_fiji", + "download_manual", + "download_model", + "download_url", + "getClassArgSpecs", + "getModelArgSpec", + "getTrackerArgSpec", + "get_add_custom_model_instructions", + "get_add_custom_prompt_model_instructions", + "get_list_of_models", + "get_list_of_promptable_models", + "get_list_of_real_time_trackers", + "get_list_of_trackers", + "import_promptable_segment_module", + "import_tracker_module", + "init_prompt_segm_model", + "init_segm_model", + "init_tracker", + "insertModelArgSpec", + "isIntensityImgRequiredForTracker", + "params_to_ArgSpec", + "parse_model_param_doc", + "parse_model_params", + "setDefaultValueArgSpecsFromKwargs", + "validate_tracker_input", + "_create_temp_dir", + "check_v123_model_path", + "determine_folder_type", + "get_acdc_data_path", + "get_acdc_java_path", + "get_examples_path", + "get_fiji_binary_filepath_mac", + "get_fiji_exec_folderpath", + "get_gdrive_path", + "get_images_folderpath", + "get_model_path", + "get_open_filemaneger_os_string", + "get_pos_foldernames", + "get_pos_status", + "get_pos_status_acdc", + "get_pos_status_spotmax", + "is_old_user_profile_path", + "is_pos_folderpath", + "listdir", + "migrate_to_new_user_profile_path", + "store_custom_model_path", + "store_custom_promptable_model_path", + "to_relative_path", + "trim_path", + "validate_images_path", + "get_cli_multi_choice_question", + "testQcoreApp", + "append_text_filename", + "elided_text", + "get_number_fstring_formatter", + "get_show_in_file_manager_text", + "get_trimmed_dict", + "get_trimmed_list", + "_update_repo_with_git_command", + "check_cellpose_version", + "check_matplotlib_version", + "check_pkg_exact_version", + "check_pkg_max_version", + "check_pkg_version", + "get_cellpose_major_version", + "get_date_from_version", + "get_git_branch_name", + "get_git_pull_checkout_cellacdc_version_commands", + "get_info_version_text", + "get_salute_string", + "is_pkg_version_within_range", + "is_second_version_greater", + "read_version", +] diff --git a/cellacdc/utils/dataframe.py b/cellacdc/utils/dataframe.py new file mode 100644 index 000000000..c061f8c08 --- /dev/null +++ b/cellacdc/utils/dataframe.py @@ -0,0 +1,352 @@ +"""Cell-ACDC utility helpers: dataframe.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def checked_reset_index(df): + if df.index.names is None or df.index.names == [None]: + return df.reset_index(drop=True) + else: + return df.reset_index() + + +def checked_reset_index_Cell_ID(df): + if df.index.names == ["Cell_ID"]: + return df + df = checked_reset_index(df) + return df.set_index("Cell_ID") + + +def get_cca_colname_desc(): + desc = { + "Cell ID": ( + "ID of the segmented cell. All of the other columns " + "are properties of this ID." + ), + "Cell cycle stage": ("G1 if the cell does NOT have a bud. S/G2/M if it does."), + "Relative ID": ( + "ID of the bud related to the Cell ID (row). For cells in G1 write the " + "bud ID it had in the previous cycle." + ), + "Generation number": ( + "Number of times the cell divided from a bud. For cells in the first " + "frame write any number greater than 1." + ), + "Relationship": ( + "Relationship of the current Cell ID (row). " + "Either mother or bud. An object is a bud if " + "it didn't divide from the mother yet. All other instances " + "(e.g., cell in G1) are still labelled as mother." + ), + "Emerging frame num.": ( + "Frame number at which the object emerged/appeared in the scene." + ), + "Division frame num.": ( + "Frame number at which the bud separated from the mother." + ), + "Is history known?": ( + "Cells that are already present in the first frame or appears " + "from outside of the field of view, have some information missing. " + "For example, for cells in the first frame we do not know how many " + "times it budded and divided in the past. " + "In these cases Is history known? is True." + ), + } + return desc + + +def are_acdc_dfs_equal(df_left, df_right): + if df_left.shape != df_right.shape: + return False + + try: + for col in df_left.columns: + if col not in df_right.columns: + return False + + try: + eq_mask = np.isclose(df_left[col], df_right[col], equal_nan=True) + except Exception as err: + # Data type is string + eq_mask = df_left[col] == df_right[col] + + nan_mask = (df_left[col].isna()) & (df_right[col].isna()) + equality_mask = (eq_mask) | (nan_mask) + if not equality_mask.all(): + return False + except Exception as err: + return False + + return True + + +def fix_acdc_df_dtypes(acdc_df): + acdc_df["is_cell_excluded"] = acdc_df["is_cell_excluded"].astype(bool) + return acdc_df + + +def df_ctc_to_acdc_df( + df_ctc, + tracked_segm, + cell_division_mode="Normal", + return_list=False, + progressbar=True, +): + """Convert Cell Tracking Challenge DataFrame with annotated division to + Cell-ACDC cell cycle annotations DataFrame. + + Parameters + ---------- + df_ctc : pd.DataFrame + DataFrame with {'label', 't1', 't2', 'parent'} columns where + 't1' is the frame index of cell division. + tracked_segm : (T, Y, X) array of ints + Array of tracked segmentation labels. + cell_division_mode : {'Normal', 'Asymmetric'}, optional + Type of cell division. `Normal` is the standard cell division, + where the mother cell divides into two daughter cells. For the + tracking, that means the two daughter cells get a new, unique ID + each. + + `Asymmetric` means that the mother cell grows one daughter + cell that eventually divides from the mother (e.g., budding yeast). + For the tracking, this means that the mother cell ID keeps + existing after division and the daughter cell gets a new, unique ID. + + If `Asymmetric`, the third returned element is the segmentation data + with the asymmetric Cell IDs. + return_list : bool, optional + If `True`, the second returned element is the list of created dataframes, + one per frame. Default is False + progressbar : bool, optional + If `True`, displays a tqdm progressbar. Default is True + """ + cca_dfs = [] + keys = [] + df_ctc = df_ctc.set_index(["t1", "parent"]) + + if cell_division_mode == "Asymmetric": + asymm_tracked_segm = tracked_segm.copy() + + asymmetric_IDs_rename_mapper = {} + if progressbar: + pbar = tqdm( + desc="Converting to Cell-ACDC format", total=len(tracked_segm), ncols=100 + ) + for frame_i, lab in enumerate(tracked_segm): + rp = skimage.measure.regionprops(lab) + IDs = [obj.label for obj in rp] + cca_df = core.getBaseCca_df(IDs, with_tree_cols=True) + keys.append(frame_i) + if frame_i == 0: + cca_dfs.append(cca_df) + if progressbar: + pbar.update() + continue + + # Copy annotations from previous frames + prev_cca_df = cca_dfs[frame_i - 1] + old_IDs = cca_df.index.intersection(prev_cca_df.index) + cca_df.loc[old_IDs] = prev_cca_df.loc[old_IDs] + + try: + df_ctc_i = df_ctc.loc[frame_i] + except KeyError as err: + # No division detected --> nothing to annotate + cca_dfs.append(cca_df) + if progressbar: + pbar.update() + continue + + for parent_ID, df_ctc_i_pID in df_ctc_i.groupby(level=0): + daughter_IDs = df_ctc_i_pID["label"].to_list() + + if parent_ID == 0: + continue + + cca_df.loc[daughter_IDs, "parent_ID_tree"] = parent_ID + cca_df.loc[daughter_IDs, "emerg_frame_i"] = frame_i + cca_df.loc[daughter_IDs, "division_frame_i"] = frame_i + + root_ID = prev_cca_df.at[parent_ID, "root_ID_tree"] + if root_ID == -1: + root_ID = parent_ID + cca_df.loc[daughter_IDs, "root_ID_tree"] = root_ID + + cca_df.loc[daughter_IDs[0], "sister_ID_tree"] = daughter_IDs[1] + cca_df.loc[daughter_IDs[1], "sister_ID_tree"] = daughter_IDs[0] + + prev_gen_num = prev_cca_df.loc[parent_ID, "generation_num_tree"] + cca_df.loc[daughter_IDs, "generation_num_tree"] = prev_gen_num + 1 + + # Annotate division from df_ctc_i into + if cell_division_mode == "Asymmetric": + # Recycle the root_ID and assign it to one of the daughters + replaced_daught_ID = daughter_IDs[1] + key = (frame_i, replaced_daught_ID) + asymmetric_IDs_rename_mapper[key] = (root_ID, parent_ID) + + cca_dfs.append(cca_df) + + if progressbar: + pbar.update() + + if progressbar: + pbar.close() + + if asymmetric_IDs_rename_mapper: + _relabel_cca_dfs_and_segm_data( + cca_dfs, + asymmetric_IDs_rename_mapper, + asymm_tracked_segm, + progressbar=True, + ) + + cca_df = pd.concat(cca_dfs, keys=keys, names=["frame_i"]) + + out = [cca_df, None, None] + + if return_list: + out[1] = cca_dfs + + if cell_division_mode == "Asymmetric": + out[2] = asymm_tracked_segm + + return out + + +def format_IDs(IDs): + if isinstance(IDs, str): + raise ValueError("IDs must not be a string") + + IDsRange = [] + text = "" + sorted_vals = sorted(IDs) + for i, e in enumerate(sorted_vals): + e = int(e) + # Get previous and next value (if possible) + if i > 0: + prevVal = sorted_vals[i - 1] + else: + prevVal = -1 + if i < len(sorted_vals) - 1: + nextVal = sorted_vals[i + 1] + else: + nextVal = -1 + + if e - prevVal == 1 or nextVal - e == 1: + if not IDsRange: + if nextVal - e == 1 and e - prevVal != 1: + # Current value is the first value of a new range + IDsRange = [e] + else: + # Current value is the second element of a new range + IDsRange = [prevVal, e] + else: + if e - prevVal == 1: + # Current value is part of an ongoing range + IDsRange.append(e) + else: + # Current value is the first element of a new range + # --> create range text and this element will + # be added to the new range at the next iter + start, stop = IDsRange[0], IDsRange[-1] + if stop - start > 1: + sep = "-" + else: + sep = "," + text = f"{text},{start}{sep}{stop}" + IDsRange = [] + else: + # Current value doesn't belong to a range + if IDsRange: + # There was a range not added to text --> add it now + start, stop = IDsRange[0], IDsRange[-1] + if stop - start > 1: + sep = "-" + else: + sep = "," + text = f"{text},{start}{sep}{stop}" + + text = f"{text},{e}" + IDsRange = [] + + if IDsRange: + # Last range was not added --> add it now + start, stop = IDsRange[0], IDsRange[-1] + text = f"{text},{start}-{stop}" + + text = text[1:] + + return text + +# Sibling imports (deferred to avoid import cycles) +from .misc import ( + _relabel_cca_dfs_and_segm_data, +) + diff --git a/cellacdc/utils/install.py b/cellacdc/utils/install.py new file mode 100644 index 000000000..5944dffa8 --- /dev/null +++ b/cellacdc/utils/install.py @@ -0,0 +1,1724 @@ +"""Cell-ACDC utility helpers: install.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def check_git_installed(parent=None): + try: + subprocess.check_call(["git", "--version"], shell=True) + return True + except Exception as e: + print("=" * 20) + traceback.print_exc() + print("=" * 20) + git_url = "https://git-scm.com/book/en/v2/Getting-Started-Installing-Git" + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + In order to install javabridge you first need to install + Git (it was not found).

    + Close Cell-ACDC and follow the instructions + {html_utils.tag("here", f'a href="{git_url}"')}.

    + NOTE: After installing Git you might need to restart the + terminal. + """) + msg.warning(parent, "Git not installed", txt) + return False + + +def install_java(): + try: + subprocess.check_call(["javac", "-version"], shell=True) + return False + except Exception as e: + from . import widgets + + win = widgets.installJavaDialog() + win.exec_() + return win.clickedButton == win.cancelButton + + +def install_javabridge(force_compile=False, attempt_uninstall_first=False): + if attempt_uninstall_first: + try: + subprocess.check_call( + [sys.executable, "-m", "pip", "uninstall", "-y", "javabridge"] + ) + except Exception as e: + pass + if sys.platform.startswith("win"): + if force_compile: + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "git+https://github.com/SchmollerLab/python-javabridge-acdc", + ] + ) + else: + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "git+https://github.com/SchmollerLab/python-javabridge-windows", + ] + ) + elif is_mac: + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "git+https://github.com/SchmollerLab/python-javabridge-acdc", + ] + ) + elif is_linux: + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "git+https://github.com/LeeKamentsky/python-javabridge.git@master", + ] + ) + + +def get_java_url(): + is_linux = sys.platform.startswith("linux") + is_mac = sys.platform == "darwin" + is_win = sys.platform.startswith("win") + is_win64 = is_win and (os.environ["PROCESSOR_ARCHITECTURE"] == "AMD64") + + # https://drive.google.com/drive/u/0/folders/1MxhySsxB1aBrqb31QmLfVpq8z1vDyLbo + if is_win64: + os_foldername = "win64" + unzipped_foldername = "java_portable_windows-0.1" + file_size = 214798150 + # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/eMyirTw8qG2wJMt/download/java_portable_windows-0.1.zip' + url = "https://github.com/SchmollerLab/java_portable_windows/archive/refs/tags/v0.1.zip" + elif is_mac: + os_foldername = "macOS" + unzipped_foldername = "java_portable_macos-0.1" + url = "https://github.com/SchmollerLab/java_portable_macos/archive/refs/tags/v0.1.zip" + # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/SjZb8aommXgrECq/download/java_portable_macos-0.1.zip' + file_size = 108478751 + elif is_linux: + os_foldername = "linux" + unzipped_foldername = "java_portable_linux-0.1" + url = "https://github.com/SchmollerLab/java_portable_linux/archive/refs/tags/v0.1.zip" + # url = 'https://hmgubox2.helmholtz-muenchen.de/index.php/s/HjeQagixE2cjbZL/download/java_portable_linux-0.1.zip' + file_size = 92520706 + return url, file_size, os_foldername, unzipped_foldername + + +def get_package_version(import_pkg_name): + import importlib.metadata + + version = importlib.metadata.version(import_pkg_name) + return version + + +def check_upgrade_javabridge(): + try: + version = get_package_version("javabridge") + except Exception as e: + return + patch = int(version.split(".")[2]) + if patch > 18: + return + install_javabridge() + + +def _java_exists(os_foldername): + acdc_java_path, dot_acdc_java_path = get_acdc_java_path() + os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) + if os.path.exists(os_acdc_java_path): + for folder in os.listdir(os_acdc_java_path): + if not folder.startswith("jre"): + continue + dir_path = os.path.join(os_acdc_java_path, folder) + for file in os.listdir(dir_path): + if file == "bin": + return dir_path + + # Some users still has the old .acdc folder --> check + os_dot_acdc_java_path = os.path.join(dot_acdc_java_path, os_foldername) + if os.path.exists(os_dot_acdc_java_path): + for folder in os.listdir(os_dot_acdc_java_path): + if not folder.startswith("jre"): + continue + dir_path = os.path.join(os_dot_acdc_java_path, folder) + for file in os.listdir(dir_path): + if file == "bin": + return dir_path + return "" + + # Check if the user unzipped the javabridge_portable folder and not its content + os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) + if os.path.exists(os_acdc_java_path): + for folder in os.listdir(os_acdc_java_path): + dir_path = os.path.join(os_acdc_java_path, folder) + if folder.startswith("java_portable") and os.path.isdir(dir_path): + # Move files one level up + unzipped_path = os.path.join(os_acdc_java_path, folder) + for name in os.listdir(unzipped_path): + # move files up one level + src = os.path.join(unzipped_path, name) + shutil.move(src, os_acdc_java_path) + try: + shutil.rmtree(unzipped_path) + except PermissionError as e: + pass + # Check if what we moved one level up was actually java + for folder in os.listdir(os_acdc_java_path): + if not folder.startswith("jre"): + continue + dir_path = os.path.join(os_acdc_java_path, folder) + for file in os.listdir(dir_path): + if file == "bin": + return dir_path + return "" + + +def download_java(): + url, file_size, os_foldername, unzipped_foldername = get_java_url() + jre_path = _java_exists(os_foldername) + jdk_path = _jdk_exists(jre_path) + if os_foldername.startswith("win") and jre_path and jdk_path: + return jre_path, jdk_path, url + + if jre_path: + # on macOS jdk is the same as jre + return jre_path, jre_path, url + + acdc_java_path, _ = get_acdc_java_path() + os_acdc_java_path = os.path.join(acdc_java_path, os_foldername) + temp_zip = os.path.join(os_acdc_java_path, "acdc_java_temp.zip") + + if not os.path.exists(os_acdc_java_path): + os.makedirs(os_acdc_java_path, exist_ok=True) + + try: + download_url(url, temp_zip, file_size=file_size, desc="Java") + extract_zip(temp_zip, os_acdc_java_path) + except Exception as e: + print("=======================") + traceback.print_exc() + print("=======================") + finally: + os.remove(temp_zip) + + # Move files one level up + unzipped_path = os.path.join(os_acdc_java_path, unzipped_foldername) + for name in os.listdir(unzipped_path): + # move files up one level + src = os.path.join(unzipped_path, name) + shutil.move(src, os_acdc_java_path) + try: + shutil.rmtree(unzipped_path) + except PermissionError as e: + pass + + jre_path = _java_exists(os_foldername) + jdk_path = _jdk_exists(jre_path) + return jre_path, jdk_path, url + + +def _install_homebrew_command(): + return '/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"' + + +def _brew_install_java_command(): + return "brew install --cask homebrew/cask-versions/adoptopenjdk8" + + +def _brew_install_hdf5(): + return "brew install hdf5" + + +def _apt_install_java_command(): + return "sudo apt-get install openjdk-8-jdk" + + +def _java_instructions_linux(): + s1 = html_utils.paragraph(""" + Run the following commands
    + in the Teminal one by one: + """) + + s2 = html_utils.paragraph(f""" + {_apt_gcc_command().replace(" ", " ")} + """) + + s3 = html_utils.paragraph(f""" + {_apt_update_command().replace(" ", " ")} + """) + + s4 = html_utils.paragraph(f""" + {_apt_install_java_command().replace(" ", " ")} + """) + + s5 = html_utils.paragraph(""" + The first command is used to install GCC, which is needed later.

    + The second and third commands are used is used to install + Java Development Kit 8.

    + Follow the instructions on the terminal to complete + installation.

    + """) + return s1, s2, s3, s4 + + +def _java_instructions_macOS(): + s1 = html_utils.paragraph(""" + Run the following commands
    + in the Teminal one by one: + """) + + s2 = html_utils.paragraph(f""" + {_install_homebrew_command()} + """) + + s3 = html_utils.paragraph(f""" + {_brew_install_java_command().replace(" ", " ")} + """) + + s4 = html_utils.paragraph(""" + The first command is used to install Homebrew
    + a package manager for macOS/Linux.

    + The second command is used to install Java 8.
    + Follow the instructions on the terminal to complete + installation.

    + Alternatively, you can install Java as a regular app
    + by downloading the app from + + here + . + """) + return s1, s2, s3, s4 + + +def _java_instructions_windows(): + jdk_url = f'"{jdk_windows_url()}"' + cpp_url = f'"{cpp_windows_url()}"' + s1 = html_utils.paragraph(""" + Download and install Java Development Kit and
    + Microsoft C++ Build Tools for Windows (links below).

    + IMPORTANT: when installing "Microsoft C++ Build Tools"
    + make sure to select "Desktop development with C++".
    + Click "See the screenshot" for more details.
    + """) + + s2 = html_utils.paragraph(f""" + Java Development Kit: + + here + + """) + + s3 = html_utils.paragraph(f""" + Microsoft C++ Build Tools: + + here + + """) + return s1, s2, s3 + + +def install_javabridge_instructions_text(): + if is_win: + return _java_instructions_windows() + elif is_mac: + return _java_instructions_macOS() + elif is_linux: + return _java_instructions_linux() + + +def install_javabridge_help(parent=None): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(f""" + Cell-ACDC is going to download and install + javabridge.

    + Make sure you have an active internet connection, + before continuing. + Progress will be displayed on the terminal

    + IMPORTANT: If the installation fails, please open an issue + on our + + GitHub page + .

    + Alternatively, you can cancel the process and try later. + """) + msg.setIcon() + msg.setWindowTitle("Installing javabridge") + msg.addText(txt) + msg.addButton(" Ok ") + cancel = msg.addButton(" Cancel ") + msg.exec_() + return msg.clickedButton == cancel + + +def _install_pip_package( + pkg_name: str, + logger: Callable = print, + install_dependencies: bool = True, + force_binary: bool = True, + pref_binary: bool = True, +) -> None: + command = [ + sys.executable, + "-m", + "pip", + "install", + pkg_name, + ] + if force_binary: + command.append("--only-binary=:all:") + elif pref_binary: + command.append("--prefer-binary") + if not install_dependencies: + command.append("--no-deps") + try: + subprocess.check_call(command) + except subprocess.CalledProcessError as e: + if "--only-binary=:all:" in str(e): + logger( + f"Error: {pkg_name} does not have a binary distribution available, trying preferred binary." + ) + _install_pip_package( + pkg_name=pkg_name, + logger=logger, + install_dependencies=install_dependencies, + force_binary=False, + pref_binary=True, + ) + elif "--prefer-binary" in str(e): + logger( + f"Error: {pkg_name} does not have a preferred binary distribution available, trying source." + ) + command.remove("--prefer-binary") + command.append("--no-binary=:all:") + _install_pip_package( + pkg_name=pkg_name, + logger=logger, + install_dependencies=install_dependencies, + force_binary=False, + pref_binary=False, + ) + else: + logger(f"""Error: {pkg_name} installation failed. Please check the error message. This is probably due to the package + not being available for your platform or python version.""") + raise e + + +def uninstall_pip_package(pkg_name): + subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", pkg_name]) + + +def uninstall_omnipose_acdc(): + """Uninstall omnipose-acdc if present. Since v1.5.0 it is not needed.""" + import json + + pip_list_output = subprocess.check_output( + [sys.executable, "-m", "pip", "list", "--format", "json"] + ) + installed_packages = json.loads(pip_list_output) + pkgs_to_uninstall = [] + for package_info in installed_packages: + if package_info["name"] == "omnipose-acdc": + pkgs_to_uninstall.append("omnipose-acdc") + elif package_info["name"] == "cellpose-omni-acdc": + pkgs_to_uninstall.append("cellpose-omni-acdc") + + for pkg_to_uninstall in pkgs_to_uninstall: + uninstall_pip_package(pkg_to_uninstall) + + +def check_install_cellpose( + version: Literal["2.0", "3.0", "4.0", "any"] = "2.0", + version_to_install_if_missing: Literal["2.0", "3.0", "4.0"] = "4.0", +): + if isinstance(version, int): + version = f"{version}.0" + + check_install_torch() + + if version == "any": + try: + from cellpose import models + + return + except Exception as err: + version = version_to_install_if_missing # after this the version will for sure be a valid format and not 'any' + + is_version_correct = check_cellpose_version(version) + if is_version_correct: + return + + major_version = int(version.split(".")[0]) + + next_version = major_version + 1 + + min_version = min_target_versions_cp[str(major_version)] + + check_install_package( + "cellpose", + max_version=f"{next_version}.0", + min_version=min_version, + include_lower_version=True, + ) + + purge_module("cellpose") + + +def check_install_baby(): + check_install_package( + "TensorFlow", + pypi_name="tensorflow", + import_pkg_name="tensorflow", + max_version="2.14", + ) + check_install_package("baby", pypi_name="baby-seg", import_pkg_name="baby") + + +def check_install_nnInteractive(): + check_install_package("huggingface-hub") + check_install_torch() + check_install_package("nnInteractive") + + purge_module("nnInteractive") + + importlib.invalidate_caches() + import nnInteractive + + importlib.reload(nnInteractive) + + +def check_install_microsam(): + check_install_package("micro-sam", pypi_name="micro_sam", installer="conda") + + +def check_install_yeaz(): + check_install_torch() + check_install_package("yeaz") + + +def check_install_segment_anything(): + check_install_torch() + check_install_package("segment_anything") + + +def check_install_sam2(): + check_install_torch() + check_install_package("sam2") + + +def check_install_cellsam(): + check_install_torch() + check_install_package( + "cellSAM", + pypi_name="git+https://github.com/vanvalenlab/cellSAM.git", + import_pkg_name="cellSAM", + note=( + "CellSAM requires a DeepCell access token to download models.\n" + "Set the DEEPCELL_ACCESS_TOKEN environment variable before use.\n" + "Get your token at: https://deepcell.org" + ), + ) + + +def install_package_conda(conda_pkg_name, channel="conda-forge"): + if not is_conda_env(): + raise EnvironmentError("Cell-ACDC is not running in a `conda` environment.") + conda_prefix, pip_prefix = get_pip_conda_prefix() + conda_prefix = re.sub( + r"(-c\sconda-forge\s?|--channel=conda-forge\s?)", f"-c {channel} ", conda_prefix + ) + + command = f"{conda_prefix} -y {conda_pkg_name}" + _subprocess_run_command(command) + + +def check_install_omnipose(): + try: + import_module("omnipose") + return + except ModuleNotFoundError: + pass + + try: + check_install_package("omnipose", pypi_name="omnipose_acdc") + except Exception as err: + install_package_conda("mahotas") + _install_pip_package("omnipose-acdc") + + +def _warn_dll_torch(qparent=None): + msg = widgets.myMessageBox() + txt = html_utils.paragraph(""" + An error message will occur after you close this message.
    + Please save your data and restart Cell-ACDC.
    + Sorry for the inconvenience!
    + This error is not critical for the main functionality of Cell-ACDC, + and only concerns the segmentation model. Your can save your data without + a problem.
    + The specific reason is that PyTorch and QtPy have weird issues with + DLL conflicts. + """) + msg.information( + qparent, + "Please restart Cell-ACDC", + txt, + buttonsTexts=("Ok, I will save my data and restart Cell-ACDC"), + ) + + +def check_install_torch(is_cli=False, caller_name="Cell-ACDC", qparent=None): + try: + import torch + import torchvision + + return + + except OSError as err: + if "dll" in str(err): + _warn_dll_torch(qparent=qparent) + raise err + else: + traceback.print_exc() + except Exception as err: + traceback.print_exc() + + if is_cli: + _install_pytorch_cli(caller_name=caller_name) + return + + win = apps.InstallPyTorchDialog(parent=qparent, caller_name=caller_name) + win.exec_() + if win.cancel: + _warnings.log_pytorch_not_installed() + return + + command = win.command + print(f'Running command: "{command}"') + _run_command(command) + + try: + import torch + except OSError as e: + if "dll" in str(e): + _warn_dll_torch(qparent=qparent) + raise e + + purge_module("torch") + + +def check_install_package( + pkg_name: str, + import_pkg_name: str = "", + pypi_name="", + note="", + parent=None, + raise_on_cancel=True, + logger_func=print, + is_cli=False, + caller_name="Cell-ACDC", + force_upgrade=False, + upgrade=False, + min_version="", + max_version="", + exact_version="", + install_dependencies=True, + return_outcome=False, + installer: Literal["pip", "conda"] = "pip", + include_higher_version: bool = False, + include_lower_version: bool = False, +): + """Try to import a package. If import fails, ask user to install it + automatically. + + Parameters + ---------- + pkg_name : str + The name of the package that is displayed to the user. + import_pkg_name : str, optional + The name of the package as it should be imported (case sensitive). + If empty string, `pkg_name` will be imported instead. Default is '' + pypi_name : str, optional + The name of the package to be installed with pip. + If empty string, `pkg_name` will be installed instead. Default is '' + note : str, optional + Additional text to display to the user. Default is '' + parent : QObject, optional + Calling QtWidget. Default is None + raise_on_cancel : bool, optional + Raise exception if processed cancelled. Default is True + logger_func : callable, optional + Function used to log text. Default is print + is_cli : bool, optional + If True, message will be displayed in the terminal. + If False, message will be displayed in a Qt message box. + Default is False + caller_name : str, optional + Program calling this function. Default is 'Cell-ACDC' + force_upgrade : bool, optional + If True, we force the upgrade even if package is installed. + upgrade : bool, optional + If True, pip will upgrade the package. This value is True if + `force_upgrade` is True. Without min_version and max_version + it will never upgrade or downgrade the package. + min_version : str, optional + If not empty it must be a valid version `major[.minor][.patch]` where + minor and patch are optional. If the installed package is older the + upgrade will be forced. + max_version : str, optional + If not empty it must be a valid version `major[.minor][.patch]` where + minor and patch are optional. If the installed package is newer the + upgrade will be forced. + exact_version : str, optional + If not empty, install this exact version. It must be a valid + `major[.minor][.patch]`. + install_dependencies : bool, optional + If False, the `--no-deps` flag will be added to the pip command. + return_outcome : bool, optional + If True, returns 1 on successfull action + installer : str, optional + Package manager to use to install the package. Either 'pip' or 'conda'. + Default is 'pip' + include_higher_version : bool, optional + If True, if the higher version is installed, it will not be downgraded. + Default is False + include_lower_version : bool, optional + If True, if the lower version is installed, it will not be upgraded. + Default is False + + Raises + ------ + ModuleNotFoundError + Error raised if process is cancelled and `raise_on_cancel=True`. + """ + if not import_pkg_name: + import_pkg_name = pkg_name + + if not is_gui_running(): + is_cli = True + + try: # check_pkg_version and check_pkg_max_version + import_pkg_name = import_pkg_name.replace("-", "_") + import_module(import_pkg_name) + if force_upgrade: + upgrade = True + raise ModuleNotFoundError( + f'User requested to forcefully upgrade the package "{pkg_name}"' + ) + if exact_version: + check_pkg_exact_version(import_pkg_name, exact_version) + if min_version: + check_pkg_version(import_pkg_name, min_version, include_lower_version) + if max_version: + check_pkg_max_version(import_pkg_name, max_version, include_higher_version) + except ModuleNotFoundError: + proceed = _install_package_msg( + pkg_name, + note=note, + parent=parent, + upgrade=upgrade, + is_cli=is_cli, + caller_name=caller_name, + logger_func=logger_func, + pkg_command=pypi_name, + max_version=max_version, + min_version=min_version, + exact_version=exact_version, + installer=installer, + include_higher_version=include_higher_version, + include_lower_version=include_lower_version, + ) + if pypi_name: + pkg_name = pypi_name + if not proceed: + if raise_on_cancel: + raise ModuleNotFoundError(f"User aborted {pkg_name} installation") + else: + return traceback.format_exc() + try: + if pkg_name == "tensorflow": + _install_tensorflow(max_version=max_version, min_version=min_version) + elif pkg_name == "deepsea": + _install_deepsea() + elif pkg_name == "segment_anything": + _install_segment_anything() + elif pkg_name == "sam2": + _install_sam2() + else: + pkg_command = _get_pkg_command_pip_install( + pkg_name, + exact_version=exact_version, + max_version=max_version, + min_version=min_version, + including_higher_version=include_higher_version, + including_lower_version=include_lower_version, + ) + if installer == "pip": + _install_pip_package( + pkg_command, install_dependencies=install_dependencies + ) + else: + install_package_conda(pkg_command) + except Exception as e: + printl(traceback.format_exc()) + _inform_install_package_failed( + pkg_name, parent=parent, do_exit=raise_on_cancel + ) + if return_outcome: + return True + + +def check_install_custom_dependencies(custom_install_requires, *args, **kwargs): + """Used to install a package with custom dependencies, usefull if they have + random pinned versions for their dependencies. + + For *args and **kwargs see `utils.check_install_package`. + + Parameters + ---------- + custom_install_requires : list + list of dependencies. Check either requirements.txt, setup.py, + setup.cfg, pyproject.toml, or any other file that lists the dependencies. + For formatting of the dependencies with min max version, + use _get_pkg_command_pip_install. + """ + kwargs["install_dependencies"] = False + kwargs["return_outcome"] = True + success = check_install_package(*args, **kwargs) + if not success: + return + for pkg_name in custom_install_requires: + _install_pip_package(pkg_name) + + +def _inform_install_package_failed(pkg_name, parent=None, do_exit=True): + conda_prefix, pip_prefix = get_pip_conda_prefix() + + install_command = f"{pip_prefix} --upgrade {pkg_name}" + txt = html_utils.paragraph(f""" + Unfortunately, installation of {pkg_name} returned an error.

    + Try restarting Cell-ACDC. If it doesn't work, + please close Cell-ACDC and, with the acdc environment ACTIVE, + install {pkg_name} manually using the follwing command:

    + {install_command}

    + Thank you for your patience. + """) + msg = widgets.myMessageBox() + msg.critical(parent, f"{pkg_name} installation failed", txt) + print("*" * 50) + print( + f'[ERROR]: Installation of "{pkg_name}" failed. ' + f"Please, close Cell-ACDC and run the command " + f"{pip_prefix} --upgrade {pkg_name}`" + ) + print("^" * 50) + + +def _install_package_msg( + pkg_name, + note="", + parent=None, + upgrade=False, + caller_name="Cell-ACDC", + is_cli=False, + pkg_command="", + logger_func=print, + exact_version="", + max_version="", + min_version="", + installer: Literal["pip", "conda"] = "pip", + include_higher_version: bool = False, + include_lower_version: bool = False, +): + if is_cli: + proceed = _install_package_cli_msg( + pkg_name, + note=note, + upgrade=upgrade, + caller_name=caller_name, + pkg_command=pkg_command, + exact_version=exact_version, + max_version=max_version, + min_version=min_version, + logger_func=logger_func, + installer=installer, + include_higher_version=include_higher_version, + include_lower_version=include_lower_version, + ) + else: + proceed = _install_package_gui_msg( + pkg_name, + note=note, + parent=parent, + upgrade=upgrade, + caller_name=caller_name, + pkg_command=pkg_command, + exact_version=exact_version, + max_version=max_version, + min_version=min_version, + logger_func=logger_func, + installer=installer, + including_higher_version=include_higher_version, + including_lower_version=include_lower_version, + ) + return proceed + + +def _install_pytorch_cli(caller_name="Cell-ACDC", action="install", logger_func=print): + separator = "-" * 60 + txt = ( + f"{separator}\n{caller_name} needs to {action} PyTorch\n\n" + "You can choose to install it now or stop the process and install it " + "later. To install it correctly, we need to know your preferences.\n" + ) + logger_func(txt) + questions = { + "Choose your OS:": ("Windows", "Mac", "Linux"), + "Package manager:": ("Pip"), + "Compute platform:": ( + "CPU", + "CUDA 11.8 (NVIDIA GPU)", + "CUDA 12.1 (NVIDIA GPU)", + ), + } + selected_command = get_pytorch_command() + selected_preferences = [] + for question, choices in questions.items(): + input_txt = get_cli_multi_choice_question(question, choices) + while True: + answer = input(input_txt) + if answer.lower() == "q": + exit("Execution stopped by the user.") + + try: + idx = int(answer) - 1 + if idx >= len(choices): + raise TypeError("Not a valid answer") + except Exception as err: + print("-" * 100) + logger_func( + f'"{answer}" is not a valid answer.' + 'Choose one of the options or "q" to quit.' + ) + print("^" * 100) + continue + + preference = choices[idx] + selected_command = selected_command[preference] + selected_preferences.append(preference) + print("") + break + + print("-" * 100) + selected_preferences = ", ".join(selected_preferences) + logger_func(f"Selected preferences: {selected_preferences}") + print("-" * 100) + logger_func(f"Command:\n\n{selected_command}\n") + while True: + answer = input("Do you want to run the command now ([y]/n)?: ") + if answer.lower() == "n": + exit("Execution stopped by the user.") + + if answer.lower() == "y" or not answer: + break + + print("-" * 100) + print(f'"{answer}" is not a valid answer. Choose "y" for yes or "n" for no.') + print("^" * 100) + + if selected_command.startswith("conda"): + try: + subprocess.check_call([selected_command], shell=True) + except Exception as err: + cmd_list = selected_command.split() + cmd_list = [cmd.strip('"') for cmd in cmd_list] + cmd_list = [cmd.strip("'") for cmd in cmd_list] + cmd_list = [cmd.lstrip(".") for cmd in cmd_list] + subprocess.check_call(cmd_list, shell=True) + else: + cmd_list = selected_command.split()[1:] + cmd_list = [cmd.strip('"') for cmd in cmd_list] + cmd_list = [cmd.strip("'") for cmd in cmd_list] + cmd_list = [cmd.lstrip(".") for cmd in cmd_list] + subprocess.check_call([sys.executable, *cmd_list], shell=True) + + +def _get_pkg_command_pip_install( + pkg_command, + exact_version="", + max_version="", + min_version="", + including_lower_version=False, + including_higher_version=False, +): + if exact_version: + pkg_command = f"{pkg_command}=={exact_version}" + return pkg_command + + if including_higher_version: + sign_max = "<=" + else: + sign_max = "<" + if including_lower_version: + sign_min = ">=" + else: + sign_min = ">" + if min_version: + pkg_command = f"{pkg_command}{sign_min}{min_version}" + if max_version: + pkg_command = f"{pkg_command}," + + if max_version: + pkg_command = f"{pkg_command}{sign_max}{max_version}" + + return pkg_command + + +def _install_package_cli_msg( + pkg_name, + note="", + upgrade=False, + caller_name="Cell-ACDC", + logger_func=print, + pkg_command="", + exact_version="", + max_version="", + min_version="", + installer: Literal["pip", "conda"] = "pip", + include_lower_version=False, + include_higher_version=False, +): + if not pkg_command: + pkg_command = pkg_name + + pkg_command = _get_pkg_command_pip_install( + pkg_command, + exact_version=exact_version, + max_version=max_version, + min_version=min_version, + including_lower_version=include_lower_version, + including_higher_version=include_higher_version, + ) + + if upgrade: + action = "upgrade" + else: + action = "install" + + conda_prefix, pip_prefix = get_pip_conda_prefix() + + if installer == "pip": + install_command = f"{pip_prefix} --upgrade {pkg_command}" + elif installer == "conda": + install_command = f"{conda_prefix} {pkg_command}" + + separator = "-" * 60 + txt = ( + f"{separator}\n{caller_name} needs to {action} {pkg_name}\n\n" + "You can choose to install it now or stop the process and install it " + "later with the following command:\n\n" + f"{install_command}\n" + ) + logger_func(txt) + + while True: + answer = try_input_install_package(pkg_name, install_command) + if not answer or answer.lower() == "y": + return True + + if answer.lower() == "n": + return False + + logger_func( + f'{answer} is not a valid answer. Valid answers are "y" for Yes and ' + '"n" for No.' + ) + + +def _install_package_gui_msg( + pkg_name, + note="", + parent=None, + upgrade=False, + caller_name="Cell-ACDC", + pkg_command="", + logger_func=None, + exact_version="", + max_version="", + min_version="", + including_lower_version=False, + including_higher_version=False, + installer: Literal["pip", "conda"] = "pip", +): + msg = widgets.myMessageBox(parent=parent) + if upgrade: + install_text = "upgrade" + else: + install_text = "install" + if pkg_name == "BayesianTracker": + pkg_name = "btrack" + + if not pkg_command: + pkg_command = pkg_name + + pkg_command = _get_pkg_command_pip_install( + pkg_command, + exact_version=exact_version, + max_version=max_version, + min_version=min_version, + including_lower_version=including_lower_version, + including_higher_version=including_higher_version, + ) + + conda_prefix, pip_prefix = get_pip_conda_prefix() + + if installer == "pip": + command = f"{pip_prefix} --upgrade {pkg_command}" + elif installer == "conda": + command = f"{conda_prefix} {pkg_command}" + + command_html = command.lower().replace("<", "<").replace(">", ">") + + txt = html_utils.paragraph(f""" + {caller_name} is going to download and {install_text} + {pkg_name}.

    + Make sure you have an active internet connection, + before continuing.
    + Progress will be displayed on the terminal

    + You might have to restart {caller_name}.

    + Alternatively, you can cancel the process and try later.

    + To install later, or if the installation fails, run the following + command: + """) + if note: + txt = f"{txt}{note}" + _, okButton = msg.information( + parent, + f"Install {pkg_name}", + txt, + buttonsTexts=("Cancel", "Ok"), + commands=(command_html,), + ) + return msg.clickedButton == okButton + + +def _install_tensorflow(max_version="", min_version=""): + cpu = platform.processor() + pkg_command = _get_pkg_command_pip_install( + "tensorflow", max_version=max_version, min_version=min_version + ) + conda_prefix, pip_prefix = get_pip_conda_prefix() + + if is_mac and cpu == "arm": + args = [f'{conda_prefix} "{pkg_command}"'] + shell = True + else: + args = [sys.executable, "-m", "pip", "install", "-U", pkg_command] + shell = False + subprocess.check_call(args, shell=shell) + + # purge numpy + purge_module("numpy") + + +def _install_segment_anything(): + args = [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "--use-pep517", + "git+https://github.com/facebookresearch/segment-anything.git", + ] + subprocess.check_call(args) + + +def _install_sam2(): + args = [ + sys.executable, + "-m", + "pip", + "install", + "-U", + "--use-pep517", + "git+https://github.com/facebookresearch/sam2.git", + ] + subprocess.check_call(args) + + +def _install_deepsea(): + subprocess.check_call([sys.executable, "-m", "pip", "install", "deepsea"]) + + +def get_pip_conda_prefix(list_return=False): + from .config import parser_args + + try: + cp = parser_args + if cp["install_details"] is not None: + no_cli_install = True + install_details = cp["install_details"] + venv_path = install_details["venv_path"] + conda_path = install_details["conda_path"] + if " " not in conda_path: + conda_path = conda_path.strip('"').strip("'") + else: + no_cli_install = False + except: + no_cli_install = False + pass + + if no_cli_install: + conda_prefix = f"{conda_path} install -y -p {venv_path} -c conda-forge" + exec_path = sys.executable + if " " in exec_path: + exec_path = f'"{exec_path}"' + pip_prefix = f"{exec_path} -m pip install" + else: + conda_prefix = "conda install -y -c conda-forge" + pip_prefix = "pip install" + + pip_list = [sys.executable, "-m", "pip", "install"] + if no_cli_install: + conda_list = [ + conda_path.strip('"').strip("'"), + "install", + "-y", + "-p", + venv_path.strip('"').strip("'"), + "-c", + "conda-forge", + ] + else: + conda_list = ["conda", "install", "-y", "-c", "conda-forge"] + if list_return: + return conda_list, pip_list + else: + return conda_prefix, pip_prefix + + +def _warn_install_gpu(model_name, ask_installs, qparent=None): + + cellpose_cuda_url = ( + r"https://github.com/mouseland/cellpose#gpu-version-cuda-on-windows-or-linux" + ) + torch_cuda_url = r"https://pytorch.org/get-started/locally/" + direct_ml_url = r"https://microsoft.github.io/DirectML/" + torch_directml_url = ( + r"https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows" + ) + + cellpose_href = f"{html_utils.href_tag('here', cellpose_cuda_url)}" + torch_href = f"{html_utils.href_tag('here', torch_cuda_url)}" + direct_ml_href = f"{html_utils.href_tag('direct_ml_DirectMLref', direct_ml_url)}" + torch_directml_href = ( + f"{html_utils.href_tag('directml pytorch', torch_directml_url)}" + ) + + conda_prefix, pip_prefix = get_pip_conda_prefix() + + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(f""" + In order to use {model_name} with the GPU you need + to install a PyTorch version which can use it.
    + We recomment using CUDA over DirectML, but if you are using a Windows + machine with an AMD GPU, you can use DirectML.
    + """) + txt_cuda_title = html_utils.paragraph(f"CUDA", font_size="18px") + + pip_prefix = pip_prefix.replace("install -y", "uninstall") + txt_cuda = html_utils.paragraph(f""" + Check out these instructions {cellpose_href}, and {torch_href}.
    + First, uninstall the CPU version of PyTorch with the following command: + {pip_prefix} uninstall torch +
    Then, install the CUDA version required by your GPU with the following + command (in this case 12.8): + {pip_prefix} torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +
    + """) + + add_info = html_utils.to_admonition( + f""" + Pleae use the following table to find the correct link for the command. + You can check the highest CUDA
    version supported on your system with the + command nvidia-smi in the terminal.
    + + {html_utils.table_style_header} + + CUDA Version + PyTorch Installation Link + + + CUDA 11.8 + https://download.pytorch.org/whl/cu118 + + + CUDA 12.6 + https://download.pytorch.org/whl/cu126 + + + CUDA 12.8 + https://download.pytorch.org/whl/cu128 + + + """, + "info", + ) + + txt_cuda = f"{txt_cuda}{add_info}" + + txt_directML_title = html_utils.paragraph(f"DirectML", font_size="18px") + txt_directML = html_utils.paragraph(f""" + Check out {direct_ml_href}, and {torch_directml_href} for more info.
    + Only supported on Windows 10/11 with Python 3.8-3.12.
    + Click the Install DirectML button to install DirectML. +

    + """) + + txt_end = html_utils.paragraph(f""" + How do you want to proceed? + """) + + stopButton = widgets.cancelPushButton("Stop the process") + directMLButton = widgets.okPushButton("Install DirectML") + proceedButton = widgets.okPushButton("Proceed without GPU") + + buttons = [stopButton] + + if "cuda" in ask_installs: + txt = f"{txt}{txt_cuda_title}{txt_cuda}" + if "directML" in ask_installs: + txt = f"{txt}{txt_directML_title}{txt_directML}" + buttons.append(directMLButton) + txt = f"{txt}{txt_end}" + buttons.append(proceedButton) + + msg.warning( + qparent, + "PyTorch GPU version not installed", + txt, + buttonsTexts=buttons, + ) + + if msg.cancel: + return False, False + + if msg.clickedButton == directMLButton: + py_ver = sys.version_info + if is_win and py_ver.major == 3 and py_ver.minor < 13: + success = check_install_package( + pkg_name="torch-directml", + import_pkg_name="torch_directml", + pypi_name="torch-directml", + return_outcome=True, + ) + purge_module("torch") + return success, True + else: + msg = widgets.myMessageBox() + msg.warning( + qparent, + "DirectML not supported", + "DirectML is only supported on Python 3.8-3.12 and Windows 10/11", + ) + return False, False + + if msg.clickedButton == stopButton: + return False, False + + if msg.clickedButton == proceedButton: + return True, False + + +def check_gpu_requested_segm_model(init_kwargs): + gpu = init_kwargs.get("gpu", False) + if gpu: + return True + + device_type = init_kwargs.get("device_type", "cpu") + return device_type == "gpu" or device_type == "" + + +def check_gpu_available( + model_name, + use_gpu, + do_not_warn=False, + qparent=None, + cuda=False, + directML=False, + return_available_gpu_type=False, +): + if not use_gpu: + if return_available_gpu_type: + return True, [] + else: + return True + + ask_for_cuda = False + if cuda: + try: + import torch + + if not torch.cuda.is_available(): + ask_for_cuda = True + if not torch.cuda.device_count() > 0: + ask_for_cuda = True + except ModuleNotFoundError: + ask_for_cuda = True + + ask_for_directML = False + if directML: + if is_win: + try: + import torch_directml + + if not torch_directml.is_available(): + ask_for_directML = True + except ModuleNotFoundError: + ask_for_directML = True + + frameworks = _available_frameworks(model_name) + ask_installs = set() if not ask_for_cuda else {"cuda"} + ask_installs.update({"directML"} if ask_for_directML else set()) + framework_available = False + available_frameworks_list = [] + for framework, model_compatible in frameworks.items(): + if not model_compatible: + continue + if framework == "cuda": + import torch + + if not torch.cuda.is_available(): + ask_installs.add("cuda") + elif not torch.cuda.device_count() > 0: + ask_installs.add("cuda") + else: + framework_available = True + available_frameworks_list.append("cuda") + elif framework == "directML": + if is_win: + try: + import torch_directml + + if not torch_directml.is_available(): + ask_installs.add("directML") + else: + framework_available = True + available_frameworks_list.append("directML") + except ModuleNotFoundError: + ask_installs.add("directML") + elif is_mac_arm64: + framework_available = True + break + + if framework_available and not ask_for_cuda and not ask_for_directML: + if return_available_gpu_type: + return True, available_frameworks_list + else: + return True + + elif do_not_warn: + if return_available_gpu_type: + return False, available_frameworks_list + else: + return False + + proceed, directML_installed = _warn_install_gpu( + model_name, ask_installs, qparent=qparent + ) + if return_available_gpu_type: + if directML_installed: + available_frameworks_list.append("directML") + return proceed, available_frameworks_list + else: + return proceed + + +def get_pip_install_cellacdc_version_command(version=None): + conda_prefix, pip_prefix = get_pip_conda_prefix() + + if version is None: + version = read_version() + commit_hash_idx = version.find("+g") + is_dev_version = commit_hash_idx > 0 + if is_dev_version: + commit_hash = version[commit_hash_idx + 2 :].split(".")[0] + command = f'{pip_prefix} --upgrade "git+{github_home_url}.git@{commit_hash}"' + command_github = None + else: + command = f"{pip_prefix} --upgrade cellacdc=={version}" + command_github = f'{pip_prefix} --upgrade "git+{urls.github_url}@{version}"' + return command, command_github + + +def check_install_tapir(): + check_install_package( + "tapnet", pypi_name="git+https://github.com/ElpadoCan/TAPIR.git" + ) + + +def check_install_trackastra(): + check_install_package( + "Trackastra", import_pkg_name="trackastra", pypi_name="trackastra" + ) + + +def get_torch_device(gpu=False): + import torch + + if torch.cuda.is_available() and gpu: + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + return device + + +def check_install_instanseg(): + check_install_package( + pkg_name="InstanSeg", import_pkg_name="instanseg", pypi_name="instanseg-torch" + ) + + +def get_pytorch_command(): + """Get the command to install pytorch CPU or CUDA + + Returns + ------- + dict + Dictionary mapping OS to commands for installing PyTorch + + Notes + ----- + As of Oct 2024, the `pytorch` channel on Anaconda was deprecated. + See here https://github.com/pytorch/pytorch/issues/138506 + """ + conda_prefix, pip_prefix = get_pip_conda_prefix() + + pytorch_commands = { + "Windows": { + # 'Conda': { + # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', + # 'CUDA 11.8 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=11.8 -c conda-forge -c nvidia', + # 'CUDA 12.1 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c nvidia' + # }, + "Pip": { + "CPU": f"{pip_prefix} torch torchvision", + "CUDA 11.8 (NVIDIA GPU)": f"{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu118", + "CUDA 12.1 (NVIDIA GPU)": f"{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu121", + } + }, + "Mac": { + # 'Conda': { + # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', + # 'CUDA 11.8 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS', + # 'CUDA 12.1 (NVIDIA GPU)': '[WARNING]: CUDA is not available on MacOS' + # }, + "Pip": { + "CPU": f"{pip_prefix} torch torchvision", + "CUDA 11.8 (NVIDIA GPU)": "[WARNING]: CUDA is not available on MacOS", + "CUDA 12.1 (NVIDIA GPU)": "[WARNING]: CUDA is not available on MacOS", + } + }, + "Linux": { + # 'Conda': { + # 'CPU': f'{conda_prefix} pytorch torchvision cpuonly -c conda-forge', + # 'CUDA 11.8 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=11.8 -c conda-forge -c nvidia', + # 'CUDA 12.1 (NVIDIA GPU)': f'{conda_prefix} pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c nvidia' + # }, + "Pip": { + "CPU": f"{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cpu", + "CUDA 11.8 (NVIDIA GPU)": f"{pip_prefix} torch torchvision --index-url https://download.pytorch.org/whl/cu118", + "CUDA 12.1 (NVIDIA GPU)": f"{pip_prefix} torch torchvision", + } + }, + } + + return pytorch_commands + + +def get_package_info(package_name): + try: + result = subprocess.run( + [sys.executable, "-m", "pip", "show", package_name], + capture_output=True, + text=True, + check=True, + ) + + info = {} + for line in result.stdout.split("\n"): + if ":" in line: + key, value = line.split(":", 1) + info[key.strip()] = value.strip() + + # Check if it's editable by looking at the location + location = info.get("Location", "") + editable_location = info.get("Editable project location", "") + + return { + "installed": True, + "editable": bool(editable_location), + "location": location, + "editable_location": editable_location, + } + + except subprocess.CalledProcessError: + return {"installed": False, "editable": False} + + +def update_package(parent, package_name): + package_info = get_package_info(package_name) + if not package_info["installed"]: + printl(f"Package {package_name} is not installed.") + return False + editable = package_info.get("editable", False) + if editable: + return update_editable_package(parent, package_name, package_info) + else: + return update_not_editable_package(package_name, package_info) + + +def update_editable_package(parent, package_name, package_info): + repo_location = package_info.get("editable_location", "") + + if not repo_location or not os.path.exists(repo_location): + print(f"Repository location not found for {package_name}") + return False + + return _update_repo_with_git_command(package_name, repo_location) + + +def update_not_editable_package(package_name, package_info): + """Update a non-editable package using pip""" + try: + _, pip_list = get_pip_conda_prefix(list_return=True) + command = pip_list + ["--upgrade ", package_name] + + print(f"Updating {package_name} using pip...") + result = subprocess.run(command, shell=True, capture_output=True, text=True) + + if result.returncode == 0: + print(f"Successfully updated {package_name}") + return True + else: + print(f"Failed to update {package_name}: {result.stderr}") + return False + + except Exception as e: + print(f"Error updating {package_name}: {e}") + return False + +# Sibling imports (deferred to avoid import cycles) +from .misc import ( + _apt_gcc_command, + _apt_update_command, + _available_frameworks, + _jdk_exists, + _run_command, + _subprocess_run_command, + cpp_windows_url, + extract_zip, + is_gui_running, + jdk_windows_url, + purge_module, +) +from .models import ( + download_url, +) +from .paths import ( + get_acdc_java_path, +) +from .qt import ( + get_cli_multi_choice_question, +) +from .version import ( + _update_repo_with_git_command, + check_cellpose_version, + check_pkg_exact_version, + check_pkg_max_version, + check_pkg_version, + read_version, +) + diff --git a/cellacdc/utils/io.py b/cellacdc/utils/io.py new file mode 100644 index 000000000..8673e30e4 --- /dev/null +++ b/cellacdc/utils/io.py @@ -0,0 +1,130 @@ +"""Cell-ACDC utility helpers: io.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def _bytes_to_MB(size_bytes): + factor = pow(2, -20) + size_MB = round(size_bytes * factor) + return size_MB + + +def _bytes_to_GB(size_bytes): + factor = pow(2, -30) + size_GB = round(size_bytes * factor, 2) + return size_GB + + +def getMemoryFootprint(files_list): + required_memory = sum( + [48 if file.endswith(".h5") else os.path.getsize(file) for file in files_list] + ) + return required_memory + + +def browse_url(url): + import webbrowser + + webbrowser.open(url) + + +def browse_docs(): + browse_url(urls.docs_homepage) + + +def save_response_content( + response, destination, file_size=None, model_name="cellpose", progress=None +): + print(f"Downloading {model_name} to: {os.path.dirname(destination)}") + CHUNK_SIZE = 32768 + + # Download to a temp folder in user path + temp_folder = pathlib.Path.home().joinpath(".acdc_temp") + if not os.path.exists(temp_folder): + os.mkdir(temp_folder) + temp_dst = os.path.join(temp_folder, os.path.basename(destination)) + if file_size is not None and progress is not None: + progress.emit(file_size, -1) + pbar = tqdm( + total=file_size, unit="B", unit_scale=True, unit_divisor=1024, ncols=100 + ) + with open(temp_dst, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + if progress is not None: + progress.emit(-1, len(chunk)) + pbar.close() + + # Move to destination and delete temp folder + destination_dir = os.path.dirname(destination) + if not os.path.exists(destination_dir): + os.makedirs(destination_dir, exist_ok=True) + shutil.move(temp_dst, destination) + shutil.rmtree(temp_folder) diff --git a/cellacdc/utils/logging.py b/cellacdc/utils/logging.py new file mode 100644 index 000000000..572c242af --- /dev/null +++ b/cellacdc/utils/logging.py @@ -0,0 +1,351 @@ +"""Cell-ACDC utility helpers: logging.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_logs_path(): + return logs_path + + +class Logger(logging.Logger): + def __init__(self, module="base", name="cellacdc-logger", level=logging.DEBUG): + super().__init__(f"{name}-{module}", level=level) + self._stdout = sys.stdout + self._stderr = StdErr(logger=self) + sys.stderr = self._stderr + self._levelToName = { + 50: "CRITICAL", + 40: "ERROR", + 30: "WARNING", + 20: "INFO", + 10: "DEBUG", + 0: "NOTSET", + } + + def write(self, text, log_to_file=True, write_to_stdout=True): + """Capture print statements, print to terminal and log text to + the open log file + + Parameters + ---------- + text : str + Text to log + log_to_file : bool, optional + If True, call `info` method with `text`. Default is True + """ + if write_to_stdout: + self._stdout.write(text) + + if not log_to_file: + return + + if text == "\n": + return + + if not text: + return + + self.debug(text) + + def close(self): + for handler in self.handlers: + handler.close() + self.removeHandler(handler) + sys.stdout = self._stdout + self._stderr.close() + + def __del__(self): + sys.stdout = self._stdout + self._stderr.close() + + def info(self, text, *args, **kwargs): + super().info(text, *args, **kwargs) + try: + self.write(f"{text}\n", log_to_file=False) + except TypeError: + # Sometimes the logger is patched (e.g., by spotiflow), which + # triggers the TypeError because the patching function does not have + # log_to_file argument + self.write(f"{text}\n") + + def warning(self, text, *args, **kwargs): + super().warning(text, *args, **kwargs) + try: + self.write(f"[WARNING]: {text}\n", log_to_file=False) + except TypeError: + # Sometimes the logger is patched (e.g., by spotiflow), which + # triggers the TypeError because the patching function does not have + # log_to_file argument + self.write(f"[WARNING]: {text}\n") + + def error(self, text, *args, write_traceback=True, **kwargs): + super().error(text, *args, **kwargs) + self.write(traceback.format_exc()) + try: + self.write(f"[ERROR]: {text}\n", log_to_file=False) + except TypeError: + # Sometimes the logger is patched (e.g., by spotiflow), which + # triggers the TypeError because the patching function does not have + # log_to_file argument + self.write(f"[ERROR]: {text}\n") + + def plain(self, text, write_to_stdout=False): + orig_formatters = [handler.formatter for handler in self.handlers] + for handler in self.handlers: + handler.setFormatter(logging.Formatter("%(message)s")) + self.write(text, write_to_stdout=write_to_stdout) + for handler in self.handlers: + handler.setFormatter(orig_formatters.pop(0)) + + def critical(self, text, *args, **kwargs): + super().critical(text, *args, **kwargs) + try: + self.write(f"[CRITICAL]: {text}\n", log_to_file=False) + except TypeError: + # Sometimes the logger is patched (e.g., by spotiflow), which + # triggers the TypeError because the patching function does not have + # log_to_file argument + self.write(f"[CRITICAL]: {text}\n") + + def exception(self, text, *args, write_traceback=True, **kwargs): + super().exception(text, *args, **kwargs) + self.write(traceback.format_exc()) + try: + self.write(f"[ERROR]: {text}\n", log_to_file=False) + except TypeError: + # Sometimes the logger is patched (e.g., by spotiflow), which + # triggers the TypeError because the patching function does not have + # log_to_file argument + self.write(f"[ERROR]: {text}\n") + + def log(self, level, text): + if not isinstance(level, int): + printl(level, text, type(level), type(text), sep="\n") + super().log(level, text) + levelName = self._levelToName.get(level, "INFO") + getattr(self, levelName.lower())(text) + + def flush(self): + self._stdout.flush() + + +def delete_older_log_files(logs_path): + if not os.path.exists(logs_path): + return + + log_files = os.listdir(logs_path) + for log_file in log_files: + if not log_file.endswith(".log"): + continue + + log_filepath = os.path.join(logs_path, log_file) + try: + mtime = os.path.getmtime(log_filepath) + except Exception as err: + continue + + mdatetime = datetime.datetime.fromtimestamp(mtime) + days = (datetime.datetime.now() - mdatetime).days + if days < 7: + continue + + try: + os.remove(log_filepath) + except Exception as err: + continue + + +def _log_system_info(logger, log_path, is_cli=False, also_spotmax=False): + logger.info(f'Initialized log file "{log_path}"') + + info_txt = get_info_version_text(is_cli=is_cli) + + logger.info(info_txt) + + if not also_spotmax: + return + + from spotmax.utils import get_info_version_text as smax_info + + smax_info_txt = smax_info(include_platform=False) + logger.info(smax_info_txt) + + +def setupLogger(module="base", logs_path=None, caller="Cell-ACDC"): + if logs_path is None: + logs_path = get_logs_path() + + logger = Logger(module=module) + sys.stdout = logger + + delete_older_log_files(logs_path) + if not os.path.exists(logs_path): + os.mkdir(logs_path) + + date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + id = uuid4() + log_filename = f"{date_time}_{module}_{id}_stdout.log" + log_path = os.path.join(logs_path, log_filename) + + output_file_handler = logging.FileHandler(log_path, mode="w") + + # Format your logs (optional) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s:\n" + "------------------------\n" + "%(message)s\n" + "------------------------\n", + datefmt="%d-%m-%Y, %H:%M:%S", + ) + output_file_handler.setFormatter(formatter) + + logger.addHandler(output_file_handler) + + _log_system_info(logger, log_path, also_spotmax=caller != "Cell-ACDC") + + # if module == 'gui' and GUI_INSTALLED: + # qt_handler = widgets.QtHandler() + # qt_handler.setFormatter(logging.Formatter("%(message)s")) + # logger.addHandler(qt_handler) + + return logger, logs_path, log_path, log_filename + + +def log_segm_params( + model_name, + init_params, + segm_params, + logger_func=print, + preproc_recipe=None, + apply_post_process=False, + standard_postprocess_kwargs=None, + custom_postprocess_features=None, +): + init_params_format = [ + f" * {option} = {value}" for option, value in init_params.items() + ] + init_params_format = "\n".join(init_params_format) + + segm_params_format = [ + f" * {option} = {value}" for option, value in segm_params.items() + ] + segm_params_format = "\n".join(segm_params_format) + + preproc_recipe_format = None + if preproc_recipe is not None: + preproc_recipe_format = [] + for s, step in enumerate(preproc_recipe): + preproc_recipe_format.append(f" * Step {s + 1}") + method = step["method"] + preproc_recipe_format.append(f" - Method: {method}") + for option, value in step["kwargs"].items(): + preproc_recipe_format.append(f" - {option}: {value}") + preproc_recipe_format = "\n".join(preproc_recipe_format) + + standard_postproc_format = None + if apply_post_process and standard_postprocess_kwargs is not None: + standard_postproc_format = [ + f" * {option} = {value}" + for option, value in standard_postprocess_kwargs.items() + ] + standard_postproc_format = "\n".join(standard_postproc_format) + + custom_postproc_format = None + if apply_post_process and custom_postprocess_features is not None: + custom_postproc_format = [ + f" * {feature} = ({low}, {high})" + for feature, (low, high) in custom_postprocess_features.items() + ] + custom_postproc_format = "\n".join(custom_postproc_format) + + separator = "-" * 100 + params_format = ( + f"{separator}\n" + f"Model name: {model_name}\n\n" + "Preprocessing recipe:\n\n" + f"{preproc_recipe_format}\n\n" + "Initialization parameters:\n\n" + f"{init_params_format}\n\n" + "Segmentation parameters:\n\n" + f"{segm_params_format}\n\n" + "Post-processing:\n\n" + f"{standard_postproc_format}\n\n" + "Custom post-processing:\n\n" + f"{custom_postproc_format}\n" + f"{separator}" + ) + logger_func(params_format) + +# Sibling imports (deferred to avoid import cycles) +from .misc import ( + StdErr, +) +from .version import ( + get_info_version_text, +) + diff --git a/cellacdc/utils/misc.py b/cellacdc/utils/misc.py new file mode 100644 index 000000000..c440f36a9 --- /dev/null +++ b/cellacdc/utils/misc.py @@ -0,0 +1,1651 @@ +"""Cell-ACDC utility helpers: misc.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_module_name(script_file_path): + parts = pathlib.Path(script_file_path).parts + parts = list(parts[parts.index("cellacdc") + 1 :]) + parts[-1] = os.path.splitext(parts[-1])[0] + module = ".".join(parts) + return module + + +def filterCommonStart(images_path): + startNameLen = 6 + ls = listdir(images_path) + if not ls: + return [] + allFilesStartNames = [f[:startNameLen] for f in ls] + mostCommonStart = Counter(allFilesStartNames).most_common(1)[0][0] + commonStartFilenames = [f for f in ls if f.startswith(mostCommonStart)] + return commonStartFilenames + + +def remove_known_extension(name): + for ext in KNOWN_EXTENSIONS: + if name.endswith(ext): + return name[: -len(ext)], ext + + return name, "" + + +def getCustomAnnotTooltip(annotState): + toolTip = ( + f"Name: {annotState['name']}\n\n" + f"Type: {annotState['type']}\n\n" + f"Usage: activate the button and RIGHT-CLICK on cell to annotate\n\n" + f"Description: {annotState['description']}\n\n" + f'SHORTCUT: "{annotState["shortcut"]}"' + ) + return toolTip + + +def is_iterable(item): + try: + iter(item) + return True + except TypeError as e: + return False + + +class utilClass: + pass + + +class StdErr: + def __init__(self, logger: Logger = None): + self._sys_stderr = sys.stderr + self._err_msg_line_buffer = [] + self._logger = logger + + def write(self, text: str): + if text.startswith("Traceback"): + print("-" * 100) + + self._sys_stderr.write(text) + + if not text: + return + + self._err_msg_line_buffer.append(text) + if not text.endswith("\n"): + return + + # If the line ends with a newline, flush the buffer + err_line = "".join(self._err_msg_line_buffer) + if self._logger is not None: + self._logger.plain(err_line, write_to_stdout=False) + else: + print(err_line) + + self._err_msg_line_buffer = [] + + def flush(self): + self._sys_stderr.flush() + + def close(self): + """Close the StdErr stream""" + sys.stderr = self._sys_stderr + + +def getMostRecentPath(): + if os.path.exists(recentPaths_path): + df = pd.read_csv(recentPaths_path, index_col="index") + if "opened_last_on" in df.columns: + df = df.sort_values("opened_last_on", ascending=False) + MostRecentPath = "" + for path in df["path"]: + if os.path.exists(path): + MostRecentPath = path + break + else: + MostRecentPath = "" + return MostRecentPath + + +def addToRecentPaths(exp_path, logger=None): + if not os.path.exists(exp_path): + return + exp_path = exp_path.replace("\\", "/") + if os.path.exists(recentPaths_path): + try: + df = pd.read_csv(recentPaths_path, index_col="index") + recentPaths = df["path"].to_list() + if "opened_last_on" in df.columns: + openedOn = df["opened_last_on"].to_list() + else: + openedOn = [np.nan] * len(recentPaths) + if exp_path in recentPaths: + pop_idx = recentPaths.index(exp_path) + recentPaths.pop(pop_idx) + openedOn.pop(pop_idx) + recentPaths.insert(0, exp_path) + openedOn.insert(0, datetime.datetime.now()) + # Keep max 40 recent paths + if len(recentPaths) > 40: + recentPaths.pop(-1) + openedOn.pop(-1) + except Exception as e: + recentPaths = [exp_path] + openedOn = [datetime.datetime.now()] + else: + recentPaths = [exp_path] + openedOn = [datetime.datetime.now()] + df = pd.DataFrame( + { + "path": recentPaths, + "opened_last_on": pd.Series(openedOn, dtype="datetime64[ns]"), + } + ) + df.index.name = "index" + df.to_csv(recentPaths_path) + + +def checkDataIntegrity(filenames, parent_path, parentQWidget=None): + if not filenames: + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph( + "Cell-ACDC could not find any files in the folder " + f"{parent_path}.

    " + "Please make sure that the folder contains at least one image file.

    " + "Thank you for your patience!" + ) + msg.warning(parentQWidget, "Selected folder is emppty", txt) + raise FileNotFoundError(f"No files found in the folder {parent_path}. ") + + char = filenames[0][:2] + startWithSameChar = all([f.startswith(char) for f in filenames]) + if not startWithSameChar: + msg = widgets.myMessageBox() + txt = html_utils.paragraph( + "Cell-ACDC detected files inside the folder " + "that do not start with the same, common basename.

    " + "To ensure correct loading of the data, the folder where " + "the file(s) is/are should either contain a single image file or" + "only files that start with the same, common basename.

    " + "For example the following filenames:

    " + "F014_s01_phase_contr.tif
    " + "F014_s01_mCitrine.tif

    " + "are named correctly since they all start with the " + 'the common basename "F014_s01_". After the common basename you ' + 'can write whatever text you want. In the example above, "phase_contr" ' + 'and "mCitrine" are the channel names.

    ' + "Data loading may still be successfull, so Cell-ACDC will " + "still try to load data now.
    " + ) + filesFormat = [f" - {file}" for file in filenames] + filesFormat = "\n".join(filesFormat) + detailsText = f"Files present in the folder {parent_path}:\n\n{filesFormat}" + msg.addShowInFileManagerButton(parent_path, txt="Open folder...") + msg.warning( + parentQWidget, + "Data structure compromised", + txt, + detailsText=detailsText, + buttonsTexts=("Cancel", "Ok"), + ) + if msg.cancel: + raise TypeError("Process aborted by the user.") + return False + return True + + +def is_in_bounds(x, y, X, Y): + in_bounds = x >= 0 and x < X and y >= 0 and y < Y + return in_bounds + + +def showInExplorer(path): + if is_mac: + os.system(f'open "{path}"') + elif is_linux: + os.system(f'xdg-open "{path}"') + else: + os.startfile(path) + + +def exec_time(func): + @wraps(func) + def inner_function(self, *args, **kwargs): + t0 = time.perf_counter() + if func.__code__.co_argcount == 1 and func.__defaults__ is None: + result = func(self) + elif func.__code__.co_argcount > 1 and func.__defaults__ is None: + result = func(self, *args) + else: + result = func(self, *args, **kwargs) + t1 = time.perf_counter() + s = f"{func.__name__} execution time = {(t1 - t0) * 1000:.3f} ms" + printl(s, is_decorator=True) + return result + + return inner_function + + +def setRetainSizePolicy(widget, retain=True): + sp = widget.sizePolicy() + sp.setRetainSizeWhenHidden(retain) + widget.setSizePolicy(sp) + + +def getAcdcDfSegmPaths(images_path): + ls = listdir(images_path) + basename = getBasename(ls) + paths = {} + for file in ls: + filePath = os.path.join(images_path, file) + fileName, ext = os.path.splitext(file) + endName = fileName[len(basename) :] + if endName.find("acdc_output") != -1 and ext == ".csv": + info_name = endName.replace("acdc_output", "") + paths.setdefault(info_name, {}) + paths[info_name]["acdc_df_path"] = filePath + paths[info_name]["acdc_df_filename"] = fileName + elif endName.find("segm") != -1 and ext == ".npz": + info_name = endName.replace("segm", "") + paths.setdefault(info_name, {}) + paths[info_name]["segm_path"] = filePath + paths[info_name]["segm_filename"] = fileName + return paths + + +def getChannelFilePath(images_path, chName): + file = "" + alignedFilePath = "" + tifFilePath = "" + h5FilePath = "" + for file in listdir(images_path): + filePath = os.path.join(images_path, file) + if file.endswith(f"{chName}_aligned.npz"): + alignedFilePath = filePath + elif file.endswith(f"{chName}.tif"): + tifFilePath = filePath + elif file.endswith(f"{chName}.h5"): + h5FilePath = filePath + if alignedFilePath: + return alignedFilePath + elif h5FilePath: + return h5FilePath + elif tifFilePath: + return tifFilePath + else: + return "" + + +def get_chname_from_basename(filename, basename, remove_ext=True): + if remove_ext: + filename, ext = os.path.splitext(filename) + chName = filename[len(basename) :] + aligned_idx = chName.find("_aligned") + if aligned_idx != -1: + chName = chName[:aligned_idx] + return chName + + +def getBaseAcdcDf(rp): + zeros_list = [0] * len(rp) + nones_list = [None] * len(rp) + minus1_list = [-1] * len(rp) + IDs = [] + xx_centroid = [] + yy_centroid = [] + zz_centroid = [] + for obj in rp: + xc, yc = obj.centroid[-2:] + IDs.append(obj.label) + xx_centroid.append(xc) + yy_centroid.append(yc) + if len(obj.centroid) == 3: + zc = obj.centroid[0] + zz_centroid.append(zc) + + df = pd.DataFrame( + { + "Cell_ID": IDs, + "is_cell_dead": zeros_list, + "is_cell_excluded": zeros_list, + "x_centroid": xx_centroid, + "y_centroid": yy_centroid, + "was_manually_edited": minus1_list, + } + ).set_index("Cell_ID") + if zz_centroid: + df["z_centroid"] = zz_centroid + + return df + + +def getBasenameAndChNames(images_path, useExt=None): + _tempPosData = utilClass() + _tempPosData.images_path = images_path + load.loadData.getBasenameAndChNames(_tempPosData, useExt=useExt) + return _tempPosData.basename, _tempPosData.chNames + + +def getBasename(files): + basename = files[0] + for file in files: + # Determine the basename based on intersection of all files + _, ext = os.path.splitext(file) + sm = difflib.SequenceMatcher(None, file, basename) + i, j, k = sm.find_longest_match(0, len(file), 0, len(basename)) + basename = file[i : i + k] + return basename + + +def findalliter(patter, string): + """Function used to return all re.findall objects in string""" + m_test = re.findall(r"(\d+)_(.+)", string) + m_iter = [m_test] + while m_test: + m_test = re.findall(r"(\d+)_(.+)", m_test[0][1]) + m_iter.append(m_test) + return m_iter + + +def clipSelemMask(mask, shape, Yc, Xc, copy=True): + if copy: + mask = mask.copy() + + Y, X = shape + h, w = mask.shape + + # Bottom, Left, Top, Right global coordinates of mask + Y0, X0, Y1, X1 = Yc - (h / 2), Xc - (w / 2), Yc + (h / 2), Xc + (w / 2) + mask_limits = [floor(Y0) + 1, floor(X0) + 1, floor(Y1) + 1, floor(X1) + 1] + + if Y0 >= 0 and X0 >= 0 and Y1 <= Y and X1 <= X: + # Mask is withing shape boundaries, no need to clip + ystart, xstart, yend, xend = mask_limits + mask_slice = slice(ystart, yend), slice(xstart, xend) + return mask, mask_slice + + if Y0 < 0: + # Mask is exceeding at the bottom + ystart = floor(abs(Y0)) + mask_limits[0] = 0 + mask = mask[ystart:] + if X0 < 0: + # Mask is exceeding at the left + xstart = floor(abs(X0)) + mask_limits[1] = 0 + mask = mask[:, xstart:] + if Y1 > Y: + # Mask is exceeding at the top + yend = ceil(abs(Y1)) - Y + mask_limits[2] = Y + mask = mask[:-yend] + if X1 > X: + # Mask is exceeding at the right + xend = ceil(abs(X1)) - X + mask_limits[3] = X + mask = mask[:, :-xend] + + ystart, xstart, yend, xend = mask_limits + mask_slice = slice(ystart, yend), slice(xstart, xend) + return mask, mask_slice + + +def get_function_argspec( + function, + args_to_skip={ + "logger_func", + }, +): + argspecs = inspect.getfullargspec(function) + kwargs_type_hints = typing.get_type_hints(function) + docstring = function.__doc__ + params = params_to_ArgSpec( + argspecs, kwargs_type_hints, docstring, args_to_skip=args_to_skip + ) + return params + + +def _get_doc_stop_idx(docstring, start_idx, next_param_name=None, debug=False): + if debug: + import pdb + + pdb.set_trace() + + if next_param_name is not None: + doc_stop_idx = docstring.find(f"{next_param_name} : ") + if doc_stop_idx > 1: + return doc_stop_idx + + docstring_from_start = docstring[start_idx:] + next_param_searched = re.search(r"\w+ : ", docstring_from_start) + if next_param_searched is not None: + return next_param_searched.start(0) + start_idx + + doc_stop_idx = docstring.find("Returns") + if doc_stop_idx > 1: + return doc_stop_idx + + doc_stop_idx = docstring.find("Notes") + if doc_stop_idx > 1: + return doc_stop_idx + + return -1 + + +def add_segm_data_param(init_params, init_argspecs): + if init_argspecs.defaults is None: + num_kwargs = 0 + else: + num_kwargs = len(init_argspecs.defaults) + + # Segm model requires segm data --> add it to params + num_args = len(init_argspecs.args) - num_kwargs + if num_args == 1: + # Args is only self --> segm data not needed + return init_params + + desc = ( + "This model requires an additional segmentation file as input.\n\n" + "Please, select which segmentation file to provide to the model." + ) + + segm_data_argspec = ArgSpec( + name="Auxiliary segmentation file", + default="", + type=str, + desc=desc, + docstring=None, + ) + + init_params.insert(0, segm_data_argspec) + return init_params + + +def getDefault_SegmInfo_df(posData, filename): + mid_slice = int(posData.SizeZ / 2) + df = pd.DataFrame( + { + "filename": [filename] * posData.SizeT, + "frame_i": range(posData.SizeT), + "z_slice_used_dataPrep": [mid_slice] * posData.SizeT, + "which_z_proj": ["single z-slice"] * posData.SizeT, + "z_slice_used_gui": [mid_slice] * posData.SizeT, + "which_z_proj_gui": ["single z-slice"] * posData.SizeT, + "resegmented_in_gui": [False] * posData.SizeT, + "is_from_dataPrep": [False] * posData.SizeT, + } + ).set_index(["filename", "frame_i"]) + return df + + +def _jdk_exists(jre_path): + # If jre_path exists and it's windows search for ~/acdc-java/win64/jdk + # or ~/.acdc-java/win64/jdk. If not Windows return jre_path + if not jre_path: + return "" + os_acdc_java_path = os.path.dirname(jre_path) + os_foldername = os.path.basename(os_acdc_java_path) + if not os_foldername.startswith("win"): + return jre_path + if os.path.exists(os_acdc_java_path): + for folder in os.listdir(os_acdc_java_path): + if not folder.startswith("jdk"): + continue + dir_path = os.path.join(os_acdc_java_path, folder) + for file in os.listdir(dir_path): + if file == "bin": + return dir_path + return "" + + +def showUserManual(): + manual_file_path = download_manual() + showInExplorer(manual_file_path) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + return value + return None + + +def extract_zip(zip_path, extract_to_path, verbose=True): + if verbose: + print(f"Extracting to {extract_to_path}...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_to_path) + + +def get_tiff_metadata( + image_arr, + SizeT=None, + SizeZ=None, + PhysicalSizeZ=None, + PhysicalSizeX=None, + PhysicalSizeY=None, + TimeIncrement=None, +): + SizeY, SizeX = image_arr.shape[-2:] + Type = str(image_arr.dtype) + + metadata = {"Pixels": {"SizeX": SizeX, "SizeY": SizeY, "Type": Type}} + + axes = "YX" + if SizeZ is not None and SizeZ > 1: + axes = f"Z{axes}" + metadata["Pixels"]["SizeZ"] = SizeZ + + if SizeT is not None and SizeT > 1: + axes = f"T{axes}" + metadata["Pixels"]["SizeT"] = SizeT + + metadata["axes"] = axes + + if PhysicalSizeX is not None: + metadata["Pixels"]["PhysicalSizeX"] = PhysicalSizeX + + if PhysicalSizeY is not None: + metadata["Pixels"]["PhysicalSizeY"] = PhysicalSizeY + + if PhysicalSizeZ is not None: + metadata["Pixels"]["PhysicalSizeZ"] = PhysicalSizeZ + + if TimeIncrement is not None: + metadata["Pixels"]["TimeIncrement"] = TimeIncrement + + return metadata + + +def to_tiff( + new_path, + data, + SizeT=None, + SizeZ=None, + PhysicalSizeZ=None, + PhysicalSizeX=None, + PhysicalSizeY=None, + TimeIncrement=None, +): + valid_dtypes = (np.uint8, np.uint16, np.float32) + is_valid_dtype = False + for valid_dtype in valid_dtypes: + if np.issubdtype(data.dtype, valid_dtype): + is_valid_dtype = True + break + + if not is_valid_dtype: + data = data.astype(np.float32) + + metadata = get_tiff_metadata( + data, + SizeT=SizeT, + SizeZ=SizeZ, + PhysicalSizeZ=PhysicalSizeZ, + PhysicalSizeX=PhysicalSizeX, + PhysicalSizeY=PhysicalSizeY, + TimeIncrement=TimeIncrement, + ) + + # # Potential alternative + # hyperstack = tifffile.memmap( + # new_path, + # shape=img.shape, + # dtype=img.dtype, + # imagej=True, + # metadata={'axes': 'TZYX'}, + # ) + # hyperstack[:] = img + # hyperstack.flush() + + try: + tifffile.imwrite(new_path, data, metadata=metadata, imagej=True) + except Exception as err: + tifffile.imwrite(new_path, data) + + +def from_lab_to_obj_coords(lab): + rp = skimage.measure.regionprops(lab) + dfs = [] + keys = [] + for obj in rp: + keys.append(obj.label) + obj_coords = obj.coords + ndim = obj_coords.shape[1] + if ndim == 3: + columns = ["z", "y", "x"] + else: + columns = ["y", "x"] + df_obj = pd.DataFrame(data=obj_coords, columns=columns) + dfs.append(df_obj) + df = pd.concat(dfs, keys=keys, names=["Cell_ID", "idx"]).droplevel("idx") + return df + + +def lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=None, z=None): + rp = skimage.measure.regionprops(lab2D) + rois = [] + for obj in rp: + cont = core.get_obj_contours(obj) + yc, xc = obj.centroid + x_str = str((int(xc))).zfill(ndigits) + y_str = str((int(yc))).zfill(ndigits) + name = f"{x_str}-{y_str}" + if z is not None: + z_str = str(z).zfill(ndigits) + name = f"{z_str}-{name}" + + if t is not None: + t_str = str(t).zfill(ndigits) + name = f"{t_str}-{name}" + + name = f"id={obj.label}-{name}" + + roi = ImagejRoi.frompoints(cont, name=name, t=t, z=z, index=obj.label) + rois.append(roi) + return rois + + +def from_lab_to_imagej_rois(lab, ImagejRoi, t=0, SizeT=1, max_ID=None): + if max_ID is None: + max_ID = lab.max() + + if SizeT == 1: + t = None + + SizeY, SizeX = lab.shape[-2:] + ndigitsT = len(str(SizeT)) + ndigitsY = len(str(SizeY)) + ndigitsX = len(str(SizeX)) + + if lab.ndim == 3: + rois = [] + SizeZ = len(lab) + ndigitsZ = len(str(SizeZ)) + ndigits = max(ndigitsT, ndigitsZ, ndigitsY, ndigitsX) + for z, lab2D in enumerate(lab): + z_rois = lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=t, z=z) + rois.extend(z_rois) + else: + ndigits = max(ndigitsT, ndigitsY, ndigitsX) + rois = lab2d_to_rois(ImagejRoi, lab, ndigits, t=t) + return rois + + +def from_imagej_rois_to_segm_data( + TZYX_shape, ID_to_roi_mapper, rescale_rois_sizes, repeat_2d_rois_zslices_range +): + SizeT, SizeZ, SizeY, SizeX = TZYX_shape + segm_data = np.zeros(TZYX_shape, dtype=np.uint32) + for ID, roi in ID_to_roi_mapper.items(): + name = roi.name + name_parts = name.split("-") + zz = [0] + if len(name_parts) == 2 and SizeZ > 1: + # 2D roi in 3D segm data --> place 2D roi on each z-slice + zz = range(*repeat_2d_rois_zslices_range) + + elif len(name_parts) > 2 and SizeZ > 1: + # 2D roi from a 3D roi --> place at requested z-slice + zz = [int(name_parts[-3])] + + tt = [0] * len(zz) + if SizeT > 1: + tt = [roi.t_position] * len(zz) + + y0, x0 = roi.top, roi.left + contours = roi.integer_coordinates + (x0, y0) + xx = contours[:, 0] + yy = contours[:, 1] + if rescale_rois_sizes is not None: + rescale_z = rescale_rois_sizes["Z"] + rescale_y = rescale_rois_sizes["Y"] + rescale_x = rescale_rois_sizes["X"] + + factor_z = rescale_z[1] / rescale_z[0] + factor_y = rescale_y[1] / rescale_y[0] + factor_x = rescale_x[1] / rescale_x[0] + + xx = np.clip(np.round(xx * factor_x).astype(int), 0, SizeX - 1) + yy = np.clip(np.round(yy * factor_y).astype(int), 0, SizeY - 1) + + for t, z in zip(tt, zz): + if rescale_rois_sizes is not None: + z = round(z * factor_z) + z = z if z < SizeZ else SizeZ + z = z if z >= 0 else 0 + + rr, cc = skimage.draw.polygon(yy, xx) + segm_data[t, z, rr, cc] = ID + + return np.squeeze(segm_data) + + +def seconds_to_ETA(seconds): + seconds = round(seconds) + ETA = datetime.timedelta(seconds=seconds) + ETA_split = str(ETA).split(":") + if seconds < 0: + ETA = "00h:00m:00s" + elif seconds >= 86400: + days, hhmmss = str(ETA).split(",") + h, m, s = hhmmss.split(":") + ETA = f"{days}, {int(h):02}h:{int(m):02}m:{int(s):02}s" + else: + h, m, s = str(ETA).split(":") + ETA = f"{int(h):02}h:{int(m):02}m:{int(s):02}s" + return ETA + + +def to_uint8(img): + if img.dtype == np.uint8: + return img + img = np.round(img_to_float(img) * 255).astype(np.uint8) + return img + + +def to_uint16(img): + if img.dtype == np.uint16: + return img + img = np.round(img_to_float(img) * 65535).astype(np.uint16) + return img + + +def img_to_float(img, force_dtype=None, force_missing_dtype=None, warn=True): + input_img_dtype = img.dtype + value = img[(0,) * img.ndim] + img_max = np.max(img) + # Check if float outside of -1, 1 + if img_max <= 1.0 and isinstance(value, (np.floating, float)): + return img + + uint8_max = np.iinfo(np.uint8).max + uint16_max = np.iinfo(np.uint16).max + uint32_max = np.iinfo(np.uint32).max + + img = img.astype(float) + if force_dtype is not None: + dtype_max = np.iinfo(force_dtype).max + img = img / dtype_max + elif input_img_dtype == np.uint8: + # Input image is 8-bit + img = img / uint8_max + elif input_img_dtype == np.uint16: + # Input image is 16-bit + img = img / uint16_max + elif input_img_dtype == np.uint32: + # Input image is 32-bit + img = img / uint32_max + elif force_missing_dtype is not None: + img = img.astype(force_dtype) + elif img_max <= uint8_max: + # Input image is probably 8-bit + if warn: + _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, "8-bit") + img = img / uint8_max + elif img_max <= uint16_max: + # Input image is probably 16-bit + if warn: + _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, "16-bit") + img = img / uint16_max + elif img_max <= uint32_max: + # Input image is probably 32-bit + if warn: + _warnings.warn_image_overflow_dtype(input_img_dtype, img_max, "32-bit") + img = img / uint32_max + else: + # Input image is a non-supported data type + raise TypeError( + f"The maximum value in the image is {img_max} which is greater than the " + f"maximum value supported of {uint32_max} (32-bit). " + "Please consider converting your images to 32-bit or 16-bit first." + ) + return img + + +def float_img_to_dtype(img, dtype): + if img.dtype == dtype: + return img + + img_max = img.max() + if img_max > 1.0: + raise TypeError( + "Images of float data type with values greater than 1.0 cannot " + f"be safely casted to {dtype}. " + f"The max value of the input image is {img_max:.3f}" + ) + + img_min = img.min() + if img_min < -1.0: + raise TypeError( + "Images of float data type with values smaller than -1.0 cannot " + f"be safely casted to {dtype}." + f"The minumum value of the input image is {img_min:.3f}" + ) + + if dtype == np.uint8: + return skimage.img_as_ubyte(img) + + if dtype == np.uint16: + return skimage.img_as_uint(img) + + if dtype == np.float32: + return img.astype(np.float32) + + if dtype == np.float64: + return img.astype(np.float64) + + raise TypeError( + f"Invalid output data type `{dtype}`. " + "Valid output data types are `np.uint8` and `np.uint16`" + ) + + +def convert_to_dtype(data: np.ndarray, dtype): + if data.dtype == dtype: + return data + val = data[tuple([0] * data.ndim)] + if isinstance(val, (np.floating, float)): + data = float_img_to_dtype(data, dtype) + elif dtype == np.uint8: + data = np.round(img_to_float(data) * 255).astype(np.uint8) + elif dtype == np.uint16: + data = np.round(img_to_float(data) * 65535).astype(np.uint16) + else: + raise TypeError( + f"Invalid output data type `{dtype}`. " + "Valid data types are floating-point format, `np.uint8` " + "and `np.uint16`" + ) + return data + + +def _apt_update_command(): + return "sudo apt-get update" + + +def _apt_gcc_command(): + return "sudo apt install python-dev gcc" + + +def jdk_windows_url(): + return "https://hmgubox2.helmholtz-muenchen.de/index.php/s/R62Ktcda6jWea2s" + + +def cpp_windows_url(): + return "https://visualstudio.microsoft.com/visual-cpp-build-tools/" + + +def check_napari_plugin(plugin_name, module_name, parent=None): + try: + import_module(module_name) + except ModuleNotFoundError as e: + url = "https://napari.org/stable/plugins/find_and_install_plugin.html#find-and-install-plugins" + href = html_utils.href_tag("this guide", url) + txt = html_utils.paragraph(f""" + To correctly use this napari utility you need to install the + plugin called {plugin_name}.

    + Please, read {href} on how to install plugins in napari.

    + You will need to restart both napari and Cell-ACDC after installing + the plugin.

    + NOTE: in the text box in napari you will need to write the full name + {plugin_name} becasue it is NOT A SEARCH BOX. + """) + msg = widgets.myMessageBox() + msg.critical(parent, f"Napari plugin required", txt) + raise e + + +def purge_module(module_name): + to_delete = [ + mod + for mod in sys.modules + if mod == module_name or mod.startswith(module_name + ".") + ] + for mod in to_delete: + del sys.modules[mod] + + importlib.invalidate_caches() + importlib.import_module(module_name) + if module_name in sys.modules: + importlib.reload(sys.modules[module_name]) + else: + raise ModuleNotFoundError(f"Module '{module_name}' not found in sys.modules.") + + +def is_gui_running(): + if not GUI_INSTALLED: + return False + + return QCoreApplication.instance() is not None + + +def _subprocess_run_command(command, shell=True, callback="check_call"): + func = getattr(subprocess, callback) + try: + out = func(command, shell=shell) + except Exception as err: + print( + f"[WARNING]: Command `{command}` failed. Trying with `{command.split()}`..." + ) + out = func(command.split(), shell=shell) + + return out + + +def _run_command(command: str | list[str], shell=False): + if not isinstance(command, (str, list)): + raise TypeError( + f"Command must be a string or a list of strings, not {type(command)}" + ) + + command_str = None + if isinstance(command, str): + args_list = [command] + command_str = command + else: + args_list = command + if len(command) == 1: + command_str = command[0] + + try: + subprocess.check_call(args_list, shell=shell) + return + except Exception as err: + pass + + if command_str is None: + return + + try: + subprocess.check_call(command_str, shell=shell) + return + except Exception as err: + pass + + try: + from . import acdc_regex + + args = acdc_regex.RE_SPLIT_SPACES_IGNORE_QUOTES.split(command_str)[1::2] + subprocess.check_call(args, shell=shell) + return + except Exception as err: + pass + + +def get_chained_attr(_object, _name): + for attr in _name.split("."): + _object = getattr(_object, attr) + return _object + + +def get_fiji_base_command(): + command = None + if is_mac: + command = get_fiji_exec_folderpath() + + return command + + +def _init_fiji_cli(): + if is_win: + return True + + fiji_app_folderpath = get_fiji_exec_folderpath() + args_add_to_path = [f"chmod 755 {fiji_app_folderpath}"] + try: + subprocess.check_call(args_add_to_path, shell=True) + return True + except Exception as e: + printl(f"Error occurred while setting permissions: {e}") + return False + + +def test_fiji_base_command(logger_func=print): + base_command = get_fiji_base_command() + + if base_command is None: + logger_func("[WARNING]: Fiji is not present.") + return False + + command = f"{base_command} --headless" + return run_fiji_command(command=command, logger_func=logger_func) + + +def run_fiji_command(command=None, logger_func=print): + if command is None: + command = f"{get_fiji_base_command()} --headless" + + init_success = _init_fiji_cli() + if not init_success: + return False + + separator = "-" * 100 + commands = (command, command.split()) + for args in commands: + logger_func(f'{separator}\nTrying Fiji command: "{args}"...\n{separator}\n') + try: + subprocess.check_call(args, shell=True) + return True + except Exception as err: + continue + return False + + +def import_segment_module(model_name): + try: + acdcSegment = import_module(f"cellacdc.segmenters.{model_name}.acdcSegment") + except ModuleNotFoundError as e: + # Check if custom model + cp = config.ConfigParser() + cp.read(models_list_file_path) + model_path = cp[model_name]["path"] + spec = importlib.util.spec_from_file_location("acdcSegment", model_path) + acdcSegment = importlib.util.module_from_spec(spec) + sys.modules["acdcSegment"] = acdcSegment + spec.loader.exec_module(acdcSegment) + return acdcSegment + + +def _available_frameworks(model_name): + frameworks = { + "cuda": ( + model_name.lower().find("cellpose") != -1 + or model_name.lower().find("omnipose") != -1 + or model_name.lower().find("deepsea") != -1 + or model_name.lower().find("segment_anything") != -1 + or model_name.lower().find("sam2") != -1 + or model_name.lower().find("yeaz") != -1 + or model_name.lower().find("yeaz_v2") != -1 + ), + "directML": ( + model_name.lower().find("cellpose_v4") != -1 + or model_name.lower().find("cellpose_v3") != -1 # has its own way to check + ), + } + return frameworks + + +def find_missing_integers(lst, max_range=None): + if max_range is not None: + max_range = lst[-1] + 1 + return [x for x in range(lst[0], max_range) if x not in lst] + + +def synthetic_image_geneator(size=(512, 512), f_x=1, f_y=1): + Y, X = size + x = np.linspace(0, 10, Y) + y = np.linspace(0, 10, X) + xx, yy = np.meshgrid(x, y) + img = np.sin(f_x * xx) * np.cos(f_y * yy) + return img + + +def get_slices_local_into_global_arr(bbox_coords, global_shape): + slice_global_to_local = [] + slice_crop_local = [] + for (_min, _max), _D in zip(bbox_coords, global_shape): + _min_crop, _max_crop = None, None + if _min < 0: + _min_crop = abs(_min) + _min = 0 + if _max > _D: + _max_crop = _D - _max + _max = _D + + slice_global_to_local.append(slice(_min, _max)) + slice_crop_local.append(slice(_min_crop, _max_crop)) + + return tuple(slice_global_to_local), tuple(slice_crop_local) + + +def format_cca_manual_changes(changes: dict): + txt = "" + for ID, changes_ID in changes.items(): + txt = f"{txt}* ID {ID}:\n" + for col, (old_val, new_val) in changes_ID.items(): + txt = f"{txt} - {col}: {old_val} --> {new_val}\n" + txt = f"{txt}--------------------------------\n\n" + return txt + + +def _parse_bool_str(value): + if isinstance(value, bool): + return value + + if value == "True": + return True + elif value == "False": + return False + + +def init_input_points_df(posData, input_points_filepath): + input_points_df = None + if os.path.exists(input_points_filepath): + input_points_df = pd.read_csv(input_points_filepath) + else: + # input_points_filepath is actually and endname + for file in listdir(posData.images_path): + if file.endswith(input_points_filepath): + filepath = os.path.join(posData.images_path, file) + input_points_df = pd.read_csv(filepath) + break + + if input_points_df is None: + raise FileNotFoundError( + f'Could not find input points table from file "input_points_filepath" ' + "Perhaps, you forgot to save the table?" + ) + + for col in ("x", "y", "id"): + if col not in input_points_df.columns: + raise KeyError( + f"Input points table is missing colum {col}. It must have " + "the colums (x, y, id)" + ) + + return input_points_df + + +def pairwise(iterable): + # pairwise('ABCDEFG') → AB BC CD DE EF FG + iterator = iter(iterable) + a = next(iterator, None) + for b in iterator: + yield a, b + a = b + + +def _relabel_cca_dfs_and_segm_data( + cca_dfs, + IDs_mapper, + asymm_tracked_segm, + progressbar=True, +): + # Rename Cell_ID index according to asymmetric cell div convention + if progressbar: + pbar = tqdm( + desc="Applying asymmetric division", total=len(IDs_mapper), ncols=100 + ) + for key, (root_ID, parent_ID) in IDs_mapper.items(): + div_frame_i, daughter_ID = key + for frame_i in range(div_frame_i, len(asymm_tracked_segm)): + lab = asymm_tracked_segm[frame_i] + rp = skimage.measure.regionprops(lab) + rp_mapper = {obj.label: obj for obj in rp} + obj_daught = rp_mapper.get(daughter_ID) + mother_ID = root_ID if rp_mapper.get(root_ID) is None else parent_ID + + cca_dfs[frame_i].rename(index={daughter_ID: mother_ID}, inplace=True) + + if obj_daught is None: + continue + + lab[obj_daught.slice][obj_daught.image] = mother_ID + + if progressbar: + pbar.update() + + if progressbar: + pbar.close() + + +def get_empty_stored_data_dict(): + return { + "regionprops": None, + "labels": None, + "acdc_df": None, + "delROIs_info": {"rois": [], "delMasks": [], "delIDsROI": [], "state": []}, + "IDs": [], + "manually_edited_lab": {"lab": {}, "zoom_slice": None}, + } + + +def iterate_along_axes(arr, axes, arr_ndim=None): + if arr_ndim is None: + arr_ndim = arr.ndim + axes = list(axes) + front_axes = axes + [i for i in range(arr_ndim) if i not in axes] + arr_moved = np.moveaxis(arr, front_axes, range(arr_ndim)) + iter_shape = arr_moved.shape[: len(axes)] + for idx in np.ndindex(iter_shape): + # Build the index for the original array + full_idx = [slice(None)] * arr_ndim + for axis, i in zip(axes, idx): + full_idx[axis] = i + yield tuple(full_idx) + + +def get_input_output_mapper( + input_shape: Tuple[int], + iterate_axes: Tuple[int], + output_shape: Tuple[int], + output_axes: Tuple[int], +) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]: + """Creates list of tuples with the input and output indices + + Parameters + ---------- + input_shape : Tuple[int] + Shape of the input array + iterate_axes : Tuple[int] + Axes to iterate over + output_shape : Tuple[int] + Shape of the output array + output_axes : Tuple[int] + Axes of the output array + """ + assert len(iterate_axes) == len(output_axes) + + iterate_shape = tuple(input_shape[axis] for axis in iterate_axes) + mapper = [] + + for idx_vals in itertools.product(*[range(s) for s in iterate_shape]): + # Build full input index + input_index = [slice(None)] * len(input_shape) + for axis in iterate_axes: + i = iterate_axes.index(axis) + input_index[axis] = idx_vals[i] + + # Build full output index + output_index = [slice(None)] * len(output_shape) + for axis in output_axes: + i = output_axes.index(axis) + output_index[axis] = idx_vals[i] + + input_index = tuple(input_index) + output_index = tuple(output_index) + + mapper.append((input_index, output_index)) + + return mapper + + +def translateStrNone(*args): + args = list(args) + for i, arg in enumerate(args): + if isinstance(arg, str): + if arg.lower() == "none": + args[i] = None + elif arg.lower() == "true": + args[i] = True + elif arg.lower() == "false": + args[i] = False + + return args + + +def try_kwargs(func, *args, **kwargs): + """ + Attempt to call a function with the provided arguments and keyword arguments. + + If the function raises a TypeError due to unexpected keyword arguments, + those arguments are dynamically removed, and the function is retried. + This process continues until the function succeeds or no keyword arguments + remain, in which case the exception is re-raised. + + Args: + func (Callable): The function to call. + *args: Positional arguments to pass to the function. + **kwargs: Keyword arguments to pass to the function. + + Returns: + Tuple[Any, List[str]]: A tuple containing: + - The result of the function call (or None if it fails). + - A list of keyword arguments that were removed. + + Raises: + ValueError: If a keyword argument mentioned in the error message + is not found in the provided kwargs. + TypeError: If the function fails with a TypeError after all keyword + arguments have been removed. + """ + + kwargs = kwargs.copy() # Create a copy to avoid modifying the original + removed_kwargs = [] + pattern = r"unexpected keyword argument ['\"](\w+)['\"]" + while True: + try: + return func(*args, **kwargs), removed_kwargs + except TypeError as e: + match = re.search(pattern, str(e)) + if match: + kwarg_name = match.group(1) + if kwarg_name in kwargs: + del kwargs[kwarg_name] + removed_kwargs.append(kwarg_name) + else: + raise ValueError( + f"Keyword argument '{kwarg_name}' not found in kwargs." + ) + else: + raise e + + if len(kwargs) == 0: + print(f"Function {func.__name__} failed with TypeError: {e}") + raise e + + +def get_obj_by_label(rp, target_label): + """ + Returns the object with the specified label from the given list of objects. + + Parameters + ---------- + rp : list + The list of objects to search through. + target_label : str + The label of the object to find. + + Returns + ------- + object + The object with the specified label, or None if not found. + """ + for obj in rp: + if obj.label == target_label: + return obj + return None + + +def find_distances_ID(rps, point=None, ID=None): + """ + Calculate the distances between a given point and the centroids of a list of regionprops. + + Parameters + ---------- + rps : list + List of regionprops objects. + point : tuple, optional + The coordinates of the point. Defaults to None. + ID : int, optional + The label ID of the regionprops object. Defaults to None. + + Returns + ------- + numpy.ndarray + A matrix of distances between the point and the centroids. + + Raises + ------ + ValueError + If ID is not found in the list of regionprops (list of cells). + ValueError + If neither ID nor point is provided. + ValueError + If both ID and point are provided. + """ + + if ID is not None and point is None: + try: + point = [rp.centroid for rp in rps if rp.label == ID][0] + except IndexError: + raise ValueError(f"ID {ID} not found in regionprops (list of cells).") + + elif ID is None and point is None: + raise ValueError("Either ID or point must be provided.") + + elif ID is not None and point is not None: + raise ValueError("Only one of ID or point must be provided.") + + point = point[ + ::-1 + ] # rp are in (y, x) format (or (z, y, x) for 3D data) so I need to reverse order + point = np.array([point]) + centroids = np.array([rp.centroid for rp in rps]) + diff = point[:, np.newaxis] - centroids + dist_matrix = np.linalg.norm(diff, axis=2) + return dist_matrix + + +def sort_IDs_dist(rps, point=None, ID=None): + """Sorts the IDs of regionprops based on their distances to a given point. + + Parameters + ---------- + rps : list + A list of regionprops objects representing cells. + point : tuple, optional + The coordinates of the point to calculate distances from. + If not provided, it will be calculated based on the given ID. + ID : int, optional + The ID of the regionprops object to calculate distances from. + If this and point are both provided, or neither, an error will be + raised. + + Returns + ------- + list + A sorted list of IDs based on their distances to the given point. + + Raises + ------ + ValueError + If ID is not found in the list of regionprops objects. + ValueError + If neither ID nor point is provided. + ValueError + If both ID and point are provided. + + """ + if ID is not None and point is None: + try: + point = [rp.centroid for rp in rps if rp.label == ID][0] + except IndexError: + raise ValueError(f"ID {ID} not found in regionprops (list of cells).") + + elif ID is None and point is None: + raise ValueError("Either ID or point must be provided.") + + elif ID is not None and point is not None: + raise ValueError("Only one of ID or point must be provided.") + + IDs = [rp.label for rp in rps] + if len(IDs) == 0: + return [] + elif len(IDs) == 1: + return IDs + dist_matrix = find_distances_ID(rps, point=point) + dist_matrix = np.squeeze(dist_matrix) + + sorted_ids = sorted(zip(dist_matrix, IDs)) + sorted_ids = [ID for _, ID in sorted_ids] + return sorted_ids + + +def safe_get_or_call(obj, path: str): + """Safely get nested attributes or call methods with literal args from a string path.""" + expr = ast.parse(path, mode="eval").body + + def _eval(node, current_obj): + if isinstance(node, ast.Attribute): + return getattr(_eval(node.value, current_obj), node.attr) + elif isinstance(node, ast.Call): + func = _eval(node.func, current_obj) + args = [ast.literal_eval(arg) for arg in node.args] + kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in node.keywords} + return func(*args, **kwargs) + elif isinstance(node, ast.Name): + # First name in chain is assumed to be from `obj` + return getattr(current_obj, node.id) + else: + raise ValueError(f"Unsupported syntax: {ast.dump(node)}") + + return _eval(expr, obj) + + +def format_commit_date_utc(utc_str): + # Parse the UTC date string (ISO 8601 format) + dt = datetime.datetime.fromisoformat(utc_str.replace("Z", "+00:00")) + + # Convert to your local time zone (optional) + local_dt = dt.astimezone() # removes UTC offset if local + + # Format nicely + return local_dt.strftime(r"%A %d %B %Y at %H:%M") + + +def get_linux_distribution_name(): + import csv + + RELEASE_DATA = {} + with open("/etc/os-release") as f: + reader = csv.reader(f, delimiter="=") + for row in reader: + if row: + RELEASE_DATA[row[0]] = row[1] + if RELEASE_DATA["ID"] in ["debian", "raspbian"]: + with open("/etc/debian_version") as f: + DEBIAN_VERSION = f.readline().strip() + major_version = DEBIAN_VERSION.split(".")[0] + version_split = RELEASE_DATA["VERSION"].split(" ", maxsplit=1) + if version_split[0] == major_version: + # Just major version shown, replace it with the full version + RELEASE_DATA["VERSION"] = " ".join([DEBIAN_VERSION] + version_split[1:]) + + name_version = f"{RELEASE_DATA['NAME']} {RELEASE_DATA['VERSION']}" + + return name_version + + +def reset_settings(): + question = ( + 'Do you want to reset Cell-ACDC settings- type "h" for help - (y/[n]/h)? ' + ) + info_txt = ( + "If you reset Cell-ACDC settings, the folder below will be deleted.\n\n" + "This means deeleting things like custom shortcuts, recent paths, last " + "selections, and GUI preferences.\n\n" + f'Settings folder path: "{settings_folderpath}"' + ) + answer = "y" + while True: + try: + answer = input(f"\n{question}") + except Exception as err: + break + + if answer == "n": + print("*" * 100) + return "Resetting Cell-ACDC settings cancelled." + + if answer == "y": + break + + if answer == "h": + print("-" * 100) + print(f"\n{info_txt}") + print("=" * 100) + + print( + f'"{answer}" is not a valid answer. ' + 'Type "y" for "yes", "n" for "no", or "h" for help.' + ) + + try: + os.remove(settings_folderpath) + print("*" * 100) + out_txt = ( + "Cell-ACDC settings have been reset.\n\n" + "The following folder was deleted:\n\n" + f"{settings_folderpath}" + ) + except Exception as err: + traceback.print_exc() + print("*" * 100) + out_txt = ( + "**ERROR** occured when trying to remove the settings folder.\n\n" + "To reset Cell-ACDC settings, please remove this folder:\n\n" + f"{settings_folderpath}\n" + ) + return out_txt + + +def separate_fluo_segment_channels(channels): + segms_to_load = [] + channels_to_load = [] + current_segm = False + for ch in channels: + if ch == "current segm.": + current_segm = True + elif "segm" in ch: + segms_to_load.append(ch) + else: + channels_to_load.append(ch) + return segms_to_load, channels_to_load, current_segm + +# Sibling imports (deferred to avoid import cycles) +from .logging import ( + Logger, +) +from .models import ( + download_manual, + params_to_ArgSpec, +) +from .paths import ( + get_fiji_exec_folderpath, + listdir, +) + diff --git a/cellacdc/utils/models.py b/cellacdc/utils/models.py new file mode 100644 index 000000000..8a5262a57 --- /dev/null +++ b/cellacdc/utils/models.py @@ -0,0 +1,1150 @@ +"""Cell-ACDC utility helpers: models.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_add_custom_prompt_model_instructions(): + init_sh = html_utils.init_sh + segment_sh = html_utils.segment_sh + add_prompt_sh = html_utils.add_prompt_sh + href = f'here' + text = html_utils.paragraph(f""" + To use a custom prompt model, you need to create a Python file with the name + acdcPromptModel.py.
    + Note that the folder name where you place this file will be used as the + model name.

    + In this file, you will implement a class called Model with + at least the {init_sh} to initialise the model,
    + the {add_prompt_sh} method to add prompts (points, boxes, etc.) + to the model, and the {segment_sh} method to run the + segmentation.

    + Have a look at the existing models in the promptable_models + folder for examples.

    + If it doesn't work, please report the issue {href} with the + code you wrote. Thanks! + """) + return text + + +def get_add_custom_model_instructions(): + user_manual_url = "https://github.com/SchmollerLab/Cell_ACDC/blob/main/UserManual/Cell-ACDC_User_Manual.pdf" + href_user_manual = f'user manual' + href = f'here' + class_sh = html_utils.class_sh + def_sh = html_utils.def_sh + kwargs_sh = html_utils.kwargs_sh + Model_sh = html_utils.Model_sh + segment_sh = html_utils.segment_sh + predict_sh = html_utils.predict_sh + init_sh = html_utils.init_sh + myModel_sh = html_utils.myModel_sh + return_sh = html_utils.return_sh + equal_sh = html_utils.equal_sh + open_par_sh = html_utils.open_par_sh + close_par_sh = html_utils.close_par_sh + image_sh = html_utils.image_sh + from_sh = html_utils.from_sh + import_sh = html_utils.import_sh + s = html_utils.paragraph(f""" + To use a custom model first create a folder with the name of your model.

    + Inside this new folder create a file named acdcSegment.py.

    + In the acdcSegment.py file you will implement the model class.

    + Have a look at the other existing models, but essentially you have to create + a class called Model with at least
    + the {init_sh} and the {segment_sh} method.

    + The {segment_sh} method takes the image (2D or 3D) as an input and return the segmentation mask.

    + You can find more details in the {href_user_manual} at the section + called Adding segmentation models to the pipeline.

    + Pseudo-code for the acdcSegment.py file: +
    
    +    {from_sh} myModel {import_sh} {myModel_sh}
    +
    +    {class_sh} {Model_sh}:
    +        {def_sh} {init_sh}(self, {kwargs_sh}):
    +            self.model {equal_sh} {myModel_sh}{open_par_sh}{close_par_sh}
    +
    +        {def_sh} {segment_sh}(self, {image_sh}, {kwargs_sh}):
    +            labels {equal_sh} self.model.{predict_sh}{open_par_sh}{image_sh}{close_par_sh}
    +            {return_sh} labels
    +    
    + + If it doesn't work, please report the issue {href} with the + code you wrote. Thanks. + """) + return s + + +def setDefaultValueArgSpecsFromKwargs(params: List[ArgSpec], kwargs: Dict[str, object]): + new_params = [] + for param in params: + new_value = kwargs.get(param.name) + if new_value is None: + new_params.append(param) + continue + + new_param = ArgSpec( + name=param.name, + default=new_value, + type=param.type, + desc=param.desc, + docstring=param.docstring, + ) + new_params.append(new_param) + return new_params + + +def insertModelArgSpec( + params, param_name, param_value, param_type=None, desc="", docstring="" +): + updated_params = [] + for param in params: + if param.name == param_name: + if param_type is None: + param_type = param.type + new_param = ArgSpec( + name=param_name, + default=param_value, + type=param_type, + desc=desc, + docstring=docstring, + ) + updated_params.append(new_param) + else: + updated_params.append(param) + return updated_params + + +def getModelArgSpec(acdcSegment): + init_ArgSpec = inspect.getfullargspec(acdcSegment.Model.__init__) + init_kwargs_type_hints = typing.get_type_hints(acdcSegment.Model.__init__) + init_doc = acdcSegment.Model.__init__.__doc__ + init_params = params_to_ArgSpec(init_ArgSpec, init_kwargs_type_hints, init_doc) + init_params = add_segm_data_param(init_params, init_ArgSpec) + + segment_ArgSpec = inspect.getfullargspec(acdcSegment.Model.segment) + segment_kwargs_type_hints = typing.get_type_hints(acdcSegment.Model.segment) + try: + segment_ArgSpec.args.remove("frame_i") + except Exception as e: + pass + + segment_doc = acdcSegment.Model.segment.__doc__ + segment_params = params_to_ArgSpec( + segment_ArgSpec, + segment_kwargs_type_hints, + segment_doc, + ) + + return init_params, segment_params + + +def parse_model_param_doc(name, next_param_name=None, docstring=None): + if not docstring: + return "" + + try: + # Extract parameter description from 'param : ...' + start_text = f"{name} : " + if docstring.find(start_text) == -1: + # Parameter not present in docstring + return "" + + doc_start_idx = docstring.find(start_text) + len(start_text) + + doc_stop_idx = _get_doc_stop_idx( + docstring, doc_start_idx, next_param_name=next_param_name + ) + if doc_stop_idx == -1: + doc_stop_idx = len(docstring) + + param_doc = docstring[doc_start_idx:doc_stop_idx] + + # Start at first end of line + param_doc = param_doc[param_doc.find("\n") + 1 :] + + # Replace multiples spaces with single space + param_doc = re.sub(" +", " ", param_doc) + + # Remove trailing spaces + param_doc = param_doc.strip() + except Exception as err: + param_doc = "" + + param_doc = param_doc.replace(", optional", "") + + return param_doc + + +def params_to_ArgSpec(fullargspecs, type_hints, docstring, args_to_skip=None): + params = [] + + if fullargspecs.defaults is None: + return params + + if args_to_skip is None: + args_to_skip = set() + + num_params = len(fullargspecs.args) + ip = num_params - len(fullargspecs.defaults) + if ip < 0: + return params + + for arg, default in zip(fullargspecs.args[ip:], fullargspecs.defaults): + if arg in args_to_skip: + continue + + if arg in type_hints: + _type = type_hints[arg] + else: + _type = type(default) + + next_param_name = None + if ip + 1 < num_params: + next_param_name = fullargspecs.args[ip + 1] + + param_doc = parse_model_param_doc( + arg, next_param_name=next_param_name, docstring=docstring + ) + param = ArgSpec( + name=arg, default=default, type=_type, desc=param_doc, docstring=docstring + ) + params.append(param) + ip += 1 + return params + + +def getClassArgSpecs(classModule, runMethodName="run"): + init_ArgSpec = inspect.getfullargspec(classModule.__init__) + init_kwargs_type_hints = typing.get_type_hints(classModule.__init__) + init_doc = classModule.__init__.__doc__ + init_params = params_to_ArgSpec(init_ArgSpec, init_kwargs_type_hints, init_doc) + + run_ArgSpec = inspect.getfullargspec(getattr(classModule, runMethodName)) + run_kwargs_type_hints = typing.get_type_hints(getattr(classModule, runMethodName)) + run_doc = getattr(classModule, runMethodName).__doc__ + run_params = params_to_ArgSpec( + run_ArgSpec, + run_kwargs_type_hints, + run_doc, + args_to_skip={"signals", "export_to"}, + ) + return init_params, run_params + + +def getTrackerArgSpec(trackerModule, realTime=False): + init_ArgSpec = inspect.getfullargspec(trackerModule.tracker.__init__) + init_kwargs_type_hints = typing.get_type_hints(trackerModule.tracker.__init__) + init_doc = trackerModule.tracker.__init__.__doc__ + init_params = params_to_ArgSpec(init_ArgSpec, init_kwargs_type_hints, init_doc) + if realTime: + track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track_frame) + track_kwargs_type_hints = typing.get_type_hints( + trackerModule.tracker.track_frame + ) + track_doc = trackerModule.tracker.track_frame.__doc__ + else: + track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track) + track_kwargs_type_hints = typing.get_type_hints(trackerModule.tracker.track) + track_doc = trackerModule.tracker.track.__doc__ + + track_params = params_to_ArgSpec( + track_ArgSpec, + track_kwargs_type_hints, + track_doc, + args_to_skip={"signals", "export_to"}, + ) + return init_params, track_params + + +def isIntensityImgRequiredForTracker(trackerModule): + track_ArgSpec = inspect.getfullargspec(trackerModule.tracker.track) + num_args = len(track_ArgSpec.args) - len(track_ArgSpec.defaults) + # If the number of args is 3 then we have `self, labels, image` as args + # which means the tracker requires the image + return num_args == 3 + + +def download_examples(which="time_lapse_2D", progress=None): + examples_path, example_path, url, file_size = get_examples_path(which) + if os.path.exists(example_path): + if progress is not None: + # display 100% progressbar + progress.emit(0, 0) + return example_path + + zip_dst = os.path.join(examples_path, "example_temp.zip") + + if not os.path.exists(examples_path): + os.makedirs(examples_path, exist_ok=True) + + print(f"Downloading example to {example_path}") + + download_url(url, zip_dst, verbose=False, file_size=file_size, progress=progress) + exctract_to = examples_path + extract_zip(zip_dst, exctract_to) + + if progress is not None: + # display 100% progressbar + progress.emit(0, 0) + + # Remove downloaded zip archive + os.remove(zip_dst) + print("Example downloaded successfully") + return example_path + + +def check_model_exists(model_path, model_name): + try: + import cellacdc + + m = model_name.lower() + weights_filenames = getattr(cellacdc, f"{m}_weights_filenames") + files_present = listdir(model_path) + return all([f in files_present for f in weights_filenames]) + except Exception as e: + return True + + +def _model_url(model_name, return_alternative=False): + if model_name == "YeaZ": + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/8PMePcwJXmaMMS6/download/YeaZ_weights.zip" + alternative_url = ( + "https://zenodo.org/record/6125825/files/YeaZ_weights.zip?download=1" + ) + file_size = 693685011 + elif model_name == "YeastMate": + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/pMT8pAmMkNtN8BP/download/yeastmate_weights.zip" + alternative_url = ( + "https://zenodo.org/record/6140067/files/yeastmate_weights.zip?download=1" + ) + file_size = 164911104 + elif model_name == "segment_anything": + url = [ + "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + ] + file_size = [2564550879, 1249524736, 375042383] + alternative_url = "" + elif model_name == "YeaZ_v2": + url = [ + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/5PARckkcJcN9D3S/download/weights_budding_BF_multilab_0_1", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/CTHq4HN3adyFbnE/download/weights_budding_PhC_multilab_0_1", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/QTtBJycYnLQZsHQ/download/weights_fission_multilab_0_2", + ] + file_size = [124142981, 124143031, 124144759] + alternative_url = "https://github.com/rahi-lab/YeaZ-GUI#installation" + elif model_name == "DeepSea": + url = [ + "https://github.com/abzargar/DeepSea/raw/master/deepsea/trained_models/segmentation.pth", + "https://github.com/abzargar/DeepSea/raw/master/deepsea/trained_models/tracker.pth", + ] + file_size = [7988969, 8637439] + alternative_url = "" + elif model_name == "TAPIR": + url = ["https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy"] + file_size = [124408122] + alternative_url = "" + elif model_name == "Cellpose_germlineNuclei": + url = [ + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/AXG6fFfD8o5GZ83/download/cellpose_germlineNuclei_2023" + ] + file_size = [26570752] + alternative_url = "" + elif model_name == "omnipose": + url = [ + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/DynLkocWRbQfyRp/download/bact_fluor_cptorch_0" + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/2248Eoyozp3Ezj2/download/bact_fluor_omnitorch_0", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/GiacDfXGerxE7PT/download/bact_phase_omnitorch_0", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/DDq8s3CgnG2Yw6H/download/cyto2_omnitorch_0", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/MM5meM2J5HbWqXR/download/plant_cptorch_0", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/aap7znrWq5sE6JQ/download/plant_omnitorch_0", + "https://hmgubox2.helmholtz-muenchen.de/index.php/s/w5M46x9qr8zLHZH/download/size_cyto2_omnitorch_0.npy", + ] + file_size = [26558464, 26558464, 26558464, 26558464, 26558464, 75071488, 4096] + alternative_url = "" + elif model_name == "sam2": + url = [ + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + ] + file_size = [155233385, 184211977, 319128965, 910600801] + alternative_url = "" + else: + return + if return_alternative: + return url, alternative_url + else: + return url, file_size + + +def _download_segment_anything_models(): + urls, file_sizes = _model_url("segment_anything") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path("segment_anything", create_temp_dir=False) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url( + url, temp_dst, file_size=file_size, desc="segment_anything", verbose=False + ) + + shutil.move(temp_dst, final_dst) + + +def _download_sam2_models(): + urls, file_sizes = _model_url("sam2") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path("sam2", create_temp_dir=False) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url(url, temp_dst, file_size=file_size, desc="sam2", verbose=False) + + shutil.move(temp_dst, final_dst) + + +def _download_deepsea_models(): + urls, file_sizes = _model_url("DeepSea") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path("deepsea", create_temp_dir=False) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url(url, temp_dst, file_size=file_size, desc="deepsea", verbose=False) + + shutil.move(temp_dst, final_dst) + + +def download_manual(): + manual_folder_path = os.path.join(user_profile_path, "acdc-manual") + if not os.path.exists(manual_folder_path): + os.makedirs(manual_folder_path, exist_ok=True) + + manual_file_path = os.path.join(user_profile_path, "Cell-ACDC_User_Manual.pdf") + if not os.path.exists(manual_file_path): + url = "https://github.com/SchmollerLab/Cell_ACDC/raw/main/UserManual/Cell-ACDC_User_Manual.pdf" + download_url(url, manual_file_path, file_size=1727470) + return manual_file_path + + +def download_bioformats_jar(qparent=None, logger_info=print, logger_exception=print): + dst_filepath = os.path.join( + cellacdc_path, "bioformats", "jars", "bioformats_package.jar" + ) + if os.path.exists(dst_filepath): + return True, dst_filepath + urls_to_try = (urls.bioformats_jar_home_url, urls.bioformats_jar_hmgu_url) + success = False + for url in urls_to_try: + try: + logger_info(f"Downloading `bioformats_package.jar`...") + download_url(url, dst_filepath, file_size=43233280) + success = True + break + except Exception as err: + success = False + traceback_str = traceback.format_exc() + logger_exception(traceback_str) + continue + + if success: + return True, dst_filepath + + _warnings.warn_download_bioformats_jar_failed(dst_filepath, qparent=qparent) + raise ModuleNotFoundError( + "Bioformats package jar could not be downloaded. Please, " + f"download it from here {urls.bioformats_download_page} and " + f'place it in the following path "{dst_filepath}". ' + "Thank you for your patience!" + ) + return False, dst_filepath + + +def download_url(url, dst, desc="", file_size=None, verbose=True, progress=None): + import urllib3 + + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + CHUNK_SIZE = 32768 + if verbose: + print(f"Downloading {desc} to: {os.path.dirname(dst)}") + response = requests.get(url, stream=True, timeout=20, verify=False) + if file_size is not None and progress is not None: + progress.emit(file_size, -1) + pbar = tqdm( + total=file_size, unit="B", unit_scale=True, unit_divisor=1024, ncols=100 + ) + with open(dst, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + # if chunk: + f.write(chunk) + pbar.update(len(chunk)) + if progress is not None: + progress.emit(-1, len(chunk)) + pbar.close() + + +def _write_model_location_to_txt(model_name): + model_info_path = os.path.join(models_path, model_name, "model") + model_path = os.path.join(user_profile_path, f"acdc-{model_name}") + file = "weights_location_path.txt" + with open(os.path.join(model_info_path, file), "w") as txt: + txt.write(model_path) + return os.path.expanduser(model_path) + + +def download_model(model_name): + if model_name == "segment_anything": + try: + _download_segment_anything_models() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "sam2": + try: + _download_sam2_models() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "DeepSea": + try: + _download_deepsea_models() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "TAPIR": + try: + _download_tapir_model() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "YeaZ_v2": + try: + _download_yeaz_models() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "Cellpose_germlineNuclei": + try: + _download_cellpose_germlineNuclei_model() + return True + except Exception as e: + traceback.print_exc() + return False + elif model_name == "omnipose": + try: + _download_omnipose_models() + return True + except Exception as err: + return False + elif model_name != "YeastMate" and model_name != "YeaZ": + # We manage only YeastMate and YeaZ + return True + + try: + # Check if model exists + temp_zip_path, model_path = get_model_path(model_name) + if not temp_zip_path: + # Model exists return + return True + + # Check if user has model in the old v1.2.3 location + v123_model_path = check_v123_model_path(model_name) + if v123_model_path: + print(f"Weights files found in {v123_model_path}") + print(f"--> moving to new location: {model_path}...") + for file in listdir(v123_model_path): + src = os.path.join(v123_model_path, file) + dst = os.path.join(model_path, file) + shutil.copy(src, dst) + return True + + # Download model from url to tempDir/model_temp.zip + temp_dir = os.path.dirname(temp_zip_path) + url, file_size = _model_url(model_name) + print(f"Downloading {model_name} to {model_path}") + download_url( + url, temp_zip_path, file_size=file_size, desc=model_name, verbose=False + ) + + # Extract zip file inside temp dir + print(f"Extracting model...") + extract_zip(temp_zip_path, temp_dir, verbose=False) + + # Move unzipped files to ~/acdc-{model_name} folder + print(f"Moving files from temporary folder to {model_path}...") + for file in listdir(temp_dir): + if file.endswith(".zip"): + continue + src = os.path.join(temp_dir, file) + dst = os.path.join(model_path, file) + shutil.move(src, dst) + + # Remove temp directory + print(f"Removing temporary folder...") + shutil.rmtree(temp_dir) + return True + + except Exception as e: + traceback.print_exc() + return False + + +def aliases_real_time_trackers(reverse=False): + """ + Returns a dictionary with aliases for real-time trackers. + """ + + aliases = { + "CellACDC_normal_division": "Cell-ACDC symmetric division", + "CellACDC_2steps": "Cell-ACDC 2 steps", + } + + if reverse: + aliases = {v: k for k, v in aliases.items()} + + return aliases + + +def get_list_of_real_time_trackers(): + trackers = get_list_of_trackers() + rt_trackers = [] + aliases = aliases_real_time_trackers() + for tracker in trackers: + if tracker == "CellACDC": + continue + if tracker == "YeaZ": + continue + tracker_filename = f"{tracker}_tracker.py" + tracker_path = os.path.join( + cellacdc_path, "trackers", tracker, tracker_filename + ) + try: + with open(tracker_path) as file: + txt = file.read() + if txt.find("def track_frame") != -1: + rt_trackers.append(tracker) + except Exception as e: + continue + + for i, tracker in enumerate(rt_trackers): + if tracker in aliases: + rt_trackers[i] = aliases[tracker] + + return natsorted(rt_trackers, key=str.casefold) + + +def get_list_of_trackers(): + trackers_path = os.path.join(cellacdc_path, "trackers") + trackers = [] + for name in listdir(trackers_path): + _path = os.path.join(trackers_path, name) + tracker_script_path = os.path.join(_path, f"{name}_tracker.py") + is_valid_tracker = ( + os.path.isdir(_path) + and os.path.exists(tracker_script_path) + and not name.endswith("__") + ) + + if name.startswith("_"): + continue + + if is_valid_tracker: + trackers.append(name) + return natsorted(trackers, key=str.casefold) + + +def get_list_of_models(): + models = set() + for name in listdir(models_path): + _path = os.path.join(models_path, name) + if not os.path.exists(_path): + continue + + if not os.path.isdir(_path): + continue + + if name.endswith("__"): + continue + + if name.startswith("_"): + continue + + if name == "skip_segmentation": + continue + + if not os.path.exists(os.path.join(_path, "acdcSegment.py")): + continue + + if name == "thresholding": + name = "Automatic thresholding" + + models.add(name) + + if not os.path.exists(models_list_file_path): + return natsorted(list(models), key=str.casefold) + + cp = config.ConfigParser() + cp.read(models_list_file_path) + models.update(cp.sections()) + return natsorted(list(models), key=str.casefold) + + +def get_list_of_promptable_models(): + models = set() + for name in listdir(promptable_models_path): + _path = os.path.join(promptable_models_path, name) + if not os.path.exists(_path): + continue + + if not os.path.isdir(_path): + continue + + if name.endswith("__"): + continue + + if not os.path.exists(os.path.join(_path, "acdcPromptSegment.py")): + continue + + models.add(name) + + if not os.path.exists(promptable_models_list_file_path): + return natsorted(list(models), key=str.casefold) + + cp = config.ConfigParser() + cp.read(promptable_models_list_file_path) + models.update(cp.sections()) + return natsorted(list(models), key=str.casefold) + + +def download_fiji(logger_func=print): + url = None + if is_mac: + url = "https://downloads.micron.ox.ac.uk/fiji_update/mirrors/fiji-latest/fiji-macosx.zip" + file_size = 474_525_405 + + if url is None: + return + + if os.path.exists(get_fiji_exec_folderpath()): + return + + os.makedirs(acdc_fiji_path) + + temp_dir = tempfile.mkdtemp() + zip_dst = os.path.join(temp_dir, "fiji-macosx.zip") + logger_func(f'Downloading Fiji to "{acdc_fiji_path}"...') + download_url(url, zip_dst, verbose=False, file_size=file_size) + extract_zip(zip_dst, acdc_fiji_path) + + return acdc_fiji_path + + +def import_tracker_module(tracker_name): + module_name = f"cellacdc.trackers.{tracker_name}.{tracker_name}_tracker" + tracker_module = import_module(module_name) + return tracker_module + + +def download_ffmpeg(): + ffmpeg_folderpath = acdc_ffmpeg_path + if is_win: + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/rXioWZpwjwn9JTT/download/windows_ffmpeg-7.0-full_build.zip" + file_size = 173477888 + ffmep_exec_path = os.path.join(ffmpeg_folderpath, "bin", "ffmpeg.exe") + elif is_mac: + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/We7rcTLzqAP4zf7/download/mac_ffmpeg.zip" + file_size = 25288704 + ffmep_exec_path = os.path.join(ffmpeg_folderpath, "ffmpeg") + elif is_linux: + ffmep_exec_path = "" + return ffmep_exec_path + + if os.path.exists(ffmep_exec_path): + return ffmep_exec_path.replace("\\", os.sep).replace("/", os.sep) + + print("Downloading FFMPEG...") + temp_dir = tempfile.mkdtemp() + temp_zip_path = os.path.join(temp_dir, "acdc-ffmpeg.zip") + + download_url( + url, + temp_zip_path, + verbose=True, + file_size=file_size, + ) + extract_zip(temp_zip_path, ffmpeg_folderpath) + + return ffmep_exec_path.replace("\\", os.sep).replace("/", os.sep) + + +def import_promptable_segment_module(model_name): + try: + acdcPromptSegment = import_module( + f"cellacdc.segmenters_promptable.{model_name}.acdcPromptSegment" + ) + except ModuleNotFoundError as e: + # Check if custom model + cp = config.ConfigParser() + cp.read(promptable_models_list_file_path) + model_path = cp[model_name]["path"] + spec = importlib.util.spec_from_file_location("acdcPromptSegment", model_path) + acdcPromptSegment = importlib.util.module_from_spec(spec) + sys.modules["acdcPromptSegment"] = acdcPromptSegment + spec.loader.exec_module(acdcPromptSegment) + return acdcPromptSegment + + +def init_tracker( + posData, trackerName, realTime=False, qparent=None, return_init_params=False +): + from . import apps + + downloadWin = apps.downloadModel(trackerName, parent=qparent) + downloadWin.download() + + trackerModule = import_tracker_module(trackerName) + init_params = {} + track_params = {} + paramsWin = None + if trackerName == "BayesianTracker": + Y, X = posData.img_data_shape[-2:] + if posData.isSegm3D: + labShape = (posData.SizeZ, Y, X) + else: + labShape = (1, Y, X) + paramsWin = apps.BayesianTrackerParamsWin( + labShape, + parent=qparent, + channels=posData.chNames, + currentChannelName=posData.user_ch_name, + ) + paramsWin.exec_() + if not paramsWin.cancel: + init_params = paramsWin.params + track_params["export_to"] = posData.get_btrack_export_path() + if paramsWin.intensityImageChannel is not None: + chName = paramsWin.intensityImageChannel + track_params["image"] = posData.loadChannelData(chName) + track_params["image_channel_name"] = chName + elif trackerName == "CellACDC": + paramsWin = apps.CellACDCTrackerParamsWin(parent=qparent) + paramsWin.exec_() + if not paramsWin.cancel: + init_params = paramsWin.params + elif trackerName == "delta": + paramsWin = apps.DeltaTrackerParamsWin(posData=posData, parent=qparent) + paramsWin.exec_() + if not paramsWin.cancel: + init_params = paramsWin.params + else: + init_argspecs, track_argspecs = getTrackerArgSpec( + trackerModule, realTime=realTime + ) + intensityImgRequiredForTracker = isIntensityImgRequiredForTracker(trackerModule) + if init_argspecs or track_argspecs: + try: + url = trackerModule.url_help() + except AttributeError: + url = None + try: + channels = posData.chNames + except Exception as e: + channels = None + try: + currentChannelName = posData.user_ch_name + except Exception as e: + currentChannelName = None + try: + df_metadata = posData.metadata_df + except Exception as e: + df_metadata = None + + if not intensityImgRequiredForTracker: + currentChannelName = None + + paramsWin = apps.QDialogModelParams( + init_argspecs, + track_argspecs, + trackerName, + url=url, + channels=channels, + is_tracker=True, + currentChannelName=currentChannelName, + df_metadata=df_metadata, + posData=posData, + ) + if not intensityImgRequiredForTracker and channels is not None: + paramsWin.channelCombobox.setDisabled(True) + + paramsWin.exec_() + if not paramsWin.cancel: + init_params = paramsWin.init_kwargs + track_params = paramsWin.model_kwargs + if paramsWin.inputChannelName != "None": + chName = paramsWin.inputChannelName + track_params["image"] = posData.loadChannelData(chName) + track_params["image_channel_name"] = chName + if "export_to_extension" in track_params: + ext = track_params["export_to_extension"] + track_params["export_to"] = posData.get_tracker_export_path( + trackerName, ext + ) + + if paramsWin is not None and paramsWin.cancel: + tracker = (None,) + track_params = None + init_params = None + else: + tracker = trackerModule.tracker(**init_params) + + if return_init_params: + return tracker, track_params, init_params + else: + return tracker, track_params + + +def _download_tapir_model(): + urls, file_sizes = _model_url("TAPIR") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path("TAPIR", create_temp_dir=False) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url(url, temp_dst, file_size=file_size, desc="TAPIR", verbose=False) + + shutil.move(temp_dst, final_dst) + + +def _download_yeaz_models(): + urls, file_sizes = _model_url("YeaZ_v2") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path("YeaZ_v2", create_temp_dir=False) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url(url, temp_dst, file_size=file_size, desc="YeaZ_v2", verbose=False) + + shutil.move(temp_dst, final_dst) + + +def _download_cellpose_germlineNuclei_model(): + urls, file_sizes = _model_url("Cellpose_germlineNuclei") + temp_model_path = tempfile.mkdtemp() + _, final_model_path = get_model_path( + "Cellpose_germlineNuclei", create_temp_dir=False + ) + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url( + url, + temp_dst, + file_size=file_size, + desc="Cellpose_germlineNuclei", + verbose=False, + ) + + shutil.move(temp_dst, final_dst) + + +def _download_omnipose_models(): + urls, file_sizes = _model_url("omnipose") + temp_model_path = tempfile.mkdtemp() + final_model_path = os.path.expanduser(r"~\.cellpose\models") + for url, file_size in zip(urls, file_sizes): + filename = url.split("/")[-1] + final_dst = os.path.join(final_model_path, filename) + if os.path.exists(final_dst): + continue + + temp_dst = os.path.join(temp_model_path, filename) + download_url(url, temp_dst, file_size=file_size, desc="omnipose", verbose=False) + + shutil.move(temp_dst, final_dst) + + +def init_prompt_segm_model(acdcPromptSegment, posData, init_kwargs): + model = acdcPromptSegment.Model(**init_kwargs) + return model + + +def init_segm_model(acdcSegment, posData, init_kwargs): + segm_endname = init_kwargs.pop("segm_endname", "None") + if segm_endname != "None": + load_segm = True + if not hasattr(posData, "segm_data"): + load_segm = True + elif posData.segm_npz_path.endswith(f"{segm_endname}.npz"): + load_segm = False + if not load_segm: + segm_data = np.squeeze(posData.segm_data) + else: + segm_filepath, _ = load.get_path_from_endname( + segm_endname, posData.images_path + ) + printl(f'Loading segmentation data from "{segm_filepath}"...') + segm_data = np.load(segm_filepath)["arr_0"] + else: + segm_data = None + + # Initialize input_points_df for models promptable with points + input_points_filepath = init_kwargs.pop("input_points_path", "") + if input_points_filepath: + input_points_df = init_input_points_df(posData, input_points_filepath) + init_kwargs["input_points_df"] = input_points_df + + try: + # Models introduced before 1.3.2 do not have the segm_data as input + kwargs = inspect.getfullargspec(acdcSegment.Model.__init__).args + if "is_rgb" not in kwargs and "is_rgb" in init_kwargs: + del init_kwargs["is_rgb"] + model = acdcSegment.Model(**init_kwargs) + + except Exception as e: + model = acdcSegment.Model(segm_data, **init_kwargs) + + if hasattr(model, "init_successful"): + if not model.init_successful: + return None + return model + + +def parse_model_params(model_argspecs, model_params): + parsed_model_params = {} + for row, argspec in enumerate(model_argspecs): + value = model_params.get(argspec.name) + if value is None: + continue + if argspec.type == bool: + value = _parse_bool_str(value) + elif argspec.type == int: + value = int(value) + elif argspec.type == float: + value = float(value) + parsed_model_params[argspec.name] = value + return parsed_model_params + + +def validate_tracker_input(tracker, segm_video_to_track): + try: + warning_text = tracker.validate_input(segm_video_to_track) + return warning_text + except Exception as err: + printl(traceback.format_exc()) + pass + return + +# Sibling imports (deferred to avoid import cycles) +from .misc import ( + _get_doc_stop_idx, + _parse_bool_str, + add_segm_data_param, + extract_zip, + init_input_points_df, +) +from .paths import ( + check_v123_model_path, + get_examples_path, + get_fiji_exec_folderpath, + get_model_path, + listdir, +) + diff --git a/cellacdc/utils/paths.py b/cellacdc/utils/paths.py new file mode 100644 index 000000000..da35742ef --- /dev/null +++ b/cellacdc/utils/paths.py @@ -0,0 +1,455 @@ +"""Cell-ACDC utility helpers: paths.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_pos_status_acdc(pos_path): + images_path = os.path.join(pos_path, "Images") + ls = listdir(images_path) + for file in ls: + if file.endswith("acdc_output.csv"): + acdc_df_path = os.path.join(images_path, file) + break + else: + return "" + + acdc_df = pd.read_csv(acdc_df_path) + last_tracked_i = acdc_df["frame_i"].max() + last_cca_i = 0 + if "cell_cycle_stage" in acdc_df.columns: + cca_df = acdc_df[["frame_i", "cell_cycle_stage"]].dropna() + last_cca_i = cca_df["frame_i"].max() + if last_cca_i > 0: + return ( + f" (last tracked frame = {last_tracked_i + 1}, " + f"last annotated frame = {last_cca_i + 1})" + ) + else: + return f" (last tracked frame = {last_tracked_i + 1})" + + +def get_pos_status_spotmax(pos_path): + spotmax_out_path = os.path.join(pos_path, "spotMAX_output") + is_smax_out_present = "Yes" if os.path.exists(spotmax_out_path) else "No" + if os.path.exists(spotmax_out_path): + return " (SpotMAX output exists)" + else: + return "" + + +def get_pos_status(pos_path, caller: Literal["Cell-ACDC", "SpotMAX"] = "Cell-ACDC"): + if caller == "Cell-ACDC": + return get_pos_status_acdc(pos_path) + + if caller == "SpotMAX": + return get_pos_status_spotmax(pos_path) + + +def get_gdrive_path(): + if is_win: + return os.path.join(f"G:{os.sep}", "My Drive") + elif is_mac: + return os.path.join( + "/Users/francesco.padovani/Library/CloudStorage/" + "GoogleDrive-padovaf@tcd.ie/My Drive" + ) + + +def get_acdc_data_path(): + Cell_ACDC_path = os.path.dirname(cellacdc_path) + return os.path.join(Cell_ACDC_path, "data") + + +def get_open_filemaneger_os_string(): + if is_win: + return "Show in Explorer..." + elif is_mac: + return "Reveal in Finder..." + elif is_linux: + return "Show in File Manager..." + + +def trim_path(path, depth=3, start_with_dots=True): + path_li = os.path.abspath(path).split(os.sep) + rel_path = f"{f'{os.sep}'.join(path_li[-depth:])}" + if start_with_dots: + return f"...{os.sep}{rel_path}" + else: + return rel_path + + +def get_pos_foldernames(exp_path, check_if_is_sub_folder=False): + if not check_if_is_sub_folder: + ls = listdir(exp_path) + pos_foldernames = [ + pos for pos in ls if is_pos_folderpath(os.path.join(exp_path, pos)) + ] + else: + folder_type = determine_folder_type(exp_path) + is_pos_folder, is_images_folder, _ = folder_type + if is_pos_folder: + return [os.path.basename(exp_path)] + elif is_images_folder: + pos_path = os.path.dirname(exp_path) + if is_pos_folderpath(pos_path): + return [os.path.basename(pos_path)] + else: + return [] + else: + return get_pos_foldernames(exp_path) + return pos_foldernames + + +def get_images_folderpath(folderpath): + if os.path.isfile(folderpath): + folderpath = os.path.dirname(folderpath) + + if folderpath.endswith("Images"): + return folderpath + + images_folderpath = os.path.join(folderpath, "Images") + if os.path.exists(images_folderpath): + return images_folderpath + + return "" + + +def store_custom_model_path(model_file_path): + model_file_path = model_file_path.replace("\\", "/") + model_name = os.path.basename(os.path.dirname(model_file_path)) + cp = config.ConfigParser() + if os.path.exists(models_list_file_path): + cp.read(models_list_file_path) + if model_name not in cp: + cp[model_name] = {} + cp[model_name]["path"] = model_file_path + with open(models_list_file_path, "w") as configFile: + cp.write(configFile) + + +def store_custom_promptable_model_path(promptable_model_file_path): + model_file_path = promptable_model_file_path.replace("\\", "/") + model_name = os.path.basename(os.path.dirname(model_file_path)) + cp = config.ConfigParser() + if os.path.exists(promptable_models_list_file_path): + cp.read(promptable_models_list_file_path) + if model_name not in cp: + cp[model_name] = {} + cp[model_name]["path"] = model_file_path + with open(promptable_models_list_file_path, "w") as configFile: + cp.write(configFile) + + +def listdir(path) -> List[str]: + return natsorted( + [ + f + for f in os.listdir(path) + if not f.startswith(".") + and not f == "desktop.ini" + and not f == "recovery" + and not f.endswith(".new.npz") + ] + ) + + +def get_examples_path(which): + if which == "time_lapse_2D": + foldername = "TimeLapse_2D" + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/KgJQtsQKZJnWZjL/download/TimeLapse_2D.zip" + file_size = 45143552 + elif which == "snapshots_3D": + foldername = "Multi_3D_zStack_Analysed" + url = "https://hmgubox2.helmholtz-muenchen.de/index.php/s/3RNjGiPwKcdnGtj/download/Yeast_Analysed_multi3D_zStacks.zip" + file_size = 124822528 + else: + return "" + + examples_path = os.path.join(user_profile_path, "acdc-examples") + example_path = os.path.join(examples_path, foldername) + return examples_path, example_path, url, file_size + + +def get_acdc_java_path(): + acdc_java_path = os.path.join(user_profile_path, "acdc-java") + dot_acdc_java_path = os.path.join(user_profile_path, ".acdc-java") + return acdc_java_path, dot_acdc_java_path + + +def get_model_path(model_name, create_temp_dir=True): + if model_name == "Automatic thresholding": + model_name == "thresholding" + + model_info_path = os.path.join(models_path, model_name, "model") + + if os.path.exists(model_info_path): + for file in listdir(model_info_path): + if file != "weights_location_path.txt": + continue + with open(os.path.join(model_info_path, file), "r") as txt: + model_path = txt.read() + model_path = os.path.expanduser(model_path) + if not os.path.exists(model_path): + model_path = _write_model_location_to_txt(model_name) + else: + break + else: + model_path = _write_model_location_to_txt(model_name) + else: + os.makedirs(model_info_path, exist_ok=True) + model_path = _write_model_location_to_txt(model_name) + + model_path = migrate_to_new_user_profile_path(model_path) + + if not os.path.exists(model_path): + os.makedirs(model_path, exist_ok=True) + + if not create_temp_dir: + return "", model_path + + exists = check_model_exists(model_path, model_name) + if exists: + return "", model_path + + temp_zip_path = _create_temp_dir() + return temp_zip_path, model_path + + +def _create_temp_dir(): + temp_model_path = tempfile.mkdtemp() + temp_zip_path = os.path.join(temp_model_path, "model_temp.zip") + return temp_zip_path + + +def check_v123_model_path(model_name): + # Cell-ACDC v1.2.3 saved the weights inside the package, + # while from v1.2.4 we save them on user folder. If we find the + # weights in the package we move them to user folder without downloading + # new ones. + v123_model_path = os.path.join(models_path, model_name, "model") + exists = check_model_exists(v123_model_path, model_name) + if exists: + return v123_model_path + else: + return "" + + +def is_old_user_profile_path(path_to_check: os.PathLike): + from . import user_data_dir + + user_data_folderpath = user_data_dir() + user_profile_path_txt = os.path.join( + user_data_folderpath, "acdc_user_profile_location.txt" + ) + if os.path.exists(user_profile_path_txt): + return False + + from . import user_home_path + + user_home_path = user_home_path.replace("\\", "/") + path_to_check = path_to_check.replace("\\", "/") + return user_home_path == path_to_check + + +def migrate_to_new_user_profile_path(path_to_migrate: os.PathLike): + parent_dir = os.path.dirname(path_to_migrate) + if not is_old_user_profile_path(parent_dir): + return path_to_migrate + folder = os.path.basename(path_to_migrate) + return os.path.join(user_profile_path, folder) + + +def determine_folder_type(folder_path): + is_pos_folder = is_pos_folderpath(folder_path) + is_images_folder = folder_path.endswith("Images") and listdir(folder_path) + contains_images_folder = os.path.exists(os.path.join(folder_path, "Images")) + contains_pos_folders = len(get_pos_foldernames(folder_path)) > 0 + if contains_pos_folders: + is_pos_folder = False + is_images_folder = False + elif contains_images_folder and not is_pos_folder: + # Folder created by loading an image + is_images_folder = True + folder_path = os.path.join(folder_path, "Images") + + return is_pos_folder, is_images_folder, folder_path + + +def to_relative_path(path, levels=3, prefix="..."): + path = path.replace("\\", "/") + parts = path.split("/") + if levels >= len(parts): + return path + parts = parts[-levels:] + rel_path = "/".join(parts) + rel_path.replace("/", os.sep) + if prefix: + rel_path = f"{prefix}{os.sep}{rel_path}" + return rel_path + + +def get_fiji_binary_filepath_mac(fiji_app_filepath): + if not is_mac: + return "" + + fiji_binary_path = os.path.join( + fiji_app_filepath, "Contents", "MacOS", "ImageJ-macosx" + ) + if os.path.exists(fiji_binary_path): + return fiji_binary_path + + fiji_binary_path = os.path.join( + fiji_app_filepath, "Contents", "MacOS", "fiji-macos" + ) + if os.path.exists(fiji_binary_path): + return fiji_binary_path + + return "" + + +def get_fiji_exec_folderpath() -> str: + if not is_mac: + return "" + + from cellacdc import fiji_location_filepath + + if os.path.exists(fiji_location_filepath): + with open(fiji_location_filepath, "r") as txt: + fiji_app_filepath = txt.read() + + return get_fiji_binary_filepath_mac(fiji_app_filepath) + + if os.path.exists("/Applications/Fiji.app"): + return get_fiji_binary_filepath_mac("/Applications/Fiji.app") + + acdc_fiji_app_path = os.path.join(acdc_fiji_path, "Fiji.app") + acdc_fiji_binary_path = get_fiji_binary_filepath_mac(acdc_fiji_app_path) + + return acdc_fiji_binary_path + + +def is_pos_folderpath(folderpath): + """Determine if a path is a valid Cell-ACDC Position folder + + Parameters + ---------- + folderpath : PathLike + Path to check + + Returns + ------- + bool + True if the path is a valid Cell-ACDC Position folder, False otherwise + + Notes + ----- + A valid Cell-ACDC Position folder must: + - Have a name matching the pattern 'Position_' + - Be a directory + - Contain an 'Images' subdirectory + - The 'Images' subdirectory must not be empty + """ + foldername = os.path.basename(folderpath) + is_valid_pos_folder = ( + re.search(r"^Position_(\d+)$", foldername) is not None + and os.path.isdir(folderpath) + and os.path.exists(os.path.join(folderpath, "Images")) + and listdir(os.path.join(folderpath, "Images")) + ) + return is_valid_pos_folder + + +def validate_images_path(input_path: os.PathLike, create_dirs_tree=False): + is_images_path = input_path.endswith("Images") + parent_dir = os.path.dirname(input_path) + parent_foldername = os.path.basename(parent_dir) + is_pos_folder = re.search( + r"^Position_(\d+)$", parent_foldername + ) is not None and os.path.isdir(parent_dir) + if not is_pos_folder: + existing_pos_foldernames = get_pos_foldernames(input_path) + pos_n = len(existing_pos_foldernames) + 1 + pos_folderpath = os.path.join(input_path, f"Position_{pos_n}") + images_path = os.path.join(pos_folderpath, "Images") + elif is_images_path: + pos_folderpath = input_path + images_path = os.path.join(pos_folderpath, "Images") + else: + images_path = input_path + + if create_dirs_tree: + os.makedirs(images_path, exist_ok=True) + + return images_path + +# Sibling imports (deferred to avoid import cycles) +from .models import ( + _write_model_location_to_txt, + check_model_exists, +) + diff --git a/cellacdc/utils/qt.py b/cellacdc/utils/qt.py new file mode 100644 index 000000000..4d809615a --- /dev/null +++ b/cellacdc/utils/qt.py @@ -0,0 +1,80 @@ +"""Cell-ACDC utility helpers: qt.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def testQcoreApp(): + print(QCoreApplication.instance()) + + +def get_cli_multi_choice_question(question, choices): + choices_format = [f"{i + 1}) {choice}." for i, choice in enumerate(choices)] + choices_format = " ".join(choices_format) + choices_opts = "/".join([str(i) for i in range(1, len(choices) + 1)]) + text = f"{question} {choices_format} q) Quit. ({choices_opts})?: " + return text diff --git a/cellacdc/utils/text.py b/cellacdc/utils/text.py new file mode 100644 index 000000000..0faff3a72 --- /dev/null +++ b/cellacdc/utils/text.py @@ -0,0 +1,141 @@ +"""Cell-ACDC utility helpers: text.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_trimmed_list(li: list, max_num_digits=10): + if len(li) == 0: + return "[]" + + tom_num_digits = sum([len(str(val)) for val in li]) + + if tom_num_digits == 0: + return f"[{', '.join(map(str, li))}]" + + avg_num_digits = tom_num_digits / len(li) + max_num_vals = int(round(max_num_digits / avg_num_digits)) + + if tom_num_digits > max_num_digits: + front_vals = ceil(max_num_vals / 2) + back_vals = max_num_vals // 2 + + if front_vals + back_vals >= len(li): + return f"[{', '.join(map(str, li))}]" + + li = li[:front_vals] + ["..."] + li[len(li) - back_vals :] + + return f"[{', '.join(map(str, li))}]" + + +def get_trimmed_dict(di: dict, max_num_digits=10): + di_str = di.copy() + total_num_digits = sum([len(str(key)) + len(str(val)) for key, val in di.items()]) + avg_num_digits = total_num_digits / len(di) + max_num_vals = int(round(max_num_digits / avg_num_digits)) + if total_num_digits > max_num_digits: + keys = list(di_str.keys()) + for key in keys[max_num_vals:-max_num_vals]: + del di_str[key] + di_str[keys[max_num_vals]] = "..." + return f"[{', '.join([f'{key} -> {val}' for key, val in di_str.items()])}]" + + +def get_number_fstring_formatter(dtype, precision=4): + if np.issubdtype(dtype, np.integer): + return "d" + else: + return f".{precision}f" + + +def elided_text(text, max_len=50, elid_idx=None): + if len(text) <= max_len: + return text + + if elid_idx is None: + elid_idx = int(max_len / 2) + if elid_idx >= max_len: + elid_idx = max_len - 1 + idx1 = elid_idx + idx2 = elid_idx - max_len + text = f"{text[:idx1]}...{text[idx2:]}" + return text + + +def get_show_in_file_manager_text(): + if is_mac: + return "Reveal in Finder" + elif is_linux: + return "Show in File Manager" + elif is_win: + return "Show in File Explorer" + + +def append_text_filename(filename: str, text_to_append: str): + filename_noext, ext = os.path.splitext(filename) + filename_out = f"{filename_noext}{text_to_append}{ext}" + return filename_out diff --git a/cellacdc/utils/version.py b/cellacdc/utils/version.py new file mode 100644 index 000000000..fe12d9a54 --- /dev/null +++ b/cellacdc/utils/version.py @@ -0,0 +1,555 @@ +"""Cell-ACDC utility helpers: version.""" + +import os +import re +import ast + +import typing +from typing import Literal, List, Callable, Tuple, Dict + +import pathlib +import difflib +import sys +import platform +import tempfile +import shutil +import traceback +import logging +import datetime +import time +import subprocess +import importlib +from uuid import uuid4 +from importlib import import_module +from math import pow, ceil, floor +from functools import wraps, partial +from collections import namedtuple, Counter +from tqdm import tqdm +import requests +import zipfile +import json +import numpy as np +import pandas as pd +import skimage +import inspect + +import traceback +import itertools +from packaging import version as packaging_version + +from natsort import natsorted + +import tifffile +import skimage.io +import skimage.measure + +from .. import GUI_INSTALLED, KNOWN_EXTENSIONS, is_conda_env + +from .. import core, load +from .. import html_utils, is_linux, is_win, is_mac, issues_url, is_mac_arm64 +from .. import cellacdc_path, printl, acdc_fiji_path, logs_path, acdc_ffmpeg_path +from .. import user_profile_path, recentPaths_path +from .. import models_list_file_path, models_path +from .. import promptable_models_list_file_path, promptable_models_path +from .. import github_home_url +from .. import try_input_install_package +from .. import _warnings +from .. import urls +from .. import qrc_resources_path +from .. import settings_folderpath +from ..segmenters._cellpose_base import min_target_versions_cp + +if GUI_INSTALLED: + from qtpy.QtWidgets import QMessageBox + from qtpy.QtCore import Signal, QObject, QCoreApplication + + from .. import widgets, apps + from .. import config + +ArgSpec = namedtuple("ArgSpec", ["name", "default", "type", "desc", "docstring"]) + +def get_salute_string(): + time_now = datetime.datetime.now().time() + time_end_morning = datetime.time(12, 00, 00) + time_end_lunch = datetime.time(13, 00, 00) + time_end_afternoon = datetime.time(15, 00, 00) + time_end_evening = datetime.time(20, 00, 00) + time_end_night = datetime.time(4, 00, 00) + if time_now >= time_end_night and time_now < time_end_morning: + return "Have a good day!" + elif time_now >= time_end_morning and time_now < time_end_lunch: + return "Enjoy your lunch!" + elif time_now >= time_end_lunch and time_now < time_end_afternoon: + return "Have a good afternoon!" + elif time_now >= time_end_afternoon and time_now < time_end_evening: + return "Have a good evening!" + else: + return "Have a good night!" + + +def get_info_version_text(is_cli=False, cli_formatted_text=True): + version = read_version() + release_date = get_date_from_version(version, package="cellacdc") + py_ver = sys.version_info + env_folderpath = sys.prefix + python_version = f"{py_ver.major}.{py_ver.minor}.{py_ver.micro}" + info_txts = [ + f"Version {version}", + f"Released on: {release_date}", + f'Installed in "{cellacdc_path}"', + f'Environment folder: "{env_folderpath}"', + f'User profile folder: "{user_profile_path}"', + f'Settings folder: "{settings_folderpath}"', + f"Python {python_version}", + f"Platform: {platform.platform()}", + f"System: {platform.system()}", + ] + if is_linux: + try: + distro_name = get_linux_distribution_name() + except Exception as err: + distro_name = "Undetermined" + + info_txts.append(f"Linux distribution: {distro_name}") + + if GUI_INSTALLED and not is_cli: + info_txts.append(f'Icons from: "{qrc_resources_path}"') + try: + from qtpy import QtCore + + info_txts.append(f"Qt {QtCore.__version__}") + except Exception as err: + info_txts.append("Qt: Not installed") + + try: + branch_name = get_git_branch_name() + info_txts.append(f'Git branch: "{branch_name}"') + except Exception as err: + pass + + info_txts.append(f"Working directory: {os.getcwd()}") + + if not cli_formatted_text: + return info_txts + + info_txts = [f" - {txt}" for txt in info_txts] + + max_len = max([len(txt) for txt in info_txts]) + 2 + + formatted_info_txts = [] + for txt in info_txts: + horiz_spacing = " " * (max_len - len(txt)) + txt = f"{txt}{horiz_spacing}|" + formatted_info_txts.append(txt) + + formatted_info_txts.insert(0, "Cell-ACDC info:\n") + formatted_info_txts.insert(0, "=" * max_len) + formatted_info_txts.append("=" * max_len) + info_txt = "\n".join(formatted_info_txts) + + try: + from spotmax.utils import get_info_version_text as smax_info + + smax_info_txt = smax_info(include_platform=False, is_cli=is_cli) + info_txt += "\n\n" + smax_info_txt + except ImportError: + pass + + return info_txt + + +def read_version(logger=None, return_success=False): + cellacdc_parent_path = os.path.dirname(cellacdc_path) + cellacdc_parent_folder = os.path.basename(cellacdc_parent_path) + if cellacdc_parent_folder == "site-packages": + from . import __version__ + + version = __version__ + success = True + else: + try: + from setuptools_scm import get_version + + version = get_version(root="..", relative_to=__file__) + success = True + except Exception as e: + if logger is None: + logger = print + logger("*" * 40) + logger(traceback.format_exc()) + logger("-" * 40) + logger( + "[WARNING]: Cell-ACDC could not determine the current version. " + "Returning the version determined at installation time. " + "See details above." + ) + logger("=" * 40) + try: + from . import _version + + version = _version.version + success = False + except Exception as e: + version = "ND" + success = False + + if return_success: + return version, success + else: + return version + + +def get_date_from_version(version: str, package="cellacdc", debug=False): + try: + response = requests.get(f"https://pypi.org/pypi/{package}/json", timeout=2) + res_json = response.json() + pypi_releases_json = res_json["releases"] + version_json = pypi_releases_json[version][0] + upload_time = version_json["upload_time_iso_8601"] + date = datetime.datetime.strptime(upload_time, r"%Y-%m-%dT%H:%M:%S.%fZ") + date_str = date.strftime(r"%A %d %B %Y at %H:%M") + return date_str + except Exception as err: + if debug: + traceback.print_exc() + + try: + # Locate the direct_url.json file for the package + # installed with pip git+ + dist = importlib.metadata.distribution(package) + dist_info_dir = dist._path # internal path to .dist-info + direct_url_path = os.path.join(dist_info_dir, "direct_url.json") + + with open(direct_url_path) as f: + data = json.load(f) + + vcs_info = data["vcs_info"] + commit_id = vcs_info.get("commit_id") + url = data.get("url") + + parts = url.split("github.com/")[1].split(".git")[0] + owner, repo = parts.split("/", 1) + + # Query GitHub API for commit date + api_url = f"https://api.github.com/repos/{owner}/{repo}/commits/{commit_id}" + response = requests.get(api_url) + response.raise_for_status() + + commit_data = response.json() + date_utc = commit_data["commit"]["committer"]["date"] + + date_str = format_commit_date_utc(date_utc) + + return date_str + except Exception as err: + if debug: + traceback.print_exc() + + try: + if package == "cellacdc": + pkg_path = cellacdc_path + elif package == "spotmax": + from spotmax import spotmax_path + + pkg_path = spotmax_path + commit_hash = re.findall(r"\+g([A-Za-z0-9]+)(\.d)?", version)[0][0] + git_path = os.path.dirname(pkg_path) + command = f"git -C {git_path} show {commit_hash}" + commit_log = _subprocess_run_command( + command, shell=False, callback="check_output" + ) + commit_log = commit_log.decode() + date_log = re.findall(r"Date:(.*) \+", commit_log)[0].strip() + date = datetime.datetime.strptime(date_log, r"%a %b %d %H:%M:%S %Y") + date_str = date.strftime(r"%A %d %B %Y at %H:%M") + return date_str + except Exception as err: + if debug: + traceback.print_exc() + + return "ND" + + +def get_git_branch_name(): + command = "git rev-parse --abbrev-ref HEAD" + output = _subprocess_run_command(command, shell=False, callback="check_output") + branch_name = output.decode().strip() + return branch_name + + +def get_cellpose_major_version(errors="raise"): + major_installed = None + try: + installed_version = get_package_version("cellpose") + major_installed = int(installed_version.split(".")[0]) + except Exception as err: + if errors == "raise": + raise err + + return major_installed + + +def check_cellpose_version(version: str): + if isinstance(version, int): + version = f"{version}.0" + + major_requested = int(version.split(".")[0]) + cancel = False + try: + installed_version = get_package_version("cellpose") + major_installed = int(installed_version.split(".")[0]) + is_version_correct = major_installed == major_requested + if not is_version_correct: + cancel = _warnings.warn_installing_different_cellpose_version( + version, installed_version + ) + if not is_second_version_greater( + min_target_versions_cp[str(major_requested)], installed_version + ): + is_version_correct = False + except Exception as err: + is_version_correct = False + + if cancel: + raise ModuleNotFoundError("Cellpose installation cancelled by the user.") + return is_version_correct + + +def is_second_version_greater( + target_version: str, + current_version: str, +): + """ + Compares two model versions and returns True if the current version is + greater than or equal to the target version. + """ + target_version = packaging_version.parse(target_version) + current_version = packaging_version.parse(current_version) + + return current_version >= target_version + + +def is_pkg_version_within_range(package_version: str, min_version="", max_version=""): + package_version_number = packaging_version.parse(package_version) + is_greater_than_min = True + if min_version: + min_version_number = packaging_version.parse(min_version) + is_greater_than_min = package_version_number >= min_version_number + + is_less_than_max = True + if max_version: + max_version_number = packaging_version.parse(max_version) + is_less_than_max = package_version_number <= max_version_number + + return is_greater_than_min and is_less_than_max + + +def check_pkg_version( + import_pkg_name, min_version, include_lower_version, raise_err=True +): + is_version_correct = False + try: + installed_version = get_package_version(import_pkg_name) + if include_lower_version: + is_version_correct = packaging_version.parse( + installed_version + ) >= packaging_version.parse(min_version) + else: + is_version_correct = packaging_version.parse( + installed_version + ) > packaging_version.parse(min_version) + except Exception as err: + is_version_correct = False + + if raise_err and not is_version_correct: + raise ModuleNotFoundError(f"{import_pkg_name}>{min_version} not installed.") + else: + return is_version_correct + + +def check_pkg_exact_version(import_pkg_name, version: str, raise_err=True): + is_version_correct = False + try: + installed_version = get_package_version(import_pkg_name) + is_version_correct = packaging_version.parse( + installed_version + ) == packaging_version.parse(version) + except Exception as err: + is_version_correct = False + + if raise_err and not is_version_correct: + raise ModuleNotFoundError(f"{import_pkg_name}=={version} not installed.") + else: + return is_version_correct + + +def check_pkg_max_version( + import_pkg_name, max_version, include_higher_version, raise_err=True +): + is_version_correct = False + try: + from packaging import version + + installed_version = get_package_version(import_pkg_name) + if include_higher_version: + is_version_correct = packaging_version.parse( + installed_version + ) <= packaging_version.parse(max_version) + else: + is_version_correct = packaging_version.parse( + installed_version + ) < packaging_version.parse(max_version) + except Exception as err: + is_version_correct = False + + if raise_err and not is_version_correct: + raise ModuleNotFoundError(f"{import_pkg_name}<={max_version} not installed.") + else: + return is_version_correct + + +def check_matplotlib_version(qparent=None): + mpl_version = get_package_version("matplotlib") + mpl_version_digits = mpl_version.split(".") + + mpl_major = int(mpl_version_digits[0]) + mpl_minor = int(mpl_version_digits[1]) + is_less_than_3_5 = mpl_major < 3 or (mpl_major >= 3 and mpl_minor < 5) + if not is_less_than_3_5: + return + + proceed = _install_package_msg("matplotlib", parent=qparent, upgrade=True) + if not proceed: + raise ModuleNotFoundError(f'User aborted "matplotlib" installation') + import subprocess + + try: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "-U", "matplotlib"] + ) + except Exception as e: + printl(traceback.format_exc()) + _inform_install_package_failed("matplotlib", parent=qparent, do_exit=False) + + +def get_git_pull_checkout_cellacdc_version_commands(version=None): + if version is None: + version = read_version() + commit_hash_idx = version.find("+g") + is_dev_version = commit_hash_idx > 0 + if not is_dev_version: + return [] + commit_hash = version[commit_hash_idx + 2 :].split(".")[0] + commands = ( + f'cd "{os.path.dirname(cellacdc_path)}"', + "git pull", + f"git checkout {commit_hash}", + ) + return commands + + +def _update_repo_with_git_command(package_name, repo_location): + """Update repository using git command""" + try: + print( + f"Updating {package_name} repository at {repo_location} using git command..." + ) + + # Change to repository directory + original_cwd = os.getcwd() + os.chdir(repo_location) + + stashed_changes = False + + # check if there is a portable git + from .config import parser_args + + try: + cp = parser_args + if cp["install_details"] is not None: + no_cli_install = True + install_details = cp["install_details"] + target_dir = install_details.get("target_dir", "") + target_dir = target_dir.strip().strip('"').strip("'") + target_dir = os.path.abspath(target_dir) + else: + no_cli_install = False + except: + no_cli_install = False + pass + + if is_win and no_cli_install: + git_loc = os.path.join(target_dir, "portable_git", "cmd", "git.exe") + if not os.path.exists(git_loc): + print(f"Portable git not found at {git_loc}. Using system git.") + git_loc = "git" + else: + git_loc = "git" + + # Check if git is available + if not shutil.which(git_loc): + print( + f"Git command not found. Please install git to update {package_name}." + ) + return False + + try: + # Check for uncommitted changes + + branch_result = subprocess.run( + [git_loc, "branch", "--show-current"], + capture_output=True, + text=True, + check=True, + ) + current_branch = branch_result.stdout.strip() + print(f"Current branch: {current_branch}") + + result = subprocess.run( + [git_loc, "status", "--porcelain"], + capture_output=True, + text=True, + check=True, + ) + if result.stdout.strip(): + print(f"Repository {package_name} has uncommitted changes") + print("Stashing changes before update...") + subprocess.run([git_loc, "stash"], check=True) + stashed_changes = True + + # Pull changes + subprocess.run([git_loc, "pull"], check=True) + print(f"Successfully updated {package_name}") + + # Pop stashed changes if any were stashed + if stashed_changes: + try: + subprocess.run([git_loc, "stash", "pop"], check=True) + print("Restored stashed changes") + except subprocess.CalledProcessError as pop_error: + print(f"Warning: Could not restore stashed changes: {pop_error}") + + return True + + except subprocess.CalledProcessError as e: + print(f"Git command failed for {package_name}: {e}") + return False + finally: + os.chdir(original_cwd) + + except Exception as e: + print(f"Error updating {package_name} with git command: {e}") + return False + +# Sibling imports (deferred to avoid import cycles) +from .install import ( + _inform_install_package_failed, + _install_package_msg, + get_package_version, +) +from .misc import ( + _subprocess_run_command, + format_commit_date_utc, + get_linux_distribution_name, +) + diff --git a/cellacdc/viewer.py b/cellacdc/viewer.py new file mode 100644 index 000000000..b99b68b51 --- /dev/null +++ b/cellacdc/viewer.py @@ -0,0 +1,100 @@ +"""Napari-style script API for launching the Cell-ACDC GUI.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from weakref import WeakSet + +from cellacdc.data_source import ExperimentData + +if TYPE_CHECKING: + from cellacdc.gui import guiWin + +_DEFAULT_MODE = "Segmentation and Tracking" + + +def _check_gui_installed() -> None: + from cellacdc import GUI_INSTALLED + + if not GUI_INSTALLED: + raise RuntimeError( + "Cell-ACDC GUI dependencies are not installed. " + 'Install them with `pip install "cellacdc[gui]"`.' + ) + + +def _read_version() -> str: + from cellacdc import utils + + return utils.read_version() + + +def _create_gui_window(app, version: str): + from cellacdc import gui + + win = gui.guiWin(app, mainWin=None, version=version) + win.run() + return win + + +class Viewer: + """Launch the Cell-ACDC annotation GUI from a script or notebook.""" + + _instances: WeakSet[Viewer] = WeakSet() + + def __init__( + self, + data: ExperimentData | None = None, + *, + show: bool = True, + mode: str = _DEFAULT_MODE, + ): + _check_gui_installed() + + from cellacdc._event_loop import get_qapp + + app = get_qapp() + version = _read_version() + win = _create_gui_window(app, version) + win.modeComboBox.setCurrentText(mode) + + self._data = data + if data is not None: + data.load_into(win) + + if show: + win.raise_() + win.activateWindow() + + self._window = win + self._instances.add(self) + + @property + def data(self) -> ExperimentData | None: + return self._data + + @property + def window(self) -> guiWin: + return self._window + + def close(self) -> None: + self._window.close() + + +def current_viewer() -> Viewer | None: + """Return the most recently created viewer, if any.""" + instances = list(Viewer._instances) + if not instances: + return None + return instances[-1] + + +def imshow( + data: ExperimentData, + *, + show: bool = True, + mode: str = _DEFAULT_MODE, +) -> tuple[Viewer, ExperimentData]: + """Open the GUI with an :class:`ExperimentData` instance.""" + viewer = Viewer(data, show=show, mode=mode) + return viewer, data diff --git a/cellacdc/whitelist.py b/cellacdc/whitelist.py index 0621f4514..78eb632b3 100644 --- a/cellacdc/whitelist.py +++ b/cellacdc/whitelist.py @@ -1,28 +1,29 @@ import os import numpy as np import skimage.measure -from . import printl, myutils +from . import printl, utils import json from typing import Set, List, Tuple import time from . import ( - html_utils, - apps, - widgets, - exception_handler, - disableWindow, + html_utils, + apps, + widgets, + exception_handler, + disableWindow, gui_utils, - exec_time + exec_time, ) from .trackers.CellACDC import CellACDC_tracker + class Whitelist: - """A class to manage the whitelist of IDs for a video. - """ + """A class to manage the whitelist of IDs for a video.""" + def __init__(self, total_frames: int | list | set, debug=False): """Initializes the whitelist with the total number of frames. - The whitelist is a dictionary with the frame index + The whitelist is a dictionary with the frame index as the key and a set of IDs as the value. Also the original not whitelisted labs are stored in the originalLabs variable. @@ -50,7 +51,7 @@ def __init__(self, total_frames: int | list | set, debug=False): self.initialized_i = set() self.new_centroids = None - def __getitem__(self, index:int): + def __getitem__(self, index: int): """Gets a whitelist for a given index. Parameters @@ -65,7 +66,7 @@ def __getitem__(self, index:int): """ return self.get(index) - def __setitem__(self, index:int, value:set): + def __setitem__(self, index: int, value: set): """Sets a whitelist for a given index. Parameters @@ -77,8 +78,8 @@ def __setitem__(self, index:int, value:set): """ self.whitelistIDs[index] = set(value) - def loadOGLabs(self, selected_path:str=None, og_data:np.ndarray=None): - """Loads the original labels from a .npz file, + def loadOGLabs(self, selected_path: str = None, og_data: np.ndarray = None): + """Loads the original labels from a .npz file, or from the provided og_data. Parameters @@ -93,9 +94,12 @@ def loadOGLabs(self, selected_path:str=None, og_data:np.ndarray=None): og_data = og_data[og_data.files[0]] self.originalLabs = og_data - self.originalLabsIDs = [{obj.label for obj in skimage.measure.regionprops(frame)} for frame in og_data] - - def saveOGLabs(self, save_path:str): + self.originalLabsIDs = [ + {obj.label for obj in skimage.measure.regionprops(frame)} + for frame in og_data + ] + + def saveOGLabs(self, save_path: str): """Saves the original labels to a .npz file. Parameters @@ -103,25 +107,27 @@ def saveOGLabs(self, save_path:str): save_path : str desired save path for the original labels """ - # original_frames = np.array(list(self.originalLabs.values())) - # the above is not necessary anymore, - #since I changed the originalLabs to be a np.ndarray + # original_frames = np.array(list(self.originalLabs.values())) + # the above is not necessary anymore, + # since I changed the originalLabs to be a np.ndarray np.savez_compressed(save_path, self.originalLabs) - - def load(self, whitelist_path:str, - new_centroids_path:str, - segm_data:np.ndarray, - allData_li:list=None, - ): + + def load( + self, + whitelist_path: str, + new_centroids_path: str, + segm_data: np.ndarray, + allData_li: list = None, + ): """Loads the whitelist from a json file. If the file does not exist, it initializes the whitelist to None. - If the file exists, it loads the whitelist and initializes + If the file exists, it loads the whitelist and initializes the originalLabs variable. Parameters ---------- whitelist_path : str - path to the whitelist json file (should be in accordance to the + path to the whitelist json file (should be in accordance to the one provided in save) segm_data : np.ndarray segmentation data for the video @@ -137,15 +143,17 @@ def load(self, whitelist_path:str, if not os.path.exists(whitelist_path): self.whitelistIDs = None return False - + any_whitelist_added = False - with open(whitelist_path, 'r') as json_file: + with open(whitelist_path, "r") as json_file: whitelist = json.load(json_file) wl_processed = dict() for key, val in whitelist.items(): if val is None: wl_processed[int(key)] = None - elif val == "None": # if the string "none" is present in the json file, it will be converted to None + elif ( + val == "None" + ): # if the string "none" is present in the json file, it will be converted to None wl_processed[int(key)] = None else: wl_processed[int(key)] = set(val) @@ -156,37 +164,43 @@ def load(self, whitelist_path:str, else: self.whitelistIDs = None return False - - self.makeOriginalLabsAndIDs(segm_data, allData_li - ) - + + self.makeOriginalLabsAndIDs(segm_data, allData_li) + self.load_centroids(new_centroids_path=new_centroids_path) return True - - def load_centroids(self, new_centroids_path:str): + + def load_centroids(self, new_centroids_path: str): if os.path.exists(new_centroids_path): - with open(new_centroids_path, 'r') as json_file: + with open(new_centroids_path, "r") as json_file: self.new_centroids = json.load(json_file) - - self.new_centroids = list(self.new_centroids) if isinstance(self.new_centroids, list) else self.new_centroids + + self.new_centroids = ( + list(self.new_centroids) + if isinstance(self.new_centroids, list) + else self.new_centroids + ) for i, val in enumerate(self.new_centroids): if isinstance(val, str) and val.lower() == "none": self.new_centroids[i] = {} elif val is None: self.new_centroids[i] = {} - else: # convert to integers - self.new_centroids[i] = {tuple(map(int, centroid)) for centroid in val} + else: # convert to integers + self.new_centroids[i] = { + tuple(map(int, centroid)) for centroid in val + } else: - printl('No new centroids file found, initializing new centroids.') + printl("No new centroids file found, initializing new centroids.") self.create_new_centroids() - - def create_new_centroids(self, - curr_rp=None, - frame_i:int=None, - ): + + def create_new_centroids( + self, + curr_rp=None, + frame_i: int = None, + ): """ Creates self.new_centroids based on the input data. - + Parameters ---------- @@ -202,38 +216,37 @@ def create_new_centroids(self, """ if self.new_centroids is not None: return - + if frame_i is None and curr_rp is not None: - raise ValueError( - 'If curr_rp is provided, frame_i must also be provided.' - ) - + raise ValueError("If curr_rp is provided, frame_i must also be provided.") + self.new_centroids = [] for i in self.total_frames: if i == 0: self.new_centroids.append({}) continue - - all_there = (self.originalLabsIDs[i] is not None and - self.originalLabsIDs[i-1] is not None) + + all_there = ( + self.originalLabsIDs[i] is not None + and self.originalLabsIDs[i - 1] is not None + ) if all_there is False: self.new_centroids.append({}) continue - - new_IDs = self.originalLabsIDs[i] - self.originalLabsIDs[i-1] - + + new_IDs = self.originalLabsIDs[i] - self.originalLabsIDs[i - 1] + rp = None - if frame_i==i and curr_rp is not None: + if frame_i == i and curr_rp is not None: rp = curr_rp else: rp = skimage.measure.regionprops(self.originalLabs[i]) - self.new_centroids.append({ - tuple(map(int, obj.centroid)) for obj in rp if obj.label in new_IDs - }) - + self.new_centroids.append( + {tuple(map(int, obj.centroid)) for obj in rp if obj.label in new_IDs} + ) - def save(self, whitelist_path:str, new_centroids_path:str): + def save(self, whitelist_path: str, new_centroids_path: str): """Saves the whitelist to a json file. If the whitelist is None, it will not be saved. Make sure that the path is in accordance to the one provided in load. @@ -241,7 +254,7 @@ def save(self, whitelist_path:str, new_centroids_path:str): Parameters ---------- whitelist_path : str - path to the whitelist json file (should be in accordance to the + path to the whitelist json file (should be in accordance to the one provided in load) """ if not self.whitelistIDs: @@ -252,20 +265,20 @@ def save(self, whitelist_path:str, new_centroids_path:str): wl_copy[key] = "None" else: wl_copy[key] = list(val) - json.dump(wl_copy, open(whitelist_path, 'w+'), indent=4) - + json.dump(wl_copy, open(whitelist_path, "w+"), indent=4) + for i, val in enumerate(self.new_centroids): if val is None: self.new_centroids[i] = "None" else: self.new_centroids[i] = list(val) - with open(new_centroids_path, 'w+') as json_file: + with open(new_centroids_path, "w+") as json_file: json.dump(self.new_centroids, json_file, indent=4) - def checkOriginalLabels(self, frame_i:int): + def checkOriginalLabels(self, frame_i: int): """Checks if there are no original labels for the current frame. - + Parameters ---------- frame_i : int @@ -275,22 +288,28 @@ def checkOriginalLabels(self, frame_i:int): bool True if there are original labels, False otherwise. """ - if len(self.originalLabsIDs) <= frame_i or self.originalLabsIDs is None or self.originalLabsIDs[frame_i] is None: + if ( + len(self.originalLabsIDs) <= frame_i + or self.originalLabsIDs is None + or self.originalLabsIDs[frame_i] is None + ): return False return True - def addNewIDs(self, frame_i:int, - allData_li: list, - IDs_curr: List[int] | Set[int]=None, - index_lab_combo: Tuple[int, np.ndarray]=None, - curr_rp: list=None, - curr_lab: np.ndarray=None, - # per_frame_IDs=None, - # labs=None - ): + def addNewIDs( + self, + frame_i: int, + allData_li: list, + IDs_curr: List[int] | Set[int] = None, + index_lab_combo: Tuple[int, np.ndarray] = None, + curr_rp: list = None, + curr_lab: np.ndarray = None, + # per_frame_IDs=None, + # labs=None + ): """Adds new IDs to the whitelist for a given frame based on the - original labels. The IDs are added to the whitelist for the + original labels. The IDs are added to the whitelist for the current frame. Also propagates. @@ -302,57 +321,62 @@ def addNewIDs(self, frame_i:int, passed to self.propagateIDs(), see rest of ACDC: posData. allData_li IDs_curr : list | set, optional - Currently present IDs, passed to self.propagateIDs(). by default + Currently present IDs, passed to self.propagateIDs(). by default None index_lab_combo: Tuple[int, np.ndarray]=None, - Combination of frame_i and current frame, + Combination of frame_i and current frame, passed to self.propagateIDs(), by default None curr_rp : list, optional - Region properties for the current frame, passed to + Region properties for the current frame, passed to self.propagateIDs(). by default None curr_lab : np.ndarray, optional - Labels for the current frame, passed to self.propagateIDs(). + Labels for the current frame, passed to self.propagateIDs(). by default None """ - - for i in [frame_i, frame_i-1]: + + for i in [frame_i, frame_i - 1]: if not self.checkOriginalLabels(i): return - + if curr_lab is None: - curr_lab = allData_li[frame_i]['labels'] - + curr_lab = allData_li[frame_i]["labels"] + new_centroids = self.new_centroids[frame_i] if not new_centroids: return - - new_IDs = {gui_utils.nearest_ID_to_centroid(curr_lab, *new_centroid) for new_centroid in new_centroids} - - self.propagateIDs(IDs_to_add=new_IDs, - curr_frame_only=False, - frame_i=frame_i, - allData_li=allData_li, - IDs_curr=IDs_curr, - index_lab_combo=index_lab_combo, - allow_only_current_IDs=False, - curr_rp=curr_rp, - curr_lab=curr_lab, - # per_frame_IDs=per_frame_IDs, - # labs=labs - ) - - def IDsAccepted(self, - whitelistIDs: Set[int] | List[int], - frame_i: int, - allData_li: list, - segm_data: np.ndarray, - curr_lab: np.ndarray=None, - index_lab_combo: Tuple[int, np.ndarray]=None, - IDs_curr: Set[int] | List[int]=None, - curr_rp: list=None, - # labs=None - ): - """Called if the user accepted IDs. + + new_IDs = { + gui_utils.nearest_ID_to_centroid(curr_lab, *new_centroid) + for new_centroid in new_centroids + } + + self.propagateIDs( + IDs_to_add=new_IDs, + curr_frame_only=False, + frame_i=frame_i, + allData_li=allData_li, + IDs_curr=IDs_curr, + index_lab_combo=index_lab_combo, + allow_only_current_IDs=False, + curr_rp=curr_rp, + curr_lab=curr_lab, + # per_frame_IDs=per_frame_IDs, + # labs=labs + ) + + def IDsAccepted( + self, + whitelistIDs: Set[int] | List[int], + frame_i: int, + allData_li: list, + segm_data: np.ndarray, + curr_lab: np.ndarray = None, + index_lab_combo: Tuple[int, np.ndarray] = None, + IDs_curr: Set[int] | List[int] = None, + curr_rp: list = None, + # labs=None + ): + """Called if the user accepted IDs. This can also be called if one wants forced propagation of IDs. Parameters @@ -366,31 +390,29 @@ def IDsAccepted(self, segm_data : np.ndarray The segmentation data for the video. Fallback to when allData_li is not provided. curr_lab : np.ndarray, optional - Labels for the current frame. Use instead of allData_li/segm_data + Labels for the current frame. Use instead of allData_li/segm_data for current frame_i Also passed to self.propagateIDs(), by default None index_lab_combo : Tuple[int, np.ndarray], optional - Combination of frame_i and current frame, + Combination of frame_i and current frame, passed to self.propagateIDs(), by default None IDs_curr : list | set, optional Currently present IDs, passed to self.propagateIDs(), by default None curr_rp : list, optional Region properties for the current frame, passed to self.propagateIDs(), by default None """ - + # if allData_li is None and labs is None: # raise ValueError('Either allData_li or curr_labs must be provided') # elif allData_li is not None and labs is not None: # raise ValueError('Either allData_li or curr_labs must be provided, not both') if self.whitelistIDs is None: - self.whitelistIDs = { - i: None for i in self.total_frames - } + self.whitelistIDs = {i: None for i in self.total_frames} if IDs_curr: if self._debug: - printl('Using IDs_curr') + printl("Using IDs_curr") try: IDs_curr = IDs_curr.copy() except AttributeError: @@ -399,49 +421,56 @@ def IDsAccepted(self, elif index_lab_combo and index_lab_combo[0] == frame_i: lab = index_lab_combo[1] if self._debug: - printl('Using index_lab_combo') + printl("Using index_lab_combo") IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} elif curr_rp is not None: IDs_curr = {obj.label for obj in curr_rp} if self._debug: - printl('Using rp') + printl("Using rp") elif curr_lab is not None: lab = curr_lab if self._debug: - printl('Using curr_lab') + printl("Using curr_lab") IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]["IDs"] if self._debug: - printl('Using allData_li') - - IDs_curr = set(IDs_curr) + printl("Using allData_li") + IDs_curr = set(IDs_curr) - self.makeOriginalLabsAndIDs(segm_data, allData_li=allData_li, - frame_i=frame_i, curr_lab=curr_lab, - IDs_curr=IDs_curr, - ) + self.makeOriginalLabsAndIDs( + segm_data, + allData_li=allData_li, + frame_i=frame_i, + curr_lab=curr_lab, + IDs_curr=IDs_curr, + ) self.create_new_centroids() whitelistIDs = set(whitelistIDs) - self.propagateIDs(frame_i, - allData_li, - new_whitelist=whitelistIDs, - try_create_new_whitelists=True, - force_not_dynamic_update=True, - index_lab_combo=index_lab_combo, - IDs_curr=IDs_curr, - curr_rp=curr_rp, - curr_lab=curr_lab, - # labs=labs, - ) - - def makeOriginalLabsAndIDs(self, segm_data: np.ndarray, - allData_li: list=None, frame_i: int=None, - curr_lab: np.ndarray=None, - IDs_curr: set | list=None,): - """ Initializes the originalLabs and originalLabsIDs variables. + self.propagateIDs( + frame_i, + allData_li, + new_whitelist=whitelistIDs, + try_create_new_whitelists=True, + force_not_dynamic_update=True, + index_lab_combo=index_lab_combo, + IDs_curr=IDs_curr, + curr_rp=curr_rp, + curr_lab=curr_lab, + # labs=labs, + ) + + def makeOriginalLabsAndIDs( + self, + segm_data: np.ndarray, + allData_li: list = None, + frame_i: int = None, + curr_lab: np.ndarray = None, + IDs_curr: set | list = None, + ): + """Initializes the originalLabs and originalLabsIDs variables. Parameters ---------- @@ -461,41 +490,44 @@ def makeOriginalLabsAndIDs(self, segm_data: np.ndarray, if IDs_curr is not None or curr_lab is not None: if IDs_curr is None or curr_lab is None or frame_i is None: raise ValueError( - 'If IDs_curr, curr_lab or frame_i are provided, all must be provided.' + "If IDs_curr, curr_lab or frame_i are provided, all must be provided." ) - + self.originalLabs = np.copy(segm_data) self.originalLabsIDs = [None] * len(self.total_frames) - + if IDs_curr is not None: self.originalLabsIDs[frame_i] = IDs_curr - + if allData_li is not None: for i in range(len(allData_li)): - if i == frame_i and IDs_curr is not None: # already set + if i == frame_i and IDs_curr is not None: # already set continue lab = None try: - lab = allData_li[i]['labels'] + lab = allData_li[i]["labels"] except: pass if lab is not None: self.originalLabs[i] = lab.copy() - + for i in range(len(segm_data)): IDs = None if IDs_curr is not None and i == frame_i: IDs = set(IDs_curr) elif allData_li is not None: try: - IDs = set(allData_li[i]['IDs']) + IDs = set(allData_li[i]["IDs"]) except KeyError: pass if IDs is None: - IDs = {obj.label for obj in skimage.measure.regionprops(self.originalLabs[i])} + IDs = { + obj.label + for obj in skimage.measure.regionprops(self.originalLabs[i]) + } self.originalLabsIDs[i] = IDs - - def get(self,frame_i:int,try_create_new_whitelists:bool=False): + + def get(self, frame_i: int, try_create_new_whitelists: bool = False): """Gets the whitelist for a given frame index. If the whitelist is not initialized, and try_create_new_whitelists is True, it will create a new whitelist empty for that frame. @@ -515,7 +547,7 @@ def get(self,frame_i:int,try_create_new_whitelists:bool=False): """ try: - old_whitelistIDs =self.whitelistIDs[frame_i] + old_whitelistIDs = self.whitelistIDs[frame_i] except Exception as e: if not try_create_new_whitelists: raise e @@ -525,21 +557,22 @@ def get(self,frame_i:int,try_create_new_whitelists:bool=False): old_whitelistIDs = set() else: raise e - + if old_whitelistIDs is None: old_whitelistIDs = set() else: old_whitelistIDs = set(old_whitelistIDs) - + return old_whitelistIDs - def initNewFrames(self, - frame_i: int, - force: bool = False, - ): + def initNewFrames( + self, + frame_i: int, + force: bool = False, + ): """Initialize the whitelists for all new frame. All frames up to and including frame_i will be initialized. - Unless forced, it will only initialize the whitelist if the frame is not + Unless forced, it will only initialize the whitelist if the frame is not already initialized, (tracked with self.initialized_i). Parameters @@ -547,7 +580,7 @@ def initNewFrames(self, frame_i : int The frame index for where the initialization should be done. force : bool, optional - If the frame_i (only this frame_i in that case) + If the frame_i (only this frame_i in that case) should be reinit, by default False Returns @@ -555,10 +588,10 @@ def initNewFrames(self, bool True if a new frame was initialized, False if not. """ - - missing_frames = set(range(frame_i+1)) - self.initialized_i + + missing_frames = set(range(frame_i + 1)) - self.initialized_i update_frames = [] - + if self._debug: printl(missing_frames, self.initialized_i, frame_i) @@ -575,10 +608,10 @@ def initNewFrames(self, if i == 0: prev_wl = set() else: - prev_wl = self.whitelistIDs[i-1] + prev_wl = self.whitelistIDs[i - 1] if prev_wl is None: prev_wl = set() - + if not self.checkOriginalLabels(i): available_IDs = set() else: @@ -591,46 +624,50 @@ def initNewFrames(self, self.whitelistIDs[i] = new_wl else: self.whitelistIDs[i] = set() - + self.initialized_i.add(i) update_frames.append((i, None, None, True)) if self._debug: - printl('Whitelist IDs new frame (without adding new IDs):', self.whitelistIDs[frame_i]) + printl( + "Whitelist IDs new frame (without adding new IDs):", + self.whitelistIDs[frame_i], + ) return new_frame, update_frames - - def propagateIDs(self, - frame_i: int, - allData_li: list, - new_whitelist: Set[int] | List[int] = None, - IDs_to_add: Set[int] = None, - IDs_to_remove: Set[int] = None, - try_create_new_whitelists: bool = False, - curr_frame_only: bool = False, - force_not_dynamic_update: bool = False, - only_future_frames: bool = True, - allow_only_current_IDs: bool = True, - IDs_curr: Set[int] | List[int] = None, - index_lab_combo: Tuple[int, np.ndarray] = None, - curr_rp: list = None, - curr_lab: np.ndarray = None, - update_frames: list = None, - ): + + def propagateIDs( + self, + frame_i: int, + allData_li: list, + new_whitelist: Set[int] | List[int] = None, + IDs_to_add: Set[int] = None, + IDs_to_remove: Set[int] = None, + try_create_new_whitelists: bool = False, + curr_frame_only: bool = False, + force_not_dynamic_update: bool = False, + only_future_frames: bool = True, + allow_only_current_IDs: bool = True, + IDs_curr: Set[int] | List[int] = None, + index_lab_combo: Tuple[int, np.ndarray] = None, + curr_rp: list = None, + curr_lab: np.ndarray = None, + update_frames: list = None, + ): """ Propagates whitelist IDs across frames in the dataset. (Doesn't update labs) Should also be called when viewing a new frame! This function updates whitelist. If curr_frame_only is True, it only updates the - whitelist of the current frame. If the frame changes, this function should be called + whitelist of the current frame. If the frame changes, this function should be called again to update the whitelist for the new frame (without this argument). It should also handle cases were this is not done, but this is less safe. Then, all the additions and removals are propagated to the other frames. - If force_not_dynamic_update is True, the function will propagate the entire whitelist to + If force_not_dynamic_update is True, the function will propagate the entire whitelist to frames, and not only the IDs which were added or removed. Hierarchy of arguments for current_IDs: 1. IDs_curr (if provided) - (2. index_lab_combo (if provided) (is also passed to not current frame only + (2. index_lab_combo (if provided) (is also passed to not current frame only propagation if that propagation is necessary, and used when the frame_i matches)) 3. curr_rp (if provided) 4. curr_lab (if provided) @@ -641,43 +678,43 @@ def propagateIDs(self, frame_i : int The frame index for the propagation. allData_li : list - See rest of ACDC. posData.allData_li. + See rest of ACDC. posData.allData_li. Used to get the IDs for the current frame. Especially for when propagating after curr_frame_only was changed. Strictly speaking could be substituted with the correct index_lab_combo if necessary in the future. new_whitelist : Set[int] | List[int], optional - A new set of whitelist IDs to replace the current whitelist. Cannot be + A new set of whitelist IDs to replace the current whitelist. Cannot be used together with `IDs_to_add` or `IDs_to_remove`, by default None. IDs_to_add : Set[int], optional A set of IDs to add to the current whitelist, by default None. IDs_to_remove : Set[int], optional A set of IDs to remove from the current whitelist, by default None. try_create_new_whitelists : bool, optional - If True, creates new whitelist entries for frames that do not already + If True, creates new whitelist entries for frames that do not already have them. Should only be necessary when its initialized, by default False. curr_frame_only : bool, optional - If True, only updates the whitelist for the current frame. + If True, only updates the whitelist for the current frame. (See description of function), by default False. force_not_dynamic_update : bool, optional - If True, disables dynamic updates to the whitelist. + If True, disables dynamic updates to the whitelist. (See description of function), by default False. only_future_frames : bool, optional If True, propagates changes only to future frames, by default True. allow_only_current_IDs : bool, optional - If True, only allows IDs that are present in the current frame + If True, only allows IDs that are present in the current frame to be added to the whitelist, by default True. IDs_curr : Set[int] | List[int], optional - A set of IDs for the current frame, if None, + A set of IDs for the current frame, if None, will be calculated from other stuff (see description), by default None. index_lab_combo : Tuple[int, np.ndarray], optional - Combination of frame_i and current frame, + Combination of frame_i and current frame, Used to get IDs_curr (see description), when the frame_i matches - (is also passed to not current frame only - propagation if that propagation is necessary, + (is also passed to not current frame only + propagation if that propagation is necessary, and used when the frame_i matches), by default None. curr_rp : list, optional - Region properties for the current frame. For IDs_curr. (see description), + Region properties for the current frame. For IDs_curr. (see description), by default None. curr_lab : np.ndarray, optional Labels for the current frame for IDs_curr. (see description), @@ -709,7 +746,7 @@ def propagateIDs(self, This would also propagate the changes to all other frames. """ - #doesn't update the frame displayed, only wl + # doesn't update the frame displayed, only wl # if allData_li is not None and per_frame_IDs is not None: # raise ValueError('Cannot provide both allData_li and per_frame_IDs') @@ -717,50 +754,50 @@ def propagateIDs(self, # raise ValueError('Either allData_li or per_frame_IDs or labs must be provided') # elif not allData_li and not per_frame_IDs: # per_frame_IDs = [set() for _ in labs] - + if not update_frames: update_frames = [] if self._debug: - printl('Propagating IDs...') - myutils.print_call_stack() + printl("Propagating IDs...") + utils.print_call_stack() printl(new_whitelist, IDs_to_add, IDs_to_remove) # if labs is None and not allData_li and not IDs_curr: # raise ValueError('Either labs or allData_li or IDs_curr/must be provided') # elif labs is not None and allData_li: - # raise ValueError('Cannot provide both labs and allData_li') - # elif + # raise ValueError('Cannot provide both labs and allData_li') + # elif if IDs_curr: if self._debug: - printl('Using IDs_curr') + printl("Using IDs_curr") try: IDs_curr = IDs_curr.copy() except AttributeError: pass IDs_curr = set(IDs_curr) - + elif index_lab_combo and index_lab_combo[0] == frame_i: lab = index_lab_combo[1] if self._debug: - printl('Using index_lab_combo') + printl("Using index_lab_combo") IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} elif curr_rp is not None: IDs_curr = {obj.label for obj in curr_rp} if self._debug: - printl('Using rp') + printl("Using rp") elif curr_lab is not None: lab = curr_lab if self._debug: - printl('Using curr_lab') + printl("Using curr_lab") IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]["IDs"] if self._debug: - printl('Using allData_li') - + printl("Using allData_li") + IDs_curr = set(IDs_curr) - + # else: # lab = labs[frame_i] # if self._debug: @@ -779,30 +816,39 @@ def propagateIDs(self, self.whitelistOriginalIDs = self.whitelistIDs[frame_i].copy() elif self.whitelistOriginalFrame_i != frame_i: if self._debug: - printl('Frame changed, whitelist was not propagated, propagating...') - new_update_frames = self.propagateIDs(self.whitelistOriginalFrame_i, - allData_li, - index_lab_combo=index_lab_combo, - update_frames=update_frames) + printl( + "Frame changed, whitelist was not propagated, propagating..." + ) + new_update_frames = self.propagateIDs( + self.whitelistOriginalFrame_i, + allData_li, + index_lab_combo=index_lab_combo, + update_frames=update_frames, + ) update_frames.extend(new_update_frames) else: if self.whitelistOriginalFrame_i is not None: if self.whitelistOriginalFrame_i != frame_i: if self._debug: - printl('Frame changed, whitelist was not propagated, propagating...') - new_update_frames = self.propagateIDs(self.whitelistOriginalFrame_i, - allData_li, - index_lab_combo=index_lab_combo, - update_frames=update_frames - ) + printl( + "Frame changed, whitelist was not propagated, propagating..." + ) + new_update_frames = self.propagateIDs( + self.whitelistOriginalFrame_i, + allData_li, + index_lab_combo=index_lab_combo, + update_frames=update_frames, + ) update_frames.extend(new_update_frames) else: propagate_after_curr_frame_only_flag = True self.whitelistOriginalFrame_i = None - + # see what the situation is with adding/removing IDs if new_whitelist and (IDs_to_add is not None or IDs_to_remove is not None): - raise ValueError('Cannot provide both new_whitelist and IDs_to_add or IDs_to_remove') + raise ValueError( + "Cannot provide both new_whitelist and IDs_to_add or IDs_to_remove" + ) # figure out what old wl supposed to be... if force_not_dynamic_update: @@ -810,14 +856,14 @@ def propagateIDs(self, elif propagate_after_curr_frame_only_flag: old_whitelist = self.whitelistOriginalIDs else: - old_whitelist = self.get(frame_i,try_create_new_whitelists) + old_whitelist = self.get(frame_i, try_create_new_whitelists) # construct new_whitelist if new_whitelist is not None: new_whitelist = set(new_whitelist) - else: # updated later if IDs_to_add or IDs_to_remove are provided - new_whitelist = self.get(frame_i,try_create_new_whitelists) - + else: # updated later if IDs_to_add or IDs_to_remove are provided + new_whitelist = self.get(frame_i, try_create_new_whitelists) + if IDs_to_add is not None or IDs_to_remove is not None: if IDs_to_add is None: IDs_to_add = set() @@ -848,7 +894,7 @@ def propagateIDs(self, if self._debug: printl(IDs_to_add, IDs_to_remove) - + prop_to_frame_i = last_frame_i if curr_frame_only: @@ -872,7 +918,7 @@ def propagateIDs(self, if frame_i == i: IDs_curr_loc = IDs_curr else: - IDs_curr_loc = set(allData_li[i]['IDs']) + IDs_curr_loc = set(allData_li[i]["IDs"]) new_whitelist = self.get(i, try_create_new_whitelists).copy() old_whitelist = new_whitelist.copy() @@ -880,1128 +926,20 @@ def propagateIDs(self, removed_IDs = [] if IDs_to_add: # intersection with... all possible IDs ...plus all old_whitelistIDs - new_whitelist = IDs_to_add.intersection(IDs_curr_loc.union(IDs_og)) | old_whitelist + new_whitelist = ( + IDs_to_add.intersection(IDs_curr_loc.union(IDs_og)) | old_whitelist + ) # IDs_curr.union(IDs_og) are all possible IDs, IDs_to_add.intersection(IDs_curr.union(IDs_og)) is for finding all possible IDs which want ot be propagated added_IDs = new_whitelist - old_whitelist if IDs_to_remove: new_whitelist = new_whitelist - IDs_to_remove removed_IDs = old_whitelist - new_whitelist - + self.whitelistIDs[i] = new_whitelist if added_IDs or removed_IDs: - update_frames.append((i,added_IDs, removed_IDs, False)) + update_frames.append((i, added_IDs, removed_IDs, False)) if self._debug: printl(self.whitelistIDs[frame_i]) - - return update_frames - -class WhitelistGUIElements: - """A class to manage the whitelist GUI elements. - """ - def whitelistCheckOriginalLabels(self, warning:bool=True, - frame_i:int=None): - """Warns the user that there are no original labels labels are present - for the frame""" - posData = self.data[self.pos_i] - if posData.whitelist is None: - return False - - if frame_i is None: - frame_i = posData.frame_i - - if posData.whitelist.originalLabsIDs is None: - return False - - if (frame_i >= len(posData.whitelist.originalLabsIDs) or - posData.whitelist.originalLabsIDs[frame_i] is None): - txt = """ - No original labels are present for the current frame, - this action cannot be performed.""" - self.logger.warning(txt) - if not warning: - return False - msg = widgets.myMessageBox.warning( - self, 'No original labels', txt, - ) - - return False - else: - return True - - @disableWindow - def whitelistTrackOGagainstPreviousFrame_cb(self, signal_slot=None): - """Tracks the original labels against the previous frame. - This is used as a callback for sigTrackOGagainstPreviousFrame signal - """ - posData = self.data[self.pos_i] - frame_i = posData.frame_i - if not self.whitelistCheckOriginalLabels(): - return - old_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] - prev_cell_IDs = posData.allData_li[frame_i-1]['IDs'] - self.whitelistTrackOGCurr(against_prev=True) - new_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] - - new_IDs = new_cell_IDs - old_cell_IDs - new_IDs = new_IDs & set(prev_cell_IDs) - - self.whitelistUpdateLab( - track_og_curr=False, IDs_to_add=new_IDs, - ) - - def whitelistLoadOGLabs_cb(self): - """Generates a dialog to load the original (not whitelisted) labels - """ - posData = self.data[self.pos_i] - curr_seg_path = posData.segm_npz_path - - segmFilename = os.path.basename(curr_seg_path) - custom_first = f"{segmFilename[:-4]}_not_whitelisted.npz" - images_path = posData.images_path - existingEndnames = [ - files for files in os.listdir(images_path) if files.endswith('.npz') - ] - if custom_first not in existingEndnames: - custom_first = None - - infoText = html_utils.paragraph( - 'Select the segmentation file containing the original labels ' - 'of the objects. Pleae note that the current saved "original" ' - 'labels will be replaced with the new ones, but the filtered ' - 'labels will be kept.' - ) - - win = apps.SelectSegmFileDialog( - existingEndnames, images_path, parent=self, - basename=posData.basename, infoText=infoText, - custom_first=custom_first - ) - win.exec_() - if win.cancel: - self.logger.info('Loading original labels canceled.') - return - selected = win.selectedItemText - self.logger.info(f'Loading original labels from {selected}...') - self.whitelistLoadOGLabs(selected) - - @disableWindow - def whitelistLoadOGLabs(self, selected:str): - """Loads the original labels from the selected files - - Parameters - ---------- - selected : str - Selected file name from the dialog. - """ - posData = self.data[self.pos_i] - images_path = posData.images_path - - selected_path = os.path.join(images_path, selected) - posData.whitelist.loadOGLabs(selected_path) - - self.whitelistIDsToolbar.viewOGToggle.setCheckable(True) - - @exception_handler - @disableWindow - def whitelistViewOGIDs(self, checked:bool): - """Switch between selected and original labels. - Uses self.viewOriginalLabels to see what has to be done. - - Parameters - ---------- - checked : bool - True if the original labels have to be shown, False otherwise. - """ - switch_to_og = checked and not self.viewOriginalLabels - switch_to_seg = not checked and self.viewOriginalLabels - - if not switch_to_og and not switch_to_seg: - return - - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - if posData.whitelist._debug: - printl('whitelistViewOGIDs', checked) - - frame_i = posData.frame_i - if frame_i > 0: - frames_range = [frame_i-1, frame_i] - else: - frames_range = [frame_i] - - self.store_data(autosave=False) - - if not self.whitelistCheckOriginalLabels(): - return - if switch_to_og: - self.setFrameNavigationDisabled(True, why='Viewing original labels') - self.viewOriginalLabels = True - - for i in frames_range: - posData.frame_i = i - self.get_data() - self.whitelistTrackOGCurr(frame_i=i) - - IDs = posData.IDs - - og_frame = posData.whitelist.originalLabs[i].copy() - IDs_to_uppdate = posData.whitelist.whitelistIDs[i] & posData.whitelist.originalLabsIDs[i] - if IDs_to_uppdate: - mask = np.isin(og_frame, list(IDs_to_uppdate)) - og_frame[mask] = 0 - - mask = np.isin(posData.lab, list(IDs_to_uppdate)) - og_frame[mask] = posData.lab[mask] - - IDs_to_add = posData.whitelist.whitelistIDs[i] - posData.whitelist.originalLabsIDs[i] - if IDs_to_add: - mask = np.isin(posData.lab, list(IDs_to_add)) - og_frame[mask] = posData.lab[mask] - - posData.lab = og_frame - self.update_rp(wl_update=False) - self.store_data(autosave=False) - - if frame_i > 0: - missing_IDs = set(posData.IDs) - set(posData.allData_li[frame_i-1]['IDs']) - self.trackManuallyAddedObject(missing_IDs,isNewID=True, wl_update=False) - - self.setAllTextAnnotations() - self.updateAllImages() - - elif switch_to_seg: - self.viewOriginalLabels = False - self.setFrameNavigationDisabled(False, why='Viewing original labels') - - for i in frames_range: - posData.frame_i = i - self.get_data() - try: - posData.whitelist.originalLabs[i] = posData.lab.copy() - posData.whitelist.originalLabsIDs[i] = set(posData.IDs) - except AttributeError: - lab = posData.segm_data[i].copy() - IDs = [obj.label for obj in skimage.measure.regionprops(lab)] - posData.whitelist.originalLabs[i] = lab - posData.whitelist.originalLabsIDs[i] = set(IDs) - - # self.whitelistTrackCurrOG() - self.update_rp(wl_update=False) - self.store_data(autosave=False) - self.whitelistUpdateLab(frame_i=i) #has update_rp and store data - self.setAllTextAnnotations() - self.updateAllImages() - - def whitelistSetViewOGIDsToggle(self, checked: bool): - """Set the view original labels toggle button to checked or unchecked. - This also updates the self.viewOriginalLabels variable. - !!! Doesn't change the actually displayed labels, use self.whitelistViewOGIDs - to do that.!!! - - Parameters - ---------- - checked : bool - True if the original labels are shown, False otherwise. - """ - self.viewOriginalLabels = checked - self.whitelistIDsToolbar.viewOGToggle.blockSignals(True) - self.whitelistIDsToolbar.viewOGToggle.setChecked(checked) - self.whitelistIDsToolbar.viewOGToggle.blockSignals(False) - - def whitelistAddNewIDsToggled(self, checked: bool): - """Will set self.addNewIDsWhitelistToggle to checked and call - whitelistAddNewIDs if checked is True. - - Parameters - ---------- - checked : bool - True if the add new IDs toggle is checked, False otherwise. - """ - self.addNewIDsWhitelistToggle = checked - if checked: - self.df_settings.at['addNewIDsWhitelistToggle', 'value'] = 'Yes' - else: - self.df_settings.at['addNewIDsWhitelistToggle', 'value'] = 'No' - self.df_settings.to_csv(self.settings_csv_path) - if checked: - self.whitelistAddNewIDs(ignore_not_first_time=True) - self.whitelistPropagateIDs() - self.updateAllImages() - self.whitelistIDsUpdateText() - - def whitelistAddNewIDs(self, ignore_not_first_time:bool=False): - """Function which adds new IDs to the whitelist, based on the original labels. - It will check if the frame is visited the first time, unless - ignore_not_first_time is True. - It does nothing if self.addNewIDsWhitelistToggle is False. - !!!Careful, does not change the lab, just the whitelist!!! - - Parameters - ---------- - ignore_not_first_time : bool, optional - Weather it should be checked if the frame is visited - the first time, by default False - """ - mode = self.modeComboBox.currentText() - if mode != 'Segmentation and Tracking': - return - - if not self.addNewIDsWhitelistToggle: - return - - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - debug = posData.whitelist._debug - - if debug: - printl('whitelistAddNewIDs') - - posData = self.data[self.pos_i] - frame_i = posData.frame_i - - if self.get_last_tracked_i() > frame_i and not ignore_not_first_time: - return - - if frame_i == 0: - return - - if self.whitelistAddNewIDsFrame is not None and frame_i == self.whitelistAddNewIDsFrame: - return - - self.whitelistAddNewIDsFrame = frame_i - - curr_lab = self.get_curr_lab() - - posData.whitelist.addNewIDs(frame_i=frame_i, - allData_li=posData.allData_li, - IDs_curr=posData.IDs, - curr_lab=curr_lab) - - def whitelistIDsAccepted(self, - whitelistIDs: Set[int] | List[int]): - """Function which is called when the user accepts a whitelist. - Also initializes the whitelist if it is not already initialized. (Aka not loaded) - - Parameters - ---------- - whitelistIDs : set | list - The accepted IDs from the whitelist dialog. - """ - # Store undo state before modifying stuff - self.storeUndoRedoStates(False) - - self.whitelistIDsToolbar.viewOGToggle.setCheckable(True) - self.whitelistSetViewOGIDsToggle(False) - self.setFrameNavigationDisabled(False, why='Viewing original labels') - - self.store_data(autosave=False) - - posData = self.data[self.pos_i] - - if not posData.whitelist: - posData.whitelist = Whitelist( - total_frames=posData.SizeT, - ) - - if posData.whitelist._debug: - printl('whitelistIDsAccepted', whitelistIDs) - - whitelistIDs = set(whitelistIDs) - - IDs_curr = set(posData.IDs) - - posData.whitelist.IDsAccepted( - whitelistIDs, - segm_data=posData.segm_data, - frame_i=posData.frame_i, - allData_li=posData.allData_li, - IDs_curr=IDs_curr, - curr_lab=posData.lab, - - ) - - # self.whitelistPropagateIDs(new_whitelist=whitelistIDs, - # try_create_new_whitelists=True, - # only_future_frames=True, - # force_not_dynamic_update=True, - # update_lab=True - # ) - self.whitelistUpdateLab(track_og_curr=True) - - self.whitelistIDsUpdateText() - self.keepIDsTempLayerLeft.clear() - - def whitelistUpdateLab(self, frame_i: int=None, - track_og_curr=False, new_frame:bool=False, - IDs_to_add:List[int] | Set[int]=None, - IDs_to_remove:List[int]|Set[int]=None, - ): - # this should also work for 3D i think... - """Updates the displayed lab based on the whitelist. - - Parameters - ---------- - frame_i : int, optional - frame which should be updated. If not provided, - uses posData.frame_i, by default None - track_og_curr : bool, optional - if True, will track the original current IDs, by default False - new_frame : bool, optional - if True, will set the frame to the new frame, by default False - IDs_to_add : list, optional - IDs to add to the whitelist, by default None - IDs_to_remove : list, optional - IDs to remove from the whitelist, by default None - """ - got_data = False - benchmark = False - if benchmark: - ts = [time.perf_counter()] - titles = [ - '', - 'store_data', - 'whitelistSetViewOGIDsToggle', - 'get_data', - 'get what to add/remove', - 'track_og_curr', - 'get current lab', - 'add/remove IDs', - 'store data', - 'update images', - ] - - mode = self.modeComboBox.currentText() - if mode != 'Segmentation and Tracking': - return - - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - if frame_i is None: - frame_i = posData.frame_i - og_frame_i = frame_i - else: - og_frame_i = posData.frame_i - posData.frame_i = frame_i - # getting data is handles later in the code - - debug = posData.whitelist._debug - if debug: - printl('whitelistUpdateLab', frame_i, og_frame_i) - from . import debugutils - debugutils.print_call_stack() - - if benchmark: - ts.append(time.perf_counter()) - - self.whitelistSetViewOGIDsToggle(False) ### - - if benchmark: - ts.append(time.perf_counter()) - - if self.whitelistCheckOriginalLabels(warning=False, frame_i=frame_i): - og_lab = posData.whitelist.originalLabs[frame_i] ### - else: - og_lab = None - if benchmark: - ts.append(time.perf_counter()) - - #### - whitelist = posData.whitelist.get(frame_i=frame_i) - IDs_to_add_remove_provided = IDs_to_add is not None or IDs_to_remove is not None - if not IDs_to_add_remove_provided: - self.get_data() - got_data = True - current_IDs = set(posData.IDs) - missing_IDs = list(whitelist - current_IDs) - to_be_removed_IDs = list(current_IDs - whitelist) - else: - missing_IDs = list(IDs_to_add) if IDs_to_add is not None else [] - to_be_removed_IDs = list(IDs_to_remove) if IDs_to_remove is not None else [] - - ### - - if benchmark: - ts.append(time.perf_counter()) - - ### - if not missing_IDs and not to_be_removed_IDs: # nothing to do - if og_frame_i != frame_i: - posData.frame_i = og_frame_i - if got_data and og_frame_i != frame_i: - self.get_data() - if benchmark: - print('No IDs to add/remove') - ts.append(time.perf_counter()) - indx = titles.index('track_og_curr') - titles[indx + 1] = 'store_data' - time_taken = time.perf_counter() - ts[0] - print(f'\nTotal time for whitelistUpdateLab: {time_taken:.2f}s') - for i in range(1, len(ts)): - time_taken = ts[i] - ts[i-1] - print(f'Time taken for {titles[i]}: {time_taken:.2f}s') - print('') - return - - if not got_data and og_frame_i != frame_i: - self.get_data() - got_data = True - - if benchmark: - ts.append(time.perf_counter()) - - ### - if missing_IDs and track_og_curr and not new_frame: - self.whitelistTrackOGCurr(frame_i=frame_i, - lab = posData.lab, - rp = posData.rp) - - missing_IDs = np.array(missing_IDs, dtype=np.int32) - to_be_removed_IDs = np.array(to_be_removed_IDs, dtype=np.int32) - - if debug: - printl(missing_IDs, to_be_removed_IDs) - - curr_lab = posData.lab # or curr_lab = posData.lab??? - # convert values to int if they are not already - if curr_lab is None: - try: - curr_lab = posData.allData_li[frame_i]['labels'].copy() - except: - pass - if curr_lab is None: - try: - curr_lab = posData.segm_data[frame_i].copy() - except: - pass - if curr_lab is None: - printl('No current lab?') - curr_lab = np.zeros_like(posData.segm_data[0]) - curr_lab = curr_lab.astype(np.int32) - if benchmark: - ts.append(time.perf_counter()) - - if missing_IDs.size > 0 and og_lab is not None: - mask = np.isin(og_lab, missing_IDs) # add missing_IDs - curr_lab[mask] = og_lab[mask] - - if to_be_removed_IDs.size > 0: - curr_lab[np.isin(curr_lab, to_be_removed_IDs)] = 0 # remove to_be_removed_IDs - - if benchmark: - ts.append(time.perf_counter()) - - posData.lab = curr_lab - - self.update_rp(wl_update=False) - self.store_data() - - if benchmark: - ts.append(time.perf_counter()) - if og_frame_i != frame_i: - posData.frame_i = og_frame_i - self.get_data() - - self.updateAllImages() - self.setAllTextAnnotations() - - if benchmark: - ts.append(time.perf_counter()) - time_taken = time.perf_counter() - ts[0] - print(f'\nTotal time for whitelistUpdateLab: {time_taken:.2f}s') - for i in range(1, len(ts)): - time_taken = ts[i] - ts[i-1] - print(f'Time taken for {titles[i]}: {time_taken:.2f}s') - print('') - - def whitelistIDsUpdateText(self): - """Updates the text. Carefull, triggers whitelistLineEdit.textChanged! - """ - mode = self.modeComboBox.currentText() - if mode != 'Segmentation and Tracking': - return - - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - if posData.whitelist._debug: - printl('whitelistIDsUpdateText') - - frame_i = posData.frame_i - whitelist = posData.whitelist.get(frame_i=frame_i) - - self.whitelistIDsToolbar.whitelistLineEdit.setText(whitelist) - - def whitelistTrackOGCurr(self, frame_i:int=None, - against_prev:bool=False, - lab:np.ndarray=None, - rp:list=None, - IDs: Set[int] | List[int] =None): - """Track the original labels in relation to the current (whitelisted) - labels. - Parameters - - Parameters - ---------- - frame_i : int, optional - frame_i to be tracked, posData.frame_i if not provided, - by default None - against_prev : bool, optional - if the original frame should be tracked against frame_i-1. - Cannot be used with rp or lab, by default False - lab : np.ndarray, optional - lab to be tracked against, by default None - rp : list, optional - regionprops for this lab, by default None - IDs : Set[int] | List[int], optional - IDs that should be tracked based on og - - Raises - ------ - ValueError - Cannot provide both rp and lab when tracking against previous frame. - Instead only provide rp and lab, and dont set against_prev. - """ - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - debug = posData.whitelist._debug - - if debug: - from . import debugutils - debugutils.print_call_stack(depth=2) - printl('whitelistTrackOGCurr', against_prev) - - if against_prev and (rp is not None or lab is not None): - raise ValueError('Cannot provide both rp and lab when tracking' - ' against previous frame.' - 'Instead only provide rp and lab, and dont set against_prev.') - - if frame_i is None: - frame_i = posData.frame_i - - if against_prev and frame_i == 0: - return - - if not self.whitelistCheckOriginalLabels(warning=False, - frame_i=frame_i): - if debug: - printl('No original labels, cannot track.') - return - - og_frame_i = posData.frame_i - ### against what should I track? - - if lab is not None and not rp: - rp = skimage.measure.regionprops(lab) - - changed_frame = False - if lab is None: - if debug: - printl('No lab and no rp provided.') - if against_prev: - rp = posData.allData_li[frame_i-1]['regionprops'] - lab = posData.allData_li[frame_i-1]['labels'] - else: - if frame_i != og_frame_i: - self.store_data(autosave=False) - posData.frame_i = frame_i - self.get_data() - changed_frame = True - rp = posData.rp - lab = posData.lab - og_lab = posData.whitelist.originalLabs[frame_i] - og_rp = skimage.measure.regionprops(og_lab) - # lab = lab.copy() - - denom_overlap_matrix = 'union' if not against_prev else 'area_prev' - - og_lab = CellACDC_tracker.track_frame( - lab, rp, og_lab, og_rp, - denom_overlap_matrix=denom_overlap_matrix, - posData = posData, - setBrushID_func=self.setBrushID, - IDs=IDs, - # assign_unique_new_IDs=False, - ) - - posData.whitelist.originalLabs[frame_i] = og_lab - posData.whitelist.originalLabsIDs[frame_i] = {obj.label for obj in skimage.measure.regionprops(og_lab)} - - if changed_frame: - posData.frame_i = og_frame_i - self.get_data() - - def whitelistTrackCurrOG(self, frame_i:int=None, against_prev:bool=False): - """Track the current (whitelisted) labels in relation to the original labels. - Parameters - ---------- - frame_i : int, optional - frame_i to be tracked, posData.frame_i if not provided, by default None - against_prev : bool, optional - if the original frame should be tracked against frame_i-1. - """ - posData = self.data[self.pos_i] - if posData.whitelist is None: - return - - if posData.whitelist._debug: - printl('whitelistTrackCurrOG', frame_i, against_prev) - - if frame_i is None: - frame_i = posData.frame_i - - if against_prev and frame_i == 0: - return - - og_frame = posData.frame_i - if frame_i != og_frame: - self.store_data(autosave=False) - posData.frame_i = frame_i - self.get_data() - - lab = posData.lab - rp = posData.rp - - if not self.whitelistCheckOriginalLabels(warning=False, - frame_i=frame_i if not against_prev else frame_i-1): - if posData.whitelist._debug: - printl('No original labels, cannot track.') - return - - if against_prev: - og_lab = posData.whitelist.originalLabs[frame_i-1] - else: - og_lab = posData.whitelist.originalLabs[frame_i] - - og_rp = skimage.measure.regionprops(og_lab) - - denom_overlap_matrix = 'union' if not against_prev else 'area_prev' - - lab = CellACDC_tracker.track_frame( - og_lab, og_rp, lab, rp, - denom_overlap_matrix=denom_overlap_matrix, - posData = posData, - setBrushID_func=self.setBrushID - ) - - posData.lab = lab - - self.update_rp(wl_update=False) - self.store_data(autosave=False) - - if frame_i != og_frame: - posData.frame_i = og_frame - self.get_data() - - def whitelistSyncIDsOG(self, - frame_is: List[int]=None, - against_prev: bool=False,): - """Interates over the frames and calls whitelistTrackOGCurr for each frame. - - Parameters - ---------- - frame_is : List[int], optional - list of frame_i, if None goes through all, by default None - against_prev : bool, optional - if the original frame should be tracked against frame_i-1. - """ - posData = self.data[self.pos_i] - if frame_is is None: - frame_is = range(posData.SizeT) - - for frame_i in frame_is: - self.whitelistTrackOGCurr(frame_i=frame_i, against_prev=against_prev) - - def whitelistInitNewFrames(self, frame_i:int=None, force:bool=False): - """Initialize the whitelist for a new frame. The class whitelist keeps track - of the init frames and doesnt try to init them again, unless forced. - Does not init the class! - - Parameters - ---------- - frame_i : int, optional - frame_i to be init, posData.frame_i if not provided, by default None - force : bool, optional - if the init should be forced, by default False - - Returns - ------- - bool - if the frame was new or not - list - list of frames that were updated, and info about added/removed IDs - """ - - posData = self.data[self.pos_i] - if posData.whitelist is None: - return False, [] - - if frame_i is None: - frame_i = posData.frame_i - - if posData.whitelist._debug: - printl('whitelistInitNewFrames', frame_i, force) - - if frame_i not in posData.whitelist.initialized_i: - self.whitelistTrackOGCurr(frame_i=frame_i, against_prev=True) - - new_frame, update_frames = posData.whitelist.initNewFrames( - frame_i=frame_i, force=force) - - self.whitelistAddNewIDs() - return new_frame, update_frames - - # @exec_time - def whitelistPropagateIDs(self, - new_whitelist: Set[int] | List[int] = None, - IDs_to_add: Set[int] = None, - IDs_to_remove: Set[int] = None, - frame_i: int = None, - try_create_new_whitelists: bool = False, - curr_frame_only: bool = False, - force_not_dynamic_update: bool = False, - only_future_frames: bool = True, - allow_only_current_IDs: bool = False, - track_og_curr: bool = True, - IDs_curr: Set[int] | List[int] = None, - index_lab_combo: Tuple[int, np.ndarray] = None, - curr_rp: list = None, - curr_lab: np.ndarray = None, - store_data: bool = True, - update_lab: bool = False, - ): - """ - Propagates whitelist IDs across frames in the dataset. (Doesnt update labs) - Should also be called when viewing a new frame! - - This function updates whitelist. If curr_frame_only is True, it only updates the - whitelist of the current frame. If the frame changes, this function should be called - again to update the whitelist for the new frame (without this argument). - It should also handle cases were this is not done, but this is less safe. - Then, all the additions and removals are propagated to the other frames. - If force_not_dynamic_update is True, the function will propagate the entire whitelist to - frames, and not only the IDs which were added or removed. - - Hierarchy of arguments for current_IDs: - 1. IDs_curr (if provided) - (2. index_lab_combo (if provided) (is also passed to not current frame only - propagation if that propagation is necessary, and used when the frame_i matches)) - 3. curr_rp (if provided) - 4. curr_lab (if provided) - 5. allData_li - - Parameters - ---------- - new_whitelist : Set[int] | List[int], optional - A new set of whitelist IDs to replace the current whitelist. Cannot be - used together with `IDs_to_add` or `IDs_to_remove`, by default None. - IDs_to_add : Set[int], optional - A set of IDs to add to the current whitelist, by default None. - IDs_to_remove : Set[int], optional - A set of IDs to remove from the current whitelist, by default None. - frame_i : int, optional - The frame index for the propagation. - If None, uses posData.frame_i, by default None. - try_create_new_whitelists : bool, optional - If True, creates new whitelist entries for frames that do not already - have them. Should only be necessary when its initialized, by default False. - curr_frame_only : bool, optional - If True, only updates the whitelist for the current frame. - (See description of function), by default False. - force_not_dynamic_update : bool, optional - If True, disables dynamic updates to the whitelist. - (See description of function), by default False. - only_future_frames : bool, optional - If True, propagates changes only to future frames, by default True. - allow_only_current_IDs : bool, optional - If True, only allows IDs that are present in the current frame - to be added to the whitelist, by default True. - track_og_curr : bool, optional - If True, tracks the original labels in relation to the current - (whitelisted) labels. This is done by calling whitelistTrackOGCurr. - If its a new frame, this is done in whitelistInitNewFrames against the - previous frame, - by default True. - IDs_curr : Set[int] | List[int], optional - A set of IDs for the current frame, if None, - will be calculated from other stuff (see description), by default None. - index_lab_combo : Tuple[int, np.ndarray], optional - Combination of frame_i and current frame, - Used to get IDs_curr (see description), when the frame_i matches - (is also passed to not current frame only - propagation if that propagation is necessary, - and used when the frame_i matches), by default None. - curr_rp : list, optional - Region properties for the current frame. For IDs_curr. (see description), - by default None. - curr_lab : np.ndarray, optional - Labels for the current frame for IDs_curr. (see description), - by default None. - store_data : bool, optional - If True, stores the data before propagating the IDs. - update_lab : bool, optional - If True, updates the labels after propagating the IDs. - Will always update labels for newly init frames, by default False. - - Raises - ------ - ValueError - If both `new_whitelistIDs` and `IDs_to_add`/`IDs_to_remove` are provided. - - Example - ------- - To add IDs 5 and 6 to the whitelist for the current frame: - ```python - self.whitelistPropagateIDs(IDs_to_add={5, 6}, curr_frame_only=True) - ``` - Then when the frame changes: - ```python - self.whitelistPropagateIDs() - ``` - - To replace the whitelist for frame 10 with a new set of IDs: - ```python - self.whitelistPropagateIDs(new_whitelistIDs={1, 2, 3}, frame_i=10) - ``` - This would also propagate the changes to all other frames. - - """ - #doesnt update the frame displayed, only wl - try: # safety XD - IDs_curr = IDs_curr.copy() - except AttributeError: - pass - - IDs_curr = set(IDs_curr) if IDs_curr is not None else None - - posData = self.data[self.pos_i] - - debug = posData.whitelist._debug if posData.whitelist is not None else False - - if debug: - printl('Propagating IDs...') - from . import debugutils - debugutils.print_call_stack() - printl(new_whitelist, IDs_to_add, IDs_to_remove) - - if posData.whitelist is None: - return - - # og_frame_i = posData.frame_i - if frame_i is None: - frame_i = posData.frame_i - - new_frame, update_frames_init = self.whitelistInitNewFrames(frame_i=frame_i) - - if new_frame: - self.update_rp(wl_update=False) - # if track_og_curr and not new_frame: - # self.whitelistTrackOGCurr(frame_i=frame_i, rp=curr_rp, lab=curr_lab) - - update_frames = posData.whitelist.propagateIDs( - frame_i, - posData.allData_li, - new_whitelist=new_whitelist, - IDs_to_add=IDs_to_add, - IDs_to_remove=IDs_to_remove, - try_create_new_whitelists=try_create_new_whitelists, - curr_frame_only=curr_frame_only, - force_not_dynamic_update=force_not_dynamic_update, - only_future_frames=only_future_frames, - allow_only_current_IDs=allow_only_current_IDs, - IDs_curr=IDs_curr, - index_lab_combo=index_lab_combo, - curr_rp=curr_rp, - curr_lab=curr_lab, - ) - if update_lab: - update_frames = update_frames_init + update_frames - else: - update_frames = update_frames_init - # printl(posData.whitelistIDs[frame_i]) - # posData.frame_i = og_frame_i - self.whitelistIDsUpdateText() - if store_data: - self.store_data(autosave=False) - - for frame_i, IDs_to_add, IDs_to_remove, new_frame in update_frames: - self.whitelistUpdateLab(frame_i=frame_i, track_og_curr=track_og_curr, - new_frame=new_frame, IDs_to_add=IDs_to_add, - IDs_to_remove=IDs_to_remove, ) - - def whitelistIDs_cb(self, checked:bool): - """Callback for when the whitelist IDs button is checked or unchecked. - Initialises the pointlayer and the whitelist IDs toolbar if checked. - - Parameters - ---------- - checked : bool - True if the whitelist IDs button is checked, False otherwise. - """ - if checked: - self.initKeepObjLabelsLayers() - self.disconnectLeftClickButtons() - self.uncheckLeftClickButtons(self.whitelistIDsButton) - self.connectLeftClickButtons() - - self.whitelistIDsToolbar.setVisible(checked) - self.whitelistHighlightIDs(checked) - self.whitelistIDsUpdateText() - self.whitelistUpdateTempLayer() - - if not checked: - self.setLostNewOldPrevIDs() - self.updateAllImages() - - def whitelistHighlightIDs(self, checked:bool=True): - """Highlights the IDs in the current frame based on the whitelist. - - Parameters - ---------- - checked : bool, optional - If False, will delete all highlights, by default True - """ - if not checked: - self.removeHighlightLabelID() - return - - posData = self.data[self.pos_i] - - if posData.whitelist is None: - if not hasattr(self, 'tempWhitelistIDs'): - self.tempWhitelistIDs = set() # not updated, only use in this context - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = posData.whitelist.get( - frame_i=posData.frame_i) - - for ID in current_whitelist: - self.highlightLabelID(ID) - - def whitelistIDsChanged(self, - whitelistIDs: Set[int] | List[int], - debug: bool=False): - """Callback for when the whitelist IDs are changed. - This is called when the user changed the IDs in the whitelist IDs toolbar - (or when its programmatically changed, but if its not - visible it should return instantly) - Will update the temp layer and also complain when IDs - are not valid/present in the current lab - - Parameters - ---------- - whitelistIDs : set | list - The IDs that are currently in the whitelist. - debug : bool, optional - debug, by default False - """ - if not self.whitelistIDsButton.isChecked(): - return - - posData = self.data[self.pos_i] - - if posData.whitelist: - debug = posData.whitelist._debug - if debug: - printl('whitelistIDsChanged', whitelistIDs) - - if posData.whitelist is None: - wl_init = False - if not hasattr(self, 'tempWhitelistIDs'): - self.tempWhitelistIDs = set() # not updated, only use in this context - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = self.tempWhitelistIDs - else: - wl_init = True - current_whitelist = posData.whitelist.get( - frame_i=posData.frame_i) - - current_whitelist_copy = current_whitelist.copy() - if not hasattr(posData, 'originalLabsIDs') or posData.whitelist.originalLabsIDs is None: - possible_IDs = posData.IDs.copy() - else: - if not self.whitelistCheckOriginalLabels(warning=False): - possible_IDs = set(posData.IDs) - else: - possible_IDs = posData.whitelist.originalLabsIDs[posData.frame_i] - possible_IDs.update(posData.IDs) - - isAnyIDnotExisting = False - for ID in whitelistIDs: - if ID not in possible_IDs: - isAnyIDnotExisting = True - continue - if ID not in current_whitelist_copy: - current_whitelist.add(ID) - self.highlightLabelID(ID) - - for ID in current_whitelist_copy: - if ID not in possible_IDs: - isAnyIDnotExisting = True - continue - if ID not in whitelistIDs: - current_whitelist.remove(ID) - self.removeHighlightLabelID(IDs=[ID]) - - if wl_init: - posData.whitelist.whitelistIDs[posData.frame_i] = current_whitelist - else: - self.tempWhitelistIDs = current_whitelist - - self.whitelistUpdateTempLayer() - if isAnyIDnotExisting: - self.whitelistIDsToolbar.whitelistLineEdit.warnNotExistingID() - else: - self.whitelistIDsToolbar.whitelistLineEdit.setInstructionsText() - - # @exec_time - def whitelistUpdateTempLayer(self): - """Updates the temp layer with the current whitelist IDs. - """ - if not self.whitelistIDsButton.isChecked(): - self.keepIDsTempLayerLeft.clear() - return - - if not hasattr(self, 'keptLab'): - self.keptLab = np.zeros_like(self.currentLab2D) - keptLab = self.keptLab - else: - keptLab = self.keptLab - keptLab[:] = 0 - - posData = self.data[self.pos_i] - if posData.whitelist is None: - if not hasattr(self, 'tempWhitelistIDs'): - self.tempWhitelistIDs = set() # not updated, only use in this context - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = self.tempWhitelistIDs - else: - current_whitelist = posData.whitelist.get(posData.frame_i) - - for obj in posData.rp: - if obj.label not in current_whitelist: - continue - - if not self.isObjVisible(obj.bbox): - continue - - _slice = self.getObjSlice(obj.slice) - _objMask = self.getObjImage(obj.image, obj.bbox) - - keptLab[_slice][_objMask] = obj.label - - self.keepIDsTempLayerLeft.setImage(keptLab, autoLevels=False) \ No newline at end of file + return update_frames diff --git a/cellacdc/widgets.py b/cellacdc/widgets.py deleted file mode 100755 index b22e1cec4..000000000 --- a/cellacdc/widgets.py +++ /dev/null @@ -1,12086 +0,0 @@ -from collections import defaultdict, deque -from typing import Dict, List, Union, Iterable, Sequence -import os -import sys -import operator -import time -import re -import datetime -import numpy as np -import pandas as pd -import math -import traceback -import logging -import textwrap -import random - -from functools import partial -from math import ceil - -import skimage.draw -import skimage.morphology - -from matplotlib.colors import ListedColormap, LinearSegmentedColormap -import matplotlib.pyplot as plt -import matplotlib -from matplotlib.backends.backend_agg import FigureCanvasAgg - -from qtpy.QtCore import ( - Signal, QTimer, Qt, QPoint, QUrl, Property, - QPropertyAnimation, QEasingCurve, QLocale, - QSize, QRect, QPointF, QRect, QPoint, QEasingCurve, QRegularExpression, - QEvent, QEventLoop, QPropertyAnimation, QObject, - QItemSelectionModel, QAbstractListModel, QModelIndex, - QByteArray, QDataStream, QMimeData, QAbstractItemModel, - QIODevice, QItemSelection, PYQT6, QRectF -) -from qtpy.QtGui import ( - QFont, QPalette, QColor, QPen, QKeyEvent, QBrush, QPainter, - QRegularExpressionValidator, QIcon, QPixmap, QKeySequence, QLinearGradient, - QShowEvent, QDesktopServices, QFontMetrics, QGuiApplication, QLinearGradient, - QImage, QCursor, QPicture -) -from qtpy.QtWidgets import ( - QTextEdit, QLabel, QProgressBar, QHBoxLayout, QToolButton, QCheckBox, - QApplication, QWidget, QVBoxLayout, QMainWindow, QTreeWidgetItemIterator, - QLineEdit, QSlider, QSpinBox, QGridLayout, QRadioButton, - QScrollArea, QSizePolicy, QComboBox, QPushButton, QScrollBar, - QGroupBox, QAbstractSlider, QDoubleSpinBox, QWidgetAction, - QAction, QTabWidget, QAbstractSpinBox, QToolBar, QStyleOptionSpinBox, - QStyle, QDialog, QSpacerItem, QFrame, QMenu, QActionGroup, - QListWidget, QPlainTextEdit, QFileDialog, QListView, QAbstractItemView, - QTreeWidget, QTreeWidgetItem, QListWidgetItem, QLayout, QStylePainter, - QGraphicsBlurEffect, QGraphicsProxyWidget, QGraphicsObject, - QButtonGroup, QStyleOptionSlider -) -import qtpy.compat - -import pyqtgraph as pg -pg.setConfigOption('imageAxisOrder', 'row-major') - -from . import myutils, measurements, is_mac, is_win, html_utils, is_linux -from . import printl, settings_folderpath -from . import colors, config -from . import html_path -from . import _palettes -from . import load -from . import apps -from . import plot -from . import annotate -from . import urls -from . import _core, core -from . import QtScoped -from . import prompts -from .acdc_regex import float_regex -from .config import PREPROCESS_MAPPER -from . import _base_widgets - -LINEEDIT_WARNING_STYLESHEET = _palettes.lineedit_warning_stylesheet() -LINEEDIT_INVALID_ENTRY_STYLESHEET = _palettes.lineedit_invalid_entry_stylesheet() -TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() -LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() -BASE_COLOR = _palettes.base_color() -PROGRESSBAR_QCOLOR = _palettes.QProgressBarColor() -PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR = _palettes.QProgressBarHighlightedTextColor() -TEXT_COLOR = _palettes.text_float_rgba() - -font = QFont() -font.setPixelSize(12) - -custom_cmaps_filepath = os.path.join(settings_folderpath, 'custom_colormaps.ini') - -str_to_operator_mapper = { - "+": operator.add, - "-": operator.sub -} - -sign_int_mapper = { - '+': 1, '-': -1 -} - -def removeHSVcmaps(): - hsv_cmaps = [] - for g, grad in pg.graphicsItems.GradientEditorItem.Gradients.items(): - if grad['mode'] == 'hsv': - hsv_cmaps.append(g) - for g in hsv_cmaps: - del pg.graphicsItems.GradientEditorItem.Gradients[g] - -def renamePgCmaps(): - Gradients = pg.graphicsItems.GradientEditorItem.Gradients - try: - Gradients['hot'] = Gradients.pop('thermal') - except KeyError: - pass - try: - Gradients.pop('greyclip') - except KeyError: - pass - -def _tab20gradient(): - cmap = plt.get_cmap('tab20') - ticks = [ - (t, tuple([int(v*255) for v in cmap(t)])) for t in np.linspace(0,1,20) - ] - gradient = {'ticks': ticks, 'mode': 'rgb'} - return gradient - -def _tab10gradient(): - cmap = plt.get_cmap('tab10') - ticks = [ - (t, tuple([int(v*255) for v in cmap(t)])) for t in np.linspace(0,1,20) - ] - gradient = {'ticks': ticks, 'mode': 'rgb'} - return gradient - -def getCustomGradients(name='image'): - CustomGradients = {} - if not os.path.exists(custom_cmaps_filepath): - return CustomGradients - - cp = config.ConfigParser() - cp.read(custom_cmaps_filepath) - for section in cp.sections(): - if not section.startswith(f'{name}'): - continue - - cmap_name = section[len(f'{name}.'):] - CustomGradients[cmap_name] = {'ticks': [], 'mode': 'rgb'} - for option in cp.options(section): - value = cp[section][option] - pos, *rgb = value.split(',') - rgb = tuple([int(c) for c in rgb]) - pos = float(pos) - CustomGradients[cmap_name]['ticks'].append((pos, rgb)) - return CustomGradients - -def addGradients(): - Gradients = pg.graphicsItems.GradientEditorItem.Gradients - Gradients['cividis'] = { - 'ticks': [ - (0.0, (0, 34, 78, 255)), - (0.25, (66, 78, 108, 255)), - (0.5, (124, 123, 120, 255)), - (0.75, (187, 173, 108, 255)), - (1.0, (254, 232, 56, 255))], - 'mode': 'rgb' - } - Gradients['cool'] = { - 'ticks': [ - (0.0, (0, 255, 255, 255)), - (1.0, (255, 0, 255, 255))], - 'mode': 'rgb' - } - Gradients['sunset'] = { - 'ticks': [ - (0.0, (71, 118, 148, 255)), - (0.4, (222, 213, 141, 255)), - (0.8, (229, 184, 155, 255)), - (1.0, (240, 127, 97, 255))], - 'mode': 'rgb' - } - Gradients['tab20'] = _tab20gradient() - Gradients['tab10'] = _tab10gradient() - cmaps = {} - for name, gradient in Gradients.items(): - ticks = gradient['ticks'] - colors = [tuple([v/255 for v in tick[1]]) for tick in ticks] - cmaps[name] = LinearSegmentedColormap.from_list(name, colors, N=256) - return cmaps, Gradients - -nonInvertibleCmaps = ['cool', 'sunset', 'bipolar'] - -renamePgCmaps() -removeHSVcmaps() -cmaps, Gradients = addGradients() -GradientsLabels = Gradients.copy() -GradientsImage = Gradients.copy() - -class XStream(QObject): - _stdout = None - _stderr = None - messageWritten = Signal(str) - - def flush( self ): - pass - - def fileno( self ): - return -1 - - def write(self, msg): - if not self.signalsBlocked(): - self.messageWritten.emit(msg) - - @staticmethod - def stdout(): - if not XStream._stdout: - XStream._stdout = XStream() - sys.stdout = XStream._stdout - return XStream._stdout - - @staticmethod - def stderr(): - if not XStream._stderr: - XStream._stderr = XStream() - sys.stderr = XStream._stderr - return XStream._stderr - -class QtHandler(logging.Handler): - def __init__(self): - super().__init__() - - def emit(self, record): - record = self.format(record) - if record: - XStream.stdout().write('%s\n'%record) - -class QLog(QPlainTextEdit): - sigClose = Signal() - - def __init__(self, *args, logger=None): - super().__init__(*args) - self.logger = logger - self.setReadOnly(True) - - def connect(self): - XStream.stdout().messageWritten.connect(self.writeStdOutput) - # XStream.stderr().messageWritten.connect(self.writeStdErr) - - def writeStdOutput(self, text: str) -> None: - super().insertPlainText(text) - self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) - - def writeStdErr(self, text: str) -> None: - super().insertPlainText(text) - self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) - if self.logger is not None: - self.logger.exception(text) - - def insertPlainText(self, text: str) -> None: - super().insertPlainText(f'{text}\n') - self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) - - def closeEvent(self, event) -> None: - super().closeEvent(event) - self.sigClose.emit() - -class PushButton(QPushButton): - def __init__( - self, *args, icon=None, alignIconLeft=False, - flat=False, hoverable=False - ): - super().__init__(*args) - if icon is not None: - self.setIcon(icon) - self.alignIconLeft = alignIconLeft - self._text = None - if flat: - self.setFlat(True) - if hoverable: - self.installEventFilter(self) - - def setRetainSizeWhenHidden(self, retainSize): - sp = self.sizePolicy() - sp.setRetainSizeWhenHidden(retainSize) - self.setSizePolicy(sp) - - def eventFilter(self, object, event): - if event.type() == QEvent.Type.HoverEnter: - self.setFlat(False) - elif event.type() == QEvent.Type.HoverLeave: - self.setFlat(True) - return False - - def show(self): - text = self.text() - if not self.alignIconLeft: - super().show() - return - - self._text = text - self.setStyleSheet('text-align:left;') - self.setLayout(QGridLayout()) - textLabel = QLabel(self._text) - textLabel.setAlignment(Qt.AlignRight | Qt.AlignVCenter) - textLabel.setAttribute(Qt.WA_TransparentForMouseEvents, True) - self._layout().addWidget(textLabel) - super().show() - - def confirmAction(self): - self.baseIcon = self.icon() - self.setIcon(QIcon(':greenTick.svg')) - QTimer.singleShot(2000, self.resetButton) - - def resetButton(self): - self.setIcon(self.baseIcon) - - def setText(self, text): - if self._text is None: - super().setText(text) - else: - super().setText(self._text) - -class LoadPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':fork_lift.svg')) - -class mergePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':merge-IDs.svg')) - -class okPushButton(PushButton): - def __init__(self, *args, isDefault=True, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':yesGray.svg')) - if isDefault: - self.setDefault(True) - # QShortcut(Qt.Key_Return, self, self.click) - # QShortcut(Qt.Key_Enter, self, self.click) - -class MagnifyingGlassPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':magnGlass.svg')) - -class MagnifyingGlassAllPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':magnGlass_all.svg')) - -class AssignNewIDButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':assign_new_id.svg')) - -class LockPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':lock.svg')) - self.toggled.connect(self.onToggled) - - def onToggled(self, checked): - if not self.isCheckable(): - return - - if checked: - self.setIcon(QIcon(':lock_closed.svg')) - else: - self.setIcon(QIcon(':lock_open.svg')) - - def setCheckable(self, checkable: bool): - super().setCheckable(checkable) - if checkable: - self.setIcon(QIcon(':lock_open.svg')) - else: - self.setIcon(QIcon(':lock.svg')) - - -class SkipPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':skip_arrow.svg')) - -class BedPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':bed.svg')) - -class BedPlusLabelPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':bed_plus_label.svg')) - iconH = self.iconSize().height() - iconW = int(iconH*2.5) - self.setIconSize(QSize(iconW, iconH)) - -class NoBedPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':no_bed.svg')) - -class NavigatePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':navigate.svg')) - -class SwitchPlaneButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':switch_2d_plane.svg')) - self._planes = ('xy', 'zy', 'zx') - self._idx = 0 - - def switchPlane(self): - self._idx += 1 - - def setPlane(self, plane): - self._idx = self._planes.index(plane) - - def plane(self): - return self._planes[self._idx % 3] - - def depthAxes(self): - plane = self.plane() - for axes in 'xyz': - if axes not in plane: - return axes - -class zoomPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':zoom_out.svg')) - - def setIconZoomOut(self): - self.setIcon(QIcon(':zoom_out.svg')) - - def setIconZoomIn(self): - self.setIcon(QIcon(':zoom_in.svg')) - -class WarningButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':warning.svg')) - -class reloadPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':reload.svg')) - -class savePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':file-save.svg')) - -class autoPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':cog_play.svg')) - -class newFilePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':file-new.svg')) - -class helpPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':help.svg')) - -class viewPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':eye.svg')) - -class infoPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':info.svg')) - -class threeDPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':3d.svg')) - -class twoDPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':2d.svg')) - -class addPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':add.svg')) - -class futurePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':arrow_future.svg')) - -class FutureAllPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':arrow_future_all.svg')) - -class currentPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':arrow_current.svg')) - -class arrowUpPushButton(PushButton): - def __init__(self, *args, **kwargs): - alignIconLeft = kwargs.get('alignIconLeft', False) - super().__init__( - *args, icon=QIcon(':arrow-up.svg'), alignIconLeft=alignIconLeft - ) - -class arrowDownPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':arrow-down.svg')) - -class selectAllPushButton(PushButton): - sigClicked = Signal(object, bool) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._status = 'deselect' - self.setIcon(QIcon(':deselect_all.svg')) - self.setText('Deselect all') - self.clicked.connect(self.onClicked) - self.setMinimumWidth(self.sizeHint().width()) - - def setChecked(self, checked): - if checked: - self._status == 'deselect' - else: - self._status == 'select' - self.click() - - def onClicked(self): - if self._status == 'select': - icon_fn = ':deselect_all.svg' - self._status = 'deselect' - checked = True - text = 'Deselect all' - else: - icon_fn = ':select_all.svg' - text = 'Select all' - self._status = 'select' - checked = False - self.setIcon(QIcon(icon_fn)) - self.setText(text) - self.sigClicked.emit(self, checked) - -class subtractPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':subtract.svg')) - -class continuePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':continue.svg')) - -class calcPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':calc.svg')) - -class playPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':play.svg')) - -class stopPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':stop.svg')) - -class copyPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':edit-copy.svg')) - self.clicked.connect(self.onClicked) - self._text_to_copy = None - - def setTextToCopy(self, text): - self._text_to_copy = text - - def onClicked(self): - self._original_text = self.text() - if self._text_to_copy is not None: - cb = QApplication.clipboard() - cb.clear(mode=cb.Clipboard) - cb.setText(self._text_to_copy, mode=cb.Clipboard) - - super().setText('Copied!') - self.setIcon(QIcon(':greenTick.svg')) - QTimer.singleShot(2000, self.resetButton) - - def resetButton(self): - self.setText(self._original_text) - self.setIcon(QIcon(':edit-copy.svg')) - -class OpenFilePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':folder-open.svg')) - -class movePushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':folder-move.svg')) - -class DownloadPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':download.svg')) - -class showInFileManagerButton(PushButton): - def __init__(self, *args, setDefaultText=False, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':drawer.svg')) - self._path_to_browse = None - if setDefaultText: - self.setDefaultText() - - def setDefaultText(self): - self._text = myutils.get_show_in_file_manager_text() - self.setText(self._text) - - def setPathToBrowse(self, path: os.PathLike): - self._path_to_browse = path - self.clicked.connect(partial(myutils.showInExplorer, path)) - - - -class OpenUrlButton(PushButton): - def __init__(self, url, *args, **kwargs): - self._url = url - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':browser.svg')) - self.clicked.connect(self.openUrl) - - def openUrl(self): - QDesktopServices.openUrl(QUrl(self._url)) - -class LessThanPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':less_than.svg')) - flat = kwargs.get('flat') - if flat is not None: - self.setFlat(True) - -class showDetailsButton(PushButton): - sigToggled = Signal(bool) - - def __init__(self, *args, txt='Show details...', parent=None): - super().__init__(txt, parent) - # self.setText(txt) - self.txt = txt - self.checkedIcon = QIcon(':hideUp.svg') - self.uncheckedIcon = QIcon(':showDown.svg') - self.setIcon(self.uncheckedIcon) - self.toggled.connect(self.onClicked) - self.setCheckable(True) - w = self.sizeHint().width() + 10 - self.setFixedWidth(w) - - def onClicked(self, checked): - if checked: - self.setText(self.txt.replace('Show', 'Hide')) - self.setIcon(self.checkedIcon) - else: - self.setText(self.txt) - self.setIcon(self.uncheckedIcon) - - self.sigToggled.emit(checked) - -class cancelPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':cancelButton.svg')) - -class setPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':cog.svg')) - -class TrainPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':train.svg')) - -class noPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':no.svg')) - -class editPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':edit-id.svg')) - -class delPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':bin.svg')) - -class eraserPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':eraser.svg')) - -class CrossCursorPointButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':cross_cursor.svg')) - -class TestPushButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':test.svg')) - -class browseFileButton(PushButton): - sigPathSelected = Signal(str) - - def __init__( - self, *args, ext=None, title='Select file', start_dir='', - openFolder=False, **kwargs - ): - """PushButton with sigPathSelected Signal to select file or folder - - Parameters - ---------- - ext : dict or None, optional - If not None, this is a dictionary of - {'FILE NAME': ['.ext1', '.ext2', ...]}. - For example, to allow only selection of CSV files, - pass {'CSV': ['.csv']}. - - Note that the 'FILE NAME' is arbitrary. Default is None - title : str, optional - Title of the File Manager window. Default is 'Select file' - start_dir : str, optional - Directory where the File Manager window will initially be open. - Default is '' - openFolder : bool, optional - If True, allows for selection of folders instead of files. - Default is False - """ - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':folder-open.svg')) - self.clicked.connect(self.browse) - - self._title = title - self._start_dir = start_dir - self._openFolder = openFolder - self._file_types = 'All Files (*)' - if ext is not None: - s_li = [] - for name, extensions in ext.items(): - _s = '' - if isinstance(extensions, str): - extensions = [extensions] - for ext in extensions: - _s = f'{_s}*{ext} ' - s_li.append(f'{name} {_s.strip()}') - - self._file_types = ';;'.join(s_li) - self._file_types = f'{self._file_types};;All Files (*)' - - def setStartPath(self, start_path): - self._start_dir = start_path - - def browse(self): - if self._openFolder: - fileDialog = QFileDialog.getExistingDirectory - args = (self, self._title, self._start_dir) - else: - fileDialog = QFileDialog.getOpenFileName - args = (self, self._title, self._start_dir, self._file_types) - file_path = fileDialog(*args) - if not isinstance(file_path, str): - file_path = file_path[0] - if file_path: - self.sigPathSelected.emit(file_path) - -def getPushButton(buttonText, qparent=None): - isCancelButton = ( - buttonText.lower().find('cancel') != -1 - or buttonText.lower().find('abort') != -1 - ) - isYesButton = ( - buttonText.lower().find('yes') != -1 - or buttonText.lower().find('ok') != -1 - or buttonText.lower().find('continue') != -1 - or buttonText.lower().find('recommended') != -1 - ) - isSettingsButton = buttonText.lower().find('set') != -1 - isNoButton = ( - buttonText.replace(' ', '').lower() == 'no' - or buttonText.lower().find('Do not ') != -1 - or buttonText.lower().find('no, ') != -1 - ) - isDelButton = buttonText.lower().find('delete') != -1 - isAddButton = buttonText.lower().find('add ') != -1 - is3Dbutton = buttonText.find(' 3D ') != -1 - is2Dbutton = buttonText.find(' 2D ') != -1 - isSaveButton = buttonText.lower().find('overwrite') != -1 - isNewFileButton = buttonText.lower().find('rename') != -1 - isTryAgainButton = buttonText.lower().find('try again') != -1 - - if isCancelButton: - button = cancelPushButton(buttonText, qparent) - if qparent is not None: - qparent.addCancelButton(button=button) - elif isYesButton: - button = okPushButton(buttonText, qparent) - if qparent is not None: - qparent.okButton = button - elif isSettingsButton: - button = setPushButton(buttonText, qparent) - elif isNoButton: - button = noPushButton(buttonText, qparent) - elif isDelButton: - button = delPushButton(buttonText, qparent) - elif isAddButton: - button = addPushButton(buttonText, qparent) - elif is3Dbutton: - button = threeDPushButton(buttonText, qparent) - elif is2Dbutton: - button = twoDPushButton(buttonText, qparent) - elif isSaveButton: - button = savePushButton(buttonText, qparent) - elif isNewFileButton: - button = newFilePushButton(buttonText, qparent) - elif isTryAgainButton: - button = reloadPushButton(buttonText, qparent) - else: - button = QPushButton(buttonText, qparent) - - return button, isCancelButton - -def CustomGradientMenuAction(gradient: QLinearGradient, name: str, parent): - pixmap = QPixmap(100, 15) - painter = QPainter(pixmap) - brush = QBrush(gradient) - painter.fillRect(QRect(0, 0, 100, 15), brush) - painter.end() - label = QLabel() - label.setPixmap(pixmap) - label.setContentsMargins(1, 1, 1, 1) - labelName = QLabel(name) - hbox = QHBoxLayout() - delButton = delPushButton() - hbox.addWidget(labelName) - hbox.addStretch(1) - hbox.addWidget(label) - hbox.addWidget(delButton) - widget = QWidget() - widget.setLayout(hbox) - action = QWidgetAction(parent) - action.name = name - action.setDefaultWidget(widget) - action.delButton = delButton - delButton.action = action - return action - -class ContourItem(pg.PlotCurveItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - self._prevData = None - - def clear(self): - try: - self.setData([], []) - except AttributeError as e: - pass - - def tempClear(self): - try: - self._prevData = [d.copy() for d in self.getData()] - self.clear() - except Exception as e: - pass - - def restore(self): - if self._prevData is not None: - if self._prevData[0] is not None: - self.setData(*self._prevData) - -class BaseScatterPlotItem(pg.ScatterPlotItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - - def tempClear(self): - try: - self._prevData = [d.copy() for d in self.getData()] - self.setData([], []) - except Exception as e: - pass - - def restore(self): - if self._prevData is not None: - if self._prevData[0] is not None: - self.setData(*self._prevData) - -class VerticalSpacerEmptyWidget(QWidget): - def __init__(self, parent=None, height=5) -> None: - super().__init__(parent) - self.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum - ) - self.setFixedHeight(height) - -class CustomAnnotationScatterPlotItem(BaseScatterPlotItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - -class ElidingLineEdit(QLineEdit): - def __init__(self, parent=None, minWidth=None): - super().__init__(parent) - self._text = '' - self._minWidth = minWidth - if minWidth is not None: - self.setMinimumWidth(minWidth) - - self.textEdited.connect(self.setText) - self.installEventFilter(self) - self._elide = True - - def setText(self, text: str, width=None, elide=True) -> None: - if width is None: - width = self._minWidth - - if width is None: - try: - textToPrevRatio = len(text)/len(self.text()) - width = round(self.width()*textToPrevRatio) - except ZeroDivisionError: - width = self.width() - - if width > self.width(): - width = self.width() - - self._text = text - if not elide or not self._elide: - super().setText(text) - return - - fm = QFontMetrics(self.font()) - elidedText = fm.elidedText(text, Qt.ElideLeft, width) - - super().setText(elidedText) - self.setToolTip(text) - - def text(self): - return self._text - - def resizeEvent(self, event): - newWidth = event.size().width() - self.setText(self._text, width=newWidth) - event.accept() - - def eventFilter(self, a0: 'QObject', a1: 'QEvent') -> bool: - isFocusIn = a1.type() == QEvent.Type.FocusIn - if isFocusIn and (self.isReadOnly() or not self.isEnabled()): - self.clearFocus() - return True - return super().eventFilter(a0, a1) - - def focusInEvent(self, event): - super().focusInEvent(event) - self._elide = False - self.setText(self._text, elide=False) - self.setCursorPosition(len(self.text())) - - def focusOutEvent(self, event): - self._elide = True - super().focusOutEvent(event) - self.setText(self._text) - -class ValidLineEdit(QLineEdit): - def __init__(self, parent=None): - super().__init__(parent) - - def setInvalidStyleSheet(self): - self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) - - def setValidStyleSheet(self): - self.setStyleSheet('') - -class KeepIDsLineEdit(ValidLineEdit): - sigIDsChanged = Signal(list) - sigSort = Signal() - sigEnterPressed = Signal() - - def __init__(self, instructionsLabel, parent=None): - super().__init__(parent) - - self.validPattern = '^[0-9-, ]+$' - regExpr = QRegularExpression(self.validPattern) - self.setValidator(QRegularExpressionValidator(regExpr)) - - self.textChanged.connect(self.onTextChanged) - self.editingFinished.connect(self.onEditingFinished) - - self.instructionsText = instructionsLabel.text() - self._label = instructionsLabel - - def keyPressEvent(self, event) -> None: - super().keyPressEvent(event) - if event.text() == ',': - self.sigSort.emit() - elif event.key() == Qt.Key.Key_Return or event.key() == Qt.Key.Key_Enter: - self.sigEnterPressed.emit() - - def onTextChanged(self, text): - IDs = [] - rangesMatch = re.findall(r'(\d+-\d+)', text) - if rangesMatch: - for rangeText in rangesMatch: - start, stop = rangeText.split('-') - start, stop = int(start), int(stop) - IDs.extend(range(start, stop+1)) - text = re.sub(r'(\d+)-(\d+)', '', text) - IDsMatch = re.findall(r'(\d+)', text) - if IDsMatch: - for ID in IDsMatch: - IDs.append(int(ID)) - self.IDs = sorted(list(set(IDs))) - self.sigIDsChanged.emit(self.IDs) - - def onEditingFinished(self): - self.sigSort.emit() - - def warnNotExistingID(self): - self.setInvalidStyleSheet() - self._label.setText( - ' Some of the IDs are not existing --> they will be IGNORED' - ) - self._label.setStyleSheet('color: red') - - def setInstructionsText(self): - self.setValidStyleSheet() - self._label.setText(self.instructionsText) - self._label.setStyleSheet('') - -class ScrollBar(QScrollBar): - def __init__(self, *args): - super().__init__(*args) - self.installEventFilter(self) - self.setContextMenuPolicy(Qt.NoContextMenu) - - def eventFilter(self, object, event) -> bool: - if event.type() == QEvent.Type.Wheel: - return True - elif event.type() == QEvent.Type.MouseButtonPress: - # Filter right-click to prevent context menu - return event.button() == Qt.MouseButton.RightButton - elif event.type() == QEvent.Type.MouseButtonRelease: - # Filter right-click to prevent context menu - return event.button() == Qt.MouseButton.RightButton - return False - -class _ReorderableListModel(QAbstractListModel): - ''' - ReorderableListModel is a list model which implements reordering of its - items via drag-n-drop - ''' - dragDropFinished = Signal() - - def __init__(self, items, parent=None): - QAbstractItemModel.__init__(self, parent) - self.nodes = items - self.lastDroppedItems = [] - self.pendingRemoveRowsAfterDrop = False - - def rowForItem(self, text): - ''' - rowForItem method returns the row corresponding to the passed in item - or None if no such item exists in the model - ''' - try: - row = self.nodes.index(text) - except ValueError: - return None - return row - - def index(self, row, column, parent): - if row < 0 or row >= len(self.nodes): - return QModelIndex() - return self.createIndex(row, column) - - def parent(self, index): - return QModelIndex() - - def rowCount(self, index): - if index.isValid(): - return 0 - return len(self.nodes) - - def data(self, index, role): - if not index.isValid(): - return None - if role == Qt.DisplayRole: - row = index.row() - if row < 0 or row >= len(self.nodes): - return None - return self.nodes[row] - elif role == Qt.SizeHintRole: - return QSize(48, 32) - else: - return None - - def supportedDropActions(self): - return Qt.MoveAction - - def flags(self, index): - if not index.isValid(): - return Qt.ItemIsEnabled - return Qt.ItemIsEnabled | Qt.ItemIsSelectable | \ - Qt.ItemIsDragEnabled | Qt.ItemIsDropEnabled - - def insertRows(self, row, count, index): - if index.isValid(): - return False - if count <= 0: - return False - # inserting 'count' empty rows starting at 'row' - self.beginInsertRows(QModelIndex(), row, row + count - 1) - for i in range(0, count): - self.nodes.insert(row + i, '') - self.endInsertRows() - return True - - def removeRows(self, row, count, index): - if index.isValid(): - return False - if count <= 0: - return False - num_rows = self.rowCount(QModelIndex()) - self.beginRemoveRows(QModelIndex(), row, row + count - 1) - for i in range(count, 0, -1): - self.nodes.pop(row - i + 1) - self.endRemoveRows() - - if self.pendingRemoveRowsAfterDrop: - ''' - If we got here, it means this call to removeRows is the automatic - 'cleanup' action after drag-n-drop performed by Qt - ''' - self.pendingRemoveRowsAfterDrop = False - self.dragDropFinished.emit() - - return True - - def setData(self, index, value, role): - if not index.isValid(): - return False - if index.row() < 0 or index.row() > len(self.nodes): - return False - self.nodes[index.row()] = str(value) - self.dataChanged.emit(index, index) - return True - - def mimeTypes(self): - return ['application/vnd.treeviewdragdrop.list'] - - def mimeData(self, indexes): - mimedata = QMimeData() - encoded_data = QByteArray() - stream = QDataStream(encoded_data, QIODevice.WriteOnly) - for index in indexes: - if index.isValid(): - text = self.data(index, 0) - stream << QByteArray(text.encode('utf-8')) - mimedata.setData('application/vnd.treeviewdragdrop.list', encoded_data) - return mimedata - - def dropMimeData(self, data, action, row, column, parent): - if action == Qt.IgnoreAction: - return True - if not data.hasFormat('application/vnd.treeviewdragdrop.list'): - return False - if column > 0: - return False - - num_rows = self.rowCount(QModelIndex()) - if num_rows <= 0: - return False - - if row < 0: - if parent.isValid(): - row = parent.row() - else: - return False - - encoded_data = data.data('application/vnd.treeviewdragdrop.list') - stream = QDataStream(encoded_data, QIODevice.ReadOnly) - - new_items = [] - rows = 0 - while not stream.atEnd(): - text = QByteArray() - stream >> text - text = bytes(text).decode('utf-8') - index = self.nodes.index(text) - new_items.append((text, index)) - rows += 1 - - self.lastDroppedItems = [] - for (text, index) in new_items: - target_row = row - if index < row: - target_row += 1 - self.beginInsertRows(QModelIndex(), target_row, target_row) - self.nodes.insert(target_row, self.nodes[index]) - self.endInsertRows() - self.lastDroppedItems.append(text) - row += 1 - - self.pendingRemoveRowsAfterDrop = True - return True - -class _SelectionModel(QItemSelectionModel): - def __init__(self, parent=None, isSingleSelection=False): - QItemSelectionModel.__init__(self, parent) - self.isSingleSelection = isSingleSelection - - def onModelItemsReordered(self): - new_selection = QItemSelection() - new_index = QModelIndex() - for item in self.model().lastDroppedItems: - row = self.model().rowForItem(item) - if row is None: - continue - new_index = self.model().index(row, 0, QModelIndex()) - new_selection.select(new_index, new_index) - - self.clearSelection() - flags = ( - QItemSelectionModel.SelectionFlag.ClearAndSelect - | QItemSelectionModel.SelectionFlag.Rows - | QItemSelectionModel.SelectionFlag.Current - ) - self.select(new_selection, flags) - self.setCurrentIndex(new_index, flags) - if not self.isSingleSelection: - self.reset() - -class ReorderableListView(QListView): - def __init__( - self, items=None, parent=None, isSingleSelection=False - ) -> None: - super().__init__(parent) - if items is None: - items = [] - - self.isSingleSelection = isSingleSelection - self._model = _ReorderableListModel(items) - self._selectionModel = _SelectionModel(self._model) - self._model.dragDropFinished.connect( - self._selectionModel.onModelItemsReordered - ) - self.setModel(self._model) - self.setSelectionModel(self._selectionModel) - self.setDragDropMode(QAbstractItemView.DragDropMode.InternalMove) - self.setDragDropOverwriteMode(False) - styleSheet = (f""" - QListView {{ - selection-background-color: rgba(200, 200, 200, 0.30); - selection-color: black; - show-decoration-selected: 1; - }} - QListView::item {{ - border-bottom: 1px solid rgba(180, 180, 180, 0.5); - }} - QListView::item:hover {{ - background-color: rgba(200, 200, 200, 0.30); - }} - """) - self.setStyleSheet(styleSheet) - - def setItems(self, items): - self._model.nodes = items - - def items(self): - return self._model.nodes - - # def mouseReleaseEvent(self, e: QMouseEvent) -> None: - # super().mouseReleaseEvent(e) - # self._selectionModel.reset() - -class QDialogListbox(QDialog): - sigSelectionConfirmed = Signal(list) - - def __init__( - self, title, text, items, cancelText='Cancel', - multiSelection=True, parent=None, - additionalButtons=(), includeSelectionHelp=False, - allowSingleSelection=True, preSelectedItems=None, - allowEmptySelection=True - ): - self.cancel = True - items = list(items) - - super().__init__(parent) - self.setWindowTitle(title) - - if preSelectedItems is None: - if items: - preSelectedItems = (items[0],) - else: - preSelectedItems = set() - - self.allowSingleSelection = allowSingleSelection - self.allowEmptySelection = allowEmptySelection - - mainLayout = QVBoxLayout() - topLayout = QVBoxLayout() - bottomLayout = QHBoxLayout() - - self.mainLayout = mainLayout - - label = QLabel(text) - _font = QFont() - _font.setPixelSize(13) - label.setFont(_font) - # padding: top, left, bottom, right - label.setStyleSheet("padding:0px 0px 3px 0px;") - topLayout.addWidget(label, alignment=Qt.AlignCenter) - - if includeSelectionHelp: - selectionHelpLabel = QLabel() - txt = html_utils.paragraph("""
    - Ctrl+Click to select multiple items
    - Shift+Click to select a range of items
    - """) - selectionHelpLabel.setText(txt) - topLayout.addWidget(label, alignment=Qt.AlignCenter) - - listBox = listWidget() - listBox.setFont(_font) - listBox.addItems(items) - if multiSelection: - listBox.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection) - else: - listBox.setSelectionMode( - QAbstractItemView.SelectionMode.SingleSelection) - listBox.setCurrentRow(0) - for i in range(listBox.count()): - item = listBox.item(i) - item.setSelected(item.text() in preSelectedItems) - - self.listBox = listBox - if not multiSelection: - listBox.itemDoubleClicked.connect(self.ok_cb) - topLayout.addWidget(listBox) - - if cancelText.lower().find('cancel') != -1: - cancelButton = cancelPushButton(cancelText) - else: - cancelButton = QPushButton(cancelText) - okButton = okPushButton(' Ok ') - - bottomLayout.addStretch(1) - bottomLayout.addWidget(cancelButton) - bottomLayout.addSpacing(20) - - if additionalButtons: - self._additionalButtons = [] - for button in additionalButtons: - if isinstance(button, str): - _button, isCancelButton = getPushButton(button) - self._additionalButtons.append(_button) - bottomLayout.addWidget(_button) - _button.clicked.connect(self.ok_cb) - else: - bottomLayout.addWidget(button) - - bottomLayout.addWidget(okButton) - bottomLayout.setContentsMargins(0, 10, 0, 0) - - mainLayout.addLayout(topLayout) - mainLayout.addLayout(bottomLayout) - self.setLayout(mainLayout) - - # Connect events - okButton.clicked.connect(self.ok_cb) - cancelButton.clicked.connect(self.cancel_cb) - - if multiSelection: - listBox.itemClicked.connect(self.onItemClicked) - listBox.itemSelectionChanged.connect(self.onItemSelectionChanged) - - self.setStyleSheet(LISTWIDGET_STYLESHEET) - self.areItemsSelected = [ - listBox.item(i).isSelected() for i in range(listBox.count()) - ] - self.setFont(font) - - def keyPressEvent(self, event) -> None: - mod = event.modifiers() - if mod == Qt.ShiftModifier or mod == Qt.ControlModifier: - self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - elif event.key() == Qt.Key_Escape: - self.listBox.clearSelection() - event.ignore() - return - super().keyPressEvent(event) - - def onItemSelectionChanged(self): - if not self.listBox.selectedItems(): - self.areItemsSelected = [ - False for i in range(self.listBox.count()) - ] - - def onItemClicked(self, item): - mod = QGuiApplication.keyboardModifiers() - if mod == Qt.ShiftModifier or mod == Qt.ControlModifier: - self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - return - - self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) - itemIdx = self.listBox.row(item) - wasSelected = self.areItemsSelected[itemIdx] - if wasSelected: - item.setSelected(False) - - self.areItemsSelected = [ - self.listBox.item(i).isSelected() - for i in range(self.listBox.count()) - ] - # self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - # else: - # selectedItems.append(item) - - # self.listBox.clearSelection() - # for i in range(self.listBox.count()): - # item = self.listBox.item(i).setSelected(True) - - # print(self.listBox.selectedItems()) - - def setSelectedItems(self, itemsTexts): - for i in range(self.listBox.count()): - item = self.listBox.item(i) - if item.text() in itemsTexts: - item.setSelected(True) - self.listBox.update() - - def warnSelectionEmpty(self): - msg = myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph( - 'You need to select at least one item!.

    ' - 'Use Ctrl+Click to select multiple items
    ' - 'or Shift+Click to select a range of items' - ) - msg.warning(self, 'Selection cannot be empty!', txt) - - def ok_cb(self, checked=False): - self.clickedButton = self.sender() - self.cancel = False - selectedItems = self.listBox.selectedItems() - self.selectedItemsText = [item.text() for item in selectedItems] - if not self.allowSingleSelection and len(self.selectedItemsText) < 2: - msg = myMessageBox(wrapText=False, showCentered=False) - txt = html_utils.paragraph( - 'You need to select two or more items.

    ' - 'Use Ctrl+Click to select multiple items
    , or
    ' - 'Shift+Click to select a range of items' - ) - msg.warning(self, 'Select two or more items', txt) - return - - if not self.allowEmptySelection and not self.selectedItemsText: - self.warnSelectionEmpty() - return - - self.sigSelectionConfirmed.emit(self.selectedItemsText) - self.close() - - def cancel_cb(self, event): - self.cancel = True - self.selectedItemsText = None - self.close() - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - - horizontal_sb = self.listBox.horizontalScrollBar() - while horizontal_sb.isVisible(): - self.resize(self.height(), self.width() + 10) - - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - - -class ExpandableListBox(QComboBox): - def __init__(self, parent=None, centered=True) -> None: - super().__init__(parent) - - self.setEditable(True) - self.lineEdit().setReadOnly(True) - - infoTxt = html_utils.paragraph( - 'Select Positions to save

    ' - 'Ctrl+Click to select multiple items
    ' - 'Shift+Click to select a range of items
    ', - center=True - ) - - self.listW = QDialogListbox( - 'Select Positions to save', infoTxt, - [], multiSelection=True, parent=self - ) - - self.listW.listBox.itemClicked.connect(self.listItemClicked) - self.listW.sigSelectionConfirmed.connect(self.updateCombobox) - - self.centered = centered - - def listItemClicked(self, item): - if item.text().find('All') == -1: - return - - for i in range(self.listW.listBox.count()): - _item = self.listW.listBox.item(i) - _item.setSelected(True) - - def clear(self) -> None: - self.listW.listBox.clear() - return super().clear() - - def setItems(self, items): - self.clear() - self.addItems(items) - - def addItems(self, items): - super().addItems(items) - self.listW.listBox.addItems(items) - self.listW.listBox.setCurrentRow(self.currentIndex()) - self.listItemClicked(self.listW.listBox.currentItem()) - if self.centered: - self.centerItems() - - def updateCombobox(self, selectedItemsText): - isAllItem = [ - i for i, t in enumerate(selectedItemsText) if t.find('All') != -1 - ] - if len(selectedItemsText) == 1: - self.setCurrentText(selectedItemsText[0]) - elif isAllItem: - idx = isAllItem[0] - self.setCurrentText(selectedItemsText[idx]) - else: - super().clear() - super().addItems(['Custom selection']) - - def centerItems(self, idx=None): - self.lineEdit().setAlignment(Qt.AlignCenter) - - def selectedItems(self): - return self.listW.listBox.selectedItems() - - def selectedItemsText(self): - return [item.text() for item in self.selectedItems()] - - def showPopup(self) -> None: - self.listW.show() - -class filePathControl(QFrame): - sigValueChanged = Signal(str) - - def __init__( - self, parent=None, browseFolder=False, - fileManagerTitle='Select file', - validExtensions=None, - startFolder='', - elide=False - ): - super().__init__(parent) - - layout = QHBoxLayout() - if elide: - self.le = ElidingLineEdit() - else: - self.le = QLineEdit() - - self.browseButton = browseFileButton( - openFolder=browseFolder, title=fileManagerTitle, - ext=validExtensions, start_dir=startFolder - ) - - layout.addWidget(self.le) - layout.addWidget(self.browseButton) - self.setLayout(layout) - - self.le.editingFinished.connect(self.setTextTooltip) - self.browseButton.sigPathSelected.connect(self.setText) - - self.setFrameStyle(QFrame.Shape.StyledPanel) - - def setText(self, text): - self.le.setText(text) - self.le.setToolTip(text) - self.sigValueChanged.emit(self.le.text()) - - def setTextTooltip(self): - self.le.setToolTip(self.le.text()) - self.sigValueChanged.emit(self.le.text()) - - def path(self): - return self.le.text() - - def showEvent(self, a0: QShowEvent) -> None: - self.le.setFixedHeight(self.browseButton.height()) - return super().showEvent(a0) - -class FolderPathControl(filePathControl): - def __init__(self, **kwargs): - super().__init__( - browseFolder=True, - fileManagerTitle='Select folder', - **kwargs - ) - -class CsvFilePathControl(filePathControl): - def __init__(self, **kwargs): - super().__init__( - browseFolder=False, - fileManagerTitle='Select a CSV file', - validExtensions={'CSV files': ['.csv', '.CSV']}, - **kwargs - ) - -class QHWidgetSpacer(QWidget): - def __init__(self, width=10, parent=None) -> None: - super().__init__(parent) - self.setFixedWidth(width) - -class QVWidgetSpacer(QWidget): - def __init__(self, height=10, parent=None) -> None: - super().__init__(parent) - self.setFixedHeight(height) - -class QHLine(QFrame): - def __init__(self, shadow='Sunken', parent=None, color=None): - super().__init__(parent) - self.setFrameShape(QFrame.Shape.HLine) - self.setFrameShadow(getattr(QFrame, shadow)) - if color is not None: - self.setColor(color) - - def setColor(self, color): - qcolor = pg.mkColor(color) - pal = self.palette() - pal.setColor(QPalette.ColorRole.WindowText, qcolor) - self.setPalette(pal) - -class QVLine(QFrame): - def __init__(self, shadow='Plain', parent=None, color=None): - super().__init__(parent) - self.setFrameShape(QFrame.Shape.VLine) - self.setFrameShadow(getattr(QFrame.Shadow, shadow)) - if color is not None: - self.setColor(color) - - def setColor(self, color): - qcolor = pg.mkColor(color) - pal = self.palette() - pal.setColor(QPalette.ColorRole.WindowText, qcolor) - self.setPalette(pal) - -class VerticalResizeHline(QFrame): - dragged = Signal(object) - clicked = Signal(object) - released = Signal(object) - - def __init__(self): - super().__init__() - self.setCursor(Qt.SplitVCursor) - self.setFrameShape(QFrame.Shape.HLine) - self.setFrameShadow(QFrame.Shadow.Sunken) - self.installEventFilter(self) - self.isMousePressed = False - self._height = 4 - self.setMinimumHeight(self._height) - - def mousePressEvent(self, event) -> None: - self.isMousePressed = True - self.clicked.emit(event) - return super().mousePressEvent(event) - - def mouseMoveEvent(self, event) -> None: - self.dragged.emit(event) - return super().mouseMoveEvent(event) - - def mouseReleaseEvent(self, event) -> None: - self.isMousePressed = False - self.released.emit(event) - return super().mouseReleaseEvent(event) - - def eventFilter(self, object, event): - if event.type() == QEvent.Type.Enter: - self.setLineWidth(0) - self.setMidLineWidth(self._height) - pal = self.palette() - pal.setColor(QPalette.ColorRole.WindowText, QColor(BASE_COLOR)) - self.setPalette(pal) - # self.setStyleSheet('background-color: #4d4d4d') - elif event.type() == QEvent.Type.Leave: - self.setMidLineWidth(0) - self.setLineWidth(1) - return False - -class GroupBox(QGroupBox): - def __init__(self, *args, keyPressCallback=None): - super().__init__(*args) - self.keyPressCallback = None - self.setFocusPolicy(Qt.NoFocus) - - def keyPressEvent(self, event) -> None: - event.ignore() - if self.keyPressCallback is None: - return - - self.keyPressCallback() - -class CheckBox(QCheckBox): - def __init__(self, *args, keyPressCallback=None): - super().__init__(*args) - self.keyPressCallback = None - self.setFocusPolicy(Qt.NoFocus) - - def keyPressEvent(self, event) -> None: - event.ignore() - if self.keyPressCallback is None: - return - - self.keyPressCallback() - -class ScrollArea(QScrollArea): - sigLeaveEvent = Signal() - - def __init__( - self, parent=None, resizeVerticalOnShow=False, - dropArrowKeyEvents=False - ) -> None: - super().__init__(parent) - self.setWidgetResizable(True) - self.setFrameStyle(QFrame.Shape.NoFrame) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) - self.containerWidget = None - self.resizeVerticalOnShow = resizeVerticalOnShow - self.isOnlyVertical = False - self.dropArrowKeyEvents = dropArrowKeyEvents - - def setVerticalLayout(self, layout, widget=None): - if widget is None: - self.containerWidget = QWidget() - else: - self.containerWidget = widget - self.containerWidget.setLayout(layout) - self.containerWidget.setSizePolicy( - QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred - ) - self.setWidget(self.containerWidget) - self.containerWidget.installEventFilter(self) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.isOnlyVertical = True - - def setWidget(self, widget): - self.containerWidget = widget - super().setWidget(widget) - - def _resizeHorizontal(self): - self.setMinimumWidth( - self.containerWidget.minimumSizeHint().width() - + self.verticalScrollBar().width() - ) - - def minimumWidthNoScrollbar(self) -> int: - width = ( - self.containerWidget.minimumSizeHint().width() - + self.verticalScrollBar().width() - ) - return width - - def minimumHeightNoScrollbar(self) -> int: - height = ( - self.containerWidget.minimumSizeHint().height() - + self.horizontalScrollBar().height() - ) - return height - - def _resizeVertical(self): - height = ( - self.containerWidget.minimumSizeHint().height() - + self.horizontalScrollBar().height() - ) - self.containerWidget.setSizePolicy( - QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred - ) - - self.setFixedHeight(height) - - def eventFilter(self, object, event: QEvent): - if event.type() == QEvent.Type.Leave: - self.sigLeaveEvent.emit() - - if object != self.containerWidget: - return False - - isResize = event.type() == QEvent.Type.Resize - isShow = event.type() == QEvent.Type.Show - if isResize and self.isOnlyVertical: - self._resizeHorizontal() - elif isShow and self.resizeVerticalOnShow: - self._resizeVertical() - return False - -class QClickableLabel(QLabel): - clicked = Signal(object) - - def __init__(self, parent=None): - self._parent = parent - super().__init__(parent) - self._checkableItem = None - - def setCheckableItem(self, widget): - self._checkableItem = widget - - def mousePressEvent(self, event): - self.clicked.emit(self) - if self._checkableItem is not None: - status = not self._checkableItem.isChecked() - self._checkableItem.setChecked(status) - - def setChecked(self, checked): - self._checkableItem.setChecked(checked) - -class QCenteredComboBox(QComboBox): - def __init__(self, parent=None) -> None: - super().__init__(parent) - - self.setEditable(True) - self.lineEdit().setReadOnly(True) - self.lineEdit().setAlignment(Qt.AlignCenter) - self.lineEdit().installEventFilter(self) - - self.currentIndexChanged.connect(self.centerItems) - - self._isPopupVisibile = False - - def centerItems(self, idx): - for i in range(self.count()): - self.setItemData(i, Qt.AlignCenter, Qt.TextAlignmentRole) - - def eventFilter(self, lineEdit, event): - # Reimplement show popup on click - if event.type() == QEvent.Type.MouseButtonPress and self.isEnabled(): - if self._isPopupVisibile: - self.hidePopup() - self._isPopupVisibile = False - else: - self.showPopup() - self._isPopupVisibile = True - return True - return False - -class AlphaNumericComboBox(QCenteredComboBox): - def __init__(self, parent=None) -> None: - super().__init__(parent=parent) - - def addItems(self, items): - self._dtype = type(items[0]) - super().addItems([str(item) for item in items]) - - def setCurrentValue(self, value): - super().setCurrentText(str(value)) - - def currentValue(self): - return self._dtype(super().currentText()) - -class statusBarPermanentLabel(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - - self.rightLabel = QLabel('') - self.leftLabel = QLabel('') - - layout = QHBoxLayout() - layout.addWidget(self.leftLabel) - layout.addStretch(10) - layout.addWidget(self.rightLabel) - - self.setLayout(layout) - -class listWidget(QListWidget): - def __init__( - self, - *args, - isMultipleSelection=False, - minimizeHeight=False, - **kwargs - ): - super().__init__(*args, **kwargs) - self.itemHeight = None - self.setStyleSheet(LISTWIDGET_STYLESHEET) - self.setFont(font) - if isMultipleSelection: - self.setSelectionMode( - QAbstractItemView.SelectionMode.ExtendedSelection - ) - - self.minimizeHeight = minimizeHeight - - def setSelectedAll(self, selected): - for i in range(self.count()): - self.item(i).setSelected(selected) - - def setSelectedItems(self, itemsText): - for i in range(self.count()): - item = self.item(i) - item.setSelected(item.text() in itemsText) - - def addItems(self, labels) -> None: - super().addItems(labels) - if self.itemHeight is not None: - self.setItemHeight() - - if self.minimizeHeight: - itemHeight = self.sizeHintForRow(0) - self.setMaximumHeight(itemHeight * self.count() + itemHeight*2) - - def addItem(self, text): - super().addItem(text) - if self.itemHeight is None: - return - self.setItemHeight() - - def setItemHeight(self, height=40): - self.itemHeight = height - for i in range(self.count()): - item = self.item(i) - item.setSizeHint(QSize(0, height)) - - def selectedItemsText(self): - return [item.text() for item in self.selectedItems()] - -class OrderableListWidget(QWidget): - sigEnterEvent = Signal(object) - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._labels = [] - - def setParentItem(self, item): - self._item = item - - def setLabelsColor(self, selected): - if selected: - stylesheet = 'color : black' - else: - stylesheet = '' - - for label in self._labels: - label.setStyleSheet(stylesheet) - - def enterEvent(self, event): - super().enterEvent(event) - self.setLabelsColor(True) - self.sigEnterEvent.emit(self._item) - - # def leaveEvent(self, event): - # super().leaveEvent(event) - # self.setLabelsColor(self._item.isSelected()) - # printl('leave', self._item.isSelected()) - - def addLabel(self, label): - self._labels.append(label) - -class OrderableList(listWidget): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.setMouseTracking(True) - self.itemEntered.connect(self.onItemEntered) - - def onItemEntered(self, enteredItem): - enteredRow = self.row(enteredItem) - for i in range(self.count()): - item = self.item(i) - item._container.setLabelsColor(i == enteredRow or item.isSelected()) - - def leaveEvent(self, event): - super().leaveEvent(event) - for i in range(self.count()): - item = self.item(i) - item._container.setLabelsColor(item.isSelected()) - - def addItems(self, items): - self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - nr_items = len(items) - nn = [str(n) for n in range(1, nr_items+1)] - for i, item in enumerate(items): - itemW = QListWidgetItem() - itemContainer = OrderableListWidget() - itemContainer.setParentItem(itemW) - itemText = QLabel(item) - tableNrLabel = QLabel('| Table nr.') - itemContainer.addLabel(tableNrLabel) - itemContainer.addLabel(itemText) - itemLayout = QHBoxLayout() - itemNumberWidget = QComboBox() - itemNumberWidget.addItems(nn) - itemLayout.addWidget(itemText) - itemLayout.addWidget(tableNrLabel) - itemLayout.addWidget(itemNumberWidget) - itemContainer.setLayout(itemLayout) - itemLayout.setSizeConstraint(QLayout.SizeConstraint.SetFixedSize) - itemW.setSizeHint(itemContainer.sizeHint()) - self.addItem(itemW) - self.setItemWidget(itemW, itemContainer) - itemW._text = item - itemW._nrWidget = itemNumberWidget - itemW._container = itemContainer - itemNumberWidget.setDisabled(True) - itemNumberWidget.textActivated.connect(self.onTextActivated) - itemNumberWidget._currentNr = 1 - itemNumberWidget.row = i - itemContainer.sigEnterEvent.connect(self.onItemEntered) - - self.itemSelectionChanged.connect(self.onItemSelectionChanged) - - def keyPressEvent(self, event) -> None: - if event.key() == Qt.Key_Escape: - self.clearSelection() - event.ignore() - return - super().keyPressEvent(event) - - def updateNr(self): - for i in range(self.count()): - item = self.item(i) - item._currentNr = int(item._nrWidget.currentText()) - - def onItemSelectionChanged(self): - for i in range(self.count()): - item = self.item(i) - item._container.setLabelsColor(item.isSelected()) - item._nrWidget.setDisabled(not item.isSelected()) - if item._nrWidget.currentText() != '1': - item._nrWidget.setCurrentText('1') - item._currentNr = 1 - - for i, item in enumerate(self.selectedItems()): - item._nrWidget.setCurrentText(f'{i+1}') - item._currentNr = i+1 - - def onTextActivated(self, text): - changedNr = self.sender()._currentNr - for item in self.selectedItems(): - row = self.row(item) - if self.sender().row == row: - changedNr = item._currentNr - continue - - for item in self.selectedItems(): - row = self.row(item) - if self.sender().row == row: - continue - nr = int(item._nrWidget.currentText()) - if nr == int(text): - item._nrWidget.setCurrentText(str(changedNr)) - break - - self.updateNr() - - -class TreeWidget(QTreeWidget): - def __init__(self, *args, multiSelection=False): - super().__init__(*args) - self.setStyleSheet(TREEWIDGET_STYLESHEET) - self.setFont(font) - if multiSelection: - self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - self.itemClicked.connect(self.selectAllChildren) - - self.isCtrlDown = False - self.isShiftDown = False - - def keyPressEvent(self, ev): - if ev.key() == Qt.Key_Escape: - self.clearSelection() - elif ev.key() == Qt.Key_Control: - self.isCtrlDown = True - elif ev.key() == Qt.Key_Shift: - self.isShiftDown = True - - def keyReleaseEvent(self, ev): - if ev.key() == Qt.Key_Control: - self.isCtrlDown = False - elif ev.key() == Qt.Key_Shift: - self.isShiftDown = False - - def onFocusChanged(self): - self.isCtrlDown = False - self.isShiftDown = False - - def selectAllChildren(self, item_or_label): - label = None - if isinstance(item_or_label, QLabel): - label = item_or_label - else: - item = item_or_label - if item.childCount() == 0: - return - - if label is not None: - if not self.isCtrlDown and not self.isShiftDown: - self.clearSelection() - label.item.setSelected(True) - if self.isShiftDown: - selectionStarted = False - it = QTreeWidgetItemIterator(self) - while it: - item = it.value() - if item is None: - break - if item.isSelected(): - selectionStarted = not selectionStarted - if selectionStarted: - item.setSelected(True) - it += 1 - - for item in self.selectedItems(): - if item.parent() is None: - for i in range(item.childCount()): - item.child(i).setSelected(True) - -class CancelOkButtonsLayout(QHBoxLayout): - def __init__(self, *args, additionalButtons=None): - super().__init__(*args) - - self.cancelButton = cancelPushButton('Cancel') - self.okButton = okPushButton(' Ok ') - - self.addStretch(1) - self.addWidget(self.cancelButton) - self.addSpacing(20) - - if additionalButtons is not None: - for button in additionalButtons: - self.addWidget(button) - - self.addWidget(self.okButton) - -class TreeWidgetItem(QTreeWidgetItem): - def __init__(self, *args, columnColors=None): - super().__init__(*args) - - if columnColors is not None: - for c, color in enumerate(columnColors): - if color is None: - continue - self.setBackground(c, QBrush(color)) - -class FilterObject(QObject): - sigFilteredEvent = Signal(object, object) - - def __init__(self) -> None: - super().__init__() - - def eventFilter(self, object, event): - self.sigFilteredEvent.emit(object, event) - return super().eventFilter(object, event) - -class readOnlyQList(QTextEdit): - def __init__(self, parent=None): - super().__init__(parent) - self.setReadOnly(True) - self.items = [] - - def addItems(self, items): - self.items.extend(items) - items = [str(item) for item in self.items] - columnList = html_utils.paragraph('
    '.join(items)) - self.setText(columnList) - -class pgScatterSymbolsCombobox(QComboBox): - def __init__(self, parent=None): - super().__init__(parent) - - symbols = [ - "'o' circle (default)", - "'s' square", - "'t' triangle", - "'d' diamond", - "'+' plus", - "'t1' triangle pointing upwards", - "'t2' triangle pointing right side", - "'t3' triangle pointing left side", - "'p' pentagon", - "'h' hexagon", - "'star'", - "'x' cross", - "'arrow_up'", - "'arrow_right'", - "'arrow_down'", - "'arrow_left'", - "'crosshair'" - ] - self.addItems(symbols) - - -class alphaNumericLineEdit(QLineEdit): - sigInvalidCharacterPressed = Signal(str) - sigInvalidCharactersEntered = Signal(object) - - def __init__(self, parent=None, additionalChars='', onlyWarn=False): - super().__init__(parent) - self.validPattern = fr'^[a-zA-Z0-9{additionalChars}_\-]+$' - self.invalidPattern = fr'[^a-zA-Z0-9{additionalChars}_\-]' - - if not onlyWarn: - regExp = QRegularExpression(self.validPattern) - self.setValidator(QRegularExpressionValidator(regExp)) - else: - self.textChanged.connect(self.emitInvalidCharactersEntered) - - def emitInvalidCharactersEntered(self, text): - invalidCharacters = self.invalidCharacters() - if not invalidCharacters: - return - - self.sigInvalidCharactersEntered.emit(set(invalidCharacters)) - - def invalidCharacters(self): - return re.findall(fr'{self.invalidPattern}', self.text()) - - def keyPressEvent(self, event: QKeyEvent): - if not event.text(): - return super().keyPressEvent(event) - - if event.modifiers() & ( - Qt.KeyboardModifier.ControlModifier - | Qt.KeyboardModifier.AltModifier - | Qt.KeyboardModifier.MetaModifier - ): - return super().keyPressEvent(event) - - if not event.text().isprintable(): - return super().keyPressEvent(event) - - super().keyPressEvent(event) - - if event.text() in self.text(): - return - - self.sigInvalidCharacterPressed.emit(event.text()) - -class NumericCommaLineEdit(QLineEdit): - def __init__(self, parent=None): - super().__init__(parent) - - self.validPattern = r'^[0-9,\.]+$' - regExp = QRegularExpression(self.validPattern) - self.setValidator(QRegularExpressionValidator(regExp)) - - def values(self): - try: - vals = [float(c) for c in self.text().split(',')] - except Exception as e: - vals = [] - return vals - -class mySpinBox(QSpinBox): - sigTabEvent = Signal(object, object) - - def __init__(self, *args) -> None: - super().__init__(*args) - - def event(self, event): - if event.type()==QEvent.Type.KeyPress and event.key() == Qt.Key_Tab: - self.sigTabEvent.emit(event, self) - return True - - return super().event(event) - -class KeptObjectIDsList(list): - def __init__(self, lineEdit, confirmSelectionAction, *args): - self.lineEdit = lineEdit - self.lineEdit.setText('') - self.confirmSelectionAction = confirmSelectionAction - confirmSelectionAction.setDisabled(True) - super().__init__(*args) - - def setText(self): - text = myutils.format_IDs(self) - - self.lineEdit.setText(text) - - def append(self, element, editText=True): - super().append(element) - if editText: - self.setText() - if not self.confirmSelectionAction.isEnabled(): - self.confirmSelectionAction.setEnabled(True) - - def remove(self, element, editText=True): - super().remove(element) - if editText: - self.setText() - if not self: - self.confirmSelectionAction.setEnabled(False) - -class ScatterPlotItem(pg.ScatterPlotItem): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.updateBrushAndPen(**kwargs) - - def updateBrushAndPen(self, **kwargs): - brush = kwargs.get('brush') - if brush is not None: - self._itemBrush = brush - pen = kwargs.get('pen') - if pen is not None: - self._itemPen = pen - - def setData(self, *args, **kwargs): - super().setData(*args, **kwargs) - self.updateBrushAndPen(**kwargs) - - def itemBrush(self): - return self._itemBrush - - def itemPen(self): - return self._itemPen - - def removePoint(self, index): - newData = np.delete(self.data, index) - # Update the index of current points - for i in range(index, len(newData)): - spotItem = newData[i]['item'] - spotItem._index = i - newData[i]['item'] = spotItem - - self.data = newData - self.prepareGeometryChange() - self.informViewBoundsChanged() - self.bounds = [None, None] - self.invalidate() - self.updateSpots(newData) - self.sigPlotChanged.emit(self) - - def coordsToNumpy(self, includeData=False, rounded=True, decimals=None): - points = self.points() - nrows = len(points) - coords_arr = np.zeros((nrows, 2)) - data_arr = None - for p, point in enumerate(points): - pos = point.pos() - x, y = pos.x(), pos.y() - if includeData: - data = point.data() - if data_arr is None: - try: - ncols = len(data) - except Exception as e: - data = [data] - ncols = 1 - data_arr = np.zeros((nrows, ncols)) - for j, data_j in enumerate(data): - data_arr[p, j] = data_j - - coords_arr[p, 0] = y - coords_arr[p, 1] = x - if not includeData: - out_arr = coords_arr - elif data_arr is not None: - out_arr = np.column_stack((data_arr, coords_arr)) - else: - out_arr = coords_arr - cast_to_int = decimals is None - decimals = decimals if decimals is not None else 0 - if rounded: - out_arr = np.round(out_arr, decimals) - if cast_to_int: - out_arr = out_arr.astype(int) - return out_arr - -class myLabelItem(pg.LabelItem): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._prevText = '' - - def setText(self, text, **args): - self.text = text - opts = self.opts - for k in args: - opts[k] = args[k] - - if 'size' in self.opts: - size = self.opts['size'] - if size == '0pt' or size == '0px': - self.opts['size'] = '1pt' - super().setText('', size='1pt') - return - - optlist = [] - - color = self.opts['color'] - if color is None: - color = pg.getConfigOption('foreground') - color = pg.functions.mkColor(color) - optlist.append('color: ' + color.name(QColor.NameFormat.HexArgb)) - if 'size' in opts: - size = opts['size'] - if not isinstance(size, str): - size = f'{size}px' - optlist.append('font-size: ' + size) - if 'bold' in opts and opts['bold'] in [True, False]: - optlist.append('font-weight: ' + {True:'bold', False:'normal'}[opts['bold']]) - if 'italic' in opts and opts['italic'] in [True, False]: - optlist.append('font-style: ' + {True:'italic', False:'normal'}[opts['italic']]) - full = "%s" % ('; '.join(optlist), text) - #print full - self.item.setHtml(full) - self.updateMin() - self.resizeEvent(None) - self.updateGeometry() - - def tempClearText(self): - if self.text: - self._prevText = self.text - self.setText('') - - def restoreText(self): - if self._prevText: - self.setText(self._prevText) - - -class myMessageBox(_base_widgets.QBaseDialog): - def __init__( - self, parent=None, showCentered=True, wrapText=True, - scrollableText=False, enlargeWidthFactor=0, - resizeButtons=True, allowClose=True - ): - super().__init__(parent) - - self.wrapText = wrapText - self.enlargeWidthFactor = enlargeWidthFactor - self.resizeButtons = resizeButtons - - self.cancel = True - self.cancelButton = None - self.okButton = None - self.clickedButton = None - self.alreadyShown = False - self.allowClose = allowClose - - self.showCentered = showCentered - - self.scrollableText = scrollableText - - self._layout = QGridLayout() - self.commandsLayout = None - self._layout.setHorizontalSpacing(20) - self.buttonsLayout = QHBoxLayout() - self.buttonsLayout.setSpacing(2) - self.buttons = [] - self.widgets = [] - self.layouts = [] - self.labels = [] - self.labelsWidgets = [] - self._pixmapLabels = [] - self.detailsTextWidget = None - self.showInFileManagButton = None - self.visibleDetails = False - self.doNotShowAgainCheckbox = None - - self.currentRow = 0 - self.textWidget = None - self._w = None - - self.textLayout = QVBoxLayout() - - self._layout.setColumnStretch(1, 1) - self.setLayout(self._layout) - - self.setFont(font) - - def mousePressEvent(self, event): - for label in self.labels: - label.setTextInteractionFlags( - Qt.TextBrowserInteraction | Qt.TextSelectableByKeyboard - ) - - def setIcon(self, iconName='SP_MessageBoxInformation'): - label = QLabel(self) - - standardIcon = getattr(QStyle, iconName) - icon = self.style().standardIcon(standardIcon) - pixmap = icon.pixmap(60, 60) - label.setPixmap(pixmap) - - self._layout.addWidget(label, 0, 0, alignment=Qt.AlignTop) - - def addImage(self, image_path): - pixmap = QPixmap(image_path) - label = QLabel() - label.setPixmap(pixmap) - self._layout.addWidget(label, self.currentRow, 1) - self.currentRow += 1 - - def addShowInFileManagerButton(self, path, txt=None): - if txt is None: - txt = 'Reveal in Finder...' if is_mac else 'Show in Explorer...' - self.showInFileManagButton = showInFileManagerButton(txt) - self.buttonsLayout.addWidget(self.showInFileManagButton) - func = partial(myutils.showInExplorer, path) - self.showInFileManagButton.clicked.connect(func) - - def addBrowseUrlButton(self, url, button_text=''): - self.openUrlButton = OpenUrlButton(url, button_text) - self.buttonsLayout.addWidget(self.openUrlButton) - - def addCancelButton(self, button=None, connect=False): - if button is None: - self.cancelButton = cancelPushButton('Cancel') - else: - self.cancelButton = button - self.cancelButton.setIcon(QIcon(':cancelButton.svg')) - - self.buttonsLayout.insertWidget(0, self.cancelButton) - self.buttonsLayout.insertSpacing(1, 20) - if connect: - self.cancelButton.clicked.connect(self.buttonCallBack) - - def splitLatexBlocks(self, text): - texts = re.split(r"(.+?)", text) - return texts - - def splitCopiableBlocks(self, texts: Sequence[str] | str): - if isinstance(texts, str): - texts = (texts,) - - texts_out = [] - for text in texts: - texts_out.extend(re.split(r"(.+?)", text)) - return texts_out - - def addText(self, text): - texts = self.splitLatexBlocks(text) - texts = self.splitCopiableBlocks(texts) - - labelsWidget = LabelsWidget(texts, wrapText=self.wrapText) - self.labelsWidgets.append(labelsWidget) - self.labels.extend(labelsWidget.labels) - if self.scrollableText: - textWidget = QScrollArea() - textWidget.setFrameStyle(QFrame.Shape.NoFrame) - textWidget.setWidget(labelsWidget) - else: - textWidget = labelsWidget - - self.textLayout.addWidget(textWidget) - - if self.textWidget is None: - self.textWidget = QWidget() - self.textWidget.setLayout(self.textLayout) - self._layout.addWidget(self.textWidget, self.currentRow, 1) - self.textRow = self.currentRow - self.currentRow += 1 - - return labelsWidget - - def addCopiableCommand(self, command): - copiableCommandWidget = CopiableCommandWidget(command) - screenWidth = self.screen().size().width() - maxWidth = int(0.75*screenWidth) - sizeHint = copiableCommandWidget.sizeHint() - width = sizeHint.width() - if width > maxWidth: - copiableCommandWidget = addWidgetToScrollArea( - copiableCommandWidget, - resizeMinHeightNoVerticalScrollbar=True - ) - self._layout.addWidget(copiableCommandWidget, self.currentRow, 1) - self.currentRow += 1 - - def copyToClipboard(self): - cb = QApplication.clipboard() - cb.clear(mode=cb.Clipboard) - cb.setText(self.sender()._command, mode=cb.Clipboard) - print('Command copied!') - - def addButton(self, buttonText): - if not isinstance(buttonText, str): - # Passing button directly - button = buttonText - self.buttonsLayout.addWidget(button) - button.clicked.connect(self.buttonCallBack) - self.buttons.append(button) - return button - - button, isCancelButton = getPushButton(buttonText, qparent=self) - if not isCancelButton: - self.buttonsLayout.addWidget(button) - - button.clicked.connect(self.buttonCallBack) - self.buttons.append(button) - return button - - def addDoNotShowAgainCheckbox(self, text='Do not show again'): - self.doNotShowAgainCheckbox = QCheckBox(text) - - def addWidget(self, widget): - self._layout.addWidget(widget, self.currentRow, 1) - self.widgets.append(widget) - self.currentRow += 1 - - def addLayout(self, layout): - self._layout.addLayout(layout, self.currentRow, 1) - self.layouts.append(layout) - self.currentRow += 1 - - def setWidth(self, w): - self._w = w - - def show(self, block=False): - self.endOfScrollableRow = self.currentRow - - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - # spacer - spacer = QSpacerItem(10, 10) - self._layout.addItem(spacer, self.currentRow, 1) - self._layout.setRowStretch(self.currentRow, 0) - - # buttons - self.currentRow += 1 - - if self.detailsTextWidget is not None: - self.buttonsLayout.insertWidget(1, self.detailsButton) - - # Do not show again checkbox - if self.doNotShowAgainCheckbox is not None: - self._layout.addWidget( - self.doNotShowAgainCheckbox, self.currentRow, 1, 1, 2 - ) - self.currentRow += 1 - - # spacer - self._layout.addItem(QSpacerItem(10, 10), self.currentRow, 1) - self.currentRow += 1 - - # buttons - self._layout.addLayout( - self.buttonsLayout, self.currentRow, 0, 1, 2, - alignment=Qt.AlignRight - ) - - # Details - if self.detailsTextWidget is not None: - # spacer - self.currentRow += 1 - self._layout.addItem(QSpacerItem(20, 20), self.currentRow, 1) - - # detailsTextWidget - self.currentRow += 1 - self._layout.addWidget( - self.detailsTextWidget, self.currentRow, 0, 1, 2 - ) - - # spacer - self.currentRow += 1 - spacer = QSpacerItem(10, 10) - self._layout.addItem(spacer, self.currentRow, 1) - self._layout.setRowStretch(self.currentRow, 0) - - screenHeight = self.screen().size().height() - dialogHeight = self.sizeHint().height() - dialogWidth = self.sizeHint().width() - screenWidth = self.screen().size().width() - - # Check if scrollbar is needed - if dialogHeight > screenHeight and self.textWidget is not None: - textScrollArea = ScrollArea() - textScrollArea.setWidget(self.textWidget) - scrollAreaWidthNoSB = textScrollArea.minimumWidthNoScrollbar() - scrollAreaWidth = textScrollArea.sizeHint().width() - desiredDeltaWidth = scrollAreaWidthNoSB - scrollAreaWidth - if desiredDeltaWidth > 0: - desiredWidth = dialogWidth + desiredDeltaWidth - if desiredWidth < screenWidth: - self._w = desiredWidth - - self._layout.removeWidget(self.textWidget) - self._layout.addWidget(textScrollArea, self.textRow, 1) - - super().show() - QTimer.singleShot(5, self._resize) - - self.alreadyShown = True - - if block: - self._block() - - def setDetailedText(self, text, visible=False, wrap=True): - text = text.replace('\n', '
    ') - self.detailsTextWidget = QTextEdit(text) - self.detailsTextWidget.setReadOnly(True) - if not wrap: - self.detailsTextWidget.setLineWrapMode(QTextEdit.NoWrap) - self.detailsButton = showDetailsButton() - self.detailsButton.setCheckable(True) - self.detailsButton.clicked.connect(self._showDetails) - self.detailsTextWidget.hide() - self.visibleDetails = visible - - def _showDetails(self, checked): - if checked: - self.origHeight = self.height() - self.resize(self.width(), self.height()+300) - self.detailsTextWidget.show() - else: - self.detailsTextWidget.hide() - func = partial(self.resize, self.width(), self.origHeight) - QTimer.singleShot(10, func) - - def _resize(self): - if self.resizeButtons: - widths = [button.width() for button in self.buttons] - if widths: - max_width = max(widths) - for button in self.buttons: - if button == self.cancelButton: - continue - button.setMinimumWidth(max_width) - - heights = [button.height() for button in self.buttons] - if heights: - max_h = max(heights) - for button in self.buttons: - button.setMinimumHeight(max_h) - if self.detailsTextWidget is not None: - self.detailsButton.setMinimumHeight(max_h) - if self.showInFileManagButton is not None: - self.showInFileManagButton.setMinimumHeight(max_h) - - if self._w is not None and self.width() < self._w: - self.resize(self._w, self.height()) - - if self.width() < 350: - self.resize(350, self.height()) - - if self.enlargeWidthFactor > 0: - self.resize(int(self.width()*self.enlargeWidthFactor), self.height()) - - if self.visibleDetails: - self.detailsButton.click() - - if self.showCentered: - screen = self.screen() - screenWidth = screen.size().width() - screenHeight = screen.size().height() - screenLeft = screen.geometry().x() - screenTop = screen.geometry().y() - w, h = self.width(), self.height() - left = int(screenLeft + screenWidth/2 - w/2) - top = int(screenTop + screenHeight/2 - h/2) - if top < screenTop: - top = screenTop - if left < screenLeft: - left = screenLeft - self.move(left, top) - - self._h = self.height() - - if self.okButton is not None: - self.okButton.setFocus() - - screen = self.screen() - screenWidth = screen.size().width() - screenHeight = screen.size().height() - - # Check Force wrap Text - for labelWidget in self.labelsWidgets: - textWidth = labelWidget.width() - if not textWidth > screenWidth-10: - continue - factor = np.ceil(textWidth/screenWidth) - lineLength = int(labelWidget.nCharsLongestLine/factor) - for label in labelWidget.labels: - if isinstance(label, CopiableCommandWidget): - continue - - text = label.text() - chunks = textwrap.wrap(text, lineLength) - text = '
    '.join(chunks) - label.setText(text) - - QTimer.singleShot(100, self._resizeWrappedText) - - if self.widgets: - return - - if self.layouts: - return - - # # Start resizing height every 1 ms - # self.resizeCallsCount = 0 - # self.timer = QTimer() - # from config import warningHandler - # warningHandler.sigGeometryWarning.connect(self.timer.stop) - # self.timer.timeout.connect(self._resizeHeight) - # self.timer.start(1) - - def _resizeWrappedText(self): - screenWidth = self.screen().size().width() - 5 - self.resize(screenWidth, self.height()) - screenLeft = self.screen().geometry().left() - self.move(screenLeft, self.geometry().top()) - - def _resizeHeight(self): - try: - # Resize until a "Unable to set geometry" warning is captured - # by copnfig.warningHandler._resizeWarningHandler or # - # height doesn't change anymore - self.resize(self.width(), self.height()-1) - if self.height() == self._h or self.resizeCallsCount > 100: - self.timer.stop() - return - - self.resizeCallsCount += 1 - self._h = self.height() - except Exception as e: - # traceback.format_exc() - self.timer.stop() - - def _template( - self, parent, title, message, detailsText=None, - buttonsTexts=None, layouts=None, widgets=None, - commands=None, path_to_browse=None, browse_button_text=None, - url_to_open=None, open_url_button_text='Open url', - image_paths=None, wrapDetails=True, - add_do_not_show_again_checkbox=False - ): - if parent is not None: - self.setParent(parent) - self.setWindowTitle(title) - self.addText(message) - if commands is not None: - if isinstance(commands, str): - commands = (commands,) - for command in commands: - self.addCopiableCommand(command) - - if image_paths is not None: - if isinstance(image_paths, str): - image_paths = (image_paths,) - for image_path in image_paths: - self.addImage(image_path) - - if layouts is not None: - if myutils.is_iterable(layouts): - for layout in layouts: - self.addLayout(layout) - else: - self.addLayout(layout) - - if widgets is not None: - self._layout.addItem(QSpacerItem(20, 20), self.currentRow, 1) - self.currentRow += 1 - if myutils.is_iterable(widgets): - for widget in widgets: - self.addWidget(widget) - else: - self.addWidget(widgets) - - if path_to_browse is not None: - self.addShowInFileManagerButton( - path_to_browse, txt=browse_button_text - ) - - if url_to_open is not None: - self.addBrowseUrlButton( - url_to_open, button_text=open_url_button_text - ) - - buttons = [] - if buttonsTexts is None: - okButton = self.addButton(' Ok ') - buttons.append(okButton) - elif isinstance(buttonsTexts, str): - button = self.addButton(buttonsTexts) - buttons.append(button) - else: - for buttonText in buttonsTexts: - button = self.addButton(buttonText) - buttons.append(button) - - if detailsText is not None: - self.setDetailedText(detailsText, visible=True, wrap=wrapDetails) - - if add_do_not_show_again_checkbox: - self.addDoNotShowAgainCheckbox() - - return buttons - - def critical(self, *args, showDialog=True, **kwargs): - self.setIcon(iconName='SP_MessageBoxCritical') - buttons = self._template(*args, **kwargs) - if showDialog: - self.exec_() - return buttons - - def information(self, *args, showDialog=True, **kwargs): - self.setIcon(iconName='SP_MessageBoxInformation') - buttons = self._template(*args, **kwargs) - if showDialog: - self.exec_() - return buttons - - def warning(self, *args, showDialog=True, **kwargs): - self.setIcon(iconName='SP_MessageBoxWarning') - buttons = self._template(*args, **kwargs) - if showDialog: - self.exec_() - return buttons - - def question(self, *args, showDialog=True, **kwargs): - self.setIcon(iconName='SP_MessageBoxQuestion') - buttons = self._template(*args, **kwargs) - if showDialog: - self.exec_() - return buttons - - def _block(self): - self.loop = QEventLoop() - self.loop.exec_() - - def exec_(self): - self.show(block=True) - - def clickButtonFromText(self, buttonText): - for button in self.buttons: - if button.text() == buttonText: - button.click() - return - - def buttonCallBack(self, checked=True): - self.clickedButton = self.sender() - if self.clickedButton != self.cancelButton: - self.cancel = False - self.allowClose = True - self.close() - - def closeEvent(self, event): - if not self.allowClose: - event.ignore() - return - super().closeEvent(event) - -class FormLayout(QGridLayout): - def __init__(self): - QGridLayout.__init__(self) - - def addFormWidget( - self, formWidget, - leftLabelAlignment=Qt.AlignRight, - align=None, - row=0 - ): - for col, item in enumerate(formWidget.items): - if col==0: - alignment = leftLabelAlignment - elif col==2: - alignment = Qt.AlignLeft - else: - alignment = align - try: - if alignment is None: - self.addWidget(item, row, col) - else: - self.addWidget(item, row, col, alignment=alignment) - except TypeError: - self.addLayout(item, row, col) - -def macShortcutToWindows(shortcut: str): - if shortcut is None: - return - - s = (shortcut - .replace('Control', 'Meta') - .replace('Option', 'Alt') - .replace('Command', 'Ctrl') - ) - return s - -def windowsShortcutToMac(shortcut: str): - if shortcut is None: - return - - if not is_mac: - return shortcut - - s = (shortcut - .replace('Meta', 'Control') - .replace('Alt', 'Option') - .replace('Ctrl', 'Command') - ) - return s - -class ToolBarSeparator: - def __init__(self, width=5, toolbar: QToolBar=None): - self._parts = ( - QHWidgetSpacer(width=width), - QVLine(), - QHWidgetSpacer(width=width) - ) - self._actions = [] - self._toolbar = None - if toolbar is not None: - self.addToToolbar(toolbar) - - def addToToolbar(self, toolbar): - self._toolbar = toolbar - for part in self._parts: - action = toolbar.addWidget(part) - self._actions.append(action) - - def removeFromToolbar(self): - if self._toolbar is None: - return - - for action in self._actions: - self._toolbar.removeAction(action) - -class ToolBar(QToolBar): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.widgetsWithShortcut = {} - - for child in self.children(): - if child.objectName() == 'qt_toolbar_ext_button': - self.extendButton = child - self.extendButton.setIcon(QIcon(":expand.svg")) - break - - def addSeparator(self, width=5): - separator = ToolBarSeparator(width=width, toolbar=self) - return separator - - def removeSeparator(self, separator): - separator.removeFromToolbar() - - def addSpinBox(self, label=''): - spinbox = SpinBox(disableKeyPress=True) - if label: - spinbox.label = QLabel(label) - spinbox.labelAction = self.addWidget(spinbox.label) - - spinbox.action = self.addWidget(spinbox) - return spinbox - - def addButton(self, icon_str: str, text='', checkable=False): - action = QAction(QIcon(icon_str), text, self) - action.setCheckable(checkable) - self.addAction(action) - return action - - def addComboBox(self, items=None, label=''): - combobox = ComboBox() - - if items is not None: - combobox.addItems(items) - - if label: - combobox.label = QLabel(label) - combobox.labelAction = self.addWidget(combobox.label) - - combobox.action = self.addWidget(combobox) - return combobox - - def addLabel(self, text=''): - label = QLabel(text) - label.action = self.addWidget(label) - return label - - def addCheckBox(self, text='', checked=False): - checkbox = QCheckBox(text) - checkbox.setChecked(checked) - checkbox.action = self.addWidget(checkbox) - return checkbox - -class ManualTrackingToolBar(ToolBar): - sigIDchanged = Signal(int) - sigDisableGhost = Signal() - sigClearGhostContour = Signal() - sigClearGhostMask = Signal() - sigGhostOpacityChanged = Signal(int) - - def __init__(self, *args) -> None: - super().__init__(*args) - self.spinboxID = self.addSpinBox(label='ID to track: ') - self.spinboxID.setMinimum(1) - - self.addSeparator() - - self.showGhostCheckbox = QCheckBox('Show ghost object') - self.showGhostCheckbox.setChecked(True) - self.addWidget(self.showGhostCheckbox) - - self.ghostContourRadiobutton = QRadioButton('Contour') - self.ghostMaskRadiobutton = QRadioButton('Mask ; ') - self.ghostMaskRadiobutton.setChecked(True) - self.addWidget(self.ghostContourRadiobutton) - self.addWidget(self.ghostMaskRadiobutton) - - self.ghostMaskOpacitySpinbox = self.addSpinBox('Mask opacity: ') - self.ghostMaskOpacitySpinbox.setMaximum(100) - self.ghostMaskOpacitySpinbox.setValue(30) - - self.showGhostCheckbox.toggled.connect(self.showGhostCheckboxToggled) - self.ghostContourRadiobutton.toggled.connect( - self.ghostContourRadiobuttonToggled - ) - self.spinboxID.valueChanged.connect(self.IDchanged) - - self.ghostMaskOpacitySpinbox.valueChanged.connect( - self.ghostOpacityValueChanged - ) - - self.addSeparator() - - self.infoLabel = QLabel('') - self.addWidget(self.infoLabel) - - def showInfo(self, text): - text = html_utils.paragraph(text, font_color='black') - self.infoLabel.setText(text) - - def showWarning(self, text): - text = html_utils.paragraph(f'WARNING: {text}', font_color='red') - self.infoLabel.setText(text) - - def clearInfoText(self): - self.infoLabel.setText('') - - def IDchanged(self, value): - self.sigIDchanged.emit(value) - - def showGhostCheckboxToggled(self, checked): - disabled = not checked - self.ghostContourRadiobutton.setDisabled(disabled) - self.ghostMaskRadiobutton.setDisabled(disabled) - self.ghostMaskOpacitySpinbox.setDisabled(disabled) - self.ghostMaskOpacitySpinbox.label.setDisabled(disabled) - if disabled: - self.sigDisableGhost.emit() - - def ghostContourRadiobuttonToggled(self, checked): - self.ghostMaskOpacitySpinbox.setDisabled(checked) - self.ghostMaskOpacitySpinbox.label.setDisabled(checked) - if checked: - self.sigClearGhostMask.emit() - else: - self.sigClearGhostContour.emit() - - def ghostOpacityValueChanged(self, value): - self.sigGhostOpacityChanged.emit(value) - -class CopyLostObjectToolbar(ToolBar): - sigCopyAllObjects = Signal(int, int) - - def __init__(self, *args) -> None: - super().__init__(*args) - - action = self.addButton(':copyContour_all.svg') - # action.setShortcut('Alt+C') - action.keyPressShortcut = KeySequenceFromText('Alt+C') - action.setToolTip( - 'Copy all lost objects\n\n' - 'Shortcut: Alt+C' - ) - self.widgetsWithShortcut['Copy all lost objects'] = action - - action.triggered.connect(self.emitSigCopyAllObjects) - - self.addSeparator() - - self.maxOverlapNumberControl = self.addSpinBox( - label='Maximum overlap to accept lost object [%]: ' - ) - self.maxOverlapNumberControl.setMinimum(0) - self.maxOverlapNumberControl.setValue(10) - tooltip = ( - 'Maximum overlap to accept lost object [%]\n\n' - 'If the overlap between the lost object and an object already ' - 'existing is greater than this value,\n' - 'the lost object will not be added.' - ) - self.maxOverlapNumberControl.setToolTip(tooltip) - self.maxOverlapNumberControl.label.setToolTip(tooltip) - - self.addSeparator() - - self.untilFrameNumberControl = self.addSpinBox( - label='Copy lost object(s) for the next number of frames: ' - ) - self.untilFrameNumberControl.setMinimum(0) - self.untilFrameNumberControl.setValue(0) - - def emitSigCopyAllObjects(self): - self.sigCopyAllObjects.emit( - self.untilFrameNumberControl.value(), - self.maxOverlapNumberControl.value() - ) - -class DrawClearRegionToolbar(ToolBar): - def __init__(self, *args) -> None: - super().__init__(*args) - - group = QButtonGroup() - group.setExclusive(True) - self.clearTouchingObjsRadioButton = QRadioButton( - 'Clear all touching objects' - ) - self.clearOnlyEnclosedObjsRadioButton = QRadioButton( - 'Clear only fully enclosed objects' - ) - self.clearOnlyEnclosedObjsRadioButton.setChecked(True) - group.addButton(self.clearTouchingObjsRadioButton) - group.addButton(self.clearOnlyEnclosedObjsRadioButton) - - self.addWidget(self.clearTouchingObjsRadioButton) - self.addWidget(self.clearOnlyEnclosedObjsRadioButton) - - self.addSeparator() - - self.numZslicesUpSpinbox = self.addSpinBox( - label='Num. of z-slices to clear upwards: ' - ) - self.numZslicesUpSpinbox.setMinimum(0) - self.numZslicesUpSpinbox.setValue(0) - - self.numZslicesDownSpinbox = self.addSpinBox( - label='Num. of z-slices to clear downwards: ' - ) - self.numZslicesDownSpinbox.setMinimum(0) - self.numZslicesDownSpinbox.setValue(0) - - def setZslicesControlEnabled(self, enabled, SizeZ=None): - self.numZslicesUpSpinbox.labelAction.setVisible(enabled) - self.numZslicesUpSpinbox.action.setVisible(enabled) - - self.numZslicesDownSpinbox.labelAction.setVisible(enabled) - self.numZslicesDownSpinbox.action.setVisible(enabled) - - if SizeZ is None: - return - - self.numZslicesUpSpinbox.setMaximum(SizeZ) - self.numZslicesDownSpinbox.setMaximum(SizeZ) - - def zRange(self, z_slice, SizeZ): - if z_slice is None: - zRange = (0, SizeZ) - return zRange - - numZslicesUp = self.numZslicesUpSpinbox.value() - numZslicesDown = self.numZslicesDownSpinbox.value() - - zmin = z_slice - numZslicesDown - zmax = z_slice + numZslicesDown + 1 - - zmin = zmin if zmin >= 0 else 0 - zmax = zmax if zmax <= SizeZ else SizeZ - - return (zmin, zmax) - -class ManualBackgroundToolBar(ToolBar): - sigIDchanged = Signal(int) - - def __init__(self, *args) -> None: - super().__init__(*args) - self.spinboxID = self.addSpinBox(label='Set background of ID ') - self.spinboxID.setMinimum(1) - self.spinboxID.valueChanged.connect(self.IDchanged) - - self.infoLabel = QLabel('') - self.addWidget(self.infoLabel) - - def IDchanged(self, value): - self.sigIDchanged.emit(value) - - def showWarning(self, text): - text = html_utils.paragraph(f'WARNING: {text}', font_color='red') - self.infoLabel.setText(text) - - def clearInfoText(self): - self.infoLabel.setText('') - - -class rightClickToolButton(QToolButton): - sigRightClick = Signal(object) - sigLeftClick = Signal(object, object) - - def __init__(self, parent=None): - super().__init__(parent) - - def mousePressEvent(self, event): - if event.button() == Qt.MouseButton.LeftButton: - super().mousePressEvent(event) - self.sigLeftClick.emit(self, event) - elif event.button() == Qt.MouseButton.RightButton: - self.sigRightClick.emit(event) - -class SavePointsLayerButton(rightClickToolButton): - sigRenameTableAction = Signal(object, str) - - def __init__(self, table_endname, parent=None): - super().__init__(parent=parent) - self.setIcon(QIcon(':file-save.svg')) - - self.table_endname = table_endname - - self.setToolTip( - "Save annotated points in the CSV file ending " - f"with '{self.table_endname}.csv'" - ) - - self.sigRightClick.connect(self.showContextMenu) - - def showContextMenu(self, event): - contextMenu = QMenu(self) - contextMenu.addSeparator() - - renameAction = QAction('Rename points layer table') - renameAction.triggered.connect(self.renameTable) - contextMenu.addAction(renameAction) - - contextMenu.exec(event.globalPos()) - - def renameTable(self): - win = apps.filenameDialog( - parent=self, - title='Rename points layer table file', - allowEmpty=False, - defaultEntry=self.table_endname, - ext='.csv', - ) - win.exec_() - if win.cancel: - return - - self.table_endname = win.entryText - self.setToolTip( - "Save annotated points in the CSV file ending " - f"with '{self.table_endname}.csv'" - ) - self.sigRenameTableAction.emit(self, self.table_endname) - -class ToolButtonCustomColor(rightClickToolButton): - def __init__(self, symbol, color='r', parent=None): - super().__init__(parent=parent) - if not isinstance(color, QColor): - color = pg.mkColor(color) - self.symbol = symbol - self.setColor(color) - - def setColor(self, color): - self.penColor = color - self.brushColor = [0, 0, 0, 100] - self.brushColor[:3] = color.getRgb()[:3] - - def updateSymbol(self, symbol, update=True): - self.symbol = symbol - if not update: - return - self.update() - - def updateColor(self, color, update=True): - self.setColor(color) - if not update: - return - self.update() - - def updateIcon(self, symbol, color): - self.updateSymbol(symbol) - self.updateColor(color) - self.update() - - def paintEvent(self, event): - QToolButton.paintEvent(self, event) - p = QPainter(self) - w, h = self.width(), self.height() - sf = 0.6 - p.scale(w*sf, h*sf) - p.translate(0.5/sf, 0.5/sf) - symbol = pg.graphicsItems.ScatterPlotItem.Symbols[self.symbol] - pen = pg.mkPen(color=self.penColor, width=2) - brush = pg.mkBrush(color=self.brushColor) - try: - p.setRenderHint(QPainter.RenderHint.Antialiasing) - p.setPen(pen) - p.setBrush(brush) - p.drawPath(symbol) - except Exception as e: - traceback.print_exc() - finally: - p.end() - -class GradientToolButton(rightClickToolButton): - def __init__(self, colors=((255, 0, 0),), parent=None): - super().__init__(parent=parent) - self._qcolors = [pg.mkColor(c) for c in colors] - if len(self._qcolors) < 2: - self._qcolors.append(self._qcolors[0]) - - def paintEvent(self, event): - super().paintEvent(event) - - painter = QPainter(self) - painter.setRenderHint(QPainter.Antialiasing) - - pen = pg.mkPen(color=self._qcolors[-1], width=2) - - pad = 7 - - rect = self.rect().adjusted(pad, pad, -pad, -pad) # A little padding - - # Gradient: bottom to top - gradient = QLinearGradient( - QPointF(rect.bottomLeft()), QPointF(rect.topLeft()) - ) - - # Set color stops evenly distributed - num_colors = len(self._qcolors) - for i, color in enumerate(self._qcolors): - gradient.setColorAt(i / (num_colors - 1), color) - - if not self.isChecked(): - painter.setOpacity(0.4) - - painter.setBrush(gradient) - painter.setPen(pen) - painter.drawRect(rect) - - painter.end() - -class PointsLayerToolButton(ToolButtonCustomColor): - sigEditAppearance = Signal(object) - sigShowIdsToggled = Signal(object, bool) - sigRemove = Signal(object) - - def __init__(self, symbol, color='r', parent=None): - super().__init__(symbol, color=color, parent=parent) - self.sigRightClick.connect(self.showContextMenu) - - def showContextMenu(self, event): - contextMenu = QMenu(self) - contextMenu.addSeparator() - - editAction = QAction('Edit points appearance...') - editAction.triggered.connect(self.editAppearance) - contextMenu.addAction(editAction) - - removeAction = QAction('Remove points') - removeAction.triggered.connect(self.emitRemove) - contextMenu.addAction(removeAction) - - showIdsAction = QAction('Show point ids') - showIdsAction.setCheckable(True) - showIdsAction.setChecked(True) - contextMenu.addAction(showIdsAction) - showIdsAction.toggled.connect(self.emitShowIdsToggled) - - contextMenu.exec(event.globalPos()) - - def emitRemove(self): - self.sigRemove.emit(self) - - def emitShowIdsToggled(self, checked): - self.sigShowIdsToggled.emit(self, checked) - - def editAppearance(self): - self.sigEditAppearance.emit(self) - -class customAnnotToolButton(ToolButtonCustomColor): - sigRemoveAction = Signal(object) - sigKeepActiveAction = Signal(object) - sigModifyAction = Signal(object) - sigHideAction = Signal(object) - - def __init__( - self, symbol, color, keepToolActive=True, parent=None, - isHideChecked=True - ): - super().__init__(symbol, color=color, parent=parent) - self.symbol = symbol - self.keepToolActive = keepToolActive - self.isHideChecked = isHideChecked - self.sigRightClick.connect(self.showContextMenu) - - def showContextMenu(self, event): - contextMenu = QMenu(self) - contextMenu.addSeparator() - - removeAction = QAction('Remove annotation') - removeAction.triggered.connect(self.removeAction) - contextMenu.addAction(removeAction) - - editAction = QAction('Modify annotation parameters...') - editAction.triggered.connect(self.modifyAction) - contextMenu.addAction(editAction) - - hideAction = QAction('Hide annotations') - hideAction.setCheckable(True) - hideAction.setChecked(self.isHideChecked) - hideAction.triggered.connect(self.hideAction) - contextMenu.addAction(hideAction) - - keepActiveAction = QAction('Keep tool active after using it') - keepActiveAction.setCheckable(True) - keepActiveAction.setChecked(self.keepToolActive) - keepActiveAction.triggered.connect(self.keepToolActiveActionToggled) - contextMenu.addAction(keepActiveAction) - - contextMenu.exec(event.globalPos()) - - def keepToolActiveActionToggled(self, checked): - self.keepToolActive = checked - self.sigKeepActiveAction.emit(self) - - def modifyAction(self): - self.sigModifyAction.emit(self) - - def removeAction(self): - self.sigRemoveAction.emit(self) - - def hideAction(self, checked): - self.isHideChecked = checked - self.sigHideAction.emit(self) - -class LabelRoiCircularItem(pg.ScatterPlotItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - - def setImageShape(self, shape): - self._shape = shape - - def slice(self, zRange=None, tRange=None): - self.mask() - if zRange is None: - _slice = self._slice - else: - zmin, zmax = zRange - _slice = (slice(zmin, zmax), *self._slice) - - if tRange is not None: - tmin, tmax = tRange - _slice = (slice(tmin, tmax), *_slice) - - return _slice - - def mask(self): - shape = self._shape - radius = int(self.opts['size']/2) - mask = skimage.morphology.disk(radius, dtype=bool) - xx, yy = self.getData() - Yc, Xc = yy[0], xx[0] - mask, self._slice = myutils.clipSelemMask(mask, shape, Yc, Xc, copy=False) - return mask - -class Toggle(QCheckBox): - def __init__( - self, - label_text='', - initial=None, - width=80, - bg_color='#b3b3b3', - circle_color='#ffffff', - active_color='#26dd66',# '#005ce6', - animation_curve=QEasingCurve.Type.InOutQuad - ): - QCheckBox.__init__(self) - - # self.setFixedSize(width, 28) - self.setCursor(Qt.PointingHandCursor) - - self._label_text = label_text - self._bg_color = bg_color - self._circle_color = circle_color - self._active_color = active_color - self._disabled_active_color = colors.lighten_color(active_color) - self._disabled_circle_color = colors.lighten_color(circle_color) - self._disabled_bg_color = colors.lighten_color(bg_color, amount=0.5) - self._circle_margin = 4 - - self._circle_position = int(self._circle_margin/2) - self.animation = QPropertyAnimation(self, b'circle_position', self) - self.animation.setEasingCurve(animation_curve) - self.animation.setDuration(200) - - self.stateChanged.connect(self.start_transition) - self.requestedState = None - - self.installEventFilter(self) - self._isChecked = False - - if initial is not None: - self.setChecked(initial) - - def sizeHint(self): - return QSize(36, 18) - - def eventFilter(self, object, event): - # To get the actual position of the circle we need to wait that - # the widget is visible before setting the state - if event.type() == QEvent.Type.Show and self.requestedState is not None: - self.setChecked(self.requestedState) - return False - - def setChecked(self, state): - # To get the actual position of the circle we need to wait that - # the widget is visible before setting the state - self._isChecked = state - if self.isVisible(): - self.requestedState = None - QCheckBox.setChecked(self, state>0) - else: - self.requestedState = state - - def isChecked(self): - if self.isVisible(): - return super().isChecked() - else: - return self._isChecked - - def circlePos(self, state: bool): - start = int(self._circle_margin/2) - if state: - if self.isVisible(): - height, width = self.height(), self.width() - else: - sizeHint = self.sizeHint() - height, width = sizeHint.height(), sizeHint.width() - circle_diameter = height-self._circle_margin - pos = width-start-circle_diameter - else: - pos = start - return pos - - @Property(float) - def circle_position(self): - return self._circle_position - - @circle_position.setter - def circle_position(self, pos): - self._circle_position = pos - self.update() - - def start_transition(self, state): - self.animation.stop() - pos = self.circlePos(state) - self.animation.setEndValue(pos) - self.animation.start() - - def hitButton(self, pos: QPoint): - return self.contentsRect().contains(pos) - - def setDisabled(self, disable): - QCheckBox.setDisabled(self, disable) - if hasattr(self, 'label'): - self.label.setDisabled(disable) - self.update() - - def paintEvent(self, e): - circle_color = ( - self._circle_color if self.isEnabled() - else self._disabled_circle_color - ) - active_color = ( - self._active_color if self.isEnabled() - else self._disabled_active_color - ) - unchecked_color = ( - self._bg_color if self.isEnabled() - else self._disabled_bg_color - ) - - # set painter - p = QPainter(self) - p.setRenderHint(QPainter.RenderHint.Antialiasing) - - # set no pen - p.setPen(Qt.NoPen) - - # draw rectangle - rect = QRect(0, 0, self.width(), self.height()) - - if not self.isChecked(): - # Draw background - p.setBrush(QColor(unchecked_color)) - half_h = int(self.height()/2) - p.drawRoundedRect( - 0, 0, rect.width(), self.height(), half_h, half_h - ) - - # Draw circle - p.setBrush(QColor(circle_color)) - p.drawEllipse( - int(self._circle_position), int(self._circle_margin/2), - self.height()-self._circle_margin, - self.height()-self._circle_margin - ) - else: - # Draw background - p.setBrush(QColor(active_color)) - half_h = int(self.height()/2) - p.drawRoundedRect( - 0, 0, rect.width(), self.height(), half_h, half_h - ) - - # Draw circle - p.setBrush(QColor(circle_color)) - p.drawEllipse( - int(self._circle_position), int(self._circle_margin/2), - self.height()-self._circle_margin, - self.height()-self._circle_margin - ) - - p.end() - -def QKeyEventToString(event: QKeyEvent, notAllowedModifier=None): - isAltKey = event.key()==Qt.Key_Alt - isCtrlKey = event.key()==Qt.Key_Control - isShiftKey = event.key()==Qt.Key_Shift - isModifierKey = isAltKey or isCtrlKey or isShiftKey - - modifiers = event.modifiers() - isNotAllowedMod = ( - notAllowedModifier is not None and modifiers == notAllowedModifier - ) - if isNotAllowedMod: - return - - modifers_value = modifiers.value if PYQT6 else modifiers - if isModifierKey: - keySequenceText = KeySequenceFromText(modifers_value).toString() - else: - keySequenceText = QKeySequence(modifers_value | event.key()).toString() - - keySequenceText = keySequenceText.encode('ascii', 'ignore').decode('utf-8') - - return keySequenceText - -class ShortcutLineEdit(QLineEdit): - def __init__( - self, parent=None, allowModifiers=False, notAllowedModifier=None - ): - self.keySequence = None - super().__init__(parent) - self._allowModifiers = allowModifiers - self._notAllowedModifier = notAllowedModifier - self.setAlignment(Qt.AlignCenter) - - def text(self): - text = macShortcutToWindows(super().text()) - - return text - - def setText(self, text): - text = windowsShortcutToMac(text) - - super().setText(text) - if not text: - self.keySequence = None - return - try: - self.keySequence = KeySequenceFromText(self.text()) - except Exception as e: - pass - - def keyPressEvent(self, event: QKeyEvent): - if event.key() == Qt.Key_Backspace or event.key() == Qt.Key_Delete: - self.setText('') - return - - keySequenceText = QKeyEventToString( - event, notAllowedModifier=self._notAllowedModifier - ) - self.setText(keySequenceText) - self.key = event.key() - - def keyReleaseEvent(self, event: QKeyEvent) -> None: - if self.text().endswith('+'): - if not self._allowModifiers: - self.setText('') - else: - self.setText(self.text().rstrip('+').strip()) - - -class selectStartStopFrames(QGroupBox): - def __init__(self, SizeT, currentFrameNum=0, parent=None): - super().__init__(parent) - selectFramesLayout = QGridLayout() - - self.startFrame_SB = QSpinBox() - self.startFrame_SB.setAlignment(Qt.AlignCenter) - self.startFrame_SB.setMinimum(1) - self.startFrame_SB.setMaximum(SizeT-1) - self.startFrame_SB.setValue(currentFrameNum) - - self.stopFrame_SB = QSpinBox() - self.stopFrame_SB.setAlignment(Qt.AlignCenter) - self.stopFrame_SB.setMinimum(1) - self.stopFrame_SB.setMaximum(SizeT) - self.stopFrame_SB.setValue(SizeT) - - selectFramesLayout.addWidget(QLabel('Start frame n.'), 0, 0) - selectFramesLayout.addWidget(self.startFrame_SB, 1, 0) - - selectFramesLayout.addWidget(QLabel('Stop frame n.'), 0, 1) - selectFramesLayout.addWidget(self.stopFrame_SB, 1, 1) - - self.warningLabel = QLabel() - palette = self.warningLabel.palette() - palette.setColor(self.warningLabel.backgroundRole(), Qt.red) - palette.setColor(self.warningLabel.foregroundRole(), Qt.red) - self.warningLabel.setPalette(palette) - selectFramesLayout.addWidget( - self.warningLabel, 2, 0, 1, 2, alignment=Qt.AlignCenter - ) - - self.setLayout(selectFramesLayout) - - self.stopFrame_SB.valueChanged.connect(self._checkRange) - - def _checkRange(self): - start = self.startFrame_SB.value() - stop = self.stopFrame_SB.value() - if stop <= start: - self.warningLabel.setText( - 'stop frame smaller than start frame' - ) - else: - self.warningLabel.setText('') - -class formWidget(QWidget): - sigApplyButtonClicked = Signal(object) - sigComputeButtonClicked = Signal(object) - - def __init__( - self, widget, - initialVal=None, - stretchWidget=True, - widgetAlignment=None, - labelTextLeft='', - labelTextRight='', - font=None, - addInfoButton=False, - addApplyButton=False, - addComputeButton=False, - addActivateCheckbox=False, - key='', - infoTxt='', - valueGetterName='value', - toolTip='', - parent=None - ): - QWidget.__init__(self, parent) - self.widget = widget - self.key = key - self.infoTxt = infoTxt - self.widgetAlignment = widgetAlignment - self.valueGetterName = valueGetterName - - widget.setParent(self) - - if isinstance(initialVal, bool): - widget.setChecked(initialVal) - elif isinstance(initialVal, str): - widget.setCurrentText(initialVal) - elif isinstance(initialVal, float) or isinstance(initialVal, int): - widget.setValue(initialVal) - - self.items = [] - - if font is None: - font = QFont() - font.setPixelSize(13) - - self.labelLeft = QClickableLabel(widget) - self.labelLeft.setText(labelTextLeft) - self.labelLeft.setFont(font) - self.items.append(self.labelLeft) - - if not stretchWidget: - widgetLayout = QHBoxLayout() - if widgetAlignment != 'left': - widgetLayout.addStretch(1) - widgetLayout.addWidget(widget) - if widgetAlignment != 'right': - widgetLayout.addStretch(1) - self.items.append(widgetLayout) - else: - self.items.append(widget) - - self.labelRight = QClickableLabel(widget) - self.labelRight.setText(labelTextRight) - self.labelRight.setFont(font) - self.items.append(self.labelRight) - - if toolTip: - self.labelLeft.setToolTip(toolTip) - self.widget.setToolTip(toolTip) - self.labelRight.setToolTip(toolTip) - - if addInfoButton: - infoButton = QPushButton(self) - infoButton.setCursor(Qt.WhatsThisCursor) - infoButton.setIcon(QIcon(":info.svg")) - if labelTextLeft: - infoButton.setToolTip( - f'Info about "{self.labelLeft.text()}" parameter' - ) - else: - infoButton.setToolTip( - f'Info about "{self.labelRight.text()}" measurement' - ) - infoButton.clicked.connect(self.showInfo) - self.infoButton = infoButton - self.items.append(infoButton) - - if addApplyButton: - applyButton = QPushButton(self) - applyButton.setCursor(Qt.PointingHandCursor) - applyButton.setCheckable(True) - applyButton.setIcon(QIcon(":apply.svg")) - applyButton.setToolTip(f'Apply this step and visualize results') - applyButton.clicked.connect(self.applyButtonClicked) - self.items.append(applyButton) - - if addComputeButton: - computeButton = QPushButton(self) - computeButton.setCursor(Qt.BusyCursor) - computeButton.setIcon(QIcon(":compute.svg")) - computeButton.setToolTip(f'Compute this step and visualize results') - computeButton.clicked.connect(self.computeButtonClicked) - self.items.append(computeButton) - - self.activateCheckbox = None - if addActivateCheckbox: - self.activateCheckbox = QCheckBox('Activate') - self.activateCheckbox.setChecked(False) - self.widget.setDisabled(True) - self.activateCheckbox.toggled.connect(self.setWidgetEnabled) - self.items.append(self.activateCheckbox) - - self.labelLeft.clicked.connect(self.tryChecking) - self.labelRight.clicked.connect(self.tryChecking) - - def setWidgetEnabled(self, checked): - self.widget.setDisabled(not checked) - - def value(self): - if self.activateCheckbox is None: - return getattr(self.widget, self.valueGetterName)() - - if not self.activateCheckbox.isChecked(): - return - - return getattr(self.widget, self.valueGetterName)() - - def tryChecking(self, label): - try: - self.widget.setChecked(not self.widget.isChecked()) - except AttributeError as e: - pass - - def applyButtonClicked(self): - self.sigApplyButtonClicked.emit(self) - - def computeButtonClicked(self): - self.sigComputeButtonClicked.emit(self) - - def showInfo(self): - msg = myMessageBox() - msg.setIcon() - msg.setWindowTitle(f'{self.labelLeft.text()} info') - msg.addText(self.infoTxt) - msg.addButton(' Ok ') - msg.exec_() - - def setDisabled(self, disabled: bool) -> None: - for item in self.items: - try: - item.setDisabled(disabled) - except Exception as err: - pass - -class ToggleTerminalButton(PushButton): - sigClicked = Signal(bool) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setIcon(QIcon(':terminal_up.svg')) - self.setFixedSize(34,18) - self.setIconSize(QSize(30, 14)) - self.setFlat(True) - self.terminalVisible = False - self.clicked.connect(self.mouseClick) - - def mouseClick(self): - if self.terminalVisible: - self.setIcon(QIcon(':terminal_up.svg')) - self.terminalVisible = False - else: - self.setIcon(QIcon(':terminal_down.svg')) - self.terminalVisible = True - self.sigClicked.emit(self.terminalVisible) - - def showEvent(self, a0) -> None: - self.idlePalette = self.palette() - return super().showEvent(a0) - - def enterEvent(self, event) -> None: - self.setFlat(False) - # pal = self.palette() - # pal.setColor(QPalette.ColorRole.Button, QColor(200, 200, 200)) - # self.setAutoFillBackground(True) - # self.setPalette(pal) - self.update() - return super().enterEvent(event) - - def leaveEvent(self, event) -> None: - self.setFlat(True) - # self.setPalette(self.idlePalette) - self.update() - return super().leaveEvent(event) - -class CenteredDoubleSpinbox(QDoubleSpinBox): - def __init__(self, parent=None): - super().__init__(parent=parent) - self.setAlignment(Qt.AlignCenter) - self.setMaximum(2**31-1) - -class readOnlyDoubleSpinbox(QDoubleSpinBox): - def __init__(self, parent=None): - super().__init__(parent=parent) - self.setReadOnly(True) - self.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons) - self.setAlignment(Qt.AlignCenter) - self.setMaximum(2**31-1) - # self.setStyleSheet('background-color: rgba(240, 240, 240, 200);') - -class readOnlySpinbox(QSpinBox): - def __init__(self, parent=None): - super().__init__(parent=parent) - self.setReadOnly(True) - self.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons) - self.setAlignment(Qt.AlignCenter) - self.setMaximum(2**31-1) - # self.setStyleSheet('background-color: rgba(240, 240, 240, 200);') - -class DoubleSpinBox(QDoubleSpinBox): - sigValueChanged = Signal(int) - - def __init__(self, parent=None, disableKeyPress=False): - super().__init__(parent=parent) - self.setAlignment(Qt.AlignCenter) - self.setMaximum(2**31-1) - self.setMinimum(-2**31) - self._valueChangedFunction = None - self.disableKeyPress = disableKeyPress - - def keyPressEvent(self, event) -> None: - isBackSpaceKey = event.key() == Qt.Key_Backspace - isDeleteKey = event.key() == Qt.Key_Delete - try: - int(event.text()) - isIntegerKey = True - except: - isIntegerKey = False - acceptEvent = isBackSpaceKey or isDeleteKey or isIntegerKey - if self.disableKeyPress and not acceptEvent: - event.ignore() - self.clearFocus() - else: - super().keyPressEvent(event) - - def textFromValue(self, value: float) -> str: - text = super().textFromValue(value) - return text.replace(QLocale().decimalPoint(), '.') - - def valueFromText(self, text: str) -> float: - text = text.replace('.', QLocale().decimalPoint()) - return super().valueFromText(text) - -class SpinBox(QSpinBox): - sigValueChanged = Signal(int) - sigUpClicked = Signal() - sigDownClicked = Signal() - - def __init__( - self, - parent=None, - disableKeyPress=False, - allowNegative=True - ): - super().__init__(parent=parent) - self.setAlignment(Qt.AlignCenter) - self.setMaximum(2**31-1) - if allowNegative: - self.setMinimum(-2**31) - else: - self.setMinimum(0) - self._valueChangedFunction = None - self.disableKeyPress = disableKeyPress - self._linkedWidget = None - - def mousePressEvent(self, event) -> None: - super().mousePressEvent(event) - opt = QStyleOptionSpinBox() - self.initStyleOption(opt) - - control = self.style().hitTestComplexControl( - QStyle.ComplexControl.CC_SpinBox, opt, event.pos(), self - ) - if control == QStyle.SubControl.SC_SpinBoxUp: - self.sigUpClicked.emit() - elif control == QStyle.SubControl.SC_SpinBoxDown: - self.sigDownClicked.emit() - - # def focusOutEvent(self, event): - # self.editingFinished.emit() - # super().focusOutEvent(event) - # printl('emitted') - - def keyPressEvent(self, event) -> None: - isBackSpaceKey = event.key() == Qt.Key_Backspace - isDeleteKey = event.key() == Qt.Key_Delete - try: - int(event.text()) - isIntegerKey = True - except: - isIntegerKey = False - acceptEvent = isBackSpaceKey or isDeleteKey or isIntegerKey - if self.disableKeyPress and not acceptEvent: - event.ignore() - self.clearFocus() - else: - super().keyPressEvent(event) - - def connectValueChanged(self, function): - self._valueChangedFunction = function - self.valueChanged.connect(function) - - def setValue(self, value, setLinkedWidget=True): - super().setValue(int(value)) - if self._linkedWidget is not None and setLinkedWidget: - self._linkedWidget.setValue(value) - - def setValueNoEmit(self, value): - if self._valueChangedFunction is None: - self.setValue(value) - return - try: - self.valueChanged.disconnect() - except TypeError as e: # this fails if its not cennected yet - pass - - self.setValue(value) - self.valueChanged.connect(self._valueChangedFunction) - - def wheelEvent(self, event): - event.ignore() - - def setLinkedValueWidget(self, widget): - self._linkedWidget = widget - -class ReadOnlyLineEdit(QLineEdit): - def __init__(self, parent=None): - super().__init__(parent=parent) - self.setReadOnly(True) - # self.setStyleSheet( - # 'background-color: rgba(240, 240, 240, 200);' - # ) - self.installEventFilter(self) - - def eventFilter(self, a0: 'QObject', a1: 'QEvent') -> bool: - if a1.type() == QEvent.Type.FocusIn: - return True - return super().eventFilter(a0, a1) - - def setValue(self, value): - self.setText(str(value)) - - def value(self, casting_func: callable = None): - text = self.text() - if casting_func is not None: - return casting_func(text) - return text - -class FloatLineEdit(QLineEdit): - valueChanged = Signal(float) - - def __init__( - self, *args, notAllowed=None, allowNegative=True, initial=None, - readOnly=False, decimals=6, warningValues=None - ): - QLineEdit.__init__(self, *args) - if readOnly: - self.setReadOnly(readOnly) - self.notAllowed = notAllowed - self.warningValues = warningValues - self._maximum = np.inf - self._minimum = -np.inf - self._decimals = decimals - - self.isNumericRegExp = rf'^{float_regex(allow_negative=allowNegative)}$' - regExp = QRegularExpression(self.isNumericRegExp) - self.setValidator(QRegularExpressionValidator(regExp)) - self.setAlignment(Qt.AlignCenter) - - font = QFont() - font.setPixelSize(11) - self.setFont(font) - - self.textChanged.connect(self.emitValueChanged) - - if initial is not None: - self.setValue(initial) - else: - self.setValue(0) - - def setDecimals(self, decimals): - self._decimals = 6 - - def castMinMax(self, value: int): - if value > self._maximum: - value = self._maximum - if value < self._minimum: - value = self._minimum - return value - - def setValue(self, value: float): - value = self.castMinMax(value) - self.setText(str(round(value, self._decimals))) - - def value(self): - m = re.match(self.isNumericRegExp, self.text()) - if m is not None: - text = m.group(0) - try: - val = float(text) - except ValueError: - val = 0.0 - else: - val = 0.0 - - return self.castMinMax(val) - - def setMaximum(self, maximum): - self._maximum = maximum - self.setValue(self.value()) - - def setMinimum(self, minimum): - self._minimum = minimum - self.setValue(self.value()) - - def emitValueChanged(self, text): - val = self.value() - reset_stylesheet = True - if self.warningValues is not None and val in self.warningValues: - self.setStyleSheet(LINEEDIT_WARNING_STYLESHEET) - reset_stylesheet = False - - if self.notAllowed is not None and val in self.notAllowed: - self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) - reset_stylesheet = False - else: - self.valueChanged.emit(self.value()) - - if reset_stylesheet: - self.setStyleSheet('') - -class IntLineEdit(QLineEdit): - valueChanged = Signal(float) - - def __init__( - self, *args, notAllowed=None, allowNegative=True, initial=None, - readOnly=False - ): - QLineEdit.__init__(self, *args) - self.notAllowed = notAllowed - if readOnly: - self.setReadOnly(readOnly) - - self._maximum = np.inf - self._minimum = -np.inf - - self._regExp = r'\d+' - if allowNegative: - self._regExp = r'-?\d+' - - regExp = QRegularExpression(self._regExp) - self.setValidator(QRegularExpressionValidator(regExp)) - self.setAlignment(Qt.AlignCenter) - - font = QFont() - font.setPixelSize(11) - self.setFont(font) - - self.textChanged.connect(self.emitValueChanged) - - if initial is not None: - self.setValue(initial) - else: - self.setValue(0) - - def setMaximum(self, maximum): - self._maximum = maximum - self.setValue(self.value()) - - def setMinimum(self, minimum): - self._minimum = minimum - self.setValue(self.value()) - - def castMinMax(self, value: int): - if value > self._maximum: - value = self._maximum - if value < self._minimum: - value = self._minimum - return value - - def setValue(self, value: int): - value = self.castMinMax(value) - self.setText(str(value)) - - def value(self): - m = re.match(self._regExp, self.text()) - if m is not None: - text = m.group(0) - try: - val = int(text) - except ValueError: - val = 0 - else: - val = 0 - - return self.castMinMax(val) - - def emitValueChanged(self, text): - if not text: - return - - val = self.value() - self.setValue(val) - if self.notAllowed is not None and val in self.notAllowed: - self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) - else: - self.setStyleSheet('') - self.valueChanged.emit(self.value()) - -class CheckboxesGroupBox(QGroupBox): - def __init__( - self, texts, title='', checkable=False, parent=None - ): - super().__init__(parent) - - self.setTitle(title) - self.setCheckable(checkable) - layout = QVBoxLayout() - - scrollLayout = QVBoxLayout() - container = QWidget() - scrollarea = QScrollArea() - - self.checkBoxes = [] - for text in texts: - checkbox = QCheckBox(text) - checkbox.setChecked(True) - scrollLayout.addWidget(checkbox) - self.checkBoxes.append(checkbox) - - container.setLayout(scrollLayout) - scrollarea.setWidget(container) - layout.addWidget(scrollarea) - - buttonsLayout = QHBoxLayout() - selectAllButton = selectAllPushButton() - selectAllButton.sigClicked.connect(self.checkAll) - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(selectAllButton) - layout.addLayout(buttonsLayout) - - self.setLayout(layout) - - def checkAll(self, button, checked): - for checkBox in self.checkBoxes: - checkBox.setChecked(checked) - -class _metricsQGBox(QGroupBox): - sigDelClicked = Signal(str, object) - - def __init__( - self, desc_dict, title, favourite_funcs=None, isZstack=False, - equations=None, addDelButton=False, delButtonMetricsDesc=None, - parent=None, addCalcForEachZsliceToggle=False - ): - QGroupBox.__init__(self, parent) - - highlightRgba = _palettes._highlight_rgba() - r, g, b, a = highlightRgba - self._highlightStylesheetColor = f'rgb({r}, {g}, {b})' - - self._parent = parent - self.scrollArea = QScrollArea() - self.scrollAreaWidget = QWidget() - self.favourite_funcs = favourite_funcs - - self.doNotWarn = False - - layout = QVBoxLayout() - inner_layout = QVBoxLayout() - self.inner_layout = inner_layout - if delButtonMetricsDesc is None: - delButtonMetricsDesc = [] - - self.checkBoxes = [] - self.checkedState = {} - for metric_colname, metric_desc in desc_dict.items(): - rowLayout = QHBoxLayout() - - checkBox = QCheckBox(metric_colname) - checkBox.setChecked(True) - checkBox.scrollArea = self.scrollArea - self.checkBoxes.append(checkBox) - self.checkedState[checkBox] = True - - try: - checkBox.equation = equations[metric_colname] - except Exception as e: - pass - - if addDelButton or metric_colname in delButtonMetricsDesc: - delButton = delPushButton() - delButton.setToolTip('Delete custom combined measurement') - delButton.colname = metric_colname - delButton.checkbox = checkBox - delButton.clicked.connect(self.onDelClicked) - delButton._layout = rowLayout - rowLayout.addWidget(delButton) - - infoButton = infoPushButton() - infoButton.setCursor(Qt.WhatsThisCursor) - infoButton.info = metric_desc - infoButton.colname = metric_colname - infoButton.clicked.connect(self.showInfo) - - rowLayout.addWidget(infoButton) - rowLayout.addWidget(checkBox) - rowLayout.addStretch(1) - - inner_layout.addLayout(rowLayout) - - self.scrollAreaWidget.setLayout(inner_layout) - self.scrollArea.setWidget(self.scrollAreaWidget) - layout.addWidget(self.scrollArea) - - buttonsLayout = QHBoxLayout() - - buttonsLayout.addStretch(1) - - self.selectAllButton = selectAllPushButton() - self.selectAllButton.sigClicked.connect(self.checkAll) - - buttonsLayout.addWidget(self.selectAllButton) - - if favourite_funcs is not None: - self.loadFavouritesButton = reloadPushButton( - ' Load last selection... ' - ) - self.loadFavouritesButton.clicked.connect(self.checkFavouriteFuncs) - # self.checkFavouriteFuncs() - buttonsLayout.addWidget(self.loadFavouritesButton) - - layout.addLayout(buttonsLayout) - - self.calcForEachZsliceToggle = None - if addCalcForEachZsliceToggle: - buttonsLayout = QHBoxLayout() - self.calcForEachZsliceToggle = Toggle() - tooltip = ( - 'Calculate `cell_area` for each z-slice.\n\n' - 'The measurements will be saved in the column with name\n' - 'ending with `_zsliceN` where N is the z-slice number\n' - '(starting from 0).' - ) - calcForEachZsliceLabel = QClickableLabel( - 'Calculate for each z-slice' - ) - calcForEachZsliceLabel.setToolTip(tooltip) - self.calcForEachZsliceToggle.setToolTip(tooltip) - buttonsLayout.addWidget(self.calcForEachZsliceToggle) - buttonsLayout.addWidget(calcForEachZsliceLabel) - buttonsLayout.addStretch(1) - layout.addLayout(buttonsLayout) - calcForEachZsliceLabel.clicked.connect( - partial( - self.toggleCalcForEachZslice, - toggle=self.calcForEachZsliceToggle - ) - ) - - self.setTitle(title) - self.setCheckable(True) - self.setLayout(layout) - _font = QFont() - _font.setPixelSize(11) - self.setFont(_font) - - self.toggled.connect(self.toggled_cb) - - def toggleCalcForEachZslice(self, label, toggle=None): - if toggle is None: - toggle = self.calcForEachZsliceToggle - - toggle.setChecked(not toggle.isChecked()) - - def isCalcForEachZsliceRequested(self): - if self.calcForEachZsliceToggle is None: - return False - - return self.calcForEachZsliceToggle.isChecked() - - def highlightCheckboxesFromSearchText(self, text): - for checkbox in self.checkBoxes: - if not text: - highlighted = False - else: - highlighted = checkbox.text().lower().find(text.lower()) != -1 - - self.setCheckboxHighlighted(highlighted, checkbox) - - def setCheckboxHighlighted(self, highlighted, checkbox): - if highlighted: - checkbox.setStyleSheet( - f'background: {self._highlightStylesheetColor}; color: black' - ) - self.scrollArea.ensureWidgetVisible(checkbox) - else: - checkbox.setStyleSheet('') - - def onDelClicked(self): - button = self.sender() - button.checkbox.setChecked(False) - self.sigDelClicked.emit(button.colname, button._layout) - - def toggled_cb(self, checked): - for checkbox in self.checkBoxes: - if not checked: - self.checkedState[checkbox] = checkbox.isChecked() - checkbox.setChecked(False) - else: - checkbox.setChecked(self.checkedState[checkbox]) - - def checkFavouriteFuncs(self, checked=True, isZstack=False): - self.doNotWarn = True - if self._parent is not None: - self._parent.doNotWarn = True - for checkBox in self.checkBoxes: - checkBox.setChecked(False) - for favourite_func in self.favourite_funcs: - func_name = checkBox.text() - if func_name.endswith(favourite_func): - checkBox.setChecked(True) - break - self.doNotWarn = False - if self._parent is not None: - self._parent.doNotWarn = False - - def checkAll(self, button, checked): - if self._parent is not None: - self._parent.doNotWarn = True - for checkBox in self.checkBoxes: - checkBox.setChecked(checked) - if self._parent is not None: - self._parent.doNotWarn = False - - def showInfo(self, checked=False): - info_txt = self.sender().info - msg = myMessageBox() - msg.setWidth(600) - msg.setIcon() - msg.setWindowTitle(f'{self.sender().colname} info') - msg.addText(info_txt) - msg.addButton(' Ok ') - msg.exec_() - - def show(self): - super().show() - fw = self.inner_layout.contentsRect().width() - sw = self.scrollArea.verticalScrollBar().sizeHint().width() - self.minWidth = fw + sw - -class channelMetricsQGBox(QGroupBox): - sigDelClicked = Signal(str, object) - sigCheckboxToggled = Signal(object) - - def __init__( - self, isZstack, chName, isSegm3D, is_concat=False, - posData=None, favourite_funcs=None - ): - QGroupBox.__init__(self) - - self.doNotWarn = False - self.is_concat = is_concat - isManualBackgrPresent = False - if posData is not None: - if posData.manualBackgroundLab is not None: - isManualBackgrPresent = True - - layout = QVBoxLayout() - metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( - isZstack, chName, isSegm3D=isSegm3D, - isManualBackgrPresent=isManualBackgrPresent - ) - - metricsQGBox = _metricsQGBox( - metrics_desc, 'Standard measurements', - favourite_funcs=favourite_funcs, - parent=self, isZstack=isZstack - ) - self.metricsQGBox = metricsQGBox - - bkgrValsQGBox = _metricsQGBox( - bkgr_val_desc, 'Background values', - favourite_funcs=favourite_funcs, - parent=self, isZstack=isZstack - ) - self.bkgrValsQGBox = bkgrValsQGBox - - self.checkBoxes = metricsQGBox.checkBoxes.copy() - self.checkBoxes.extend(bkgrValsQGBox.checkBoxes) - - self.uncheckAndDisableDataPrepIfPosNotPrepped(posData) - - self.groupboxes = [metricsQGBox, bkgrValsQGBox] - - for checkbox in metricsQGBox.checkBoxes: - checkbox.toggled.connect(self.standardMetricToggled) - self.standardMetricToggled(checkbox.isChecked(), checkbox=checkbox) - - for bkgrCheckbox in bkgrValsQGBox.checkBoxes: - bkgrCheckbox.toggled.connect(self.backgroundMetricToggled) - - layout.addWidget(metricsQGBox) - layout.addWidget(bkgrValsQGBox) - - items = measurements.custom_metrics_desc( - isZstack, chName, posData=posData, isSegm3D=isSegm3D, - return_combine=True - ) - custom_metrics_desc, combine_metrics_desc = items - - if custom_metrics_desc: - customMetricsQGBox = _metricsQGBox( - custom_metrics_desc, 'Custom measurements', - delButtonMetricsDesc=combine_metrics_desc, - favourite_funcs=favourite_funcs, - isZstack=isZstack - ) - layout.addWidget(customMetricsQGBox) - self.checkBoxes.extend(customMetricsQGBox.checkBoxes) - customMetricsQGBox.sigDelClicked.connect(self.onDelClicked) - self.customMetricsQGBox = customMetricsQGBox - - self.calcForEachZsliceToggle = None - if isZstack: - buttonsLayout = QHBoxLayout() - self.calcForEachZsliceToggle = Toggle() - tooltip = ( - 'Calculate the selected measurements for each z-slice.\n\n' - 'The measurements will be saved in the column with name\n' - 'ending with `_zsliceN` where N is the z-slice number\n' - '(starting from 0).' - ) - calcForEachZsliceLabel = QClickableLabel( - 'Calculate for each z-slice' - ) - calcForEachZsliceLabel.setToolTip(tooltip) - self.calcForEachZsliceToggle.setToolTip(tooltip) - buttonsLayout.addWidget(self.calcForEachZsliceToggle) - buttonsLayout.addWidget(calcForEachZsliceLabel) - buttonsLayout.addStretch(1) - layout.addLayout(buttonsLayout) - calcForEachZsliceLabel.clicked.connect( - partial( - self.toggleCalcForEachZslice, - toggle=self.calcForEachZsliceToggle - ) - ) - - - self.setTitle(f'{chName} metrics') - self.setCheckable(True) - self.setLayout(layout) - - def toggleCalcForEachZslice(self, label, toggle=None): - if toggle is None: - toggle = self.calcForEachZsliceToggle - - toggle.setChecked(not toggle.isChecked()) - - def isCalcForEachZsliceRequested(self): - if self.calcForEachZsliceToggle is None: - return False - - return self.calcForEachZsliceToggle.isChecked() - - def uncheckAndDisableDataPrepIfPosNotPrepped(self, posData): - # Uncheck and disable dataprep metrics if pos is not prepped - if posData is None: - return - - if posData.isBkgrROIpresent(): - return - - for checkbox in self.checkBoxes: - if checkbox.text().find('dataPrep') == -1: - continue - - checkbox.setChecked(False) - checkbox.isDataPrepDisabled = True - - def _warnDataPrepCannotBeChecked(self): - if self.doNotWarn: - return - txt = html_utils.paragraph(""" - Data prep measurements cannot be saved because you did - not select any background ROI at the data prep step.

    - - You can read more details about data prep metrics by clicking - on the info button besides the measurement's name.

    - - Thank you for you patience! - """) - msg = myMessageBox(showCentered=False) - msg.warning(self, 'Metric cannot be saved', txt) - - def standardMetricToggled(self, checked, checkbox=None): - """Method called when a check-box is toggled. It performs the following - actions: - 1. If the user try to check a data prep measurement, such as - dataPrep_amount, and this cannot be saved (checkbox has the attr - `isDataPrepDisabled`) then it warns and explains why it cannot be saved - 2. Make sure that background value median is checked if the user - requires amount or concentration metric. - 3. Do not allow unchecking background value median and explain why. - - Parameters - ---------- - checked : bool - State of the checkbox toggled - checkbox : QtWidgets.QCheckBox, optional - The checkbox that has been toggled. Default is None. If None - use `self.sender()` - """ - if self.is_concat: - return - - if checkbox is None: - checkbox = self.sender() - - if hasattr(checkbox, 'isDataPrepDisabled'): - # Warn that user cannot check data prep metrics and uncheck it - if not checkbox.isChecked(): - return - checkbox.setChecked(False) - self._warnDataPrepCannotBeChecked() - return - - self.sigCheckboxToggled.emit(checkbox) - if checkbox.text().find('amount_') == -1: - return - pattern = r'amount_([A-Za-z]+)(_?[A-Za-z0-9]*)' - repl = r'\g<1>_bkgrVal_median\g<2>' - bkgrValMetric = s1 = re.sub(pattern, repl, checkbox.text()) - for bkgrCheckbox in self.groupboxes[1].checkBoxes: - if bkgrCheckbox.text() == bkgrValMetric: - break - else: - # Make sure to not check for similarly named custom metrics - return - - if checked: - bkgrCheckbox.setChecked(True) - bkgrCheckbox.isRequired = True - else: - bkgrCheckbox.setDisabled(False) - bkgrCheckbox.isRequired = False - - def backgroundMetricToggled(self, checked): - """Method called when a checkbox of a background metric is toggled. - Check if the background value is required and explain why it cannot be - unchecked. - - Parameters - ---------- - checked : bool - State of the checkbox toggled - """ - if self.is_concat: - return - - checkbox = self.sender() - if not hasattr(checkbox, 'isRequired'): - return - - if not checkbox.isRequired: - return - - if checkbox.isChecked(): - return - - if self.doNotWarn: - return - - checkbox.setChecked(True) - txt = html_utils.paragraph(""" - This background value cannot be unchecked because it is required - by the _amount and _concentration measurements - that you requested to save.

    - - Thank you for you patience! - """) - msg = myMessageBox(showCentered=False) - msg.warning(self, 'Background value required', txt) - - def onDelClicked(self, colname_to_del, hlayout): - self.sigDelClicked.emit(colname_to_del, hlayout) - - def checkFavouriteFuncs(self): - self.doNotWarn = True - for groupbox in self.groupboxes: - groupbox.checkFavouriteFuncs() - self.doNotWarn = False - -class PixelSizeGroupbox(QGroupBox): - sigValueChanged = Signal(float, float, float) - sigReset = Signal() - - def __init__(self, parent=None): - super().__init__('Pixel size', parent) - - mainLayout = QGridLayout() - - row = 0 - label = QLabel('Pixel width (μm): ') - self.pixelWidthWidget = FloatLineEdit(initial=1.0) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.pixelWidthWidget, row, 1) - - row += 1 - label = QLabel('Pixel height (μm): ') - self.pixelHeightWidget = FloatLineEdit(initial=1.0) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.pixelHeightWidget, row, 1) - - row += 1 - label = QLabel('Voxel depth (μm): ') - self.voxelDepthWidget = FloatLineEdit(initial=1.0) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.voxelDepthWidget, row, 1) - - row += 1 - resetButton = reloadPushButton('Reset') - mainLayout.addWidget( - resetButton, row, 1, alignment=Qt.AlignRight - ) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - mainLayout.setColumnStretch(0, 0) - mainLayout.setColumnStretch(1, 1) - - self.setLayout(mainLayout) - - self.pixelWidthWidget.valueChanged.connect(self.emitValueChanged) - self.pixelHeightWidget.valueChanged.connect(self.emitValueChanged) - self.voxelDepthWidget.valueChanged.connect(self.emitValueChanged) - resetButton.clicked.connect(self.emitReset) - - def emitReset(self): - self.sigReset.emit() - - def emitValueChanged(self, value): - PhysicalSizeX = self.pixelWidthWidget.value() - PhysicalSizeY = self.pixelHeightWidget.value() - PhysicalSizeZ = self.voxelDepthWidget.value() - self.sigValueChanged.emit(PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ) - -class objPropsQGBox(QGroupBox): - def __init__(self, parent=None): - QGroupBox.__init__(self, 'Properties', parent) - - mainLayout = QGridLayout() - - row = 0 - label = QLabel('Object ID: ') - self.idSB = IntLineEdit() - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.idSB, row, 1) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - row += 1 - self.notExistingIDLabel = QLabel() - self.notExistingIDLabel.setStyleSheet( - 'font-size:11px; color: rgb(255, 0, 0);' - ) - mainLayout.addWidget( - self.notExistingIDLabel, row, 0, 1, 2, alignment=Qt.AlignCenter - ) - - row += 1 - label = QLabel('Area (pixel): ') - self.cellAreaPxlSB = IntLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellAreaPxlSB, row, 1) - - row += 1 - label = QLabel('Area (µm2): ') - self.cellAreaUm2DSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellAreaUm2DSB, row, 1) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - row += 1 - label = QLabel('Rotational volume (voxel): ') - self.cellVolVoxSB = IntLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellVolVoxSB, row, 1) - - row += 1 - label = QLabel('3D volume (voxel): ') - self.cellVolVox3D_SB = IntLineEdit(readOnly=True) - self.cellVolVox3D_SB.label = label - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellVolVox3D_SB, row, 1) - - row += 1 - label = QLabel('Rotational volume (fl): ') - self.cellVolFlDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellVolFlDSB, row, 1) - - row += 1 - label = QLabel('3D volume (fl): ') - self.cellVolFl3D_DSB = FloatLineEdit(readOnly=True) - self.cellVolFl3D_DSB.label = label - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.cellVolFl3D_DSB, row, 1) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - row += 1 - label = QLabel('Solidity: ') - self.solidityDSB = FloatLineEdit(readOnly=True) - self.solidityDSB.setMaximum(1) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.solidityDSB, row, 1) - - row += 1 - label = QLabel('Elongation: ') - self.elongationDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.elongationDSB, row, 1) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - row += 1 - propsNames = measurements.get_props_names()[1:] - self.additionalPropsCombobox = QComboBox() - self.additionalPropsCombobox.addItems(propsNames) - self.additionalPropsCombobox.indicator = FloatLineEdit(readOnly=True) - mainLayout.addWidget(self.additionalPropsCombobox, row, 0) - mainLayout.addWidget(self.additionalPropsCombobox.indicator, row, 1) - - row += 1 - mainLayout.addWidget(QHLine(), row, 0, 1, 2) - - mainLayout.setColumnStretch(0, 0) - mainLayout.setColumnStretch(1, 1) - - self.setLayout(mainLayout) - -class objIntesityMeasurQGBox(QGroupBox): - def __init__(self, parent=None): - QGroupBox.__init__(self, 'Intensity measurements', parent) - - mainLayout = QGridLayout() - - row = 0 - label = QLabel('Raw intensity measurements') - - row += 1 - label = QLabel('Channel: ') - self.channelCombobox = QComboBox() - self.channelCombobox.addItem('placeholderlong') - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.channelCombobox, row, 1) - - row += 1 - label = QLabel('Minimum: ') - self.minimumDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.minimumDSB, row, 1) - - row += 1 - label = QLabel('Maximum: ') - self.maximumDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.maximumDSB, row, 1) - - row += 1 - label = QLabel('Mean: ') - self.meanDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.meanDSB, row, 1) - - row += 1 - label = QLabel('Median: ') - self.medianDSB = FloatLineEdit(readOnly=True) - mainLayout.addWidget(label, row, 0) - mainLayout.addWidget(self.medianDSB, row, 1) - - row += 1 - metricsDesc = measurements._get_metrics_names() - metricsFunc, _ = measurements.standard_metrics_func() - items = list(set([metricsDesc[key] for key in metricsFunc.keys()])) - items.append('Concentration') - items.sort() - nameFuncDict = {} - for name, desc in metricsDesc.items(): - if name.find('_dataPrepBkgr')!=-1 or name.find('_manualBkgr')!=-1: - # Skip dataPrepBkgr and manualBkgr since in the dock widget - # we display only autoBkgr metrics - continue - if name.startswith('concentration_'): - # We use amount function because dividing by volume is taken - # care in the GUI - name = 'amount_autoBkgr' - nameFuncDict[desc] = metricsFunc[name] - - funcionCombobox = QComboBox() - funcionCombobox.addItems(items) - self.additionalMeasCombobox = funcionCombobox - self.additionalMeasCombobox.indicator = FloatLineEdit(readOnly=True) - self.additionalMeasCombobox.functions = nameFuncDict - mainLayout.addWidget(funcionCombobox, row, 0) - mainLayout.addWidget(self.additionalMeasCombobox.indicator, row, 1) - - self.setLayout(mainLayout) - - def addChannels(self, channels): - self.channelCombobox.clear() - self.channelCombobox.addItems(channels) - -class guiTabControl(QTabWidget): - def __init__(self, *args): - super().__init__(args[0]) - - self._defaultPixelSize = None - - self.propsTab = QScrollArea(self) - - container = QWidget() - layout = QVBoxLayout() - - self.pixelSizeQGBox = PixelSizeGroupbox(parent=self.propsTab) - self.propsQGBox = objPropsQGBox(parent=self.propsTab) - self.intensMeasurQGBox = objIntesityMeasurQGBox(parent=self.propsTab) - - self.highlightCheckbox = QCheckBox('Highlight objects on mouse hover') - self.highlightCheckbox.setChecked(False) - - self.highlightSearchedCheckbox = QCheckBox('Highlight searched object') - self.highlightSearchedCheckbox.setChecked(True) - - highlightLayout = QHBoxLayout() - highlightLayout.addWidget(self.highlightCheckbox) - highlightLayout.addStretch(1) - highlightLayout.addWidget(QLabel('|')) - highlightLayout.addStretch(1) - highlightLayout.addWidget(self.highlightSearchedCheckbox) - - layout.addLayout(highlightLayout) - layout.addWidget(self.pixelSizeQGBox) - layout.addWidget(self.propsQGBox) - layout.addWidget(self.intensMeasurQGBox) - layout.addStretch(1) - container.setLayout(layout) - - self.propsTab.setWidgetResizable(True) - self.propsTab.setWidget(container) - self.addTab(self.propsTab, 'Measurements') - - self.pixelSizeQGBox.sigValueChanged.connect(self.pixelSizeChanged) - self.pixelSizeQGBox.sigReset.connect(self.resetPixelSize) - - def addChannels(self, channels): - self.intensMeasurQGBox.addChannels(channels) - - def resetPixelSize(self): - if self._defaultPixelSize is None: - return - - self.initPixelSize(*self._defaultPixelSize) - - def initPixelSize(self, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ): - self.pixelSizeQGBox.pixelWidthWidget.setValue(PhysicalSizeX) - self.pixelSizeQGBox.pixelHeightWidget.setValue(PhysicalSizeY) - self.pixelSizeQGBox.voxelDepthWidget.setValue(PhysicalSizeZ) - self._defaultPixelSize = (PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ) - - def pixelSizeChanged(self, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ): - propsQGBox = self.propsQGBox - yx_pxl_to_um2 = PhysicalSizeY*PhysicalSizeX - vox_rot_to_fl = float(PhysicalSizeY)*pow(float(PhysicalSizeX), 2) - vox_3D_to_fl = PhysicalSizeZ*PhysicalSizeY*PhysicalSizeX - - area_pxl = propsQGBox.cellAreaPxlSB.value() - area_um2 = area_pxl*yx_pxl_to_um2 - propsQGBox.cellAreaUm2DSB.setValue(area_um2) - - vol_rot_vox = propsQGBox.cellVolVoxSB.value() - vol_rot_fl = vol_rot_vox*vox_rot_to_fl - propsQGBox.cellVolFlDSB.setValue(vol_rot_fl) - - vol_3D_vox = propsQGBox.cellVolVox3D_SB.value() - vol_3D_fl = vol_3D_vox*vox_3D_to_fl - propsQGBox.cellVolFl3D_DSB.setValue(vol_3D_fl) - - -class expandCollapseButton(PushButton): - sigClicked = Signal() - - def __init__(self, parent=None, **kwargs): - super().__init__(parent, **kwargs) - self.setIcon(QIcon(":expand.svg")) - self.setFlat(True) - self.installEventFilter(self) - self.isExpand = True - self.clicked.connect(self.buttonClicked) - - def buttonClicked(self, checked=False): - if self.isExpand: - self.setIcon(QIcon(":collapse.svg")) - self.isExpand = False - if self.text(): - self.setText(self.text().replace('Hide', 'Show')) - else: - self.setIcon(QIcon(":expand.svg")) - self.isExpand = True - if self.text(): - self.setText(self.text().replace('Show', 'Hide')) - self.sigClicked.emit() - - def eventFilter(self, object, event): - if event.type() == QEvent.Type.HoverEnter: - self.setFlat(False) - elif event.type() == QEvent.Type.HoverLeave: - self.setFlat(True) - return False - -class view_visualcpp_screenshot(QDialog): - def __init__(self, parent=None): - super().__init__(parent) - layout = QHBoxLayout() - - self.setWindowTitle('Visual Studio Builld Tools installation') - - pixmap = QPixmap(':visualcpp.png') - label = QLabel() - label.setPixmap(pixmap) - - layout.addWidget(label) - self.setLayout(layout) - -class PolyLineROI(pg.PolyLineROI): - def __init__(self, positions, closed=False, pos=None, **args): - super().__init__(positions, closed, pos, **args) - -class BaseGradientEditorItemImage(pg.GradientEditorItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - - def restoreState(self, state): - pg.graphicsItems.GradientEditorItem.Gradients = GradientsImage - return super().restoreState(state) - -class MouseCursor(QWidget): - def __init__(self, parent=None) -> None: - super().__init__(parent) - self._x = None - self._y = None - self.setMouseTracking(True) - - def mouseMoveEvent(self, event) -> None: - self.move(event.pos()) - self.update() - return super().mouseMoveEvent(event) - - # def drawAtPos(self, x, y): - # self._x = x - # self._y = y - # self.update() - - def paintEvent(self, event) -> None: - p = QPainter(self) - # p.setPen(QPen(QColor(0,0,0))) - # p.setBrush(QBrush(QColor(70,70,70,200))) - p.drawLine(0,0,200,0) - p.end() - -class BaseGradientEditorItemLabels(pg.GradientEditorItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - - def restoreState(self, state): - pg.graphicsItems.GradientEditorItem.Gradients = GradientsLabels - return super().restoreState(state) - -class baseHistogramLUTitem(pg.HistogramLUTItem): - sigAddColormap = Signal(object, str) - sigRescaleIntes = Signal(object) - - def __init__(self, name='image', axisLabel='', parent=None, **kwargs): - pg.GradientEditorItem = BaseGradientEditorItemLabels - - super().__init__(**kwargs) - - self.labelStyle = {'color': '#ffffff', 'font-size': '11px'} - - if axisLabel: - self.setAxisLabel(axisLabel) - - self.cmaps = cmaps - self._parent = parent - self.name = name - - self.gradient.colorDialog.setWindowFlags( - Qt.Dialog | Qt.WindowStaysOnTopHint - ) - self.gradient.colorDialog.accepted.disconnect() - self.gradient.colorDialog.accepted.connect(self.tickColorAccepted) - - self.isInverted = False - self.lastGradientName = 'grey' - self.lastGradient = Gradients['grey'] - - for action in self.gradient.menu.actions(): - if action.text() == 'HSV': - HSV_action = action - elif action.text() == 'RGB': - RGB_ation = action - self.gradient.menu.removeAction(HSV_action) - self.gradient.menu.removeAction(RGB_ation) - - # Rescale intensities (LUT) - rescaleIntensMenu = self.gradient.menu.addMenu( - 'Rescale intensities (LUT)' - ) - rescaleActionGroup = QActionGroup(self) - rescaleActionGroup.setExclusive(True) - - self.rescaleEach2DimgAction = QAction( - 'Rescale each 2D image', rescaleIntensMenu - ) - self.rescaleEach2DimgAction.setCheckable(True) - self.rescaleEach2DimgAction.setChecked(True) - rescaleActionGroup.addAction(self.rescaleEach2DimgAction) - rescaleIntensMenu.addAction(self.rescaleEach2DimgAction) - - self.rescaleAcrossZstackAction = QAction( - 'Rescale across z-stack', rescaleIntensMenu - ) - self.rescaleAcrossZstackAction.setCheckable(True) - self.rescaleAcrossZstackAction.setChecked(False) - rescaleActionGroup.addAction(self.rescaleAcrossZstackAction) - rescaleIntensMenu.addAction(self.rescaleAcrossZstackAction) - - self.rescaleAcrossTimeAction = QAction( - 'Rescale across time frames', rescaleIntensMenu - ) - self.rescaleAcrossTimeAction.setCheckable(True) - self.rescaleAcrossTimeAction.setChecked(False) - rescaleActionGroup.addAction(self.rescaleAcrossTimeAction) - rescaleIntensMenu.addAction(self.rescaleAcrossTimeAction) - - self.customRescaleAction = QAction( - 'Choose custom levels...', rescaleIntensMenu - ) - self.customRescaleAction.setCheckable(True) - rescaleActionGroup.addAction(self.customRescaleAction) - rescaleIntensMenu.addAction(self.customRescaleAction) - - self.doNotRescaleAction = QAction( - 'Do no rescale, display raw image', rescaleIntensMenu - ) - self.doNotRescaleAction.setCheckable(True) - rescaleActionGroup.addAction(self.doNotRescaleAction) - rescaleIntensMenu.addAction(self.doNotRescaleAction) - - self.rescaleActionGroup = rescaleActionGroup - rescaleActionGroup.triggered.connect(self.rescaleActionTriggered) - - # Add custom colormap action - self.customCmapsMenu = self.gradient.menu.addMenu('Custom colormaps') - self.customCmapsMenu.aboutToShow.connect(self.onShowCustomCmapsMenu) - self.customCmapsMenu.triggered.connect(self.customCmapsMenuTriggered) - - self.saveColormapAction = QAction( - 'Save current colormap...', self - ) - self.gradient.menu.addAction(self.saveColormapAction) - self.saveColormapAction.triggered.connect( - self.saveColormap - ) - - self.addCustomGradients() - - # Set inverted gradients for invert bw action - self.addInvertedColorMaps() - - self.gradient.menu.addSeparator() - - # hide histogram tool - self.vb.hide() - - # Disable moving the axis up and down - self.axis.unlinkFromView() - - # Disable histogram default context Menu event - self.vb.raiseContextMenu = lambda x: None - - def rescaleActionTriggered(self, action): - self.sigRescaleIntes.emit(action) - - def onShowCustomCmapsMenu(self): - self.customCmapsMenu.show() - - def customCmapsMenuTriggered(self, action): - cmap = action.cmap - self.gradient.colorMapMenuClicked(cmap) - self.gradient.showTicks(True) - - def setAxisLabel(self, text): - self.labelText = text - self.axis.setLabel(text, **self.labelStyle) - - def updateAxisLabel(self): - text = self.axis.label.toPlainText() - if not text: - return - self.setAxisLabel(text) - - def setGradient(self, gradient): - self.gradient.restoreState(gradient) - self.lastGradient = gradient - - def colormapClicked(self, checked=False, name=None): - name = self.sender().name - self.lastGradientName = name - if self.isInverted: - self.setGradient(self.invertedGradients[name]) - else: - self.setGradient(Gradients[name]) - - def sortTicks(self, ticks): - sortedTicks = sorted(ticks, key=operator.itemgetter(0)) - return sortedTicks - - def getInvertedGradients(self): - invertedGradients = {} - for name, gradient in Gradients.items(): - ticks = gradient['ticks'] - sortedTicks = self.sortTicks(ticks) - if name in nonInvertibleCmaps: - invertedColors = sortedTicks - else: - invertedColors = [ - (t[0], ti[1]) - for t, ti in zip(sortedTicks, sortedTicks[::-1]) - ] - invertedGradient = {} - invertedGradient['ticks'] = invertedColors - invertedGradient['mode'] = gradient['mode'] - invertedGradients[name] = invertedGradient - return invertedGradients - - def addInvertedColorMaps(self): - self.invertedGradients = self.getInvertedGradients() - for action in self.gradient.menu.actions(): - if not hasattr(action, 'name'): - continue - - if action.name not in self.cmaps: - continue - - action.triggered.disconnect() - action.triggered.connect(self.colormapClicked) - - px = QPixmap(100, 15) - p = QPainter(px) - invertedGradient = self.invertedGradients[action.name] - qtGradient = QLinearGradient(QPointF(0,0), QPointF(100,0)) - ticks = self.sortTicks(invertedGradient['ticks']) - qtGradient.setStops([(x, QColor(*color)) for x,color in ticks]) - brush = QBrush(qtGradient) - p.fillRect(QRect(0, 0, 100, 15), brush) - p.end() - widget = action.defaultWidget() - hbox = widget.layout() - rectLabelWidget = QLabel() - rectLabelWidget.setPixmap(px) - hbox.addWidget(rectLabelWidget) - rectLabelWidget.hide() - - def setInvertedColorMaps(self, inverted): - if inverted: - showIdx = 2 - hideIdx = 1 - self.labelStyle['color'] = '#000000' - else: - showIdx = 1 - hideIdx = 2 - self.labelStyle['color'] = '#ffffff' - - for action in self.gradient.menu.actions(): - if not hasattr(action, 'name'): - continue - - if action.name not in self.cmaps: - continue - - widget = action.defaultWidget() - hbox = widget.layout() - hideCmapRect = hbox.itemAt(hideIdx).widget() - showCmapRect = hbox.itemAt(showIdx).widget() - hideCmapRect.hide() - showCmapRect.show() - - self.updateAxisLabel() - self.isInverted = inverted - - def invertGradient(self, gradient): - ticks = gradient['ticks'] - sortedTicks = self.sortTicks(ticks) - invertedColors = [ - (t[0], ti[1]) - for t, ti in zip(sortedTicks, sortedTicks[::-1]) - ] - invertedGradient = {} - invertedGradient['ticks'] = invertedColors - invertedGradient['mode'] = gradient['mode'] - return invertedGradient - - def invertCurrentColormap(self, inverted, debug=False): - self.setGradient(self.invertGradient(self.lastGradient)) - - def addCustomGradient(self, gradient_name, gradient_ticks, restore=True): - self.originalLength = self.gradient.length - self.gradient.length = 100 - if restore: - self.gradient.restoreState(gradient_ticks) - gradient = self.gradient.getGradient() - action = CustomGradientMenuAction(gradient, gradient_name, self.gradient) - # action.triggered.connect(self.gradient.contextMenuClicked) - action.delButton.clicked.connect(self.removeCustomGradient) - action.cmap = colors.pg_ticks_to_colormap(gradient_ticks['ticks']) - # self.gradient.menu.insertAction(self.saveColormapAction, action) - self.customCmapsMenu.addAction(action) - self.gradient.length = self.originalLength - GradientsImage[gradient_name] = gradient_ticks - - def removeCustomGradient(self): - button = self.sender() - action = button.action - self.customCmapsMenu.removeAction(action) - cp = config.ConfigParser() - cp.read(custom_cmaps_filepath) - cp.remove_section(f'image.{action.name}') - with open(custom_cmaps_filepath, mode='w') as file: - cp.write(file) - - def addCustomGradients(self): - try: - CustomGradients = getCustomGradients(name='image') - if not CustomGradients: - return - for gradient_name, gradient_ticks in CustomGradients.items(): - self.addCustomGradient(gradient_name, gradient_ticks) - except Exception as e: - printl(traceback.format_exc()) - pass - - def _askNameColormap(self): - inputWin = apps.QInput(parent=self._parent, title='Colormap name') - inputWin.askText('Insert a name for the colormap: ', allowEmpty=False) - if inputWin.cancel: - return - cmapName = inputWin.answer - return cmapName - - def saveColormap(self): - cmapName = self._askNameColormap() - if cmapName is None: - return - - cp = config.ConfigParser() - if os.path.exists(custom_cmaps_filepath): - cp.read(custom_cmaps_filepath) - - SECTION = f'{self.name}.{cmapName}' - cp[SECTION] = {} - - # gradient_ticks = [] - state = self.gradient.saveState() - for key, value in state.items(): - if key != 'ticks': - continue - for t, tick in enumerate(value): - pos, rgb = tick - # gradient_ticks.append((pos, rgb)) - rgb = ','.join([str(c) for c in rgb]) - val = f'{pos},{rgb}' - cp[SECTION][f'tick_{t}_pos_rgb'] = val - - with open(custom_cmaps_filepath, mode='w') as file: - cp.write(file) - - self.addCustomGradient(cmapName, state, restore=False) - - def tickColorAccepted(self): - self.gradient.currentColorAccepted() - # self.sigTickColorAccepted.emit(self.gradient.colorDialog.color().getRgb()) - - def setRescaleIntensitiesHow(self, how): - for action in self.rescaleActionGroup.actions(): - if action.text() == how: - action.setChecked(True) - return - -class ROI(pg.ROI): - def __init__( - self, pos, size=pg.Point(1, 1), angle=0, invertible=False, - maxBounds=None, snapSize=1, scaleSnap=False, translateSnap=False, - rotateSnap=False, parent=None, pen=None, hoverPen=None, - handlePen=None, handleHoverPen=None, movable=True, rotatable=True, - resizable=True, removable=False, aspectLocked=False - ): - super().__init__( - pos, size, angle, invertible, maxBounds, snapSize, scaleSnap, - translateSnap, rotateSnap, parent, pen, hoverPen, handlePen, - handleHoverPen, movable, rotatable, resizable, removable, - aspectLocked - ) - - def slice(self, zRange=None, tRange=None): - x0, y0 = [int(round(c)) for c in self.pos()] - w, h = [int(round(c)) for c in self.size()] - xmin, xmax = x0, x0+w - if xmin > xmax: - xmin, xmax = xmax, xmin - ymin, ymax = y0, y0+h - if ymin > ymax: - ymin, ymax = ymax, ymin - if zRange is not None: - zmin, zmax = zRange - _slice = (slice(zmin, zmax), slice(ymin, ymax), slice(xmin, xmax)) - else: - _slice = (slice(ymin, ymax), slice(xmin, xmax)) - if tRange is not None: - tmin, tmax = tRange - _slice = (slice(tmin, tmax), *_slice) - return _slice - - def bbox(self): - x0, y0 = [int(round(c)) for c in self.pos()] - w, h = [int(round(c)) for c in self.size()] - xmin, xmax = x0, x0+w - if xmin > xmax: - xmin, xmax = xmax, xmin - ymin, ymax = y0, y0+h - if ymin > ymax: - ymin, ymax = ymax, ymin - - return ymin, xmin, ymax, xmax - -class ZoomROI(ROI): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.viewRangesQueue = deque() - - def getLastRange(self): - xRange, yRange = self.viewRangesQueue.pop() - return xRange, yRange - - def storeLastRange(self, xRange, yRange): - self.viewRangesQueue.append((xRange, yRange)) - -class DelROI(pg.ROI): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def clearPoints(self): - """ - Remove all handles and segments. - """ - while len(self.handles) > 0: - self.removeHandle(self.handles[0]['item']) - -class PlotCurveItem(pg.PlotCurveItem): - def __init__(self, *args, **kargs): - super().__init__(*args, **kargs) - - def addPoint(self, x, y, **kargs): - _xx, _yy = self.getData() - if _xx is None or len(_xx) == 0: - self.xData = np.array([x], dtype=int) - self.yData = np.array([y], dtype=int) - return - if _xx[-1] == x and _yy[-1] == y: - # Do not append same point - return - - # Pre-allocate array and insert data (faster than append) - xx = np.zeros(len(_xx)+1, dtype=_xx.dtype) - xx[:-1] = _xx - xx[-1] = x - yy = np.zeros(len(_yy)+1, dtype=_xx.dtype) - yy[:-1] = _yy - yy[-1] = y - self.setData(xx, yy, **kargs) - - def clear(self): - try: - self.setData([], []) - except Exception as e: - pass - super().clear() - - - def closeCurve(self): - _xx, _yy = self.getData() - self.addPoint(_xx[0], _yy[0]) - - def mask(self): - ymin, xmin, ymax, xmax = self.bbox() - _mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=bool) - local_xx, local_yy = self.getLocalData() - rr, cc = skimage.draw.polygon(local_yy, local_xx) - _mask[rr, cc] = True - return _mask - - def getLocalData(self): - _xx, _yy = self.getData() - return _xx - _xx.min(), _yy - _yy.min() - - def slice(self, zRange=None, tRange=None): - ymin, xmin, ymax, xmax = self.bbox() - if zRange is not None: - zmin, zmax = zRange - _slice = (slice(zmin, zmax), slice(ymin, ymax+1), slice(xmin, xmax+1)) - else: - _slice = (slice(ymin, ymax+1), slice(xmin, xmax+1)) - if tRange is not None: - tmin, tmax = tRange - _slice = (slice(tmin, tmax), *_slice) - return _slice - - def bbox(self): - _xx, _yy = self.getData() - return _yy.min(), _xx.min(), _yy.max(), _xx.max() - -class ToggleVisibilityButton(PushButton): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setFlat(True) - # self.setCheckable(True) - self._state = False - self.setIcon(QIcon(':unchecked.svg')) - self.clicked.connect(self.onClicked) - self.setStyleSheet(""" - QPushButton::pressed { - background-color: none; - border-style: none; - } - """) - - def onClicked(self): - self._state = not self._state - if self._state: - self.setIcon(QIcon(':eye-checked.svg')) - else: - self.setIcon(QIcon(':unchecked.svg')) - -class ToggleVisibilityCheckBox(QCheckBox): - def __init__(self, *args, pixelSize=24): - super().__init__(*args) - self._pixelSize = pixelSize - self.onToggled(False) - self.toggled.connect(self.onToggled) - - def setPixelSize(self, pixelSize): - self._pixelSize = pixelSize - - def onToggled(self, checked): - if checked: - self.setStyleSheet(f""" - QCheckBox::indicator {{ - width: {self._pixelSize}px; - height: {self._pixelSize}px; - }} - - QCheckBox::indicator:checked - {{ - image: url(:eye-checked.svg); - }} - """) - else: - self.setStyleSheet(f""" - QCheckBox::indicator {{ - width: {self._pixelSize}px; - height: {self._pixelSize}px; - }} - - QCheckBox::indicator:unchecked - {{ - image: url(:unchecked.svg); - }} - """) - - -class myHistogramLUTitem(baseHistogramLUTitem): - sigGradientMenuEvent = Signal(object) - sigGradientChanged = Signal(object) - sigTickColorAccepted = Signal(object) - sigAddScaleBar = Signal(bool) - sigAddTimestamp = Signal(bool) - - def __init__( - self, parent=None, name='image', axisLabel='', isViewer=False, - **kwargs - ): - super().__init__( - parent=parent, name=name, axisLabel=axisLabel, **kwargs - ) - - self.name = name - self._parent = parent - - self.childLutItem = None - - self.isViewer = isViewer - if isViewer: - # In the viewer we don't allow additional settings from the menu - return - - # Add scale bar action - self.addScaleBarAction = QAction('Add scale bar', self) - self.addScaleBarAction.setCheckable(True) - self.addScaleBarAction.triggered.connect(self.emitAddScaleBar) - self.gradient.menu.addAction(self.addScaleBarAction) - - # Add timestamp action - self.addTimestampAction = QAction('Add timestamp', self) - self.addTimestampAction.setCheckable(True) - self.addTimestampAction.triggered.connect(self.emitAddTimestamp) - self.gradient.menu.addAction(self.addTimestampAction) - - # Invert bw action - self.invertBwAction = QAction('Invert black/white', self) - self.invertBwAction.setCheckable(True) - self.gradient.menu.addAction(self.invertBwAction) - - # Font size menu action - self.fontSizeMenu = QMenu('Text font size') - self.gradient.menu.addMenu(self.fontSizeMenu) - - # Text color button - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Text color: ')) - self.textColorButton = myColorButton(color=(255,255,255)) - hbox.addStretch(1) - hbox.addWidget(self.textColorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.textColorButton.click) - self.gradient.menu.addAction(act) - - # Contours line weight - contLineWeightMenu = QMenu('Contours line weight', self.gradient.menu) - self.contLineWightActionGroup = QActionGroup(self) - self.contLineWightActionGroup.setExclusionPolicy( - QActionGroup.ExclusionPolicy.Exclusive - ) - for w in range(1, 11): - action = QAction(str(w)) - action.setCheckable(True) - if w == 2: - action.setChecked(True) - action.lineWeight = w - self.contLineWightActionGroup.addAction(action) - action = contLineWeightMenu.addAction(action) - self.gradient.menu.addMenu(contLineWeightMenu) - - # Contours color button - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Contours color: ')) - self.contoursColorButton = myColorButton(color=(25,25,25)) - hbox.addStretch(1) - hbox.addWidget(self.contoursColorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.contoursColorButton.click) - self.gradient.menu.addAction(act) - - # Mother-bud line weight - mothBudLineWeightMenu = QMenu('Mother-bud line weight', self.gradient.menu) - self.mothBudLineWightActionGroup = QActionGroup(self) - self.mothBudLineWightActionGroup.setExclusionPolicy( - QActionGroup.ExclusionPolicy.Exclusive - ) - for w in range(1, 11): - action = QAction(str(w)) - action.setCheckable(True) - if w == 2: - action.setChecked(True) - action.lineWeight = w - self.mothBudLineWightActionGroup.addAction(action) - action = mothBudLineWeightMenu.addAction(action) - self.gradient.menu.addMenu(mothBudLineWeightMenu) - - # Mother-bud line color - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Mother-bud line color: ')) - self.mothBudLineColorButton = myColorButton(color=(255,0,0)) - hbox.addStretch(1) - hbox.addWidget(self.mothBudLineColorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.mothBudLineColorButton.click) - self.gradient.menu.addAction(act) - - self.labelsAlphaMenu = self.gradient.menu.addMenu( - 'Segm. masks overlay alpha...' - ) - # self.labelsAlphaMenu.setDisabled(True) - hbox = QHBoxLayout() - self.labelsAlphaSlider = sliderWithSpinBox( - title='Alpha', title_loc='in_line', isFloat=True, - normalize=True - ) - self.labelsAlphaSlider.setMaximum(100) - self.labelsAlphaSlider.setSingleStep(0.05) - self.labelsAlphaSlider.setValue(0.3) - hbox.addWidget(self.labelsAlphaSlider) - shortCutText = 'Command+Up/Down' if is_mac else 'Ctrl+Up/Down' - hbox.addWidget(QLabel(f'({shortCutText})')) - widget = QWidget() - widget.setLayout(hbox) - act = QWidgetAction(self) - act.setDefaultWidget(widget) - self.labelsAlphaMenu.addSeparator() - self.labelsAlphaMenu.addAction(act) - - # Default settings - self.defaultSettingsAction = QAction('Restore default settings...', self) - self.gradient.menu.addAction(self.defaultSettingsAction) - - self.filterObject = FilterObject() - self.filterObject.sigFilteredEvent.connect(self.gradientMenuEventFilter) - self.gradient.menu.installEventFilter(self.filterObject) - self.highlightedAction = None - self.lastHoveredAction = None - - def setChildLutItem(self, childLutItem): - self.childLutItem = childLutItem - - def removeAddScaleBarAction(self): - self.gradient.menu.removeAction(self.addScaleBarAction) - - def removeAddTimestampAction(self): - self.gradient.menu.removeAction(self.addTimestampAction) - - def emitAddScaleBar(self): - self.sigAddScaleBar.emit(self.addScaleBarAction.isChecked()) - - def emitAddTimestamp(self): - self.sigAddTimestamp.emit(self.addTimestampAction.isChecked()) - - def gradientChanged(self): - super().gradientChanged() - self.sigGradientChanged.emit(self) - - def gradientMenuEventFilter(self, object, event): - if event.type() == QEvent.Type.MouseMove: - hoveredAction = self.gradient.menu.actionAt(event.pos()) - isActionEntered = ( - hoveredAction != self.lastHoveredAction - ) - if isActionEntered: - if isinstance(hoveredAction, highlightableQWidgetAction): - # print('Entered a custom action') - pass - isActionLeft = ( - self.highlightedAction is not None - and self.highlightedAction != hoveredAction - ) - if isActionLeft: - if isinstance( - self.highlightedAction, highlightableQWidgetAction - ): - # print('Left a custom action') - pass - self.highlightedAction = hoveredAction - - self.lastHoveredAction = hoveredAction - - def addOverlayColorButton(self, rgbColor, channelName): - # Overlay color button - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Overlay color: ')) - self.overlayColorButton = myColorButton(color=rgbColor) - self.overlayColorButton.channel = channelName - hbox.addStretch(1) - hbox.addWidget(self.overlayColorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.overlayColorButton.click) - self.gradient.menu.addAction(act) - - def uncheckContLineWeightActions(self): - for act in self.contLineWightActionGroup.actions(): - try: - act.toggled.disconnect() - except Exception as e: - pass - act.setChecked(False) - - def uncheckMothBudLineLineWeightActions(self): - for act in self.mothBudLineWightActionGroup.actions(): - try: - act.toggled.disconnect() - except Exception as e: - pass - act.setChecked(False) - - def restoreState(self, df): - if 'textIDsColor' in df.index: - rgbString = df.at['textIDsColor', 'value'] - r, g, b = colors.rgb_str_to_values(rgbString) - self.textColorButton.setColor((r, g, b)) - - if 'contLineColor' in df.index: - rgba_str = df.at['contLineColor', 'value'] - rgb = colors.rgba_str_to_values(rgba_str)[:3] - self.contoursColorButton.setColor(rgb) - - if 'contLineWeight' in df.index: - w = df.at['contLineWeight', 'value'] - w = int(w) - for action in self.contLineWightActionGroup.actions(): - if action.lineWeight == w: - action.setChecked(True) - break - - if 'mothBudLineWeight' in df.index: - w = df.at['mothBudLineWeight', 'value'] - w = int(w) - for action in self.mothBudLineWightActionGroup.actions(): - if action.lineWeight == w: - action.setChecked(True) - break - - if 'overlaySegmMasksAlpha' in df.index: - alpha = df.at['overlaySegmMasksAlpha', 'value'] - self.labelsAlphaSlider.setValue(float(alpha)) - - if 'mothBudLineColor' in df.index: - rgba_str = df.at['mothBudLineColor', 'value'] - rgb = colors.rgba_str_to_values(rgba_str)[:3] - self.mothBudLineColorButton.setColor(rgb) - - checked = df.at['is_bw_inverted', 'value'] == 'Yes' - self.invertBwAction.setChecked(checked) - - self.restoreColormap(df) - - def saveState(self, df): - # remove previous state - df = df[~df.index.str.contains('img_cmap')].copy() - - state = self.gradient.saveState() - for key, value in state.items(): - if key == 'ticks': - for t, tick in enumerate(value): - pos, rgb = tick - df.at[f'img_cmap_tick{t}_rgb', 'value'] = rgb - df.at[f'img_cmap_tick{t}_pos', 'value'] = pos - else: - if isinstance(value, bool): - value = 'Yes' if value else 'No' - df.at[f'img_cmap_{key}', 'value'] = value - return df - - def restoreColormap(self, df): - state = {'mode': 'rgb', 'ticksVisible': True, 'ticks': []} - ticks_pos = {} - ticks_rgb = {} - stateFound = False - for setting, value in df.itertuples(): - idx = setting.find('img_cmap_') - if idx == -1: - continue - - stateFound = True - m = re.findall(r'tick(\d+)_(\w+)', setting) - if m: - tick_idx, tick_type = m[0] - if tick_type == 'pos': - ticks_pos[int(tick_idx)] = float(value) - elif tick_type == 'rgb': - ticks_rgb[int(tick_idx)] = colors.rgba_str_to_values(value) - else: - key = setting[9:] - if value == 'Yes': - value = True - elif value == 'No': - value = False - state[key] = value - - if stateFound: - ticks = [(0, 0)]*len(ticks_pos) - for idx, val in ticks_pos.items(): - pos = val - rgb = ticks_rgb[idx] - ticks[idx] = (pos, rgb) - - state['ticks'] = ticks - self.gradient.restoreState(state) - - def regionChanged(self): - super().regionChanged() - if self.childLutItem is None: - return - - imageItem = self.imageItem() - try: - mn, mx = imageItem.quickMinMax(targetSize=65536) - # mn and mx can still be NaN if the data is all-NaN - if mn == mx or imageItem._xp.isnan(mn) or imageItem._xp.isnan(mx): - mn = 0 - mx = 255 - except AttributeError as err: - mn, mx = self.getLevels() - - self.childLutItem.setLevels(min=mn, max=mx) - - -class labelledQScrollbar(ScrollBar): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._label = None - - def setLabel(self, label): - self._label = label - - def updateLabel(self): - if self._label is not None: - position = self.sliderPosition() - s = self._label.text() - s = re.sub(r'(\d+)/(\d+)', fr'{position+1:02}/\2', s) - self._label.setText(s) - - def setSliderPosition(self, position): - QScrollBar.setSliderPosition(self, position) - self.updateLabel() - - def setValue(self, value): - QScrollBar.setValue(self, value) - self.updateLabel() - -class navigateQScrollBar(ScrollBar): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._disableCustomPressEvent = False - self.signal_slot_mapper = {} - - def disableCustomPressEvent(self): - self._disableCustomPressEvent = True - - def enableCustomPressEvent(self): - self._disableCustomPressEvent = False - - def setAbsoluteMaximum(self, absoluteMaximum): - self._absoluteMaximum = absoluteMaximum - - def absoluteMaximum(self): - return self._absoluteMaximum - - def mousePressEvent(self, event): - super().mousePressEvent(event) - if self.maximum() == self._absoluteMaximum: - return - - if self._disableCustomPressEvent: - return - - def setValueNoSignal(self, value): - for signal_name, slot in self.signal_slot_mapper.items(): - signal = getattr(self, signal_name) - try: - signal.disconnect() - except Exception as e: - pass - - self.setSliderPosition(value) - self.connectEvents(self.signal_slot_mapper) - - def connectEvents(self, signal_slot_mapper: dict): - self.signal_slot_mapper = signal_slot_mapper - for signal_name, slot in signal_slot_mapper.items(): - signal = getattr(self, signal_name) - try: - signal.disconnect() - except Exception as e: - pass - signal.connect(slot) - -class linkedQScrollbar(ScrollBar): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._linkedScrollBar = None - - def linkScrollBar(self, scrollbar): - self._linkedScrollBar = scrollbar - scrollbar.setSliderPosition(self.sliderPosition()) - - def unlinkScrollBar(self): - self._linkedScrollBar = None - - def setSliderPosition(self, position): - QScrollBar.setSliderPosition(self, position) - if self._linkedScrollBar is not None: - self._linkedScrollBar.setSliderPosition(position) - - def setMaximum(self, max): - QScrollBar.setMaximum(self, max) - if self._linkedScrollBar is not None: - self._linkedScrollBar.setMaximum(max) - -class myColorButton(pg.ColorButton): - def __init__(self, parent=None, color=(128,128,128), padding=5): - super().__init__(parent=parent, color=color) - if isinstance(padding, (int, float)): - self.padding = (padding, padding, -padding, -padding) - else: - self.padding = padding - self._c = 225 - self._hoverDeltaC = 30 - self._alpha = 100 - self._bkgrColor = QColor(self._c, self._c, self._c, self._alpha) - self._borderColor = QColor(171, 171, 171) - self._rectBorderPen = QPen(QBrush(QColor(0,0,0)), 0.3) - - def paintEvent(self, event): - # QPushButton.paintEvent(self, ev) - p = QStylePainter(self) - p.setRenderHint(QPainter.RenderHint.Antialiasing) - rect = self.rect() - p.setBrush(QBrush(self._bkgrColor)) - p.setPen(QPen(self._borderColor)) - p.drawRoundedRect(rect, 5, 5) - # p.fillRect(self.rect(), self._bkgrColor) - rect = self.rect().adjusted(*self.padding) - ## draw white base, then texture for indicating transparency, then actual color - p.setBrush(pg.mkBrush('w')) - p.drawRect(rect) - p.setBrush(QBrush(Qt.BrushStyle.DiagCrossPattern)) - p.drawRect(rect) - p.setPen(self._rectBorderPen) - p.setBrush(pg.mkBrush(self._color)) - p.drawRect(rect) - p.end() - - def enterEvent(self, event): - c = self._c + self._hoverDeltaC - self._bkgrColor = QColor(c, c, c, self._alpha) - self.update() - - def leaveEvent(self, event): - c = self._c - self._bkgrColor = QColor(c, c, c, self._alpha) - self.update() - -class highlightableQWidgetAction(QWidgetAction): - def __init__(self, parent) -> None: - super().__init__(parent) - -class overlayLabelsGradientWidget(pg.GradientWidget): - def __init__( - self, imageItem, selectActionGroup, segmEndname, - parent=None, orientation='right' - ): - pg.GradientWidget.__init__(self, parent=parent, orientation=orientation) - - self.imageItem = imageItem - self.selectActionGroup = selectActionGroup - - for action in self.menu.actions(): - if action.text() == 'HSV': - HSV_action = action - elif action.text() == 'RGB': - RGB_ation = action - self.menu.removeAction(HSV_action) - self.menu.removeAction(RGB_ation) - - # Shuffle colors action - self.shuffleCmapAction = QAction( - 'Randomly shuffle colormap (Shift+S)', self - ) - self.menu.addAction(self.shuffleCmapAction) - - # Drawing mode - drawModeMenu = QMenu('Drawing mode', self) - self.drawModeActionGroup = QActionGroup(self) - contoursDrawModeAction = QAction('Draw contours', drawModeMenu) - contoursDrawModeAction.setCheckable(True) - contoursDrawModeAction.setChecked(True) - contoursDrawModeAction.segmEndname = segmEndname - self.drawModeActionGroup.addAction(contoursDrawModeAction) - drawModeMenu.addAction(contoursDrawModeAction) - olDrawModeAction = QAction('Overlay labels', drawModeMenu) - olDrawModeAction.setCheckable(True) - olDrawModeAction.segmEndname = segmEndname - self.drawModeActionGroup.addAction(olDrawModeAction) - drawModeMenu.addAction(olDrawModeAction) - self.menu.addMenu(drawModeMenu) - - self.labelsAlphaMenu = self.menu.addMenu( - 'Overlay labels alpha...' - ) - hbox = QHBoxLayout() - self.labelsAlphaSlider = sliderWithSpinBox( - title='Alpha', title_loc='in_line', isFloat=True, - normalize=True - ) - self.labelsAlphaSlider.setMaximum(100) - self.labelsAlphaSlider.setSingleStep(0.05) - self.labelsAlphaSlider.setValue(0.3) - hbox.addWidget(self.labelsAlphaSlider) - widget = QWidget() - widget.setLayout(hbox) - act = QWidgetAction(self) - act.setDefaultWidget(widget) - self.labelsAlphaMenu.addSeparator() - self.labelsAlphaMenu.addAction(act) - - self.menu.addSeparator() - self.menu.addSection('Select segm. file to adjust:') - for action in selectActionGroup.actions(): - self.menu.addAction(action) - - self.item.loadPreset('viridis') - self.updateImageLut(None) - self.updateImageOpacity(0.3) - - # Connect events - self.sigGradientChangeFinished.connect(self.updateImageLut) - self.labelsAlphaSlider.valueChanged.connect(self.updateImageOpacity) - self.shuffleCmapAction.triggered.connect(self.shuffleCmap) - - def shuffleCmap(self): - lut = self.imageItem.lut - np.random.shuffle(lut) - lut[0] = [0,0,0,0] - self.imageItem.setLookupTable(lut) - self.imageItem.update() - - def updateImageLut(self, gradientItem): - lut = np.zeros((255, 4), dtype=np.uint8) - lut[:,-1] = 255 - lut[:,:-1] = self.item.colorMap().getLookupTable(0,1,255) - np.random.shuffle(lut) - lut[0] = [0,0,0,0] - self.imageItem.setLookupTable(lut) - self.imageItem.setLevels([0, 255]) - - def updateImageOpacity(self, value): - self.imageItem.setOpacity(value) - -class labelsGradientWidget(pg.GradientWidget): - sigShowRightImgToggled = Signal(bool) - sigShowLabelsImgToggled = Signal(bool) - sigShowNextFrameToggled = Signal(bool) - - def __init__( self, *args, parent=None, orientation='right', **kargs): - pg.GradientEditorItem = BaseGradientEditorItemLabels - - pg.GradientWidget.__init__( - self, *args, parent=parent, orientation=orientation, **kargs - ) - - self._parent = parent - self.name = 'labels' - - for action in self.menu.actions(): - if action.text() == 'HSV': - HSV_action = action - elif action.text() == 'RGB': - RGB_ation = action - self.menu.removeAction(HSV_action) - self.menu.removeAction(RGB_ation) - - # Add custom colormap action - self.customCmapsMenu = self.menu.addMenu('Custom colormaps') - self.customCmapsMenu.aboutToShow.connect(self.onShowCustomCmapsMenu) - self.customCmapsMenu.triggered.connect(self.customCmapsMenuTriggered) - - self.saveColormapAction = QAction( - 'Save current colormap...', self - ) - self.menu.addAction(self.saveColormapAction) - self.saveColormapAction.triggered.connect( - self.saveColormap - ) - - self.addCustomGradients() - - # Background color button - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Background color: ')) - self.colorButton = myColorButton(color=(25,25,25)) - hbox.addStretch(1) - hbox.addWidget(self.colorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.colorButton.click) - self.menu.addAction(act) - - # Font size menu action - self.fontSizeMenu = QMenu('Text font size', self) - self.menu.addMenu(self.fontSizeMenu) - - # IDs color button - hbox = QHBoxLayout() - hbox.addWidget(QLabel('Text color: ')) - self.textColorButton = myColorButton(color=(25,25,25)) - hbox.addStretch(1) - hbox.addWidget(self.textColorButton) - widget = QWidget() - widget.setLayout(hbox) - act = highlightableQWidgetAction(self) - act.setDefaultWidget(widget) - act.triggered.connect(self.textColorButton.click) - self.menu.addAction(act) - self.menu.addSeparator() - - # Shuffle colors action - self.shuffleCmapAction = QAction( - 'Randomly shuffle colormap (Shift+S)', self - ) - self.menu.addAction(self.shuffleCmapAction) - - self.greedyShuffleCmapAction = QAction( - 'Greedily shuffle colormap (Alt+Shift+S)', self - ) - self.menu.addAction(self.greedyShuffleCmapAction) - - self.permanentGreedyCmapAction = QAction( - 'Always use greedy colormap', self - ) - self.permanentGreedyCmapAction.setCheckable(True) - self.menu.addAction(self.permanentGreedyCmapAction) - - # Invert bw action - self.invertBwAction = QAction('Invert black/white', self) - self.invertBwAction.setCheckable(True) - self.menu.addAction(self.invertBwAction) - - # Show labels action - self.showLabelsImgAction = QAction('Show segmentation image', self) - self.showLabelsImgAction.setCheckable(True) - self.menu.addAction(self.showLabelsImgAction) - - # Show right image action - self.showRightImgAction = QAction('Show duplicated left image', self) - self.showRightImgAction.setCheckable(True) - self.menu.addAction(self.showRightImgAction) - - # Show next frame action - self.showNextFrameAction = QAction('Show next frame', self) - self.showNextFrameAction.setCheckable(True) - self.menu.addAction(self.showNextFrameAction) - - # Default settings - self.defaultSettingsAction = QAction('Restore default settings...', self) - self.menu.addAction(self.defaultSettingsAction) - - self.menu.addSeparator() - - self.showRightImgAction.toggled.connect(self.showRightImageToggled) - self.showLabelsImgAction.toggled.connect(self.showLabelsImageToggled) - self.showNextFrameAction.toggled.connect(self.showNextFrameToggled) - - def onShowCustomCmapsMenu(self): - self.customCmapsMenu.show() - - def customCmapsMenuTriggered(self, action): - cmap = action.cmap - self.item.colorMapMenuClicked(cmap) - self.item.showTicks(True) - - def addCustomGradient(self, gradient_name, gradient_ticks, restore=True): - currentState = self.item.saveState() - self.originalLength = self.item.length - self.item.length = 100 - if restore: - self.item.restoreState(gradient_ticks) - gradient = self.item.getGradient() - action = CustomGradientMenuAction(gradient, gradient_name, self.item) - # action.triggered.connect(self.item.contextMenuClicked) - action.delButton.clicked.connect(self.removeCustomGradient) - action.cmap = colors.pg_ticks_to_colormap(gradient_ticks['ticks']) - # self.item.menu.insertAction(self.saveColormapAction, action) - self.customCmapsMenu.addAction(action) - self.item.length = self.originalLength - self.item.restoreState(currentState) - GradientsLabels[gradient_name] = gradient_ticks - - def removeCustomGradient(self): - button = self.sender() - action = button.action - self.customCmapsMenu.removeAction(action) - cp = config.ConfigParser() - cp.read(custom_cmaps_filepath) - cp.remove_section(f'labels.{action.name}') - with open(custom_cmaps_filepath, mode='w') as file: - cp.write(file) - - def addCustomGradients(self): - try: - CustomGradients = getCustomGradients(name='labels') - if not CustomGradients: - return - for gradient_name, gradient_ticks in CustomGradients.items(): - self.addCustomGradient(gradient_name, gradient_ticks) - except Exception as e: - printl(traceback.format_exc()) - pass - - def _askNameColormap(self): - inputWin = apps.QInput(parent=self._parent, title='Colormap name') - inputWin.askText('Insert a name for the colormap: ', allowEmpty=False) - if inputWin.cancel: - return - cmapName = inputWin.answer - return cmapName - - def saveColormap(self): - cmapName = self._askNameColormap() - if cmapName is None: - return - - cp = config.ConfigParser() - if os.path.exists(custom_cmaps_filepath): - cp.read(custom_cmaps_filepath) - - SECTION = f'{self.name}.{cmapName}' - cp[SECTION] = {} - - state = self.item.saveState() - for key, value in state.items(): - if key != 'ticks': - continue - for t, tick in enumerate(value): - pos, rgb = tick - rgb = ','.join([str(c) for c in rgb]) - val = f'{pos},{rgb}' - cp[SECTION][f'tick_{t}_pos_rgb'] = val - - with open(custom_cmaps_filepath, mode='w') as file: - cp.write(file) - - self.addCustomGradient(cmapName, state, restore=False) - - def isRightImageVisible(self): - return ( - self.showLabelsImgAction.isChecked() - or self.showNextFrameAction.isChecked() - ) - - def showRightImageToggled(self, checked): - if checked and self.isRightImageVisible(): - # Hide the right labels image before showing right image - self.showLabelsImgAction.setChecked(False) - self.showNextFrameAction.setChecked(False) - self.sigShowLabelsImgToggled.emit(False) - self.sigShowNextFrameToggled.emit(checked) - self.sigShowRightImgToggled.emit(checked) - - def showLabelsImageToggled(self, checked): - if checked and self.isRightImageVisible(): - # Hide the right image before showing labels image - self.showRightImgAction.setChecked(False) - self.showNextFrameAction.setChecked(False) - self.sigShowRightImgToggled.emit(False) - self.sigShowNextFrameToggled.emit(False) - self.sigShowLabelsImgToggled.emit(checked) - - def showNextFrameToggled(self, checked): - if checked and self.isRightImageVisible(): - # Hide the right image before showing labels image - self.showRightImgAction.setChecked(False) - self.showLabelsImgAction.setChecked(False) - self.sigShowRightImgToggled.emit(False) - self.sigShowLabelsImgToggled.emit(False) - self.sigShowNextFrameToggled.emit(checked) - - def saveState(self, df): - # remove previous state - df = df[~df.index.str.contains('lab_cmap')].copy() - - state = self.item.saveState() - for key, value in state.items(): - if key == 'ticks': - for t, tick in enumerate(value): - pos, rgb = tick - df.at[f'lab_cmap_tick{t}_rgb', 'value'] = rgb - df.at[f'lab_cmap_tick{t}_pos', 'value'] = pos - else: - if isinstance(value, bool): - value = 'Yes' if value else 'No' - df.at[f'lab_cmap_{key}', 'value'] = value - return df - - def restoreState(self, df, loadCmap=True): - # Insert background color - if 'labels_bkgrColor' in df.index: - rgbString = df.at['labels_bkgrColor', 'value'] - r, g, b = colors.rgb_str_to_values(rgbString) - self.colorButton.setColor((r, g, b)) - - if 'labels_text_color' in df.index: - rgbString = df.at['labels_text_color', 'value'] - r, g, b = colors.rgb_str_to_values(rgbString) - self.textColorButton.setColor((r, g, b)) - else: - self.textColorButton.setColor((255, 0, 0)) - - checked = df.at['is_bw_inverted', 'value'] == 'Yes' - self.invertBwAction.setChecked(checked) - - if not loadCmap: - return - - state = {'mode': 'rgb', 'ticksVisible': True, 'ticks': []} - ticks_pos = {} - ticks_rgb = {} - stateFound = False - for setting, value in df.itertuples(): - idx = setting.find('lab_cmap_') - if idx == -1: - continue - - stateFound = True - m = re.findall(r'tick(\d+)_(\w+)', setting) - if m: - tick_idx, tick_type = m[0] - if tick_type == 'pos': - ticks_pos[int(tick_idx)] = float(value) - elif tick_type == 'rgb': - ticks_rgb[int(tick_idx)] = colors.rgba_str_to_values(value) - else: - key = setting[9:] - if value == 'Yes': - value = True - elif value == 'No': - value = False - state[key] = value - - if stateFound: - ticks = [(0, 0)]*len(ticks_pos) - for idx, val in ticks_pos.items(): - pos = val - rgb = ticks_rgb[idx] - ticks[idx] = (pos, rgb) - - state['ticks'] = ticks - self.item.restoreState(state) - else: - self.item.loadPreset('viridis') - - return stateFound - - def showMenu(self, ev): - try: - # Convert QPointF to QPoint - self.menu.popup(ev.screenPos().toPoint()) - except AttributeError: - self.menu.popup(ev.screenPos()) - -class QLogConsole(QTextEdit): - def __init__(self, parent=None): - super().__init__(parent) - self.setReadOnly(True) - font = QFont() - font.setPixelSize(13) - self.setFont(font) - - def write(self, message): - # Method required by tqdm pbar - message = message.replace('\r ', '') - if message: - self.apppendText(message) - - def append(self, text: str) -> None: - super().append(text) - self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) - - def insertPlainText(self, text: str) -> None: - super().append(text) - self.verticalScrollBar().setValue(self.verticalScrollBar().maximum()) - -class ProgressBar(QProgressBar): - def __init__(self, parent=None): - super().__init__(parent) - palette = self.palette() - palette.setColor( - QPalette.ColorRole.Highlight, - PROGRESSBAR_QCOLOR - ) - palette.setColor( - QPalette.ColorRole.HighlightedText, - PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR - ) - self.setPalette(palette) - -class ProgressBarWithETA(ProgressBar): - def __init__(self, parent=None): - self.parent = parent - super().__init__(parent=parent) - self.ETA_label = QLabel('NDh:NDm:NDs') - - def update(self, step: int): - self.setValue(self.value()+step) - t = time.perf_counter() - if not hasattr(self, 'last_time_update'): - self.last_time_update = t - self.mean_value_duration = None - return - seconds_per_value = (t - self.last_time_update)/step - value_left = self.maximum() - self.value() - if self.mean_value_duration is None: - self.mean_value_duration = seconds_per_value - else: - self.mean_value_duration = ( - self.mean_value_duration*(self.value()-1) + seconds_per_value - )/self.value() - - seconds_left = self.mean_value_duration*value_left - ETA = myutils.seconds_to_ETA(seconds_left) - self.ETA_label.setText(ETA) - self.last_time_update = t - return ETA - - def show(self): - QProgressBar.show(self) - self.ETA_label.show() - - def hide(self): - QProgressBar.hide(self) - self.ETA_label.hide() - -class NoneWidget: - def __init__(self): - pass - - def value(self): - return None - - def setValue(self, value): - return - -class MainPlotItem(pg.PlotItem): - def __init__( - self, parent=None, name=None, labels=None, title=None, - viewBox=None, axisItems=None, enableMenu=True, - showWelcomeText=False, **kargs - ): - super().__init__( - parent, name, labels, title, viewBox, axisItems, enableMenu, - **kargs - ) - # Overwrite zoom out button behaviour to disable autoRange after - # clicking it. - # If autorange is enabled, it is called everytime the brush or eraser - # scatter plot items touches the border causing flickering - self.disableAutoRange() - self.autoBtn.mode = 'manual' - if showWelcomeText: - self.infoTextItem = pg.TextItem() - self.addItem(self.infoTextItem) - html_filepath = os.path.join(html_path, 'gui_welcome.html') - with open(html_filepath) as html_file: - htmlText = html_file.read() - self.infoTextItem.setHtml(htmlText) - self.infoTextItem.setPos(0,0) - - self.delRoiItems = {} - self.highlightingRectItems = None - self._baseImageItem = None - self._imageItems = [] - self.highlightingRectItemsColor = None - - def addHighlightingRectItems(self, color=None): - self.highlightingRectItems = { - 'left': RectItem(QRectF()), - 'right': RectItem(QRectF()), - 'top': RectItem(QRectF()), - 'bottom': RectItem(QRectF()) - } - for rect in self.highlightingRectItems.values(): - self.addItem(rect) - - if color is None: - return - - self.setHighlightingRectItemsColor(color) - - def setHighlightingRectItemsColor(self, color): - if color == self.highlightingRectItemsColor: - return - - for item in self.highlightingRectItems.values(): - item.setColor(color) - - self.highlightingRectItemsColor = color - - def addBaseImageItem(self, baseImageItem): - self._baseImageItem = baseImageItem - self._imageItems.append(baseImageItem) - self.addItem(baseImageItem) - - def addImageItem(self, imageItem): - self._imageItems.append(imageItem) - self.addItem(imageItem) - - def setHighlighted(self, highlighted, color=None): - if color is None: - color = self.highlightingRectItemsColor - - if color is None: - color = 'green' - - if self.highlightingRectItems is None: - self.addHighlightingRectItems(color=color) - - if not highlighted: - for rect in self.highlightingRectItems.values(): - rect.setQRect(QRectF()) - return - - self.setHighlightingRectItemsColor(color) - - ((xmin, xmax), (ymin, ymax)) = self.viewRange() - xmin = xmin if xmin >= 0 else 0 - ymin = ymin if ymin >= 0 else 0 - if self._baseImageItem is not None: - Y, X = self._baseImageItem.image.shape[:2] - xmax = min(xmax, X) - ymax = min(ymax, Y) - - w = xmax - xmin - h = ymax - ymin - - bs = round(((w + h) / 2) * 0.02) - if bs < 1: - bs = 1 - - x0 = xmin - x1 = xmin + bs - x2 = xmax - bs - x3 = xmax - - y0 = ymin - y1 = ymin + bs - y2 = ymax - bs - y3 = ymax - - self.highlightingRectItems['left'].setRect(x0, y0, bs, y3-y0) - self.highlightingRectItems['top'].setRect(x1, y0, x3-x1, bs) - self.highlightingRectItems['right'].setRect(x2, y1, bs, y3-y1) - self.highlightingRectItems['bottom'].setRect(x1, y2, x2-x1, bs) - self.update() - - def clear(self): - super().clear() - - self.delRoiItems = {} - self.highlightingRectItems = None - self._baseImageItem = None - self._imageItems = [] - self.highlightingRectItemsColor = None - - try: - self.removeItem(self.infoTextItem) - except Exception as e: - pass - - def autoBtnClicked(self): - self.vb.autoRange() - self.autoBtn.hide() - - def addDelRoiItem(self, roiItem, key): - if self.isDelRoiItemPresent(roiItem): - return - - self.delRoiItems[key] = roiItem - roiItem.key = key - self.addItem(roiItem) - - def removeDelRoiItem(self, roiItem): - key = roiItem.key - self.delRoiItems.pop(key, None) - try: - self.removeItem(roiItem) - except Exception as err: - return - - def isDelRoiItemPresent(self, roiItem): - try: - key = roiItem.key - except AttributeError as e: - return False - - try: - roi = self.delRoiItems[key] - except Exception as err: - return False - - return True - - def viewRange(self, mask_img=None): - if mask_img is None: - return super().viewRange() - - mask_rp = skimage.measure.regionprops( - skimage.measure.label(mask_img) - ) - if not mask_rp: - return super().viewRange() - - mask_obj = mask_rp[0] - ymin, xmin, ymax, xmax = mask_obj.bbox - return (xmin, xmax), (ymin, ymax) - -class sliderWithSpinBox(QWidget): - sigValueChange = Signal(object) - valueChanged = Signal(object) - editingFinished = Signal() - - def __init__(self, *args, **kwargs): - super().__init__(*args) - - layout = QGridLayout() - - title = kwargs.get('title') - row = 0 - col = 0 - if title is not None: - titleLabel = QLabel(self) - titleLabel.setText(title) - loc = kwargs.get('title_loc', 'top') - if loc == 'top': - layout.addWidget(titleLabel, 0, col, alignment=Qt.AlignLeft) - elif loc=='in_line': - row = -1 - col = 1 - layout.addWidget(titleLabel, 0, 0, alignment=Qt.AlignLeft) - layout.setColumnStretch(0, 0) - - self._normalize = False - normalize = kwargs.get('normalize') - if normalize is not None and normalize: - self._normalize = True - self._isFloat = True - - self._isFloat = False - isFloat = kwargs.get('isFloat') - if isFloat is not None and isFloat: - self._isFloat = True - - self.slider = QSlider(Qt.Horizontal, self) - - if self._normalize or self._isFloat: - self.spinBox = DoubleSpinBox(self) - else: - self.spinBox = SpinBox(self) - self.spinBox.setAlignment(Qt.AlignCenter) - self.spinBox.setMaximum(2**31-1) - - maximum_on_label = kwargs.get('maximum_on_label') - spinbox_loc = kwargs.get('spinbox_loc', 'right') - if spinbox_loc == 'right': - spinbox_col = col+1 - slider_col = col - if maximum_on_label is not None: - maximum_on_label_col = spinbox_col + 1 - elif spinbox_loc == 'left': - spinbox_col = col - slider_col = col + 1 - if maximum_on_label is not None: - maximum_on_label_col = spinbox_col + 1 - slider_col += 1 - - if maximum_on_label is not None: - self.labelMaximum = QLabel() - layout.addWidget(self.labelMaximum, row+1, maximum_on_label_col) - layout.addWidget(self.slider, row+1, slider_col) - layout.addWidget(self.spinBox, row+1, spinbox_col) - - if title is not None: - layout.setRowStretch(0, 1) - layout.setRowStretch(row+1, 1) - layout.setColumnStretch(slider_col, 6) - layout.setColumnStretch(spinbox_col, 1) - - self._layout = layout - self.lastCol = col+1 - self.sliderCol = slider_col - - self.slider.valueChanged.connect(self.sliderValueChanged) - self.slider.sliderReleased.connect(self.onEditingFinished) - self.spinBox.valueChanged.connect(self.spinboxValueChanged) - self.spinBox.editingFinished.connect(self.onEditingFinished) - - layout.setContentsMargins(5, 0, 5, 0) - - self.setLayout(layout) - - - if maximum_on_label is not None: - self.setMaximum(maximum_on_label) - self.labelMaximum.setText(f'/{maximum_on_label}') - - def onEditingFinished(self): - self.editingFinished.emit() - - def maximum(self): - return self.slider.maximum() - - def minimum(self): - return self.slider.minimum() - - def setValue(self, value, emitSignal=False): - valueInt = value - if self._normalize: - valueInt = int(value*self.slider.maximum()) - elif self._isFloat: - valueInt = int(value) - - self.spinBox.valueChanged.disconnect() - self.spinBox.setValue(value) - self.spinBox.valueChanged.connect(self.spinboxValueChanged) - - self.slider.valueChanged.disconnect() - if valueInt > self.slider.maximum(): - self.slider.setMaximum(valueInt) - self.slider.setValue(valueInt) - self.slider.valueChanged.connect(self.sliderValueChanged) - - if emitSignal: - self.sigValueChange.emit(self.value()) - self.valueChanged.emit(self.value()) - - def setMaximum(self, max, including_spinbox=False): - self.slider.setMaximum(max) - if including_spinbox: - self.spinBox.setMaximum(max) - - def setSingleStep(self, step): - self.spinBox.setSingleStep(step) - - def setMinimum(self, min, including_spinbox=False): - self.slider.setMinimum(min) - if including_spinbox: - self.spinBox.setMinimum(min) - - def setSingleStep(self, step): - self.spinBox.setSingleStep(step) - - def setDecimals(self, decimals): - self.spinBox.setDecimals(decimals) - - def setTickPosition(self, position): - self.slider.setTickPosition(position) - - def setTickInterval(self, interval): - self.slider.setTickInterval(interval) - - def sliderValueChanged(self, val): - self.spinBox.valueChanged.disconnect() - if self._normalize: - valF = val/self.slider.maximum() - self.spinBox.setValue(valF) - else: - self.spinBox.setValue(val) - self.spinBox.valueChanged.connect(self.spinboxValueChanged) - self.sigValueChange.emit(self.value()) - self.valueChanged.emit(self.value()) - - def spinboxValueChanged(self, val): - if self._normalize: - val = int(val*self.slider.maximum()) - elif self._isFloat: - val = int(val) - - self.slider.valueChanged.disconnect() - self.slider.setValue(val) - self.slider.valueChanged.connect(self.sliderValueChanged) - self.sigValueChange.emit(self.value()) - self.valueChanged.emit(self.value()) - - def value(self): - return self.spinBox.value() - - def setDisabled(self, disabled) -> None: - self.slider.setDisabled(disabled) - self.spinBox.setDisabled(disabled) - -class BaseImageItem(pg.ImageItem): - def __init__( - self, image=None, **kargs - ): - self.minMaxValuesMapper = None - self.minMaxValuesMapperPreproc = None - self.minMaxValuesMapperCombined = None - self.minMaxValuesMapperEqualized = None - self.pos_i = 0 - self.z = 0 - self.frame_i = 0 - self.usePreprocessed = False - self.useEqualized = False - self.useCombined = False - self._isRgba = False - - super().__init__(image, **kargs) - self.autoLevelsEnabled = None - - def isRgba(self): - return self._isRgba - - def setEnableAutoLevels(self, enabled: bool): - self.autoLevelsEnabled = enabled - - def setImage( - self, image=None, autoLevels=None, **kargs - ): - if autoLevels is None: - autoLevels = self.autoLevelsEnabled - - if image is not None and image.ndim == 3 and image.shape[2] in (3, 4): - self._isRgba = True - - super().setImage(image, autoLevels=autoLevels, **kargs) - - def preComputedMinMaxValues(self, data: List['load.loadData']): - self.minMaxValuesMapper = {} - for pos_i, posData in enumerate(data): - img_data = posData.img_data - requires_time_dim = ( - posData.img_data.ndim == 2 - or (posData.img_data.ndim == 3 and posData.SizeZ > 1) - ) - if requires_time_dim: - img_data = (img_data,) - - for frame_i, image in enumerate(img_data): - if image.ndim == 3: - self._updateMinMaxValuesProjections( - image, pos_i, frame_i, self.minMaxValuesMapper - ) - - if image.ndim == 2: - image = (image,) - - for z, img in enumerate(image): - self.minMaxValuesMapper[(pos_i, frame_i, z)] = ( - np.nanmin(img), np.nanmax(img) - ) - - def updateMinMaxValuesEqualizedData( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - z_slice: Union[int, str], - ): - if self.minMaxValuesMapperEqualized is None: - self.minMaxValuesMapperEqualized = {} - - posData = data[pos_i] - img = posData.equalized_img_data[frame_i][z_slice] - key = (pos_i, frame_i, z_slice) - self.minMaxValuesMapperEqualized[key] = (np.nanmin(img), np.nanmax(img)) - - def updateMinMaxValuesEqualizedDataProjections( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - ): - posData = data[pos_i] - eq_zstack = posData.equalized_img_data[frame_i] - - self._updateMinMaxValuesProjections( - eq_zstack, pos_i, frame_i, self.minMaxValuesMapperEqualized - ) - - def _updateMinMaxValuesProjections(self, zstack, pos_i, frame_i, mapper): - max_proj = zstack.max(axis=0) - key = (pos_i, frame_i, 'max z-projection') - mapper[key] = np.nanmin(max_proj), np.nanmax(max_proj) - - mean_proj = zstack.mean(axis=0) - key = (pos_i, frame_i, 'mean z-projection') - mapper[key] = np.nanmin(mean_proj), np.nanmax(mean_proj) - - median_proj = np.median(zstack, axis=0) - key = (pos_i, frame_i, 'median z-proj.') - mapper[key] = np.nanmin(median_proj), np.nanmax(median_proj) - - def updateMinMaxValuesPreprocessedData( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - z_slice: Union[int, str], - ): - if self.minMaxValuesMapperPreproc is None: - self.minMaxValuesMapperPreproc = {} - - posData = data[pos_i] - img = posData.preproc_img_data[frame_i][z_slice] - key = (pos_i, frame_i, z_slice) - self.minMaxValuesMapperPreproc[key] = (np.nanmin(img), np.nanmax(img)) - - def updateMinMaxValuesPreprocessedProjections( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - ): - posData = data[pos_i] - zstack = posData.preproc_img_data[frame_i] - - self._updateMinMaxValuesProjections( - zstack, pos_i, frame_i, self.minMaxValuesMapperPreproc - ) - - def updateMinMaxValuesCombinedData( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - z_slice: Union[int, str], - ): - if self.minMaxValuesMapperCombined is None: - self.minMaxValuesMapperCombined = {} - - posData = data[pos_i] - img = posData.combine_img_data[frame_i][z_slice] - key = (pos_i, frame_i, z_slice) - self.minMaxValuesMapperCombined[key] = (np.nanmin(img), np.nanmax(img)) - - def updateMinMaxValuesCombinedDataProjections( - self, - data: List['load.loadData'], - pos_i: int, - frame_i: int, - ): - posData = data[pos_i] - zstack = posData.combine_img_data[frame_i] - - self._updateMinMaxValuesProjections( - zstack, pos_i, frame_i, self.minMaxValuesMapperCombined - ) - - def setCurrentPosIndex(self, pos_i: int): - self.pos_i = pos_i - - def setCurrentFrameIndex(self, frame_i: int): - self.frame_i = frame_i - - def setCurrentZsliceIndex(self, z: int): - self.z = z - - def quickMinMax(self, targetSize=1e6): - if self.isRgba(): - return super().quickMinMax(targetSize=targetSize) - - if self.usePreprocessed and self.minMaxValuesMapperPreproc is not None: - minMaxValuesMapper = self.minMaxValuesMapperPreproc - elif self.useCombined and self.minMaxValuesMapperCombined is not None: - minMaxValuesMapper = self.minMaxValuesMapperCombined - elif self.useEqualized and self.minMaxValuesMapperEqualized is not None: - minMaxValuesMapper = self.minMaxValuesMapperEqualized - else: - minMaxValuesMapper = self.minMaxValuesMapper - - if minMaxValuesMapper is None: - return super().quickMinMax(targetSize=targetSize) - - try: - key = (self.pos_i, self.frame_i, self.z) - levels = minMaxValuesMapper[key] - return levels - except Exception as err: - pass - - try: - key = (self.pos_i, self.frame_i, self.z) - levels = self.minMaxValuesMapper[key] - return levels - except Exception as err: - return super().quickMinMax(targetSize=targetSize) - - def setOpacity(self, value, **kwargs): - if value == 0: - value = 0.001 - - if value == 1: - value = 0.999 - - super().setOpacity(value) - - -class BaseLabelsImageItem(pg.ImageItem): - def __init__( - self, image=None, **kargs - ): - super().__init__(image, **kargs) - - def setImage(self, image=None, **kwargs): - if image is None: - return - autoLevels = kwargs.get('autoLevels') - if autoLevels is None: - kwargs['autoLevels'] = False - super().setImage(image, **kwargs) - -class OverlayImageItem(pg.ImageItem): - def __init__( - self, image=None, **kargs - ): - super().__init__(image, **kargs) - self.autoLevelsEnabled = None - - def setEnableAutoLevels(self, enabled: bool): - self.autoLevelsEnabled = enabled - - def setImage( - self, image=None, autoLevels=None, **kargs - ): - if autoLevels is None: - autoLevels = self.autoLevelsEnabled - - super().setImage(image, autoLevels=autoLevels, **kargs) - - def setOpacity(self, value, **kwargs): - if value == 0: - value = 0.001 - - if value == 1: - value = 0.999 - - super().setOpacity(value) - -class ParentImageItem(BaseImageItem): - def __init__( - self, image=None, linkedImageItem=None, activatingActions=None, - debug=False, **kargs - ): - super().__init__(image, **kargs) - self.linkedImageItem = linkedImageItem - self.activatingActions = activatingActions - self.debug = debug - self._forceDoNotUpdateLinked = False - self.autoLevelsEnabled = None - - def clear(self): - if self.linkedImageItem is not None: - self.linkedImageItem.clear() - return super().clear() - - def isLinkedImageItemActive(self): - if self._forceDoNotUpdateLinked: - return False - - if self.linkedImageItem is None: - return False - - if self.activatingActions is None: - return False - - for action in self.activatingActions: - if action.isChecked(): - return True - - return False - - def setEnableAutoLevels(self, enabled: bool): - self.autoLevelsEnabled = enabled - - def setUsePreprocessed(self, usePreprocessed): - self.usePreprocessed = usePreprocessed - if self.linkedImageItem is None: - return - - self.linkedImageItem.usePreprocessed = usePreprocessed - - def setUseCombined(self, useCombined): - self.useCombined = useCombined - if self.linkedImageItem is None: - return - - self.linkedImageItem.useCombined = useCombined - - def preComputedMinMaxValues(self, *args, **kwargs): - super().preComputedMinMaxValues(*args, **kwargs) - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapper = self.minMaxValuesMapper - - def updateMinMaxValuesPreprocessedData(self, *args, **kwargs): - super().updateMinMaxValuesPreprocessedData(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapper = self.minMaxValuesMapper - - def updateMinMaxValuesCombinedData(self, *args, **kwargs): - super().updateMinMaxValuesCombinedData(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapperCombined = ( - self.minMaxValuesMapperCombined - ) - - def updateMinMaxValuesCombinedDataProjections(self, *args, **kwargs): - super().updateMinMaxValuesCombinedDataProjections(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapperCombined = ( - self.minMaxValuesMapperCombined - ) - - def updateMinMaxValuesEqualizedDataProjections(self, *args, **kwargs): - super().updateMinMaxValuesEqualizedDataProjections(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapperEqualized = ( - self.minMaxValuesMapperEqualized - ) - - def updateMinMaxValuesEqualizedData(self, *args, **kwargs): - super().updateMinMaxValuesEqualizedData(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.minMaxValuesMapperEqualized = ( - self.minMaxValuesMapperEqualized - ) - - def setCurrentPosIndex(self, *args, **kwargs): - super().setCurrentPosIndex(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.pos_i = self.pos_i - - def setCurrentFrameIndex(self, *args, **kwargs): - super().setCurrentFrameIndex(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.frame_i = self.frame_i + 1 - - def setCurrentZsliceIndex(self, *args, **kwargs): - super().setCurrentZsliceIndex(*args, **kwargs) - - if self.linkedImageItem is None: - return - - self.linkedImageItem.z = self.z - - def setImage( - self, image=None, autoLevels=None, next_frame_image=None, - scrollbar_value=None, force_set_linked=False, **kargs - ): - if autoLevels is None: - autoLevels = self.autoLevelsEnabled - - super().setImage(image, autoLevels=autoLevels, **kargs) - - if self.linkedImageItem is None: - return - - if not self.isLinkedImageItemActive() and not force_set_linked: - return - - if next_frame_image is not None: - self.linkedImageItem.setImage( - next_frame_image, - scrollbar_value=scrollbar_value, - autoLevels=autoLevels - ) - elif image is not None: - self.linkedImageItem.setImage(image) - - def updateImage(self, *args, **kargs): - if self.isLinkedImageItemActive(): - self.linkedImageItem.image = self.image - self.linkedImageItem.updateImage(*args, **kargs) - return super().updateImage(*args, **kargs) - - def setOpacity(self, value, applyToLinked=True): - super().setOpacity(value) - if not applyToLinked: - return - - if self.linkedImageItem is None: - return - - self.linkedImageItem.setOpacity(value) - - def setLookupTable(self, lut): - super().setLookupTable(lut) - # if self.linkedImageItem is not None: - # self.linkedImageItem.setLookupTable(lut) - -class ChildImageItem(BaseImageItem): - def __init__(self, *args, linkedScrollbar=None, **kwargs): - BaseImageItem.__init__(self, *args, **kwargs) - self.linkedScrollbar = linkedScrollbar - - def setImage(self, img=None, z=None, scrollbar_value=None, **kargs): - autoLevels = kargs.get('autoLevels') - if autoLevels is None: - kargs['autoLevels'] = False - - if img is None: - BaseImageItem.setImage(self, img, **kargs) - return - - if img.ndim == 3 and img.shape[-1] > 4 and z is not None: - BaseImageItem.setImage(self, img[z], **kargs) - else: - BaseImageItem.setImage(self, img, **kargs) - - if self.linkedScrollbar is None: - return - - if not self.linkedScrollbar.isEnabled(): - return - - if scrollbar_value is None: - return - - self.linkedScrollbar.setValueNoSignal(scrollbar_value) - -class labImageItem(pg.ImageItem): - def __init__(self, *args, **kwargs): - pg.ImageItem.__init__(self, *args, **kwargs) - - def setImage(self, img=None, z=None, **kargs): - autoLevels = kargs.get('autoLevels') - if autoLevels is None: - kargs['autoLevels'] = False - - if img is None: - pg.ImageItem.setImage(self, img, **kargs) - return - - if img.ndim == 3 and img.shape[-1] > 4 and z is not None: - pg.ImageItem.setImage(self, img[z], **kargs) - else: - pg.ImageItem.setImage(self, img, **kargs) - - - -class PostProcessSegmSlider(sliderWithSpinBox): - def __init__(self, *args, label=None, **kwargs): - super().__init__(*args, **kwargs) - - self.label = label - self.checkbox = QCheckBox('Disable') - self._layout.addWidget(self.checkbox, self.sliderCol, self.lastCol+1) - self.checkbox.toggled.connect(self.onCheckBoxToggled) - self.valueChanged.connect(self.checkExpandRange) - - def onCheckBoxToggled(self, checked: bool) -> None: - super().setDisabled(checked) - if self.label is not None: - self.label.setDisabled(checked) - self.onValueChanged(None) - self.onEditingFinished() - - def onValueChanged(self, value): - self.valueChanged.emit(value) - - def checkExpandRange(self, value): - if value == self.maximum(): - range = int(self.maximum() - self.minimum()) - half_range = int(range/2) - newMinimum = self.minimum() + half_range - newMaximum = self.maximum() + half_range - self.setMaximum(newMaximum) - self.setMinimum(newMinimum) - elif value == self.minimum(): - range = int(self.maximum() - self.minimum()) - half_range = int(range/2) - newMinimum = self.minimum() - half_range - newMaximum = self.maximum() - half_range - self.setMaximum(newMaximum) - self.setMinimum(newMinimum) - - def onEditingFinished(self): - self.editingFinished.emit() - - def value(self): - if self.checkbox.isChecked(): - return None - else: - return super().value() - -class GhostContourItem(pg.PlotDataItem): - def __init__( - self, ParentPlotItem, penColor=(245, 184, 0, 100), - textColor=(245, 184, 0) - ): - super().__init__() - # Yellow pen - self.setPen(pg.mkPen(width=2, color=penColor)) - self.label = myLabelItem() - self.label.setAttr('bold', True) - self.label.setAttr('color', textColor) - self._ParentPlotItem = ParentPlotItem - - def addToPlotItem(self): - self._ParentPlotItem.addItem(self) - self._ParentPlotItem.addItem(self.label) - - def removeFromPlotItem(self): - self._ParentPlotItem.removeItem(self.label) - self._ParentPlotItem.removeItem(self) - - def setData( - self, xx=None, yy=None, fontSize=11, ID=0, - y_cursor=None, x_cursor=None - ): - if xx is None: - xx = [] - if yy is None: - yy = [] - super().setData(xx, yy) - if not hasattr(self, 'label'): - return - - if ID == 0: - self.label.setText('') - else: - self.label.setText(f'{ID}', size=fontSize) - w, h = self.label.itemRect().width(), self.label.itemRect().height() - self.label.setPos(x_cursor, y_cursor-h) - - def clear(self): - self.setData([], []) - -class GhostMaskItem(pg.ImageItem): - def __init__(self, ParentPlotItem): - super().__init__() - self.label = myLabelItem() - self.label.setAttr('bold', True) - self.label.setAttr('color', (245, 184, 0)) - self._ParentPlotItem = ParentPlotItem - - def initImage(self, imgShape): - image = np.zeros(imgShape, dtype=np.uint32) - self.setImage(image) - - def initLookupTable(self, rgbaColor): - lut = np.zeros((2, 4), dtype=np.uint8) - lut[1,-1] = 255 - lut[1,:-1] = rgbaColor - self.setLookupTable(lut) - - def addToPlotItem(self): - self._ParentPlotItem.addItem(self) - self._ParentPlotItem.addItem(self.label) - - def removeFromPlotItem(self): - self._ParentPlotItem.removeItem(self.label) - self._ParentPlotItem.removeItem(self) - - def updateGhostImage(self, ID=0, y_cursor=None, x_cursor=None, fontSize=None): - self.setImage(self.image) - - if ID == 0: - self.label.setText('') - return - - self.label.setText(f'{ID}', size=fontSize) - w, h = self.label.itemRect().width(), self.label.itemRect().height() - self.label.item.setPos(x_cursor, y_cursor-h) - - def clear(self): - if hasattr(self, 'label'): - self.label.setText('') - if self.image is None: - return - self.image[:] = 0 - self.setImage(self.image) - -class PostProcessSegmSpinbox(QWidget): - valueChanged = Signal(int) - editingFinished = Signal() - sigCheckboxToggled = Signal() - - def __init__(self, *args, isFloat=False, label=None, **kwargs): - super().__init__(*args, **kwargs) - - layout = QHBoxLayout() - - if isFloat: - self.spinBox = DoubleSpinBox() - else: - self.spinBox = SpinBox() - - self.spinBox.valueChanged.connect(self.onValueChanged) - self.spinBox.editingFinished.connect(self.onEditingFinished) - - layout.addWidget(self.spinBox) - self.checkbox = QCheckBox('Disable') - layout.addWidget(self.checkbox) - layout.setStretch(0,1) - layout.setStretch(1,0) - - self.label = label - - self.checkbox.toggled.connect(self.onCheckBoxToggled) - - layout.setContentsMargins(5, 0, 5, 0) - - self.setLayout(layout) - - def onCheckBoxToggled(self, checked: bool) -> None: - self.spinBox.setDisabled(checked) - if self.label is not None: - self.label.setDisabled(checked) - self.onValueChanged(None) - self.onEditingFinished() - - def onValueChanged(self, value): - self.valueChanged.emit(value) - - def onEditingFinished(self): - self.editingFinished.emit() - - def maximum(self): - return self.spinBox.maximum() - - def setValue(self, value): - self.spinBox.setValue(value) - - def sizeHint(self): - return self.spinBox.sizeHint() - - def setMaximum(self, max): - self.spinBox.setMaximum(max) - - def setSingleStep(self, step): - self.spinBox.setSingleStep(step) - - def setMinimum(self, min): - self.spinBox.setMinimum(min) - - def setSingleStep(self, step): - self.spinBox.setSingleStep(step) - - def setDecimals(self, decimals): - self.spinBox.setDecimals(decimals) - - def value(self): - if self.checkbox.isChecked(): - return None - else: - return self.spinBox.value() - -class CopiableCommandWidget(QGroupBox): - def __init__(self, command='', parent=None, font_size='13px'): - super().__init__(parent) - - layout = QHBoxLayout() - - label = QLabel(self) - self.label = label - self._font_size = font_size - self.setCommand(command, font_size=font_size) - label.setTextInteractionFlags( - Qt.TextBrowserInteraction | Qt.TextSelectableByKeyboard - ) - layout.addWidget(label) - layout.addWidget(QVLine(shadow='Plain', color='#4d4d4d')) - copyButton = copyPushButton('Copy', flat=True, hoverable=True) - copyButton.clicked.connect(self.copyToClipboard) - layout.addWidget(copyButton) - layout.addStretch(1) - - self.setLayout(layout) - - def setWordWrap(self, wordWrap): - self.label.setWordWrap(wordWrap) - - def copyToClipboard(self): - cb = QApplication.clipboard() - cb.clear(mode=cb.Clipboard) - cb.setText(self._command, mode=cb.Clipboard) - print('Command copied!') - - def setCommand(self, command, font_size=None): - if font_size is None: - font_size = self._font_size - - self._command = command - txt = html_utils.paragraph( - f'{command}', font_size=font_size - ) - self.label.setText(txt) - - def command(self): - return self._command - - def text(self): - return self.label.text() - - def setTextInteractionFlags(self, flags): - self.label.setTextInteractionFlags(flags) - -def PostProcessSegmWidget( - minimum, maximum, value, useSliders, isFloat=False, normalize=False, - label=None - ): - if useSliders: - if normalize: - maximum = int(maximum*100) - widget = PostProcessSegmSlider( - normalize=normalize, isFloat=isFloat, label=label - ) - else: - widget = PostProcessSegmSpinbox(label=label, isFloat=isFloat) - widget.setMinimum(minimum) - widget.setMaximum(maximum) - widget.setValue(value) - return widget - -# class Spinner(QLabel): -# def __init__(self, size=150, parent=None): -# super().__init__(parent) -# # layout = QHBoxLayout() - -# # self._label = QLabel() -# self.setAlignment(Qt.AlignCenter) -# # self._label.setText('Ciao') -# self._pixmap = QPixmap(':spinner.svg') - -# self._pixmapSize = size + size%2 -# self._halfPixmapSize = int(self._pixmapSize/2) -# printl(self._pixmapSize, self._halfPixmapSize) -# self.setPixmap(self._pixmap.scaled(self._pixmapSize, self._pixmapSize)) - -# # self.setFixedSize(160, 160) -# self._angle = 0 - -# blurEffect = QGraphicsBlurEffect() -# blurEffect.setBlurRadius(1.4) -# self.setGraphicsEffect(blurEffect) - -# # layout.addWidget(self._label) -# # self.setLayout(layout) - -# self.animation = QPropertyAnimation(self, b"angle", self) -# self.animation.setStartValue(0) -# self.animation.setEndValue(360) -# self.animation.setLoopCount(-1) -# self.animation.setDuration(1700) -# self.animation.start() - -# @Property(int) -# def angle(self): -# return self._angle - -# @angle.setter -# def angle(self, value): -# self._angle = value -# self.update() - -# def paintEvent(self, ev=None): -# width, height = self.size().width(), self.size().height() -# radius_x = int(width/2) -# radius_y = int(height/2) -# x = radius_x-self._halfPixmapSize -# y = radius_y-self._halfPixmapSize -# painter = QPainter(self) -# painter.setRenderHint(QPainter.Antialiasing) -# painter.translate(radius_x, radius_y) -# painter.rotate(self._angle) -# painter.translate(-radius_x, -radius_y) -# painter.drawPixmap(x, y, self._pixmap.scaled(self._pixmapSize, self._pixmapSize)) -# painter.end() - -class LoadingCircleAnimation(QLabel): - def __init__(self, size=32, motionBlur=False, parent=None): - super().__init__(parent) - # layout = QHBoxLayout() - - # self._label = QLabel() - self.setAlignment(Qt.AlignCenter) - self._size = size + size%2 - self._radius = int(self._size/2) - self.setFixedSize(self._size, self._size) - self._dotDiameter = int(self._size*0.15) - self._dotDiameter = self._dotDiameter + self._dotDiameter%2 - self._dotRadius = int(self._dotDiameter/2) - - self._rgb = _palettes.getPainterColor()[:3] - self._index = 0 - - self.setBrushesAndAngles() - - if motionBlur: - blurEffect = QGraphicsBlurEffect() - blurRadius = self._size*0.02 - if blurRadius < 1: - blurRadius = 1 - blurEffect.setBlurRadius(blurRadius) - self.setGraphicsEffect(blurEffect) - - self.animation = QPropertyAnimation(self, b"index", self) - self.animation.setStartValue(0) - self.animation.setEndValue(11) - self.animation.setLoopCount(-1) - self.animation.setDuration(1200) - self.animation.start() - - self.update() - - def setVisible(self, visible): - if visible: - self.animation.start() - else: - self.animation.stop() - super().setVisible(visible) - - def setBrushesAndAngles(self): - self._brushes = [] - self._pens = [] - alphas = np.round(np.linspace(0, 255, 12)).astype(int) - self._angles = np.arange(0, 360, 30) - for alpha in alphas: - color = QColor(*self._rgb, alpha) - self._brushes.append(pg.mkBrush(color)) - self._pens.append(pg.mkPen(color)) - - @Property(int) - def index(self): - return self._index - - @index.setter - def index(self, value): - self._index = value - self.update() - - def paintEvent(self, event): - painter = QPainter(self) - painter.setRenderHint(QPainter.Antialiasing) - painter.translate(self._radius, self._radius) - for i in range(12): - idx = i - self._index - angle = self._angles[i] - painter.setBrush(self._brushes[idx]) - painter.setPen(self._pens[idx]) - x = (self._radius-self._dotRadius)*math.cos(angle*math.pi/180) - y = (self._radius-self._dotRadius)*math.sin(angle*math.pi/180) - painter.drawEllipse(QPointF(x, y), self._dotRadius, self._dotRadius) - - painter.end() - -class QBaseWindow(QMainWindow): - def __init__(self, parent=None): - super().__init__(parent) - - def exec_(self): - self.show(block=True) - - def show(self, block=False): - self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - super().show() - if block: - self.loop = QEventLoop() - self.loop.exec_() - - def closeEvent(self, event): - if hasattr(self, 'loop'): - self.loop.exit() - - def keyPressEvent(self, event) -> None: - if event.key() == Qt.Key_Escape: - event.ignore() - return - - super().keyPressEvent(event) - -class ScrollBarWithNumericControl(QWidget): - sigValueChanged = Signal(int) - sigMaxProjToggled = Signal(bool, object) - - def __init__( - self, orientation=Qt.Horizontal, add_max_proj_button=False, - parent=None, labelText='' - ) -> None: - super().__init__(parent) - - self._slot = None - - layout = QHBoxLayout() - self.scrollbar = QScrollBar(orientation, self) - self.spinbox = QSpinBox(self) - self.maxLabel = QLabel(self) - idx = 0 - if labelText: - layout.addWidget(QLabel(labelText)) - layout.setStretch(idx, 0) - idx += 1 - - layout.addWidget(self.spinbox) - layout.setStretch(idx,0) - idx += 1 - - layout.addWidget(self.maxLabel) - layout.setStretch(idx,0) - idx += 1 - - layout.addWidget(self.scrollbar) - layout.setStretch(idx,1) - idx += 1 - - if add_max_proj_button: - self.maxProjCheckbox = QCheckBox('MAX') - self.scrollbar.maxProjCheckbox = self.maxProjCheckbox - layout.addWidget(self.maxProjCheckbox) - layout.setStretch(idx,0) - - layout.setContentsMargins(5, 0, 5, 0) - - self.setLayout(layout) - - self.spinbox.valueChanged.connect(self.spinboxValueChanged) - self.scrollbar.valueChanged.connect(self.scrollbarValueChanged) - - if add_max_proj_button: - self.maxProjCheckbox.toggled.connect(self.maxProjToggled) - - def connectValueChanged(self, slot): - self.sigValueChanged.connect(slot) - self._slot = slot - - def setValueNoSignal(self, value): - if self._slot is None: - return - self.sigValueChanged.disconnect() - self.setValue(value) - self.sigValueChanged.connect(self._slot) - - def maxProjToggled(self, checked): - self.scrollbar.setDisabled(checked) - self.sigMaxProjToggled.emit(checked, self) - - def showEvent(self, event) -> None: - super().showEvent(event) - - self.scrollbar.setMinimumHeight(self.spinbox.height()) - - def setMaximum(self, maximum): - self.maxLabel.setText(f'/{maximum}') - self.scrollbar.setMaximum(maximum) - self.spinbox.setMaximum(maximum) - - def setMinimum(self, minumum): - self.scrollbar.setMinimum(minumum) - self.spinbox.setMinimum(minumum) - - def spinboxValueChanged(self, value): - self.scrollbar.setValue(value) - - def scrollbarValueChanged(self, value): - self.spinbox.setValue(value) - self.sigValueChanged.emit(value) - - def setValue(self, value): - self.scrollbar.setValue(value) - - def value(self): - return self.scrollbar.value() - - def maximum(self): - return self.scrollbar.maximum() - -class ImShowPlotItem(pg.PlotItem): - def __init__( - self, parent=None, name=None, labels=None, title=None, - viewBox=None, axisItems=None, enableMenu=True, **kargs - ): - super().__init__( - parent, name, labels, title, viewBox, axisItems, enableMenu, - **kargs - ) - # Overwrite zoom out button behaviour to disable autoRange after - # clicking it. - # If autorange is enabled, it is called everytime the brush or eraser - # scatter plot items touches the border causing flickering - self.disableAutoRange() - self.autoBtn.mode = 'manual' - self.invertY(True) - self.setAspectLocked(True) - self.addImageItem(kargs.get('imageItem')) - - self._selected = False - self.selectingRects = [] - - def setSelectableTitle(self, title: QGraphicsProxyWidget, **kwargs): - self.layout.removeItem(self.titleLabel) - self.layout.addItem(title, 0, 1, alignment=Qt.AlignCenter) - - def isSelected(self): - return self._selected - - def setSelected( - self, selected: bool, - xlim=(-np.inf, np.inf), - ylim=(-np.inf, np.inf) - ): - if selected == self._selected: - return - - if selected: - ((xmin, xmax), (ymin, ymax)) = self.viewRange() - ylim_min, ylim_max = ylim - xlim_min, xlim_max = xlim - - xmin = max(xlim_min, xmin) - xmax = min(xlim_max, xmax) - ymin = max(ylim_min, ymin) - ymax = min(ylim_max, ymax) - - w = xmax - xmin - h = ymax - ymin - - bs = round(((w + h) / 2) * 0.02) - if bs < 1: - bs = 1 - - rect_left = RectItem(QRectF(xmin, ymin, bs, h)) - rect_top = RectItem(QRectF(xmin+bs, ymin, w-bs-bs, bs)) - rect_right = RectItem(QRectF(xmax-bs, ymin, bs, h)) - rect_bottom = RectItem(QRectF(xmin+bs, ymax-bs, w-bs-bs, bs)) - self.selectingRects.append(rect_left) - self.selectingRects.append(rect_top) - self.selectingRects.append(rect_right) - self.selectingRects.append(rect_bottom) - - self.addItem(rect_left) - self.addItem(rect_top) - self.addItem(rect_right) - self.addItem(rect_bottom) - else: - for rect in self.selectingRects: - self.removeItem(rect) - self.selectingRects = [] - - self._selected = selected - - def addImageItem(self, imageItem): - self.imageItem = imageItem - if imageItem is None: - return - - self.setupContextMenu() - self.addItem(imageItem) - - def setupContextMenu(self): - shuffleCmapAction = QAction('Shuffle colormap', self.vb.menu) - shuffleCmapAction.triggered.connect(self.shuffleColormap) - self.vb.menu.addAction(shuffleCmapAction) - - self.resetCmapAction = QAction('Reset colormap', self.vb.menu) - self.resetCmapAction.triggered.connect(self.resetColormap) - self.vb.menu.addAction(self.resetCmapAction) - self.resetCmapAction.setDisabled(True) - - def shuffleColormap(self): - N = self.imageItem._numLevels - colors = self.imageItem.lut/255 - cmap = LinearSegmentedColormap.from_list('shuffled', colors, N=N) - lut = plot.matplotlib_cmap_to_lut(cmap, n_colors=N) - if not self.resetCmapAction.isEnabled(): - self._defaultLut = lut.copy() - bkgrColor = lut[0].copy() - np.random.shuffle(lut) - lut[0] = bkgrColor - self.imageItem.setLookupTable(lut) - self.imageItem.update() - self.resetCmapAction.setDisabled(False) - - def resetColormap(self): - self.imageItem.setLookupTable(self._defaultLut) - - def autoBtnClicked(self): - self.autoRange() - - def autoRange(self): - self.vb.autoRange() - self.autoBtn.hide() - -class _ImShowImageItem(pg.ImageItem): - sigDataHover = Signal(str) - sigHoverEvent = Signal(object, object) - sigMousePressEvent = Signal(object, object) - - def __init__(self, idx) -> None: - super().__init__() - self._idx = idx - self._cursors = [] - self._autoLevels = True - - def _getHoverImageValue(self, xdata, ydata): - try: - value = self.image[ydata, xdata] - return value - except Exception as err: - return - - def setAutoLevels(self, autoLevels): - self._autoLevels = autoLevels - - def mousePressEvent(self, event): - self.sigMousePressEvent.emit(self, event) - super().mousePressEvent(event) - - def setOtherImagesCursors(self, cursors): - self._cursors = cursors - - def clearCursors(self): - for p, cursor in enumerate(self._cursors): - if p == self._idx: - continue - - cursor.setData([], []) - - def setImage(self, *args, **kwargs): - if 'autoLevels' not in kwargs: - kwargs['autoLevels'] = self._autoLevels - - super().setImage(*args, **kwargs) - if not args: - return - - if not kwargs['autoLevels']: - return - - image = args[0] - self._imageMax = image.max() - self._imageMin = image.min() - self._numLevels = self._imageMax - self._imageMin - - def hoverEvent(self, event): - self.sigHoverEvent.emit(self, event) - - if event.isExit(): - self.clearCursors() - self.sigDataHover.emit('') - return - - x, y = event.pos() - xdata, ydata = int(x), int(y) - value = self._getHoverImageValue(xdata, ydata) - if value is None: - self.clearCursors() - self.sigDataHover.emit('') - return - - try: - self.sigDataHover.emit( - f'x={xdata}, y={ydata}, {value = :.4f}' - ) - except Exception as e: - self.sigDataHover.emit( - f'x={xdata}, y={ydata}, {[val for val in value]}' - ) - - for p, cursor in enumerate(self._cursors): - if p == self._idx: - continue - - cursor.setData([x], [y]) - -class ImShow(QBaseWindow): - def __init__( - self, - parent=None, - link_scrollbars=True, - infer_rgb=True, - figure_title='', - selectable_images=False - ): - super().__init__(parent=parent) - self._linkedScrollbars = link_scrollbars - self._infer_rgb = infer_rgb - self._figure_title = figure_title - self._selectable_images = True - self.selected_idx = None - - self._autoLevels = True - - self.textItems = [] - self.group_to_idx_mapper = {'': 0} - - def _getGraphicsScrollbar(self, idx, image, imageItem, maximum): - proxy = QGraphicsProxyWidget(imageItem) - scrollbar = ScrollBarWithNumericControl( - orientation=Qt.Horizontal, add_max_proj_button=True - ) - scrollbar.sigValueChanged.connect(self.OnScrollbarValueChanged) - scrollbar.sigMaxProjToggled.connect(self.onMaxProjToggled) - scrollbar.idx = idx - scrollbar.image = image - scrollbar.imageItem = imageItem - scrollbar.setMaximum(maximum) - proxy.setWidget(scrollbar) - proxy.scrollbar = scrollbar - return proxy - - def OnScrollbarValueChanged(self, value): - scrollbar = self.sender() - imageItem = scrollbar.imageItem - img = self._get2Dimg(imageItem, scrollbar.image) - imageItem.setImage(img) # , autoLevels=self._autoLevels) - - overlayLab = self._get2DlabOverlay(imageItem) - if overlayLab is not None: - imageItem.labImageItem.setImage(overlayLab, autoLevels=False) - - self.setPointsVisible(imageItem) - - self.updateIDs() - - if not self._linkedScrollbars: - return - if len(self.ImageItems) == 1: - return - - self._linkedScrollbars = False - try: - idx = scrollbar.idx - for otherImageItem in self.ImageItems: - if otherImageItem.gridPos == imageItem.gridPos: - continue - if otherImageItem.image.shape != imageItem.image.shape: - continue - for otherScrollbar in otherImageItem.ScrollBars: - if otherScrollbar.idx != idx: - continue - otherScrollbar.setValue(scrollbar.value()) - except Exception as e: - pass - finally: - self._linkedScrollbars = True - - def _get2Dimg(self, imageItem, image): - for scrollbar in imageItem.ScrollBars: - if scrollbar.maxProjCheckbox.isChecked(): - image = image.max(axis=0) - else: - image = image[scrollbar.value()] - return image - - def _get2DlabOverlay(self, imageItem): - try: - lab = imageItem.lab - except Exception as err: - return - - for scrollbar in imageItem.ScrollBars: - if scrollbar.maxProjCheckbox.isChecked(): - lab = lab.max(axis=0) - else: - lab = lab[scrollbar.value()] - - return lab - - def isObjVisible(self, obj, imageItem): - if len(obj.centroid) == 2: - return True - - z_scrollbar = imageItem.ScrollBars[-1] - if z_scrollbar.maxProjCheckbox.isChecked(): - return True - - z_slice = z_scrollbar.value() - min_z, min_y, min_x, max_z, max_y, max_x = obj.bbox - if z_slice >= min_z and z_slice < max_z: - return True - - return False - - def onMaxProjToggled(self, checked, scrollbar): - imageItem = scrollbar.imageItem - img = self._get2Dimg(imageItem, scrollbar.image) - imageItem.setImage(img) # , autoLevels=self._autoLevels) - overlayLab = self._get2DlabOverlay(imageItem) - if overlayLab is not None: - imageItem.labImageItem.setImage(overlayLab, autoLevels=False) - self.setPointsVisible(imageItem) - if not self._linkedScrollbars: - return - if len(self.ImageItems) == 1: - return - - self._linkedScrollbars = False - try: - idx = scrollbar.idx - for otherImageItem in self.ImageItems: - if otherImageItem.gridPos == imageItem.gridPos: - continue - if otherImageItem.image.shape != imageItem.image.shape: - continue - for otherScrollbar in otherImageItem.ScrollBars: - if otherScrollbar.idx != idx: - continue - otherScrollbar.maxProjCheckbox.setChecked(checked) - except Exception as e: - pass - finally: - self._linkedScrollbars = True - - self.updateIDs() - - def setPointsVisible(self, imageItem): - if not hasattr(imageItem, 'pointsItems'): - return - - first_coord = imageItem.ScrollBars[0].value() - isMaxProj = imageItem.ScrollBars[0].maxProjCheckbox.isChecked() - for pointsItems in imageItem.pointsItems.values(): - for p, plotItem in enumerate(pointsItems): - plotItem.setVisible((isMaxProj) or (p == first_coord)) - - def setupStatusBar(self): - self.statusbar = self.statusBar() - self.wcLabel = QLabel(f"") - self.statusbar.addPermanentWidget(self.wcLabel) - - def setupMainLayout(self): - self._layout = QHBoxLayout() - self._container = QWidget() - self._container.setLayout(self._layout) - self.setCentralWidget(self._container) - - def setupGraphicLayout( - self, *images, hide_axes=True, max_ncols=4, color_scheme='light' - ): - self.graphicLayout = pg.GraphicsLayoutWidget() - self._colorScheme = color_scheme - - # Set a light background - if color_scheme == 'light': - self.graphicLayout.setBackground((235, 235, 235)) - else: - self.graphicLayout.setBackground((30, 30, 30)) - - ncells = max_ncols * ceil(len(images)/max_ncols) - - nrows = ncells // max_ncols - nrows = nrows if nrows > 0 else 1 - ncols = max_ncols if len(images) > max_ncols else len(images) - - if color_scheme == 'light': - color = 'black' - else: - color = 'white' - - self.titleLabel = pg.LabelItem( - justify='center', color=color, size='14pt' - ) - self.titleLabel.setText(self._figure_title) - self.graphicLayout.addItem(self.titleLabel, row=0, col=0, colspan=ncols) - start_row = 1 - - # Check if additional rows are needed for the scrollbars - max_ndim = max([image.ndim for image in images]) - if max_ndim > 4: - raise TypeError('One or more of the images have more than 4 dimensions.') - if max_ndim == 4: - rows_range = range(0, (nrows-1)*3+1, 3) - elif max_ndim == 3: - rows_range = range(0, (nrows-1)*2+1, 2) - else: - rows_range = range(nrows) - - self.PlotItems = [] - self.ImageItems = [] - self.ScrollBars = [] - i = 0 - for r in rows_range: - row = r + start_row - for col in range(ncols): - try: - image = images[i] - except IndexError: - break - plotItem = ImShowPlotItem() - if hide_axes: - plotItem.hideAxis('bottom') - plotItem.hideAxis('left') - self.graphicLayout.addItem(plotItem, row=row, col=col) - plotItem.loc = (row, col) - self.PlotItems.append(plotItem) - - imageItem = _ImShowImageItem(i) - plotItem.addImageItem(imageItem) - imageItem.plot = plotItem - imageItem.sigHoverEvent.connect( - self.onImageItemHoverEvent - ) - imageItem.sigMousePressEvent.connect( - self.onImageItemMousePressEvent - ) - self.ImageItems.append(imageItem) - imageItem.gridPos = (row, col) - imageItem.ScrollBars = [] - - is_rgb = image.shape[-1] == 3 and self._infer_rgb - is_rgba = image.shape[-1] == 4 and self._infer_rgb - does_not_require_scrollbars = ( - image.ndim == 2 - or (image.ndim == 3 and (is_rgb or is_rgba)) - ) - if does_not_require_scrollbars: - i += 1 - continue - - idx_image = 3 if (is_rgb or is_rgba) else 2 - for s in range(image.ndim-idx_image): - maximum = image.shape[s]-1 - scrollbarProxy = self._getGraphicsScrollbar( - s, image, imageItem, maximum - ) - self.graphicLayout.addItem( - scrollbarProxy, row=row+s+1, col=col - ) - imageItem.ScrollBars.append(scrollbarProxy.scrollbar) - - i += 1 - - self._layout.addWidget(self.graphicLayout) - - def onImageItemMousePressEvent(self, imageItem, event): - if not self._selectable_images: - return - - plotItem = imageItem.plot - if not plotItem.isSelected(): - return - - self.selected_idx = self.PlotItems.index(plotItem) - event.ignore() - self.close() - - def onImageItemHoverEvent(self, imageItem, event): - if not self._selectable_images: - return - - modifiers = QGuiApplication.keyboardModifiers() - isCtrl = modifiers == Qt.ControlModifier - plotItem = imageItem.plot - Y, X = imageItem.image.shape[:2] - plotItem.setSelected( - isCtrl and not event.isExit(), - xlim=(0, X), - ylim=(0, Y) - ) - - def movePlotItem(self, title): - combobox = self.sender() - plotItem = combobox.plotItem - row, col = plotItem.loc - - otherPlotItemIdx = combobox.titles.index(title) - otherPlotItem = self.PlotItems[otherPlotItemIdx] - other_row, other_col = otherPlotItem.loc - - self.graphicLayout.removeItem(plotItem) - self.graphicLayout.removeItem(otherPlotItem) - self.graphicLayout.addItem(otherPlotItem, row=row, col=col) - self.graphicLayout.addItem(plotItem, row=other_row, col=other_col) - - combobox.blockSignals(True) - combobox.setCurrentText(combobox.default_text) - combobox.blockSignals(False) - - plotItemIdx = combobox.titles.index(combobox.default_text) - - otherPlotItem.loc = (row, col) - plotItem.loc = (other_row, other_col) - - def setupTitles(self, *titles): - for plotItem, title in zip(self.PlotItems, titles): - combobox = ComboBox() - combobox.default_text = title - combobox.titles = list(titles) - combobox.addItems(titles) - combobox.setMaximumWidth(combobox.sizeHint().width()) - combobox.setCurrentText(title) - comboboxGraphicsItem = QGraphicsProxyWidget() - comboboxGraphicsItem.setWidget(combobox) - combobox.plotItem = plotItem - plotItem.setSelectableTitle(comboboxGraphicsItem) - combobox.currentTextChanged.connect(self.movePlotItem) - - # color = 'k' if self._colorScheme == 'light' else 'w' - # for plotItem, title in zip(self.PlotItems, titles): - # plotItem.setSelectableTitle(title, color=color) - - def updateStatusBarLabel(self, text): - self.wcLabel.setText(text) - - def autoRange(self): - for plot in self.PlotItems: - plot.autoRange() - - def showImages( - self, *images, - labels_overlays: np.ndarray | List[np.ndarray]=None, - luts=None, - labels_overlays_luts=None, - autoLevels=True, - autoLevelsOnScroll=False - ): - from .plot import matplotlib_cmap_to_lut - - images = [np.squeeze(img) for img in images] - self.luts = luts - self._autoLevels = autoLevels - self._autoLevelsOnScroll = autoLevelsOnScroll - for image in images: - if image.ndim > 5 or image.ndim < 2: - raise TypeError( - f'Input image has {image.ndim} dimensions. ' - 'Only 2-D, 3-D, and 4-D images are supported' - ) - - if isinstance(labels_overlays, np.ndarray): - labels_overlays = [labels_overlays] - - if isinstance(labels_overlays_luts, np.ndarray): - labels_overlays_luts = [labels_overlays_luts] - - if ( - labels_overlays_luts is not None - and labels_overlays is not None - and (len(labels_overlays_luts) != len(labels_overlays)) - ): - raise TypeError( - f'Number of lables_overlays_luts is {len(labels_overlays_luts)}, ' - f'while number of labels_overaly is {len(labels_overlays)}. ' - 'Pass `None` if you want to use default lut for the labels_overlays.' - ) - - if labels_overlays is not None and (len(labels_overlays) != len(images)): - raise TypeError( - f'Number of images is {len(images)}, ' - f'while number of labels_overaly is {len(labels_overlays)}. ' - 'Pass `None` if you do not need overlaid labeles.' - ) - - for i, (image, imageItem) in enumerate(zip(images, self.ImageItems)): - if luts is not None: - _autoLevels = autoLevels - lut = luts[i] - if not autoLevels and lut is not None: - imageItem.setLevels([0, len(lut)]) - else: - _autoLevels = True - if lut is None: - lut = matplotlib_cmap_to_lut('viridis') - imageItem.setLookupTable(lut) - else: - _autoLevels = True - - is_rgb = image.shape[-1] == 3 and self._infer_rgb - is_rgba = image.shape[-1] == 4 and self._infer_rgb - does_not_require_scrollbars = ( - image.ndim == 2 - or (image.ndim == 3 and (is_rgb or is_rgba)) - ) - - if does_not_require_scrollbars: - imageItem.setAutoLevels(_autoLevels) - imageItem.setImage(image) - else: - if not self._autoLevelsOnScroll and not _autoLevels: - imageItem.setAutoLevels(False) - imageItem.setLevels([image.min(), image.max()]) - for scrollbar in imageItem.ScrollBars: - scrollbar.setValue(int(scrollbar.maximum()/2)) - - imageItem.sigDataHover.connect(self.updateStatusBarLabel) - - if labels_overlays is None: - continue - - lab_overlay = labels_overlays[i] - if lab_overlay is None: - continue - - if lab_overlay.shape != image.shape: - raise TypeError( - f'`lab_overlay` at index {i} has shape ' - f'{lab_overlay.shape} which is different ' - f'from image shape {image.shape}. ' - 'The image and the `lab_overlay` must ' - 'have the same shape.' - ) - - plot = imageItem.plot - labImageItem = pg.ImageItem() - labImageItem.setOpacity(0.4) - plot.addImageItem(labImageItem) - - if labels_overlays_luts is not None: - labels_overlays_lut = labels_overlays_luts[i] - else: - labels_overlays_lut = self._getDefaultLabelsOverlayLut( - lab_overlay - ) - - labImageItem.setLookupTable(labels_overlays_lut) - labImageItem.setLevels([0, len(labels_overlays_lut)]) - - imageItem.lab = lab_overlay - imageItem.labImageItem = labImageItem - - overlayLab = self._get2DlabOverlay(imageItem) - labImageItem.setImage(overlayLab, autoLevels=False) - - # Share axis between images with same X, Y shape - all_shapes = [image.shape[-2:] for image in images] - unique_shapes = set(all_shapes) - shame_shape_plots = [] - for unique_shape in unique_shapes: - plots = [ - self.PlotItems[i] for i, shape in enumerate(all_shapes) - if shape==unique_shape - ] - shame_shape_plots.append(plots) - - for plots in shame_shape_plots: - for plot in plots: - plot.vb.setYLink(plots[0].vb) - plot.vb.setXLink(plots[0].vb) - - def _getDefaultLabelsOverlayLut(self, lab_overlay): - IDs = [obj.label for obj in skimage.measure.regionprops(lab_overlay)] - n_objs = len(IDs) - lut = np.zeros((n_objs+1, 4), dtype=np.uint8) - rgbas = colors.plt_colormap_to_pg_lut('tab20', ncolors=n_objs) - np.random.shuffle(rgbas) - lut[1:] = rgbas - return lut - - def _createPointsScatterItem(self, xx, yy, group, colors=None, data=None): - if colors is None: - cmap = matplotlib.colormaps['jet_r'] - idx = self.group_to_idx_mapper[group] - r, g, b = [round(c*255) for c in cmap(idx)][:3] - brush = pg.mkBrush(color=(r,g,b,100)) - pen = pg.mkPen(width=2, color=(r,g,b)) - hoverBrush = pg.mkBrush((r,g,b,200)) - else: - brush = [] - pen = [] - hoverBrush = None - for color in colors: - rgb = matplotlib.colors.to_rgb(color) - rgb = [round(c*255) for c in rgb] - _brush = pg.mkBrush(color=(*rgb,100)) - _pen = pg.mkPen(width=2, color=rgb) - brush.append(_brush) - pen.append(_pen) - - item = pg.ScatterPlotItem( - xx, yy, symbol='o', pxMode=False, size=3, - brush=brush, pen=pen, - hoverable=True, hoverBrush=hoverBrush, - data=data - ) - return item - - def drawPointsFromDf( - self, - points_df: pd.DataFrame | List[pd.DataFrame], - points_groups=None - ): - if not isinstance(points_df, (list, tuple)): - points_df = [points_df]*len(self.PlotItems) - - for p, df in enumerate(points_df): - if isinstance(points_groups, str): - points_groups = [points_groups] - - if points_groups is None: - grouped = [('', df)] - groups = [''] - else: - grouped = df.groupby(points_groups) - groups = grouped.groups.keys() - - idxs_space = np.linspace(0, 1, len(groups)) - self.group_to_idx_mapper = dict(zip(groups, idxs_space)) - - for group, df in grouped: - yy = df['y'].values - xx = df['x'].values - points_coords = np.column_stack((yy, xx)) - if 'z' in df.columns: - zz = df['z'].values - points_coords = np.column_stack((zz, points_coords)) - if len(group) == 1: - group = group[0] - - colors = None - if 'color' in df.columns: - colors = df['color'].values - - data = None - if 'data' in df.columns: - data = df['data'].values - - self.drawPoints( - points_coords, - colors=colors, - group=group, - idx=p, - data=data - ) - - def drawPoints( - self, - points_coords: np.ndarray, - group='', - idx=None, - colors=None, - data=None, - ): - offset = 0.5 if np.issubdtype(points_coords.dtype, np.integer) else 0 - n_dim = points_coords.shape[1] - - if idx is not None: - PlotItems = [self.PlotItems[idx]] - ImageItems = [self.ImageItems[idx]] - else: - PlotItems = self.PlotItems - ImageItems = self.ImageItems - - if n_dim == 2: - if data is None: - data = group - - zz = [0]*len(points_coords) - self.points_coords = np.column_stack((zz, points_coords)) - for p, plotItem in enumerate(PlotItems): - imageItem = ImageItems[p] - xx = points_coords[:, 1] + offset - yy = points_coords[:, 0] + offset - pointsItem = self._createPointsScatterItem( - xx, yy, group, data=data, colors=colors - ) - pointsItem.z = 0 - plotItem.addItem(pointsItem) - imageItem.pointsItems = {group: [pointsItem]} - elif n_dim == 3: - self.points_coords = points_coords - for p, plotItem in enumerate(PlotItems): - imageItem = ImageItems[p] - imageItem.pointsItems = defaultdict(list) - scrollbar = imageItem.ScrollBars[0] - for first_coord in range(scrollbar.maximum()+1): - coords_idx = np.nonzero(points_coords[:,0] == first_coord) - coords = points_coords[coords_idx] - if colors is None: - _colors = None - else: - _colors = np.asarray(colors)[coords_idx] - if len(_colors) == 0: - _colors = None - - _data = group - if data is not None: - _data = data[coords_idx] - if len(_data) == 0: - _data = group - - xx = coords[:, 2] + offset - yy = coords[:, 1] + offset - pointsItem = self._createPointsScatterItem( - xx, yy, group, data=_data, colors=_colors - ) - pointsItem.z = first_coord - plotItem.addItem(pointsItem) - pointsItem.setVisible(False) - imageItem.pointsItems[group].append(pointsItem) - self.setPointsVisible(imageItem) - - def setupDuplicatedCursors(self): - self.cursors = [] - for p, plotItem in enumerate(self.PlotItems): - cursor = pg.ScatterPlotItem( - symbol='+', pxMode=True, pen=pg.mkPen('k', width=1), - brush=pg.mkBrush('w'), size=16, tip=None - ) - self.cursors.append(cursor) - plotItem.addItem(cursor) - - for imageItem in self.ImageItems: - imageItem.setOtherImagesCursors(self.cursors) - - def setPointsData(self, points_data): - points_df = pd.DataFrame({ - 'z': self.points_coords[:, 0], - 'y': self.points_coords[:, 1], - 'x': self.points_coords[:, 2] - }) - if isinstance(points_data, pd.Series): - points_df[points_data.name] = points_data.values - elif isinstance(points_data, pd.DataFrame): - points_df = points_df.join(points_data) - elif isinstance(points_data, np.ndarray): - if points_data.ndim == 1: - points_data = points_data[np.newaxis] - else: - points_data = points_data.T - for i, values in enumerate(points_data): - points_df[f'col_{i}'] = values - - self.points_df = points_df.set_index(['z', 'y', 'x']).sort_index() - - for p, plotItem in enumerate(self.PlotItems): - imageItem = self.ImageItems[p] - for pointsItems in imageItem.pointsItems.values(): - for pointsItem in pointsItems: - pointsItem.sigClicked.connect(self.pointsClicked) - - def pointsClicked(self, item, points, event): - point = points[0] - x, y = point.pos() - coords = (item.z, int(y), int(x)) - point_data = self.points_df.loc[[coords]] - now = datetime.datetime.now().strftime('%H:%M:%S') - print('*'*60) - print(f'Point clicked at {now}. Data:') - print('-'*60) - print(point_data) - print('') - print('*'*60) - - def annotateObjectIDs(self, annotate_labels_idxs=None, init=False): - if init: - self.annotate_labels_idxs = annotate_labels_idxs - self.textItems = [{} for _ in self.PlotItems] - if self.annotate_labels_idxs is None: - return - for i, plotItem in enumerate(self.PlotItems): - if i not in self.annotate_labels_idxs: - continue - plotTextItems = self.textItems[i] - imageItem = self.ImageItems[i] - try: - if init: - # 3D labels (if 3D) - lab = imageItem.lab - else: - lab = imageItem.labImageItem.image - except Exception as err: - lab = imageItem.image - - rp = skimage.measure.regionprops(lab) - for obj in rp: - textItem = plotTextItems.get(obj.label) - yc, xc = obj.centroid[-2:] - if textItem is None: - textItem = pg.TextItem( - text='', anchor=(0.5,0.5), color='r' - ) - plotItem.addItem(textItem) - plotTextItems[obj.label] = textItem - - if self.isObjVisible(obj, imageItem): - text = str(obj.label) - else: - text = '' - - textItem.setText(text) - textItem.setPos(xc, yc) - - # plotItem.enableAutoRange() - - def clearLabels(self): - for textItems in self.textItems: - for textItem in textItems.values(): - textItem.setText('') - - def updateIDs(self): - self.clearLabels() - try: - self.annotateObjectIDs( - annotate_labels_idxs=self.annotate_labels_idxs - ) - except Exception as err: - pass - - def show(self, block=False, screenToWindowRatio=None): - super().show(block=block) - if screenToWindowRatio is None: - return - screenGeometry = self.screen().geometry() - screenWidth = screenGeometry.width() - screenHeight = screenGeometry.height() - finalWidth = int(screenToWindowRatio*screenWidth) - finalHeight = int(screenToWindowRatio*screenHeight) - screenTop = screenGeometry.top() - screenLeft = screenGeometry.left() - xc, yc = screenLeft + screenWidth/2, screenTop + screenHeight/2 - winLeft = int(xc - finalWidth/2) - winTop = int(yc - finalHeight/2) - self.setGeometry(winLeft, winTop, finalWidth, finalHeight) - - def run(self, block=False, showMaximised=False, screenToWindowRatio=None): - if showMaximised: - self.showMaximized() - else: - self.show(screenToWindowRatio=screenToWindowRatio) - QTimer.singleShot(100, self.autoRange) - - if block: - self.exec_() - - def resizeEvent(self, event) -> None: - self.PlotItems[0].autoRange() - return super().resizeEvent(event) - -class FeatureSelectorButton(QPushButton): - def __init__(self, text, parent=None, alignment=''): - super().__init__(text, parent=parent) - self._isFeatureSet = False - self._alignment = alignment - self.setCursor(Qt.PointingHandCursor) - - def setFeatureText(self, text): - self.setText(text) - self.setFlat(True) - self._isFeatureSet = True - if self._alignment: - self.setStyleSheet(f'text-align:{self._alignment};') - - def enterEvent(self, event) -> None: - if self._isFeatureSet: - self.setFlat(False) - return super().enterEvent(event) - - def leaveEvent(self, event) -> None: - if self._isFeatureSet: - self.setFlat(True) - self.update() - return super().leaveEvent(event) - - def setSizeLongestText(self, longestText): - currentText = self.text() - self.setText(longestText) - w, h = self.sizeHint().width(), self.sizeHint().height() - self.setMinimumWidth(w+10) - # self.setMinimumHeight(h+5) - self.setText(currentText) - -class CheckableSpinBoxWidgets: - def __init__(self, isFloat=True): - if isFloat: - self.spinbox = FloatLineEdit() - else: - self.spinbox = SpinBox() - self.checkbox = QCheckBox('Activate') - self.spinbox.setEnabled(False) - self.checkbox.toggled.connect(self.spinbox.setEnabled) - - def value(self): - if not self.checkbox.isChecked(): - return - return self.spinbox.value() - -class Label(QLabel): - def __init__(self, parent=None, force_html=False): - super().__init__(parent) - self._force_html = force_html - - def setText(self, text): - if self._force_html: - text = html_utils.paragraph(text) - super().setText(text) - - -class LabelItem(pg.LabelItem): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def bbox(self): - xl, yl = self.pos().x(), self.pos().y() - wl, hl = self.itemRect().width(), self.itemRect().height() - return yl, xl, yl+hl, xl+wl - - def setBold(self, bold=True): - self.origPos = self.pos() - self.setText(self.text, bold=bold) - self.setPos(self.origPos) - -class ScaleBar(QGraphicsObject): - sigEditProperties = Signal(object) - sigRemove = Signal(object) - - def __init__(self, imageShape, viewRange, parent=None): - super().__init__(parent) - self.SizeY, self.SizeX = imageShape - self.updateViewRange(viewRange) - self.plotItem = PlotCurveItem() - self.labelItem = LabelItem() - self._x_pad = 5 - self._y_pad = 3 - self._highlighted = False - self._parent = parent - self.clicked = False - self.createContextMenu() - - def updateViewRange(self, viewRange): - xRange, yRange = viewRange - x0, x1 = xRange - y0, y1 = yRange - if x0 < 0: - x0 = 0 - - if x1 > self.SizeX: - x1 = self.SizeX - - if y0 < 0: - y0 = 0 - - if y1 > self.SizeY: - y1 = self.SizeY - - self.xmax = x1 - self.xmin = x0 - - self.ymax = y1 - self.ymin = y0 - - def createContextMenu(self): - self.contextMenu = QMenu() - action = QAction('Edit properties...', self.contextMenu) - action.triggered.connect(self.emitEditProperties) - self.contextMenu.addSeparator() - action = QAction('Remove', self.contextMenu) - action.triggered.connect(self.emitRemove) - self.contextMenu.addAction(action) - - def emitEditProperties(self): - self.setHighlighted(False) - self.sigEditProperties.emit(self.properties()) - - def emitRemove(self): - self.sigRemove.emit(self) - - def isHighlighted(self): - return self._highlighted - - def setHighlighted(self, highlighted): - if self._highlighted and highlighted: - return - - if not self._highlighted and not highlighted: - return - - pen = self.highlightPen if highlighted else self.pen - self.labelItem.setBold(bold=highlighted) - self.plotItem.setPen(pen) - - self._highlighted = highlighted - - def showContextMenu(self, x, y): - self.contextMenu.popup(QPoint(int(x), int(y))) - - def properties(self): - properties = { - 'thickness': self._thickness, - 'length_pixel': self._length, - 'length_unit': self._length_unit, - 'is_text_visible': self._is_text_visible, - 'color': self._color, - 'loc': self._loc, - 'font_size': float(self._font_size[:-2]), - 'unit': self._unit, - 'num_decimals': self._num_decimals, - 'move_with_zoom': self._move_with_zoom, - } - return properties - - def move(self, xm, ym): - self._loc = 'Custom' - - Dy = ym - self.yc - Dx = xm - self.xc - - x0 = self.x0c + Dx - x1 = x0 + self._length - y0 = y1 = self.y0c + Dy - self.plotItem.setData([x0, x1], [y0, y1]) - self.setTextPos() - - def paint(self, painter, option, widget): - pass - - def boundingRect(self): - ymin, xmin, ymax, xmax = self.bbox() - return QRectF(xmin, ymin, xmax-xmin, ymax-ymin) - - def setLocationProperty(self, loc: str): - self._loc = loc - - def setMoveWithZoomProperty(self, move_with_zoom): - self._move_with_zoom = move_with_zoom - - def setProperties( - self, - length_pixel, - length_unit, - thickness=3, - color='w', - is_text_visible=True, - loc='top-left', - font_size=12, - unit='', - num_decimals=0, - move_with_zoom=False - ): - self._loc = loc - self._color = color - self._length = length_pixel - self._length_unit = length_unit - self._is_text_visible = is_text_visible - self._font_size = f'{font_size}px' - self._unit = unit - self._num_decimals = num_decimals - self._move_with_zoom = move_with_zoom - self._thickness = thickness - self.pen = pg.mkPen(width=thickness, color=color, cosmetic=False) - self.highlightPen = pg.mkPen( - width=thickness+2, color=color, cosmetic=False - ) - self.pen.setCapStyle(Qt.PenCapStyle.FlatCap) - self.highlightPen.setCapStyle(Qt.PenCapStyle.FlatCap) - self.plotItem.setPen(self.pen) - - def updatePhysicalLength(self, PhysicalSizeX): - length_unit = self._length_unit - unit = self._unit - length_um = _core.convert_length(length_unit, unit, 'μm') - length_pixel = length_um/PhysicalSizeX - self._length = length_pixel - self.update() - - def addToAxis(self, ax): - ax.addItem(self.plotItem) - ax.addItem(self.labelItem) - - def setText(self): - if self._is_text_visible: - number = round(self._length_unit, self._num_decimals) - if self._num_decimals == 0: - number = int(number) - text = f'{number} {self._unit}' - else: - text = '' - self.labelItem.setText( - text, color=self._color, size=self._font_size - ) - - def setTextPos(self): - xx, yy = self.plotItem.getData() - x0 = xx[0] - y0 = yy[0] - xc = x0 + self._length/2 - wl = self.labelItem.itemRect().width() - hl = self.labelItem.itemRect().height() - xl = xc-wl/2 - yt = y0-hl - self.labelItem.setPos(xl, yt) - - def updatePosViewRangeChanged(self, viewRange): - if self._loc == 'custom': - xx, yy = self.plotItem.getData() - x0p = xx[0] - y0p = yy[0] - xcp = x0p + self._length/2 - hl = self.labelItem.itemRect().height() - ycp = y0p - hl/2 - x0 = self.xmin - y0 = self.ymin - x_range = self.xmax - x0 - y_range = self.ymax - y0 - Dx_perc = (xcp - x0)/x_range - Dy_perc = (ycp - y0)/y_range - - self.updateViewRange(viewRange) - - X0 = self.xmin - Y0 = self.ymin - - X_range = self.xmax - X0 - Y_range = self.ymax - Y0 - - Xcp = X0 + (Dx_perc*X_range) - Ycp = Y0 + (Dy_perc*Y_range) - X0p = Xcp - (self._length/2) - Y0p = Ycp + (hl/2) - - X1p = X0p + self._length - Y1p = Y0p - - self.plotItem.setData([X0p, X1p], [Y0p, Y1p]) - else: - self.updateViewRange(viewRange) - self.update() - - def getStartXCoordFromLoc(self, loc): - if loc == 'custom': - xx, yy = self.plotItem.getData() - x0 = xx[0] - return x0 - - self.setText() - wl = self.labelItem.itemRect().width() - if loc.find('left') != -1: - x0 = self._x_pad + self.xmin - xc = x0 + self._length/2 - xl = xc-wl/2 - if xl < x0: - # Text is larger than line --> move line to the right - x0 = self._x_pad + abs(xl-self._x_pad) - else: - x0 = self.xmax - self._length - self._x_pad - xc = x0 + self._length/2 - x1 = x0 + self._length - xr = xc+wl/2 - if xr > x1: - # Text is larger than line --> move line to the left - delta_overshoot = xr - x1 - x0 = x0 - delta_overshoot - return x0 - - def getStartYCoordFromLoc(self, loc): - if loc == 'custom': - xx, yy = self.plotItem.getData() - y0 = yy[0] - return y0 - - self.setText() - textHeight = self.labelItem.itemRect().height() - if loc.find('top') != -1: - return textHeight + self._y_pad + self.ymin - else: - return self.ymax - self._y_pad - self._thickness - - def update(self): - x0 = self.getStartXCoordFromLoc(self._loc) # + self._thickness/2 - y0 = self.getStartYCoordFromLoc(self._loc) - - x1 = x0 + self._length # - self._thickness/2 - self.plotItem.setData([x0, x1], [y0, y0]) - - self.setText() - self.setTextPos() - - def draw(self, length_pixel, length_unit, **kwargs): - self.setProperties(length_pixel, length_unit, **kwargs) - self.update() - - def bbox(self): - y_line_min, x_line_min, y_line_max, x_line_max = self.plotItem.bbox() - y_lab_min, x_lab_min, y_lab_max, x_lab_max = self.labelItem.bbox() - ymin = min(y_line_min, y_lab_min) - xmin = min(x_line_min, x_lab_min) - ymax = max(y_line_max, y_lab_max) - xmax = max(x_line_max, x_lab_max) - return ymin, xmin, ymax, xmax - - def mousePressed(self, x, y): - self.clicked = True - self.xc, self.yc = x, y - xx, yy = self.plotItem.getData() - self.x0c = xx[0] - self.y0c = yy[0] - - def removeFromAxis(self, ax): - ax.removeItem(self.labelItem) - ax.removeItem(self.plotItem) - -class ComboBox(QComboBox): - sigTextChanged = Signal(str) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._previousText = None - self._valueChanged = False - self.currentTextChanged.connect(self.emitTextChanged) - self.installEventFilter(self) - - def eventFilter(self, object, event) -> bool: - if object == self and event.type() == QEvent.Type.Wheel: - # Forward event to parent so QScrollArea can scroll - QApplication.sendEvent(self.parent(), event) - return True # Consume for the combo itself - - return super().eventFilter(object, event) - - def text(self): - return self.currentText() - - def emitTextChanged(self, text): - self._valueChanged = True - self.sigTextChanged.emit(text) - - def mousePressEvent(self, event): - self._previousText = self.currentText() - super().mousePressEvent(event) - - def previousText(self): - return self._previousText - - def addItems(self, items): - super().addItems(items) - self._previousText = items[0] - - def itemsText(self): - return [self.itemText(i) for i in range(self.count())] - - def setCurrentIndex(self, idx): - itemsText = self.itemsText() - currentText = itemsText[idx] - self._valueChanged = currentText != self._previousText - self._previousText = self.currentText() - super().setCurrentIndex(idx) - - def setCurrentText(self, text): - currentText = text - self._valueChanged = currentText != self._previousText - self._previousText = self.currentText() - super().setCurrentText(text) - -class SetMeasurementsGroupBox(QGroupBox): - def __init__( - self, title, itemsText, checkable=True, itemsInfo=None, - lastSelection=None, itemsInfoUrls=None, parent=None - ): - super().__init__(parent) - - if itemsInfo is None: - itemsInfo = {} - - if itemsInfo is None: - itemsInfoUrls = {} - - highlightRgba = _palettes._highlight_rgba() - r, g, b, a = highlightRgba - self._highlightStylesheetColor = f'rgb({r}, {g}, {b})' - - self.setTitle(title) - self.setCheckable(checkable) - - mainLayout = QVBoxLayout() - - scrollArea = QScrollArea() - scrollArea.setWidgetResizable(True) - scrollAreaLayout = QVBoxLayout() - scrollAreaWidget = QWidget() - self.scrollAreaWidget = scrollAreaWidget - self.scrollAreaLayout = scrollAreaLayout - - self.checkboxes = {} - for text in itemsText: - rowLayout = QHBoxLayout() - infoText = itemsInfo.get(text) - infoUrl = itemsInfoUrls.get(text) - if infoText is not None or infoUrl is not None: - infoButton = infoPushButton() - infoButton.setCursor(Qt.WhatsThisCursor) - rowLayout.addWidget(infoButton) - - if infoText is not None: - infoButton.itemText = text - infoButton.infoText = infoText - infoButton.clicked.connect(self.showInfo) - - if infoUrl is not None: - infoButton.itemText = text - infoButton.infoUrl = infoUrl - infoButton.clicked.connect(self.openInfoUrl) - - checkbox = QCheckBox(text) - checkbox.setParent(self.scrollAreaWidget) - checkbox.setChecked(True) - rowLayout.addWidget(checkbox) - rowLayout.addStretch(1) - - self.checkboxes[text] = checkbox - - scrollAreaLayout.addLayout(rowLayout) - - scrollAreaLayout.addStretch(1) - - scrollAreaWidget.setLayout(scrollAreaLayout) - scrollArea.setWidget(scrollAreaWidget) - self.scrollArea = scrollArea - - buttonsLayout = QHBoxLayout() - self.selectAllButton = selectAllPushButton() - self.selectAllButton.sigClicked.connect(self.setCheckedAll) - - buttonsLayout.addStretch(1) - buttonsLayout.addWidget(self.selectAllButton) - self.buttonsLayout = buttonsLayout - - if lastSelection is not None: - self.lastSelection = lastSelection - self.loadLastSelButton = reloadPushButton( - ' Load last selection... ' - ) - self.loadLastSelButton.clicked.connect(self.loadLastSelection) - buttonsLayout.addWidget(self.loadLastSelButton) - - mainLayout.addWidget(scrollArea) - mainLayout.addSpacing(10) - mainLayout.addLayout(buttonsLayout) - - self.setLayout(mainLayout) - - def openInfoUrl(self): - url = self.sender().infoUrl - QDesktopServices.openUrl(QUrl(url)) - # import webbrowser - # url = self.sender().infoUrl - # webbrowser.open(url) - - def getWidthNoScrollBarNeeded(self): - width = ( - self.scrollArea.verticalScrollBar().sizeHint().width() - # self.scrollAreaLayout.contentsRect().width() - + self.scrollAreaWidget.sizeHint().width() - + 30 - ) - buttonsWidth = 0 - for i in range(self.buttonsLayout.count()): - widget = self.buttonsLayout.itemAt(i).widget() - if not isinstance(widget, QPushButton): - continue - buttonsWidth += widget.sizeHint().width() + 16 - largerWidth = max(width, buttonsWidth) - return largerWidth - - def resizeWidthNoScrollBarNeeded(self): - width = self.getWidthNoScrollBarNeeded() - self.setMinimumWidth(width) - # self.setFixedWidth(width) - - def loadLastSelection(self): - for text, checkbox in self.checkboxes.items(): - checked = self.lastSelection.get(text, False) - checkbox.setChecked(checked) - - def showInfo(self): - infoText = self.sender().infoText - itemText = self.sender().itemText - - title = f'{itemText} description' - msg = myMessageBox() - msg.setWidth(int(self.screen().size().width()/2)) - msg.information(self, title, infoText) - - def setCheckedAll(self, button, checked): - for checkbox in self.checkboxes.values(): - checkbox.setChecked(checked) - - def highlightCheckboxesFromSearchText(self, text): - for checkbox in self.checkboxes.values(): - if not text: - highlighted = False - else: - highlighted = checkbox.text().lower().find(text.lower()) != -1 - - self.setCheckboxHighlighted(highlighted, checkbox) - - def setCheckboxHighlighted(self, highlighted, checkbox): - if highlighted: - checkbox.setStyleSheet( - f'background: {self._highlightStylesheetColor}; color: black' - ) - self.scrollArea.ensureWidgetVisible(checkbox) - else: - checkbox.setStyleSheet('') - -class SearchLineEdit(QLineEdit): - def __init__(self, parent=None): - super().__init__(parent) - - self.initSearch() - self.setFocusPolicy(Qt.ClickFocus) - - def focusInEvent(self, event) -> None: - super().focusInEvent(event) - if super().text() == 'Search...': - self.setText('') - self.setStyleSheet('') - - def focusOutEvent(self, event) -> None: - super().focusOutEvent(event) - if not super().text(): - self.initSearch() - - def initSearch(self): - self.setText('Search...') - self.setStyleSheet('color: rgb(150, 150, 150)') - self.clearFocus() - - def text(self): - if super().text() == 'Search...': - return '' - return super().text() - -class ToolButtonTextIcon(rightClickToolButton): - def __init__(self, text='', parent=None): - super().__init__(parent=parent) - self._text = text - self._penColor = _palettes.text_pen_color() - - def setText(self, text): - self._text = text - self.update() - - def text(self): - return self._text - - def paintEvent(self, event): - QToolButton.paintEvent(self, event) - p = QPainter(self) - - pen = pg.mkPen(color=self._penColor, width=2) - p.setPen(pen) - - w, h = self.width(), self.height() - sf = 0.7 - rect_w = w*sf - rect_h = h*sf - x = (w-rect_w)/2 - y = (h-rect_h)/2 - rect = QRectF(x, y, rect_w, rect_h) - - font = p.font() - font.setBold(True) - font.setPixelSize(int(h/len(self._text))) - p.setFont(font) - - p.drawText(rect, Qt.AlignCenter, self._text) - p.end() - -class RulerPlotItem(pg.PlotDataItem): - def __init__(self, *args, **kwargs): - self.labelItem = pg.LabelItem() - super().__init__(*args, **kwargs) - - def setData(self, *args, lengthText='', **kwargs): - super().setData(*args, **kwargs) - self.labelItem.setText('') - if not lengthText: - return - self.setLengthText(lengthText) - - def setLengthText(self, lengthText): - xx, yy = self.getData() - x0, x1 = sorted(xx) - y0, y1 = sorted(yy) - xc = round(x0 + (x1-x0)/2) - yc = round(y0 + (y1-y0)/2) - self.labelItem.setText(lengthText, size='11px', color='r') - # xc = x0 + self._length/2 - wl = self.labelItem.itemRect().width() - hl = self.labelItem.itemRect().height() - xl = xc-wl/2 - yt = y0-hl - self.labelItem.setPos(xl, yt) - -class VectorLineEdit(QLineEdit): - valueChanged = Signal(object) - valueChangeFinished = Signal(object) - - def __init__(self, parent=None, initial=None): - super().__init__(parent) - - self._minimum = -np.inf - - float_re = float_regex() - vector_regex = fr'\(?\[?{float_re}(,\s?{float_re})+\)?\]?' - regex = fr'^{vector_regex}$|^{float_re}$' - self.validRegex = regex - - regExp = QRegularExpression(regex) - self.setValidator(QRegularExpressionValidator(regExp)) - self.setAlignment(Qt.AlignCenter) - - self.textChanged.connect(self.emitValueChanged) - self.editingFinished.connect(self.emitValueChangeFinished) - if initial is None: - self.setText('0.0') - - font = QFont() - font.setPixelSize(11) - self.setFont(font) - - def emitValueChangeFinished(self): - value = self.value() - self.textChanged.disconnect() - self.editingFinished.disconnect() - self.setValue(value) - self.textChanged.connect(self.emitValueChanged) - self.editingFinished.connect(self.emitValueChangeFinished) - - self.emitValueChanged(self.text(), signal=self.valueChangeFinished) - - def emitValueChanged(self, text, signal=None): - m = re.match(self.validRegex, text) - if m is None: - self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) - return - - if signal is None: - signal = self.valueChanged - - self.setStyleSheet('') - signal.emit(self.value()) - - def increaseValue(self, step): - value = self.value() - if isinstance(value, (float, int)): - value += step - else: - value = [val+step for val in value] - value = str(value).lstrip('[').rstrip(']') - self.setValue(value) - self.emitValueChangeFinished() - - def decreaseValue(self, step): - value = self.value() - if isinstance(value, (float, int)): - value -= step - else: - value = [val-step for val in value] - value = str(value).lstrip('[').rstrip(']') - self.setText(value) - self.emitValueChangeFinished() - - def setValue(self, value): - if isinstance(value, (float, int)): - if value < self._minimum: - value = self._minimum - else: - clipped = [] - for val in value: - if val < self._minimum: - val = self._minimum - clipped.append(val) - value = str(clipped).lstrip('[').rstrip(']') - self.setText(value) - - def setText(self, text): - super().setText(str(text)) - - def clipValue(self, val: float): - if val < self._minimum: - val = self._minimum - return val - - def value(self): - m = re.match(self.validRegex, self.text()) - if m is None: - return 0.0 - - try: - value = self.clipValue(float(self.text())) - return value - except Exception as e: - text = self.text() - text = text.replace('(', '') - text = text.replace(')', '') - text = text.replace('[', '') - text = text.replace(']', '') - values = text.split(',') - return [self.clipValue(float(value)) for value in values] - - def setMinimum(self, minimum): - self._minimum = float(minimum) - -class LatexLabel(QLabel): - def __init__(self, latexText, parent=None): - super().__init__(parent) - - latexText = latexText.replace('', '$') - if not latexText.startswith('$'): - latexText = f'${latexText}' - - if not latexText.endswith('$'): - latexText = f'{latexText}$' - - latexText = latexText.replace('
    ', '\n') - - pixmap = self.mathTex_to_QPixmap(latexText) - self.setPixmap(pixmap) - - def mathTex_to_QPixmap(self, mathTex): - #---- set up a mpl figure instance ---- - - fig = matplotlib.figure.Figure() - fig.patch.set_facecolor('none') - fig.set_canvas(FigureCanvasAgg(fig)) - renderer = fig.canvas.get_renderer() - - #---- plot the mathTex expression ---- - - ax = fig.add_axes([0, 0, 1, 1]) - ax.axis('off') - ax.patch.set_facecolor('none') - t = ax.text( - 0, 0, mathTex, - ha='left', va='bottom', - fontsize=13, - color=TEXT_COLOR - ) - - #---- fit figure size to text artist ---- - - fwidth, fheight = fig.get_size_inches() - fig_bbox = fig.get_window_extent(renderer) - - text_bbox = t.get_window_extent(renderer) - - tight_fwidth = text_bbox.width * fwidth / fig_bbox.width - tight_fheight = text_bbox.height * fheight / fig_bbox.height - - fig.set_size_inches(tight_fwidth, tight_fheight) - - #---- convert mpl figure to QPixmap ---- - - buf, size = fig.canvas.print_to_buffer() - qimage = QImage.rgbSwapped(QImage( - buf, size[0], size[1], QImage.Format_ARGB32) - ) - qpixmap = QPixmap(qimage) - - return qpixmap - - -class LabelsWidget(QWidget): - def __init__(self, texts, wrapText=False, parent=None): - super().__init__(parent=parent) - - layout = QVBoxLayout() - - texts = self.fixParagraphTags(texts) - - self.textLengths = [] - self.labels = [] - for t, text in enumerate(texts): - if not text: - continue - - if text.startswith(''): - layout.addSpacing(10) - label = LatexLabel(text) - layout.addWidget(label, alignment=Qt.AlignCenter) - try: - # Add spacing only if next text is not a formula - nextText = texts[t+1] - if not nextText.startswith(''): - layout.addSpacing(10) - except IndexError: - layout.addSpacing(10) - elif text.startswith(''): - text = ( - text.removeprefix('') - .removeprefix('') - ) - label = CopiableCommandWidget(command=text, parent=self) - layout.addWidget(label) - else: - label = QLabel(text) - label.setWordWrap(wrapText) - label.setOpenExternalLinks(True) - layout.addWidget(label) - if wrapText: - self.textLengths.append(1) - self.textLengths.extend( - [len(line) for line in text.split('
    ')] - ) - - self.labels.append(label) - - self.nCharsLongestLine = max(self.textLengths, default=1) - - layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - - def setWordWrap(self, wordWrap): - for label in self.labels: - label.setWordWrap(wordWrap) - - def fixParagraphTags(self, texts): - firstText = texts[0] - if firstText.find('

    ', firstText) - if searched is None: - openTag = '

    ' - else: - openTag = searched.group() - - not_allowed = {' ', '\n'} - - fixedTexts = [] - for text in texts: - if text.startswith('') or text.startswith(''): - fixedTexts.append(text) - continue - - if set(text) <= not_allowed: - # Ignore texts that are made of only \n and spaces - continue - - if text.find('

    ') == -1: - text = rf'{text}<\p>' - - if text.find(openTag) == -1: - text = f'{openTag}{text}' - - text = text.replace('\n', '') - - fixedTexts.append(text) - return fixedTexts - -class SwitchPlaneCombobox(QComboBox): - sigPlaneChanged = Signal(str, str) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.addItems(['xy', 'zy', 'zx']) - self._previousPlane = 'xy' - self.currentTextChanged.connect(self.emitPlaneChanged) - - def emitPlaneChanged(self, plane): - self.sigPlaneChanged.emit(self._previousPlane, plane) - self._previousPlane = plane - - def setPlane(self, plane): - self.setCurrentText(plane) - - def setCurrentText(self, text): - self._previousPlane = self.plane() - super().setCurrentText(text) - - def plane(self): - return self.currentText() - - def depthAxes(self): - plane = self.plane() - for axes in 'xyz': - if axes not in plane: - return axes - -class SamInputPointsWidget(QWidget): - sigValueChanged = Signal(str) - - def __init__(self, parent=None): - super().__init__(parent) - - _layout = QHBoxLayout() - - self.lineEntry = ElidingLineEdit(parent=self) - self.lineEntry.setAlignment(Qt.AlignCenter) - self.lineEntry.editingFinished.connect(self.emitValueChanged) - - self.editButton = editPushButton() - self.browseButton = browseFileButton( - ext={'CSV': '.csv'}, - start_dir=myutils.getMostRecentPath() - ) - - _layout.addWidget(self.lineEntry) - _layout.addWidget(self.editButton) - _layout.addWidget(self.browseButton) - - _layout.setStretch(0, 1) - _layout.setStretch(1, 0) - _layout.setStretch(1, 0) - - self.browseButton.sigPathSelected.connect(self.browseCsvFiles) - self.editButton.clicked.connect(self.showInfoEditPoints) - - _layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(_layout) - - def emitValueChanged(self, text): - self.sigValueChanged.emit(text) - - def showInfoEditPoints(self): - note = html_utils.to_note( - 'When adding points with the mouse left button you will create a ' - 'new object for each point. To add multiple points for the same ' - 'object click the right button.' - ) - txt = html_utils.paragraph(f""" - To add input points for Segment Anything open the GUI (module 3), - load the data, and then click on the button
    - on the top toolbar called Add points layer.

    - Select the option "Add points by clicking" and click on the image - to add points.

    - Finally, save the table and browse to the saved file on this widget. -
    {note} - """) - msg = myMessageBox(wrapText=False) - msg.information(self, 'Info edit points', txt) - - def criticalMissingColumn(self, filepath, missing_col): - txt = html_utils.paragraph(f""" - [ERROR]: The selected table does not contain the column - {missing_col}.

    - A valid table must contain the columns (x, y, id) - with an additional z column for 3D z-stacks data. - """) - msg = myMessageBox(wrapText=False) - msg.critical(self, 'Invalid table', txt) - - def setValue(self, value: str): - self.lineEntry.setText(value) - - def value(self): - return self.lineEntry.text() - - def cast_dtype(self, value) -> str: - return str(value) - - def browseCsvFiles(self, filepath): - # Check if metadata.csv file exists with basename and set only the - # endname of the file - df_points = pd.read_csv(filepath) - for col in ('x', 'y', 'id'): - if col not in df_points.columns: - self.criticalMissingColumn(filepath, col) - return - - # Check if basename is present in metadata - folderpath = os.path.dirname(filepath) - basename = None - for file in myutils.listdir(folderpath): - if file.endswith('metadata.csv'): - metadata_csv_path = os.path.join(folderpath, file) - df = pd.read_csv(metadata_csv_path, index_col='Description') - try: - basename = df.at['basename', 'values'] - except Exception as e: - basename = None - break - - # Check if file is inside images folder and get basename - is_images_folder = folderpath.endswith('Images') - if is_images_folder: - images_path = folderpath - img_filepath = None - for file in myutils.listdir(images_path): - if file.endswith('.tif'): - img_filepath = os.path.join(images_path, file) - break - - if file.endswith('aligned.npz'): - img_filepath = os.path.join(images_path, file) - break - - if img_filepath is not None: - posData = load.loadData(img_filepath, '', QParent=self) - posData.getBasenameAndChNames() - filename = os.path.basename(filepath) - if filename.startswith(posData.basename): - basename = posData.basename - - if basename is None: - self.lineEntry.setText(filepath) - else: - filename = os.path.basename(filepath) - endname = filename[len(basename):] - self.lineEntry.setText(endname) - -class PointsScatterPlotItem(pg.ScatterPlotItem): - sigHoverEntered = Signal(object, object, object) - - def __init__(self, *args, ax=None, show_data_as_tip=False, **kwargs): - self.textItem = annotate.TextAnnotationsScatterItem( - size=12, anchor=(1.0, 1.0) - ) - self.textItem.createSymbols( - [str(int_id) for int_id in range(200)], includeBold=False - ) - # self._textItems = {} - super().__init__(*args, **kwargs) - self.textItem.setParentItem(self) - self._font = QFont() - self._font.setPixelSize(12) - self.show_data_as_tip = show_data_as_tip - self.drawIds = True - self.ax = ax - self.sigHovered.connect(self.onHover) - self.lastHoveredPoint = None - - def onHover(self, item, points, event): - if len(points) == 0: - vb = self.getViewBox() - vb.setToolTip('') - return - - if self.lastHoveredPoint != points[0]: - self.sigHoverEntered.emit(item, points, event) - self.lastHoveredPoint = points[0] - - if not self.opts['hoverable']: - return - - if not self.show_data_as_tip: - return - - tip_li = [str(point.data()) for point in points] - tip = '\n\n'.join(tip_li) - - vb = self.getViewBox() - vb.setToolTip(tip) - - - def setData(self, *args, **kwargs): - self.clearTextItems() - super().setData(*args, **kwargs) - data = kwargs.get('data') - if data is None: - return - - if len(data) == 0: - return - - first_point_data = data[0] - if not isinstance(first_point_data, (int, str)): - return - - if not self.drawIds: - return - - if self.show_data_as_tip: - return - - color = self.opts['brush'].color() - self.textItem.setColors({'id': color.getRgb()}) - size = self.opts['size'] - radius = size/2 - # xx, yy = args - # for x, y, point_data in zip(xx, yy, data): - for point in self.points(): - text = str(point.data()) - if not text: - continue - - x, y = point.pos().x(), point.pos().y() - xt, yt = x+radius-0.5, y-radius+0.5 - opts = { - 'text': text, - 'bold': False, - 'color_name': 'id', - } - data = self.textItem.addObjAnnot( - (xt, yt), anchor=(-0.3, 1.3), **opts - ) - self.textItem.appendData(data, opts['text']) - - self.textItem.draw() - # hexColor = color.name() - # htmlText = html_utils.span( - # text, color=hexColor, font_size='13pt', bold=True - # ) - - # textItem = self._textItems.get((x, y)) - # if textItem is None: - # textItem = pg.TextItem(html=htmlText, anchor=(0, 1)) - # textItem.setParentItem(self) - # self._textItems[(x, y)] = textItem - # self.ax.addItem(textItem) - # else: - # textItem.setHtml(htmlText) - # textItem.setPos(x+radius-0.5, y-radius+0.5) - - def clearTextItems(self): - self.textItem.clearData() - # for textItem in self._textItems.values(): - # textItem.setText('') - - def clear(self): - super().clear() - self.clearTextItems() - - def setVisible(self, visible): - super().setVisible(visible) - self.textItem.setVisible(visible) - -class installJavaDialog(myMessageBox): - def __init__(self, parent=None): - super().__init__(parent) - - self.setWindowTitle('Install Java') - self.setIcon('SP_MessageBoxWarning') - - txt_macOS = html_utils.paragraph(""" - Your system doesn't have the Java Development Kit - installed
    and/or a C++ compiler which is required for the installation of - javabridge

    - Cell-ACDC is now going to install Java for you.

    - NOTE: After clicking on "Install", follow the instructions
    - on the terminal
    . You will be asked to confirm steps and insert
    - your password to allow the installation.


    - If you prefer to do it manually, cancel the process
    - and follow the instructions below. - """) - - txt_windows = html_utils.paragraph(""" - Unfortunately, installing pre-compiled version of - javabridge failed.

    - Cell-ACDC is going to try to compile it now.

    - However, before proceeding, you need to install - Java Development Kit
    and a C++ compiler.

    - See instructions below on how to install it. - """) - - if not is_win: - self.instructionsButton = self.addButton('Show intructions...') - self.instructionsButton.setCheckable(True) - self.instructionsButton.disconnect() - self.instructionsButton.clicked.connect(self.showInstructions) - installButton = self.addButton('Install') - installButton.disconnect() - installButton.clicked.connect(self.installJava) - txt = txt_macOS - else: - okButton = self.addButton('Ok') - txt = txt_windows - - self.cancelButton = self.addButton('Cancel') - - label = self.addText(txt) - label.setWordWrap(False) - - self.resizeCount = 0 - - def addInstructionsWindows(self): - self.scrollArea = QScrollArea() - _container = QWidget() - _layout = QVBoxLayout() - for t, text in enumerate(myutils.install_javabridge_instructions_text()): - label = QLabel() - label.setText(text) - if (t == 1 or t == 2): - label.setOpenExternalLinks(True) - label.setTextInteractionFlags(Qt.TextBrowserInteraction) - code_layout = QHBoxLayout() - code_layout.addWidget(label) - copyButton = QToolButton() - copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - copyButton.setIcon(QIcon(':edit-copy.svg')) - copyButton.setText('Copy link') - if t==1: - copyButton.textToCopy = myutils.jdk_windows_url() - code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) - else: - copyButton.textToCopy = myutils.cpp_windows_url() - screenshotButton = QToolButton() - screenshotButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - screenshotButton.setIcon(QIcon(':cog.svg')) - screenshotButton.setText('See screenshot') - code_layout.addWidget(screenshotButton, alignment=Qt.AlignLeft) - code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) - screenshotButton.clicked.connect(self.viewScreenshot) - copyButton.clicked.connect(self.copyToClipboard) - code_layout.setStretch(0, 2) - code_layout.setStretch(1, 0) - _layout.addLayout(code_layout) - else: - _layout.addWidget(label) - - - _container.setLayout(_layout) - self.scrollArea.setWidget(_container) - self.currentRow += 1 - self._layout.addWidget( - self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop - ) - - # Stretch last row - self.currentRow += 1 - self._layout.setRowStretch(self.currentRow, 1) - - def viewScreenshot(self, checked=False): - self.screenShotWin = view_visualcpp_screenshot(parent=self) - self.screenShotWin.show() - - def addInstructionsMacOS(self): - self.scrollArea = QScrollArea() - _container = QWidget() - _layout = QVBoxLayout() - for t, text in enumerate(myutils.install_javabridge_instructions_text()): - label = QLabel() - label.setText(text) - # label.setWordWrap(True) - if (t == 1 or t == 2): - label.setWordWrap(True) - code_layout = QHBoxLayout() - code_layout.addWidget(label) - copyButton = QToolButton() - copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - copyButton.setIcon(QIcon(':edit-copy.svg')) - copyButton.setText('Copy') - if t==1: - copyButton.textToCopy = myutils._install_homebrew_command() - else: - copyButton.textToCopy = myutils._brew_install_java_command() - copyButton.clicked.connect(self.copyToClipboard) - code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) - # code_layout.addStretch(1) - code_layout.setStretch(0, 2) - code_layout.setStretch(1, 0) - _layout.addLayout(code_layout) - else: - _layout.addWidget(label) - _container.setLayout(_layout) - self.scrollArea.setWidget(_container) - self.currentRow += 1 - self._layout.addWidget( - self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop - ) - - # Stretch last row - self.currentRow += 1 - self._layout.setRowStretch(self.currentRow, 1) - self.scrollArea.hide() - - def addInstructionsLinux(self): - self.scrollArea = QScrollArea() - _container = QWidget() - _layout = QVBoxLayout() - for t, text in enumerate(myutils.install_javabridge_instructions_text()): - label = QLabel() - label.setText(text) - # label.setWordWrap(True) - if (t == 1 or t == 2 or t==3): - label.setWordWrap(True) - code_layout = QHBoxLayout() - code_layout.addWidget(label) - copyButton = QToolButton() - copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - copyButton.setIcon(QIcon(':edit-copy.svg')) - copyButton.setText('Copy') - if t==1: - copyButton.textToCopy = myutils._apt_update_command() - elif t==2: - copyButton.textToCopy = myutils._apt_install_java_command() - elif t==3: - copyButton.textToCopy = myutils._apt_gcc_command() - copyButton.clicked.connect(self.copyToClipboard) - code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) - # code_layout.addStretch(1) - code_layout.setStretch(0, 2) - code_layout.setStretch(1, 0) - _layout.addLayout(code_layout) - else: - _layout.addWidget(label) - _container.setLayout(_layout) - self.scrollArea.setWidget(_container) - self.currentRow += 1 - self._layout.addWidget( - self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop - ) - - # Stretch last row - self.currentRow += 1 - self._layout.setRowStretch(self.currentRow, 1) - self.scrollArea.hide() - - def copyToClipboard(self): - cb = QApplication.clipboard() - cb.clear(mode=cb.Clipboard) - cb.setText(self.sender().textToCopy, mode=cb.Clipboard) - print('Command copied!') - - def showInstructions(self, checked): - if checked: - self.instructionsButton.setText('Hide instructions') - self.origHeight = self.height() - self.resize(self.width(), self.height()+300) - self.scrollArea.show() - else: - self.instructionsButton.setText('Show instructions...') - self.scrollArea.hide() - func = partial(self.resize, self.width(), self.origHeight) - QTimer.singleShot(50, func) - - def installJava(self): - import subprocess - try: - if is_mac: - try: - subprocess.check_call(['brew', 'update']) - except Exception as e: - subprocess.run( - myutils._install_homebrew_command(), - check=True, text=True, shell=True - ) - subprocess.run( - myutils._brew_install_java_command(), - check=True, text=True, shell=True - ) - elif is_linux: - subprocess.run( - myutils._apt_gcc_command()(), - check=True, text=True, shell=True - ) - subprocess.run( - myutils._apt_update_command()(), - check=True, text=True, shell=True - ) - subprocess.run( - myutils._apt_install_java_command()(), - check=True, text=True, shell=True - ) - self.close() - except Exception as e: - print('=======================') - traceback.print_exc() - print('=======================') - msg = myMessageBox(wrapText=False) - err_msg = html_utils.paragraph(""" - Automatic installation of Java failed.

    - Please, try manually by following the instructions provided - below (click on "Show instructions..." button). Thanks - """) - msg.critical( - self, 'Java installation failed', err_msg - ) - - def show(self, block=False): - super().show(block=False) - print(is_linux) - if is_win: - self.addInstructionsWindows() - elif is_mac: - self.addInstructionsMacOS() - elif is_linux: - self.addInstructionsLinux() - self.move(self.pos().x(), 20) - if is_win: - self.resize(self.width(), self.height()+200) - if block: - self._block() - - def exec_(self): - self.show(block=True) - -class selectTrackerGUI(QDialogListbox): - def __init__( - self, SizeT, currentFrameNo=1, parent=None - ): - trackers = myutils.get_list_of_trackers() - super().__init__( - 'Select tracker', 'Select one of the following trackers', - trackers, multiSelection=False, parent=parent - ) - self.setWindowTitle('Select tracker') - - self.selectFramesGroupbox = selectStartStopFrames( - SizeT, currentFrameNum=currentFrameNo, parent=parent - ) - - self.mainLayout.insertWidget(1, self.selectFramesGroupbox) - - def ok_cb(self, event): - if self.selectFramesGroupbox.warningLabel.text(): - return - else: - self.startFrame = self.selectFramesGroupbox.startFrame_SB.value() - self.stopFrame = self.selectFramesGroupbox.stopFrame_SB.value() - super().ok_cb(event) - -def addWidgetToScrollArea( - widget, - resizeMinWidthNoHorizontalScrollbar=False, - resizeMinHeightNoVerticalScrollbar=False - ): - container = QWidget() - layout = QVBoxLayout() - layout.addWidget(widget) - layout.addStretch(1) - container.setLayout(layout) - scrollArea = QScrollArea() - scrollArea.setWidgetResizable(True) - scrollArea.setWidget(container) - - if resizeMinWidthNoHorizontalScrollbar: - scrollArea.setMinimumWidth( - container.sizeHint().width() - + scrollArea.verticalScrollBar().sizeHint().width() - ) - - if resizeMinHeightNoVerticalScrollbar: - scrollArea.setMinimumHeight( - container.sizeHint().height() - + scrollArea.horizontalScrollBar().sizeHint().height() - ) - - return scrollArea - -class CheckableAction(QAction): - clicked = Signal(bool) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.setCheckable(True) - self.toggled.connect(self.emitClicked) - - def emitClicked(self, checked): - self.clicked.emit(checked) - - def setChecked(self, checked): - self.toggled.disconnect() - super().setChecked(checked) - self.toggled.connect(self.emitClicked) - -class OddSpinBox(SpinBox): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.setSingleStep(2) - self.editingFinished.connect(self.roundToOdd) - - def roundToOdd(self): - if self.value() % 2 == 1: - return - - self.setValue(self.value()+1) - -class TimestampItem(LabelItem): - sigEditProperties = Signal(object) - sigRemove = Signal(object) - - def __init__( - self, SizeY, SizeX, viewRange, - secondsPerFrame=1, - parent=None, - start_timedelta=None - ): - self._secondsPerFrame = secondsPerFrame - self._x_pad = 3 - self._y_pad = 2 - self.xmin, self.ymin = 0, 0 - self.SizeY = SizeY - self.SizeX = SizeX - self._highlighted = False - self._parent = parent - if start_timedelta is None: - start_timedelta = datetime.timedelta(seconds=0) - self._start_timedelta = start_timedelta - self.clicked = False - super().__init__(self) - self.updateViewRange(viewRange) - self.createContextMenu() - - def setSecondsPerFrame(self, secondsPerFrame): - self._secondsPerFrame = secondsPerFrame - - def getBboxViewRange(self, viewRange): - xRange, yRange = viewRange - x0, x1 = xRange - y0, y1 = yRange - if x0 < 0: - x0 = 0 - - if x1 > self.SizeX: - x1 = self.SizeX - - if y0 < 0: - y0 = 0 - - if y1 > self.SizeY: - y1 = self.SizeY - - return x0, y0, x1, y1 - - def updateViewRange(self, viewRange): - x0, y0, x1, y1 = self.getBboxViewRange(viewRange) - - self.xmax = x1 - self.xmin = x0 - - self.ymax = y1 - self.ymin = y0 - - def createContextMenu(self): - self.contextMenu = QMenu() - action = QAction('Edit properties...', self.contextMenu) - action.triggered.connect(self.emitEditProperties) - self.contextMenu.addSeparator() - action = QAction('Remove', self.contextMenu) - action.triggered.connect(self.emitRemove) - self.contextMenu.addAction(action) - - def emitRemove(self): - self.sigRemove.emit(self) - - def mousePressed(self, x, y): - self.clicked = True - - def emitEditProperties(self): - self.setHighlighted(False) - self.sigEditProperties.emit(self.properties()) - - def isHighlighted(self): - return self._highlighted - - def setHighlighted(self, highlighted): - if self._highlighted and highlighted: - return - - if not self._highlighted and not highlighted: - return - - super().setText(self.text, bold=highlighted) - - self._highlighted = highlighted - - def showContextMenu(self, x, y): - self.contextMenu.popup(QPoint(int(x), int(y))) - - def setLocationProperty(self, loc: str): - self._loc = loc - - def properties(self): - properties = { - 'color': self._color, - 'loc': self._loc, - 'font_size': int(self._font_size[:-2]), - 'start_timedelta': self._start_timedelta, - 'move_with_zoom': self._move_with_zoom, - } - return properties - - def draw(self, frame_i, **kwargs): - self.setProperties(**kwargs) - self.update(frame_i) - - def update(self, frame_i): - self.setPosFromLoc() - self.setText(frame_i) - - def setMoveWithZoomProperty(self, move_with_zoom): - self._move_with_zoom = move_with_zoom - - def updatePosViewRangeChanged(self, viewRange): - if self._loc == 'custom': - textHeight = self.itemRect().height() - textWidth = self.itemRect().width() - x0p = self.pos().x() - y0p = self.pos().y() - xcp = x0p + textWidth/2 - ycp = y0p + textHeight/2 - x0 = self.xmin - y0 = self.ymin - x_range = self.xmax - x0 - y_range = self.ymax - y0 - Dx_perc = (xcp - x0)/x_range - Dy_perc = (ycp - y0)/y_range - - self.updateViewRange(viewRange) - - X0 = self.xmin - Y0 = self.ymin - - X_range = self.xmax - X0 - Y_range = self.ymax - Y0 - - Xcp = X0 + (Dx_perc*X_range) - Ycp = Y0 + (Dy_perc*Y_range) - X0p = Xcp - (textWidth/2) - Y0p = Ycp - (textHeight/2) - - y_pos_max = self.ymax - textHeight - self._y_pad - if Y0p > y_pos_max: - Y0p = y_pos_max - - x_pos_max = self.xmax - textWidth - self._x_pad - if X0p > x_pos_max: - X0p = x_pos_max - - self.setPos(X0p, Y0p) - else: - self.updateViewRange(viewRange) - self.setPosFromLoc() - - - def setPosFromLoc(self): - textHeight = self.itemRect().height() - textWidth = self.itemRect().width() - if self._loc == 'custom': - return - - if self._loc.find('top') != -1: - y0 = self._y_pad + self.ymin - else: - y0 = self.ymax - textHeight - self._y_pad - - if self._loc.find('left') != -1: - x0 = self._x_pad + self.xmin - else: - x0 = self.xmax - textWidth - self._x_pad - - self.setPos(x0, y0) - - def setProperties( - self, - color=(255, 255, 255), - font_size='13px', - loc='top-left', - start_timedelta=None, - move_with_zoom=False - ): - if start_timedelta is not None: - self._start_timedelta = start_timedelta - self._color = color - self._loc = loc - self._font_size = font_size - self._move_with_zoom = move_with_zoom - - def move(self, xm, ym): - Dy = ym - self.yc - Dx = xm - self.xc - x0 = self.x0c + Dx - y0 = self.y0c + Dy - self.setPos(x0, y0) - - def mousePressed(self, x, y): - self.clicked = True - self.xc, self.yc = x, y - self.x0c = self.pos().x() - self.y0c = self.pos().y() - - def setText(self, frame_i): - if not isinstance(frame_i, int): - return - - seconds = frame_i*self._secondsPerFrame - timedelta = datetime.timedelta(seconds=round(seconds)) - - diff_seconds = ( - timedelta.total_seconds() - + self._start_timedelta.total_seconds() - ) - if diff_seconds >= 0: - timedelta = datetime.timedelta(seconds=round(diff_seconds)) - text = str(timedelta) - else: - abs_diff = abs( - timedelta.total_seconds() - + self._start_timedelta.total_seconds() - ) - abs_timedelta = datetime.timedelta(seconds=round(abs_diff)) - text = f'-{abs_timedelta}' - - # printl(timedelta) - super().setText( - text, color=self._color, size=self._font_size - ) - - def addToAxis(self, ax): - ax.addItem(self) - - def removeFromAxis(self, ax): - ax.removeItem(self) - -class FontSizeWidget(QWidget): - sigTextChanged = Signal(str) - - def __init__(self, parent=None, unit='px', initalVal=12): - super().__init__(parent) - - layout = QHBoxLayout() - - self.spinbox = SpinBox() - self.spinbox.setValue(initalVal) - layout.addWidget(self.spinbox) - - self.unitLabel = QLabel(unit) - layout.addWidget(self.unitLabel) - - layout.setContentsMargins(0, 0, 0, 0) - layout.setStretch(0, 1) - layout.setStretch(1, 0) - - self.setLayout(layout) - - self.spinbox.valueChanged.connect(self.emitTextChanged) - - def emitTextChanged(self, value): - self.sigTextChanged.emit(self.text()) - - def setValue(self, value): - if isinstance(value, str): - value = int(value.replace(self.unitLabel.text(), '').strip()) - self.spinbox.setValue(value) - - def setText(self, text): - value = int(text.replace(self.unitLabel.text(), '').strip()) - self.setValue(value) - - def text(self): - return f'{self.spinbox.value()}{self.unitLabel.text()}' - - def value(self): - return self.spinbox.value() - -class RangeSelector(QWidget): - sigRangeChanged = Signal(object, object) - sigLowValueChanged = Signal(object) - sigHighValueChanged = Signal(object) - sigRangeManuallyChanged = Signal(object, object) - - def __init__(self, parent=None, integers=False, ordered=True): - super().__init__(parent) - - self._integers = integers - self._ordered = ordered - - layout = QHBoxLayout() - - if integers: - self.lowSpinbox = SpinBox() - self.highSpinbox = SpinBox() - else: - self.lowSpinbox = DoubleSpinBox() - self.highSpinbox = DoubleSpinBox() - - layout.addWidget(self.lowSpinbox) - layout.addWidget(self.highSpinbox) - - layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - - self.lowSpinbox.valueChanged.connect(self.lowValueChanged) - self.highSpinbox.valueChanged.connect(self.highValueChanged) - - self.lowSpinbox.editingFinished.connect(self.lowValueEditingFinished) - self.highSpinbox.editingFinished.connect(self.highValueEditingFinished) - - def lowValueEditingFinished(self): - self.sigRangeManuallyChanged.emit(*self.range()) - self.emitRangeChanged() - - def highValueEditingFinished(self): - self.sigRangeManuallyChanged.emit(*self.range()) - self.emitRangeChanged() - - def lowValueChanged(self, value): - self.emitRangeChanged() - self.sigLowValueChanged.emit(value) - - def highValueChanged(self, value): - self.emitRangeChanged() - self.sigHighValueChanged.emit(value) - - def emitRangeChanged(self): - self.sigRangeChanged.emit(*self.range()) - - def setRangeNoEmit(self, lowValue, highValue, decimals=3): - self.lowSpinbox.valueChanged.disconnect() - self.highSpinbox.valueChanged.disconnect() - - self.setRange(round(lowValue, 3), round(highValue, 3)) - - self.lowSpinbox.valueChanged.connect(self.lowValueChanged) - self.highSpinbox.valueChanged.connect(self.highValueChanged) - - def setRange(self, lowValue, highValue): - # if lowValue > highValue and self._ordered: - # highValue = lowValue + 1 - - if self._integers: - lowValue = round(lowValue) - highValue = round(highValue) - - self.lowSpinbox.setValue(lowValue) - self.highSpinbox.setValue(highValue) - - def range(self): - return self.lowSpinbox.value(), self.highSpinbox.value() - -class LineEdit(QLineEdit): - def __init__(self, parent=None): - super().__init__(parent) - self.setAlignment(Qt.AlignCenter) - - def value(self): - return self.text() - - def setValue(self, value): - self.setText(str(value)) - -class PreProcessingSelector(QComboBox): - sigValuesChanged = Signal(dict, int) - - def __init__(self, parent=None): - super().__init__(parent) - self._parent = parent - - self.addItems(PREPROCESS_MAPPER.keys()) - self.methodToDefaultValuesMapper = {} - self.step_n = -1 - self.setParamsWindow = None - - def htmlInfo(self): - href = html_utils.href_tag('GitHub page', urls.issues_url) - docstring = PREPROCESS_MAPPER[self.currentText()]['docstring'] - if docstring is None: - text = 'This function is not documented, yet. Sorry :(' - else: - text = html_utils.rst_docstring_to_html(docstring) - text = ( - f'{text}

    ' - f'Feel free to submit an issue on our {href} if you ' - 'need help with this filter.' - ) - return text - - def setParams(self, method: str, kwargToValueMapper: Dict[str, str]): - self.methodToDefaultValuesMapper[method] = kwargToValueMapper - - def askSetParams(self, df_metadata=None, addApplyButton=False): - method = self.currentText() - function = PREPROCESS_MAPPER[method]['function'] - params_argspecs = myutils.get_function_argspec( - function, - args_to_skip={ - 'logger_func', - 'apply_to_all_zslices', - 'apply_to_all_frames' - } - ) - default_values = self.methodToDefaultValuesMapper.get(method, {}) - for kwarg, value in default_values.items(): - for p, param_argspec in enumerate(params_argspecs): - if param_argspec.name != kwarg: - continue - - if hasattr(param_argspec.type, 'cast_dtype'): - cls = param_argspec.type - value = cls.cast_dtype(value) - else: - value = param_argspec.type(value) - - if value == param_argspec.default: - continue - param_argspec = param_argspec._replace(default=value) - params_argspecs[p] = param_argspec - - if self.setParamsWindow is not None: - self.setParamsWindow.raise_() - self.setParamsWindow.activateWindow() - return - - self.setParamsWindow = apps.FunctionParamsDialog( - params_argspecs, - df_metadata=df_metadata, - function_name=method, - addApplyButton=addApplyButton, - parent=self._parent - ) - self.setParamsWindow.sigValuesChanged.connect(self.emitValuesChanged) - self.setParamsWindow.emitValuesChanged() - self.setParamsWindow.exec_() - if self.setParamsWindow.cancel: - return - - self.setParams(method, self.setParamsWindow.function_kwargs) - - function_kwargs = self.setParamsWindow.function_kwargs - self.setParamsWindow = None - - return function_kwargs - - def emitValuesChanged(self, functionKwargs: dict): - self.sigValuesChanged.emit(functionKwargs, self.step_n) - -class RescaleImageJroisGroupbox(QGroupBox): - def __init__(self, TZYX_out_shape, parent=None): - super().__init__(parent) - - self.setTitle('Rescale ROIs') - self.setCheckable(True) - - gridLayout = QGridLayout() - - dims = ('Z', 'Y', 'X') - self.widgets = {} - for row, SizeD in enumerate(TZYX_out_shape[1:]): - if SizeD == 1: - continue - - dim = dims[row] - inputSpinbox = SpinBox() - inputSpinbox.setMinimum(1) - inputSpinbox.setValue(SizeD) - - outZwidget = QLineEdit() - outZwidget.setReadOnly(True) - outZwidget.setAlignment(Qt.AlignCenter) - # outZwidget.setValue(SizeD) - outZwidget.setText(str(SizeD)) - - row0 = row*2 - row1 = row0+1 - gridLayout.addWidget(QLabel(f'{dim}-dimension: '), row1, 0) - - gridLayout.addWidget(QLabel('Input size'), row0, 1) - gridLayout.addWidget(inputSpinbox, row1, 1) - - gridLayout.addWidget(QLabel('Output size'), row0, 2) - gridLayout.addWidget(outZwidget, row1, 2) - - self.widgets[dim] = (inputSpinbox, SizeD) - - self.setLayout(gridLayout) - - def inputOutputSizes(self): - if not self.isChecked(): - return - - sizes = { - dim: (spinbox.value(), int(SizeD)) - for dim, (spinbox, SizeD) in self.widgets.items() - } - return sizes - -class WhitelistLineEdit(KeepIDsLineEdit): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def setText(self, IDs): - if not isinstance(IDs, set) and not isinstance(IDs, list): - raise TypeError('IDs must be a set or list') - - formatted_text = myutils.format_IDs(IDs) - super().setText(formatted_text) - -class WhitelistIDsToolbar(ToolBar): - sigWhitelistChanged = Signal(list) - sigViewOGIDs = Signal(bool) - sigWhitelistAccepted = Signal(list) - sigAddNewIDs = Signal(bool) - sigLoadOGLabs = Signal() - sigTrackOGagainstPreviousFrame = Signal(bool) - - def __init__(self, addNewIDToggleState, *args) -> None: - super().__init__(*args) - - whitelistLineEditLabel = QLabel('Whitelist IDs: ') - self.addWidget(whitelistLineEditLabel) - - self.whitelistLineEdit = WhitelistLineEdit( - whitelistLineEditLabel, parent=self - ) - self.whitelistLineEdit.sigEnterPressed.connect(self.accept) - self.whitelistLineEdit.sigIDsChanged.connect(self.emitWhitelistChanged) - self.addWidget(self.whitelistLineEdit) - - # accept button - self.acceptButton = self.addButton(':greenTick.svg') - self.acceptButton.triggered.connect(self.accept) - - # add a view OG toggle - self.viewOGToggle = self.addButton(':eye.svg', checkable=True) - viewOGTooltip = ( - 'View the non-whitelisted segmentation mask.\n\n' - 'You can activate this to add new IDs to the whitelist,\n' - 'correct tracking errors, etc.' - ) - self.viewOGToggle.setChecked(True) - self.viewOGToggle.setToolTip(viewOGTooltip) - self.viewOGToggle.setShortcut('Shift+K') - key = 'View the non-whitelisted segmentation mask' - self.widgetsWithShortcut[key] = self.viewOGToggle - - self.viewOGToggle.toggled.connect(self.emitViewOGIDs) - self.emitViewOGIDs(True) - - # add a Toggle to add new IDs - self.addNewIDToggle = QCheckBox( - 'Automatically add new IDs to whitelist' - ) - self.addNewIDToggle.setChecked(addNewIDToggleState) - self.addWidget(self.addNewIDToggle) - self.addNewIDToggle.toggled.connect(self.emitAddNewIDs) - self.emitAddNewIDs(addNewIDToggleState) - - self.addSeparator() - - # add a button to load og df - self.loadOGButton = self.addButton(':open_file.svg') - self.loadOGButton.triggered.connect(self.sigLoadOGLabs.emit) - self.loadOGButton.setToolTip( - 'Select which segmentation mask file to load ' - 'as the non-whitelisted masks' - ) - - self.TrackOGagainstPreviousFrameButton = self.addButton(':segment.svg') - self.TrackOGagainstPreviousFrameButton.triggered.connect( - self.sigTrackOGagainstPreviousFrame.emit - ) - self.TrackOGagainstPreviousFrameButton.setToolTip( - 'Track the non-whitelisted segmentation masks against the previous frame and copy over successfull tacks' - ) - - self.addSeparator() - - # add an info button - self.infoButton = self.addButton(':info.svg') - self.infoButton.triggered.connect(self.showInfo) - - # add a spacer to the toolbar - spacer = QWidget() - spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred) - self.addWidget(spacer) - - def emitWhitelistChanged(self, whitelist): - self.sigWhitelistChanged.emit(whitelist) - - def emitViewOGIDs(self, checked): - self.sigViewOGIDs.emit(checked) - - def accept(self): - try: - whitelist = self.whitelistLineEdit.IDs - except AttributeError as e: - if "has no attribute 'IDs'" in str(e): - whitelist = list() - self.viewOGToggle.toggled.disconnect() - self.viewOGToggle.setChecked(False) - self.viewOGToggle.toggled.connect(self.emitViewOGIDs) - self.sigWhitelistAccepted.emit(whitelist) - - def emitAddNewIDs(self, checked): - self.sigAddNewIDs.emit(checked) - - def showInfo(self): - msg = myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - This function is used to track a subset of segmented objects.

    - - To add new IDs to the white list, click with left mouse button on the - object to add.
    - You can also write directly into the Whitelist IDs widget
    - and separate the IDs by commas.

    - - After adding the IDs, click on the "Accept" button to remove the - non-whitelisted objects.
    - Every time you visit a new frame, the non-whitelisted objects will - be removed automatically.

    - Use the "Eye" button to view the non-whitelisted segmentation masks.
    - This will allow you to correct tracking errors, add new IDs to the - white list, etc.

    - - If you previously saved the whitelisted masks, you can load the - non-whitelisted file
    - by clicking on the "Load file" button to restart from where you - left last time. - """ - ) - msg.information(self, 'White list IDs', txt) - -class MagicPromptsToolbar(ToolBar): - sigPromptTypeChanged = Signal(object, str) - sigComputeOnZoom = Signal(object) - sigComputeOnImage = Signal(object) - sigClearPoints = Signal(object) - sigClearPointsOnZmom = Signal(object) - sigInitSelectedModel = Signal( - str, object, list, list, str, object - ) - sigViewModelParams = Signal( - str, object, list, list, str, object, object, object - ) - sigInterpolateZslice = Signal(bool) - - def __init__(self, parent=None): - super().__init__(parent) - - self._parent = parent - - prompt_types = ( - 'Points', - ) - - self.selectModelAction = self.addButton(':select-list.svg') - self.selectModelAction.setToolTip( - 'Select the promptable model to use' - ) - - self.viewModelParamsAction = self.addButton(':view.svg') - self.viewModelParamsAction.setToolTip( - 'View the currently selected model parameters' - ) - self.viewModelParamsAction.setDisabled(True) - - self.addSeparator() - - self.promptTypeCombobox = self.addComboBox( - prompt_types, label='Prompt type: ', - ) - - self.addSeparator() - - self.interpolateZslicesCheckbox = self.addCheckBox( - 'Interpolate points on missing z-slices', checked=False - ) - self.interpolateZslicesCheckbox.setToolTip( - 'If checked, when working with 3D segmentation masks, you can ' - 'add points on some z-slices only and the points on the missing ' - 'z-slices will be determined by linear interpolation.\n\n' - 'This is useful when working with 2D models that segments ' - 'each z-slice independently.\n\n' - 'NOTE: The points will be added only when running the model and ' - 'removed afterwards.' - ) - - self.addSeparator() - - self.computeOnZoomAction = self.addButton(':compute-zoom.svg') - self.computeOnZoomAction.setToolTip( - 'Compute the segmentation on the zoomed area of the image ' - '(faster)' - ) - - self.computeAction = self.addButton(':compute.svg') - self.computeAction.setToolTip( - 'Compute the segmentation on the whole image' - ) - - self.clearPointsAction = self.addButton(':clear-points.svg') - self.clearPointsAction.setToolTip( - 'Clear all points' - ) - self.clearPointsAction.setDisabled(True) - - self.clearPointsActionOnZoom = self.addButton(':clear-points-zoom.svg') - self.clearPointsActionOnZoom.setToolTip( - 'Clear all points on the zoomed area of the image' - ) - self.clearPointsActionOnZoom.setDisabled(True) - - self.addSeparator() - - self.infoAction = self.addButton(':info.svg') - self.infoAction.setToolTip( - 'Show instructions how to use promptable models' - ) - - self.addSeparator() - - self.infoAction.triggered.connect(self.showHelp) - self.selectModelAction.triggered.connect(self.selectModel) - self.viewModelParamsAction.triggered.connect(self.viewModelParams) - self.promptTypeCombobox.sigTextChanged.connect( - self.emitPromptTypeChanged - ) - self.computeOnZoomAction.triggered.connect( - self.emitSigComputeOnZoom - ) - self.computeAction.triggered.connect( - self.emitSigComputeOnImage - ) - self.clearPointsAction.triggered.connect( - self.emitSigClearPoints - ) - self.clearPointsActionOnZoom.triggered.connect( - self.emitSigClearPointsOnZoom - ) - self.interpolateZslicesCheckbox.toggled.connect( - self.sigInterpolateZslice.emit - ) - - def showHelp(self): - msg = myMessageBox(wrapText=False) - txt = html_utils.paragraph(""" - This toolbar allows you to use promptable models for - segmentation.

    - - To use a promptable model, first select the model by clicking on the - "Select model" button.
    - This will open a dialog where you can select the model to use.

    - - After selecting the model, you can view the model parameters - by clicking on the "View model parameters" button.

    - - To add points to the image, make sure you have points layer correctly - initialised. You should see controls
    - called "Left-click ID" and "Right-click ID".

    - - You can add points for a new object by left-clicking on the image, - while you can add points
    - for the same object by right-clicking. - To delete a point, click on it again.

    - - To change the right-click ID, - you can either type in the corresponding control,
    - or type the object id on the keyboard followed by "Enter".

    - - To add negative prompts (i.e., for the background), use the - same action you use to delete objects
    - (default is middle-click on Windows and Cmd+Click on MacOS).

    - Note that you can also add object-specific negative prompts (i.e., - they affect only that object)
    - by adding the negative prompt on the newly segmented object - directly.

    - - Once you are happy with the added points, click either the - "Compute on zoomed area"
    - button or the "Compute on whole image" button.

    - - Finally, you can clear all points by clicking on the - "Clear points" button.

    - - Note that you can also save the points by clicking on the - "Save points" button to load them later and start from - where you left.

    - """) - msg.information( - self, 'Promptable models help', txt - ) - - def emitSigClearPoints(self): - self.sigClearPoints.emit(self) - - def emitSigClearPointsOnZoom(self): - self.sigClearPointsOnZmom.emit(self) - - def emitSigComputeOnZoom(self): - self.sigComputeOnZoom.emit(self) - - def emitSigComputeOnImage(self): - self.sigComputeOnImage.emit(self) - - def selectModel(self): - win = apps.SelectPromptableModelDialog(parent=self._parent) - win.exec_() - if win.cancel: - print('Promptable model selection cancelled') - return - - model_name = win.model_name - print(f'Importing promptable model {model_name}...') - - # Download model weights, consistent with gui.py - downloadWin = apps.downloadModel(model_name, parent=self._parent) - downloadWin.download() - - acdcPromptSegment = myutils.import_promptable_segment_module(model_name) - init_argspecs, segment_argspecs = myutils.getModelArgSpec( - acdcPromptSegment - ) - - try: - help_url = acdcPromptSegment.url_help() - except AttributeError: - help_url = None - - self._model_name = model_name - self._acdcPromptSegment = acdcPromptSegment - self._init_argspecs = init_argspecs - self._segment_argspecs = segment_argspecs - self._help_url = help_url - - self.sigInitSelectedModel.emit( - model_name, - acdcPromptSegment, - init_argspecs, - segment_argspecs, - help_url, - self - ) - - def setInitializedModel(self, init_kwargs, segment_kwargs): - self._init_kwargs = init_kwargs - self._segment_kwargs = segment_kwargs - - def viewModelParams(self): - self.sigViewModelParams.emit( - self._model_name, - self._acdcPromptSegment, - self._init_argspecs, - self._segment_argspecs, - self._help_url, - self._init_kwargs, - self._segment_kwargs, - self - ) - - def emitPromptTypeChanged(self, text): - self.sigPromptTypeChanged.emit(self, text) - -class KeySequenceFromText(QKeySequence): - def __init__(self, text: str): - if isinstance(text, str): - text = macShortcutToWindows(text) - super().__init__(text) - self._text = text - - def toString(self): - if isinstance(self._text, str): - return windowsShortcutToMac(self._text) - else: - return windowsShortcutToMac(super().toString()) - -def modifierKeyToText(modifierKey: int): - if modifierKey == Qt.ControlModifier: - return 'Ctrl' - elif modifierKey == Qt.AltModifier: - return 'Alt' - elif modifierKey == Qt.ShiftModifier: - return 'Shift' - elif modifierKey == Qt.MetaModifier: - return 'Meta' - else: - return '' - -class TimeWidget(QGroupBox): - sigValueChanged = Signal(object) - - def __init__(self, parent=None, orientation='vertical'): - super().__init__(parent) - - mainLayout = QHBoxLayout() - - if orientation == 'vertical': - spinboxesLayout = QVBoxLayout() - elif orientation == 'horizontal': - spinboxesLayout = QHBoxLayout() - else: - raise ValueError('orientation must be "vertical" or "horizontal"') - - self.signCombobox = QComboBox() - self.signCombobox.addItems(('+', '-')) - self.signCombobox.currentTextChanged.connect(self.emitValueChanged) - - mainLayout.addWidget(self.signCombobox) - - self.spinboxesMapper = {} - units = ('days', 'hours', 'minutes', 'seconds') - for unit in units: - layout = QHBoxLayout() - spinbox = SpinBox() - spinbox.setMinimum(0) - label = QLabel(unit) - layout.addWidget(spinbox) - layout.addWidget(label) - spinbox.valueChanged.connect(self.emitValueChanged) - self.spinboxesMapper[unit] = spinbox - spinboxesLayout.addLayout(layout) - - mainLayout.addLayout(spinboxesLayout) - - self.setLayout(mainLayout) - mainLayout.setContentsMargins(5, 5, 5, 5) - - def values(self): - values = {} - for unit, spinbox in self.spinboxesMapper.items(): - values[unit] = spinbox.value() - - signText = self.signCombobox.currentText() - return values, sign_int_mapper[signText] - - def setValuesFromTimedelta(self, timedelta): - total_seconds = timedelta.total_seconds() - sign = 1 if total_seconds > 0 else -1 - days = timedelta.days - hours, remainder = divmod(timedelta.seconds, 3600) - minutes, seconds = divmod(remainder, 60) - - values = { - 'days': days, - 'hours': hours, - 'minutes': minutes, - 'seconds': seconds - } - - self.setValues(values, sign=sign) - - def timedelta(self): - values, sign = self.values() - return datetime.timedelta(**values)*sign - - def setValues(self, values: dict[str, int | float], sign=1): - signText = '+' if sign > 0 else '-' - self.signCombobox.setCurrentText(signText) - for unit, value in values.items(): - spinbox = self.spinboxesMapper[unit] - spinbox.setValue(value) - - def emitValueChanged(self, value): - self.sigValueChanged.emit(self.values()) - -class PointsLayersToolbar(ToolBar): - sigAddPointsLayer = Signal() - - def __init__(self, name='Points layers', parent=None): - - super().__init__(name, parent) - - self.guiWin = parent - - self.setContextMenuPolicy(Qt.PreventContextMenu) - - self.addPointsLayerAction = self.addButton(':addPointsLayer.svg') - - self.addSeparator() - - self.pointsLayersLabel = self.addLabel('Points layers: ') - - self.addPointsLayerAction.triggered.connect( - self.emitAddPointsLayer - ) - self.doAddPointsZslicesInterpolation = False - - def emitAddPointsLayer(self): - self.sigAddPointsLayer.emit() - - def fromActionToDataFrame(self, action, posData, isSegm3D=False): - df = pd.DataFrame( - columns=['frame_i', 'Cell_ID', 'z', 'y', 'x', 'id'] - ) - frames_vals = [] - IDs = [] - zz = [] - yy = [] - xx = [] - ids = [] - pos_i = self.guiWin.pos_i - if pos_i not in action.pointsData: - printl('No points data for position', pos_i) # should really not happen, but its not a disaster if it does - return df - pointsDataPos = action.pointsData[pos_i] - for frame_i, framePointsData in pointsDataPos.items(): - if posData.SizeZ > 1: - for z, zSlicePointsData in framePointsData.items(): - yyxx = zip( - zSlicePointsData['y'], zSlicePointsData['x'] - ) - for y, x in yyxx: - if isSegm3D: - ID = posData.lab[int(z), int(y), int(x)] - else: - ID = posData.lab[int(y), int(x)] - frames_vals.append(frame_i) - IDs.append(ID) - zz.append(z) - yy.append(y) - xx.append(x) - ids.extend(zSlicePointsData['id']) - else: - yyxx = zip(framePointsData['y'], framePointsData['x']) - for y, x in yyxx: - ID = posData.lab[int(y), int(x)] - frames_vals.append(frame_i) - IDs.append(ID) - yy.append(y) - xx.append(x) - ids.extend(framePointsData['id']) - df['frame_i'] = frames_vals - df['Cell_ID'] = IDs - df['y'] = yy - df['x'] = xx - df['id'] = ids - if zz: - df['z'] = zz - - df = self.addPointsZslicesInterpolation(df, posData.lab, isSegm3D) - - return df - - def addPointsZslicesInterpolation( - self, - df: pd.DataFrame, - lab: np.ndarray, - isSegm3D: bool - ): - if not self.doAddPointsZslicesInterpolation: - return df - - if not isSegm3D: - return df - - if 'z' not in df.columns: - return df - - df_new_rows = [] - for (frame_i, point_id), df_id in df.groupby(['frame_i', 'id']): - xx = df_id['x'].values - yy = df_id['y'].values - zz = df_id['z'].values - - p0, d = core.linear_fit_3d(xx, yy, zz) - - new_row_df = df_id.iloc[[0]].copy() - - z0, z1 = int(np.min(zz)), int(np.max(zz)) - for z in range(z0, z1+1): - if z in zz: - continue - - t_int = (z - p0[2]) / d[2] - x_new, y_new, z_new = p0 + t_int * d - new_row_df['z'] = round(z_new) - new_row_df['y'] = round(y_new) - new_row_df['x'] = round(x_new) - - Cell_ID = lab[ - int(round(z_new)), - int(round(y_new)), - int(round(x_new)) - ] - new_row_df['Cell_ID'] = Cell_ID - - df_new_rows.append(new_row_df.copy()) - - if not df_new_rows: - return df - - df_new = pd.concat(df_new_rows, ignore_index=True) - df = pd.concat([df, df_new], ignore_index=True) - df = df.sort_values(by=['frame_i', 'id', 'z']).reset_index(drop=True) - - return df - -class PromptableModelPointsLayerToolbar(PointsLayersToolbar): - def __init__(self, name='Promptable model points layers', parent=None): - super().__init__(name, parent=parent) - - self.isPointsLayerInit = False - - self.addPointsLayerAction.setDisabled(True) - self.addPointsLayerAction.setVisible(False) - - def pointsLayerDf(self, posData, isSegm3D=False): - for action in self.actions()[1:]: - if not hasattr(action, 'button'): - continue - - df = self.fromActionToDataFrame( - action, posData, isSegm3D=isSegm3D - ) - return df - - def scatterItem(self): - for action in self.actions()[1:]: - if not hasattr(action, 'button'): - continue - - return action.scatterItem - -class RectItem(pg.GraphicsObject): - def __init__(self, rect, pen=None, brush=(255, 0, 0, 100), parent=None): - super().__init__(parent) - self._rect = rect - self._pen = pg.mkPen(pen) - self._brush = pg.mkBrush(brush) - self.picture = QPicture() - self._generate_picture() - - def setColor(self, color): - rgba = matplotlib.colors.to_rgba(color, alpha=100/255) - rgba = [round(c*255) for c in rgba] - self._brush = pg.mkBrush(rgba) - self._generate_picture() - self.update() - - def setRect(self, x, y, width, height): - self._rect = QRectF(x, y, width, height) - self._generate_picture() - self.update() - - def setQRect(self, qrect): - self._rect = qrect - self._generate_picture() - self.update() - - @property - def rect(self): - return self._rect - - def _generate_picture(self): - painter = QPainter(self.picture) - painter.setPen(self._pen) - painter.setBrush(self._brush) - painter.drawRect(self._rect) - painter.end() - - def paint(self, painter, option, widget=None): - painter.drawPicture(0, 0, self.picture) - - def boundingRect(self): - return QRectF(self.picture.boundingRect()) - -def get_min_width_for_no_scrollbar(list_widget: QListWidget) -> int: - """ - Calculate the minimum width needed for the QListWidget - so that the horizontal scrollbar will not be required. - """ - font_metrics = QFontMetrics(list_widget.font()) - max_width = 0 - - for i in range(list_widget.count()): - item = list_widget.item(i) - text_width = font_metrics.horizontalAdvance(item.text()) - max_width = max(max_width, text_width) - - # Add padding for icon, scrollbar margin, and frame - padding = 30 # Adjust as needed (depends on style and icons) - return max_width + padding - -class OverlayToolbar(ToolBar): - sigSetTranspacency = Signal(bool) - sigSetSingleChannel = Signal(bool) - - def __init__(self, name='Overlay tools', parent=None): - - super().__init__(name, parent) - - self.guiWin = parent - - self.setContextMenuPolicy(Qt.PreventContextMenu) - - self.addSeparator() - - self.transparencyCheckbox = self.addCheckBox( - text='True transparency (RGBA composite)' - ) - - self.transparencyCheckbox.setToolTip( - 'Activate to achieve true pixel-wise transparency where ' - 'the pixel intensity is 0 or set to 0 using the ' - 'LUT sliders on the left of the images.\n\n' - 'Since it is significantly slower, we recommended to activate this ' - 'only if you need to export images for figures.' - ) - - self.addSeparator() - - self.singleChannelCheckbox = self.addCheckBox( - text='Single channel' - ) - - self.singleChannelCheckbox.setToolTip( - 'When single channel mode is activated, selecting a channel ' - 'will display only that channel in the overlay.' - ) - - self.transparencyCheckbox.toggled.connect(self.sigSetTranspacency.emit) - self.singleChannelCheckbox.toggled.connect( - self.sigSetSingleChannel.emit - ) - - def setTransparent(self, transparent: bool): - self.transparencyCheckbox.setChecked(transparent) - - def isTransparent(self): - return self.transparencyCheckbox.isChecked() - - def isSingleChannel(self): - return self.singleChannelCheckbox.isChecked() - -class OverlayChannelToolButton(GradientToolButton): - def __init__( - self, - channel_name: str, - lut_item: myHistogramLUTitem, - shortcut='0', - parent=None, - ): - super().__init__( - colors=lut_item.gradient.getLookupTable(256), - parent=parent - ) - self._channel_name = channel_name - - lut_item.sigGradientChanged.connect(self.updateColors) - - self.setToolTip( - f'Show/hide "{channel_name}" channel\n\n' - f'Shortcut: {shortcut}' - ) - - self.setCheckable(True) - - def channelName(self): - return self._channel_name - - def updateColors(self, lut_item): - colors = lut_item.gradient.getLookupTable(256) - self._qcolors = [pg.mkColor(c) for c in colors] - self.update() - - def setVisible(self, visible: bool): - super().setVisible(visible) - if not hasattr(self, 'action'): - return - - self.action.setVisible(visible) - -class YeazV2SelectModelNameCombobox(ComboBox): - sigValueChanged = Signal(str) - - def __init__( - self, *args, - custom_select_item_text='Select custom weights file...', - **kwargs - ): - super().__init__(*args, **kwargs) - self._csi_text = custom_select_item_text - self.sigTextChanged.connect(self.onTextChanged) - self.initItems() - - def initItems(self): - from cellacdc.models.YeaZ_v2 import load_models_filepath - models_name, models_name_filepath_mapper = load_models_filepath() - self.addItems(models_name) - - def onTextChanged(self, text): - if text != self._csi_text: - return - - start_dir = myutils.getMostRecentPath() - model_filepath = qtpy.compat.getopenfilename( - parent=self, - caption='Select YeaZ weights file', - filters='All Files (*)', - basedir=start_dir - )[0] - if not model_filepath: - self.setCurrentIndex(0) - return - - msg = html_utils.paragraph(f""" - Insert a name for the following YeaZ model:

    - {model_filepath}
    - """) - modelNameWindow = apps.QLineEditDialog( - title='Insert a name for the model', - msg=msg, - allowEmpty=False, - parent=self - ) - modelNameWindow.exec_() - if modelNameWindow.cancel: - self.setCurrentIndex(0) - return - - model_name = modelNameWindow.enteredValue - - from cellacdc.models.YeaZ_v2 import add_model_filepath - add_model_filepath(model_name, model_filepath) - - self.addItem(model_name) - self.setCurrentText(model_name) - - print( - 'YeaZ_v2 model added!\n\n' - f' * Name: {model_name}\n' - f' * File path: {model_filepath}\n' - ) - - def addItem(self, item): - idx = self.count() - 1 - self.insertItem(idx, item) - - def addItems(self, items): - super().clear() - super().addItems(items) - super().addItem(self._csi_text) - idx = len(items) - font = self.font() - font.setItalic(True) - self.setItemData(idx, font, Qt.FontRole) - - def setValue(self, value: str): - self.setCurrentText(value) - - def value(self, *args): - return self.currentText() - - -class HighlightedIDToolbar(ToolBar): - sigIDChanged = Signal(int) - - def __init__(self, name='Highlighted ID', parent=None): - - super().__init__(name, parent) - - self.spinbox = self.addSpinBox('Highlighted ID: ') - self.spinbox.valueChanged.connect(self.emitSigIDChanged) - - self.addSeparator() - - def emitSigIDChanged(self, *args, **kwargs): - self.sigIDChanged.emit(self.spinbox.value()) - - def setIDNoSignals(self, ID: int): - self.spinbox.blockSignals(True) - self.spinbox.setValue(ID) - self.spinbox.blockSignals(False) - - -class AutoSaveIntervalWidget(QWidget): - sigValueChanged = Signal(float, str) - - def __init__(self, parent=None): - super().__init__(parent) - - layout = QHBoxLayout() - - autoSaveIntervalTooltip = ( - 'Autosave every minutes or frames specified here.' - ) - - self.setToolTip(autoSaveIntervalTooltip) - - self.spinbox = DoubleSpinBox() - self.spinbox.setMinimum(0) - self.spinbox.setValue(2) - self.spinbox.setDecimals(2) - self.spinbox.setSingleStep(1.0) - - layout.addWidget(self.spinbox) - - self.unitCombobox = ComboBox() - self.unitCombobox.addItems(['minutes', 'frames']) - layout.addWidget(self.unitCombobox) - - layout.setStretch(0, 1) - layout.setStretch(1, 0) - layout.setContentsMargins(5, 0, 5, 0) - - self.setLayout(layout) - - self.spinbox.sigValueChanged.connect(self.emitSigValueChanged) - self.unitCombobox.sigTextChanged.connect(self.emitSigValueChanged) - - def emitSigValueChanged(self, *args, **kwargs): - self.sigValueChanged.emit( - self.spinbox.value(), - self.unitCombobox.currentText() - ) - -class CheckableWidget(QWidget): - def __init__(self, widget, valueGetterName='value', parent=None): - super().__init__(parent) - - self.widget = widget - self.valueGetterName = valueGetterName - - widget.setDisabled(True) - - layout = QHBoxLayout() - - layout.addWidget(widget) - - self.checkbox = QCheckBox('Activate') - self.checkbox.toggled.connect(self.setWidgetEnabled) - - layout.addSpacing(5) - layout.addWidget(self.checkbox) - - layout.setContentsMargins(5, 0, 5, 0) - - - self.setLayout(layout) - - def setWidgetEnabled(self, checked): - self.widget.setDisabled(not checked) - - def value(self): - if not self.checkbox.isChecked(): - return - - return getattr(self.widget, self.valueGetterName)() - - -class WandControlsToolbar(ToolBar): - def __init__(self, name='Magic wand controls', parent=None): - super().__init__(name, parent) - - self.toleranceSpinbox = self.addSpinBox('Tolerance [%]: ') - self.toleranceSpinbox.setMinimum(0) - self.toleranceSpinbox.setMaximum(100) - self.toleranceSpinbox.setValue(5) - self.toleranceSpinbox.setToolTip( - 'The tolerance is calculated as a percentage of the minimum-maximum ' - 'pixel values range of the loaded dataset.\n\n' - 'If tolerance is greater than 0, the pixels adjacent to the added ' - 'pixels with value within +- tolerance will be considered part of ' - 'the object.' - ) - self.addLabel(r'% of min-max intensity range ') - - self.addSeparator() - - self.autoFillHolesCheckbox = self.addCheckBox( - 'Auto-fill holes' - ) - - self.addSeparator() - - self.useConvexHullCheckbox = self.addCheckBox( - 'Use convex hull mask' - ) - - self.addSeparator() - -class warnVisualCppRequired(myMessageBox): - def __init__(self, pkg_name='javabridge', parent=None): - super().__init__(parent) - self.screenShotWin = None - - self.setIcon(iconName='SP_MessageBoxWarning') - self.setWindowTitle(f'Installation of {pkg_name} info') - txt = html_utils.paragraph(f""" - Installation of {pkg_name} on Windows requires - Microsoft Visual C++ 14.0 or higher.

    - Cell-ACDC will anyway try to install {pkg_name} now.

    - If the installation fails, please close Cell-ACDC, - then download and install "Microsoft C++ Build Tools" - from the link below - before trying this module again.

    - - https://visualstudio.microsoft.com/visual-cpp-build-tools/ -

    - IMPORTANT: when installing "Microsoft C++ Build Tools" - make sure to select "Desktop development with C++". - Click "See the screenshot" for more details. - """) - seeScreenshotButton = QPushButton('See screenshot...') - okButton = okPushButton('Ok') - okButton = self.addButton('Ok') - okButton.disconnect() - okButton.clicked.connect(self.ok_cb) - self.addButton(seeScreenshotButton) - seeScreenshotButton.disconnect() - seeScreenshotButton.clicked.connect( - self.viewScreenshot - ) - self.addCancelButton(connect=True) - self.addText(txt) - - def ok_cb(self): - self.cancel = False - self.close() - - def viewScreenshot(self, checked=False): - self.screenShotWin = view_visualcpp_screenshot(self) - self.screenShotWin.show() - - def closeEvent(self, event): - if self.screenShotWin is not None: - self.screenShotWin.close() - - return super().closeEvent(event) \ No newline at end of file diff --git a/cellacdc/widgets/__init__.py b/cellacdc/widgets/__init__.py new file mode 100644 index 000000000..85c7bcf3c --- /dev/null +++ b/cellacdc/widgets/__init__.py @@ -0,0 +1,280 @@ +"""GUI widgets package (canvas, controls, toolbars) + components re-exports.""" + +from ..components.palette import * # noqa: F403 +from ..components.progress import * # noqa: F403 +from ..components.buttons import * # noqa: F403 +from ..components.layout import * # noqa: F403 +from ..components.inputs_basic import * # noqa: F403 +from ..components.path_controls import * # noqa: F403 +from ..components.lists import * # noqa: F403 +from ..components.base import QBaseWindow, QBaseDialog # noqa: F401 + +from .canvas import ( + BaseGradientEditorItemImage, + BaseGradientEditorItemLabels, + baseHistogramLUTitem, + labelsGradientWidget, + myColorButton, + myHistogramLUTitem, + overlayLabelsGradientWidget, + BaseImageItem, + BaseLabelsImageItem, + ChildImageItem, + GhostMaskItem, + OverlayImageItem, + ParentImageItem, + _ImShowImageItem, + labImageItem, + ImShow, + ImShowPlotItem, + BaseScatterPlotItem, + ContourItem, + CustomAnnotationScatterPlotItem, + GhostContourItem, + LabelItem, + LabelRoiCircularItem, + MainPlotItem, + PlotCurveItem, + PointsScatterPlotItem, + RectItem, + RulerPlotItem, + ScaleBar, + ScatterPlotItem, + myLabelItem, + DelROI, + PolyLineROI, + ROI, + ZoomROI, + MouseCursor, + ScrollBarWithNumericControl, + labelledQScrollbar, + linkedQScrollbar, + navigateQScrollBar, + sliderWithSpinBox, +) + +from .controls import ( + QDialogListbox, + installJavaDialog, + myMessageBox, + selectTrackerGUI, + view_visualcpp_screenshot, + warnVisualCppRequired, + AutoSaveIntervalWidget, + CheckableWidget, + CheckboxesGroupBox, + CopiableCommandWidget, + FontSizeWidget, + LabelsWidget, + PostProcessSegmSlider, + PostProcessSegmSpinbox, + PreProcessingSelector, + RangeSelector, + RescaleImageJroisGroupbox, + SamInputPointsWidget, + TimeWidget, + YeazV2SelectModelNameCombobox, + formWidget, + guiTabControl, + selectStartStopFrames, + AlphaNumericComboBox, + CenteredDoubleSpinbox, + ComboBox, + DoubleSpinBox, + ExpandableListBox, + FloatLineEdit, + IntLineEdit, + KeySequenceFromText, + LineEdit, + OddSpinBox, + QCenteredComboBox, + QClickableLabel, + ReadOnlyLineEdit, + SearchLineEdit, + ShortcutLineEdit, + SpinBox, + VectorLineEdit, + WhitelistLineEdit, + highlightableQWidgetAction, + mySpinBox, + readOnlyDoubleSpinbox, + readOnlySpinbox, + PixelSizeGroupbox, + SetMeasurementsGroupBox, + _metricsQGBox, + channelMetricsQGBox, + objIntesityMeasurQGBox, + objPropsQGBox, + CheckableAction, + CheckableSpinBoxWidgets, + FeatureSelectorButton, + KeptObjectIDsList, + Label, + LatexLabel, + OrderableListWidget, + SwitchPlaneCombobox, + TimestampItem, + Toggle, + ToggleTerminalButton, + ToggleVisibilityButton, + ToggleVisibilityCheckBox, + expandCollapseButton, + listWidget, + statusBarPermanentLabel, +) + +from .toolbars import ( + GradientToolButton, + ManualBackgroundToolBar, + ManualTrackingToolBar, + OverlayChannelToolButton, + PointsLayerToolButton, + SavePointsLayerButton, + ToolBar, + ToolBarSeparator, + ToolButtonCustomColor, + ToolButtonTextIcon, + customAnnotToolButton, + rightClickToolButton, + CopyLostObjectToolbar, + DrawClearRegionToolbar, + HighlightedIDToolbar, + MagicPromptsToolbar, + OverlayToolbar, + PointsLayersToolbar, + PromptableModelPointsLayerToolbar, + WandControlsToolbar, + WhitelistIDsToolbar, +) + +__all__ = [ + "BaseGradientEditorItemImage", + "BaseGradientEditorItemLabels", + "baseHistogramLUTitem", + "labelsGradientWidget", + "myColorButton", + "myHistogramLUTitem", + "overlayLabelsGradientWidget", + "BaseImageItem", + "BaseLabelsImageItem", + "ChildImageItem", + "GhostMaskItem", + "OverlayImageItem", + "ParentImageItem", + "_ImShowImageItem", + "labImageItem", + "ImShow", + "ImShowPlotItem", + "BaseScatterPlotItem", + "ContourItem", + "CustomAnnotationScatterPlotItem", + "GhostContourItem", + "LabelItem", + "LabelRoiCircularItem", + "MainPlotItem", + "PlotCurveItem", + "PointsScatterPlotItem", + "RectItem", + "RulerPlotItem", + "ScaleBar", + "ScatterPlotItem", + "myLabelItem", + "DelROI", + "PolyLineROI", + "ROI", + "ZoomROI", + "MouseCursor", + "ScrollBarWithNumericControl", + "labelledQScrollbar", + "linkedQScrollbar", + "navigateQScrollBar", + "sliderWithSpinBox", + "QDialogListbox", + "installJavaDialog", + "myMessageBox", + "selectTrackerGUI", + "view_visualcpp_screenshot", + "warnVisualCppRequired", + "AutoSaveIntervalWidget", + "CheckableWidget", + "CheckboxesGroupBox", + "CopiableCommandWidget", + "FontSizeWidget", + "LabelsWidget", + "PostProcessSegmSlider", + "PostProcessSegmSpinbox", + "PreProcessingSelector", + "RangeSelector", + "RescaleImageJroisGroupbox", + "SamInputPointsWidget", + "TimeWidget", + "YeazV2SelectModelNameCombobox", + "formWidget", + "guiTabControl", + "selectStartStopFrames", + "AlphaNumericComboBox", + "CenteredDoubleSpinbox", + "ComboBox", + "DoubleSpinBox", + "ExpandableListBox", + "FloatLineEdit", + "IntLineEdit", + "KeySequenceFromText", + "LineEdit", + "OddSpinBox", + "QCenteredComboBox", + "QClickableLabel", + "ReadOnlyLineEdit", + "SearchLineEdit", + "ShortcutLineEdit", + "SpinBox", + "VectorLineEdit", + "WhitelistLineEdit", + "highlightableQWidgetAction", + "mySpinBox", + "readOnlyDoubleSpinbox", + "readOnlySpinbox", + "PixelSizeGroupbox", + "SetMeasurementsGroupBox", + "_metricsQGBox", + "channelMetricsQGBox", + "objIntesityMeasurQGBox", + "objPropsQGBox", + "CheckableAction", + "CheckableSpinBoxWidgets", + "FeatureSelectorButton", + "KeptObjectIDsList", + "Label", + "LatexLabel", + "OrderableListWidget", + "SwitchPlaneCombobox", + "TimestampItem", + "Toggle", + "ToggleTerminalButton", + "ToggleVisibilityButton", + "ToggleVisibilityCheckBox", + "expandCollapseButton", + "listWidget", + "statusBarPermanentLabel", + "GradientToolButton", + "ManualBackgroundToolBar", + "ManualTrackingToolBar", + "OverlayChannelToolButton", + "PointsLayerToolButton", + "SavePointsLayerButton", + "ToolBar", + "ToolBarSeparator", + "ToolButtonCustomColor", + "ToolButtonTextIcon", + "customAnnotToolButton", + "rightClickToolButton", + "CopyLostObjectToolbar", + "DrawClearRegionToolbar", + "HighlightedIDToolbar", + "MagicPromptsToolbar", + "OverlayToolbar", + "PointsLayersToolbar", + "PromptableModelPointsLayerToolbar", + "WandControlsToolbar", + "WhitelistIDsToolbar", +] diff --git a/cellacdc/widgets/canvas/__init__.py b/cellacdc/widgets/canvas/__init__.py new file mode 100644 index 000000000..7f36811a2 --- /dev/null +++ b/cellacdc/widgets/canvas/__init__.py @@ -0,0 +1,104 @@ +"""Canvas widgets.""" + +from .histogram import ( + BaseGradientEditorItemImage, + BaseGradientEditorItemLabels, + baseHistogramLUTitem, + labelsGradientWidget, + myColorButton, + myHistogramLUTitem, + overlayLabelsGradientWidget, +) + +from .images import ( + BaseImageItem, + BaseLabelsImageItem, + ChildImageItem, + GhostMaskItem, + OverlayImageItem, + ParentImageItem, + _ImShowImageItem, + labImageItem, +) + +from .imshow import ( + ImShow, + ImShowPlotItem, +) + +from .plot_items import ( + BaseScatterPlotItem, + ContourItem, + CustomAnnotationScatterPlotItem, + GhostContourItem, + LabelItem, + LabelRoiCircularItem, + MainPlotItem, + PlotCurveItem, + PointsScatterPlotItem, + RectItem, + RulerPlotItem, + ScaleBar, + ScatterPlotItem, + myLabelItem, +) + +from .rois import ( + DelROI, + PolyLineROI, + ROI, + ZoomROI, +) + +from .scrollbars import ( + MouseCursor, + ScrollBarWithNumericControl, + labelledQScrollbar, + linkedQScrollbar, + navigateQScrollBar, + sliderWithSpinBox, +) + +__all__ = [ + "BaseGradientEditorItemImage", + "BaseGradientEditorItemLabels", + "baseHistogramLUTitem", + "labelsGradientWidget", + "myColorButton", + "myHistogramLUTitem", + "overlayLabelsGradientWidget", + "BaseImageItem", + "BaseLabelsImageItem", + "ChildImageItem", + "GhostMaskItem", + "OverlayImageItem", + "ParentImageItem", + "_ImShowImageItem", + "labImageItem", + "ImShow", + "ImShowPlotItem", + "BaseScatterPlotItem", + "ContourItem", + "CustomAnnotationScatterPlotItem", + "GhostContourItem", + "LabelItem", + "LabelRoiCircularItem", + "MainPlotItem", + "PlotCurveItem", + "PointsScatterPlotItem", + "RectItem", + "RulerPlotItem", + "ScaleBar", + "ScatterPlotItem", + "myLabelItem", + "DelROI", + "PolyLineROI", + "ROI", + "ZoomROI", + "MouseCursor", + "ScrollBarWithNumericControl", + "labelledQScrollbar", + "linkedQScrollbar", + "navigateQScrollBar", + "sliderWithSpinBox", +] diff --git a/cellacdc/widgets/canvas/histogram.py b/cellacdc/widgets/canvas/histogram.py new file mode 100644 index 000000000..df7f710b6 --- /dev/null +++ b/cellacdc/widgets/canvas/histogram.py @@ -0,0 +1,1301 @@ +"""Canvas widgets: histogram.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class BaseGradientEditorItemImage(pg.GradientEditorItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + def restoreState(self, state): + pg.graphicsItems.GradientEditorItem.Gradients = GradientsImage + return super().restoreState(state) + + +class BaseGradientEditorItemLabels(pg.GradientEditorItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + def restoreState(self, state): + pg.graphicsItems.GradientEditorItem.Gradients = GradientsLabels + return super().restoreState(state) + + +class baseHistogramLUTitem(pg.HistogramLUTItem): + sigAddColormap = Signal(object, str) + sigRescaleIntes = Signal(object) + + def __init__(self, name="image", axisLabel="", parent=None, **kwargs): + pg.GradientEditorItem = BaseGradientEditorItemLabels + + super().__init__(**kwargs) + + self.labelStyle = {"color": "#ffffff", "font-size": "11px"} + + if axisLabel: + self.setAxisLabel(axisLabel) + + self.cmaps = cmaps + self._parent = parent + self.name = name + + self.gradient.colorDialog.setWindowFlags(Qt.Dialog | Qt.WindowStaysOnTopHint) + self.gradient.colorDialog.accepted.disconnect() + self.gradient.colorDialog.accepted.connect(self.tickColorAccepted) + + self.isInverted = False + self.lastGradientName = "grey" + self.lastGradient = Gradients["grey"] + + for action in self.gradient.menu.actions(): + if action.text() == "HSV": + HSV_action = action + elif action.text() == "RGB": + RGB_ation = action + self.gradient.menu.removeAction(HSV_action) + self.gradient.menu.removeAction(RGB_ation) + + # Rescale intensities (LUT) + rescaleIntensMenu = self.gradient.menu.addMenu("Rescale intensities (LUT)") + rescaleActionGroup = QActionGroup(self) + rescaleActionGroup.setExclusive(True) + + self.rescaleEach2DimgAction = QAction( + "Rescale each 2D image", rescaleIntensMenu + ) + self.rescaleEach2DimgAction.setCheckable(True) + self.rescaleEach2DimgAction.setChecked(True) + rescaleActionGroup.addAction(self.rescaleEach2DimgAction) + rescaleIntensMenu.addAction(self.rescaleEach2DimgAction) + + self.rescaleAcrossZstackAction = QAction( + "Rescale across z-stack", rescaleIntensMenu + ) + self.rescaleAcrossZstackAction.setCheckable(True) + self.rescaleAcrossZstackAction.setChecked(False) + rescaleActionGroup.addAction(self.rescaleAcrossZstackAction) + rescaleIntensMenu.addAction(self.rescaleAcrossZstackAction) + + self.rescaleAcrossTimeAction = QAction( + "Rescale across time frames", rescaleIntensMenu + ) + self.rescaleAcrossTimeAction.setCheckable(True) + self.rescaleAcrossTimeAction.setChecked(False) + rescaleActionGroup.addAction(self.rescaleAcrossTimeAction) + rescaleIntensMenu.addAction(self.rescaleAcrossTimeAction) + + self.customRescaleAction = QAction("Choose custom levels...", rescaleIntensMenu) + self.customRescaleAction.setCheckable(True) + rescaleActionGroup.addAction(self.customRescaleAction) + rescaleIntensMenu.addAction(self.customRescaleAction) + + self.doNotRescaleAction = QAction( + "Do no rescale, display raw image", rescaleIntensMenu + ) + self.doNotRescaleAction.setCheckable(True) + rescaleActionGroup.addAction(self.doNotRescaleAction) + rescaleIntensMenu.addAction(self.doNotRescaleAction) + + self.rescaleActionGroup = rescaleActionGroup + rescaleActionGroup.triggered.connect(self.rescaleActionTriggered) + + # Add custom colormap action + self.customCmapsMenu = self.gradient.menu.addMenu("Custom colormaps") + self.customCmapsMenu.aboutToShow.connect(self.onShowCustomCmapsMenu) + self.customCmapsMenu.triggered.connect(self.customCmapsMenuTriggered) + + self.saveColormapAction = QAction("Save current colormap...", self) + self.gradient.menu.addAction(self.saveColormapAction) + self.saveColormapAction.triggered.connect(self.saveColormap) + + self.addCustomGradients() + + # Set inverted gradients for invert bw action + self.addInvertedColorMaps() + + self.gradient.menu.addSeparator() + + # hide histogram tool + self.vb.hide() + + # Disable moving the axis up and down + self.axis.unlinkFromView() + + # Disable histogram default context Menu event + self.vb.raiseContextMenu = lambda x: None + + def rescaleActionTriggered(self, action): + self.sigRescaleIntes.emit(action) + + def onShowCustomCmapsMenu(self): + self.customCmapsMenu.show() + + def customCmapsMenuTriggered(self, action): + cmap = action.cmap + self.gradient.colorMapMenuClicked(cmap) + self.gradient.showTicks(True) + + def setAxisLabel(self, text): + self.labelText = text + self.axis.setLabel(text, **self.labelStyle) + + def updateAxisLabel(self): + text = self.axis.label.toPlainText() + if not text: + return + self.setAxisLabel(text) + + def setGradient(self, gradient): + self.gradient.restoreState(gradient) + self.lastGradient = gradient + + def colormapClicked(self, checked=False, name=None): + name = self.sender().name + self.lastGradientName = name + if self.isInverted: + self.setGradient(self.invertedGradients[name]) + else: + self.setGradient(Gradients[name]) + + def sortTicks(self, ticks): + sortedTicks = sorted(ticks, key=operator.itemgetter(0)) + return sortedTicks + + def getInvertedGradients(self): + invertedGradients = {} + for name, gradient in Gradients.items(): + ticks = gradient["ticks"] + sortedTicks = self.sortTicks(ticks) + if name in nonInvertibleCmaps: + invertedColors = sortedTicks + else: + invertedColors = [ + (t[0], ti[1]) for t, ti in zip(sortedTicks, sortedTicks[::-1]) + ] + invertedGradient = {} + invertedGradient["ticks"] = invertedColors + invertedGradient["mode"] = gradient["mode"] + invertedGradients[name] = invertedGradient + return invertedGradients + + def addInvertedColorMaps(self): + self.invertedGradients = self.getInvertedGradients() + for action in self.gradient.menu.actions(): + if not hasattr(action, "name"): + continue + + if action.name not in self.cmaps: + continue + + action.triggered.disconnect() + action.triggered.connect(self.colormapClicked) + + px = QPixmap(100, 15) + p = QPainter(px) + invertedGradient = self.invertedGradients[action.name] + qtGradient = QLinearGradient(QPointF(0, 0), QPointF(100, 0)) + ticks = self.sortTicks(invertedGradient["ticks"]) + qtGradient.setStops([(x, QColor(*color)) for x, color in ticks]) + brush = QBrush(qtGradient) + p.fillRect(QRect(0, 0, 100, 15), brush) + p.end() + widget = action.defaultWidget() + hbox = widget.layout() + rectLabelWidget = QLabel() + rectLabelWidget.setPixmap(px) + hbox.addWidget(rectLabelWidget) + rectLabelWidget.hide() + + def setInvertedColorMaps(self, inverted): + if inverted: + showIdx = 2 + hideIdx = 1 + self.labelStyle["color"] = "#000000" + else: + showIdx = 1 + hideIdx = 2 + self.labelStyle["color"] = "#ffffff" + + for action in self.gradient.menu.actions(): + if not hasattr(action, "name"): + continue + + if action.name not in self.cmaps: + continue + + widget = action.defaultWidget() + hbox = widget.layout() + hideCmapRect = hbox.itemAt(hideIdx).widget() + showCmapRect = hbox.itemAt(showIdx).widget() + hideCmapRect.hide() + showCmapRect.show() + + self.updateAxisLabel() + self.isInverted = inverted + + def invertGradient(self, gradient): + ticks = gradient["ticks"] + sortedTicks = self.sortTicks(ticks) + invertedColors = [ + (t[0], ti[1]) for t, ti in zip(sortedTicks, sortedTicks[::-1]) + ] + invertedGradient = {} + invertedGradient["ticks"] = invertedColors + invertedGradient["mode"] = gradient["mode"] + return invertedGradient + + def invertCurrentColormap(self, inverted, debug=False): + self.setGradient(self.invertGradient(self.lastGradient)) + + def addCustomGradient(self, gradient_name, gradient_ticks, restore=True): + self.originalLength = self.gradient.length + self.gradient.length = 100 + if restore: + self.gradient.restoreState(gradient_ticks) + gradient = self.gradient.getGradient() + action = CustomGradientMenuAction(gradient, gradient_name, self.gradient) + # action.triggered.connect(self.gradient.contextMenuClicked) + action.delButton.clicked.connect(self.removeCustomGradient) + action.cmap = colors.pg_ticks_to_colormap(gradient_ticks["ticks"]) + # self.gradient.menu.insertAction(self.saveColormapAction, action) + self.customCmapsMenu.addAction(action) + self.gradient.length = self.originalLength + GradientsImage[gradient_name] = gradient_ticks + + def removeCustomGradient(self): + button = self.sender() + action = button.action + self.customCmapsMenu.removeAction(action) + cp = config.ConfigParser() + cp.read(custom_cmaps_filepath) + cp.remove_section(f"image.{action.name}") + with open(custom_cmaps_filepath, mode="w") as file: + cp.write(file) + + def addCustomGradients(self): + try: + CustomGradients = getCustomGradients(name="image") + if not CustomGradients: + return + for gradient_name, gradient_ticks in CustomGradients.items(): + self.addCustomGradient(gradient_name, gradient_ticks) + except Exception as e: + printl(traceback.format_exc()) + pass + + def _askNameColormap(self): + inputWin = apps.QInput(parent=self._parent, title="Colormap name") + inputWin.askText("Insert a name for the colormap: ", allowEmpty=False) + if inputWin.cancel: + return + cmapName = inputWin.answer + return cmapName + + def saveColormap(self): + cmapName = self._askNameColormap() + if cmapName is None: + return + + cp = config.ConfigParser() + if os.path.exists(custom_cmaps_filepath): + cp.read(custom_cmaps_filepath) + + SECTION = f"{self.name}.{cmapName}" + cp[SECTION] = {} + + # gradient_ticks = [] + state = self.gradient.saveState() + for key, value in state.items(): + if key != "ticks": + continue + for t, tick in enumerate(value): + pos, rgb = tick + # gradient_ticks.append((pos, rgb)) + rgb = ",".join([str(c) for c in rgb]) + val = f"{pos},{rgb}" + cp[SECTION][f"tick_{t}_pos_rgb"] = val + + with open(custom_cmaps_filepath, mode="w") as file: + cp.write(file) + + self.addCustomGradient(cmapName, state, restore=False) + + def tickColorAccepted(self): + self.gradient.currentColorAccepted() + # self.sigTickColorAccepted.emit(self.gradient.colorDialog.color().getRgb()) + + def setRescaleIntensitiesHow(self, how): + for action in self.rescaleActionGroup.actions(): + if action.text() == how: + action.setChecked(True) + return + + +class myHistogramLUTitem(baseHistogramLUTitem): + sigGradientMenuEvent = Signal(object) + sigGradientChanged = Signal(object) + sigTickColorAccepted = Signal(object) + sigAddScaleBar = Signal(bool) + sigAddTimestamp = Signal(bool) + + def __init__( + self, parent=None, name="image", axisLabel="", isViewer=False, **kwargs + ): + super().__init__(parent=parent, name=name, axisLabel=axisLabel, **kwargs) + + self.name = name + self._parent = parent + + self.childLutItem = None + + self.isViewer = isViewer + if isViewer: + # In the viewer we don't allow additional settings from the menu + return + + # Add scale bar action + self.addScaleBarAction = QAction("Add scale bar", self) + self.addScaleBarAction.setCheckable(True) + self.addScaleBarAction.triggered.connect(self.emitAddScaleBar) + self.gradient.menu.addAction(self.addScaleBarAction) + + # Add timestamp action + self.addTimestampAction = QAction("Add timestamp", self) + self.addTimestampAction.setCheckable(True) + self.addTimestampAction.triggered.connect(self.emitAddTimestamp) + self.gradient.menu.addAction(self.addTimestampAction) + + # Invert bw action + self.invertBwAction = QAction("Invert black/white", self) + self.invertBwAction.setCheckable(True) + self.gradient.menu.addAction(self.invertBwAction) + + # Font size menu action + self.fontSizeMenu = QMenu("Text font size") + self.gradient.menu.addMenu(self.fontSizeMenu) + + # Text color button + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Text color: ")) + self.textColorButton = myColorButton(color=(255, 255, 255)) + hbox.addStretch(1) + hbox.addWidget(self.textColorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.textColorButton.click) + self.gradient.menu.addAction(act) + + # Contours line weight + contLineWeightMenu = QMenu("Contours line weight", self.gradient.menu) + self.contLineWightActionGroup = QActionGroup(self) + self.contLineWightActionGroup.setExclusionPolicy( + QActionGroup.ExclusionPolicy.Exclusive + ) + for w in range(1, 11): + action = QAction(str(w)) + action.setCheckable(True) + if w == 2: + action.setChecked(True) + action.lineWeight = w + self.contLineWightActionGroup.addAction(action) + action = contLineWeightMenu.addAction(action) + self.gradient.menu.addMenu(contLineWeightMenu) + + # Contours color button + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Contours color: ")) + self.contoursColorButton = myColorButton(color=(25, 25, 25)) + hbox.addStretch(1) + hbox.addWidget(self.contoursColorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.contoursColorButton.click) + self.gradient.menu.addAction(act) + + # Mother-bud line weight + mothBudLineWeightMenu = QMenu("Mother-bud line weight", self.gradient.menu) + self.mothBudLineWightActionGroup = QActionGroup(self) + self.mothBudLineWightActionGroup.setExclusionPolicy( + QActionGroup.ExclusionPolicy.Exclusive + ) + for w in range(1, 11): + action = QAction(str(w)) + action.setCheckable(True) + if w == 2: + action.setChecked(True) + action.lineWeight = w + self.mothBudLineWightActionGroup.addAction(action) + action = mothBudLineWeightMenu.addAction(action) + self.gradient.menu.addMenu(mothBudLineWeightMenu) + + # Mother-bud line color + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Mother-bud line color: ")) + self.mothBudLineColorButton = myColorButton(color=(255, 0, 0)) + hbox.addStretch(1) + hbox.addWidget(self.mothBudLineColorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.mothBudLineColorButton.click) + self.gradient.menu.addAction(act) + + self.labelsAlphaMenu = self.gradient.menu.addMenu( + "Segm. masks overlay alpha..." + ) + # self.labelsAlphaMenu.setDisabled(True) + hbox = QHBoxLayout() + self.labelsAlphaSlider = sliderWithSpinBox( + title="Alpha", title_loc="in_line", isFloat=True, normalize=True + ) + self.labelsAlphaSlider.setMaximum(100) + self.labelsAlphaSlider.setSingleStep(0.05) + self.labelsAlphaSlider.setValue(0.3) + hbox.addWidget(self.labelsAlphaSlider) + shortCutText = "Command+Up/Down" if is_mac else "Ctrl+Up/Down" + hbox.addWidget(QLabel(f"({shortCutText})")) + widget = QWidget() + widget.setLayout(hbox) + act = QWidgetAction(self) + act.setDefaultWidget(widget) + self.labelsAlphaMenu.addSeparator() + self.labelsAlphaMenu.addAction(act) + + # Default settings + self.defaultSettingsAction = QAction("Restore default settings...", self) + self.gradient.menu.addAction(self.defaultSettingsAction) + + self.filterObject = FilterObject() + self.filterObject.sigFilteredEvent.connect(self.gradientMenuEventFilter) + self.gradient.menu.installEventFilter(self.filterObject) + self.highlightedAction = None + self.lastHoveredAction = None + + def setChildLutItem(self, childLutItem): + self.childLutItem = childLutItem + + def removeAddScaleBarAction(self): + self.gradient.menu.removeAction(self.addScaleBarAction) + + def removeAddTimestampAction(self): + self.gradient.menu.removeAction(self.addTimestampAction) + + def emitAddScaleBar(self): + self.sigAddScaleBar.emit(self.addScaleBarAction.isChecked()) + + def emitAddTimestamp(self): + self.sigAddTimestamp.emit(self.addTimestampAction.isChecked()) + + def gradientChanged(self): + super().gradientChanged() + self.sigGradientChanged.emit(self) + + def gradientMenuEventFilter(self, object, event): + if event.type() == QEvent.Type.MouseMove: + hoveredAction = self.gradient.menu.actionAt(event.pos()) + isActionEntered = hoveredAction != self.lastHoveredAction + if isActionEntered: + if isinstance(hoveredAction, highlightableQWidgetAction): + # print('Entered a custom action') + pass + isActionLeft = ( + self.highlightedAction is not None + and self.highlightedAction != hoveredAction + ) + if isActionLeft: + if isinstance(self.highlightedAction, highlightableQWidgetAction): + # print('Left a custom action') + pass + self.highlightedAction = hoveredAction + + self.lastHoveredAction = hoveredAction + + def addOverlayColorButton(self, rgbColor, channelName): + # Overlay color button + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Overlay color: ")) + self.overlayColorButton = myColorButton(color=rgbColor) + self.overlayColorButton.channel = channelName + hbox.addStretch(1) + hbox.addWidget(self.overlayColorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.overlayColorButton.click) + self.gradient.menu.addAction(act) + + def uncheckContLineWeightActions(self): + for act in self.contLineWightActionGroup.actions(): + try: + act.toggled.disconnect() + except Exception as e: + pass + act.setChecked(False) + + def uncheckMothBudLineLineWeightActions(self): + for act in self.mothBudLineWightActionGroup.actions(): + try: + act.toggled.disconnect() + except Exception as e: + pass + act.setChecked(False) + + def restoreState(self, df): + if "textIDsColor" in df.index: + rgbString = df.at["textIDsColor", "value"] + r, g, b = colors.rgb_str_to_values(rgbString) + self.textColorButton.setColor((r, g, b)) + + if "contLineColor" in df.index: + rgba_str = df.at["contLineColor", "value"] + rgb = colors.rgba_str_to_values(rgba_str)[:3] + self.contoursColorButton.setColor(rgb) + + if "contLineWeight" in df.index: + w = df.at["contLineWeight", "value"] + w = int(w) + for action in self.contLineWightActionGroup.actions(): + if action.lineWeight == w: + action.setChecked(True) + break + + if "mothBudLineWeight" in df.index: + w = df.at["mothBudLineWeight", "value"] + w = int(w) + for action in self.mothBudLineWightActionGroup.actions(): + if action.lineWeight == w: + action.setChecked(True) + break + + if "overlaySegmMasksAlpha" in df.index: + alpha = df.at["overlaySegmMasksAlpha", "value"] + self.labelsAlphaSlider.setValue(float(alpha)) + + if "mothBudLineColor" in df.index: + rgba_str = df.at["mothBudLineColor", "value"] + rgb = colors.rgba_str_to_values(rgba_str)[:3] + self.mothBudLineColorButton.setColor(rgb) + + checked = df.at["is_bw_inverted", "value"] == "Yes" + self.invertBwAction.setChecked(checked) + + self.restoreColormap(df) + + def saveState(self, df): + # remove previous state + df = df[~df.index.str.contains("img_cmap")].copy() + + state = self.gradient.saveState() + for key, value in state.items(): + if key == "ticks": + for t, tick in enumerate(value): + pos, rgb = tick + df.at[f"img_cmap_tick{t}_rgb", "value"] = rgb + df.at[f"img_cmap_tick{t}_pos", "value"] = pos + else: + if isinstance(value, bool): + value = "Yes" if value else "No" + df.at[f"img_cmap_{key}", "value"] = value + return df + + def restoreColormap(self, df): + state = {"mode": "rgb", "ticksVisible": True, "ticks": []} + ticks_pos = {} + ticks_rgb = {} + stateFound = False + for setting, value in df.itertuples(): + idx = setting.find("img_cmap_") + if idx == -1: + continue + + stateFound = True + m = re.findall(r"tick(\d+)_(\w+)", setting) + if m: + tick_idx, tick_type = m[0] + if tick_type == "pos": + ticks_pos[int(tick_idx)] = float(value) + elif tick_type == "rgb": + ticks_rgb[int(tick_idx)] = colors.rgba_str_to_values(value) + else: + key = setting[9:] + if value == "Yes": + value = True + elif value == "No": + value = False + state[key] = value + + if stateFound: + ticks = [(0, 0)] * len(ticks_pos) + for idx, val in ticks_pos.items(): + pos = val + rgb = ticks_rgb[idx] + ticks[idx] = (pos, rgb) + + state["ticks"] = ticks + self.gradient.restoreState(state) + + def regionChanged(self): + super().regionChanged() + if self.childLutItem is None: + return + + imageItem = self.imageItem() + try: + mn, mx = imageItem.quickMinMax(targetSize=65536) + # mn and mx can still be NaN if the data is all-NaN + if mn == mx or imageItem._xp.isnan(mn) or imageItem._xp.isnan(mx): + mn = 0 + mx = 255 + except AttributeError as err: + mn, mx = self.getLevels() + + self.childLutItem.setLevels(min=mn, max=mx) + + +class myColorButton(pg.ColorButton): + def __init__(self, parent=None, color=(128, 128, 128), padding=5): + super().__init__(parent=parent, color=color) + if isinstance(padding, (int, float)): + self.padding = (padding, padding, -padding, -padding) + else: + self.padding = padding + self._c = 225 + self._hoverDeltaC = 30 + self._alpha = 100 + self._bkgrColor = QColor(self._c, self._c, self._c, self._alpha) + self._borderColor = QColor(171, 171, 171) + self._rectBorderPen = QPen(QBrush(QColor(0, 0, 0)), 0.3) + + def paintEvent(self, event): + # QPushButton.paintEvent(self, ev) + p = QStylePainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + rect = self.rect() + p.setBrush(QBrush(self._bkgrColor)) + p.setPen(QPen(self._borderColor)) + p.drawRoundedRect(rect, 5, 5) + # p.fillRect(self.rect(), self._bkgrColor) + rect = self.rect().adjusted(*self.padding) + ## draw white base, then texture for indicating transparency, then actual color + p.setBrush(pg.mkBrush("w")) + p.drawRect(rect) + p.setBrush(QBrush(Qt.BrushStyle.DiagCrossPattern)) + p.drawRect(rect) + p.setPen(self._rectBorderPen) + p.setBrush(pg.mkBrush(self._color)) + p.drawRect(rect) + p.end() + + def enterEvent(self, event): + c = self._c + self._hoverDeltaC + self._bkgrColor = QColor(c, c, c, self._alpha) + self.update() + + def leaveEvent(self, event): + c = self._c + self._bkgrColor = QColor(c, c, c, self._alpha) + self.update() + + +class overlayLabelsGradientWidget(pg.GradientWidget): + def __init__( + self, + imageItem, + selectActionGroup, + segmEndname, + parent=None, + orientation="right", + ): + pg.GradientWidget.__init__(self, parent=parent, orientation=orientation) + + self.imageItem = imageItem + self.selectActionGroup = selectActionGroup + + for action in self.menu.actions(): + if action.text() == "HSV": + HSV_action = action + elif action.text() == "RGB": + RGB_ation = action + self.menu.removeAction(HSV_action) + self.menu.removeAction(RGB_ation) + + # Shuffle colors action + self.shuffleCmapAction = QAction("Randomly shuffle colormap (Shift+S)", self) + self.menu.addAction(self.shuffleCmapAction) + + # Drawing mode + drawModeMenu = QMenu("Drawing mode", self) + self.drawModeActionGroup = QActionGroup(self) + contoursDrawModeAction = QAction("Draw contours", drawModeMenu) + contoursDrawModeAction.setCheckable(True) + contoursDrawModeAction.setChecked(True) + contoursDrawModeAction.segmEndname = segmEndname + self.drawModeActionGroup.addAction(contoursDrawModeAction) + drawModeMenu.addAction(contoursDrawModeAction) + olDrawModeAction = QAction("Overlay labels", drawModeMenu) + olDrawModeAction.setCheckable(True) + olDrawModeAction.segmEndname = segmEndname + self.drawModeActionGroup.addAction(olDrawModeAction) + drawModeMenu.addAction(olDrawModeAction) + self.menu.addMenu(drawModeMenu) + + self.labelsAlphaMenu = self.menu.addMenu("Overlay labels alpha...") + hbox = QHBoxLayout() + self.labelsAlphaSlider = sliderWithSpinBox( + title="Alpha", title_loc="in_line", isFloat=True, normalize=True + ) + self.labelsAlphaSlider.setMaximum(100) + self.labelsAlphaSlider.setSingleStep(0.05) + self.labelsAlphaSlider.setValue(0.3) + hbox.addWidget(self.labelsAlphaSlider) + widget = QWidget() + widget.setLayout(hbox) + act = QWidgetAction(self) + act.setDefaultWidget(widget) + self.labelsAlphaMenu.addSeparator() + self.labelsAlphaMenu.addAction(act) + + self.menu.addSeparator() + self.menu.addSection("Select segm. file to adjust:") + for action in selectActionGroup.actions(): + self.menu.addAction(action) + + self.item.loadPreset("viridis") + self.updateImageLut(None) + self.updateImageOpacity(0.3) + + # Connect events + self.sigGradientChangeFinished.connect(self.updateImageLut) + self.labelsAlphaSlider.valueChanged.connect(self.updateImageOpacity) + self.shuffleCmapAction.triggered.connect(self.shuffleCmap) + + def shuffleCmap(self): + lut = self.imageItem.lut + np.random.shuffle(lut) + lut[0] = [0, 0, 0, 0] + self.imageItem.setLookupTable(lut) + self.imageItem.update() + + def updateImageLut(self, gradientItem): + lut = np.zeros((255, 4), dtype=np.uint8) + lut[:, -1] = 255 + lut[:, :-1] = self.item.colorMap().getLookupTable(0, 1, 255) + np.random.shuffle(lut) + lut[0] = [0, 0, 0, 0] + self.imageItem.setLookupTable(lut) + self.imageItem.setLevels([0, 255]) + + def updateImageOpacity(self, value): + self.imageItem.setOpacity(value) + + +class labelsGradientWidget(pg.GradientWidget): + sigShowRightImgToggled = Signal(bool) + sigShowLabelsImgToggled = Signal(bool) + sigShowNextFrameToggled = Signal(bool) + + def __init__(self, *args, parent=None, orientation="right", **kargs): + pg.GradientEditorItem = BaseGradientEditorItemLabels + + pg.GradientWidget.__init__( + self, *args, parent=parent, orientation=orientation, **kargs + ) + + self._parent = parent + self.name = "labels" + + for action in self.menu.actions(): + if action.text() == "HSV": + HSV_action = action + elif action.text() == "RGB": + RGB_ation = action + self.menu.removeAction(HSV_action) + self.menu.removeAction(RGB_ation) + + # Add custom colormap action + self.customCmapsMenu = self.menu.addMenu("Custom colormaps") + self.customCmapsMenu.aboutToShow.connect(self.onShowCustomCmapsMenu) + self.customCmapsMenu.triggered.connect(self.customCmapsMenuTriggered) + + self.saveColormapAction = QAction("Save current colormap...", self) + self.menu.addAction(self.saveColormapAction) + self.saveColormapAction.triggered.connect(self.saveColormap) + + self.addCustomGradients() + + # Background color button + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Background color: ")) + self.colorButton = myColorButton(color=(25, 25, 25)) + hbox.addStretch(1) + hbox.addWidget(self.colorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.colorButton.click) + self.menu.addAction(act) + + # Font size menu action + self.fontSizeMenu = QMenu("Text font size", self) + self.menu.addMenu(self.fontSizeMenu) + + # IDs color button + hbox = QHBoxLayout() + hbox.addWidget(QLabel("Text color: ")) + self.textColorButton = myColorButton(color=(25, 25, 25)) + hbox.addStretch(1) + hbox.addWidget(self.textColorButton) + widget = QWidget() + widget.setLayout(hbox) + act = highlightableQWidgetAction(self) + act.setDefaultWidget(widget) + act.triggered.connect(self.textColorButton.click) + self.menu.addAction(act) + self.menu.addSeparator() + + # Shuffle colors action + self.shuffleCmapAction = QAction("Randomly shuffle colormap (Shift+S)", self) + self.menu.addAction(self.shuffleCmapAction) + + self.greedyShuffleCmapAction = QAction( + "Greedily shuffle colormap (Alt+Shift+S)", self + ) + self.menu.addAction(self.greedyShuffleCmapAction) + + self.permanentGreedyCmapAction = QAction("Always use greedy colormap", self) + self.permanentGreedyCmapAction.setCheckable(True) + self.menu.addAction(self.permanentGreedyCmapAction) + + # Invert bw action + self.invertBwAction = QAction("Invert black/white", self) + self.invertBwAction.setCheckable(True) + self.menu.addAction(self.invertBwAction) + + # Show labels action + self.showLabelsImgAction = QAction("Show segmentation image", self) + self.showLabelsImgAction.setCheckable(True) + self.menu.addAction(self.showLabelsImgAction) + + # Show right image action + self.showRightImgAction = QAction("Show duplicated left image", self) + self.showRightImgAction.setCheckable(True) + self.menu.addAction(self.showRightImgAction) + + # Show next frame action + self.showNextFrameAction = QAction("Show next frame", self) + self.showNextFrameAction.setCheckable(True) + self.menu.addAction(self.showNextFrameAction) + + # Default settings + self.defaultSettingsAction = QAction("Restore default settings...", self) + self.menu.addAction(self.defaultSettingsAction) + + self.menu.addSeparator() + + self.showRightImgAction.toggled.connect(self.showRightImageToggled) + self.showLabelsImgAction.toggled.connect(self.showLabelsImageToggled) + self.showNextFrameAction.toggled.connect(self.showNextFrameToggled) + + def onShowCustomCmapsMenu(self): + self.customCmapsMenu.show() + + def customCmapsMenuTriggered(self, action): + cmap = action.cmap + self.item.colorMapMenuClicked(cmap) + self.item.showTicks(True) + + def addCustomGradient(self, gradient_name, gradient_ticks, restore=True): + currentState = self.item.saveState() + self.originalLength = self.item.length + self.item.length = 100 + if restore: + self.item.restoreState(gradient_ticks) + gradient = self.item.getGradient() + action = CustomGradientMenuAction(gradient, gradient_name, self.item) + # action.triggered.connect(self.item.contextMenuClicked) + action.delButton.clicked.connect(self.removeCustomGradient) + action.cmap = colors.pg_ticks_to_colormap(gradient_ticks["ticks"]) + # self.item.menu.insertAction(self.saveColormapAction, action) + self.customCmapsMenu.addAction(action) + self.item.length = self.originalLength + self.item.restoreState(currentState) + GradientsLabels[gradient_name] = gradient_ticks + + def removeCustomGradient(self): + button = self.sender() + action = button.action + self.customCmapsMenu.removeAction(action) + cp = config.ConfigParser() + cp.read(custom_cmaps_filepath) + cp.remove_section(f"labels.{action.name}") + with open(custom_cmaps_filepath, mode="w") as file: + cp.write(file) + + def addCustomGradients(self): + try: + CustomGradients = getCustomGradients(name="labels") + if not CustomGradients: + return + for gradient_name, gradient_ticks in CustomGradients.items(): + self.addCustomGradient(gradient_name, gradient_ticks) + except Exception as e: + printl(traceback.format_exc()) + pass + + def _askNameColormap(self): + inputWin = apps.QInput(parent=self._parent, title="Colormap name") + inputWin.askText("Insert a name for the colormap: ", allowEmpty=False) + if inputWin.cancel: + return + cmapName = inputWin.answer + return cmapName + + def saveColormap(self): + cmapName = self._askNameColormap() + if cmapName is None: + return + + cp = config.ConfigParser() + if os.path.exists(custom_cmaps_filepath): + cp.read(custom_cmaps_filepath) + + SECTION = f"{self.name}.{cmapName}" + cp[SECTION] = {} + + state = self.item.saveState() + for key, value in state.items(): + if key != "ticks": + continue + for t, tick in enumerate(value): + pos, rgb = tick + rgb = ",".join([str(c) for c in rgb]) + val = f"{pos},{rgb}" + cp[SECTION][f"tick_{t}_pos_rgb"] = val + + with open(custom_cmaps_filepath, mode="w") as file: + cp.write(file) + + self.addCustomGradient(cmapName, state, restore=False) + + def isRightImageVisible(self): + return ( + self.showLabelsImgAction.isChecked() or self.showNextFrameAction.isChecked() + ) + + def showRightImageToggled(self, checked): + if checked and self.isRightImageVisible(): + # Hide the right labels image before showing right image + self.showLabelsImgAction.setChecked(False) + self.showNextFrameAction.setChecked(False) + self.sigShowLabelsImgToggled.emit(False) + self.sigShowNextFrameToggled.emit(checked) + self.sigShowRightImgToggled.emit(checked) + + def showLabelsImageToggled(self, checked): + if checked and self.isRightImageVisible(): + # Hide the right image before showing labels image + self.showRightImgAction.setChecked(False) + self.showNextFrameAction.setChecked(False) + self.sigShowRightImgToggled.emit(False) + self.sigShowNextFrameToggled.emit(False) + self.sigShowLabelsImgToggled.emit(checked) + + def showNextFrameToggled(self, checked): + if checked and self.isRightImageVisible(): + # Hide the right image before showing labels image + self.showRightImgAction.setChecked(False) + self.showLabelsImgAction.setChecked(False) + self.sigShowRightImgToggled.emit(False) + self.sigShowLabelsImgToggled.emit(False) + self.sigShowNextFrameToggled.emit(checked) + + def saveState(self, df): + # remove previous state + df = df[~df.index.str.contains("lab_cmap")].copy() + + state = self.item.saveState() + for key, value in state.items(): + if key == "ticks": + for t, tick in enumerate(value): + pos, rgb = tick + df.at[f"lab_cmap_tick{t}_rgb", "value"] = rgb + df.at[f"lab_cmap_tick{t}_pos", "value"] = pos + else: + if isinstance(value, bool): + value = "Yes" if value else "No" + df.at[f"lab_cmap_{key}", "value"] = value + return df + + def restoreState(self, df, loadCmap=True): + # Insert background color + if "labels_bkgrColor" in df.index: + rgbString = df.at["labels_bkgrColor", "value"] + r, g, b = colors.rgb_str_to_values(rgbString) + self.colorButton.setColor((r, g, b)) + + if "labels_text_color" in df.index: + rgbString = df.at["labels_text_color", "value"] + r, g, b = colors.rgb_str_to_values(rgbString) + self.textColorButton.setColor((r, g, b)) + else: + self.textColorButton.setColor((255, 0, 0)) + + checked = df.at["is_bw_inverted", "value"] == "Yes" + self.invertBwAction.setChecked(checked) + + if not loadCmap: + return + + state = {"mode": "rgb", "ticksVisible": True, "ticks": []} + ticks_pos = {} + ticks_rgb = {} + stateFound = False + for setting, value in df.itertuples(): + idx = setting.find("lab_cmap_") + if idx == -1: + continue + + stateFound = True + m = re.findall(r"tick(\d+)_(\w+)", setting) + if m: + tick_idx, tick_type = m[0] + if tick_type == "pos": + ticks_pos[int(tick_idx)] = float(value) + elif tick_type == "rgb": + ticks_rgb[int(tick_idx)] = colors.rgba_str_to_values(value) + else: + key = setting[9:] + if value == "Yes": + value = True + elif value == "No": + value = False + state[key] = value + + if stateFound: + ticks = [(0, 0)] * len(ticks_pos) + for idx, val in ticks_pos.items(): + pos = val + rgb = ticks_rgb[idx] + ticks[idx] = (pos, rgb) + + state["ticks"] = ticks + self.item.restoreState(state) + else: + self.item.loadPreset("viridis") + + return stateFound + + def showMenu(self, ev): + try: + # Convert QPointF to QPoint + self.menu.popup(ev.screenPos().toPoint()) + except AttributeError: + self.menu.popup(ev.screenPos()) + +# Cross-module imports (deferred to avoid import cycles) +from .scrollbars import ( + sliderWithSpinBox, +) +from ..controls.inputs import ( + highlightableQWidgetAction, +) + diff --git a/cellacdc/widgets/canvas/images.py b/cellacdc/widgets/canvas/images.py new file mode 100644 index 000000000..34ef94680 --- /dev/null +++ b/cellacdc/widgets/canvas/images.py @@ -0,0 +1,796 @@ +"""Canvas widgets: images.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class BaseImageItem(pg.ImageItem): + def __init__(self, image=None, **kargs): + self.minMaxValuesMapper = None + self.minMaxValuesMapperPreproc = None + self.minMaxValuesMapperCombined = None + self.minMaxValuesMapperEqualized = None + self.pos_i = 0 + self.z = 0 + self.frame_i = 0 + self.usePreprocessed = False + self.useEqualized = False + self.useCombined = False + self._isRgba = False + + super().__init__(image, **kargs) + self.autoLevelsEnabled = None + + def isRgba(self): + return self._isRgba + + def setEnableAutoLevels(self, enabled: bool): + self.autoLevelsEnabled = enabled + + def setImage(self, image=None, autoLevels=None, **kargs): + if autoLevels is None: + autoLevels = self.autoLevelsEnabled + + if image is not None and image.ndim == 3 and image.shape[2] in (3, 4): + self._isRgba = True + + super().setImage(image, autoLevels=autoLevels, **kargs) + + def preComputedMinMaxValues(self, data: List["load.loadData"]): + self.minMaxValuesMapper = {} + for pos_i, posData in enumerate(data): + img_data = posData.img_data + requires_time_dim = posData.img_data.ndim == 2 or ( + posData.img_data.ndim == 3 and posData.SizeZ > 1 + ) + if requires_time_dim: + img_data = (img_data,) + + for frame_i, image in enumerate(img_data): + if image.ndim == 3: + self._updateMinMaxValuesProjections( + image, pos_i, frame_i, self.minMaxValuesMapper + ) + + if image.ndim == 2: + image = (image,) + + for z, img in enumerate(image): + self.minMaxValuesMapper[(pos_i, frame_i, z)] = ( + np.nanmin(img), + np.nanmax(img), + ) + + def updateMinMaxValuesEqualizedData( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + z_slice: Union[int, str], + ): + if self.minMaxValuesMapperEqualized is None: + self.minMaxValuesMapperEqualized = {} + + posData = data[pos_i] + img = posData.equalized_img_data[frame_i][z_slice] + key = (pos_i, frame_i, z_slice) + self.minMaxValuesMapperEqualized[key] = (np.nanmin(img), np.nanmax(img)) + + def updateMinMaxValuesEqualizedDataProjections( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + ): + posData = data[pos_i] + eq_zstack = posData.equalized_img_data[frame_i] + + self._updateMinMaxValuesProjections( + eq_zstack, pos_i, frame_i, self.minMaxValuesMapperEqualized + ) + + def _updateMinMaxValuesProjections(self, zstack, pos_i, frame_i, mapper): + max_proj = zstack.max(axis=0) + key = (pos_i, frame_i, "max z-projection") + mapper[key] = np.nanmin(max_proj), np.nanmax(max_proj) + + mean_proj = zstack.mean(axis=0) + key = (pos_i, frame_i, "mean z-projection") + mapper[key] = np.nanmin(mean_proj), np.nanmax(mean_proj) + + median_proj = np.median(zstack, axis=0) + key = (pos_i, frame_i, "median z-proj.") + mapper[key] = np.nanmin(median_proj), np.nanmax(median_proj) + + def updateMinMaxValuesPreprocessedData( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + z_slice: Union[int, str], + ): + if self.minMaxValuesMapperPreproc is None: + self.minMaxValuesMapperPreproc = {} + + posData = data[pos_i] + img = posData.preproc_img_data[frame_i][z_slice] + key = (pos_i, frame_i, z_slice) + self.minMaxValuesMapperPreproc[key] = (np.nanmin(img), np.nanmax(img)) + + def updateMinMaxValuesPreprocessedProjections( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + ): + posData = data[pos_i] + zstack = posData.preproc_img_data[frame_i] + + self._updateMinMaxValuesProjections( + zstack, pos_i, frame_i, self.minMaxValuesMapperPreproc + ) + + def updateMinMaxValuesCombinedData( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + z_slice: Union[int, str], + ): + if self.minMaxValuesMapperCombined is None: + self.minMaxValuesMapperCombined = {} + + posData = data[pos_i] + img = posData.combine_img_data[frame_i][z_slice] + key = (pos_i, frame_i, z_slice) + self.minMaxValuesMapperCombined[key] = (np.nanmin(img), np.nanmax(img)) + + def updateMinMaxValuesCombinedDataProjections( + self, + data: List["load.loadData"], + pos_i: int, + frame_i: int, + ): + posData = data[pos_i] + zstack = posData.combine_img_data[frame_i] + + self._updateMinMaxValuesProjections( + zstack, pos_i, frame_i, self.minMaxValuesMapperCombined + ) + + def setCurrentPosIndex(self, pos_i: int): + self.pos_i = pos_i + + def setCurrentFrameIndex(self, frame_i: int): + self.frame_i = frame_i + + def setCurrentZsliceIndex(self, z: int): + self.z = z + + def quickMinMax(self, targetSize=1e6): + if self.isRgba(): + return super().quickMinMax(targetSize=targetSize) + + if self.usePreprocessed and self.minMaxValuesMapperPreproc is not None: + minMaxValuesMapper = self.minMaxValuesMapperPreproc + elif self.useCombined and self.minMaxValuesMapperCombined is not None: + minMaxValuesMapper = self.minMaxValuesMapperCombined + elif self.useEqualized and self.minMaxValuesMapperEqualized is not None: + minMaxValuesMapper = self.minMaxValuesMapperEqualized + else: + minMaxValuesMapper = self.minMaxValuesMapper + + if minMaxValuesMapper is None: + return super().quickMinMax(targetSize=targetSize) + + try: + key = (self.pos_i, self.frame_i, self.z) + levels = minMaxValuesMapper[key] + return levels + except Exception as err: + pass + + try: + key = (self.pos_i, self.frame_i, self.z) + levels = self.minMaxValuesMapper[key] + return levels + except Exception as err: + return super().quickMinMax(targetSize=targetSize) + + def setOpacity(self, value, **kwargs): + if value == 0: + value = 0.001 + + if value == 1: + value = 0.999 + + super().setOpacity(value) + + +class BaseLabelsImageItem(pg.ImageItem): + def __init__(self, image=None, **kargs): + super().__init__(image, **kargs) + + def setImage(self, image=None, **kwargs): + if image is None: + return + autoLevels = kwargs.get("autoLevels") + if autoLevels is None: + kwargs["autoLevels"] = False + super().setImage(image, **kwargs) + + +class OverlayImageItem(pg.ImageItem): + def __init__(self, image=None, **kargs): + super().__init__(image, **kargs) + self.autoLevelsEnabled = None + + def setEnableAutoLevels(self, enabled: bool): + self.autoLevelsEnabled = enabled + + def setImage(self, image=None, autoLevels=None, **kargs): + if autoLevels is None: + autoLevels = self.autoLevelsEnabled + + super().setImage(image, autoLevels=autoLevels, **kargs) + + def setOpacity(self, value, **kwargs): + if value == 0: + value = 0.001 + + if value == 1: + value = 0.999 + + super().setOpacity(value) + + +class ParentImageItem(BaseImageItem): + def __init__( + self, + image=None, + linkedImageItem=None, + activatingActions=None, + debug=False, + **kargs, + ): + super().__init__(image, **kargs) + self.linkedImageItem = linkedImageItem + self.activatingActions = activatingActions + self.debug = debug + self._forceDoNotUpdateLinked = False + self.autoLevelsEnabled = None + + def clear(self): + if self.linkedImageItem is not None: + self.linkedImageItem.clear() + return super().clear() + + def isLinkedImageItemActive(self): + if self._forceDoNotUpdateLinked: + return False + + if self.linkedImageItem is None: + return False + + if self.activatingActions is None: + return False + + for action in self.activatingActions: + if action.isChecked(): + return True + + return False + + def setEnableAutoLevels(self, enabled: bool): + self.autoLevelsEnabled = enabled + + def setUsePreprocessed(self, usePreprocessed): + self.usePreprocessed = usePreprocessed + if self.linkedImageItem is None: + return + + self.linkedImageItem.usePreprocessed = usePreprocessed + + def setUseCombined(self, useCombined): + self.useCombined = useCombined + if self.linkedImageItem is None: + return + + self.linkedImageItem.useCombined = useCombined + + def preComputedMinMaxValues(self, *args, **kwargs): + super().preComputedMinMaxValues(*args, **kwargs) + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapper = self.minMaxValuesMapper + + def updateMinMaxValuesPreprocessedData(self, *args, **kwargs): + super().updateMinMaxValuesPreprocessedData(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapper = self.minMaxValuesMapper + + def updateMinMaxValuesCombinedData(self, *args, **kwargs): + super().updateMinMaxValuesCombinedData(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapperCombined = ( + self.minMaxValuesMapperCombined + ) + + def updateMinMaxValuesCombinedDataProjections(self, *args, **kwargs): + super().updateMinMaxValuesCombinedDataProjections(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapperCombined = ( + self.minMaxValuesMapperCombined + ) + + def updateMinMaxValuesEqualizedDataProjections(self, *args, **kwargs): + super().updateMinMaxValuesEqualizedDataProjections(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapperEqualized = ( + self.minMaxValuesMapperEqualized + ) + + def updateMinMaxValuesEqualizedData(self, *args, **kwargs): + super().updateMinMaxValuesEqualizedData(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.minMaxValuesMapperEqualized = ( + self.minMaxValuesMapperEqualized + ) + + def setCurrentPosIndex(self, *args, **kwargs): + super().setCurrentPosIndex(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.pos_i = self.pos_i + + def setCurrentFrameIndex(self, *args, **kwargs): + super().setCurrentFrameIndex(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.frame_i = self.frame_i + 1 + + def setCurrentZsliceIndex(self, *args, **kwargs): + super().setCurrentZsliceIndex(*args, **kwargs) + + if self.linkedImageItem is None: + return + + self.linkedImageItem.z = self.z + + def setImage( + self, + image=None, + autoLevels=None, + next_frame_image=None, + scrollbar_value=None, + force_set_linked=False, + **kargs, + ): + if autoLevels is None: + autoLevels = self.autoLevelsEnabled + + super().setImage(image, autoLevels=autoLevels, **kargs) + + if self.linkedImageItem is None: + return + + if not self.isLinkedImageItemActive() and not force_set_linked: + return + + if next_frame_image is not None: + self.linkedImageItem.setImage( + next_frame_image, scrollbar_value=scrollbar_value, autoLevels=autoLevels + ) + elif image is not None: + self.linkedImageItem.setImage(image) + + def updateImage(self, *args, **kargs): + if self.isLinkedImageItemActive(): + self.linkedImageItem.image = self.image + self.linkedImageItem.updateImage(*args, **kargs) + return super().updateImage(*args, **kargs) + + def setOpacity(self, value, applyToLinked=True): + super().setOpacity(value) + if not applyToLinked: + return + + if self.linkedImageItem is None: + return + + self.linkedImageItem.setOpacity(value) + + def setLookupTable(self, lut): + super().setLookupTable(lut) + + +class ChildImageItem(BaseImageItem): + def __init__(self, *args, linkedScrollbar=None, **kwargs): + BaseImageItem.__init__(self, *args, **kwargs) + self.linkedScrollbar = linkedScrollbar + + def setImage(self, img=None, z=None, scrollbar_value=None, **kargs): + autoLevels = kargs.get("autoLevels") + if autoLevels is None: + kargs["autoLevels"] = False + + if img is None: + BaseImageItem.setImage(self, img, **kargs) + return + + if img.ndim == 3 and img.shape[-1] > 4 and z is not None: + BaseImageItem.setImage(self, img[z], **kargs) + else: + BaseImageItem.setImage(self, img, **kargs) + + if self.linkedScrollbar is None: + return + + if not self.linkedScrollbar.isEnabled(): + return + + if scrollbar_value is None: + return + + self.linkedScrollbar.setValueNoSignal(scrollbar_value) + + +class labImageItem(pg.ImageItem): + def __init__(self, *args, **kwargs): + pg.ImageItem.__init__(self, *args, **kwargs) + + def setImage(self, img=None, z=None, **kargs): + autoLevels = kargs.get("autoLevels") + if autoLevels is None: + kargs["autoLevels"] = False + + if img is None: + pg.ImageItem.setImage(self, img, **kargs) + return + + if img.ndim == 3 and img.shape[-1] > 4 and z is not None: + pg.ImageItem.setImage(self, img[z], **kargs) + else: + pg.ImageItem.setImage(self, img, **kargs) + + +class GhostMaskItem(pg.ImageItem): + def __init__(self, ParentPlotItem): + super().__init__() + self.label = myLabelItem() + self.label.setAttr("bold", True) + self.label.setAttr("color", (245, 184, 0)) + self._ParentPlotItem = ParentPlotItem + + def initImage(self, imgShape): + image = np.zeros(imgShape, dtype=np.uint32) + self.setImage(image) + + def initLookupTable(self, rgbaColor): + lut = np.zeros((2, 4), dtype=np.uint8) + lut[1, -1] = 255 + lut[1, :-1] = rgbaColor + self.setLookupTable(lut) + + def addToPlotItem(self): + self._ParentPlotItem.addItem(self) + self._ParentPlotItem.addItem(self.label) + + def removeFromPlotItem(self): + self._ParentPlotItem.removeItem(self.label) + self._ParentPlotItem.removeItem(self) + + def updateGhostImage(self, ID=0, y_cursor=None, x_cursor=None, fontSize=None): + self.setImage(self.image) + + if ID == 0: + self.label.setText("") + return + + self.label.setText(f"{ID}", size=fontSize) + w, h = self.label.itemRect().width(), self.label.itemRect().height() + self.label.item.setPos(x_cursor, y_cursor - h) + + def clear(self): + if hasattr(self, "label"): + self.label.setText("") + if self.image is None: + return + self.image[:] = 0 + self.setImage(self.image) + + +class _ImShowImageItem(pg.ImageItem): + sigDataHover = Signal(str) + sigHoverEvent = Signal(object, object) + sigMousePressEvent = Signal(object, object) + + def __init__(self, idx) -> None: + super().__init__() + self._idx = idx + self._cursors = [] + self._autoLevels = True + + def _getHoverImageValue(self, xdata, ydata): + try: + value = self.image[ydata, xdata] + return value + except Exception as err: + return + + def setAutoLevels(self, autoLevels): + self._autoLevels = autoLevels + + def mousePressEvent(self, event): + self.sigMousePressEvent.emit(self, event) + super().mousePressEvent(event) + + def setOtherImagesCursors(self, cursors): + self._cursors = cursors + + def clearCursors(self): + for p, cursor in enumerate(self._cursors): + if p == self._idx: + continue + + cursor.setData([], []) + + def setImage(self, *args, **kwargs): + if "autoLevels" not in kwargs: + kwargs["autoLevels"] = self._autoLevels + + super().setImage(*args, **kwargs) + if not args: + return + + if not kwargs["autoLevels"]: + return + + image = args[0] + self._imageMax = image.max() + self._imageMin = image.min() + self._numLevels = self._imageMax - self._imageMin + + def hoverEvent(self, event): + self.sigHoverEvent.emit(self, event) + + if event.isExit(): + self.clearCursors() + self.sigDataHover.emit("") + return + + x, y = event.pos() + xdata, ydata = int(x), int(y) + value = self._getHoverImageValue(xdata, ydata) + if value is None: + self.clearCursors() + self.sigDataHover.emit("") + return + + try: + self.sigDataHover.emit(f"x={xdata}, y={ydata}, {value = :.4f}") + except Exception as e: + self.sigDataHover.emit(f"x={xdata}, y={ydata}, {[val for val in value]}") + + for p, cursor in enumerate(self._cursors): + if p == self._idx: + continue + + cursor.setData([x], [y]) + +# Cross-module imports (deferred to avoid import cycles) +from .plot_items import ( + myLabelItem, +) + diff --git a/cellacdc/widgets/canvas/imshow.py b/cellacdc/widgets/canvas/imshow.py new file mode 100644 index 000000000..f975a8bed --- /dev/null +++ b/cellacdc/widgets/canvas/imshow.py @@ -0,0 +1,1075 @@ +"""Canvas widgets: imshow.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class ImShowPlotItem(pg.PlotItem): + def __init__( + self, + parent=None, + name=None, + labels=None, + title=None, + viewBox=None, + axisItems=None, + enableMenu=True, + **kargs, + ): + super().__init__( + parent, name, labels, title, viewBox, axisItems, enableMenu, **kargs + ) + # Overwrite zoom out button behaviour to disable autoRange after + # clicking it. + # If autorange is enabled, it is called everytime the brush or eraser + # scatter plot items touches the border causing flickering + self.disableAutoRange() + self.autoBtn.mode = "manual" + self.invertY(True) + self.setAspectLocked(True) + self.addImageItem(kargs.get("imageItem")) + + self._selected = False + self.selectingRects = [] + + def setSelectableTitle(self, title: QGraphicsProxyWidget, **kwargs): + self.layout.removeItem(self.titleLabel) + self.layout.addItem(title, 0, 1, alignment=Qt.AlignCenter) + + def isSelected(self): + return self._selected + + def setSelected( + self, selected: bool, xlim=(-np.inf, np.inf), ylim=(-np.inf, np.inf) + ): + if selected == self._selected: + return + + if selected: + ((xmin, xmax), (ymin, ymax)) = self.viewRange() + ylim_min, ylim_max = ylim + xlim_min, xlim_max = xlim + + xmin = max(xlim_min, xmin) + xmax = min(xlim_max, xmax) + ymin = max(ylim_min, ymin) + ymax = min(ylim_max, ymax) + + w = xmax - xmin + h = ymax - ymin + + bs = round(((w + h) / 2) * 0.02) + if bs < 1: + bs = 1 + + rect_left = RectItem(QRectF(xmin, ymin, bs, h)) + rect_top = RectItem(QRectF(xmin + bs, ymin, w - bs - bs, bs)) + rect_right = RectItem(QRectF(xmax - bs, ymin, bs, h)) + rect_bottom = RectItem(QRectF(xmin + bs, ymax - bs, w - bs - bs, bs)) + self.selectingRects.append(rect_left) + self.selectingRects.append(rect_top) + self.selectingRects.append(rect_right) + self.selectingRects.append(rect_bottom) + + self.addItem(rect_left) + self.addItem(rect_top) + self.addItem(rect_right) + self.addItem(rect_bottom) + else: + for rect in self.selectingRects: + self.removeItem(rect) + self.selectingRects = [] + + self._selected = selected + + def addImageItem(self, imageItem): + self.imageItem = imageItem + if imageItem is None: + return + + self.setupContextMenu() + self.addItem(imageItem) + + def setupContextMenu(self): + shuffleCmapAction = QAction("Shuffle colormap", self.vb.menu) + shuffleCmapAction.triggered.connect(self.shuffleColormap) + self.vb.menu.addAction(shuffleCmapAction) + + self.resetCmapAction = QAction("Reset colormap", self.vb.menu) + self.resetCmapAction.triggered.connect(self.resetColormap) + self.vb.menu.addAction(self.resetCmapAction) + self.resetCmapAction.setDisabled(True) + + def shuffleColormap(self): + N = self.imageItem._numLevels + colors = self.imageItem.lut / 255 + cmap = LinearSegmentedColormap.from_list("shuffled", colors, N=N) + lut = plot.matplotlib_cmap_to_lut(cmap, n_colors=N) + if not self.resetCmapAction.isEnabled(): + self._defaultLut = lut.copy() + bkgrColor = lut[0].copy() + np.random.shuffle(lut) + lut[0] = bkgrColor + self.imageItem.setLookupTable(lut) + self.imageItem.update() + self.resetCmapAction.setDisabled(False) + + def resetColormap(self): + self.imageItem.setLookupTable(self._defaultLut) + + def autoBtnClicked(self): + self.autoRange() + + def autoRange(self): + self.vb.autoRange() + self.autoBtn.hide() + + +class ImShow(QBaseWindow): + def __init__( + self, + parent=None, + link_scrollbars=True, + infer_rgb=True, + figure_title="", + selectable_images=False, + ): + super().__init__(parent=parent) + self._linkedScrollbars = link_scrollbars + self._infer_rgb = infer_rgb + self._figure_title = figure_title + self._selectable_images = True + self.selected_idx = None + + self._autoLevels = True + + self.textItems = [] + self.group_to_idx_mapper = {"": 0} + + def _getGraphicsScrollbar(self, idx, image, imageItem, maximum): + proxy = QGraphicsProxyWidget(imageItem) + scrollbar = ScrollBarWithNumericControl( + orientation=Qt.Horizontal, add_max_proj_button=True + ) + scrollbar.sigValueChanged.connect(self.OnScrollbarValueChanged) + scrollbar.sigMaxProjToggled.connect(self.onMaxProjToggled) + scrollbar.idx = idx + scrollbar.image = image + scrollbar.imageItem = imageItem + scrollbar.setMaximum(maximum) + proxy.setWidget(scrollbar) + proxy.scrollbar = scrollbar + return proxy + + def OnScrollbarValueChanged(self, value): + scrollbar = self.sender() + imageItem = scrollbar.imageItem + img = self._get2Dimg(imageItem, scrollbar.image) + imageItem.setImage(img) # , autoLevels=self._autoLevels) + + overlayLab = self._get2DlabOverlay(imageItem) + if overlayLab is not None: + imageItem.labImageItem.setImage(overlayLab, autoLevels=False) + + self.setPointsVisible(imageItem) + + self.updateIDs() + + if not self._linkedScrollbars: + return + if len(self.ImageItems) == 1: + return + + self._linkedScrollbars = False + try: + idx = scrollbar.idx + for otherImageItem in self.ImageItems: + if otherImageItem.gridPos == imageItem.gridPos: + continue + if otherImageItem.image.shape != imageItem.image.shape: + continue + for otherScrollbar in otherImageItem.ScrollBars: + if otherScrollbar.idx != idx: + continue + otherScrollbar.setValue(scrollbar.value()) + except Exception as e: + pass + finally: + self._linkedScrollbars = True + + def _get2Dimg(self, imageItem, image): + for scrollbar in imageItem.ScrollBars: + if scrollbar.maxProjCheckbox.isChecked(): + image = image.max(axis=0) + else: + image = image[scrollbar.value()] + return image + + def _get2DlabOverlay(self, imageItem): + try: + lab = imageItem.lab + except Exception as err: + return + + for scrollbar in imageItem.ScrollBars: + if scrollbar.maxProjCheckbox.isChecked(): + lab = lab.max(axis=0) + else: + lab = lab[scrollbar.value()] + + return lab + + def isObjVisible(self, obj, imageItem): + if len(obj.centroid) == 2: + return True + + z_scrollbar = imageItem.ScrollBars[-1] + if z_scrollbar.maxProjCheckbox.isChecked(): + return True + + z_slice = z_scrollbar.value() + min_z, min_y, min_x, max_z, max_y, max_x = obj.bbox + if z_slice >= min_z and z_slice < max_z: + return True + + return False + + def onMaxProjToggled(self, checked, scrollbar): + imageItem = scrollbar.imageItem + img = self._get2Dimg(imageItem, scrollbar.image) + imageItem.setImage(img) # , autoLevels=self._autoLevels) + overlayLab = self._get2DlabOverlay(imageItem) + if overlayLab is not None: + imageItem.labImageItem.setImage(overlayLab, autoLevels=False) + self.setPointsVisible(imageItem) + if not self._linkedScrollbars: + return + if len(self.ImageItems) == 1: + return + + self._linkedScrollbars = False + try: + idx = scrollbar.idx + for otherImageItem in self.ImageItems: + if otherImageItem.gridPos == imageItem.gridPos: + continue + if otherImageItem.image.shape != imageItem.image.shape: + continue + for otherScrollbar in otherImageItem.ScrollBars: + if otherScrollbar.idx != idx: + continue + otherScrollbar.maxProjCheckbox.setChecked(checked) + except Exception as e: + pass + finally: + self._linkedScrollbars = True + + self.updateIDs() + + def setPointsVisible(self, imageItem): + if not hasattr(imageItem, "pointsItems"): + return + + first_coord = imageItem.ScrollBars[0].value() + isMaxProj = imageItem.ScrollBars[0].maxProjCheckbox.isChecked() + for pointsItems in imageItem.pointsItems.values(): + for p, plotItem in enumerate(pointsItems): + plotItem.setVisible((isMaxProj) or (p == first_coord)) + + def setupStatusBar(self): + self.statusbar = self.statusBar() + self.wcLabel = QLabel(f"") + self.statusbar.addPermanentWidget(self.wcLabel) + + def setupMainLayout(self): + self._layout = QHBoxLayout() + self._container = QWidget() + self._container.setLayout(self._layout) + self.setCentralWidget(self._container) + + def setupGraphicLayout( + self, *images, hide_axes=True, max_ncols=4, color_scheme="light" + ): + self.graphicLayout = pg.GraphicsLayoutWidget() + self._colorScheme = color_scheme + + # Set a light background + if color_scheme == "light": + self.graphicLayout.setBackground((235, 235, 235)) + else: + self.graphicLayout.setBackground((30, 30, 30)) + + ncells = max_ncols * ceil(len(images) / max_ncols) + + nrows = ncells // max_ncols + nrows = nrows if nrows > 0 else 1 + ncols = max_ncols if len(images) > max_ncols else len(images) + + if color_scheme == "light": + color = "black" + else: + color = "white" + + self.titleLabel = pg.LabelItem(justify="center", color=color, size="14pt") + self.titleLabel.setText(self._figure_title) + self.graphicLayout.addItem(self.titleLabel, row=0, col=0, colspan=ncols) + start_row = 1 + + # Check if additional rows are needed for the scrollbars + max_ndim = max([image.ndim for image in images]) + if max_ndim > 4: + raise TypeError("One or more of the images have more than 4 dimensions.") + if max_ndim == 4: + rows_range = range(0, (nrows - 1) * 3 + 1, 3) + elif max_ndim == 3: + rows_range = range(0, (nrows - 1) * 2 + 1, 2) + else: + rows_range = range(nrows) + + self.PlotItems = [] + self.ImageItems = [] + self.ScrollBars = [] + i = 0 + for r in rows_range: + row = r + start_row + for col in range(ncols): + try: + image = images[i] + except IndexError: + break + plotItem = ImShowPlotItem() + if hide_axes: + plotItem.hideAxis("bottom") + plotItem.hideAxis("left") + self.graphicLayout.addItem(plotItem, row=row, col=col) + plotItem.loc = (row, col) + self.PlotItems.append(plotItem) + + imageItem = _ImShowImageItem(i) + plotItem.addImageItem(imageItem) + imageItem.plot = plotItem + imageItem.sigHoverEvent.connect(self.onImageItemHoverEvent) + imageItem.sigMousePressEvent.connect(self.onImageItemMousePressEvent) + self.ImageItems.append(imageItem) + imageItem.gridPos = (row, col) + imageItem.ScrollBars = [] + + is_rgb = image.shape[-1] == 3 and self._infer_rgb + is_rgba = image.shape[-1] == 4 and self._infer_rgb + does_not_require_scrollbars = image.ndim == 2 or ( + image.ndim == 3 and (is_rgb or is_rgba) + ) + if does_not_require_scrollbars: + i += 1 + continue + + idx_image = 3 if (is_rgb or is_rgba) else 2 + for s in range(image.ndim - idx_image): + maximum = image.shape[s] - 1 + scrollbarProxy = self._getGraphicsScrollbar( + s, image, imageItem, maximum + ) + self.graphicLayout.addItem(scrollbarProxy, row=row + s + 1, col=col) + imageItem.ScrollBars.append(scrollbarProxy.scrollbar) + + i += 1 + + self._layout.addWidget(self.graphicLayout) + + def onImageItemMousePressEvent(self, imageItem, event): + if not self._selectable_images: + return + + plotItem = imageItem.plot + if not plotItem.isSelected(): + return + + self.selected_idx = self.PlotItems.index(plotItem) + event.ignore() + self.close() + + def onImageItemHoverEvent(self, imageItem, event): + if not self._selectable_images: + return + + modifiers = QGuiApplication.keyboardModifiers() + isCtrl = modifiers == Qt.ControlModifier + plotItem = imageItem.plot + Y, X = imageItem.image.shape[:2] + plotItem.setSelected(isCtrl and not event.isExit(), xlim=(0, X), ylim=(0, Y)) + + def movePlotItem(self, title): + combobox = self.sender() + plotItem = combobox.plotItem + row, col = plotItem.loc + + otherPlotItemIdx = combobox.titles.index(title) + otherPlotItem = self.PlotItems[otherPlotItemIdx] + other_row, other_col = otherPlotItem.loc + + self.graphicLayout.removeItem(plotItem) + self.graphicLayout.removeItem(otherPlotItem) + self.graphicLayout.addItem(otherPlotItem, row=row, col=col) + self.graphicLayout.addItem(plotItem, row=other_row, col=other_col) + + combobox.blockSignals(True) + combobox.setCurrentText(combobox.default_text) + combobox.blockSignals(False) + + plotItemIdx = combobox.titles.index(combobox.default_text) + + otherPlotItem.loc = (row, col) + plotItem.loc = (other_row, other_col) + + def setupTitles(self, *titles): + for plotItem, title in zip(self.PlotItems, titles): + combobox = ComboBox() + combobox.default_text = title + combobox.titles = list(titles) + combobox.addItems(titles) + combobox.setMaximumWidth(combobox.sizeHint().width()) + combobox.setCurrentText(title) + comboboxGraphicsItem = QGraphicsProxyWidget() + comboboxGraphicsItem.setWidget(combobox) + combobox.plotItem = plotItem + plotItem.setSelectableTitle(comboboxGraphicsItem) + combobox.currentTextChanged.connect(self.movePlotItem) + + # color = 'k' if self._colorScheme == 'light' else 'w' + # for plotItem, title in zip(self.PlotItems, titles): + # plotItem.setSelectableTitle(title, color=color) + + def updateStatusBarLabel(self, text): + self.wcLabel.setText(text) + + def autoRange(self): + for plot in self.PlotItems: + plot.autoRange() + + def showImages( + self, + *images, + labels_overlays: np.ndarray | List[np.ndarray] = None, + luts=None, + labels_overlays_luts=None, + autoLevels=True, + autoLevelsOnScroll=False, + ): + from .plot import matplotlib_cmap_to_lut + + images = [np.squeeze(img) for img in images] + self.luts = luts + self._autoLevels = autoLevels + self._autoLevelsOnScroll = autoLevelsOnScroll + for image in images: + if image.ndim > 5 or image.ndim < 2: + raise TypeError( + f"Input image has {image.ndim} dimensions. " + "Only 2-D, 3-D, and 4-D images are supported" + ) + + if isinstance(labels_overlays, np.ndarray): + labels_overlays = [labels_overlays] + + if isinstance(labels_overlays_luts, np.ndarray): + labels_overlays_luts = [labels_overlays_luts] + + if ( + labels_overlays_luts is not None + and labels_overlays is not None + and (len(labels_overlays_luts) != len(labels_overlays)) + ): + raise TypeError( + f"Number of lables_overlays_luts is {len(labels_overlays_luts)}, " + f"while number of labels_overaly is {len(labels_overlays)}. " + "Pass `None` if you want to use default lut for the labels_overlays." + ) + + if labels_overlays is not None and (len(labels_overlays) != len(images)): + raise TypeError( + f"Number of images is {len(images)}, " + f"while number of labels_overaly is {len(labels_overlays)}. " + "Pass `None` if you do not need overlaid labeles." + ) + + for i, (image, imageItem) in enumerate(zip(images, self.ImageItems)): + if luts is not None: + _autoLevels = autoLevels + lut = luts[i] + if not autoLevels and lut is not None: + imageItem.setLevels([0, len(lut)]) + else: + _autoLevels = True + if lut is None: + lut = matplotlib_cmap_to_lut("viridis") + imageItem.setLookupTable(lut) + else: + _autoLevels = True + + is_rgb = image.shape[-1] == 3 and self._infer_rgb + is_rgba = image.shape[-1] == 4 and self._infer_rgb + does_not_require_scrollbars = image.ndim == 2 or ( + image.ndim == 3 and (is_rgb or is_rgba) + ) + + if does_not_require_scrollbars: + imageItem.setAutoLevels(_autoLevels) + imageItem.setImage(image) + else: + if not self._autoLevelsOnScroll and not _autoLevels: + imageItem.setAutoLevels(False) + imageItem.setLevels([image.min(), image.max()]) + for scrollbar in imageItem.ScrollBars: + scrollbar.setValue(int(scrollbar.maximum() / 2)) + + imageItem.sigDataHover.connect(self.updateStatusBarLabel) + + if labels_overlays is None: + continue + + lab_overlay = labels_overlays[i] + if lab_overlay is None: + continue + + if lab_overlay.shape != image.shape: + raise TypeError( + f"`lab_overlay` at index {i} has shape " + f"{lab_overlay.shape} which is different " + f"from image shape {image.shape}. " + "The image and the `lab_overlay` must " + "have the same shape." + ) + + plot = imageItem.plot + labImageItem = pg.ImageItem() + labImageItem.setOpacity(0.4) + plot.addImageItem(labImageItem) + + if labels_overlays_luts is not None: + labels_overlays_lut = labels_overlays_luts[i] + else: + labels_overlays_lut = self._getDefaultLabelsOverlayLut(lab_overlay) + + labImageItem.setLookupTable(labels_overlays_lut) + labImageItem.setLevels([0, len(labels_overlays_lut)]) + + imageItem.lab = lab_overlay + imageItem.labImageItem = labImageItem + + overlayLab = self._get2DlabOverlay(imageItem) + labImageItem.setImage(overlayLab, autoLevels=False) + + # Share axis between images with same X, Y shape + all_shapes = [image.shape[-2:] for image in images] + unique_shapes = set(all_shapes) + shame_shape_plots = [] + for unique_shape in unique_shapes: + plots = [ + self.PlotItems[i] + for i, shape in enumerate(all_shapes) + if shape == unique_shape + ] + shame_shape_plots.append(plots) + + for plots in shame_shape_plots: + for plot in plots: + plot.vb.setYLink(plots[0].vb) + plot.vb.setXLink(plots[0].vb) + + def _getDefaultLabelsOverlayLut(self, lab_overlay): + IDs = [obj.label for obj in skimage.measure.regionprops(lab_overlay)] + n_objs = len(IDs) + lut = np.zeros((n_objs + 1, 4), dtype=np.uint8) + rgbas = colors.plt_colormap_to_pg_lut("tab20", ncolors=n_objs) + np.random.shuffle(rgbas) + lut[1:] = rgbas + return lut + + def _createPointsScatterItem(self, xx, yy, group, colors=None, data=None): + if colors is None: + cmap = matplotlib.colormaps["jet_r"] + idx = self.group_to_idx_mapper[group] + r, g, b = [round(c * 255) for c in cmap(idx)][:3] + brush = pg.mkBrush(color=(r, g, b, 100)) + pen = pg.mkPen(width=2, color=(r, g, b)) + hoverBrush = pg.mkBrush((r, g, b, 200)) + else: + brush = [] + pen = [] + hoverBrush = None + for color in colors: + rgb = matplotlib.colors.to_rgb(color) + rgb = [round(c * 255) for c in rgb] + _brush = pg.mkBrush(color=(*rgb, 100)) + _pen = pg.mkPen(width=2, color=rgb) + brush.append(_brush) + pen.append(_pen) + + item = pg.ScatterPlotItem( + xx, + yy, + symbol="o", + pxMode=False, + size=3, + brush=brush, + pen=pen, + hoverable=True, + hoverBrush=hoverBrush, + data=data, + ) + return item + + def drawPointsFromDf( + self, points_df: pd.DataFrame | List[pd.DataFrame], points_groups=None + ): + if not isinstance(points_df, (list, tuple)): + points_df = [points_df] * len(self.PlotItems) + + for p, df in enumerate(points_df): + if isinstance(points_groups, str): + points_groups = [points_groups] + + if points_groups is None: + grouped = [("", df)] + groups = [""] + else: + grouped = df.groupby(points_groups) + groups = grouped.groups.keys() + + idxs_space = np.linspace(0, 1, len(groups)) + self.group_to_idx_mapper = dict(zip(groups, idxs_space)) + + for group, df in grouped: + yy = df["y"].values + xx = df["x"].values + points_coords = np.column_stack((yy, xx)) + if "z" in df.columns: + zz = df["z"].values + points_coords = np.column_stack((zz, points_coords)) + if len(group) == 1: + group = group[0] + + colors = None + if "color" in df.columns: + colors = df["color"].values + + data = None + if "data" in df.columns: + data = df["data"].values + + self.drawPoints( + points_coords, colors=colors, group=group, idx=p, data=data + ) + + def drawPoints( + self, + points_coords: np.ndarray, + group="", + idx=None, + colors=None, + data=None, + ): + offset = 0.5 if np.issubdtype(points_coords.dtype, np.integer) else 0 + n_dim = points_coords.shape[1] + + if idx is not None: + PlotItems = [self.PlotItems[idx]] + ImageItems = [self.ImageItems[idx]] + else: + PlotItems = self.PlotItems + ImageItems = self.ImageItems + + if n_dim == 2: + if data is None: + data = group + + zz = [0] * len(points_coords) + self.points_coords = np.column_stack((zz, points_coords)) + for p, plotItem in enumerate(PlotItems): + imageItem = ImageItems[p] + xx = points_coords[:, 1] + offset + yy = points_coords[:, 0] + offset + pointsItem = self._createPointsScatterItem( + xx, yy, group, data=data, colors=colors + ) + pointsItem.z = 0 + plotItem.addItem(pointsItem) + imageItem.pointsItems = {group: [pointsItem]} + elif n_dim == 3: + self.points_coords = points_coords + for p, plotItem in enumerate(PlotItems): + imageItem = ImageItems[p] + imageItem.pointsItems = defaultdict(list) + scrollbar = imageItem.ScrollBars[0] + for first_coord in range(scrollbar.maximum() + 1): + coords_idx = np.nonzero(points_coords[:, 0] == first_coord) + coords = points_coords[coords_idx] + if colors is None: + _colors = None + else: + _colors = np.asarray(colors)[coords_idx] + if len(_colors) == 0: + _colors = None + + _data = group + if data is not None: + _data = data[coords_idx] + if len(_data) == 0: + _data = group + + xx = coords[:, 2] + offset + yy = coords[:, 1] + offset + pointsItem = self._createPointsScatterItem( + xx, yy, group, data=_data, colors=_colors + ) + pointsItem.z = first_coord + plotItem.addItem(pointsItem) + pointsItem.setVisible(False) + imageItem.pointsItems[group].append(pointsItem) + self.setPointsVisible(imageItem) + + def setupDuplicatedCursors(self): + self.cursors = [] + for p, plotItem in enumerate(self.PlotItems): + cursor = pg.ScatterPlotItem( + symbol="+", + pxMode=True, + pen=pg.mkPen("k", width=1), + brush=pg.mkBrush("w"), + size=16, + tip=None, + ) + self.cursors.append(cursor) + plotItem.addItem(cursor) + + for imageItem in self.ImageItems: + imageItem.setOtherImagesCursors(self.cursors) + + def setPointsData(self, points_data): + points_df = pd.DataFrame( + { + "z": self.points_coords[:, 0], + "y": self.points_coords[:, 1], + "x": self.points_coords[:, 2], + } + ) + if isinstance(points_data, pd.Series): + points_df[points_data.name] = points_data.values + elif isinstance(points_data, pd.DataFrame): + points_df = points_df.join(points_data) + elif isinstance(points_data, np.ndarray): + if points_data.ndim == 1: + points_data = points_data[np.newaxis] + else: + points_data = points_data.T + for i, values in enumerate(points_data): + points_df[f"col_{i}"] = values + + self.points_df = points_df.set_index(["z", "y", "x"]).sort_index() + + for p, plotItem in enumerate(self.PlotItems): + imageItem = self.ImageItems[p] + for pointsItems in imageItem.pointsItems.values(): + for pointsItem in pointsItems: + pointsItem.sigClicked.connect(self.pointsClicked) + + def pointsClicked(self, item, points, event): + point = points[0] + x, y = point.pos() + coords = (item.z, int(y), int(x)) + point_data = self.points_df.loc[[coords]] + now = datetime.datetime.now().strftime("%H:%M:%S") + print("*" * 60) + print(f"Point clicked at {now}. Data:") + print("-" * 60) + print(point_data) + print("") + print("*" * 60) + + def annotateObjectIDs(self, annotate_labels_idxs=None, init=False): + if init: + self.annotate_labels_idxs = annotate_labels_idxs + self.textItems = [{} for _ in self.PlotItems] + if self.annotate_labels_idxs is None: + return + for i, plotItem in enumerate(self.PlotItems): + if i not in self.annotate_labels_idxs: + continue + plotTextItems = self.textItems[i] + imageItem = self.ImageItems[i] + try: + if init: + # 3D labels (if 3D) + lab = imageItem.lab + else: + lab = imageItem.labImageItem.image + except Exception as err: + lab = imageItem.image + + rp = skimage.measure.regionprops(lab) + for obj in rp: + textItem = plotTextItems.get(obj.label) + yc, xc = obj.centroid[-2:] + if textItem is None: + textItem = pg.TextItem(text="", anchor=(0.5, 0.5), color="r") + plotItem.addItem(textItem) + plotTextItems[obj.label] = textItem + + if self.isObjVisible(obj, imageItem): + text = str(obj.label) + else: + text = "" + + textItem.setText(text) + textItem.setPos(xc, yc) + + # plotItem.enableAutoRange() + + def clearLabels(self): + for textItems in self.textItems: + for textItem in textItems.values(): + textItem.setText("") + + def updateIDs(self): + self.clearLabels() + try: + self.annotateObjectIDs(annotate_labels_idxs=self.annotate_labels_idxs) + except Exception as err: + pass + + def show(self, block=False, screenToWindowRatio=None): + super().show(block=block) + if screenToWindowRatio is None: + return + screenGeometry = self.screen().geometry() + screenWidth = screenGeometry.width() + screenHeight = screenGeometry.height() + finalWidth = int(screenToWindowRatio * screenWidth) + finalHeight = int(screenToWindowRatio * screenHeight) + screenTop = screenGeometry.top() + screenLeft = screenGeometry.left() + xc, yc = screenLeft + screenWidth / 2, screenTop + screenHeight / 2 + winLeft = int(xc - finalWidth / 2) + winTop = int(yc - finalHeight / 2) + self.setGeometry(winLeft, winTop, finalWidth, finalHeight) + + def run(self, block=False, showMaximised=False, screenToWindowRatio=None): + if showMaximised: + self.showMaximized() + else: + self.show(screenToWindowRatio=screenToWindowRatio) + QTimer.singleShot(100, self.autoRange) + + if block: + self.exec_() + + def resizeEvent(self, event) -> None: + self.PlotItems[0].autoRange() + return super().resizeEvent(event) + +# Cross-module imports (deferred to avoid import cycles) +from .images import ( + _ImShowImageItem, + labImageItem, +) +from .plot_items import ( + RectItem, +) +from .scrollbars import ( + ScrollBarWithNumericControl, +) +from ..controls.inputs import ( + ComboBox, +) + diff --git a/cellacdc/widgets/canvas/plot_items.py b/cellacdc/widgets/canvas/plot_items.py new file mode 100644 index 000000000..29c8b2bf8 --- /dev/null +++ b/cellacdc/widgets/canvas/plot_items.py @@ -0,0 +1,1170 @@ +"""Canvas widgets: plot_items.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class ContourItem(pg.PlotCurveItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + self._prevData = None + + def clear(self): + try: + self.setData([], []) + except AttributeError as e: + pass + + def tempClear(self): + try: + self._prevData = [d.copy() for d in self.getData()] + self.clear() + except Exception as e: + pass + + def restore(self): + if self._prevData is not None: + if self._prevData[0] is not None: + self.setData(*self._prevData) + + +class BaseScatterPlotItem(pg.ScatterPlotItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + def tempClear(self): + try: + self._prevData = [d.copy() for d in self.getData()] + self.setData([], []) + except Exception as e: + pass + + def restore(self): + if self._prevData is not None: + if self._prevData[0] is not None: + self.setData(*self._prevData) + + +class CustomAnnotationScatterPlotItem(BaseScatterPlotItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + +class ScatterPlotItem(pg.ScatterPlotItem): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.updateBrushAndPen(**kwargs) + + def updateBrushAndPen(self, **kwargs): + brush = kwargs.get("brush") + if brush is not None: + self._itemBrush = brush + pen = kwargs.get("pen") + if pen is not None: + self._itemPen = pen + + def setData(self, *args, **kwargs): + super().setData(*args, **kwargs) + self.updateBrushAndPen(**kwargs) + + def itemBrush(self): + return self._itemBrush + + def itemPen(self): + return self._itemPen + + def removePoint(self, index): + newData = np.delete(self.data, index) + # Update the index of current points + for i in range(index, len(newData)): + spotItem = newData[i]["item"] + spotItem._index = i + newData[i]["item"] = spotItem + + self.data = newData + self.prepareGeometryChange() + self.informViewBoundsChanged() + self.bounds = [None, None] + self.invalidate() + self.updateSpots(newData) + self.sigPlotChanged.emit(self) + + def coordsToNumpy(self, includeData=False, rounded=True, decimals=None): + points = self.points() + nrows = len(points) + coords_arr = np.zeros((nrows, 2)) + data_arr = None + for p, point in enumerate(points): + pos = point.pos() + x, y = pos.x(), pos.y() + if includeData: + data = point.data() + if data_arr is None: + try: + ncols = len(data) + except Exception as e: + data = [data] + ncols = 1 + data_arr = np.zeros((nrows, ncols)) + for j, data_j in enumerate(data): + data_arr[p, j] = data_j + + coords_arr[p, 0] = y + coords_arr[p, 1] = x + if not includeData: + out_arr = coords_arr + elif data_arr is not None: + out_arr = np.column_stack((data_arr, coords_arr)) + else: + out_arr = coords_arr + cast_to_int = decimals is None + decimals = decimals if decimals is not None else 0 + if rounded: + out_arr = np.round(out_arr, decimals) + if cast_to_int: + out_arr = out_arr.astype(int) + return out_arr + + +class myLabelItem(pg.LabelItem): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._prevText = "" + + def setText(self, text, **args): + self.text = text + opts = self.opts + for k in args: + opts[k] = args[k] + + if "size" in self.opts: + size = self.opts["size"] + if size == "0pt" or size == "0px": + self.opts["size"] = "1pt" + super().setText("", size="1pt") + return + + optlist = [] + + color = self.opts["color"] + if color is None: + color = pg.getConfigOption("foreground") + color = pg.functions.mkColor(color) + optlist.append("color: " + color.name(QColor.NameFormat.HexArgb)) + if "size" in opts: + size = opts["size"] + if not isinstance(size, str): + size = f"{size}px" + optlist.append("font-size: " + size) + if "bold" in opts and opts["bold"] in [True, False]: + optlist.append( + "font-weight: " + {True: "bold", False: "normal"}[opts["bold"]] + ) + if "italic" in opts and opts["italic"] in [True, False]: + optlist.append( + "font-style: " + {True: "italic", False: "normal"}[opts["italic"]] + ) + full = "%s" % ("; ".join(optlist), text) + # print full + self.item.setHtml(full) + self.updateMin() + self.resizeEvent(None) + self.updateGeometry() + + def tempClearText(self): + if self.text: + self._prevText = self.text + self.setText("") + + def restoreText(self): + if self._prevText: + self.setText(self._prevText) + + +class LabelRoiCircularItem(pg.ScatterPlotItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + def setImageShape(self, shape): + self._shape = shape + + def slice(self, zRange=None, tRange=None): + self.mask() + if zRange is None: + _slice = self._slice + else: + zmin, zmax = zRange + _slice = (slice(zmin, zmax), *self._slice) + + if tRange is not None: + tmin, tmax = tRange + _slice = (slice(tmin, tmax), *_slice) + + return _slice + + def mask(self): + shape = self._shape + radius = int(self.opts["size"] / 2) + mask = skimage.morphology.disk(radius, dtype=bool) + xx, yy = self.getData() + Yc, Xc = yy[0], xx[0] + mask, self._slice = utils.clipSelemMask(mask, shape, Yc, Xc, copy=False) + return mask + + +class PlotCurveItem(pg.PlotCurveItem): + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + + def addPoint(self, x, y, **kargs): + _xx, _yy = self.getData() + if _xx is None or len(_xx) == 0: + self.xData = np.array([x], dtype=int) + self.yData = np.array([y], dtype=int) + return + if _xx[-1] == x and _yy[-1] == y: + # Do not append same point + return + + # Pre-allocate array and insert data (faster than append) + xx = np.zeros(len(_xx) + 1, dtype=_xx.dtype) + xx[:-1] = _xx + xx[-1] = x + yy = np.zeros(len(_yy) + 1, dtype=_xx.dtype) + yy[:-1] = _yy + yy[-1] = y + self.setData(xx, yy, **kargs) + + def clear(self): + try: + self.setData([], []) + except Exception as e: + pass + super().clear() + + def closeCurve(self): + _xx, _yy = self.getData() + self.addPoint(_xx[0], _yy[0]) + + def mask(self): + ymin, xmin, ymax, xmax = self.bbox() + _mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=bool) + local_xx, local_yy = self.getLocalData() + rr, cc = skimage.draw.polygon(local_yy, local_xx) + _mask[rr, cc] = True + return _mask + + def getLocalData(self): + _xx, _yy = self.getData() + return _xx - _xx.min(), _yy - _yy.min() + + def slice(self, zRange=None, tRange=None): + ymin, xmin, ymax, xmax = self.bbox() + if zRange is not None: + zmin, zmax = zRange + _slice = (slice(zmin, zmax), slice(ymin, ymax + 1), slice(xmin, xmax + 1)) + else: + _slice = (slice(ymin, ymax + 1), slice(xmin, xmax + 1)) + if tRange is not None: + tmin, tmax = tRange + _slice = (slice(tmin, tmax), *_slice) + return _slice + + def bbox(self): + _xx, _yy = self.getData() + return _yy.min(), _xx.min(), _yy.max(), _xx.max() + + +class MainPlotItem(pg.PlotItem): + def __init__( + self, + parent=None, + name=None, + labels=None, + title=None, + viewBox=None, + axisItems=None, + enableMenu=True, + showWelcomeText=False, + **kargs, + ): + super().__init__( + parent, name, labels, title, viewBox, axisItems, enableMenu, **kargs + ) + # Overwrite zoom out button behaviour to disable autoRange after + # clicking it. + # If autorange is enabled, it is called everytime the brush or eraser + # scatter plot items touches the border causing flickering + self.disableAutoRange() + self.autoBtn.mode = "manual" + if showWelcomeText: + self.infoTextItem = pg.TextItem() + self.addItem(self.infoTextItem) + html_filepath = os.path.join(html_path, "gui_welcome.html") + with open(html_filepath) as html_file: + htmlText = html_file.read() + self.infoTextItem.setHtml(htmlText) + self.infoTextItem.setPos(0, 0) + + self.delRoiItems = {} + self.highlightingRectItems = None + self._baseImageItem = None + self._imageItems = [] + self.highlightingRectItemsColor = None + + def addHighlightingRectItems(self, color=None): + self.highlightingRectItems = { + "left": RectItem(QRectF()), + "right": RectItem(QRectF()), + "top": RectItem(QRectF()), + "bottom": RectItem(QRectF()), + } + for rect in self.highlightingRectItems.values(): + self.addItem(rect) + + if color is None: + return + + self.setHighlightingRectItemsColor(color) + + def setHighlightingRectItemsColor(self, color): + if color == self.highlightingRectItemsColor: + return + + for item in self.highlightingRectItems.values(): + item.setColor(color) + + self.highlightingRectItemsColor = color + + def addBaseImageItem(self, baseImageItem): + self._baseImageItem = baseImageItem + self._imageItems.append(baseImageItem) + self.addItem(baseImageItem) + + def addImageItem(self, imageItem): + self._imageItems.append(imageItem) + self.addItem(imageItem) + + def setHighlighted(self, highlighted, color=None): + if color is None: + color = self.highlightingRectItemsColor + + if color is None: + color = "green" + + if self.highlightingRectItems is None: + self.addHighlightingRectItems(color=color) + + if not highlighted: + for rect in self.highlightingRectItems.values(): + rect.setQRect(QRectF()) + return + + self.setHighlightingRectItemsColor(color) + + ((xmin, xmax), (ymin, ymax)) = self.viewRange() + xmin = xmin if xmin >= 0 else 0 + ymin = ymin if ymin >= 0 else 0 + if self._baseImageItem is not None: + Y, X = self._baseImageItem.image.shape[:2] + xmax = min(xmax, X) + ymax = min(ymax, Y) + + w = xmax - xmin + h = ymax - ymin + + bs = round(((w + h) / 2) * 0.02) + if bs < 1: + bs = 1 + + x0 = xmin + x1 = xmin + bs + x2 = xmax - bs + x3 = xmax + + y0 = ymin + y1 = ymin + bs + y2 = ymax - bs + y3 = ymax + + self.highlightingRectItems["left"].setRect(x0, y0, bs, y3 - y0) + self.highlightingRectItems["top"].setRect(x1, y0, x3 - x1, bs) + self.highlightingRectItems["right"].setRect(x2, y1, bs, y3 - y1) + self.highlightingRectItems["bottom"].setRect(x1, y2, x2 - x1, bs) + self.update() + + def clear(self): + super().clear() + + self.delRoiItems = {} + self.highlightingRectItems = None + self._baseImageItem = None + self._imageItems = [] + self.highlightingRectItemsColor = None + + try: + self.removeItem(self.infoTextItem) + except Exception as e: + pass + + def autoBtnClicked(self): + self.vb.autoRange() + self.autoBtn.hide() + + def addDelRoiItem(self, roiItem, key): + if self.isDelRoiItemPresent(roiItem): + return + + self.delRoiItems[key] = roiItem + roiItem.key = key + self.addItem(roiItem) + + def removeDelRoiItem(self, roiItem): + key = roiItem.key + self.delRoiItems.pop(key, None) + try: + self.removeItem(roiItem) + except Exception as err: + return + + def isDelRoiItemPresent(self, roiItem): + try: + key = roiItem.key + except AttributeError as e: + return False + + try: + roi = self.delRoiItems[key] + except Exception as err: + return False + + return True + + def viewRange(self, mask_img=None): + if mask_img is None: + return super().viewRange() + + mask_rp = skimage.measure.regionprops(skimage.measure.label(mask_img)) + if not mask_rp: + return super().viewRange() + + mask_obj = mask_rp[0] + ymin, xmin, ymax, xmax = mask_obj.bbox + return (xmin, xmax), (ymin, ymax) + + +class GhostContourItem(pg.PlotDataItem): + def __init__( + self, ParentPlotItem, penColor=(245, 184, 0, 100), textColor=(245, 184, 0) + ): + super().__init__() + # Yellow pen + self.setPen(pg.mkPen(width=2, color=penColor)) + self.label = myLabelItem() + self.label.setAttr("bold", True) + self.label.setAttr("color", textColor) + self._ParentPlotItem = ParentPlotItem + + def addToPlotItem(self): + self._ParentPlotItem.addItem(self) + self._ParentPlotItem.addItem(self.label) + + def removeFromPlotItem(self): + self._ParentPlotItem.removeItem(self.label) + self._ParentPlotItem.removeItem(self) + + def setData( + self, xx=None, yy=None, fontSize=11, ID=0, y_cursor=None, x_cursor=None + ): + if xx is None: + xx = [] + if yy is None: + yy = [] + super().setData(xx, yy) + if not hasattr(self, "label"): + return + + if ID == 0: + self.label.setText("") + else: + self.label.setText(f"{ID}", size=fontSize) + w, h = self.label.itemRect().width(), self.label.itemRect().height() + self.label.setPos(x_cursor, y_cursor - h) + + def clear(self): + self.setData([], []) + + +class LabelItem(pg.LabelItem): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def bbox(self): + xl, yl = self.pos().x(), self.pos().y() + wl, hl = self.itemRect().width(), self.itemRect().height() + return yl, xl, yl + hl, xl + wl + + def setBold(self, bold=True): + self.origPos = self.pos() + self.setText(self.text, bold=bold) + self.setPos(self.origPos) + + +class ScaleBar(QGraphicsObject): + sigEditProperties = Signal(object) + sigRemove = Signal(object) + + def __init__(self, imageShape, viewRange, parent=None): + super().__init__(parent) + self.SizeY, self.SizeX = imageShape + self.updateViewRange(viewRange) + self.plotItem = PlotCurveItem() + self.labelItem = LabelItem() + self._x_pad = 5 + self._y_pad = 3 + self._highlighted = False + self._parent = parent + self.clicked = False + self.createContextMenu() + + def updateViewRange(self, viewRange): + xRange, yRange = viewRange + x0, x1 = xRange + y0, y1 = yRange + if x0 < 0: + x0 = 0 + + if x1 > self.SizeX: + x1 = self.SizeX + + if y0 < 0: + y0 = 0 + + if y1 > self.SizeY: + y1 = self.SizeY + + self.xmax = x1 + self.xmin = x0 + + self.ymax = y1 + self.ymin = y0 + + def createContextMenu(self): + self.contextMenu = QMenu() + action = QAction("Edit properties...", self.contextMenu) + action.triggered.connect(self.emitEditProperties) + self.contextMenu.addSeparator() + action = QAction("Remove", self.contextMenu) + action.triggered.connect(self.emitRemove) + self.contextMenu.addAction(action) + + def emitEditProperties(self): + self.setHighlighted(False) + self.sigEditProperties.emit(self.properties()) + + def emitRemove(self): + self.sigRemove.emit(self) + + def isHighlighted(self): + return self._highlighted + + def setHighlighted(self, highlighted): + if self._highlighted and highlighted: + return + + if not self._highlighted and not highlighted: + return + + pen = self.highlightPen if highlighted else self.pen + self.labelItem.setBold(bold=highlighted) + self.plotItem.setPen(pen) + + self._highlighted = highlighted + + def showContextMenu(self, x, y): + self.contextMenu.popup(QPoint(int(x), int(y))) + + def properties(self): + properties = { + "thickness": self._thickness, + "length_pixel": self._length, + "length_unit": self._length_unit, + "is_text_visible": self._is_text_visible, + "color": self._color, + "loc": self._loc, + "font_size": float(self._font_size[:-2]), + "unit": self._unit, + "num_decimals": self._num_decimals, + "move_with_zoom": self._move_with_zoom, + } + return properties + + def move(self, xm, ym): + self._loc = "Custom" + + Dy = ym - self.yc + Dx = xm - self.xc + + x0 = self.x0c + Dx + x1 = x0 + self._length + y0 = y1 = self.y0c + Dy + self.plotItem.setData([x0, x1], [y0, y1]) + self.setTextPos() + + def paint(self, painter, option, widget): + pass + + def boundingRect(self): + ymin, xmin, ymax, xmax = self.bbox() + return QRectF(xmin, ymin, xmax - xmin, ymax - ymin) + + def setLocationProperty(self, loc: str): + self._loc = loc + + def setMoveWithZoomProperty(self, move_with_zoom): + self._move_with_zoom = move_with_zoom + + def setProperties( + self, + length_pixel, + length_unit, + thickness=3, + color="w", + is_text_visible=True, + loc="top-left", + font_size=12, + unit="", + num_decimals=0, + move_with_zoom=False, + ): + self._loc = loc + self._color = color + self._length = length_pixel + self._length_unit = length_unit + self._is_text_visible = is_text_visible + self._font_size = f"{font_size}px" + self._unit = unit + self._num_decimals = num_decimals + self._move_with_zoom = move_with_zoom + self._thickness = thickness + self.pen = pg.mkPen(width=thickness, color=color, cosmetic=False) + self.highlightPen = pg.mkPen(width=thickness + 2, color=color, cosmetic=False) + self.pen.setCapStyle(Qt.PenCapStyle.FlatCap) + self.highlightPen.setCapStyle(Qt.PenCapStyle.FlatCap) + self.plotItem.setPen(self.pen) + + def updatePhysicalLength(self, PhysicalSizeX): + length_unit = self._length_unit + unit = self._unit + length_um = _core.convert_length(length_unit, unit, "μm") + length_pixel = length_um / PhysicalSizeX + self._length = length_pixel + self.update() + + def addToAxis(self, ax): + ax.addItem(self.plotItem) + ax.addItem(self.labelItem) + + def setText(self): + if self._is_text_visible: + number = round(self._length_unit, self._num_decimals) + if self._num_decimals == 0: + number = int(number) + text = f"{number} {self._unit}" + else: + text = "" + self.labelItem.setText(text, color=self._color, size=self._font_size) + + def setTextPos(self): + xx, yy = self.plotItem.getData() + x0 = xx[0] + y0 = yy[0] + xc = x0 + self._length / 2 + wl = self.labelItem.itemRect().width() + hl = self.labelItem.itemRect().height() + xl = xc - wl / 2 + yt = y0 - hl + self.labelItem.setPos(xl, yt) + + def updatePosViewRangeChanged(self, viewRange): + if self._loc == "custom": + xx, yy = self.plotItem.getData() + x0p = xx[0] + y0p = yy[0] + xcp = x0p + self._length / 2 + hl = self.labelItem.itemRect().height() + ycp = y0p - hl / 2 + x0 = self.xmin + y0 = self.ymin + x_range = self.xmax - x0 + y_range = self.ymax - y0 + Dx_perc = (xcp - x0) / x_range + Dy_perc = (ycp - y0) / y_range + + self.updateViewRange(viewRange) + + X0 = self.xmin + Y0 = self.ymin + + X_range = self.xmax - X0 + Y_range = self.ymax - Y0 + + Xcp = X0 + (Dx_perc * X_range) + Ycp = Y0 + (Dy_perc * Y_range) + X0p = Xcp - (self._length / 2) + Y0p = Ycp + (hl / 2) + + X1p = X0p + self._length + Y1p = Y0p + + self.plotItem.setData([X0p, X1p], [Y0p, Y1p]) + else: + self.updateViewRange(viewRange) + self.update() + + def getStartXCoordFromLoc(self, loc): + if loc == "custom": + xx, yy = self.plotItem.getData() + x0 = xx[0] + return x0 + + self.setText() + wl = self.labelItem.itemRect().width() + if loc.find("left") != -1: + x0 = self._x_pad + self.xmin + xc = x0 + self._length / 2 + xl = xc - wl / 2 + if xl < x0: + # Text is larger than line --> move line to the right + x0 = self._x_pad + abs(xl - self._x_pad) + else: + x0 = self.xmax - self._length - self._x_pad + xc = x0 + self._length / 2 + x1 = x0 + self._length + xr = xc + wl / 2 + if xr > x1: + # Text is larger than line --> move line to the left + delta_overshoot = xr - x1 + x0 = x0 - delta_overshoot + return x0 + + def getStartYCoordFromLoc(self, loc): + if loc == "custom": + xx, yy = self.plotItem.getData() + y0 = yy[0] + return y0 + + self.setText() + textHeight = self.labelItem.itemRect().height() + if loc.find("top") != -1: + return textHeight + self._y_pad + self.ymin + else: + return self.ymax - self._y_pad - self._thickness + + def update(self): + x0 = self.getStartXCoordFromLoc(self._loc) # + self._thickness/2 + y0 = self.getStartYCoordFromLoc(self._loc) + + x1 = x0 + self._length # - self._thickness/2 + self.plotItem.setData([x0, x1], [y0, y0]) + + self.setText() + self.setTextPos() + + def draw(self, length_pixel, length_unit, **kwargs): + self.setProperties(length_pixel, length_unit, **kwargs) + self.update() + + def bbox(self): + y_line_min, x_line_min, y_line_max, x_line_max = self.plotItem.bbox() + y_lab_min, x_lab_min, y_lab_max, x_lab_max = self.labelItem.bbox() + ymin = min(y_line_min, y_lab_min) + xmin = min(x_line_min, x_lab_min) + ymax = max(y_line_max, y_lab_max) + xmax = max(x_line_max, x_lab_max) + return ymin, xmin, ymax, xmax + + def mousePressed(self, x, y): + self.clicked = True + self.xc, self.yc = x, y + xx, yy = self.plotItem.getData() + self.x0c = xx[0] + self.y0c = yy[0] + + def removeFromAxis(self, ax): + ax.removeItem(self.labelItem) + ax.removeItem(self.plotItem) + + +class RulerPlotItem(pg.PlotDataItem): + def __init__(self, *args, **kwargs): + self.labelItem = pg.LabelItem() + super().__init__(*args, **kwargs) + + def setData(self, *args, lengthText="", **kwargs): + super().setData(*args, **kwargs) + self.labelItem.setText("") + if not lengthText: + return + self.setLengthText(lengthText) + + def setLengthText(self, lengthText): + xx, yy = self.getData() + x0, x1 = sorted(xx) + y0, y1 = sorted(yy) + xc = round(x0 + (x1 - x0) / 2) + yc = round(y0 + (y1 - y0) / 2) + self.labelItem.setText(lengthText, size="11px", color="r") + # xc = x0 + self._length/2 + wl = self.labelItem.itemRect().width() + hl = self.labelItem.itemRect().height() + xl = xc - wl / 2 + yt = y0 - hl + self.labelItem.setPos(xl, yt) + + +class PointsScatterPlotItem(pg.ScatterPlotItem): + sigHoverEntered = Signal(object, object, object) + + def __init__(self, *args, ax=None, show_data_as_tip=False, **kwargs): + self.textItem = annotate.TextAnnotationsScatterItem(size=12, anchor=(1.0, 1.0)) + self.textItem.createSymbols( + [str(int_id) for int_id in range(200)], includeBold=False + ) + # self._textItems = {} + super().__init__(*args, **kwargs) + self.textItem.setParentItem(self) + self._font = QFont() + self._font.setPixelSize(12) + self.show_data_as_tip = show_data_as_tip + self.drawIds = True + self.ax = ax + self.sigHovered.connect(self.onHover) + self.lastHoveredPoint = None + + def onHover(self, item, points, event): + if len(points) == 0: + vb = self.getViewBox() + vb.setToolTip("") + return + + if self.lastHoveredPoint != points[0]: + self.sigHoverEntered.emit(item, points, event) + self.lastHoveredPoint = points[0] + + if not self.opts["hoverable"]: + return + + if not self.show_data_as_tip: + return + + tip_li = [str(point.data()) for point in points] + tip = "\n\n".join(tip_li) + + vb = self.getViewBox() + vb.setToolTip(tip) + + def setData(self, *args, **kwargs): + self.clearTextItems() + super().setData(*args, **kwargs) + data = kwargs.get("data") + if data is None: + return + + if len(data) == 0: + return + + first_point_data = data[0] + if not isinstance(first_point_data, (int, str)): + return + + if not self.drawIds: + return + + if self.show_data_as_tip: + return + + color = self.opts["brush"].color() + self.textItem.setColors({"id": color.getRgb()}) + size = self.opts["size"] + radius = size / 2 + # xx, yy = args + # for x, y, point_data in zip(xx, yy, data): + for point in self.points(): + text = str(point.data()) + if not text: + continue + + x, y = point.pos().x(), point.pos().y() + xt, yt = x + radius - 0.5, y - radius + 0.5 + opts = { + "text": text, + "bold": False, + "color_name": "id", + } + data = self.textItem.addObjAnnot((xt, yt), anchor=(-0.3, 1.3), **opts) + self.textItem.appendData(data, opts["text"]) + + self.textItem.draw() + # hexColor = color.name() + # htmlText = html_utils.span( + # text, color=hexColor, font_size='13pt', bold=True + # ) + + # textItem = self._textItems.get((x, y)) + # if textItem is None: + # textItem = pg.TextItem(html=htmlText, anchor=(0, 1)) + # textItem.setParentItem(self) + # self._textItems[(x, y)] = textItem + # self.ax.addItem(textItem) + # else: + # textItem.setHtml(htmlText) + # textItem.setPos(x+radius-0.5, y-radius+0.5) + + def clearTextItems(self): + self.textItem.clearData() + # for textItem in self._textItems.values(): + # textItem.setText('') + + def clear(self): + super().clear() + self.clearTextItems() + + def setVisible(self, visible): + super().setVisible(visible) + self.textItem.setVisible(visible) + + +class RectItem(pg.GraphicsObject): + def __init__(self, rect, pen=None, brush=(255, 0, 0, 100), parent=None): + super().__init__(parent) + self._rect = rect + self._pen = pg.mkPen(pen) + self._brush = pg.mkBrush(brush) + self.picture = QPicture() + self._generate_picture() + + def setColor(self, color): + rgba = matplotlib.colors.to_rgba(color, alpha=100 / 255) + rgba = [round(c * 255) for c in rgba] + self._brush = pg.mkBrush(rgba) + self._generate_picture() + self.update() + + def setRect(self, x, y, width, height): + self._rect = QRectF(x, y, width, height) + self._generate_picture() + self.update() + + def setQRect(self, qrect): + self._rect = qrect + self._generate_picture() + self.update() + + @property + def rect(self): + return self._rect + + def _generate_picture(self): + painter = QPainter(self.picture) + painter.setPen(self._pen) + painter.setBrush(self._brush) + painter.drawRect(self._rect) + painter.end() + + def paint(self, painter, option, widget=None): + painter.drawPicture(0, 0, self.picture) + + def boundingRect(self): + return QRectF(self.picture.boundingRect()) diff --git a/cellacdc/widgets/canvas/rois.py b/cellacdc/widgets/canvas/rois.py new file mode 100644 index 000000000..643e736ef --- /dev/null +++ b/cellacdc/widgets/canvas/rois.py @@ -0,0 +1,303 @@ +"""Canvas widgets: rois.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class PolyLineROI(pg.PolyLineROI): + def __init__(self, positions, closed=False, pos=None, **args): + super().__init__(positions, closed, pos, **args) + + +class ROI(pg.ROI): + def __init__( + self, + pos, + size=pg.Point(1, 1), + angle=0, + invertible=False, + maxBounds=None, + snapSize=1, + scaleSnap=False, + translateSnap=False, + rotateSnap=False, + parent=None, + pen=None, + hoverPen=None, + handlePen=None, + handleHoverPen=None, + movable=True, + rotatable=True, + resizable=True, + removable=False, + aspectLocked=False, + ): + super().__init__( + pos, + size, + angle, + invertible, + maxBounds, + snapSize, + scaleSnap, + translateSnap, + rotateSnap, + parent, + pen, + hoverPen, + handlePen, + handleHoverPen, + movable, + rotatable, + resizable, + removable, + aspectLocked, + ) + + def slice(self, zRange=None, tRange=None): + x0, y0 = [int(round(c)) for c in self.pos()] + w, h = [int(round(c)) for c in self.size()] + xmin, xmax = x0, x0 + w + if xmin > xmax: + xmin, xmax = xmax, xmin + ymin, ymax = y0, y0 + h + if ymin > ymax: + ymin, ymax = ymax, ymin + if zRange is not None: + zmin, zmax = zRange + _slice = (slice(zmin, zmax), slice(ymin, ymax), slice(xmin, xmax)) + else: + _slice = (slice(ymin, ymax), slice(xmin, xmax)) + if tRange is not None: + tmin, tmax = tRange + _slice = (slice(tmin, tmax), *_slice) + return _slice + + def bbox(self): + x0, y0 = [int(round(c)) for c in self.pos()] + w, h = [int(round(c)) for c in self.size()] + xmin, xmax = x0, x0 + w + if xmin > xmax: + xmin, xmax = xmax, xmin + ymin, ymax = y0, y0 + h + if ymin > ymax: + ymin, ymax = ymax, ymin + + return ymin, xmin, ymax, xmax + + +class ZoomROI(ROI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.viewRangesQueue = deque() + + def getLastRange(self): + xRange, yRange = self.viewRangesQueue.pop() + return xRange, yRange + + def storeLastRange(self, xRange, yRange): + self.viewRangesQueue.append((xRange, yRange)) + + +class DelROI(pg.ROI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def clearPoints(self): + """ + Remove all handles and segments. + """ + while len(self.handles) > 0: + self.removeHandle(self.handles[0]["item"]) diff --git a/cellacdc/widgets/canvas/scrollbars.py b/cellacdc/widgets/canvas/scrollbars.py new file mode 100644 index 000000000..b7803e122 --- /dev/null +++ b/cellacdc/widgets/canvas/scrollbars.py @@ -0,0 +1,595 @@ +"""Canvas widgets: scrollbars.""" + +"""GUI widgets: canvas.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class MouseCursor(QWidget): + def __init__(self, parent=None) -> None: + super().__init__(parent) + self._x = None + self._y = None + self.setMouseTracking(True) + + def mouseMoveEvent(self, event) -> None: + self.move(event.pos()) + self.update() + return super().mouseMoveEvent(event) + + # def drawAtPos(self, x, y): + # self._x = x + # self._y = y + # self.update() + + def paintEvent(self, event) -> None: + p = QPainter(self) + # p.setPen(QPen(QColor(0,0,0))) + # p.setBrush(QBrush(QColor(70,70,70,200))) + p.drawLine(0, 0, 200, 0) + p.end() + + +class labelledQScrollbar(ScrollBar): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._label = None + + def setLabel(self, label): + self._label = label + + def updateLabel(self): + if self._label is not None: + position = self.sliderPosition() + s = self._label.text() + s = re.sub(r"(\d+)/(\d+)", rf"{position + 1:02}/\2", s) + self._label.setText(s) + + def setSliderPosition(self, position): + QScrollBar.setSliderPosition(self, position) + self.updateLabel() + + def setValue(self, value): + QScrollBar.setValue(self, value) + self.updateLabel() + + +class navigateQScrollBar(ScrollBar): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._disableCustomPressEvent = False + self.signal_slot_mapper = {} + + def disableCustomPressEvent(self): + self._disableCustomPressEvent = True + + def enableCustomPressEvent(self): + self._disableCustomPressEvent = False + + def setAbsoluteMaximum(self, absoluteMaximum): + self._absoluteMaximum = absoluteMaximum + + def absoluteMaximum(self): + return self._absoluteMaximum + + def mousePressEvent(self, event): + super().mousePressEvent(event) + if self.maximum() == self._absoluteMaximum: + return + + if self._disableCustomPressEvent: + return + + def setValueNoSignal(self, value): + for signal_name, slot in self.signal_slot_mapper.items(): + signal = getattr(self, signal_name) + try: + signal.disconnect() + except Exception as e: + pass + + self.setSliderPosition(value) + self.connectEvents(self.signal_slot_mapper) + + def connectEvents(self, signal_slot_mapper: dict): + self.signal_slot_mapper = signal_slot_mapper + for signal_name, slot in signal_slot_mapper.items(): + signal = getattr(self, signal_name) + try: + signal.disconnect() + except Exception as e: + pass + signal.connect(slot) + + +class linkedQScrollbar(ScrollBar): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._linkedScrollBar = None + + def linkScrollBar(self, scrollbar): + self._linkedScrollBar = scrollbar + scrollbar.setSliderPosition(self.sliderPosition()) + + def unlinkScrollBar(self): + self._linkedScrollBar = None + + def setSliderPosition(self, position): + QScrollBar.setSliderPosition(self, position) + if self._linkedScrollBar is not None: + self._linkedScrollBar.setSliderPosition(position) + + def setMaximum(self, max): + QScrollBar.setMaximum(self, max) + if self._linkedScrollBar is not None: + self._linkedScrollBar.setMaximum(max) + + +class sliderWithSpinBox(QWidget): + sigValueChange = Signal(object) + valueChanged = Signal(object) + editingFinished = Signal() + + def __init__(self, *args, **kwargs): + super().__init__(*args) + + layout = QGridLayout() + + title = kwargs.get("title") + row = 0 + col = 0 + if title is not None: + titleLabel = QLabel(self) + titleLabel.setText(title) + loc = kwargs.get("title_loc", "top") + if loc == "top": + layout.addWidget(titleLabel, 0, col, alignment=Qt.AlignLeft) + elif loc == "in_line": + row = -1 + col = 1 + layout.addWidget(titleLabel, 0, 0, alignment=Qt.AlignLeft) + layout.setColumnStretch(0, 0) + + self._normalize = False + normalize = kwargs.get("normalize") + if normalize is not None and normalize: + self._normalize = True + self._isFloat = True + + self._isFloat = False + isFloat = kwargs.get("isFloat") + if isFloat is not None and isFloat: + self._isFloat = True + + self.slider = QSlider(Qt.Horizontal, self) + + if self._normalize or self._isFloat: + self.spinBox = DoubleSpinBox(self) + else: + self.spinBox = SpinBox(self) + self.spinBox.setAlignment(Qt.AlignCenter) + self.spinBox.setMaximum(2**31 - 1) + + maximum_on_label = kwargs.get("maximum_on_label") + spinbox_loc = kwargs.get("spinbox_loc", "right") + if spinbox_loc == "right": + spinbox_col = col + 1 + slider_col = col + if maximum_on_label is not None: + maximum_on_label_col = spinbox_col + 1 + elif spinbox_loc == "left": + spinbox_col = col + slider_col = col + 1 + if maximum_on_label is not None: + maximum_on_label_col = spinbox_col + 1 + slider_col += 1 + + if maximum_on_label is not None: + self.labelMaximum = QLabel() + layout.addWidget(self.labelMaximum, row + 1, maximum_on_label_col) + layout.addWidget(self.slider, row + 1, slider_col) + layout.addWidget(self.spinBox, row + 1, spinbox_col) + + if title is not None: + layout.setRowStretch(0, 1) + layout.setRowStretch(row + 1, 1) + layout.setColumnStretch(slider_col, 6) + layout.setColumnStretch(spinbox_col, 1) + + self._layout = layout + self.lastCol = col + 1 + self.sliderCol = slider_col + + self.slider.valueChanged.connect(self.sliderValueChanged) + self.slider.sliderReleased.connect(self.onEditingFinished) + self.spinBox.valueChanged.connect(self.spinboxValueChanged) + self.spinBox.editingFinished.connect(self.onEditingFinished) + + layout.setContentsMargins(5, 0, 5, 0) + + self.setLayout(layout) + + if maximum_on_label is not None: + self.setMaximum(maximum_on_label) + self.labelMaximum.setText(f"/{maximum_on_label}") + + def onEditingFinished(self): + self.editingFinished.emit() + + def maximum(self): + return self.slider.maximum() + + def minimum(self): + return self.slider.minimum() + + def setValue(self, value, emitSignal=False): + valueInt = value + if self._normalize: + valueInt = int(value * self.slider.maximum()) + elif self._isFloat: + valueInt = int(value) + + self.spinBox.valueChanged.disconnect() + self.spinBox.setValue(value) + self.spinBox.valueChanged.connect(self.spinboxValueChanged) + + self.slider.valueChanged.disconnect() + if valueInt > self.slider.maximum(): + self.slider.setMaximum(valueInt) + self.slider.setValue(valueInt) + self.slider.valueChanged.connect(self.sliderValueChanged) + + if emitSignal: + self.sigValueChange.emit(self.value()) + self.valueChanged.emit(self.value()) + + def setMaximum(self, max, including_spinbox=False): + self.slider.setMaximum(max) + if including_spinbox: + self.spinBox.setMaximum(max) + + def setSingleStep(self, step): + self.spinBox.setSingleStep(step) + + def setMinimum(self, min, including_spinbox=False): + self.slider.setMinimum(min) + if including_spinbox: + self.spinBox.setMinimum(min) + + def setSingleStep(self, step): + self.spinBox.setSingleStep(step) + + def setDecimals(self, decimals): + self.spinBox.setDecimals(decimals) + + def setTickPosition(self, position): + self.slider.setTickPosition(position) + + def setTickInterval(self, interval): + self.slider.setTickInterval(interval) + + def sliderValueChanged(self, val): + self.spinBox.valueChanged.disconnect() + if self._normalize: + valF = val / self.slider.maximum() + self.spinBox.setValue(valF) + else: + self.spinBox.setValue(val) + self.spinBox.valueChanged.connect(self.spinboxValueChanged) + self.sigValueChange.emit(self.value()) + self.valueChanged.emit(self.value()) + + def spinboxValueChanged(self, val): + if self._normalize: + val = int(val * self.slider.maximum()) + elif self._isFloat: + val = int(val) + + self.slider.valueChanged.disconnect() + self.slider.setValue(val) + self.slider.valueChanged.connect(self.sliderValueChanged) + self.sigValueChange.emit(self.value()) + self.valueChanged.emit(self.value()) + + def value(self): + return self.spinBox.value() + + def setDisabled(self, disabled) -> None: + self.slider.setDisabled(disabled) + self.spinBox.setDisabled(disabled) + + +class ScrollBarWithNumericControl(QWidget): + sigValueChanged = Signal(int) + sigMaxProjToggled = Signal(bool, object) + + def __init__( + self, + orientation=Qt.Horizontal, + add_max_proj_button=False, + parent=None, + labelText="", + ) -> None: + super().__init__(parent) + + self._slot = None + + layout = QHBoxLayout() + self.scrollbar = QScrollBar(orientation, self) + self.spinbox = QSpinBox(self) + self.maxLabel = QLabel(self) + idx = 0 + if labelText: + layout.addWidget(QLabel(labelText)) + layout.setStretch(idx, 0) + idx += 1 + + layout.addWidget(self.spinbox) + layout.setStretch(idx, 0) + idx += 1 + + layout.addWidget(self.maxLabel) + layout.setStretch(idx, 0) + idx += 1 + + layout.addWidget(self.scrollbar) + layout.setStretch(idx, 1) + idx += 1 + + if add_max_proj_button: + self.maxProjCheckbox = QCheckBox("MAX") + self.scrollbar.maxProjCheckbox = self.maxProjCheckbox + layout.addWidget(self.maxProjCheckbox) + layout.setStretch(idx, 0) + + layout.setContentsMargins(5, 0, 5, 0) + + self.setLayout(layout) + + self.spinbox.valueChanged.connect(self.spinboxValueChanged) + self.scrollbar.valueChanged.connect(self.scrollbarValueChanged) + + if add_max_proj_button: + self.maxProjCheckbox.toggled.connect(self.maxProjToggled) + + def connectValueChanged(self, slot): + self.sigValueChanged.connect(slot) + self._slot = slot + + def setValueNoSignal(self, value): + if self._slot is None: + return + self.sigValueChanged.disconnect() + self.setValue(value) + self.sigValueChanged.connect(self._slot) + + def maxProjToggled(self, checked): + self.scrollbar.setDisabled(checked) + self.sigMaxProjToggled.emit(checked, self) + + def showEvent(self, event) -> None: + super().showEvent(event) + + self.scrollbar.setMinimumHeight(self.spinbox.height()) + + def setMaximum(self, maximum): + self.maxLabel.setText(f"/{maximum}") + self.scrollbar.setMaximum(maximum) + self.spinbox.setMaximum(maximum) + + def setMinimum(self, minumum): + self.scrollbar.setMinimum(minumum) + self.spinbox.setMinimum(minumum) + + def spinboxValueChanged(self, value): + self.scrollbar.setValue(value) + + def scrollbarValueChanged(self, value): + self.spinbox.setValue(value) + self.sigValueChanged.emit(value) + + def setValue(self, value): + self.scrollbar.setValue(value) + + def value(self): + return self.scrollbar.value() + + def maximum(self): + return self.scrollbar.maximum() + +# Cross-module imports (deferred to avoid import cycles) +from ..controls.inputs import ( + DoubleSpinBox, + SpinBox, +) + diff --git a/cellacdc/widgets/controls/__init__.py b/cellacdc/widgets/controls/__init__.py new file mode 100644 index 000000000..9772cfda1 --- /dev/null +++ b/cellacdc/widgets/controls/__init__.py @@ -0,0 +1,153 @@ +"""Composite controls.""" + +from .dialogs import ( + QDialogListbox, + installJavaDialog, + myMessageBox, + selectTrackerGUI, + view_visualcpp_screenshot, + warnVisualCppRequired, +) + +from .forms import ( + AutoSaveIntervalWidget, + CheckableWidget, + CheckboxesGroupBox, + CopiableCommandWidget, + FontSizeWidget, + LabelsWidget, + PostProcessSegmSlider, + PostProcessSegmSpinbox, + PreProcessingSelector, + RangeSelector, + RescaleImageJroisGroupbox, + SamInputPointsWidget, + TimeWidget, + YeazV2SelectModelNameCombobox, + formWidget, + guiTabControl, + selectStartStopFrames, +) + +from .inputs import ( + AlphaNumericComboBox, + CenteredDoubleSpinbox, + ComboBox, + DoubleSpinBox, + ExpandableListBox, + FloatLineEdit, + IntLineEdit, + KeySequenceFromText, + LineEdit, + OddSpinBox, + QCenteredComboBox, + QClickableLabel, + ReadOnlyLineEdit, + SearchLineEdit, + ShortcutLineEdit, + SpinBox, + VectorLineEdit, + WhitelistLineEdit, + highlightableQWidgetAction, + mySpinBox, + readOnlyDoubleSpinbox, + readOnlySpinbox, +) + +from .metrics import ( + PixelSizeGroupbox, + SetMeasurementsGroupBox, + _metricsQGBox, + channelMetricsQGBox, + objIntesityMeasurQGBox, + objPropsQGBox, +) + +from .panels import ( + CheckableAction, + CheckableSpinBoxWidgets, + FeatureSelectorButton, + KeptObjectIDsList, + Label, + LatexLabel, + OrderableListWidget, + SwitchPlaneCombobox, + TimestampItem, + Toggle, + ToggleTerminalButton, + ToggleVisibilityButton, + ToggleVisibilityCheckBox, + expandCollapseButton, + listWidget, + statusBarPermanentLabel, +) + +__all__ = [ + "QDialogListbox", + "installJavaDialog", + "myMessageBox", + "selectTrackerGUI", + "view_visualcpp_screenshot", + "warnVisualCppRequired", + "AutoSaveIntervalWidget", + "CheckableWidget", + "CheckboxesGroupBox", + "CopiableCommandWidget", + "FontSizeWidget", + "LabelsWidget", + "PostProcessSegmSlider", + "PostProcessSegmSpinbox", + "PreProcessingSelector", + "RangeSelector", + "RescaleImageJroisGroupbox", + "SamInputPointsWidget", + "TimeWidget", + "YeazV2SelectModelNameCombobox", + "formWidget", + "guiTabControl", + "selectStartStopFrames", + "AlphaNumericComboBox", + "CenteredDoubleSpinbox", + "ComboBox", + "DoubleSpinBox", + "ExpandableListBox", + "FloatLineEdit", + "IntLineEdit", + "KeySequenceFromText", + "LineEdit", + "OddSpinBox", + "QCenteredComboBox", + "QClickableLabel", + "ReadOnlyLineEdit", + "SearchLineEdit", + "ShortcutLineEdit", + "SpinBox", + "VectorLineEdit", + "WhitelistLineEdit", + "highlightableQWidgetAction", + "mySpinBox", + "readOnlyDoubleSpinbox", + "readOnlySpinbox", + "PixelSizeGroupbox", + "SetMeasurementsGroupBox", + "_metricsQGBox", + "channelMetricsQGBox", + "objIntesityMeasurQGBox", + "objPropsQGBox", + "CheckableAction", + "CheckableSpinBoxWidgets", + "FeatureSelectorButton", + "KeptObjectIDsList", + "Label", + "LatexLabel", + "OrderableListWidget", + "SwitchPlaneCombobox", + "TimestampItem", + "Toggle", + "ToggleTerminalButton", + "ToggleVisibilityButton", + "ToggleVisibilityCheckBox", + "expandCollapseButton", + "listWidget", + "statusBarPermanentLabel", +] diff --git a/cellacdc/widgets/controls/dialogs.py b/cellacdc/widgets/controls/dialogs.py new file mode 100644 index 000000000..bc817e377 --- /dev/null +++ b/cellacdc/widgets/controls/dialogs.py @@ -0,0 +1,1309 @@ +"""Composite controls: dialogs.""" + +"""GUI widgets: controls.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class QDialogListbox(QDialog): + sigSelectionConfirmed = Signal(list) + + def __init__( + self, + title, + text, + items, + cancelText="Cancel", + multiSelection=True, + parent=None, + additionalButtons=(), + includeSelectionHelp=False, + allowSingleSelection=True, + preSelectedItems=None, + allowEmptySelection=True, + ): + self.cancel = True + items = list(items) + + super().__init__(parent) + self.setWindowTitle(title) + + if preSelectedItems is None: + if items: + preSelectedItems = (items[0],) + else: + preSelectedItems = set() + + self.allowSingleSelection = allowSingleSelection + self.allowEmptySelection = allowEmptySelection + + mainLayout = QVBoxLayout() + topLayout = QVBoxLayout() + bottomLayout = QHBoxLayout() + + self.mainLayout = mainLayout + + label = QLabel(text) + _font = QFont() + _font.setPixelSize(13) + label.setFont(_font) + # padding: top, left, bottom, right + label.setStyleSheet("padding:0px 0px 3px 0px;") + topLayout.addWidget(label, alignment=Qt.AlignCenter) + + if includeSelectionHelp: + selectionHelpLabel = QLabel() + txt = html_utils.paragraph("""
    + Ctrl+Click to select multiple items
    + Shift+Click to select a range of items
    + """) + selectionHelpLabel.setText(txt) + topLayout.addWidget(label, alignment=Qt.AlignCenter) + + listBox = listWidget() + listBox.setFont(_font) + listBox.addItems(items) + if multiSelection: + listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + else: + listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + listBox.setCurrentRow(0) + for i in range(listBox.count()): + item = listBox.item(i) + item.setSelected(item.text() in preSelectedItems) + + self.listBox = listBox + if not multiSelection: + listBox.itemDoubleClicked.connect(self.ok_cb) + topLayout.addWidget(listBox) + + if cancelText.lower().find("cancel") != -1: + cancelButton = cancelPushButton(cancelText) + else: + cancelButton = QPushButton(cancelText) + okButton = okPushButton(" Ok ") + + bottomLayout.addStretch(1) + bottomLayout.addWidget(cancelButton) + bottomLayout.addSpacing(20) + + if additionalButtons: + self._additionalButtons = [] + for button in additionalButtons: + if isinstance(button, str): + _button, isCancelButton = getPushButton(button) + self._additionalButtons.append(_button) + bottomLayout.addWidget(_button) + _button.clicked.connect(self.ok_cb) + else: + bottomLayout.addWidget(button) + + bottomLayout.addWidget(okButton) + bottomLayout.setContentsMargins(0, 10, 0, 0) + + mainLayout.addLayout(topLayout) + mainLayout.addLayout(bottomLayout) + self.setLayout(mainLayout) + + # Connect events + okButton.clicked.connect(self.ok_cb) + cancelButton.clicked.connect(self.cancel_cb) + + if multiSelection: + listBox.itemClicked.connect(self.onItemClicked) + listBox.itemSelectionChanged.connect(self.onItemSelectionChanged) + + self.setStyleSheet(LISTWIDGET_STYLESHEET) + self.areItemsSelected = [ + listBox.item(i).isSelected() for i in range(listBox.count()) + ] + self.setFont(font) + + def keyPressEvent(self, event) -> None: + mod = event.modifiers() + if mod == Qt.ShiftModifier or mod == Qt.ControlModifier: + self.listBox.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + elif event.key() == Qt.Key_Escape: + self.listBox.clearSelection() + event.ignore() + return + super().keyPressEvent(event) + + def onItemSelectionChanged(self): + if not self.listBox.selectedItems(): + self.areItemsSelected = [False for i in range(self.listBox.count())] + + def onItemClicked(self, item): + mod = QGuiApplication.keyboardModifiers() + if mod == Qt.ShiftModifier or mod == Qt.ControlModifier: + self.listBox.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + return + + self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) + itemIdx = self.listBox.row(item) + wasSelected = self.areItemsSelected[itemIdx] + if wasSelected: + item.setSelected(False) + + self.areItemsSelected = [ + self.listBox.item(i).isSelected() for i in range(self.listBox.count()) + ] + # self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + # else: + # selectedItems.append(item) + + # self.listBox.clearSelection() + # for i in range(self.listBox.count()): + # item = self.listBox.item(i).setSelected(True) + + # print(self.listBox.selectedItems()) + + def setSelectedItems(self, itemsTexts): + for i in range(self.listBox.count()): + item = self.listBox.item(i) + if item.text() in itemsTexts: + item.setSelected(True) + self.listBox.update() + + def warnSelectionEmpty(self): + msg = myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph( + "You need to select at least one item!.

    " + "Use Ctrl+Click to select multiple items
    " + "or Shift+Click to select a range of items" + ) + msg.warning(self, "Selection cannot be empty!", txt) + + def ok_cb(self, checked=False): + self.clickedButton = self.sender() + self.cancel = False + selectedItems = self.listBox.selectedItems() + self.selectedItemsText = [item.text() for item in selectedItems] + if not self.allowSingleSelection and len(self.selectedItemsText) < 2: + msg = myMessageBox(wrapText=False, showCentered=False) + txt = html_utils.paragraph( + "You need to select two or more items.

    " + "Use Ctrl+Click to select multiple items
    , or
    " + "Shift+Click to select a range of items" + ) + msg.warning(self, "Select two or more items", txt) + return + + if not self.allowEmptySelection and not self.selectedItemsText: + self.warnSelectionEmpty() + return + + self.sigSelectionConfirmed.emit(self.selectedItemsText) + self.close() + + def cancel_cb(self, event): + self.cancel = True + self.selectedItemsText = None + self.close() + + def exec_(self): + self.show(block=True) + + def show(self, block=False): + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + super().show() + + horizontal_sb = self.listBox.horizontalScrollBar() + while horizontal_sb.isVisible(): + self.resize(self.height(), self.width() + 10) + + if block: + self.loop = QEventLoop() + self.loop.exec_() + + def closeEvent(self, event): + if hasattr(self, "loop"): + self.loop.exit() + + +class myMessageBox(_base_widgets.QBaseDialog): + def __init__( + self, + parent=None, + showCentered=True, + wrapText=True, + scrollableText=False, + enlargeWidthFactor=0, + resizeButtons=True, + allowClose=True, + ): + super().__init__(parent) + + self.wrapText = wrapText + self.enlargeWidthFactor = enlargeWidthFactor + self.resizeButtons = resizeButtons + + self.cancel = True + self.cancelButton = None + self.okButton = None + self.clickedButton = None + self.alreadyShown = False + self.allowClose = allowClose + + self.showCentered = showCentered + + self.scrollableText = scrollableText + + self._layout = QGridLayout() + self.commandsLayout = None + self._layout.setHorizontalSpacing(20) + self.buttonsLayout = QHBoxLayout() + self.buttonsLayout.setSpacing(2) + self.buttons = [] + self.widgets = [] + self.layouts = [] + self.labels = [] + self.labelsWidgets = [] + self._pixmapLabels = [] + self.detailsTextWidget = None + self.showInFileManagButton = None + self.visibleDetails = False + self.doNotShowAgainCheckbox = None + + self.currentRow = 0 + self.textWidget = None + self._w = None + + self.textLayout = QVBoxLayout() + + self._layout.setColumnStretch(1, 1) + self.setLayout(self._layout) + + self.setFont(font) + + def mousePressEvent(self, event): + for label in self.labels: + label.setTextInteractionFlags( + Qt.TextBrowserInteraction | Qt.TextSelectableByKeyboard + ) + + def setIcon(self, iconName="SP_MessageBoxInformation"): + label = QLabel(self) + + standardIcon = getattr(QStyle, iconName) + icon = self.style().standardIcon(standardIcon) + pixmap = icon.pixmap(60, 60) + label.setPixmap(pixmap) + + self._layout.addWidget(label, 0, 0, alignment=Qt.AlignTop) + + def addImage(self, image_path): + pixmap = QPixmap(image_path) + label = QLabel() + label.setPixmap(pixmap) + self._layout.addWidget(label, self.currentRow, 1) + self.currentRow += 1 + + def addShowInFileManagerButton(self, path, txt=None): + if txt is None: + txt = "Reveal in Finder..." if is_mac else "Show in Explorer..." + self.showInFileManagButton = showInFileManagerButton(txt) + self.buttonsLayout.addWidget(self.showInFileManagButton) + func = partial(utils.showInExplorer, path) + self.showInFileManagButton.clicked.connect(func) + + def addBrowseUrlButton(self, url, button_text=""): + self.openUrlButton = OpenUrlButton(url, button_text) + self.buttonsLayout.addWidget(self.openUrlButton) + + def addCancelButton(self, button=None, connect=False): + if button is None: + self.cancelButton = cancelPushButton("Cancel") + else: + self.cancelButton = button + self.cancelButton.setIcon(QIcon(":cancelButton.svg")) + + self.buttonsLayout.insertWidget(0, self.cancelButton) + self.buttonsLayout.insertSpacing(1, 20) + if connect: + self.cancelButton.clicked.connect(self.buttonCallBack) + + def splitLatexBlocks(self, text): + texts = re.split(r"(.+?)
    ", text) + return texts + + def splitCopiableBlocks(self, texts: Sequence[str] | str): + if isinstance(texts, str): + texts = (texts,) + + texts_out = [] + for text in texts: + texts_out.extend(re.split(r"(.+?)", text)) + return texts_out + + def addText(self, text): + texts = self.splitLatexBlocks(text) + texts = self.splitCopiableBlocks(texts) + + labelsWidget = LabelsWidget(texts, wrapText=self.wrapText) + self.labelsWidgets.append(labelsWidget) + self.labels.extend(labelsWidget.labels) + if self.scrollableText: + textWidget = QScrollArea() + textWidget.setFrameStyle(QFrame.Shape.NoFrame) + textWidget.setWidget(labelsWidget) + else: + textWidget = labelsWidget + + self.textLayout.addWidget(textWidget) + + if self.textWidget is None: + self.textWidget = QWidget() + self.textWidget.setLayout(self.textLayout) + self._layout.addWidget(self.textWidget, self.currentRow, 1) + self.textRow = self.currentRow + self.currentRow += 1 + + return labelsWidget + + def addCopiableCommand(self, command): + copiableCommandWidget = CopiableCommandWidget(command) + screenWidth = self.screen().size().width() + maxWidth = int(0.75 * screenWidth) + sizeHint = copiableCommandWidget.sizeHint() + width = sizeHint.width() + if width > maxWidth: + copiableCommandWidget = addWidgetToScrollArea( + copiableCommandWidget, resizeMinHeightNoVerticalScrollbar=True + ) + self._layout.addWidget(copiableCommandWidget, self.currentRow, 1) + self.currentRow += 1 + + def copyToClipboard(self): + cb = QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + cb.setText(self.sender()._command, mode=cb.Clipboard) + print("Command copied!") + + def addButton(self, buttonText): + if not isinstance(buttonText, str): + # Passing button directly + button = buttonText + self.buttonsLayout.addWidget(button) + button.clicked.connect(self.buttonCallBack) + self.buttons.append(button) + return button + + button, isCancelButton = getPushButton(buttonText, qparent=self) + if not isCancelButton: + self.buttonsLayout.addWidget(button) + + button.clicked.connect(self.buttonCallBack) + self.buttons.append(button) + return button + + def addDoNotShowAgainCheckbox(self, text="Do not show again"): + self.doNotShowAgainCheckbox = QCheckBox(text) + + def addWidget(self, widget): + self._layout.addWidget(widget, self.currentRow, 1) + self.widgets.append(widget) + self.currentRow += 1 + + def addLayout(self, layout): + self._layout.addLayout(layout, self.currentRow, 1) + self.layouts.append(layout) + self.currentRow += 1 + + def setWidth(self, w): + self._w = w + + def show(self, block=False): + self.endOfScrollableRow = self.currentRow + + self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + # spacer + spacer = QSpacerItem(10, 10) + self._layout.addItem(spacer, self.currentRow, 1) + self._layout.setRowStretch(self.currentRow, 0) + + # buttons + self.currentRow += 1 + + if self.detailsTextWidget is not None: + self.buttonsLayout.insertWidget(1, self.detailsButton) + + # Do not show again checkbox + if self.doNotShowAgainCheckbox is not None: + self._layout.addWidget( + self.doNotShowAgainCheckbox, self.currentRow, 1, 1, 2 + ) + self.currentRow += 1 + + # spacer + self._layout.addItem(QSpacerItem(10, 10), self.currentRow, 1) + self.currentRow += 1 + + # buttons + self._layout.addLayout( + self.buttonsLayout, self.currentRow, 0, 1, 2, alignment=Qt.AlignRight + ) + + # Details + if self.detailsTextWidget is not None: + # spacer + self.currentRow += 1 + self._layout.addItem(QSpacerItem(20, 20), self.currentRow, 1) + + # detailsTextWidget + self.currentRow += 1 + self._layout.addWidget(self.detailsTextWidget, self.currentRow, 0, 1, 2) + + # spacer + self.currentRow += 1 + spacer = QSpacerItem(10, 10) + self._layout.addItem(spacer, self.currentRow, 1) + self._layout.setRowStretch(self.currentRow, 0) + + screenHeight = self.screen().size().height() + dialogHeight = self.sizeHint().height() + dialogWidth = self.sizeHint().width() + screenWidth = self.screen().size().width() + + # Check if scrollbar is needed + if dialogHeight > screenHeight and self.textWidget is not None: + textScrollArea = ScrollArea() + textScrollArea.setWidget(self.textWidget) + scrollAreaWidthNoSB = textScrollArea.minimumWidthNoScrollbar() + scrollAreaWidth = textScrollArea.sizeHint().width() + desiredDeltaWidth = scrollAreaWidthNoSB - scrollAreaWidth + if desiredDeltaWidth > 0: + desiredWidth = dialogWidth + desiredDeltaWidth + if desiredWidth < screenWidth: + self._w = desiredWidth + + self._layout.removeWidget(self.textWidget) + self._layout.addWidget(textScrollArea, self.textRow, 1) + + super().show() + QTimer.singleShot(5, self._resize) + + self.alreadyShown = True + + if block: + self._block() + + def setDetailedText(self, text, visible=False, wrap=True): + text = text.replace("\n", "
    ") + self.detailsTextWidget = QTextEdit(text) + self.detailsTextWidget.setReadOnly(True) + if not wrap: + self.detailsTextWidget.setLineWrapMode(QTextEdit.NoWrap) + self.detailsButton = showDetailsButton() + self.detailsButton.setCheckable(True) + self.detailsButton.clicked.connect(self._showDetails) + self.detailsTextWidget.hide() + self.visibleDetails = visible + + def _showDetails(self, checked): + if checked: + self.origHeight = self.height() + self.resize(self.width(), self.height() + 300) + self.detailsTextWidget.show() + else: + self.detailsTextWidget.hide() + func = partial(self.resize, self.width(), self.origHeight) + QTimer.singleShot(10, func) + + def _resize(self): + if self.resizeButtons: + widths = [button.width() for button in self.buttons] + if widths: + max_width = max(widths) + for button in self.buttons: + if button == self.cancelButton: + continue + button.setMinimumWidth(max_width) + + heights = [button.height() for button in self.buttons] + if heights: + max_h = max(heights) + for button in self.buttons: + button.setMinimumHeight(max_h) + if self.detailsTextWidget is not None: + self.detailsButton.setMinimumHeight(max_h) + if self.showInFileManagButton is not None: + self.showInFileManagButton.setMinimumHeight(max_h) + + if self._w is not None and self.width() < self._w: + self.resize(self._w, self.height()) + + if self.width() < 350: + self.resize(350, self.height()) + + if self.enlargeWidthFactor > 0: + self.resize(int(self.width() * self.enlargeWidthFactor), self.height()) + + if self.visibleDetails: + self.detailsButton.click() + + if self.showCentered: + screen = self.screen() + screenWidth = screen.size().width() + screenHeight = screen.size().height() + screenLeft = screen.geometry().x() + screenTop = screen.geometry().y() + w, h = self.width(), self.height() + left = int(screenLeft + screenWidth / 2 - w / 2) + top = int(screenTop + screenHeight / 2 - h / 2) + if top < screenTop: + top = screenTop + if left < screenLeft: + left = screenLeft + self.move(left, top) + + self._h = self.height() + + if self.okButton is not None: + self.okButton.setFocus() + + screen = self.screen() + screenWidth = screen.size().width() + screenHeight = screen.size().height() + + # Check Force wrap Text + for labelWidget in self.labelsWidgets: + textWidth = labelWidget.width() + if not textWidth > screenWidth - 10: + continue + factor = np.ceil(textWidth / screenWidth) + lineLength = int(labelWidget.nCharsLongestLine / factor) + for label in labelWidget.labels: + if isinstance(label, CopiableCommandWidget): + continue + + text = label.text() + chunks = textwrap.wrap(text, lineLength) + text = "
    ".join(chunks) + label.setText(text) + + QTimer.singleShot(100, self._resizeWrappedText) + + if self.widgets: + return + + if self.layouts: + return + + # # Start resizing height every 1 ms + # self.resizeCallsCount = 0 + # self.timer = QTimer() + # from config import warningHandler + # warningHandler.sigGeometryWarning.connect(self.timer.stop) + # self.timer.timeout.connect(self._resizeHeight) + # self.timer.start(1) + + def _resizeWrappedText(self): + screenWidth = self.screen().size().width() - 5 + self.resize(screenWidth, self.height()) + screenLeft = self.screen().geometry().left() + self.move(screenLeft, self.geometry().top()) + + def _resizeHeight(self): + try: + # Resize until a "Unable to set geometry" warning is captured + # by copnfig.warningHandler._resizeWarningHandler or # + # height doesn't change anymore + self.resize(self.width(), self.height() - 1) + if self.height() == self._h or self.resizeCallsCount > 100: + self.timer.stop() + return + + self.resizeCallsCount += 1 + self._h = self.height() + except Exception as e: + # traceback.format_exc() + self.timer.stop() + + def _template( + self, + parent, + title, + message, + detailsText=None, + buttonsTexts=None, + layouts=None, + widgets=None, + commands=None, + path_to_browse=None, + browse_button_text=None, + url_to_open=None, + open_url_button_text="Open url", + image_paths=None, + wrapDetails=True, + add_do_not_show_again_checkbox=False, + ): + if parent is not None: + self.setParent(parent) + self.setWindowTitle(title) + self.addText(message) + if commands is not None: + if isinstance(commands, str): + commands = (commands,) + for command in commands: + self.addCopiableCommand(command) + + if image_paths is not None: + if isinstance(image_paths, str): + image_paths = (image_paths,) + for image_path in image_paths: + self.addImage(image_path) + + if layouts is not None: + if utils.is_iterable(layouts): + for layout in layouts: + self.addLayout(layout) + else: + self.addLayout(layout) + + if widgets is not None: + self._layout.addItem(QSpacerItem(20, 20), self.currentRow, 1) + self.currentRow += 1 + if utils.is_iterable(widgets): + for widget in widgets: + self.addWidget(widget) + else: + self.addWidget(widgets) + + if path_to_browse is not None: + self.addShowInFileManagerButton(path_to_browse, txt=browse_button_text) + + if url_to_open is not None: + self.addBrowseUrlButton(url_to_open, button_text=open_url_button_text) + + buttons = [] + if buttonsTexts is None: + okButton = self.addButton(" Ok ") + buttons.append(okButton) + elif isinstance(buttonsTexts, str): + button = self.addButton(buttonsTexts) + buttons.append(button) + else: + for buttonText in buttonsTexts: + button = self.addButton(buttonText) + buttons.append(button) + + if detailsText is not None: + self.setDetailedText(detailsText, visible=True, wrap=wrapDetails) + + if add_do_not_show_again_checkbox: + self.addDoNotShowAgainCheckbox() + + return buttons + + def critical(self, *args, showDialog=True, **kwargs): + self.setIcon(iconName="SP_MessageBoxCritical") + buttons = self._template(*args, **kwargs) + if showDialog: + self.exec_() + return buttons + + def information(self, *args, showDialog=True, **kwargs): + self.setIcon(iconName="SP_MessageBoxInformation") + buttons = self._template(*args, **kwargs) + if showDialog: + self.exec_() + return buttons + + def warning(self, *args, showDialog=True, **kwargs): + self.setIcon(iconName="SP_MessageBoxWarning") + buttons = self._template(*args, **kwargs) + if showDialog: + self.exec_() + return buttons + + def question(self, *args, showDialog=True, **kwargs): + self.setIcon(iconName="SP_MessageBoxQuestion") + buttons = self._template(*args, **kwargs) + if showDialog: + self.exec_() + return buttons + + def _block(self): + self.loop = QEventLoop() + self.loop.exec_() + + def exec_(self): + self.show(block=True) + + def clickButtonFromText(self, buttonText): + for button in self.buttons: + if button.text() == buttonText: + button.click() + return + + def buttonCallBack(self, checked=True): + self.clickedButton = self.sender() + if self.clickedButton != self.cancelButton: + self.cancel = False + self.allowClose = True + self.close() + + def closeEvent(self, event): + if not self.allowClose: + event.ignore() + return + super().closeEvent(event) + + +class view_visualcpp_screenshot(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + layout = QHBoxLayout() + + self.setWindowTitle("Visual Studio Builld Tools installation") + + pixmap = QPixmap(":visualcpp.png") + label = QLabel() + label.setPixmap(pixmap) + + layout.addWidget(label) + self.setLayout(layout) + + +class installJavaDialog(myMessageBox): + def __init__(self, parent=None): + super().__init__(parent) + + self.setWindowTitle("Install Java") + self.setIcon("SP_MessageBoxWarning") + + txt_macOS = html_utils.paragraph(""" + Your system doesn't have the Java Development Kit + installed
    and/or a C++ compiler which is required for the installation of + javabridge

    + Cell-ACDC is now going to install Java for you.

    + NOTE: After clicking on "Install", follow the instructions
    + on the terminal
    . You will be asked to confirm steps and insert
    + your password to allow the installation.


    + If you prefer to do it manually, cancel the process
    + and follow the instructions below. + """) + + txt_windows = html_utils.paragraph(""" + Unfortunately, installing pre-compiled version of + javabridge failed.

    + Cell-ACDC is going to try to compile it now.

    + However, before proceeding, you need to install + Java Development Kit
    and a C++ compiler.

    + See instructions below on how to install it. + """) + + if not is_win: + self.instructionsButton = self.addButton("Show intructions...") + self.instructionsButton.setCheckable(True) + self.instructionsButton.disconnect() + self.instructionsButton.clicked.connect(self.showInstructions) + installButton = self.addButton("Install") + installButton.disconnect() + installButton.clicked.connect(self.installJava) + txt = txt_macOS + else: + okButton = self.addButton("Ok") + txt = txt_windows + + self.cancelButton = self.addButton("Cancel") + + label = self.addText(txt) + label.setWordWrap(False) + + self.resizeCount = 0 + + def addInstructionsWindows(self): + self.scrollArea = QScrollArea() + _container = QWidget() + _layout = QVBoxLayout() + for t, text in enumerate(utils.install_javabridge_instructions_text()): + label = QLabel() + label.setText(text) + if t == 1 or t == 2: + label.setOpenExternalLinks(True) + label.setTextInteractionFlags(Qt.TextBrowserInteraction) + code_layout = QHBoxLayout() + code_layout.addWidget(label) + copyButton = QToolButton() + copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) + copyButton.setIcon(QIcon(":edit-copy.svg")) + copyButton.setText("Copy link") + if t == 1: + copyButton.textToCopy = utils.jdk_windows_url() + code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) + else: + copyButton.textToCopy = utils.cpp_windows_url() + screenshotButton = QToolButton() + screenshotButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) + screenshotButton.setIcon(QIcon(":cog.svg")) + screenshotButton.setText("See screenshot") + code_layout.addWidget(screenshotButton, alignment=Qt.AlignLeft) + code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) + screenshotButton.clicked.connect(self.viewScreenshot) + copyButton.clicked.connect(self.copyToClipboard) + code_layout.setStretch(0, 2) + code_layout.setStretch(1, 0) + _layout.addLayout(code_layout) + else: + _layout.addWidget(label) + + _container.setLayout(_layout) + self.scrollArea.setWidget(_container) + self.currentRow += 1 + self._layout.addWidget( + self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop + ) + + # Stretch last row + self.currentRow += 1 + self._layout.setRowStretch(self.currentRow, 1) + + def viewScreenshot(self, checked=False): + self.screenShotWin = view_visualcpp_screenshot(parent=self) + self.screenShotWin.show() + + def addInstructionsMacOS(self): + self.scrollArea = QScrollArea() + _container = QWidget() + _layout = QVBoxLayout() + for t, text in enumerate(utils.install_javabridge_instructions_text()): + label = QLabel() + label.setText(text) + # label.setWordWrap(True) + if t == 1 or t == 2: + label.setWordWrap(True) + code_layout = QHBoxLayout() + code_layout.addWidget(label) + copyButton = QToolButton() + copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) + copyButton.setIcon(QIcon(":edit-copy.svg")) + copyButton.setText("Copy") + if t == 1: + copyButton.textToCopy = utils._install_homebrew_command() + else: + copyButton.textToCopy = utils._brew_install_java_command() + copyButton.clicked.connect(self.copyToClipboard) + code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) + # code_layout.addStretch(1) + code_layout.setStretch(0, 2) + code_layout.setStretch(1, 0) + _layout.addLayout(code_layout) + else: + _layout.addWidget(label) + _container.setLayout(_layout) + self.scrollArea.setWidget(_container) + self.currentRow += 1 + self._layout.addWidget( + self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop + ) + + # Stretch last row + self.currentRow += 1 + self._layout.setRowStretch(self.currentRow, 1) + self.scrollArea.hide() + + def addInstructionsLinux(self): + self.scrollArea = QScrollArea() + _container = QWidget() + _layout = QVBoxLayout() + for t, text in enumerate(utils.install_javabridge_instructions_text()): + label = QLabel() + label.setText(text) + # label.setWordWrap(True) + if t == 1 or t == 2 or t == 3: + label.setWordWrap(True) + code_layout = QHBoxLayout() + code_layout.addWidget(label) + copyButton = QToolButton() + copyButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) + copyButton.setIcon(QIcon(":edit-copy.svg")) + copyButton.setText("Copy") + if t == 1: + copyButton.textToCopy = utils._apt_update_command() + elif t == 2: + copyButton.textToCopy = utils._apt_install_java_command() + elif t == 3: + copyButton.textToCopy = utils._apt_gcc_command() + copyButton.clicked.connect(self.copyToClipboard) + code_layout.addWidget(copyButton, alignment=Qt.AlignLeft) + # code_layout.addStretch(1) + code_layout.setStretch(0, 2) + code_layout.setStretch(1, 0) + _layout.addLayout(code_layout) + else: + _layout.addWidget(label) + _container.setLayout(_layout) + self.scrollArea.setWidget(_container) + self.currentRow += 1 + self._layout.addWidget( + self.scrollArea, self.currentRow, 1, alignment=Qt.AlignTop + ) + + # Stretch last row + self.currentRow += 1 + self._layout.setRowStretch(self.currentRow, 1) + self.scrollArea.hide() + + def copyToClipboard(self): + cb = QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + cb.setText(self.sender().textToCopy, mode=cb.Clipboard) + print("Command copied!") + + def showInstructions(self, checked): + if checked: + self.instructionsButton.setText("Hide instructions") + self.origHeight = self.height() + self.resize(self.width(), self.height() + 300) + self.scrollArea.show() + else: + self.instructionsButton.setText("Show instructions...") + self.scrollArea.hide() + func = partial(self.resize, self.width(), self.origHeight) + QTimer.singleShot(50, func) + + def installJava(self): + import subprocess + + try: + if is_mac: + try: + subprocess.check_call(["brew", "update"]) + except Exception as e: + subprocess.run( + utils._install_homebrew_command(), + check=True, + text=True, + shell=True, + ) + subprocess.run( + utils._brew_install_java_command(), + check=True, + text=True, + shell=True, + ) + elif is_linux: + subprocess.run( + utils._apt_gcc_command()(), check=True, text=True, shell=True + ) + subprocess.run( + utils._apt_update_command()(), check=True, text=True, shell=True + ) + subprocess.run( + utils._apt_install_java_command()(), + check=True, + text=True, + shell=True, + ) + self.close() + except Exception as e: + print("=======================") + traceback.print_exc() + print("=======================") + msg = myMessageBox(wrapText=False) + err_msg = html_utils.paragraph(""" + Automatic installation of Java failed.

    + Please, try manually by following the instructions provided + below (click on "Show instructions..." button). Thanks + """) + msg.critical(self, "Java installation failed", err_msg) + + def show(self, block=False): + super().show(block=False) + print(is_linux) + if is_win: + self.addInstructionsWindows() + elif is_mac: + self.addInstructionsMacOS() + elif is_linux: + self.addInstructionsLinux() + self.move(self.pos().x(), 20) + if is_win: + self.resize(self.width(), self.height() + 200) + if block: + self._block() + + def exec_(self): + self.show(block=True) + + +class selectTrackerGUI(QDialogListbox): + def __init__(self, SizeT, currentFrameNo=1, parent=None): + trackers = utils.get_list_of_trackers() + super().__init__( + "Select tracker", + "Select one of the following trackers", + trackers, + multiSelection=False, + parent=parent, + ) + self.setWindowTitle("Select tracker") + + self.selectFramesGroupbox = selectStartStopFrames( + SizeT, currentFrameNum=currentFrameNo, parent=parent + ) + + self.mainLayout.insertWidget(1, self.selectFramesGroupbox) + + def ok_cb(self, event): + if self.selectFramesGroupbox.warningLabel.text(): + return + else: + self.startFrame = self.selectFramesGroupbox.startFrame_SB.value() + self.stopFrame = self.selectFramesGroupbox.stopFrame_SB.value() + super().ok_cb(event) + + +class warnVisualCppRequired(myMessageBox): + def __init__(self, pkg_name="javabridge", parent=None): + super().__init__(parent) + self.screenShotWin = None + + self.setIcon(iconName="SP_MessageBoxWarning") + self.setWindowTitle(f"Installation of {pkg_name} info") + txt = html_utils.paragraph(f""" + Installation of {pkg_name} on Windows requires + Microsoft Visual C++ 14.0 or higher.

    + Cell-ACDC will anyway try to install {pkg_name} now.

    + If the installation fails, please close Cell-ACDC, + then download and install "Microsoft C++ Build Tools" + from the link below + before trying this module again.

    + + https://visualstudio.microsoft.com/visual-cpp-build-tools/ +

    + IMPORTANT: when installing "Microsoft C++ Build Tools" + make sure to select "Desktop development with C++". + Click "See the screenshot" for more details. + """) + seeScreenshotButton = QPushButton("See screenshot...") + okButton = okPushButton("Ok") + okButton = self.addButton("Ok") + okButton.disconnect() + okButton.clicked.connect(self.ok_cb) + self.addButton(seeScreenshotButton) + seeScreenshotButton.disconnect() + seeScreenshotButton.clicked.connect(self.viewScreenshot) + self.addCancelButton(connect=True) + self.addText(txt) + + def ok_cb(self): + self.cancel = False + self.close() + + def viewScreenshot(self, checked=False): + self.screenShotWin = view_visualcpp_screenshot(self) + self.screenShotWin.show() + + def closeEvent(self, event): + if self.screenShotWin is not None: + self.screenShotWin.close() + + return super().closeEvent(event) + +# Cross-module imports (deferred to avoid import cycles) +from .forms import ( + CopiableCommandWidget, + LabelsWidget, + selectStartStopFrames, +) +from .panels import ( + listWidget, +) + diff --git a/cellacdc/widgets/controls/forms.py b/cellacdc/widgets/controls/forms.py new file mode 100644 index 000000000..1a268672e --- /dev/null +++ b/cellacdc/widgets/controls/forms.py @@ -0,0 +1,1382 @@ +"""Composite controls: forms.""" + +"""GUI widgets: controls.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +from ..canvas.scrollbars import ( + sliderWithSpinBox, +) +from .inputs import ( + ComboBox, +) + +class selectStartStopFrames(QGroupBox): + def __init__(self, SizeT, currentFrameNum=0, parent=None): + super().__init__(parent) + selectFramesLayout = QGridLayout() + + self.startFrame_SB = QSpinBox() + self.startFrame_SB.setAlignment(Qt.AlignCenter) + self.startFrame_SB.setMinimum(1) + self.startFrame_SB.setMaximum(SizeT - 1) + self.startFrame_SB.setValue(currentFrameNum) + + self.stopFrame_SB = QSpinBox() + self.stopFrame_SB.setAlignment(Qt.AlignCenter) + self.stopFrame_SB.setMinimum(1) + self.stopFrame_SB.setMaximum(SizeT) + self.stopFrame_SB.setValue(SizeT) + + selectFramesLayout.addWidget(QLabel("Start frame n."), 0, 0) + selectFramesLayout.addWidget(self.startFrame_SB, 1, 0) + + selectFramesLayout.addWidget(QLabel("Stop frame n."), 0, 1) + selectFramesLayout.addWidget(self.stopFrame_SB, 1, 1) + + self.warningLabel = QLabel() + palette = self.warningLabel.palette() + palette.setColor(self.warningLabel.backgroundRole(), Qt.red) + palette.setColor(self.warningLabel.foregroundRole(), Qt.red) + self.warningLabel.setPalette(palette) + selectFramesLayout.addWidget( + self.warningLabel, 2, 0, 1, 2, alignment=Qt.AlignCenter + ) + + self.setLayout(selectFramesLayout) + + self.stopFrame_SB.valueChanged.connect(self._checkRange) + + def _checkRange(self): + start = self.startFrame_SB.value() + stop = self.stopFrame_SB.value() + if stop <= start: + self.warningLabel.setText("stop frame smaller than start frame") + else: + self.warningLabel.setText("") + + +class formWidget(QWidget): + sigApplyButtonClicked = Signal(object) + sigComputeButtonClicked = Signal(object) + + def __init__( + self, + widget, + initialVal=None, + stretchWidget=True, + widgetAlignment=None, + labelTextLeft="", + labelTextRight="", + font=None, + addInfoButton=False, + addApplyButton=False, + addComputeButton=False, + addActivateCheckbox=False, + key="", + infoTxt="", + valueGetterName="value", + toolTip="", + parent=None, + ): + QWidget.__init__(self, parent) + self.widget = widget + self.key = key + self.infoTxt = infoTxt + self.widgetAlignment = widgetAlignment + self.valueGetterName = valueGetterName + + widget.setParent(self) + + if isinstance(initialVal, bool): + widget.setChecked(initialVal) + elif isinstance(initialVal, str): + widget.setCurrentText(initialVal) + elif isinstance(initialVal, float) or isinstance(initialVal, int): + widget.setValue(initialVal) + + self.items = [] + + if font is None: + font = QFont() + font.setPixelSize(13) + + self.labelLeft = QClickableLabel(widget) + self.labelLeft.setText(labelTextLeft) + self.labelLeft.setFont(font) + self.items.append(self.labelLeft) + + if not stretchWidget: + widgetLayout = QHBoxLayout() + if widgetAlignment != "left": + widgetLayout.addStretch(1) + widgetLayout.addWidget(widget) + if widgetAlignment != "right": + widgetLayout.addStretch(1) + self.items.append(widgetLayout) + else: + self.items.append(widget) + + self.labelRight = QClickableLabel(widget) + self.labelRight.setText(labelTextRight) + self.labelRight.setFont(font) + self.items.append(self.labelRight) + + if toolTip: + self.labelLeft.setToolTip(toolTip) + self.widget.setToolTip(toolTip) + self.labelRight.setToolTip(toolTip) + + if addInfoButton: + infoButton = QPushButton(self) + infoButton.setCursor(Qt.WhatsThisCursor) + infoButton.setIcon(QIcon(":info.svg")) + if labelTextLeft: + infoButton.setToolTip(f'Info about "{self.labelLeft.text()}" parameter') + else: + infoButton.setToolTip( + f'Info about "{self.labelRight.text()}" measurement' + ) + infoButton.clicked.connect(self.showInfo) + self.infoButton = infoButton + self.items.append(infoButton) + + if addApplyButton: + applyButton = QPushButton(self) + applyButton.setCursor(Qt.PointingHandCursor) + applyButton.setCheckable(True) + applyButton.setIcon(QIcon(":apply.svg")) + applyButton.setToolTip(f"Apply this step and visualize results") + applyButton.clicked.connect(self.applyButtonClicked) + self.items.append(applyButton) + + if addComputeButton: + computeButton = QPushButton(self) + computeButton.setCursor(Qt.BusyCursor) + computeButton.setIcon(QIcon(":compute.svg")) + computeButton.setToolTip(f"Compute this step and visualize results") + computeButton.clicked.connect(self.computeButtonClicked) + self.items.append(computeButton) + + self.activateCheckbox = None + if addActivateCheckbox: + self.activateCheckbox = QCheckBox("Activate") + self.activateCheckbox.setChecked(False) + self.widget.setDisabled(True) + self.activateCheckbox.toggled.connect(self.setWidgetEnabled) + self.items.append(self.activateCheckbox) + + self.labelLeft.clicked.connect(self.tryChecking) + self.labelRight.clicked.connect(self.tryChecking) + + def setWidgetEnabled(self, checked): + self.widget.setDisabled(not checked) + + def value(self): + if self.activateCheckbox is None: + return getattr(self.widget, self.valueGetterName)() + + if not self.activateCheckbox.isChecked(): + return + + return getattr(self.widget, self.valueGetterName)() + + def tryChecking(self, label): + try: + self.widget.setChecked(not self.widget.isChecked()) + except AttributeError as e: + pass + + def applyButtonClicked(self): + self.sigApplyButtonClicked.emit(self) + + def computeButtonClicked(self): + self.sigComputeButtonClicked.emit(self) + + def showInfo(self): + msg = myMessageBox() + msg.setIcon() + msg.setWindowTitle(f"{self.labelLeft.text()} info") + msg.addText(self.infoTxt) + msg.addButton(" Ok ") + msg.exec_() + + def setDisabled(self, disabled: bool) -> None: + for item in self.items: + try: + item.setDisabled(disabled) + except Exception as err: + pass + + +class CheckboxesGroupBox(QGroupBox): + def __init__(self, texts, title="", checkable=False, parent=None): + super().__init__(parent) + + self.setTitle(title) + self.setCheckable(checkable) + layout = QVBoxLayout() + + scrollLayout = QVBoxLayout() + container = QWidget() + scrollarea = QScrollArea() + + self.checkBoxes = [] + for text in texts: + checkbox = QCheckBox(text) + checkbox.setChecked(True) + scrollLayout.addWidget(checkbox) + self.checkBoxes.append(checkbox) + + container.setLayout(scrollLayout) + scrollarea.setWidget(container) + layout.addWidget(scrollarea) + + buttonsLayout = QHBoxLayout() + selectAllButton = selectAllPushButton() + selectAllButton.sigClicked.connect(self.checkAll) + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(selectAllButton) + layout.addLayout(buttonsLayout) + + self.setLayout(layout) + + def checkAll(self, button, checked): + for checkBox in self.checkBoxes: + checkBox.setChecked(checked) + + +class guiTabControl(QTabWidget): + def __init__(self, *args): + super().__init__(args[0]) + + self._defaultPixelSize = None + + self.propsTab = QScrollArea(self) + + container = QWidget() + layout = QVBoxLayout() + + self.pixelSizeQGBox = PixelSizeGroupbox(parent=self.propsTab) + self.propsQGBox = objPropsQGBox(parent=self.propsTab) + self.intensMeasurQGBox = objIntesityMeasurQGBox(parent=self.propsTab) + + self.highlightCheckbox = QCheckBox("Highlight objects on mouse hover") + self.highlightCheckbox.setChecked(False) + + self.highlightSearchedCheckbox = QCheckBox("Highlight searched object") + self.highlightSearchedCheckbox.setChecked(True) + + highlightLayout = QHBoxLayout() + highlightLayout.addWidget(self.highlightCheckbox) + highlightLayout.addStretch(1) + highlightLayout.addWidget(QLabel("|")) + highlightLayout.addStretch(1) + highlightLayout.addWidget(self.highlightSearchedCheckbox) + + layout.addLayout(highlightLayout) + layout.addWidget(self.pixelSizeQGBox) + layout.addWidget(self.propsQGBox) + layout.addWidget(self.intensMeasurQGBox) + layout.addStretch(1) + container.setLayout(layout) + + self.propsTab.setWidgetResizable(True) + self.propsTab.setWidget(container) + self.addTab(self.propsTab, "Measurements") + + self.pixelSizeQGBox.sigValueChanged.connect(self.pixelSizeChanged) + self.pixelSizeQGBox.sigReset.connect(self.resetPixelSize) + + def addChannels(self, channels): + self.intensMeasurQGBox.addChannels(channels) + + def resetPixelSize(self): + if self._defaultPixelSize is None: + return + + self.initPixelSize(*self._defaultPixelSize) + + def initPixelSize(self, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ): + self.pixelSizeQGBox.pixelWidthWidget.setValue(PhysicalSizeX) + self.pixelSizeQGBox.pixelHeightWidget.setValue(PhysicalSizeY) + self.pixelSizeQGBox.voxelDepthWidget.setValue(PhysicalSizeZ) + self._defaultPixelSize = (PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ) + + def pixelSizeChanged(self, PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ): + propsQGBox = self.propsQGBox + yx_pxl_to_um2 = PhysicalSizeY * PhysicalSizeX + vox_rot_to_fl = float(PhysicalSizeY) * pow(float(PhysicalSizeX), 2) + vox_3D_to_fl = PhysicalSizeZ * PhysicalSizeY * PhysicalSizeX + + area_pxl = propsQGBox.cellAreaPxlSB.value() + area_um2 = area_pxl * yx_pxl_to_um2 + propsQGBox.cellAreaUm2DSB.setValue(area_um2) + + vol_rot_vox = propsQGBox.cellVolVoxSB.value() + vol_rot_fl = vol_rot_vox * vox_rot_to_fl + propsQGBox.cellVolFlDSB.setValue(vol_rot_fl) + + vol_3D_vox = propsQGBox.cellVolVox3D_SB.value() + vol_3D_fl = vol_3D_vox * vox_3D_to_fl + propsQGBox.cellVolFl3D_DSB.setValue(vol_3D_fl) + + +class PostProcessSegmSlider(sliderWithSpinBox): + def __init__(self, *args, label=None, **kwargs): + super().__init__(*args, **kwargs) + + self.label = label + self.checkbox = QCheckBox("Disable") + self._layout.addWidget(self.checkbox, self.sliderCol, self.lastCol + 1) + self.checkbox.toggled.connect(self.onCheckBoxToggled) + self.valueChanged.connect(self.checkExpandRange) + + def onCheckBoxToggled(self, checked: bool) -> None: + super().setDisabled(checked) + if self.label is not None: + self.label.setDisabled(checked) + self.onValueChanged(None) + self.onEditingFinished() + + def onValueChanged(self, value): + self.valueChanged.emit(value) + + def checkExpandRange(self, value): + if value == self.maximum(): + range = int(self.maximum() - self.minimum()) + half_range = int(range / 2) + newMinimum = self.minimum() + half_range + newMaximum = self.maximum() + half_range + self.setMaximum(newMaximum) + self.setMinimum(newMinimum) + elif value == self.minimum(): + range = int(self.maximum() - self.minimum()) + half_range = int(range / 2) + newMinimum = self.minimum() - half_range + newMaximum = self.maximum() - half_range + self.setMaximum(newMaximum) + self.setMinimum(newMinimum) + + def onEditingFinished(self): + self.editingFinished.emit() + + def value(self): + if self.checkbox.isChecked(): + return None + else: + return super().value() + + +class PostProcessSegmSpinbox(QWidget): + valueChanged = Signal(int) + editingFinished = Signal() + sigCheckboxToggled = Signal() + + def __init__(self, *args, isFloat=False, label=None, **kwargs): + super().__init__(*args, **kwargs) + + layout = QHBoxLayout() + + if isFloat: + self.spinBox = DoubleSpinBox() + else: + self.spinBox = SpinBox() + + self.spinBox.valueChanged.connect(self.onValueChanged) + self.spinBox.editingFinished.connect(self.onEditingFinished) + + layout.addWidget(self.spinBox) + self.checkbox = QCheckBox("Disable") + layout.addWidget(self.checkbox) + layout.setStretch(0, 1) + layout.setStretch(1, 0) + + self.label = label + + self.checkbox.toggled.connect(self.onCheckBoxToggled) + + layout.setContentsMargins(5, 0, 5, 0) + + self.setLayout(layout) + + def onCheckBoxToggled(self, checked: bool) -> None: + self.spinBox.setDisabled(checked) + if self.label is not None: + self.label.setDisabled(checked) + self.onValueChanged(None) + self.onEditingFinished() + + def onValueChanged(self, value): + self.valueChanged.emit(value) + + def onEditingFinished(self): + self.editingFinished.emit() + + def maximum(self): + return self.spinBox.maximum() + + def setValue(self, value): + self.spinBox.setValue(value) + + def sizeHint(self): + return self.spinBox.sizeHint() + + def setMaximum(self, max): + self.spinBox.setMaximum(max) + + def setSingleStep(self, step): + self.spinBox.setSingleStep(step) + + def setMinimum(self, min): + self.spinBox.setMinimum(min) + + def setSingleStep(self, step): + self.spinBox.setSingleStep(step) + + def setDecimals(self, decimals): + self.spinBox.setDecimals(decimals) + + def value(self): + if self.checkbox.isChecked(): + return None + else: + return self.spinBox.value() + + +class CopiableCommandWidget(QGroupBox): + def __init__(self, command="", parent=None, font_size="13px"): + super().__init__(parent) + + layout = QHBoxLayout() + + label = QLabel(self) + self.label = label + self._font_size = font_size + self.setCommand(command, font_size=font_size) + label.setTextInteractionFlags( + Qt.TextBrowserInteraction | Qt.TextSelectableByKeyboard + ) + layout.addWidget(label) + layout.addWidget(QVLine(shadow="Plain", color="#4d4d4d")) + copyButton = copyPushButton("Copy", flat=True, hoverable=True) + copyButton.clicked.connect(self.copyToClipboard) + layout.addWidget(copyButton) + layout.addStretch(1) + + self.setLayout(layout) + + def setWordWrap(self, wordWrap): + self.label.setWordWrap(wordWrap) + + def copyToClipboard(self): + cb = QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + cb.setText(self._command, mode=cb.Clipboard) + print("Command copied!") + + def setCommand(self, command, font_size=None): + if font_size is None: + font_size = self._font_size + + self._command = command + txt = html_utils.paragraph(f"{command}", font_size=font_size) + self.label.setText(txt) + + def command(self): + return self._command + + def text(self): + return self.label.text() + + def setTextInteractionFlags(self, flags): + self.label.setTextInteractionFlags(flags) + + +class LabelsWidget(QWidget): + def __init__(self, texts, wrapText=False, parent=None): + super().__init__(parent=parent) + + layout = QVBoxLayout() + + texts = self.fixParagraphTags(texts) + + self.textLengths = [] + self.labels = [] + for t, text in enumerate(texts): + if not text: + continue + + if text.startswith(""): + layout.addSpacing(10) + label = LatexLabel(text) + layout.addWidget(label, alignment=Qt.AlignCenter) + try: + # Add spacing only if next text is not a formula + nextText = texts[t + 1] + if not nextText.startswith(""): + layout.addSpacing(10) + except IndexError: + layout.addSpacing(10) + elif text.startswith(""): + text = text.removeprefix("").removeprefix("") + label = CopiableCommandWidget(command=text, parent=self) + layout.addWidget(label) + else: + label = QLabel(text) + label.setWordWrap(wrapText) + label.setOpenExternalLinks(True) + layout.addWidget(label) + if wrapText: + self.textLengths.append(1) + self.textLengths.extend([len(line) for line in text.split("
    ")]) + + self.labels.append(label) + + self.nCharsLongestLine = max(self.textLengths, default=1) + + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + def setWordWrap(self, wordWrap): + for label in self.labels: + label.setWordWrap(wordWrap) + + def fixParagraphTags(self, texts): + firstText = texts[0] + if firstText.find("

    ', firstText) + if searched is None: + openTag = '

    ' + else: + openTag = searched.group() + + not_allowed = {" ", "\n"} + + fixedTexts = [] + for text in texts: + if text.startswith("") or text.startswith(""): + fixedTexts.append(text) + continue + + if set(text) <= not_allowed: + # Ignore texts that are made of only \n and spaces + continue + + if text.find("

    ") == -1: + text = rf"{text}<\p>" + + if text.find(openTag) == -1: + text = f"{openTag}{text}" + + text = text.replace("\n", "") + + fixedTexts.append(text) + return fixedTexts + + +class SamInputPointsWidget(QWidget): + sigValueChanged = Signal(str) + + def __init__(self, parent=None): + super().__init__(parent) + + _layout = QHBoxLayout() + + self.lineEntry = ElidingLineEdit(parent=self) + self.lineEntry.setAlignment(Qt.AlignCenter) + self.lineEntry.editingFinished.connect(self.emitValueChanged) + + self.editButton = editPushButton() + self.browseButton = browseFileButton( + ext={"CSV": ".csv"}, start_dir=utils.getMostRecentPath() + ) + + _layout.addWidget(self.lineEntry) + _layout.addWidget(self.editButton) + _layout.addWidget(self.browseButton) + + _layout.setStretch(0, 1) + _layout.setStretch(1, 0) + _layout.setStretch(1, 0) + + self.browseButton.sigPathSelected.connect(self.browseCsvFiles) + self.editButton.clicked.connect(self.showInfoEditPoints) + + _layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(_layout) + + def emitValueChanged(self, text): + self.sigValueChanged.emit(text) + + def showInfoEditPoints(self): + note = html_utils.to_note( + "When adding points with the mouse left button you will create a " + "new object for each point. To add multiple points for the same " + "object click the right button." + ) + txt = html_utils.paragraph(f""" + To add input points for Segment Anything open the GUI (module 3), + load the data, and then click on the button
    + on the top toolbar called Add points layer.

    + Select the option "Add points by clicking" and click on the image + to add points.

    + Finally, save the table and browse to the saved file on this widget. +
    {note} + """) + msg = myMessageBox(wrapText=False) + msg.information(self, "Info edit points", txt) + + def criticalMissingColumn(self, filepath, missing_col): + txt = html_utils.paragraph(f""" + [ERROR]: The selected table does not contain the column + {missing_col}.

    + A valid table must contain the columns (x, y, id) + with an additional z column for 3D z-stacks data. + """) + msg = myMessageBox(wrapText=False) + msg.critical(self, "Invalid table", txt) + + def setValue(self, value: str): + self.lineEntry.setText(value) + + def value(self): + return self.lineEntry.text() + + def cast_dtype(self, value) -> str: + return str(value) + + def browseCsvFiles(self, filepath): + # Check if metadata.csv file exists with basename and set only the + # endname of the file + df_points = pd.read_csv(filepath) + for col in ("x", "y", "id"): + if col not in df_points.columns: + self.criticalMissingColumn(filepath, col) + return + + # Check if basename is present in metadata + folderpath = os.path.dirname(filepath) + basename = None + for file in utils.listdir(folderpath): + if file.endswith("metadata.csv"): + metadata_csv_path = os.path.join(folderpath, file) + df = pd.read_csv(metadata_csv_path, index_col="Description") + try: + basename = df.at["basename", "values"] + except Exception as e: + basename = None + break + + # Check if file is inside images folder and get basename + is_images_folder = folderpath.endswith("Images") + if is_images_folder: + images_path = folderpath + img_filepath = None + for file in utils.listdir(images_path): + if file.endswith(".tif"): + img_filepath = os.path.join(images_path, file) + break + + if file.endswith("aligned.npz"): + img_filepath = os.path.join(images_path, file) + break + + if img_filepath is not None: + posData = load.loadData(img_filepath, "", QParent=self) + posData.getBasenameAndChNames() + filename = os.path.basename(filepath) + if filename.startswith(posData.basename): + basename = posData.basename + + if basename is None: + self.lineEntry.setText(filepath) + else: + filename = os.path.basename(filepath) + endname = filename[len(basename) :] + self.lineEntry.setText(endname) + + +class FontSizeWidget(QWidget): + sigTextChanged = Signal(str) + + def __init__(self, parent=None, unit="px", initalVal=12): + super().__init__(parent) + + layout = QHBoxLayout() + + self.spinbox = SpinBox() + self.spinbox.setValue(initalVal) + layout.addWidget(self.spinbox) + + self.unitLabel = QLabel(unit) + layout.addWidget(self.unitLabel) + + layout.setContentsMargins(0, 0, 0, 0) + layout.setStretch(0, 1) + layout.setStretch(1, 0) + + self.setLayout(layout) + + self.spinbox.valueChanged.connect(self.emitTextChanged) + + def emitTextChanged(self, value): + self.sigTextChanged.emit(self.text()) + + def setValue(self, value): + if isinstance(value, str): + value = int(value.replace(self.unitLabel.text(), "").strip()) + self.spinbox.setValue(value) + + def setText(self, text): + value = int(text.replace(self.unitLabel.text(), "").strip()) + self.setValue(value) + + def text(self): + return f"{self.spinbox.value()}{self.unitLabel.text()}" + + def value(self): + return self.spinbox.value() + + +class RangeSelector(QWidget): + sigRangeChanged = Signal(object, object) + sigLowValueChanged = Signal(object) + sigHighValueChanged = Signal(object) + sigRangeManuallyChanged = Signal(object, object) + + def __init__(self, parent=None, integers=False, ordered=True): + super().__init__(parent) + + self._integers = integers + self._ordered = ordered + + layout = QHBoxLayout() + + if integers: + self.lowSpinbox = SpinBox() + self.highSpinbox = SpinBox() + else: + self.lowSpinbox = DoubleSpinBox() + self.highSpinbox = DoubleSpinBox() + + layout.addWidget(self.lowSpinbox) + layout.addWidget(self.highSpinbox) + + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + self.lowSpinbox.valueChanged.connect(self.lowValueChanged) + self.highSpinbox.valueChanged.connect(self.highValueChanged) + + self.lowSpinbox.editingFinished.connect(self.lowValueEditingFinished) + self.highSpinbox.editingFinished.connect(self.highValueEditingFinished) + + def lowValueEditingFinished(self): + self.sigRangeManuallyChanged.emit(*self.range()) + self.emitRangeChanged() + + def highValueEditingFinished(self): + self.sigRangeManuallyChanged.emit(*self.range()) + self.emitRangeChanged() + + def lowValueChanged(self, value): + self.emitRangeChanged() + self.sigLowValueChanged.emit(value) + + def highValueChanged(self, value): + self.emitRangeChanged() + self.sigHighValueChanged.emit(value) + + def emitRangeChanged(self): + self.sigRangeChanged.emit(*self.range()) + + def setRangeNoEmit(self, lowValue, highValue, decimals=3): + self.lowSpinbox.valueChanged.disconnect() + self.highSpinbox.valueChanged.disconnect() + + self.setRange(round(lowValue, 3), round(highValue, 3)) + + self.lowSpinbox.valueChanged.connect(self.lowValueChanged) + self.highSpinbox.valueChanged.connect(self.highValueChanged) + + def setRange(self, lowValue, highValue): + # if lowValue > highValue and self._ordered: + # highValue = lowValue + 1 + + if self._integers: + lowValue = round(lowValue) + highValue = round(highValue) + + self.lowSpinbox.setValue(lowValue) + self.highSpinbox.setValue(highValue) + + def range(self): + return self.lowSpinbox.value(), self.highSpinbox.value() + + +class PreProcessingSelector(QComboBox): + sigValuesChanged = Signal(dict, int) + + def __init__(self, parent=None): + super().__init__(parent) + self._parent = parent + + self.addItems(PREPROCESS_MAPPER.keys()) + self.methodToDefaultValuesMapper = {} + self.step_n = -1 + self.setParamsWindow = None + + def htmlInfo(self): + href = html_utils.href_tag("GitHub page", urls.issues_url) + docstring = PREPROCESS_MAPPER[self.currentText()]["docstring"] + if docstring is None: + text = "This function is not documented, yet. Sorry :(" + else: + text = html_utils.rst_docstring_to_html(docstring) + text = ( + f"{text}

    " + f"Feel free to submit an issue on our {href} if you " + "need help with this filter." + ) + return text + + def setParams(self, method: str, kwargToValueMapper: Dict[str, str]): + self.methodToDefaultValuesMapper[method] = kwargToValueMapper + + def askSetParams(self, df_metadata=None, addApplyButton=False): + method = self.currentText() + function = PREPROCESS_MAPPER[method]["function"] + params_argspecs = utils.get_function_argspec( + function, + args_to_skip={"logger_func", "apply_to_all_zslices", "apply_to_all_frames"}, + ) + default_values = self.methodToDefaultValuesMapper.get(method, {}) + for kwarg, value in default_values.items(): + for p, param_argspec in enumerate(params_argspecs): + if param_argspec.name != kwarg: + continue + + if hasattr(param_argspec.type, "cast_dtype"): + cls = param_argspec.type + value = cls.cast_dtype(value) + else: + value = param_argspec.type(value) + + if value == param_argspec.default: + continue + param_argspec = param_argspec._replace(default=value) + params_argspecs[p] = param_argspec + + if self.setParamsWindow is not None: + self.setParamsWindow.raise_() + self.setParamsWindow.activateWindow() + return + + self.setParamsWindow = apps.FunctionParamsDialog( + params_argspecs, + df_metadata=df_metadata, + function_name=method, + addApplyButton=addApplyButton, + parent=self._parent, + ) + self.setParamsWindow.sigValuesChanged.connect(self.emitValuesChanged) + self.setParamsWindow.emitValuesChanged() + self.setParamsWindow.exec_() + if self.setParamsWindow.cancel: + return + + self.setParams(method, self.setParamsWindow.function_kwargs) + + function_kwargs = self.setParamsWindow.function_kwargs + self.setParamsWindow = None + + return function_kwargs + + def emitValuesChanged(self, functionKwargs: dict): + self.sigValuesChanged.emit(functionKwargs, self.step_n) + + +class RescaleImageJroisGroupbox(QGroupBox): + def __init__(self, TZYX_out_shape, parent=None): + super().__init__(parent) + + self.setTitle("Rescale ROIs") + self.setCheckable(True) + + gridLayout = QGridLayout() + + dims = ("Z", "Y", "X") + self.widgets = {} + for row, SizeD in enumerate(TZYX_out_shape[1:]): + if SizeD == 1: + continue + + dim = dims[row] + inputSpinbox = SpinBox() + inputSpinbox.setMinimum(1) + inputSpinbox.setValue(SizeD) + + outZwidget = QLineEdit() + outZwidget.setReadOnly(True) + outZwidget.setAlignment(Qt.AlignCenter) + # outZwidget.setValue(SizeD) + outZwidget.setText(str(SizeD)) + + row0 = row * 2 + row1 = row0 + 1 + gridLayout.addWidget(QLabel(f"{dim}-dimension: "), row1, 0) + + gridLayout.addWidget(QLabel("Input size"), row0, 1) + gridLayout.addWidget(inputSpinbox, row1, 1) + + gridLayout.addWidget(QLabel("Output size"), row0, 2) + gridLayout.addWidget(outZwidget, row1, 2) + + self.widgets[dim] = (inputSpinbox, SizeD) + + self.setLayout(gridLayout) + + def inputOutputSizes(self): + if not self.isChecked(): + return + + sizes = { + dim: (spinbox.value(), int(SizeD)) + for dim, (spinbox, SizeD) in self.widgets.items() + } + return sizes + + +class TimeWidget(QGroupBox): + sigValueChanged = Signal(object) + + def __init__(self, parent=None, orientation="vertical"): + super().__init__(parent) + + mainLayout = QHBoxLayout() + + if orientation == "vertical": + spinboxesLayout = QVBoxLayout() + elif orientation == "horizontal": + spinboxesLayout = QHBoxLayout() + else: + raise ValueError('orientation must be "vertical" or "horizontal"') + + self.signCombobox = QComboBox() + self.signCombobox.addItems(("+", "-")) + self.signCombobox.currentTextChanged.connect(self.emitValueChanged) + + mainLayout.addWidget(self.signCombobox) + + self.spinboxesMapper = {} + units = ("days", "hours", "minutes", "seconds") + for unit in units: + layout = QHBoxLayout() + spinbox = SpinBox() + spinbox.setMinimum(0) + label = QLabel(unit) + layout.addWidget(spinbox) + layout.addWidget(label) + spinbox.valueChanged.connect(self.emitValueChanged) + self.spinboxesMapper[unit] = spinbox + spinboxesLayout.addLayout(layout) + + mainLayout.addLayout(spinboxesLayout) + + self.setLayout(mainLayout) + mainLayout.setContentsMargins(5, 5, 5, 5) + + def values(self): + values = {} + for unit, spinbox in self.spinboxesMapper.items(): + values[unit] = spinbox.value() + + signText = self.signCombobox.currentText() + return values, sign_int_mapper[signText] + + def setValuesFromTimedelta(self, timedelta): + total_seconds = timedelta.total_seconds() + sign = 1 if total_seconds > 0 else -1 + days = timedelta.days + hours, remainder = divmod(timedelta.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + values = {"days": days, "hours": hours, "minutes": minutes, "seconds": seconds} + + self.setValues(values, sign=sign) + + def timedelta(self): + values, sign = self.values() + return datetime.timedelta(**values) * sign + + def setValues(self, values: dict[str, int | float], sign=1): + signText = "+" if sign > 0 else "-" + self.signCombobox.setCurrentText(signText) + for unit, value in values.items(): + spinbox = self.spinboxesMapper[unit] + spinbox.setValue(value) + + def emitValueChanged(self, value): + self.sigValueChanged.emit(self.values()) + + +class YeazV2SelectModelNameCombobox(ComboBox): + sigValueChanged = Signal(str) + + def __init__( + self, *args, custom_select_item_text="Select custom weights file...", **kwargs + ): + super().__init__(*args, **kwargs) + self._csi_text = custom_select_item_text + self.sigTextChanged.connect(self.onTextChanged) + self.initItems() + + def initItems(self): + from cellacdc.segmenters.YeaZ_v2 import load_models_filepath + + models_name, models_name_filepath_mapper = load_models_filepath() + self.addItems(models_name) + + def onTextChanged(self, text): + if text != self._csi_text: + return + + start_dir = utils.getMostRecentPath() + model_filepath = qtpy.compat.getopenfilename( + parent=self, + caption="Select YeaZ weights file", + filters="All Files (*)", + basedir=start_dir, + )[0] + if not model_filepath: + self.setCurrentIndex(0) + return + + msg = html_utils.paragraph(f""" + Insert a name for the following YeaZ model:

    + {model_filepath}
    + """) + modelNameWindow = apps.QLineEditDialog( + title="Insert a name for the model", msg=msg, allowEmpty=False, parent=self + ) + modelNameWindow.exec_() + if modelNameWindow.cancel: + self.setCurrentIndex(0) + return + + model_name = modelNameWindow.enteredValue + + from cellacdc.segmenters.YeaZ_v2 import add_model_filepath + + add_model_filepath(model_name, model_filepath) + + self.addItem(model_name) + self.setCurrentText(model_name) + + print( + "YeaZ_v2 model added!\n\n" + f" * Name: {model_name}\n" + f" * File path: {model_filepath}\n" + ) + + def addItem(self, item): + idx = self.count() - 1 + self.insertItem(idx, item) + + def addItems(self, items): + super().clear() + super().addItems(items) + super().addItem(self._csi_text) + idx = len(items) + font = self.font() + font.setItalic(True) + self.setItemData(idx, font, Qt.FontRole) + + def setValue(self, value: str): + self.setCurrentText(value) + + def value(self, *args): + return self.currentText() + + +class AutoSaveIntervalWidget(QWidget): + sigValueChanged = Signal(float, str) + + def __init__(self, parent=None): + super().__init__(parent) + + layout = QHBoxLayout() + + autoSaveIntervalTooltip = "Autosave every minutes or frames specified here." + + self.setToolTip(autoSaveIntervalTooltip) + + self.spinbox = DoubleSpinBox() + self.spinbox.setMinimum(0) + self.spinbox.setValue(2) + self.spinbox.setDecimals(2) + self.spinbox.setSingleStep(1.0) + + layout.addWidget(self.spinbox) + + self.unitCombobox = ComboBox() + self.unitCombobox.addItems(["minutes", "frames"]) + layout.addWidget(self.unitCombobox) + + layout.setStretch(0, 1) + layout.setStretch(1, 0) + layout.setContentsMargins(5, 0, 5, 0) + + self.setLayout(layout) + + self.spinbox.sigValueChanged.connect(self.emitSigValueChanged) + self.unitCombobox.sigTextChanged.connect(self.emitSigValueChanged) + + def emitSigValueChanged(self, *args, **kwargs): + self.sigValueChanged.emit(self.spinbox.value(), self.unitCombobox.currentText()) + + +class CheckableWidget(QWidget): + def __init__(self, widget, valueGetterName="value", parent=None): + super().__init__(parent) + + self.widget = widget + self.valueGetterName = valueGetterName + + widget.setDisabled(True) + + layout = QHBoxLayout() + + layout.addWidget(widget) + + self.checkbox = QCheckBox("Activate") + self.checkbox.toggled.connect(self.setWidgetEnabled) + + layout.addSpacing(5) + layout.addWidget(self.checkbox) + + layout.setContentsMargins(5, 0, 5, 0) + + self.setLayout(layout) + + def setWidgetEnabled(self, checked): + self.widget.setDisabled(not checked) + + def value(self): + if not self.checkbox.isChecked(): + return + + return getattr(self.widget, self.valueGetterName)() + +# Cross-module imports (deferred to avoid import cycles) +from .dialogs import ( + myMessageBox, +) +from .inputs import ( + DoubleSpinBox, + QClickableLabel, + SpinBox, +) +from .metrics import ( + PixelSizeGroupbox, + objIntesityMeasurQGBox, + objPropsQGBox, +) +from .panels import ( + LatexLabel, +) + diff --git a/cellacdc/widgets/controls/inputs.py b/cellacdc/widgets/controls/inputs.py new file mode 100644 index 000000000..a52b4c743 --- /dev/null +++ b/cellacdc/widgets/controls/inputs.py @@ -0,0 +1,976 @@ +"""Composite controls: inputs.""" + +"""GUI widgets: controls.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class ExpandableListBox(QComboBox): + def __init__(self, parent=None, centered=True) -> None: + super().__init__(parent) + + self.setEditable(True) + self.lineEdit().setReadOnly(True) + + infoTxt = html_utils.paragraph( + "Select Positions to save

    " + "Ctrl+Click to select multiple items
    " + "Shift+Click to select a range of items
    ", + center=True, + ) + + self.listW = QDialogListbox( + "Select Positions to save", infoTxt, [], multiSelection=True, parent=self + ) + + self.listW.listBox.itemClicked.connect(self.listItemClicked) + self.listW.sigSelectionConfirmed.connect(self.updateCombobox) + + self.centered = centered + + def listItemClicked(self, item): + if item.text().find("All") == -1: + return + + for i in range(self.listW.listBox.count()): + _item = self.listW.listBox.item(i) + _item.setSelected(True) + + def clear(self) -> None: + self.listW.listBox.clear() + return super().clear() + + def setItems(self, items): + self.clear() + self.addItems(items) + + def addItems(self, items): + super().addItems(items) + self.listW.listBox.addItems(items) + self.listW.listBox.setCurrentRow(self.currentIndex()) + self.listItemClicked(self.listW.listBox.currentItem()) + if self.centered: + self.centerItems() + + def updateCombobox(self, selectedItemsText): + isAllItem = [i for i, t in enumerate(selectedItemsText) if t.find("All") != -1] + if len(selectedItemsText) == 1: + self.setCurrentText(selectedItemsText[0]) + elif isAllItem: + idx = isAllItem[0] + self.setCurrentText(selectedItemsText[idx]) + else: + super().clear() + super().addItems(["Custom selection"]) + + def centerItems(self, idx=None): + self.lineEdit().setAlignment(Qt.AlignCenter) + + def selectedItems(self): + return self.listW.listBox.selectedItems() + + def selectedItemsText(self): + return [item.text() for item in self.selectedItems()] + + def showPopup(self) -> None: + self.listW.show() + + +class QClickableLabel(QLabel): + clicked = Signal(object) + + def __init__(self, parent=None): + self._parent = parent + super().__init__(parent) + self._checkableItem = None + + def setCheckableItem(self, widget): + self._checkableItem = widget + + def mousePressEvent(self, event): + self.clicked.emit(self) + if self._checkableItem is not None: + status = not self._checkableItem.isChecked() + self._checkableItem.setChecked(status) + + def setChecked(self, checked): + self._checkableItem.setChecked(checked) + + +class QCenteredComboBox(QComboBox): + def __init__(self, parent=None) -> None: + super().__init__(parent) + + self.setEditable(True) + self.lineEdit().setReadOnly(True) + self.lineEdit().setAlignment(Qt.AlignCenter) + self.lineEdit().installEventFilter(self) + + self.currentIndexChanged.connect(self.centerItems) + + self._isPopupVisibile = False + + def centerItems(self, idx): + for i in range(self.count()): + self.setItemData(i, Qt.AlignCenter, Qt.TextAlignmentRole) + + def eventFilter(self, lineEdit, event): + # Reimplement show popup on click + if event.type() == QEvent.Type.MouseButtonPress and self.isEnabled(): + if self._isPopupVisibile: + self.hidePopup() + self._isPopupVisibile = False + else: + self.showPopup() + self._isPopupVisibile = True + return True + return False + + +class AlphaNumericComboBox(QCenteredComboBox): + def __init__(self, parent=None) -> None: + super().__init__(parent=parent) + + def addItems(self, items): + self._dtype = type(items[0]) + super().addItems([str(item) for item in items]) + + def setCurrentValue(self, value): + super().setCurrentText(str(value)) + + def currentValue(self): + return self._dtype(super().currentText()) + + +class mySpinBox(QSpinBox): + sigTabEvent = Signal(object, object) + + def __init__(self, *args) -> None: + super().__init__(*args) + + def event(self, event): + if event.type() == QEvent.Type.KeyPress and event.key() == Qt.Key_Tab: + self.sigTabEvent.emit(event, self) + return True + + return super().event(event) + + +class ShortcutLineEdit(QLineEdit): + def __init__(self, parent=None, allowModifiers=False, notAllowedModifier=None): + self.keySequence = None + super().__init__(parent) + self._allowModifiers = allowModifiers + self._notAllowedModifier = notAllowedModifier + self.setAlignment(Qt.AlignCenter) + + def text(self): + text = macShortcutToWindows(super().text()) + + return text + + def setText(self, text): + text = windowsShortcutToMac(text) + + super().setText(text) + if not text: + self.keySequence = None + return + try: + self.keySequence = KeySequenceFromText(self.text()) + except Exception as e: + pass + + def keyPressEvent(self, event: QKeyEvent): + if event.key() == Qt.Key_Backspace or event.key() == Qt.Key_Delete: + self.setText("") + return + + keySequenceText = QKeyEventToString( + event, notAllowedModifier=self._notAllowedModifier + ) + self.setText(keySequenceText) + self.key = event.key() + + def keyReleaseEvent(self, event: QKeyEvent) -> None: + if self.text().endswith("+"): + if not self._allowModifiers: + self.setText("") + else: + self.setText(self.text().rstrip("+").strip()) + + +class CenteredDoubleSpinbox(QDoubleSpinBox): + def __init__(self, parent=None): + super().__init__(parent=parent) + self.setAlignment(Qt.AlignCenter) + self.setMaximum(2**31 - 1) + + +class readOnlyDoubleSpinbox(QDoubleSpinBox): + def __init__(self, parent=None): + super().__init__(parent=parent) + self.setReadOnly(True) + self.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons) + self.setAlignment(Qt.AlignCenter) + self.setMaximum(2**31 - 1) + + +class readOnlySpinbox(QSpinBox): + def __init__(self, parent=None): + super().__init__(parent=parent) + self.setReadOnly(True) + self.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons) + self.setAlignment(Qt.AlignCenter) + self.setMaximum(2**31 - 1) + + +class DoubleSpinBox(QDoubleSpinBox): + sigValueChanged = Signal(int) + + def __init__(self, parent=None, disableKeyPress=False): + super().__init__(parent=parent) + self.setAlignment(Qt.AlignCenter) + self.setMaximum(2**31 - 1) + self.setMinimum(-(2**31)) + self._valueChangedFunction = None + self.disableKeyPress = disableKeyPress + + def keyPressEvent(self, event) -> None: + isBackSpaceKey = event.key() == Qt.Key_Backspace + isDeleteKey = event.key() == Qt.Key_Delete + try: + int(event.text()) + isIntegerKey = True + except: + isIntegerKey = False + acceptEvent = isBackSpaceKey or isDeleteKey or isIntegerKey + if self.disableKeyPress and not acceptEvent: + event.ignore() + self.clearFocus() + else: + super().keyPressEvent(event) + + def textFromValue(self, value: float) -> str: + text = super().textFromValue(value) + return text.replace(QLocale().decimalPoint(), ".") + + def valueFromText(self, text: str) -> float: + text = text.replace(".", QLocale().decimalPoint()) + return super().valueFromText(text) + + +class SpinBox(QSpinBox): + sigValueChanged = Signal(int) + sigUpClicked = Signal() + sigDownClicked = Signal() + + def __init__(self, parent=None, disableKeyPress=False, allowNegative=True): + super().__init__(parent=parent) + self.setAlignment(Qt.AlignCenter) + self.setMaximum(2**31 - 1) + if allowNegative: + self.setMinimum(-(2**31)) + else: + self.setMinimum(0) + self._valueChangedFunction = None + self.disableKeyPress = disableKeyPress + self._linkedWidget = None + + def mousePressEvent(self, event) -> None: + super().mousePressEvent(event) + opt = QStyleOptionSpinBox() + self.initStyleOption(opt) + + control = self.style().hitTestComplexControl( + QStyle.ComplexControl.CC_SpinBox, opt, event.pos(), self + ) + if control == QStyle.SubControl.SC_SpinBoxUp: + self.sigUpClicked.emit() + elif control == QStyle.SubControl.SC_SpinBoxDown: + self.sigDownClicked.emit() + + # def focusOutEvent(self, event): + # self.editingFinished.emit() + # super().focusOutEvent(event) + # printl('emitted') + + def keyPressEvent(self, event) -> None: + isBackSpaceKey = event.key() == Qt.Key_Backspace + isDeleteKey = event.key() == Qt.Key_Delete + try: + int(event.text()) + isIntegerKey = True + except: + isIntegerKey = False + acceptEvent = isBackSpaceKey or isDeleteKey or isIntegerKey + if self.disableKeyPress and not acceptEvent: + event.ignore() + self.clearFocus() + else: + super().keyPressEvent(event) + + def connectValueChanged(self, function): + self._valueChangedFunction = function + self.valueChanged.connect(function) + + def setValue(self, value, setLinkedWidget=True): + super().setValue(int(value)) + if self._linkedWidget is not None and setLinkedWidget: + self._linkedWidget.setValue(value) + + def setValueNoEmit(self, value): + if self._valueChangedFunction is None: + self.setValue(value) + return + try: + self.valueChanged.disconnect() + except TypeError as e: # this fails if its not cennected yet + pass + + self.setValue(value) + self.valueChanged.connect(self._valueChangedFunction) + + def wheelEvent(self, event): + event.ignore() + + def setLinkedValueWidget(self, widget): + self._linkedWidget = widget + + +class ReadOnlyLineEdit(QLineEdit): + def __init__(self, parent=None): + super().__init__(parent=parent) + self.setReadOnly(True) + # self.setStyleSheet( + # 'background-color: rgba(240, 240, 240, 200);' + # ) + self.installEventFilter(self) + + def eventFilter(self, a0: "QObject", a1: "QEvent") -> bool: + if a1.type() == QEvent.Type.FocusIn: + return True + return super().eventFilter(a0, a1) + + def setValue(self, value): + self.setText(str(value)) + + def value(self, casting_func: callable = None): + text = self.text() + if casting_func is not None: + return casting_func(text) + return text + + +class FloatLineEdit(QLineEdit): + valueChanged = Signal(float) + + def __init__( + self, + *args, + notAllowed=None, + allowNegative=True, + initial=None, + readOnly=False, + decimals=6, + warningValues=None, + ): + QLineEdit.__init__(self, *args) + if readOnly: + self.setReadOnly(readOnly) + self.notAllowed = notAllowed + self.warningValues = warningValues + self._maximum = np.inf + self._minimum = -np.inf + self._decimals = decimals + + self.isNumericRegExp = rf"^{float_regex(allow_negative=allowNegative)}$" + regExp = QRegularExpression(self.isNumericRegExp) + self.setValidator(QRegularExpressionValidator(regExp)) + self.setAlignment(Qt.AlignCenter) + + font = QFont() + font.setPixelSize(11) + self.setFont(font) + + self.textChanged.connect(self.emitValueChanged) + + if initial is not None: + self.setValue(initial) + else: + self.setValue(0) + + def setDecimals(self, decimals): + self._decimals = 6 + + def castMinMax(self, value: int): + if value > self._maximum: + value = self._maximum + if value < self._minimum: + value = self._minimum + return value + + def setValue(self, value: float): + value = self.castMinMax(value) + self.setText(str(round(value, self._decimals))) + + def value(self): + m = re.match(self.isNumericRegExp, self.text()) + if m is not None: + text = m.group(0) + try: + val = float(text) + except ValueError: + val = 0.0 + else: + val = 0.0 + + return self.castMinMax(val) + + def setMaximum(self, maximum): + self._maximum = maximum + self.setValue(self.value()) + + def setMinimum(self, minimum): + self._minimum = minimum + self.setValue(self.value()) + + def emitValueChanged(self, text): + val = self.value() + reset_stylesheet = True + if self.warningValues is not None and val in self.warningValues: + self.setStyleSheet(LINEEDIT_WARNING_STYLESHEET) + reset_stylesheet = False + + if self.notAllowed is not None and val in self.notAllowed: + self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) + reset_stylesheet = False + else: + self.valueChanged.emit(self.value()) + + if reset_stylesheet: + self.setStyleSheet("") + + +class IntLineEdit(QLineEdit): + valueChanged = Signal(float) + + def __init__( + self, *args, notAllowed=None, allowNegative=True, initial=None, readOnly=False + ): + QLineEdit.__init__(self, *args) + self.notAllowed = notAllowed + if readOnly: + self.setReadOnly(readOnly) + + self._maximum = np.inf + self._minimum = -np.inf + + self._regExp = r"\d+" + if allowNegative: + self._regExp = r"-?\d+" + + regExp = QRegularExpression(self._regExp) + self.setValidator(QRegularExpressionValidator(regExp)) + self.setAlignment(Qt.AlignCenter) + + font = QFont() + font.setPixelSize(11) + self.setFont(font) + + self.textChanged.connect(self.emitValueChanged) + + if initial is not None: + self.setValue(initial) + else: + self.setValue(0) + + def setMaximum(self, maximum): + self._maximum = maximum + self.setValue(self.value()) + + def setMinimum(self, minimum): + self._minimum = minimum + self.setValue(self.value()) + + def castMinMax(self, value: int): + if value > self._maximum: + value = self._maximum + if value < self._minimum: + value = self._minimum + return value + + def setValue(self, value: int): + value = self.castMinMax(value) + self.setText(str(value)) + + def value(self): + m = re.match(self._regExp, self.text()) + if m is not None: + text = m.group(0) + try: + val = int(text) + except ValueError: + val = 0 + else: + val = 0 + + return self.castMinMax(val) + + def emitValueChanged(self, text): + if not text: + return + + val = self.value() + self.setValue(val) + if self.notAllowed is not None and val in self.notAllowed: + self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) + else: + self.setStyleSheet("") + self.valueChanged.emit(self.value()) + + +class highlightableQWidgetAction(QWidgetAction): + def __init__(self, parent) -> None: + super().__init__(parent) + + +class ComboBox(QComboBox): + sigTextChanged = Signal(str) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._previousText = None + self._valueChanged = False + self.currentTextChanged.connect(self.emitTextChanged) + self.installEventFilter(self) + + def eventFilter(self, object, event) -> bool: + if object == self and event.type() == QEvent.Type.Wheel: + # Forward event to parent so QScrollArea can scroll + QApplication.sendEvent(self.parent(), event) + return True # Consume for the combo itself + + return super().eventFilter(object, event) + + def text(self): + return self.currentText() + + def emitTextChanged(self, text): + self._valueChanged = True + self.sigTextChanged.emit(text) + + def mousePressEvent(self, event): + self._previousText = self.currentText() + super().mousePressEvent(event) + + def previousText(self): + return self._previousText + + def addItems(self, items): + super().addItems(items) + self._previousText = items[0] + + def itemsText(self): + return [self.itemText(i) for i in range(self.count())] + + def setCurrentIndex(self, idx): + itemsText = self.itemsText() + currentText = itemsText[idx] + self._valueChanged = currentText != self._previousText + self._previousText = self.currentText() + super().setCurrentIndex(idx) + + def setCurrentText(self, text): + currentText = text + self._valueChanged = currentText != self._previousText + self._previousText = self.currentText() + super().setCurrentText(text) + + +class SearchLineEdit(QLineEdit): + def __init__(self, parent=None): + super().__init__(parent) + + self.initSearch() + self.setFocusPolicy(Qt.ClickFocus) + + def focusInEvent(self, event) -> None: + super().focusInEvent(event) + if super().text() == "Search...": + self.setText("") + self.setStyleSheet("") + + def focusOutEvent(self, event) -> None: + super().focusOutEvent(event) + if not super().text(): + self.initSearch() + + def initSearch(self): + self.setText("Search...") + self.setStyleSheet("color: rgb(150, 150, 150)") + self.clearFocus() + + def text(self): + if super().text() == "Search...": + return "" + return super().text() + + +class VectorLineEdit(QLineEdit): + valueChanged = Signal(object) + valueChangeFinished = Signal(object) + + def __init__(self, parent=None, initial=None): + super().__init__(parent) + + self._minimum = -np.inf + + float_re = float_regex() + vector_regex = rf"\(?\[?{float_re}(,\s?{float_re})+\)?\]?" + regex = rf"^{vector_regex}$|^{float_re}$" + self.validRegex = regex + + regExp = QRegularExpression(regex) + self.setValidator(QRegularExpressionValidator(regExp)) + self.setAlignment(Qt.AlignCenter) + + self.textChanged.connect(self.emitValueChanged) + self.editingFinished.connect(self.emitValueChangeFinished) + if initial is None: + self.setText("0.0") + + font = QFont() + font.setPixelSize(11) + self.setFont(font) + + def emitValueChangeFinished(self): + value = self.value() + self.textChanged.disconnect() + self.editingFinished.disconnect() + self.setValue(value) + self.textChanged.connect(self.emitValueChanged) + self.editingFinished.connect(self.emitValueChangeFinished) + + self.emitValueChanged(self.text(), signal=self.valueChangeFinished) + + def emitValueChanged(self, text, signal=None): + m = re.match(self.validRegex, text) + if m is None: + self.setStyleSheet(LINEEDIT_INVALID_ENTRY_STYLESHEET) + return + + if signal is None: + signal = self.valueChanged + + self.setStyleSheet("") + signal.emit(self.value()) + + def increaseValue(self, step): + value = self.value() + if isinstance(value, (float, int)): + value += step + else: + value = [val + step for val in value] + value = str(value).lstrip("[").rstrip("]") + self.setValue(value) + self.emitValueChangeFinished() + + def decreaseValue(self, step): + value = self.value() + if isinstance(value, (float, int)): + value -= step + else: + value = [val - step for val in value] + value = str(value).lstrip("[").rstrip("]") + self.setText(value) + self.emitValueChangeFinished() + + def setValue(self, value): + if isinstance(value, (float, int)): + if value < self._minimum: + value = self._minimum + else: + clipped = [] + for val in value: + if val < self._minimum: + val = self._minimum + clipped.append(val) + value = str(clipped).lstrip("[").rstrip("]") + self.setText(value) + + def setText(self, text): + super().setText(str(text)) + + def clipValue(self, val: float): + if val < self._minimum: + val = self._minimum + return val + + def value(self): + m = re.match(self.validRegex, self.text()) + if m is None: + return 0.0 + + try: + value = self.clipValue(float(self.text())) + return value + except Exception as e: + text = self.text() + text = text.replace("(", "") + text = text.replace(")", "") + text = text.replace("[", "") + text = text.replace("]", "") + values = text.split(",") + return [self.clipValue(float(value)) for value in values] + + def setMinimum(self, minimum): + self._minimum = float(minimum) + + +class OddSpinBox(SpinBox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.setSingleStep(2) + self.editingFinished.connect(self.roundToOdd) + + def roundToOdd(self): + if self.value() % 2 == 1: + return + + self.setValue(self.value() + 1) + + +class LineEdit(QLineEdit): + def __init__(self, parent=None): + super().__init__(parent) + self.setAlignment(Qt.AlignCenter) + + def value(self): + return self.text() + + def setValue(self, value): + self.setText(str(value)) + + +class WhitelistLineEdit(KeepIDsLineEdit): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setText(self, IDs): + if not isinstance(IDs, set) and not isinstance(IDs, list): + raise TypeError("IDs must be a set or list") + + formatted_text = utils.format_IDs(IDs) + super().setText(formatted_text) + + +class KeySequenceFromText(QKeySequence): + def __init__(self, text: str): + if isinstance(text, str): + text = macShortcutToWindows(text) + super().__init__(text) + self._text = text + + def toString(self): + if isinstance(self._text, str): + return windowsShortcutToMac(self._text) + else: + return windowsShortcutToMac(super().toString()) + +# Cross-module imports (deferred to avoid import cycles) +from .dialogs import ( + QDialogListbox, +) + diff --git a/cellacdc/widgets/controls/metrics.py b/cellacdc/widgets/controls/metrics.py new file mode 100644 index 000000000..89e9e8271 --- /dev/null +++ b/cellacdc/widgets/controls/metrics.py @@ -0,0 +1,1049 @@ +"""Composite controls: metrics.""" + +"""GUI widgets: controls.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class _metricsQGBox(QGroupBox): + sigDelClicked = Signal(str, object) + + def __init__( + self, + desc_dict, + title, + favourite_funcs=None, + isZstack=False, + equations=None, + addDelButton=False, + delButtonMetricsDesc=None, + parent=None, + addCalcForEachZsliceToggle=False, + ): + QGroupBox.__init__(self, parent) + + highlightRgba = _palettes._highlight_rgba() + r, g, b, a = highlightRgba + self._highlightStylesheetColor = f"rgb({r}, {g}, {b})" + + self._parent = parent + self.scrollArea = QScrollArea() + self.scrollAreaWidget = QWidget() + self.favourite_funcs = favourite_funcs + + self.doNotWarn = False + + layout = QVBoxLayout() + inner_layout = QVBoxLayout() + self.inner_layout = inner_layout + if delButtonMetricsDesc is None: + delButtonMetricsDesc = [] + + self.checkBoxes = [] + self.checkedState = {} + for metric_colname, metric_desc in desc_dict.items(): + rowLayout = QHBoxLayout() + + checkBox = QCheckBox(metric_colname) + checkBox.setChecked(True) + checkBox.scrollArea = self.scrollArea + self.checkBoxes.append(checkBox) + self.checkedState[checkBox] = True + + try: + checkBox.equation = equations[metric_colname] + except Exception as e: + pass + + if addDelButton or metric_colname in delButtonMetricsDesc: + delButton = delPushButton() + delButton.setToolTip("Delete custom combined measurement") + delButton.colname = metric_colname + delButton.checkbox = checkBox + delButton.clicked.connect(self.onDelClicked) + delButton._layout = rowLayout + rowLayout.addWidget(delButton) + + infoButton = infoPushButton() + infoButton.setCursor(Qt.WhatsThisCursor) + infoButton.info = metric_desc + infoButton.colname = metric_colname + infoButton.clicked.connect(self.showInfo) + + rowLayout.addWidget(infoButton) + rowLayout.addWidget(checkBox) + rowLayout.addStretch(1) + + inner_layout.addLayout(rowLayout) + + self.scrollAreaWidget.setLayout(inner_layout) + self.scrollArea.setWidget(self.scrollAreaWidget) + layout.addWidget(self.scrollArea) + + buttonsLayout = QHBoxLayout() + + buttonsLayout.addStretch(1) + + self.selectAllButton = selectAllPushButton() + self.selectAllButton.sigClicked.connect(self.checkAll) + + buttonsLayout.addWidget(self.selectAllButton) + + if favourite_funcs is not None: + self.loadFavouritesButton = reloadPushButton(" Load last selection... ") + self.loadFavouritesButton.clicked.connect(self.checkFavouriteFuncs) + # self.checkFavouriteFuncs() + buttonsLayout.addWidget(self.loadFavouritesButton) + + layout.addLayout(buttonsLayout) + + self.calcForEachZsliceToggle = None + if addCalcForEachZsliceToggle: + buttonsLayout = QHBoxLayout() + self.calcForEachZsliceToggle = Toggle() + tooltip = ( + "Calculate `cell_area` for each z-slice.\n\n" + "The measurements will be saved in the column with name\n" + "ending with `_zsliceN` where N is the z-slice number\n" + "(starting from 0)." + ) + calcForEachZsliceLabel = QClickableLabel("Calculate for each z-slice") + calcForEachZsliceLabel.setToolTip(tooltip) + self.calcForEachZsliceToggle.setToolTip(tooltip) + buttonsLayout.addWidget(self.calcForEachZsliceToggle) + buttonsLayout.addWidget(calcForEachZsliceLabel) + buttonsLayout.addStretch(1) + layout.addLayout(buttonsLayout) + calcForEachZsliceLabel.clicked.connect( + partial( + self.toggleCalcForEachZslice, toggle=self.calcForEachZsliceToggle + ) + ) + + self.setTitle(title) + self.setCheckable(True) + self.setLayout(layout) + _font = QFont() + _font.setPixelSize(11) + self.setFont(_font) + + self.toggled.connect(self.toggled_cb) + + def toggleCalcForEachZslice(self, label, toggle=None): + if toggle is None: + toggle = self.calcForEachZsliceToggle + + toggle.setChecked(not toggle.isChecked()) + + def isCalcForEachZsliceRequested(self): + if self.calcForEachZsliceToggle is None: + return False + + return self.calcForEachZsliceToggle.isChecked() + + def highlightCheckboxesFromSearchText(self, text): + for checkbox in self.checkBoxes: + if not text: + highlighted = False + else: + highlighted = checkbox.text().lower().find(text.lower()) != -1 + + self.setCheckboxHighlighted(highlighted, checkbox) + + def setCheckboxHighlighted(self, highlighted, checkbox): + if highlighted: + checkbox.setStyleSheet( + f"background: {self._highlightStylesheetColor}; color: black" + ) + self.scrollArea.ensureWidgetVisible(checkbox) + else: + checkbox.setStyleSheet("") + + def onDelClicked(self): + button = self.sender() + button.checkbox.setChecked(False) + self.sigDelClicked.emit(button.colname, button._layout) + + def toggled_cb(self, checked): + for checkbox in self.checkBoxes: + if not checked: + self.checkedState[checkbox] = checkbox.isChecked() + checkbox.setChecked(False) + else: + checkbox.setChecked(self.checkedState[checkbox]) + + def checkFavouriteFuncs(self, checked=True, isZstack=False): + self.doNotWarn = True + if self._parent is not None: + self._parent.doNotWarn = True + for checkBox in self.checkBoxes: + checkBox.setChecked(False) + for favourite_func in self.favourite_funcs: + func_name = checkBox.text() + if func_name.endswith(favourite_func): + checkBox.setChecked(True) + break + self.doNotWarn = False + if self._parent is not None: + self._parent.doNotWarn = False + + def checkAll(self, button, checked): + if self._parent is not None: + self._parent.doNotWarn = True + for checkBox in self.checkBoxes: + checkBox.setChecked(checked) + if self._parent is not None: + self._parent.doNotWarn = False + + def showInfo(self, checked=False): + info_txt = self.sender().info + msg = myMessageBox() + msg.setWidth(600) + msg.setIcon() + msg.setWindowTitle(f"{self.sender().colname} info") + msg.addText(info_txt) + msg.addButton(" Ok ") + msg.exec_() + + def show(self): + super().show() + fw = self.inner_layout.contentsRect().width() + sw = self.scrollArea.verticalScrollBar().sizeHint().width() + self.minWidth = fw + sw + + +class channelMetricsQGBox(QGroupBox): + sigDelClicked = Signal(str, object) + sigCheckboxToggled = Signal(object) + + def __init__( + self, + isZstack, + chName, + isSegm3D, + is_concat=False, + posData=None, + favourite_funcs=None, + ): + QGroupBox.__init__(self) + + self.doNotWarn = False + self.is_concat = is_concat + isManualBackgrPresent = False + if posData is not None: + if posData.manualBackgroundLab is not None: + isManualBackgrPresent = True + + layout = QVBoxLayout() + metrics_desc, bkgr_val_desc = measurements.standard_metrics_desc( + isZstack, + chName, + isSegm3D=isSegm3D, + isManualBackgrPresent=isManualBackgrPresent, + ) + + metricsQGBox = _metricsQGBox( + metrics_desc, + "Standard measurements", + favourite_funcs=favourite_funcs, + parent=self, + isZstack=isZstack, + ) + self.metricsQGBox = metricsQGBox + + bkgrValsQGBox = _metricsQGBox( + bkgr_val_desc, + "Background values", + favourite_funcs=favourite_funcs, + parent=self, + isZstack=isZstack, + ) + self.bkgrValsQGBox = bkgrValsQGBox + + self.checkBoxes = metricsQGBox.checkBoxes.copy() + self.checkBoxes.extend(bkgrValsQGBox.checkBoxes) + + self.uncheckAndDisableDataPrepIfPosNotPrepped(posData) + + self.groupboxes = [metricsQGBox, bkgrValsQGBox] + + for checkbox in metricsQGBox.checkBoxes: + checkbox.toggled.connect(self.standardMetricToggled) + self.standardMetricToggled(checkbox.isChecked(), checkbox=checkbox) + + for bkgrCheckbox in bkgrValsQGBox.checkBoxes: + bkgrCheckbox.toggled.connect(self.backgroundMetricToggled) + + layout.addWidget(metricsQGBox) + layout.addWidget(bkgrValsQGBox) + + items = measurements.custom_metrics_desc( + isZstack, chName, posData=posData, isSegm3D=isSegm3D, return_combine=True + ) + custom_metrics_desc, combine_metrics_desc = items + + if custom_metrics_desc: + customMetricsQGBox = _metricsQGBox( + custom_metrics_desc, + "Custom measurements", + delButtonMetricsDesc=combine_metrics_desc, + favourite_funcs=favourite_funcs, + isZstack=isZstack, + ) + layout.addWidget(customMetricsQGBox) + self.checkBoxes.extend(customMetricsQGBox.checkBoxes) + customMetricsQGBox.sigDelClicked.connect(self.onDelClicked) + self.customMetricsQGBox = customMetricsQGBox + + self.calcForEachZsliceToggle = None + if isZstack: + buttonsLayout = QHBoxLayout() + self.calcForEachZsliceToggle = Toggle() + tooltip = ( + "Calculate the selected measurements for each z-slice.\n\n" + "The measurements will be saved in the column with name\n" + "ending with `_zsliceN` where N is the z-slice number\n" + "(starting from 0)." + ) + calcForEachZsliceLabel = QClickableLabel("Calculate for each z-slice") + calcForEachZsliceLabel.setToolTip(tooltip) + self.calcForEachZsliceToggle.setToolTip(tooltip) + buttonsLayout.addWidget(self.calcForEachZsliceToggle) + buttonsLayout.addWidget(calcForEachZsliceLabel) + buttonsLayout.addStretch(1) + layout.addLayout(buttonsLayout) + calcForEachZsliceLabel.clicked.connect( + partial( + self.toggleCalcForEachZslice, toggle=self.calcForEachZsliceToggle + ) + ) + + self.setTitle(f"{chName} metrics") + self.setCheckable(True) + self.setLayout(layout) + + def toggleCalcForEachZslice(self, label, toggle=None): + if toggle is None: + toggle = self.calcForEachZsliceToggle + + toggle.setChecked(not toggle.isChecked()) + + def isCalcForEachZsliceRequested(self): + if self.calcForEachZsliceToggle is None: + return False + + return self.calcForEachZsliceToggle.isChecked() + + def uncheckAndDisableDataPrepIfPosNotPrepped(self, posData): + # Uncheck and disable dataprep metrics if pos is not prepped + if posData is None: + return + + if posData.isBkgrROIpresent(): + return + + for checkbox in self.checkBoxes: + if checkbox.text().find("dataPrep") == -1: + continue + + checkbox.setChecked(False) + checkbox.isDataPrepDisabled = True + + def _warnDataPrepCannotBeChecked(self): + if self.doNotWarn: + return + txt = html_utils.paragraph(""" + Data prep measurements cannot be saved because you did + not select any background ROI at the data prep step.

    + + You can read more details about data prep metrics by clicking + on the info button besides the measurement's name.

    + + Thank you for you patience! + """) + msg = myMessageBox(showCentered=False) + msg.warning(self, "Metric cannot be saved", txt) + + def standardMetricToggled(self, checked, checkbox=None): + """Method called when a check-box is toggled. It performs the following + actions: + 1. If the user try to check a data prep measurement, such as + dataPrep_amount, and this cannot be saved (checkbox has the attr + `isDataPrepDisabled`) then it warns and explains why it cannot be saved + 2. Make sure that background value median is checked if the user + requires amount or concentration metric. + 3. Do not allow unchecking background value median and explain why. + + Parameters + ---------- + checked : bool + State of the checkbox toggled + checkbox : QtWidgets.QCheckBox, optional + The checkbox that has been toggled. Default is None. If None + use `self.sender()` + """ + if self.is_concat: + return + + if checkbox is None: + checkbox = self.sender() + + if hasattr(checkbox, "isDataPrepDisabled"): + # Warn that user cannot check data prep metrics and uncheck it + if not checkbox.isChecked(): + return + checkbox.setChecked(False) + self._warnDataPrepCannotBeChecked() + return + + self.sigCheckboxToggled.emit(checkbox) + if checkbox.text().find("amount_") == -1: + return + pattern = r"amount_([A-Za-z]+)(_?[A-Za-z0-9]*)" + repl = r"\g<1>_bkgrVal_median\g<2>" + bkgrValMetric = s1 = re.sub(pattern, repl, checkbox.text()) + for bkgrCheckbox in self.groupboxes[1].checkBoxes: + if bkgrCheckbox.text() == bkgrValMetric: + break + else: + # Make sure to not check for similarly named custom metrics + return + + if checked: + bkgrCheckbox.setChecked(True) + bkgrCheckbox.isRequired = True + else: + bkgrCheckbox.setDisabled(False) + bkgrCheckbox.isRequired = False + + def backgroundMetricToggled(self, checked): + """Method called when a checkbox of a background metric is toggled. + Check if the background value is required and explain why it cannot be + unchecked. + + Parameters + ---------- + checked : bool + State of the checkbox toggled + """ + if self.is_concat: + return + + checkbox = self.sender() + if not hasattr(checkbox, "isRequired"): + return + + if not checkbox.isRequired: + return + + if checkbox.isChecked(): + return + + if self.doNotWarn: + return + + checkbox.setChecked(True) + txt = html_utils.paragraph(""" + This background value cannot be unchecked because it is required + by the _amount and _concentration measurements + that you requested to save.

    + + Thank you for you patience! + """) + msg = myMessageBox(showCentered=False) + msg.warning(self, "Background value required", txt) + + def onDelClicked(self, colname_to_del, hlayout): + self.sigDelClicked.emit(colname_to_del, hlayout) + + def checkFavouriteFuncs(self): + self.doNotWarn = True + for groupbox in self.groupboxes: + groupbox.checkFavouriteFuncs() + self.doNotWarn = False + + +class PixelSizeGroupbox(QGroupBox): + sigValueChanged = Signal(float, float, float) + sigReset = Signal() + + def __init__(self, parent=None): + super().__init__("Pixel size", parent) + + mainLayout = QGridLayout() + + row = 0 + label = QLabel("Pixel width (μm): ") + self.pixelWidthWidget = FloatLineEdit(initial=1.0) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.pixelWidthWidget, row, 1) + + row += 1 + label = QLabel("Pixel height (μm): ") + self.pixelHeightWidget = FloatLineEdit(initial=1.0) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.pixelHeightWidget, row, 1) + + row += 1 + label = QLabel("Voxel depth (μm): ") + self.voxelDepthWidget = FloatLineEdit(initial=1.0) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.voxelDepthWidget, row, 1) + + row += 1 + resetButton = reloadPushButton("Reset") + mainLayout.addWidget(resetButton, row, 1, alignment=Qt.AlignRight) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + mainLayout.setColumnStretch(0, 0) + mainLayout.setColumnStretch(1, 1) + + self.setLayout(mainLayout) + + self.pixelWidthWidget.valueChanged.connect(self.emitValueChanged) + self.pixelHeightWidget.valueChanged.connect(self.emitValueChanged) + self.voxelDepthWidget.valueChanged.connect(self.emitValueChanged) + resetButton.clicked.connect(self.emitReset) + + def emitReset(self): + self.sigReset.emit() + + def emitValueChanged(self, value): + PhysicalSizeX = self.pixelWidthWidget.value() + PhysicalSizeY = self.pixelHeightWidget.value() + PhysicalSizeZ = self.voxelDepthWidget.value() + self.sigValueChanged.emit(PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ) + + +class objPropsQGBox(QGroupBox): + def __init__(self, parent=None): + QGroupBox.__init__(self, "Properties", parent) + + mainLayout = QGridLayout() + + row = 0 + label = QLabel("Object ID: ") + self.idSB = IntLineEdit() + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.idSB, row, 1) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + row += 1 + self.notExistingIDLabel = QLabel() + self.notExistingIDLabel.setStyleSheet("font-size:11px; color: rgb(255, 0, 0);") + mainLayout.addWidget( + self.notExistingIDLabel, row, 0, 1, 2, alignment=Qt.AlignCenter + ) + + row += 1 + label = QLabel("Area (pixel): ") + self.cellAreaPxlSB = IntLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellAreaPxlSB, row, 1) + + row += 1 + label = QLabel("Area (µm2): ") + self.cellAreaUm2DSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellAreaUm2DSB, row, 1) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + row += 1 + label = QLabel("Rotational volume (voxel): ") + self.cellVolVoxSB = IntLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellVolVoxSB, row, 1) + + row += 1 + label = QLabel("3D volume (voxel): ") + self.cellVolVox3D_SB = IntLineEdit(readOnly=True) + self.cellVolVox3D_SB.label = label + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellVolVox3D_SB, row, 1) + + row += 1 + label = QLabel("Rotational volume (fl): ") + self.cellVolFlDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellVolFlDSB, row, 1) + + row += 1 + label = QLabel("3D volume (fl): ") + self.cellVolFl3D_DSB = FloatLineEdit(readOnly=True) + self.cellVolFl3D_DSB.label = label + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.cellVolFl3D_DSB, row, 1) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + row += 1 + label = QLabel("Solidity: ") + self.solidityDSB = FloatLineEdit(readOnly=True) + self.solidityDSB.setMaximum(1) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.solidityDSB, row, 1) + + row += 1 + label = QLabel("Elongation: ") + self.elongationDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.elongationDSB, row, 1) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + row += 1 + propsNames = measurements.get_props_names()[1:] + self.additionalPropsCombobox = QComboBox() + self.additionalPropsCombobox.addItems(propsNames) + self.additionalPropsCombobox.indicator = FloatLineEdit(readOnly=True) + mainLayout.addWidget(self.additionalPropsCombobox, row, 0) + mainLayout.addWidget(self.additionalPropsCombobox.indicator, row, 1) + + row += 1 + mainLayout.addWidget(QHLine(), row, 0, 1, 2) + + mainLayout.setColumnStretch(0, 0) + mainLayout.setColumnStretch(1, 1) + + self.setLayout(mainLayout) + + +class objIntesityMeasurQGBox(QGroupBox): + def __init__(self, parent=None): + QGroupBox.__init__(self, "Intensity measurements", parent) + + mainLayout = QGridLayout() + + row = 0 + label = QLabel("Raw intensity measurements") + + row += 1 + label = QLabel("Channel: ") + self.channelCombobox = QComboBox() + self.channelCombobox.addItem("placeholderlong") + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.channelCombobox, row, 1) + + row += 1 + label = QLabel("Minimum: ") + self.minimumDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.minimumDSB, row, 1) + + row += 1 + label = QLabel("Maximum: ") + self.maximumDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.maximumDSB, row, 1) + + row += 1 + label = QLabel("Mean: ") + self.meanDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.meanDSB, row, 1) + + row += 1 + label = QLabel("Median: ") + self.medianDSB = FloatLineEdit(readOnly=True) + mainLayout.addWidget(label, row, 0) + mainLayout.addWidget(self.medianDSB, row, 1) + + row += 1 + metricsDesc = measurements._get_metrics_names() + metricsFunc, _ = measurements.standard_metrics_func() + items = list(set([metricsDesc[key] for key in metricsFunc.keys()])) + items.append("Concentration") + items.sort() + nameFuncDict = {} + for name, desc in metricsDesc.items(): + if name.find("_dataPrepBkgr") != -1 or name.find("_manualBkgr") != -1: + # Skip dataPrepBkgr and manualBkgr since in the dock widget + # we display only autoBkgr metrics + continue + if name.startswith("concentration_"): + # We use amount function because dividing by volume is taken + # care in the GUI + name = "amount_autoBkgr" + nameFuncDict[desc] = metricsFunc[name] + + funcionCombobox = QComboBox() + funcionCombobox.addItems(items) + self.additionalMeasCombobox = funcionCombobox + self.additionalMeasCombobox.indicator = FloatLineEdit(readOnly=True) + self.additionalMeasCombobox.functions = nameFuncDict + mainLayout.addWidget(funcionCombobox, row, 0) + mainLayout.addWidget(self.additionalMeasCombobox.indicator, row, 1) + + self.setLayout(mainLayout) + + def addChannels(self, channels): + self.channelCombobox.clear() + self.channelCombobox.addItems(channels) + + +class SetMeasurementsGroupBox(QGroupBox): + def __init__( + self, + title, + itemsText, + checkable=True, + itemsInfo=None, + lastSelection=None, + itemsInfoUrls=None, + parent=None, + ): + super().__init__(parent) + + if itemsInfo is None: + itemsInfo = {} + + if itemsInfo is None: + itemsInfoUrls = {} + + highlightRgba = _palettes._highlight_rgba() + r, g, b, a = highlightRgba + self._highlightStylesheetColor = f"rgb({r}, {g}, {b})" + + self.setTitle(title) + self.setCheckable(checkable) + + mainLayout = QVBoxLayout() + + scrollArea = QScrollArea() + scrollArea.setWidgetResizable(True) + scrollAreaLayout = QVBoxLayout() + scrollAreaWidget = QWidget() + self.scrollAreaWidget = scrollAreaWidget + self.scrollAreaLayout = scrollAreaLayout + + self.checkboxes = {} + for text in itemsText: + rowLayout = QHBoxLayout() + infoText = itemsInfo.get(text) + infoUrl = itemsInfoUrls.get(text) + if infoText is not None or infoUrl is not None: + infoButton = infoPushButton() + infoButton.setCursor(Qt.WhatsThisCursor) + rowLayout.addWidget(infoButton) + + if infoText is not None: + infoButton.itemText = text + infoButton.infoText = infoText + infoButton.clicked.connect(self.showInfo) + + if infoUrl is not None: + infoButton.itemText = text + infoButton.infoUrl = infoUrl + infoButton.clicked.connect(self.openInfoUrl) + + checkbox = QCheckBox(text) + checkbox.setParent(self.scrollAreaWidget) + checkbox.setChecked(True) + rowLayout.addWidget(checkbox) + rowLayout.addStretch(1) + + self.checkboxes[text] = checkbox + + scrollAreaLayout.addLayout(rowLayout) + + scrollAreaLayout.addStretch(1) + + scrollAreaWidget.setLayout(scrollAreaLayout) + scrollArea.setWidget(scrollAreaWidget) + self.scrollArea = scrollArea + + buttonsLayout = QHBoxLayout() + self.selectAllButton = selectAllPushButton() + self.selectAllButton.sigClicked.connect(self.setCheckedAll) + + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(self.selectAllButton) + self.buttonsLayout = buttonsLayout + + if lastSelection is not None: + self.lastSelection = lastSelection + self.loadLastSelButton = reloadPushButton(" Load last selection... ") + self.loadLastSelButton.clicked.connect(self.loadLastSelection) + buttonsLayout.addWidget(self.loadLastSelButton) + + mainLayout.addWidget(scrollArea) + mainLayout.addSpacing(10) + mainLayout.addLayout(buttonsLayout) + + self.setLayout(mainLayout) + + def openInfoUrl(self): + url = self.sender().infoUrl + QDesktopServices.openUrl(QUrl(url)) + # import webbrowser + # url = self.sender().infoUrl + # webbrowser.open(url) + + def getWidthNoScrollBarNeeded(self): + width = ( + self.scrollArea.verticalScrollBar().sizeHint().width() + # self.scrollAreaLayout.contentsRect().width() + + self.scrollAreaWidget.sizeHint().width() + + 30 + ) + buttonsWidth = 0 + for i in range(self.buttonsLayout.count()): + widget = self.buttonsLayout.itemAt(i).widget() + if not isinstance(widget, QPushButton): + continue + buttonsWidth += widget.sizeHint().width() + 16 + largerWidth = max(width, buttonsWidth) + return largerWidth + + def resizeWidthNoScrollBarNeeded(self): + width = self.getWidthNoScrollBarNeeded() + self.setMinimumWidth(width) + # self.setFixedWidth(width) + + def loadLastSelection(self): + for text, checkbox in self.checkboxes.items(): + checked = self.lastSelection.get(text, False) + checkbox.setChecked(checked) + + def showInfo(self): + infoText = self.sender().infoText + itemText = self.sender().itemText + + title = f"{itemText} description" + msg = myMessageBox() + msg.setWidth(int(self.screen().size().width() / 2)) + msg.information(self, title, infoText) + + def setCheckedAll(self, button, checked): + for checkbox in self.checkboxes.values(): + checkbox.setChecked(checked) + + def highlightCheckboxesFromSearchText(self, text): + for checkbox in self.checkboxes.values(): + if not text: + highlighted = False + else: + highlighted = checkbox.text().lower().find(text.lower()) != -1 + + self.setCheckboxHighlighted(highlighted, checkbox) + + def setCheckboxHighlighted(self, highlighted, checkbox): + if highlighted: + checkbox.setStyleSheet( + f"background: {self._highlightStylesheetColor}; color: black" + ) + self.scrollArea.ensureWidgetVisible(checkbox) + else: + checkbox.setStyleSheet("") + +# Cross-module imports (deferred to avoid import cycles) +from .dialogs import ( + myMessageBox, +) +from .inputs import ( + FloatLineEdit, + IntLineEdit, + QClickableLabel, +) +from .panels import ( + Toggle, +) + diff --git a/cellacdc/widgets/controls/panels.py b/cellacdc/widgets/controls/panels.py new file mode 100644 index 000000000..0da52dec0 --- /dev/null +++ b/cellacdc/widgets/controls/panels.py @@ -0,0 +1,1025 @@ +"""Composite controls: panels.""" + +"""GUI widgets: controls.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +from ..canvas.plot_items import ( + LabelItem, +) + +class statusBarPermanentLabel(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + + self.rightLabel = QLabel("") + self.leftLabel = QLabel("") + + layout = QHBoxLayout() + layout.addWidget(self.leftLabel) + layout.addStretch(10) + layout.addWidget(self.rightLabel) + + self.setLayout(layout) + + +class listWidget(QListWidget): + def __init__( + self, *args, isMultipleSelection=False, minimizeHeight=False, **kwargs + ): + super().__init__(*args, **kwargs) + self.itemHeight = None + self.setStyleSheet(LISTWIDGET_STYLESHEET) + self.setFont(font) + if isMultipleSelection: + self.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) + + self.minimizeHeight = minimizeHeight + + def setSelectedAll(self, selected): + for i in range(self.count()): + self.item(i).setSelected(selected) + + def setSelectedItems(self, itemsText): + for i in range(self.count()): + item = self.item(i) + item.setSelected(item.text() in itemsText) + + def addItems(self, labels) -> None: + super().addItems(labels) + if self.itemHeight is not None: + self.setItemHeight() + + if self.minimizeHeight: + itemHeight = self.sizeHintForRow(0) + self.setMaximumHeight(itemHeight * self.count() + itemHeight * 2) + + def addItem(self, text): + super().addItem(text) + if self.itemHeight is None: + return + self.setItemHeight() + + def setItemHeight(self, height=40): + self.itemHeight = height + for i in range(self.count()): + item = self.item(i) + item.setSizeHint(QSize(0, height)) + + def selectedItemsText(self): + return [item.text() for item in self.selectedItems()] + + +class OrderableListWidget(QWidget): + sigEnterEvent = Signal(object) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._labels = [] + + def setParentItem(self, item): + self._item = item + + def setLabelsColor(self, selected): + if selected: + stylesheet = "color : black" + else: + stylesheet = "" + + for label in self._labels: + label.setStyleSheet(stylesheet) + + def enterEvent(self, event): + super().enterEvent(event) + self.setLabelsColor(True) + self.sigEnterEvent.emit(self._item) + + # def leaveEvent(self, event): + # super().leaveEvent(event) + # self.setLabelsColor(self._item.isSelected()) + # printl('leave', self._item.isSelected()) + + def addLabel(self, label): + self._labels.append(label) + self.validPattern = r"^[0-9,\.]+$" + regExp = QRegularExpression(self.validPattern) + self.setValidator(QRegularExpressionValidator(regExp)) + + def values(self): + try: + vals = [float(c) for c in self.text().split(",")] + except Exception as e: + vals = [] + return vals + + +class KeptObjectIDsList(list): + def __init__(self, lineEdit, confirmSelectionAction, *args): + self.lineEdit = lineEdit + self.lineEdit.setText("") + self.confirmSelectionAction = confirmSelectionAction + confirmSelectionAction.setDisabled(True) + super().__init__(*args) + + def setText(self): + text = utils.format_IDs(self) + + self.lineEdit.setText(text) + + def append(self, element, editText=True): + super().append(element) + if editText: + self.setText() + if not self.confirmSelectionAction.isEnabled(): + self.confirmSelectionAction.setEnabled(True) + + def remove(self, element, editText=True): + super().remove(element) + if editText: + self.setText() + if not self: + self.confirmSelectionAction.setEnabled(False) + + +class Toggle(QCheckBox): + def __init__( + self, + label_text="", + initial=None, + width=80, + bg_color="#b3b3b3", + circle_color="#ffffff", + active_color="#26dd66", # '#005ce6', + animation_curve=QEasingCurve.Type.InOutQuad, + ): + QCheckBox.__init__(self) + + # self.setFixedSize(width, 28) + self.setCursor(Qt.PointingHandCursor) + + self._label_text = label_text + self._bg_color = bg_color + self._circle_color = circle_color + self._active_color = active_color + self._disabled_active_color = colors.lighten_color(active_color) + self._disabled_circle_color = colors.lighten_color(circle_color) + self._disabled_bg_color = colors.lighten_color(bg_color, amount=0.5) + self._circle_margin = 4 + + self._circle_position = int(self._circle_margin / 2) + self.animation = QPropertyAnimation(self, b"circle_position", self) + self.animation.setEasingCurve(animation_curve) + self.animation.setDuration(200) + + self.stateChanged.connect(self.start_transition) + self.requestedState = None + + self.installEventFilter(self) + self._isChecked = False + + if initial is not None: + self.setChecked(initial) + + def sizeHint(self): + return QSize(36, 18) + + def eventFilter(self, object, event): + # To get the actual position of the circle we need to wait that + # the widget is visible before setting the state + if event.type() == QEvent.Type.Show and self.requestedState is not None: + self.setChecked(self.requestedState) + return False + + def setChecked(self, state): + # To get the actual position of the circle we need to wait that + # the widget is visible before setting the state + self._isChecked = state + if self.isVisible(): + self.requestedState = None + QCheckBox.setChecked(self, state > 0) + else: + self.requestedState = state + + def isChecked(self): + if self.isVisible(): + return super().isChecked() + else: + return self._isChecked + + def circlePos(self, state: bool): + start = int(self._circle_margin / 2) + if state: + if self.isVisible(): + height, width = self.height(), self.width() + else: + sizeHint = self.sizeHint() + height, width = sizeHint.height(), sizeHint.width() + circle_diameter = height - self._circle_margin + pos = width - start - circle_diameter + else: + pos = start + return pos + + @Property(float) + def circle_position(self): + return self._circle_position + + @circle_position.setter + def circle_position(self, pos): + self._circle_position = pos + self.update() + + def start_transition(self, state): + self.animation.stop() + pos = self.circlePos(state) + self.animation.setEndValue(pos) + self.animation.start() + + def hitButton(self, pos: QPoint): + return self.contentsRect().contains(pos) + + def setDisabled(self, disable): + QCheckBox.setDisabled(self, disable) + if hasattr(self, "label"): + self.label.setDisabled(disable) + self.update() + + def paintEvent(self, e): + circle_color = ( + self._circle_color if self.isEnabled() else self._disabled_circle_color + ) + active_color = ( + self._active_color if self.isEnabled() else self._disabled_active_color + ) + unchecked_color = ( + self._bg_color if self.isEnabled() else self._disabled_bg_color + ) + + # set painter + p = QPainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + + # set no pen + p.setPen(Qt.NoPen) + + # draw rectangle + rect = QRect(0, 0, self.width(), self.height()) + + if not self.isChecked(): + # Draw background + p.setBrush(QColor(unchecked_color)) + half_h = int(self.height() / 2) + p.drawRoundedRect(0, 0, rect.width(), self.height(), half_h, half_h) + + # Draw circle + p.setBrush(QColor(circle_color)) + p.drawEllipse( + int(self._circle_position), + int(self._circle_margin / 2), + self.height() - self._circle_margin, + self.height() - self._circle_margin, + ) + else: + # Draw background + p.setBrush(QColor(active_color)) + half_h = int(self.height() / 2) + p.drawRoundedRect(0, 0, rect.width(), self.height(), half_h, half_h) + + # Draw circle + p.setBrush(QColor(circle_color)) + p.drawEllipse( + int(self._circle_position), + int(self._circle_margin / 2), + self.height() - self._circle_margin, + self.height() - self._circle_margin, + ) + + p.end() + + +class ToggleTerminalButton(PushButton): + sigClicked = Signal(bool) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setIcon(QIcon(":terminal_up.svg")) + self.setFixedSize(34, 18) + self.setIconSize(QSize(30, 14)) + self.setFlat(True) + self.terminalVisible = False + self.clicked.connect(self.mouseClick) + + def mouseClick(self): + if self.terminalVisible: + self.setIcon(QIcon(":terminal_up.svg")) + self.terminalVisible = False + else: + self.setIcon(QIcon(":terminal_down.svg")) + self.terminalVisible = True + self.sigClicked.emit(self.terminalVisible) + + def showEvent(self, a0) -> None: + self.idlePalette = self.palette() + return super().showEvent(a0) + + def enterEvent(self, event) -> None: + self.setFlat(False) + # pal = self.palette() + # pal.setColor(QPalette.ColorRole.Button, QColor(200, 200, 200)) + # self.setAutoFillBackground(True) + # self.setPalette(pal) + self.update() + return super().enterEvent(event) + + def leaveEvent(self, event) -> None: + self.setFlat(True) + # self.setPalette(self.idlePalette) + self.update() + return super().leaveEvent(event) + + +class expandCollapseButton(PushButton): + sigClicked = Signal() + + def __init__(self, parent=None, **kwargs): + super().__init__(parent, **kwargs) + self.setIcon(QIcon(":expand.svg")) + self.setFlat(True) + self.installEventFilter(self) + self.isExpand = True + self.clicked.connect(self.buttonClicked) + + def buttonClicked(self, checked=False): + if self.isExpand: + self.setIcon(QIcon(":collapse.svg")) + self.isExpand = False + if self.text(): + self.setText(self.text().replace("Hide", "Show")) + else: + self.setIcon(QIcon(":expand.svg")) + self.isExpand = True + if self.text(): + self.setText(self.text().replace("Show", "Hide")) + self.sigClicked.emit() + + def eventFilter(self, object, event): + if event.type() == QEvent.Type.HoverEnter: + self.setFlat(False) + elif event.type() == QEvent.Type.HoverLeave: + self.setFlat(True) + return False + + +class ToggleVisibilityButton(PushButton): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setFlat(True) + # self.setCheckable(True) + self._state = False + self.setIcon(QIcon(":unchecked.svg")) + self.clicked.connect(self.onClicked) + self.setStyleSheet(""" + QPushButton::pressed { + background-color: none; + border-style: none; + } + """) + + def onClicked(self): + self._state = not self._state + if self._state: + self.setIcon(QIcon(":eye-checked.svg")) + else: + self.setIcon(QIcon(":unchecked.svg")) + + +class ToggleVisibilityCheckBox(QCheckBox): + def __init__(self, *args, pixelSize=24): + super().__init__(*args) + self._pixelSize = pixelSize + self.onToggled(False) + self.toggled.connect(self.onToggled) + + def setPixelSize(self, pixelSize): + self._pixelSize = pixelSize + + def onToggled(self, checked): + if checked: + self.setStyleSheet(f""" + QCheckBox::indicator {{ + width: {self._pixelSize}px; + height: {self._pixelSize}px; + }} + + QCheckBox::indicator:checked + {{ + image: url(:eye-checked.svg); + }} + """) + else: + self.setStyleSheet(f""" + QCheckBox::indicator {{ + width: {self._pixelSize}px; + height: {self._pixelSize}px; + }} + + QCheckBox::indicator:unchecked + {{ + image: url(:unchecked.svg); + }} + """) + + +class FeatureSelectorButton(QPushButton): + def __init__(self, text, parent=None, alignment=""): + super().__init__(text, parent=parent) + self._isFeatureSet = False + self._alignment = alignment + self.setCursor(Qt.PointingHandCursor) + + def setFeatureText(self, text): + self.setText(text) + self.setFlat(True) + self._isFeatureSet = True + if self._alignment: + self.setStyleSheet(f"text-align:{self._alignment};") + + def enterEvent(self, event) -> None: + if self._isFeatureSet: + self.setFlat(False) + return super().enterEvent(event) + + def leaveEvent(self, event) -> None: + if self._isFeatureSet: + self.setFlat(True) + self.update() + return super().leaveEvent(event) + + def setSizeLongestText(self, longestText): + currentText = self.text() + self.setText(longestText) + w, h = self.sizeHint().width(), self.sizeHint().height() + self.setMinimumWidth(w + 10) + # self.setMinimumHeight(h+5) + self.setText(currentText) + + +class CheckableSpinBoxWidgets: + def __init__(self, isFloat=True): + if isFloat: + self.spinbox = FloatLineEdit() + else: + self.spinbox = SpinBox() + self.checkbox = QCheckBox("Activate") + self.spinbox.setEnabled(False) + self.checkbox.toggled.connect(self.spinbox.setEnabled) + + def value(self): + if not self.checkbox.isChecked(): + return + return self.spinbox.value() + + +class Label(QLabel): + def __init__(self, parent=None, force_html=False): + super().__init__(parent) + self._force_html = force_html + + def setText(self, text): + if self._force_html: + text = html_utils.paragraph(text) + super().setText(text) + + +class LatexLabel(QLabel): + def __init__(self, latexText, parent=None): + super().__init__(parent) + + latexText = latexText.replace("", "$") + if not latexText.startswith("$"): + latexText = f"${latexText}" + + if not latexText.endswith("$"): + latexText = f"{latexText}$" + + latexText = latexText.replace("
    ", "\n") + + pixmap = self.mathTex_to_QPixmap(latexText) + self.setPixmap(pixmap) + + def mathTex_to_QPixmap(self, mathTex): + # ---- set up a mpl figure instance ---- + + fig = matplotlib.figure.Figure() + fig.patch.set_facecolor("none") + fig.set_canvas(FigureCanvasAgg(fig)) + renderer = fig.canvas.get_renderer() + + # ---- plot the mathTex expression ---- + + ax = fig.add_axes([0, 0, 1, 1]) + ax.axis("off") + ax.patch.set_facecolor("none") + t = ax.text( + 0, 0, mathTex, ha="left", va="bottom", fontsize=13, color=TEXT_COLOR + ) + + # ---- fit figure size to text artist ---- + + fwidth, fheight = fig.get_size_inches() + fig_bbox = fig.get_window_extent(renderer) + + text_bbox = t.get_window_extent(renderer) + + tight_fwidth = text_bbox.width * fwidth / fig_bbox.width + tight_fheight = text_bbox.height * fheight / fig_bbox.height + + fig.set_size_inches(tight_fwidth, tight_fheight) + + # ---- convert mpl figure to QPixmap ---- + + buf, size = fig.canvas.print_to_buffer() + qimage = QImage.rgbSwapped(QImage(buf, size[0], size[1], QImage.Format_ARGB32)) + qpixmap = QPixmap(qimage) + + return qpixmap + + +class SwitchPlaneCombobox(QComboBox): + sigPlaneChanged = Signal(str, str) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.addItems(["xy", "zy", "zx"]) + self._previousPlane = "xy" + self.currentTextChanged.connect(self.emitPlaneChanged) + + def emitPlaneChanged(self, plane): + self.sigPlaneChanged.emit(self._previousPlane, plane) + self._previousPlane = plane + + def setPlane(self, plane): + self.setCurrentText(plane) + + def setCurrentText(self, text): + self._previousPlane = self.plane() + super().setCurrentText(text) + + def plane(self): + return self.currentText() + + def depthAxes(self): + plane = self.plane() + for axes in "xyz": + if axes not in plane: + return axes + + +class CheckableAction(QAction): + clicked = Signal(bool) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.setCheckable(True) + self.toggled.connect(self.emitClicked) + + def emitClicked(self, checked): + self.clicked.emit(checked) + + def setChecked(self, checked): + self.toggled.disconnect() + super().setChecked(checked) + self.toggled.connect(self.emitClicked) + + +class TimestampItem(LabelItem): + sigEditProperties = Signal(object) + sigRemove = Signal(object) + + def __init__( + self, + SizeY, + SizeX, + viewRange, + secondsPerFrame=1, + parent=None, + start_timedelta=None, + ): + self._secondsPerFrame = secondsPerFrame + self._x_pad = 3 + self._y_pad = 2 + self.xmin, self.ymin = 0, 0 + self.SizeY = SizeY + self.SizeX = SizeX + self._highlighted = False + self._parent = parent + if start_timedelta is None: + start_timedelta = datetime.timedelta(seconds=0) + self._start_timedelta = start_timedelta + self.clicked = False + super().__init__(self) + self.updateViewRange(viewRange) + self.createContextMenu() + + def setSecondsPerFrame(self, secondsPerFrame): + self._secondsPerFrame = secondsPerFrame + + def getBboxViewRange(self, viewRange): + xRange, yRange = viewRange + x0, x1 = xRange + y0, y1 = yRange + if x0 < 0: + x0 = 0 + + if x1 > self.SizeX: + x1 = self.SizeX + + if y0 < 0: + y0 = 0 + + if y1 > self.SizeY: + y1 = self.SizeY + + return x0, y0, x1, y1 + + def updateViewRange(self, viewRange): + x0, y0, x1, y1 = self.getBboxViewRange(viewRange) + + self.xmax = x1 + self.xmin = x0 + + self.ymax = y1 + self.ymin = y0 + + def createContextMenu(self): + self.contextMenu = QMenu() + action = QAction("Edit properties...", self.contextMenu) + action.triggered.connect(self.emitEditProperties) + self.contextMenu.addSeparator() + action = QAction("Remove", self.contextMenu) + action.triggered.connect(self.emitRemove) + self.contextMenu.addAction(action) + + def emitRemove(self): + self.sigRemove.emit(self) + + def mousePressed(self, x, y): + self.clicked = True + + def emitEditProperties(self): + self.setHighlighted(False) + self.sigEditProperties.emit(self.properties()) + + def isHighlighted(self): + return self._highlighted + + def setHighlighted(self, highlighted): + if self._highlighted and highlighted: + return + + if not self._highlighted and not highlighted: + return + + super().setText(self.text, bold=highlighted) + + self._highlighted = highlighted + + def showContextMenu(self, x, y): + self.contextMenu.popup(QPoint(int(x), int(y))) + + def setLocationProperty(self, loc: str): + self._loc = loc + + def properties(self): + properties = { + "color": self._color, + "loc": self._loc, + "font_size": int(self._font_size[:-2]), + "start_timedelta": self._start_timedelta, + "move_with_zoom": self._move_with_zoom, + } + return properties + + def draw(self, frame_i, **kwargs): + self.setProperties(**kwargs) + self.update(frame_i) + + def update(self, frame_i): + self.setPosFromLoc() + self.setText(frame_i) + + def setMoveWithZoomProperty(self, move_with_zoom): + self._move_with_zoom = move_with_zoom + + def updatePosViewRangeChanged(self, viewRange): + if self._loc == "custom": + textHeight = self.itemRect().height() + textWidth = self.itemRect().width() + x0p = self.pos().x() + y0p = self.pos().y() + xcp = x0p + textWidth / 2 + ycp = y0p + textHeight / 2 + x0 = self.xmin + y0 = self.ymin + x_range = self.xmax - x0 + y_range = self.ymax - y0 + Dx_perc = (xcp - x0) / x_range + Dy_perc = (ycp - y0) / y_range + + self.updateViewRange(viewRange) + + X0 = self.xmin + Y0 = self.ymin + + X_range = self.xmax - X0 + Y_range = self.ymax - Y0 + + Xcp = X0 + (Dx_perc * X_range) + Ycp = Y0 + (Dy_perc * Y_range) + X0p = Xcp - (textWidth / 2) + Y0p = Ycp - (textHeight / 2) + + y_pos_max = self.ymax - textHeight - self._y_pad + if Y0p > y_pos_max: + Y0p = y_pos_max + + x_pos_max = self.xmax - textWidth - self._x_pad + if X0p > x_pos_max: + X0p = x_pos_max + + self.setPos(X0p, Y0p) + else: + self.updateViewRange(viewRange) + self.setPosFromLoc() + + def setPosFromLoc(self): + textHeight = self.itemRect().height() + textWidth = self.itemRect().width() + if self._loc == "custom": + return + + if self._loc.find("top") != -1: + y0 = self._y_pad + self.ymin + else: + y0 = self.ymax - textHeight - self._y_pad + + if self._loc.find("left") != -1: + x0 = self._x_pad + self.xmin + else: + x0 = self.xmax - textWidth - self._x_pad + + self.setPos(x0, y0) + + def setProperties( + self, + color=(255, 255, 255), + font_size="13px", + loc="top-left", + start_timedelta=None, + move_with_zoom=False, + ): + if start_timedelta is not None: + self._start_timedelta = start_timedelta + self._color = color + self._loc = loc + self._font_size = font_size + self._move_with_zoom = move_with_zoom + + def move(self, xm, ym): + Dy = ym - self.yc + Dx = xm - self.xc + x0 = self.x0c + Dx + y0 = self.y0c + Dy + self.setPos(x0, y0) + + def mousePressed(self, x, y): + self.clicked = True + self.xc, self.yc = x, y + self.x0c = self.pos().x() + self.y0c = self.pos().y() + + def setText(self, frame_i): + if not isinstance(frame_i, int): + return + + seconds = frame_i * self._secondsPerFrame + timedelta = datetime.timedelta(seconds=round(seconds)) + + diff_seconds = timedelta.total_seconds() + self._start_timedelta.total_seconds() + if diff_seconds >= 0: + timedelta = datetime.timedelta(seconds=round(diff_seconds)) + text = str(timedelta) + else: + abs_diff = abs( + timedelta.total_seconds() + self._start_timedelta.total_seconds() + ) + abs_timedelta = datetime.timedelta(seconds=round(abs_diff)) + text = f"-{abs_timedelta}" + + # printl(timedelta) + super().setText(text, color=self._color, size=self._font_size) + + def addToAxis(self, ax): + ax.addItem(self) + + def removeFromAxis(self, ax): + ax.removeItem(self) + +# Cross-module imports (deferred to avoid import cycles) +from .inputs import ( + FloatLineEdit, + SpinBox, +) + diff --git a/cellacdc/widgets/toolbars/__init__.py b/cellacdc/widgets/toolbars/__init__.py new file mode 100644 index 000000000..41e8862da --- /dev/null +++ b/cellacdc/widgets/toolbars/__init__.py @@ -0,0 +1,52 @@ +"""Toolbars.""" + +from ._base import ( + GradientToolButton, + ManualBackgroundToolBar, + ManualTrackingToolBar, + OverlayChannelToolButton, + PointsLayerToolButton, + SavePointsLayerButton, + ToolBar, + ToolBarSeparator, + ToolButtonCustomColor, + ToolButtonTextIcon, + customAnnotToolButton, + rightClickToolButton, +) + +from .feature import ( + CopyLostObjectToolbar, + DrawClearRegionToolbar, + HighlightedIDToolbar, + MagicPromptsToolbar, + OverlayToolbar, + PointsLayersToolbar, + PromptableModelPointsLayerToolbar, + WandControlsToolbar, + WhitelistIDsToolbar, +) + +__all__ = [ + "GradientToolButton", + "ManualBackgroundToolBar", + "ManualTrackingToolBar", + "OverlayChannelToolButton", + "PointsLayerToolButton", + "SavePointsLayerButton", + "ToolBar", + "ToolBarSeparator", + "ToolButtonCustomColor", + "ToolButtonTextIcon", + "customAnnotToolButton", + "rightClickToolButton", + "CopyLostObjectToolbar", + "DrawClearRegionToolbar", + "HighlightedIDToolbar", + "MagicPromptsToolbar", + "OverlayToolbar", + "PointsLayersToolbar", + "PromptableModelPointsLayerToolbar", + "WandControlsToolbar", + "WhitelistIDsToolbar", +] diff --git a/cellacdc/widgets/toolbars/_base.py b/cellacdc/widgets/toolbars/_base.py new file mode 100644 index 000000000..78a43f843 --- /dev/null +++ b/cellacdc/widgets/toolbars/_base.py @@ -0,0 +1,558 @@ +"""Toolbars: _base.""" + +"""GUI widgets: toolbars.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +class ToolBarSeparator: + def __init__(self, width=5, toolbar: QToolBar = None): + self._parts = ( + QHWidgetSpacer(width=width), + QVLine(), + QHWidgetSpacer(width=width), + ) + self._actions = [] + self._toolbar = None + if toolbar is not None: + self.addToToolbar(toolbar) + + def addToToolbar(self, toolbar): + self._toolbar = toolbar + for part in self._parts: + action = toolbar.addWidget(part) + self._actions.append(action) + + def removeFromToolbar(self): + if self._toolbar is None: + return + + for action in self._actions: + self._toolbar.removeAction(action) + + +class ToolBar(QToolBar): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.widgetsWithShortcut = {} + + for child in self.children(): + if child.objectName() == "qt_toolbar_ext_button": + self.extendButton = child + self.extendButton.setIcon(QIcon(":expand.svg")) + break + + def addSeparator(self, width=5): + separator = ToolBarSeparator(width=width, toolbar=self) + return separator + + def removeSeparator(self, separator): + separator.removeFromToolbar() + + def addSpinBox(self, label=""): + spinbox = SpinBox(disableKeyPress=True) + if label: + spinbox.label = QLabel(label) + spinbox.labelAction = self.addWidget(spinbox.label) + + spinbox.action = self.addWidget(spinbox) + return spinbox + + def addButton(self, icon_str: str, text="", checkable=False): + action = QAction(QIcon(icon_str), text, self) + action.setCheckable(checkable) + self.addAction(action) + return action + + def addComboBox(self, items=None, label=""): + combobox = ComboBox() + + if items is not None: + combobox.addItems(items) + + if label: + combobox.label = QLabel(label) + combobox.labelAction = self.addWidget(combobox.label) + + combobox.action = self.addWidget(combobox) + return combobox + + def addLabel(self, text=""): + label = QLabel(text) + label.action = self.addWidget(label) + return label + + def addCheckBox(self, text="", checked=False): + checkbox = QCheckBox(text) + checkbox.setChecked(checked) + checkbox.action = self.addWidget(checkbox) + return checkbox + + +class rightClickToolButton(QToolButton): + sigRightClick = Signal(object) + sigLeftClick = Signal(object, object) + + def __init__(self, parent=None): + super().__init__(parent) + + def mousePressEvent(self, event): + if event.button() == Qt.MouseButton.LeftButton: + super().mousePressEvent(event) + self.sigLeftClick.emit(self, event) + elif event.button() == Qt.MouseButton.RightButton: + self.sigRightClick.emit(event) + + +class ToolButtonCustomColor(rightClickToolButton): + def __init__(self, symbol, color="r", parent=None): + super().__init__(parent=parent) + if not isinstance(color, QColor): + color = pg.mkColor(color) + self.symbol = symbol + self.setColor(color) + + def setColor(self, color): + self.penColor = color + self.brushColor = [0, 0, 0, 100] + self.brushColor[:3] = color.getRgb()[:3] + + def updateSymbol(self, symbol, update=True): + self.symbol = symbol + if not update: + return + self.update() + + def updateColor(self, color, update=True): + self.setColor(color) + if not update: + return + self.update() + + def updateIcon(self, symbol, color): + self.updateSymbol(symbol) + self.updateColor(color) + self.update() + + def paintEvent(self, event): + QToolButton.paintEvent(self, event) + p = QPainter(self) + w, h = self.width(), self.height() + sf = 0.6 + p.scale(w * sf, h * sf) + p.translate(0.5 / sf, 0.5 / sf) + symbol = pg.graphicsItems.ScatterPlotItem.Symbols[self.symbol] + pen = pg.mkPen(color=self.penColor, width=2) + brush = pg.mkBrush(color=self.brushColor) + try: + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setPen(pen) + p.setBrush(brush) + p.drawPath(symbol) + except Exception as e: + traceback.print_exc() + finally: + p.end() + + +class GradientToolButton(rightClickToolButton): + def __init__(self, colors=((255, 0, 0),), parent=None): + super().__init__(parent=parent) + self._qcolors = [pg.mkColor(c) for c in colors] + if len(self._qcolors) < 2: + self._qcolors.append(self._qcolors[0]) + + def paintEvent(self, event): + super().paintEvent(event) + + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + + pen = pg.mkPen(color=self._qcolors[-1], width=2) + + pad = 7 + + rect = self.rect().adjusted(pad, pad, -pad, -pad) # A little padding + + # Gradient: bottom to top + gradient = QLinearGradient(QPointF(rect.bottomLeft()), QPointF(rect.topLeft())) + + # Set color stops evenly distributed + num_colors = len(self._qcolors) + for i, color in enumerate(self._qcolors): + gradient.setColorAt(i / (num_colors - 1), color) + + if not self.isChecked(): + painter.setOpacity(0.4) + + painter.setBrush(gradient) + painter.setPen(pen) + painter.drawRect(rect) + + painter.end() + + +class PointsLayerToolButton(ToolButtonCustomColor): + sigEditAppearance = Signal(object) + sigShowIdsToggled = Signal(object, bool) + sigRemove = Signal(object) + + def __init__(self, symbol, color="r", parent=None): + super().__init__(symbol, color=color, parent=parent) + self.sigRightClick.connect(self.showContextMenu) + + def showContextMenu(self, event): + contextMenu = QMenu(self) + contextMenu.addSeparator() + + editAction = QAction("Edit points appearance...") + editAction.triggered.connect(self.editAppearance) + contextMenu.addAction(editAction) + + removeAction = QAction("Remove points") + removeAction.triggered.connect(self.emitRemove) + contextMenu.addAction(removeAction) + + showIdsAction = QAction("Show point ids") + showIdsAction.setCheckable(True) + showIdsAction.setChecked(True) + contextMenu.addAction(showIdsAction) + showIdsAction.toggled.connect(self.emitShowIdsToggled) + + contextMenu.exec(event.globalPos()) + + def emitRemove(self): + self.sigRemove.emit(self) + + def emitShowIdsToggled(self, checked): + self.sigShowIdsToggled.emit(self, checked) + + def editAppearance(self): + self.sigEditAppearance.emit(self) + + +class customAnnotToolButton(ToolButtonCustomColor): + sigRemoveAction = Signal(object) + sigKeepActiveAction = Signal(object) + sigModifyAction = Signal(object) + sigHideAction = Signal(object) + + def __init__( + self, symbol, color, keepToolActive=True, parent=None, isHideChecked=True + ): + super().__init__(symbol, color=color, parent=parent) + self.symbol = symbol + self.keepToolActive = keepToolActive + self.isHideChecked = isHideChecked + self.sigRightClick.connect(self.showContextMenu) + + def showContextMenu(self, event): + contextMenu = QMenu(self) + contextMenu.addSeparator() + + removeAction = QAction("Remove annotation") + removeAction.triggered.connect(self.removeAction) + contextMenu.addAction(removeAction) + + editAction = QAction("Modify annotation parameters...") + editAction.triggered.connect(self.modifyAction) + contextMenu.addAction(editAction) + + hideAction = QAction("Hide annotations") + hideAction.setCheckable(True) + hideAction.setChecked(self.isHideChecked) + hideAction.triggered.connect(self.hideAction) + contextMenu.addAction(hideAction) + + keepActiveAction = QAction("Keep tool active after using it") + keepActiveAction.setCheckable(True) + keepActiveAction.setChecked(self.keepToolActive) + keepActiveAction.triggered.connect(self.keepToolActiveActionToggled) + contextMenu.addAction(keepActiveAction) + + contextMenu.exec(event.globalPos()) + + def keepToolActiveActionToggled(self, checked): + self.keepToolActive = checked + self.sigKeepActiveAction.emit(self) + + def modifyAction(self): + self.sigModifyAction.emit(self) + + def removeAction(self): + self.sigRemoveAction.emit(self) + + def hideAction(self, checked): + self.isHideChecked = checked + self.sigHideAction.emit(self) + + +class ToolButtonTextIcon(rightClickToolButton): + def __init__(self, text="", parent=None): + super().__init__(parent=parent) + self._text = text + self._penColor = _palettes.text_pen_color() + + def setText(self, text): + self._text = text + self.update() + + def text(self): + return self._text + + def paintEvent(self, event): + QToolButton.paintEvent(self, event) + p = QPainter(self) + + pen = pg.mkPen(color=self._penColor, width=2) + p.setPen(pen) + + w, h = self.width(), self.height() + sf = 0.7 + rect_w = w * sf + rect_h = h * sf + x = (w - rect_w) / 2 + y = (h - rect_h) / 2 + rect = QRectF(x, y, rect_w, rect_h) + + font = p.font() + font.setBold(True) + font.setPixelSize(int(h / len(self._text))) + p.setFont(font) + + p.drawText(rect, Qt.AlignCenter, self._text) + p.end() + + +class OverlayChannelToolButton(GradientToolButton): + def __init__( + self, + channel_name: str, + lut_item: myHistogramLUTitem, + shortcut="0", + parent=None, + ): + super().__init__(colors=lut_item.gradient.getLookupTable(256), parent=parent) + self._channel_name = channel_name + + lut_item.sigGradientChanged.connect(self.updateColors) + + self.setToolTip(f'Show/hide "{channel_name}" channel\n\nShortcut: {shortcut}') + + self.setCheckable(True) + + def channelName(self): + return self._channel_name + + def updateColors(self, lut_item): + colors = lut_item.gradient.getLookupTable(256) + self._qcolors = [pg.mkColor(c) for c in colors] + self.update() + + def setVisible(self, visible: bool): + super().setVisible(visible) + if not hasattr(self, "action"): + return + + self.action.setVisible(visible) + +# Cross-module imports (deferred to avoid import cycles) +from ..canvas.histogram import ( + myHistogramLUTitem, +) +from ..controls.inputs import ( + ComboBox, + SpinBox, +) + diff --git a/cellacdc/widgets/toolbars/feature.py b/cellacdc/widgets/toolbars/feature.py new file mode 100644 index 000000000..26c1e3be7 --- /dev/null +++ b/cellacdc/widgets/toolbars/feature.py @@ -0,0 +1,879 @@ +"""Toolbars: feature.""" + +"""GUI widgets: toolbars.""" + +from collections import defaultdict, deque +from typing import Dict, List, Union, Iterable, Sequence +import os +import sys +import operator +import time +import re +import datetime +import numpy as np +import pandas as pd +import math +import traceback +import logging +import textwrap +import random + +from functools import partial +from math import ceil + +import skimage.draw +import skimage.morphology + +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from qtpy.QtCore import ( + Signal, + QTimer, + Qt, + QPoint, + QUrl, + Property, + QPropertyAnimation, + QEasingCurve, + QLocale, + QSize, + QRect, + QPointF, + QRect, + QPoint, + QEasingCurve, + QRegularExpression, + QEvent, + QEventLoop, + QPropertyAnimation, + QObject, + QItemSelectionModel, + QAbstractListModel, + QModelIndex, + QByteArray, + QDataStream, + QMimeData, + QAbstractItemModel, + QIODevice, + QItemSelection, + PYQT6, + QRectF, +) +from qtpy.QtGui import ( + QFont, + QPalette, + QColor, + QPen, + QKeyEvent, + QBrush, + QPainter, + QRegularExpressionValidator, + QIcon, + QPixmap, + QKeySequence, + QLinearGradient, + QShowEvent, + QDesktopServices, + QFontMetrics, + QGuiApplication, + QLinearGradient, + QImage, + QCursor, + QPicture, +) +from qtpy.QtWidgets import ( + QTextEdit, + QLabel, + QProgressBar, + QHBoxLayout, + QToolButton, + QCheckBox, + QApplication, + QWidget, + QVBoxLayout, + QMainWindow, + QTreeWidgetItemIterator, + QLineEdit, + QSlider, + QSpinBox, + QGridLayout, + QRadioButton, + QScrollArea, + QSizePolicy, + QComboBox, + QPushButton, + QScrollBar, + QGroupBox, + QAbstractSlider, + QDoubleSpinBox, + QWidgetAction, + QAction, + QTabWidget, + QAbstractSpinBox, + QToolBar, + QStyleOptionSpinBox, + QStyle, + QDialog, + QSpacerItem, + QFrame, + QMenu, + QActionGroup, + QListWidget, + QPlainTextEdit, + QFileDialog, + QListView, + QAbstractItemView, + QTreeWidget, + QTreeWidgetItem, + QListWidgetItem, + QLayout, + QStylePainter, + QGraphicsBlurEffect, + QGraphicsProxyWidget, + QGraphicsObject, + QButtonGroup, + QStyleOptionSlider, +) +import qtpy.compat + +import pyqtgraph as pg + +pg.setConfigOption("imageAxisOrder", "row-major") + +from ... import utils, measurements, is_mac, is_win, html_utils, is_linux +from ... import printl, settings_folderpath +from ... import colors, config +from ... import html_path +from ... import _palettes +from ... import load +from ... import apps +from ... import plot +from ... import annotate +from ... import urls +from ... import _core, core +from ... import QtScoped +from ... import prompts +from ...acdc_regex import float_regex +from ...config import PREPROCESS_MAPPER +from ... import _base_widgets + +from ...components.palette import ( # noqa: E402 + BASE_COLOR, + Gradients, + GradientsImage, + GradientsLabels, + LINEEDIT_INVALID_ENTRY_STYLESHEET, + LINEEDIT_WARNING_STYLESHEET, + LISTWIDGET_STYLESHEET, + PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR, + PROGRESSBAR_QCOLOR, + TEXT_COLOR, + TREEWIDGET_STYLESHEET, + cmaps, + font, + getCustomGradients, + nonInvertibleCmaps, + sign_int_mapper, + str_to_operator_mapper, +) +from ...components.progress import QtHandler, QLog, XStream # noqa: E402 +from ...components.buttons import * # noqa: E402, F403 +from ...components.layout import * # noqa: E402, F403 +from ...components.inputs_basic import * # noqa: E402, F403 +from ...components.path_controls import * # noqa: E402, F403 + +from ...components.lists import * # noqa: E402, F403 +from ...components.base import QBaseWindow # noqa: E402 +from ...components.progress import ( # noqa: E402 + LoadingCircleAnimation, + NoneWidget, + ProgressBar, + ProgressBarWithETA, + QLogConsole, +) + +from ._base import ( + ToolBar, +) + +class CopyLostObjectToolbar(ToolBar): + sigCopyAllObjects = Signal(int, int) + + def __init__(self, *args) -> None: + super().__init__(*args) + + action = self.addButton(":copyContour_all.svg") + # action.setShortcut('Alt+C') + action.keyPressShortcut = KeySequenceFromText("Alt+C") + action.setToolTip("Copy all lost objects\n\nShortcut: Alt+C") + self.widgetsWithShortcut["Copy all lost objects"] = action + + action.triggered.connect(self.emitSigCopyAllObjects) + + self.addSeparator() + + self.maxOverlapNumberControl = self.addSpinBox( + label="Maximum overlap to accept lost object [%]: " + ) + self.maxOverlapNumberControl.setMinimum(0) + self.maxOverlapNumberControl.setValue(10) + tooltip = ( + "Maximum overlap to accept lost object [%]\n\n" + "If the overlap between the lost object and an object already " + "existing is greater than this value,\n" + "the lost object will not be added." + ) + self.maxOverlapNumberControl.setToolTip(tooltip) + self.maxOverlapNumberControl.label.setToolTip(tooltip) + + self.addSeparator() + + self.untilFrameNumberControl = self.addSpinBox( + label="Copy lost object(s) for the next number of frames: " + ) + self.untilFrameNumberControl.setMinimum(0) + self.untilFrameNumberControl.setValue(0) + + def emitSigCopyAllObjects(self): + self.sigCopyAllObjects.emit( + self.untilFrameNumberControl.value(), self.maxOverlapNumberControl.value() + ) + + +class DrawClearRegionToolbar(ToolBar): + def __init__(self, *args) -> None: + super().__init__(*args) + + group = QButtonGroup() + group.setExclusive(True) + self.clearTouchingObjsRadioButton = QRadioButton("Clear all touching objects") + self.clearOnlyEnclosedObjsRadioButton = QRadioButton( + "Clear only fully enclosed objects" + ) + self.clearOnlyEnclosedObjsRadioButton.setChecked(True) + group.addButton(self.clearTouchingObjsRadioButton) + group.addButton(self.clearOnlyEnclosedObjsRadioButton) + + self.addWidget(self.clearTouchingObjsRadioButton) + self.addWidget(self.clearOnlyEnclosedObjsRadioButton) + + self.addSeparator() + + self.numZslicesUpSpinbox = self.addSpinBox( + label="Num. of z-slices to clear upwards: " + ) + self.numZslicesUpSpinbox.setMinimum(0) + self.numZslicesUpSpinbox.setValue(0) + + self.numZslicesDownSpinbox = self.addSpinBox( + label="Num. of z-slices to clear downwards: " + ) + self.numZslicesDownSpinbox.setMinimum(0) + self.numZslicesDownSpinbox.setValue(0) + + def setZslicesControlEnabled(self, enabled, SizeZ=None): + self.numZslicesUpSpinbox.labelAction.setVisible(enabled) + self.numZslicesUpSpinbox.action.setVisible(enabled) + + self.numZslicesDownSpinbox.labelAction.setVisible(enabled) + self.numZslicesDownSpinbox.action.setVisible(enabled) + + if SizeZ is None: + return + + self.numZslicesUpSpinbox.setMaximum(SizeZ) + self.numZslicesDownSpinbox.setMaximum(SizeZ) + + def zRange(self, z_slice, SizeZ): + if z_slice is None: + zRange = (0, SizeZ) + return zRange + + numZslicesUp = self.numZslicesUpSpinbox.value() + numZslicesDown = self.numZslicesDownSpinbox.value() + + zmin = z_slice - numZslicesDown + zmax = z_slice + numZslicesDown + 1 + + zmin = zmin if zmin >= 0 else 0 + zmax = zmax if zmax <= SizeZ else SizeZ + + return (zmin, zmax) + + +class WhitelistIDsToolbar(ToolBar): + sigWhitelistChanged = Signal(list) + sigViewOGIDs = Signal(bool) + sigWhitelistAccepted = Signal(list) + sigAddNewIDs = Signal(bool) + sigLoadOGLabs = Signal() + sigTrackOGagainstPreviousFrame = Signal(bool) + + def __init__(self, addNewIDToggleState, *args) -> None: + super().__init__(*args) + + whitelistLineEditLabel = QLabel("Whitelist IDs: ") + self.addWidget(whitelistLineEditLabel) + + self.whitelistLineEdit = WhitelistLineEdit(whitelistLineEditLabel, parent=self) + self.whitelistLineEdit.sigEnterPressed.connect(self.accept) + self.whitelistLineEdit.sigIDsChanged.connect(self.emitWhitelistChanged) + self.addWidget(self.whitelistLineEdit) + + # accept button + self.acceptButton = self.addButton(":greenTick.svg") + self.acceptButton.triggered.connect(self.accept) + + # add a view OG toggle + self.viewOGToggle = self.addButton(":eye.svg", checkable=True) + viewOGTooltip = ( + "View the non-whitelisted segmentation mask.\n\n" + "You can activate this to add new IDs to the whitelist,\n" + "correct tracking errors, etc." + ) + self.viewOGToggle.setChecked(True) + self.viewOGToggle.setToolTip(viewOGTooltip) + self.viewOGToggle.setShortcut("Shift+K") + key = "View the non-whitelisted segmentation mask" + self.widgetsWithShortcut[key] = self.viewOGToggle + + self.viewOGToggle.toggled.connect(self.emitViewOGIDs) + self.emitViewOGIDs(True) + + # add a Toggle to add new IDs + self.addNewIDToggle = QCheckBox("Automatically add new IDs to whitelist") + self.addNewIDToggle.setChecked(addNewIDToggleState) + self.addWidget(self.addNewIDToggle) + self.addNewIDToggle.toggled.connect(self.emitAddNewIDs) + self.emitAddNewIDs(addNewIDToggleState) + + self.addSeparator() + + # add a button to load og df + self.loadOGButton = self.addButton(":open_file.svg") + self.loadOGButton.triggered.connect(self.sigLoadOGLabs.emit) + self.loadOGButton.setToolTip( + "Select which segmentation mask file to load as the non-whitelisted masks" + ) + + self.TrackOGagainstPreviousFrameButton = self.addButton(":segment.svg") + self.TrackOGagainstPreviousFrameButton.triggered.connect( + self.sigTrackOGagainstPreviousFrame.emit + ) + self.TrackOGagainstPreviousFrameButton.setToolTip( + "Track the non-whitelisted segmentation masks against the previous frame and copy over successfull tacks" + ) + + self.addSeparator() + + # add an info button + self.infoButton = self.addButton(":info.svg") + self.infoButton.triggered.connect(self.showInfo) + + # add a spacer to the toolbar + spacer = QWidget() + spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred) + self.addWidget(spacer) + + def emitWhitelistChanged(self, whitelist): + self.sigWhitelistChanged.emit(whitelist) + + def emitViewOGIDs(self, checked): + self.sigViewOGIDs.emit(checked) + + def accept(self): + try: + whitelist = self.whitelistLineEdit.IDs + except AttributeError as e: + if "has no attribute 'IDs'" in str(e): + whitelist = list() + self.viewOGToggle.toggled.disconnect() + self.viewOGToggle.setChecked(False) + self.viewOGToggle.toggled.connect(self.emitViewOGIDs) + self.sigWhitelistAccepted.emit(whitelist) + + def emitAddNewIDs(self, checked): + self.sigAddNewIDs.emit(checked) + + def showInfo(self): + msg = myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + This function is used to track a subset of segmented objects.

    + + To add new IDs to the white list, click with left mouse button on the + object to add.
    + You can also write directly into the Whitelist IDs widget
    + and separate the IDs by commas.

    + + After adding the IDs, click on the "Accept" button to remove the + non-whitelisted objects.
    + Every time you visit a new frame, the non-whitelisted objects will + be removed automatically.

    + Use the "Eye" button to view the non-whitelisted segmentation masks.
    + This will allow you to correct tracking errors, add new IDs to the + white list, etc.

    + + If you previously saved the whitelisted masks, you can load the + non-whitelisted file
    + by clicking on the "Load file" button to restart from where you + left last time. + """) + msg.information(self, "White list IDs", txt) + + +class MagicPromptsToolbar(ToolBar): + sigPromptTypeChanged = Signal(object, str) + sigComputeOnZoom = Signal(object) + sigComputeOnImage = Signal(object) + sigClearPoints = Signal(object) + sigClearPointsOnZmom = Signal(object) + sigInitSelectedModel = Signal(str, object, list, list, str, object) + sigViewModelParams = Signal(str, object, list, list, str, object, object, object) + sigInterpolateZslice = Signal(bool) + + def __init__(self, parent=None): + super().__init__(parent) + + self._parent = parent + + prompt_types = ("Points",) + + self.selectModelAction = self.addButton(":select-list.svg") + self.selectModelAction.setToolTip("Select the promptable model to use") + + self.viewModelParamsAction = self.addButton(":view.svg") + self.viewModelParamsAction.setToolTip( + "View the currently selected model parameters" + ) + self.viewModelParamsAction.setDisabled(True) + + self.addSeparator() + + self.promptTypeCombobox = self.addComboBox( + prompt_types, + label="Prompt type: ", + ) + + self.addSeparator() + + self.interpolateZslicesCheckbox = self.addCheckBox( + "Interpolate points on missing z-slices", checked=False + ) + self.interpolateZslicesCheckbox.setToolTip( + "If checked, when working with 3D segmentation masks, you can " + "add points on some z-slices only and the points on the missing " + "z-slices will be determined by linear interpolation.\n\n" + "This is useful when working with 2D models that segments " + "each z-slice independently.\n\n" + "NOTE: The points will be added only when running the model and " + "removed afterwards." + ) + + self.addSeparator() + + self.computeOnZoomAction = self.addButton(":compute-zoom.svg") + self.computeOnZoomAction.setToolTip( + "Compute the segmentation on the zoomed area of the image (faster)" + ) + + self.computeAction = self.addButton(":compute.svg") + self.computeAction.setToolTip("Compute the segmentation on the whole image") + + self.clearPointsAction = self.addButton(":clear-points.svg") + self.clearPointsAction.setToolTip("Clear all points") + self.clearPointsAction.setDisabled(True) + + self.clearPointsActionOnZoom = self.addButton(":clear-points-zoom.svg") + self.clearPointsActionOnZoom.setToolTip( + "Clear all points on the zoomed area of the image" + ) + self.clearPointsActionOnZoom.setDisabled(True) + + self.addSeparator() + + self.infoAction = self.addButton(":info.svg") + self.infoAction.setToolTip("Show instructions how to use promptable models") + + self.addSeparator() + + self.infoAction.triggered.connect(self.showHelp) + self.selectModelAction.triggered.connect(self.selectModel) + self.viewModelParamsAction.triggered.connect(self.viewModelParams) + self.promptTypeCombobox.sigTextChanged.connect(self.emitPromptTypeChanged) + self.computeOnZoomAction.triggered.connect(self.emitSigComputeOnZoom) + self.computeAction.triggered.connect(self.emitSigComputeOnImage) + self.clearPointsAction.triggered.connect(self.emitSigClearPoints) + self.clearPointsActionOnZoom.triggered.connect(self.emitSigClearPointsOnZoom) + self.interpolateZslicesCheckbox.toggled.connect(self.sigInterpolateZslice.emit) + + def showHelp(self): + msg = myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + This toolbar allows you to use promptable models for + segmentation.

    + + To use a promptable model, first select the model by clicking on the + "Select model" button.
    + This will open a dialog where you can select the model to use.

    + + After selecting the model, you can view the model parameters + by clicking on the "View model parameters" button.

    + + To add points to the image, make sure you have points layer correctly + initialised. You should see controls
    + called "Left-click ID" and "Right-click ID".

    + + You can add points for a new object by left-clicking on the image, + while you can add points
    + for the same object by right-clicking. + To delete a point, click on it again.

    + + To change the right-click ID, + you can either type in the corresponding control,
    + or type the object id on the keyboard followed by "Enter".

    + + To add negative prompts (i.e., for the background), use the + same action you use to delete objects
    + (default is middle-click on Windows and Cmd+Click on MacOS).

    + Note that you can also add object-specific negative prompts (i.e., + they affect only that object)
    + by adding the negative prompt on the newly segmented object + directly.

    + + Once you are happy with the added points, click either the + "Compute on zoomed area"
    + button or the "Compute on whole image" button.

    + + Finally, you can clear all points by clicking on the + "Clear points" button.

    + + Note that you can also save the points by clicking on the + "Save points" button to load them later and start from + where you left.

    + """) + msg.information(self, "Promptable models help", txt) + + def emitSigClearPoints(self): + self.sigClearPoints.emit(self) + + def emitSigClearPointsOnZoom(self): + self.sigClearPointsOnZmom.emit(self) + + def emitSigComputeOnZoom(self): + self.sigComputeOnZoom.emit(self) + + def emitSigComputeOnImage(self): + self.sigComputeOnImage.emit(self) + + def selectModel(self): + win = apps.SelectPromptableModelDialog(parent=self._parent) + win.exec_() + if win.cancel: + print("Promptable model selection cancelled") + return + + model_name = win.model_name + print(f"Importing promptable model {model_name}...") + + # Download model weights, consistent with gui.py + downloadWin = apps.downloadModel(model_name, parent=self._parent) + downloadWin.download() + + acdcPromptSegment = utils.import_promptable_segment_module(model_name) + init_argspecs, segment_argspecs = utils.getModelArgSpec(acdcPromptSegment) + + try: + help_url = acdcPromptSegment.url_help() + except AttributeError: + help_url = None + + self._model_name = model_name + self._acdcPromptSegment = acdcPromptSegment + self._init_argspecs = init_argspecs + self._segment_argspecs = segment_argspecs + self._help_url = help_url + + self.sigInitSelectedModel.emit( + model_name, + acdcPromptSegment, + init_argspecs, + segment_argspecs, + help_url, + self, + ) + + def setInitializedModel(self, init_kwargs, segment_kwargs): + self._init_kwargs = init_kwargs + self._segment_kwargs = segment_kwargs + + def viewModelParams(self): + self.sigViewModelParams.emit( + self._model_name, + self._acdcPromptSegment, + self._init_argspecs, + self._segment_argspecs, + self._help_url, + self._init_kwargs, + self._segment_kwargs, + self, + ) + + def emitPromptTypeChanged(self, text): + self.sigPromptTypeChanged.emit(self, text) + + +class PointsLayersToolbar(ToolBar): + sigAddPointsLayer = Signal() + + def __init__(self, name="Points layers", parent=None): + + super().__init__(name, parent) + + self.guiWin = parent + + self.setContextMenuPolicy(Qt.PreventContextMenu) + + self.addPointsLayerAction = self.addButton(":addPointsLayer.svg") + + self.addSeparator() + + self.pointsLayersLabel = self.addLabel("Points layers: ") + + self.addPointsLayerAction.triggered.connect(self.emitAddPointsLayer) + self.doAddPointsZslicesInterpolation = False + + def emitAddPointsLayer(self): + self.sigAddPointsLayer.emit() + + def fromActionToDataFrame(self, action, posData, isSegm3D=False): + df = pd.DataFrame(columns=["frame_i", "Cell_ID", "z", "y", "x", "id"]) + frames_vals = [] + IDs = [] + zz = [] + yy = [] + xx = [] + ids = [] + pos_i = self.guiWin.pos_i + if pos_i not in action.pointsData: + printl( + "No points data for position", pos_i + ) # should really not happen, but its not a disaster if it does + return df + pointsDataPos = action.pointsData[pos_i] + for frame_i, framePointsData in pointsDataPos.items(): + if posData.SizeZ > 1: + for z, zSlicePointsData in framePointsData.items(): + yyxx = zip(zSlicePointsData["y"], zSlicePointsData["x"]) + for y, x in yyxx: + if isSegm3D: + ID = posData.lab[int(z), int(y), int(x)] + else: + ID = posData.lab[int(y), int(x)] + frames_vals.append(frame_i) + IDs.append(ID) + zz.append(z) + yy.append(y) + xx.append(x) + ids.extend(zSlicePointsData["id"]) + else: + yyxx = zip(framePointsData["y"], framePointsData["x"]) + for y, x in yyxx: + ID = posData.lab[int(y), int(x)] + frames_vals.append(frame_i) + IDs.append(ID) + yy.append(y) + xx.append(x) + ids.extend(framePointsData["id"]) + df["frame_i"] = frames_vals + df["Cell_ID"] = IDs + df["y"] = yy + df["x"] = xx + df["id"] = ids + if zz: + df["z"] = zz + + df = self.addPointsZslicesInterpolation(df, posData.lab, isSegm3D) + + return df + + def addPointsZslicesInterpolation( + self, df: pd.DataFrame, lab: np.ndarray, isSegm3D: bool + ): + if not self.doAddPointsZslicesInterpolation: + return df + + if not isSegm3D: + return df + + if "z" not in df.columns: + return df + + df_new_rows = [] + for (frame_i, point_id), df_id in df.groupby(["frame_i", "id"]): + xx = df_id["x"].values + yy = df_id["y"].values + zz = df_id["z"].values + + p0, d = core.linear_fit_3d(xx, yy, zz) + + new_row_df = df_id.iloc[[0]].copy() + + z0, z1 = int(np.min(zz)), int(np.max(zz)) + for z in range(z0, z1 + 1): + if z in zz: + continue + + t_int = (z - p0[2]) / d[2] + x_new, y_new, z_new = p0 + t_int * d + new_row_df["z"] = round(z_new) + new_row_df["y"] = round(y_new) + new_row_df["x"] = round(x_new) + + Cell_ID = lab[int(round(z_new)), int(round(y_new)), int(round(x_new))] + new_row_df["Cell_ID"] = Cell_ID + + df_new_rows.append(new_row_df.copy()) + + if not df_new_rows: + return df + + df_new = pd.concat(df_new_rows, ignore_index=True) + df = pd.concat([df, df_new], ignore_index=True) + df = df.sort_values(by=["frame_i", "id", "z"]).reset_index(drop=True) + + return df + + +class PromptableModelPointsLayerToolbar(PointsLayersToolbar): + def __init__(self, name="Promptable model points layers", parent=None): + super().__init__(name, parent=parent) + + self.isPointsLayerInit = False + + self.addPointsLayerAction.setDisabled(True) + self.addPointsLayerAction.setVisible(False) + + def pointsLayerDf(self, posData, isSegm3D=False): + for action in self.actions()[1:]: + if not hasattr(action, "button"): + continue + + df = self.fromActionToDataFrame(action, posData, isSegm3D=isSegm3D) + return df + + def scatterItem(self): + for action in self.actions()[1:]: + if not hasattr(action, "button"): + continue + + return action.scatterItem + + +class OverlayToolbar(ToolBar): + sigSetTranspacency = Signal(bool) + sigSetSingleChannel = Signal(bool) + + def __init__(self, name="Overlay tools", parent=None): + + super().__init__(name, parent) + + self.guiWin = parent + + self.setContextMenuPolicy(Qt.PreventContextMenu) + + self.addSeparator() + + self.transparencyCheckbox = self.addCheckBox( + text="True transparency (RGBA composite)" + ) + + self.transparencyCheckbox.setToolTip( + "Activate to achieve true pixel-wise transparency where " + "the pixel intensity is 0 or set to 0 using the " + "LUT sliders on the left of the images.\n\n" + "Since it is significantly slower, we recommended to activate this " + "only if you need to export images for figures." + ) + + self.addSeparator() + + self.singleChannelCheckbox = self.addCheckBox(text="Single channel") + + self.singleChannelCheckbox.setToolTip( + "When single channel mode is activated, selecting a channel " + "will display only that channel in the overlay." + ) + + self.transparencyCheckbox.toggled.connect(self.sigSetTranspacency.emit) + self.singleChannelCheckbox.toggled.connect(self.sigSetSingleChannel.emit) + + def setTransparent(self, transparent: bool): + self.transparencyCheckbox.setChecked(transparent) + + def isTransparent(self): + return self.transparencyCheckbox.isChecked() + + def isSingleChannel(self): + return self.singleChannelCheckbox.isChecked() + + +class HighlightedIDToolbar(ToolBar): + sigIDChanged = Signal(int) + + def __init__(self, name="Highlighted ID", parent=None): + + super().__init__(name, parent) + + self.spinbox = self.addSpinBox("Highlighted ID: ") + self.spinbox.valueChanged.connect(self.emitSigIDChanged) + + self.addSeparator() + + def emitSigIDChanged(self, *args, **kwargs): + self.sigIDChanged.emit(self.spinbox.value()) + + def setIDNoSignals(self, ID: int): + self.spinbox.blockSignals(True) + self.spinbox.setValue(ID) + self.spinbox.blockSignals(False) + + +class WandControlsToolbar(ToolBar): + def __init__(self, name="Magic wand controls", parent=None): + super().__init__(name, parent) + + self.toleranceSpinbox = self.addSpinBox("Tolerance [%]: ") + self.toleranceSpinbox.setMinimum(0) + self.toleranceSpinbox.setMaximum(100) + self.toleranceSpinbox.setValue(5) + self.toleranceSpinbox.setToolTip( + "The tolerance is calculated as a percentage of the minimum-maximum " + "pixel values range of the loaded dataset.\n\n" + "If tolerance is greater than 0, the pixels adjacent to the added " + "pixels with value within +- tolerance will be considered part of " + "the object." + ) + self.addLabel(r"% of min-max intensity range ") + + self.addSeparator() + + self.autoFillHolesCheckbox = self.addCheckBox("Auto-fill holes") + + self.addSeparator() + + self.useConvexHullCheckbox = self.addCheckBox("Use convex hull mask") + + self.addSeparator() + +# Cross-module imports (deferred to avoid import cycles) +from ..controls.dialogs import ( + myMessageBox, +) +from ..controls.inputs import ( + KeySequenceFromText, + WhitelistLineEdit, +) + diff --git a/cellacdc/workers.py b/cellacdc/workers.py deleted file mode 100755 index 25feb6f72..000000000 --- a/cellacdc/workers.py +++ /dev/null @@ -1,6745 +0,0 @@ -import re -import os -import shutil -import time -import json -import concurrent.futures -from functools import partial -from collections import defaultdict, deque -import itertools - -from typing import Union, List, Dict, Callable, Any, Tuple, Iterable - -from functools import wraps -import numpy as np -import pandas as pd -import h5py -import traceback - -import skimage.io -import skimage.measure -import skimage.exposure - -import queue - -from tqdm import tqdm - -from qtpy.QtCore import ( - Signal, QObject, QMutex, QWaitCondition -) - -from cellacdc import html_utils - -from . import ( - load, myutils, core, prompts, printl, config, - segm_re_pattern, io -) -from . import transformation, measurements, cca_functions -from .path import copy_or_move_tree -from . import features, plot -from . import core -from . import cca_df_colnames, lineage_tree_cols, default_annot_df -from . import cca_df_colnames_with_tree -from . import cli -from .utils import resize -from . import segm_utils - -DEBUG = False - -def worker_exception_handler(func): - @wraps(func) - def run(self): - try: - func(self) - except Exception as error: - printl(traceback.format_exc()) - try: - self.dataQ.clear() - except Exception as err: - pass - - # Some workers have both self.critical and self.signals.critical - # errors but only one of them is connected --> emit both just - # in case - try: - self.critical.emit((self, error)) - except Exception as err: - self.signals.critical.emit((self, error)) - - try: - self.signals.critical.emit((self, error)) - except Exception as err: - self.critical.emit((self, error)) - - try: - self.mutex.unlock() - except Exception as err: - pass - return run - -class workerLogger: - def __init__(self, sigProcess): - self.sigProcess = sigProcess - - def log(self, message, level='INFO'): - try: - self.sigProcess.emit(str(message), level) - except Exception as err: - print(message, level) - try: - traceback_format = traceback.format_exc() - print(traceback_format) - except Exception as err: - pass - printl(err) - finally: - pass - - def info(self, message): - self.log(message, level='INFO') - - def warning(self, message): - self.log(message, level='WARNING') - - def exception(self, message): - self.log(message, level='EXCEPTION') - -class signals(QObject): - progress = Signal(str, object) - finished = Signal(object) - initProgressBar = Signal(int) - progressBar = Signal(int) - critical = Signal(object) - dataIntegrityWarning = Signal(str) - dataIntegrityCritical = Signal() - sigLoadingFinished = Signal() - sigLoadingNewChunk = Signal(object) - resetInnerPbar = Signal(int) - progress_tqdm = Signal(int) - signal_close_tqdm = Signal() - create_tqdm = Signal(int) - innerProgressBar = Signal(int) - sigPermissionError = Signal(str, object) - sigSelectSegmFiles = Signal(object, object) - sigSelectAcdcOutputFiles = Signal(object, object, str, bool, bool) - sigSelectSpotmaxRun = Signal(object, object, object, str, bool, bool) - sigSetMeasurements = Signal(object) - sigInitAddMetrics = Signal(object, object) - sigUpdatePbarDesc = Signal(str) - sigComputeVolume = Signal(int, object) - sigAskStopFrame = Signal(object) - sigWarnMismatchSegmDataShape = Signal(object) - sigErrorsReport = Signal(dict, dict, dict) - sigMissingAcdcAnnot = Signal(dict) - sigRecovery = Signal(object) - sigInitInnerPbar = Signal(int) - sigUpdateInnerPbar = Signal(int) - sigSelectFile = Signal(str, str, str) - sigAskCopyCca = Signal(str) - sigSelectFilesWithText = Signal(str, object, str, object) - sigAskRunNow = Signal(object) - -class AutoPilotWorker(QObject): - finished = Signal() - critical = Signal(object) - progress = Signal(str, object) - sigStarted = Signal() - sigStopTimer = Signal() - - def __init__(self, guiWin): - QObject.__init__(self) - self.logger = workerLogger(self.progress) - self.guiWin = guiWin - self.app = guiWin.app - # self.timer = timer - - def timerCallback(self): - pass - - def stop(self): - self.sigStopTimer.emit() - self.finished.emit() - - def run(self): - self.sigStarted.emit() - -class FindNextNewIdWorker(QObject): - def __init__(self, posData, guiWin): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.posData = posData - self.guiWin = guiWin - - @worker_exception_handler - def run(self): - prev_IDs = None - next_frame_i = -1 - for frame_i, data_dict in enumerate(self.posData.allData_li): - lab = data_dict['labels'] - rp = data_dict['regionprops'] - IDs = data_dict['IDs'] - if lab is None: - lab = self.posData.segm_data[frame_i] - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - - if prev_IDs is None: - prev_IDs = IDs - continue - - newIDs = [ID for ID in IDs if ID not in prev_IDs] - if newIDs: - next_frame_i = frame_i - break - prev_IDs = IDs - - self.signals.finished.emit(next_frame_i) - -class SegForLostIDsWorker(QObject): - sigAskInit = Signal() - sigAskInstallModel = Signal(str) - sigshowImageDebug = Signal(object) - sigStoreData = Signal(bool) - sigUpdateRP = Signal(bool, bool) - # sigGetData = Signal() - # sigGet2Dlab = Signal() - # sigGetTrackedLostIDs = Signal() - # sigGetBrushID = Signal() - sigSegForLostIDsWorkerAskInstallGPU = Signal(str, bool) - sigTrackManuallyAddedObject = Signal(object, object, bool, bool) - - def __init__(self, guiWin, mutex, waitCond, debug=False): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.guiWin = guiWin - self.mutex = mutex - self.waitCond = waitCond - self._debug = debug - - def emitSigAskInit(self): - self.mutex.lock() - self.sigAskInit.emit() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitSigShowImageDebug(self, img): - # self.mutex.lock() - self.sigshowImageDebug.emit(img) - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - def emitSigStoreData(self, autosave): - self.mutex.lock() - self.sigStoreData.emit(autosave) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitSigUpdateRP(self, wl_track_og_curr, wl_update): - self.mutex.lock() - self.sigUpdateRP.emit(wl_track_og_curr, wl_update) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - # def emitSigGetData(self): - # self.mutex.lock() - # self.sigGetData.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - def emitSigAskInstallModel(self, model_name): - self.mutex.lock() - self.sigAskInstallModel.emit(model_name) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitSigAskInstallGPU(self, base_model_name, use_gpu): - self.mutex.lock() - self.sigSegForLostIDsWorkerAskInstallGPU.emit(base_model_name, - use_gpu) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - # def emitGet2Dlab(self): - # self.mutex.lock() - # self.sigGet2Dlab.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - # def emitGetTrackedLostIDs(self): - # self.mutex.lock() - # self.sigGetTrackedLostIDs.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - # def emitGetBrushID(self): - # self.mutex.lock() - # self.sigGetBrushID.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - def emitTrackManuallyAddedObject(self, IDs, isLost, wl_update, wl_track_og_curr): - self.mutex.lock() - self.sigTrackManuallyAddedObject.emit(IDs, isLost, wl_update, wl_track_og_curr) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - @worker_exception_handler - def run(self): - posData = self.guiWin.data[self.guiWin.pos_i] - frame_i = posData.frame_i - - if not self.guiWin.SegForLostIDsSettings: - self.emitSigAskInit() - - if not self.guiWin.SegForLostIDsSettings: - self.signals.finished.emit(self) - return - - self.logger.info('Segmentation for lost IDs started.') - model_name = 'local_seg' - base_model_name = self.guiWin.SegForLostIDsSettings['base_model_name'] - idx = self.guiWin.modelNames.index(model_name) - acdcSegment = self.guiWin.acdcSegment_li[idx] - - init_kwargs = self.guiWin.SegForLostIDsSettings['win'].init_kwargs - - use_gpu = init_kwargs.get('device_type', 'cpu') != 'cpu' - use_gpu = use_gpu or init_kwargs.get('use_gpu', False) - - self.emitSigAskInstallGPU(base_model_name, use_gpu) - - if not self.gpu_go: - self.signals.finished.emit(self) - return - - if not self.dont_force_cpu: - if 'device' in init_kwargs: - init_kwargs['device'] = 'cpu' - if 'use_gpu' in init_kwargs: - init_kwargs['use_gpu'] = False - - if acdcSegment is None or base_model_name != self.guiWin.local_seg_base_model_name: - try: - self.logger.info(f'Importing {base_model_name}...') - self.emitSigAskInstallModel(base_model_name) - acdcSegment = myutils.import_segment_module(base_model_name) - self.guiWin.acdcSegment_li[idx] = acdcSegment - self.guiWin.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: - self.logger.warning( - f'Cannot import {base_model_name} model. ' - 'Please install it first.' - ) - self.signals.critical.emit( - (self, f'Cannot import {base_model_name} model. ' - 'Please install it first.') - ) - self.signals.finished.emit(self) - return - - win = self.guiWin.SegForLostIDsSettings['win'] - init_kwargs_new = self.guiWin.SegForLostIDsSettings['init_kwargs_new'] - args_new = self.guiWin.SegForLostIDsSettings['args_new'] - - model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new) - if model is None: - self.logger.info('Segmentation model was not initialized correctly!') - self.signals.critical.emit( - (self, 'Segmentation model was not initialized correctly!') - ) - self.signals.finished.emit(self) - return - if self._debug: - try: - model.setupLogger(self.guiwin.logger) - except Exception as e: - pass - - assigned_IDs = [] - missing_IDs_global = set() - original_lab = posData.lab.copy() - IDs_bboxs_list = [] - bboxs_list = [] - - curr_img = self.guiWin.getDisplayedImg1() - prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) - prev_IDs = set(posData.allData_li[frame_i-1]['IDs']) - - # should probably not paly so much with posData.lab, instead handle stuff myself - self.signals.initProgressBar.emit(2 * args_new['max_iterations']) - new_labs = np.zeros([args_new['max_iterations'], *posData.lab.shape], dtype=np.uint32) - for i in range(args_new['max_iterations']): - curr_lab = self.guiWin.get_2Dlab(posData.lab) - tracked_lost_IDs = self.guiWin.getTrackedLostIDs() - new_unique_ID = self.guiWin.setBrushID(useCurrentLab=True, return_val=True) - - missing_IDs = prev_IDs - set(posData.IDs) - set(tracked_lost_IDs) - missing_IDs_global.update(missing_IDs) - - assigned_IDs_prev = assigned_IDs.copy() - out = segm_utils.single_cell_seg( - model, prev_lab, curr_lab, curr_img, - missing_IDs, new_unique_ID, - win, posData, - distance_filler_growth=args_new['distance_filler_growth'], - overlap_threshold=args_new['overlap_threshold'], - padding=args_new['padding'], - ) - new_lab, assigned_IDs, IDs_bboxs, bboxs = out - - IDs_bboxs_list.append(IDs_bboxs) - bboxs_list.append(bboxs) - posData.lab = new_lab - self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) - newly_assigned_IDs = set(assigned_IDs) - set(assigned_IDs_prev) - self.emitTrackManuallyAddedObject(newly_assigned_IDs, True, False, False) - new_labs[i] = posData.lab.copy() - self.signals.progressBar.emit(1) - - if self._debug: - originals = [] - models = [] - - posData.lab = original_lab.copy() - - global_area_mean = np.mean([obj.area for obj in posData.rp]) - for IDs_bboxs, bboxs in zip(IDs_bboxs_list, bboxs_list): - model_lab = new_labs[i] - if self._debug: - originals.append(original_lab.copy()) - models.append(posData.lab.copy()) - - for IDs, bbox in zip(IDs_bboxs, bboxs): - - box_x_min, box_x_max, box_y_min, box_y_max = bbox - original_bbox_lab = original_lab[box_x_min:box_x_max, box_y_min:box_y_max] - original_bbox_lab_cleared_borders = skimage.segmentation.clear_border(original_bbox_lab) - box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] - - # original_bbox_lab[np.isin(original_bbox_lab, IDs)] = 0 should be a given. If not seg for lost IDs this recommended - - box_model_lab = skimage.segmentation.clear_border(box_model_lab, buffer_size=1) - - rp_model_lab = skimage.measure.regionprops(box_model_lab) - rp_original_lab = skimage.measure.regionprops(original_bbox_lab) - rp_original_lab_cleared = skimage.measure.regionprops(original_bbox_lab_cleared_borders) - - original_IDs = [obj.label for obj in rp_original_lab] - areas = [obj.area for obj in rp_original_lab_cleared] - if len(areas) > 0: - area_mean = np.mean(areas) - else: - area_mean = global_area_mean - if args_new['allow_only_tracked_cells']: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean - and obj.label not in original_IDs - and obj.label in missing_IDs_global] - else: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean - and obj.label not in original_IDs] - - if self._debug or DEBUG: - filtered_sizes = [(obj.label, obj.area) for obj in rp_model_lab if obj.label in filtered_IDs] - self.logger.info(f"Filtered sizes: {filtered_sizes}") - for label in filtered_IDs: - original_bbox_lab[box_model_lab == label] = label # here the stuff should be tracked, so we keep the ID! - - # original_lab[box_x_min:box_x_max, box_y_min:box_y_max] = original_bbox_lab - - self.signals.progressBar.emit(1) - - posData.lab = original_lab - - # if self._debug: - # originals = np.concatenate(originals, axis=0) - # models = np.concatenate(models, axis=0) - # self.emitSigShowImageDebug(originals) - # self.emitSigShowImageDebug(models) - - self.emitSigUpdateRP(wl_track_og_curr=True, wl_update=True) - self.emitSigStoreData(autosave=True) - - self.logger.info('Segmentation for lost IDs done.') - - self.signals.finished.emit(self) - -class AlignDataWorker(QObject): - sigWarnTifAligned = Signal(object, object, object) - sigAskAlignSegmData = Signal() - - def __init__(self, posData, dataPrepWin, mutex, waitCond): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.posData = posData - self.dataPrepWin = dataPrepWin - self.mutex = mutex - self.waitCond = waitCond - self.doNotAlignSegmData = False - self.doAbort = False - - def set_attr(self, align, user_ch_name): - self.align = align - self.user_ch_name = user_ch_name - - def pause(self): - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def restart(self): - self.waitCond.wakeAll() - - def emitWarnTifAligned(self, numFramesWith0s, tif, posData): - self.sigWarnTifAligned.emit(numFramesWith0s, tif, posData) - self.pause() - - def emitSigAskAlignSegmData(self): - self.sigAskAlignSegmData.emit() - self.pause() - - def _align_data(self): - _zip = zip(self.posData.tif_paths, self.posData.npz_paths) - aligned = False - self.posData.all_npz_paths = [ - tif.replace('.tif', '_aligned.npz') for tif in self.posData.tif_paths - ] - for i, (tif, npz) in enumerate(_zip): - doAlign = npz is None or self.posData.loaded_shifts is None - - filename_tif = os.path.basename(tif) - user_ch_filename = f'{self.posData.basename}{self.user_ch_name}.tif' - - if not doAlign: - _npz = f'{os.path.splitext(tif)[0]}_aligned.npz' - if os.path.exists(_npz): - self.posData.all_npz_paths[i] = _npz - continue - - if filename_tif != user_ch_filename: - continue - - if not self.align: - continue - - # Align based on user_ch_name - aligned = True - self.logger.log(f'Aligning: {tif}') - - tif_data = load.imread(tif) - numFramesWith0s = self.dataPrepWin.detectTifAlignment( - tif_data, self.posData - ) - if self.align: - self.emitWarnTifAligned( - numFramesWith0s, tif, self.posData - ) - if self.doAbort: - return - - # Alignment routine - if self.posData.SizeZ>1: - align_func = core.align_frames_3D - df = self.posData.segmInfo_df.loc[self.posData.filename] - zz = df['z_slice_used_dataPrep'].to_list() - if not self.posData.filename.endswith('aligned') and self.align: - # Add aligned channel to segmInfo - df_aligned = self.posData.segmInfo_df.rename( - index={self.posData.filename: f'{self.posData.filename}_aligned'} - ) - self.posData.segmInfo_df = pd.concat( - [self.posData.segmInfo_df, df_aligned] - ) - self.posData.segmInfo_df.to_csv(self.posData.segmInfo_df_csv_path) - else: - align_func = core.align_frames_2D - zz = None - - if self.align: - self.signals.initProgressBar.emit(len(tif_data)) - aligned_frames, shifts = align_func( - tif_data, slices=zz, user_shifts=self.posData.loaded_shifts, - sigPyqt=self.signals.progressBar - ) - self.posData.loaded_shifts = shifts - else: - aligned_frames = tif_data - - if self.align: - self.signals.initProgressBar.emit(0) - _npz = f'{os.path.splitext(tif)[0]}_aligned.npz' - self.logger.log(f'Storing temporary file: {_npz}') - temp_npz = self.dataPrepWin.getTempfilePath(_npz) - io.savez_compressed(temp_npz, aligned_frames) - self.dataPrepWin.storeTempFileMove(temp_npz, _npz) - np.save( - self.posData.align_shifts_path, self.posData.loaded_shifts - ) - self.posData.all_npz_paths[i] = _npz - - self.logger.log(f'Storing temporary file: {tif}') - temp_tif = self.dataPrepWin.getTempfilePath(tif) - myutils.to_tiff(temp_tif, aligned_frames) - self.dataPrepWin.storeTempFileMove(temp_tif, tif) - self.posData.img_data = load.imread(temp_tif) - - _zip = zip(self.posData.tif_paths, self.posData.npz_paths) - for i, (tif, npz) in enumerate(_zip): - doAlign = npz is None or aligned - - if not doAlign: - continue - - if tif.endswith(f'{self.user_ch_name}.tif'): - continue - - if not self.align: - continue - - # Align the other channels - if self.posData.loaded_shifts is None: - break - - if self.align: - self.logger.log(f'Aligning: {tif}') - tif_data = load.imread(tif) - - # Alignment routine - if self.posData.SizeZ>1: - align_func = core.align_frames_3D - df = self.posData.segmInfo_df.loc[self.posData.filename] - zz = df['z_slice_used_dataPrep'].to_list() - else: - align_func = core.align_frames_2D - zz = None - if self.align: - self.signals.initProgressBar.emit(len(tif_data)) - aligned_frames, shifts = align_func( - tif_data, slices=zz, user_shifts=self.posData.loaded_shifts, - sigPyqt=self.signals.progressBar - ) - else: - aligned_frames = tif_data - - _npz = f'{os.path.splitext(tif)[0]}_aligned.npz' - - if self.align: - self.signals.initProgressBar.emit(0) - self.logger.log(f'Saving: {_npz}') - temp_npz = self.dataPrepWin.getTempfilePath(_npz) - io.savez_compressed(temp_npz, aligned_frames) - self.dataPrepWin.storeTempFileMove(temp_npz, _npz) - self.posData.all_npz_paths[i] = _npz - - self.logger.log(f'Saving: {tif}') - temp_tif = self.dataPrepWin.getTempfilePath(tif) - myutils.to_tiff(temp_tif, aligned_frames) - self.dataPrepWin.storeTempFileMove(temp_tif, tif) - - if not aligned: - return - - if not self.posData.segmFound: - return - - # Align segmentation data accordingly - self.segmAligned = False - if self.posData.loaded_shifts is None or not self.align: - return - - self.emitSigAskAlignSegmData() - if self.doNotAlignSegmData: - return - - self.dataPrepWin.segmAligned = True - self.logger.log(f'Aligning: {self.posData.segm_npz_path}') - self.posData.segm_data, shifts = core.align_frames_2D( - self.posData.segm_data, slices=None, - user_shifts=self.posData.loaded_shifts - ) - self.logger.log(f'Saving: {self.posData.segm_npz_path}') - temp_npz = self.dataPrepWin.getTempfilePath(self.posData.segm_npz_path) - io.savez_compressed(temp_npz, self.posData.segm_data) - self.dataPrepWin.storeTempFileMove(temp_npz, self.posData.segm_npz_path) - - @worker_exception_handler - def run(self): - self._align_data() - self.signals.finished.emit(self) - -class LabelRoiWorker(QObject): - finished = Signal() - critical = Signal(object) - progress = Signal(str, object) - sigProgressBar = Signal(int) - sigLabellingDone = Signal(object, bool) - - def __init__(self, Gui): - QObject.__init__(self) - self.logger = workerLogger(self.progress) - self.Gui = Gui - self.mutex = Gui.labelRoiMutex - self.waitCond = Gui.labelRoiWaitCond - self.exit = False - self.started = False - - def pause(self): - self.logger.log('Draw box around object to start magic labeller.') - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def start(self, roiImg, posData, roiSecondChannel=None, isTimelapse=False): - self.posData = posData - self.isTimelapse = isTimelapse - self.imageData = roiImg - self.roiSecondChannel = roiSecondChannel - self.restart() - - def restart(self, log=True): - if log: - self.logger.log('Magic labeller started...') - self.started = True - self.waitCond.wakeAll() - - def _stop(self): - self.logger.log('Magic labeller backend process done. Closing it...') - self.exit = True - self.waitCond.wakeAll() - - def _segment_image(self, img, secondChannelImg): - if secondChannelImg is not None: - img = self.Gui.labelRoiModel.second_ch_img_to_stack( - img, secondChannelImg - ) - - lab = core.segm_model_segment( - self.Gui.labelRoiModel, img, self.Gui.model_kwargs, - preproc_recipe=self.Gui.preproc_recipe, - posData=self.posData - ) - if self.Gui.applyPostProcessing: - lab = core.post_process_segm( - lab, **self.Gui.standardPostProcessKwargs - ) - if self.Gui.customPostProcessFeatures: - lab = features.custom_post_process_segm( - self.posData, self.Gui.customPostProcessGroupedFeatures, - lab, img, self.posData.frame_i, self.posData.filename, - self.posData.user_ch_name, - self.Gui.customPostProcessFeatures - ) - return lab - - @worker_exception_handler - def run(self): - while not self.exit: - if self.exit: - break - elif self.started: - self.logger.log('Magic labeller is doing its magic...') - if self.isTimelapse: - segmData = np.zeros(self.imageData.shape, dtype=np.uint32) - for frame_i, img in enumerate(self.imageData): - if self.roiSecondChannel is not None: - secondChannelImg = self.roiSecondChannel[frame_i] - else: - secondChannelImg = None - lab = self._segment_image(img, secondChannelImg) - segmData[frame_i] = lab - self.sigProgressBar.emit(1) - else: - img = self.imageData - secondChannelImg = self.roiSecondChannel - segmData = self._segment_image(img, secondChannelImg) - - self.sigLabellingDone.emit(segmData, self.isTimelapse) - self.started = False - self.pause() - self.finished.emit() - -class StoreGuiStateWorker(QObject): - finished = Signal(object) - sigDone = Signal() - progress = Signal(str, object) - - def __init__(self, mutex, waitCond): - QObject.__init__(self) - self.mutex = mutex - self.waitCond = waitCond - self.exit = False - self.isFinished = False - self.q = queue.Queue() - self.logger = workerLogger(self.progress) - - def pause(self): - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def enqueue(self, posData, img1): - self.q.put((posData, img1)) - self.waitCond.wakeAll() - - def _stop(self): - self.exit = True - self.waitCond.wakeAll() - - def run(self): - while True: - if self.exit: - self.logger.log('Closing store state worker...') - break - elif not self.q.empty(): - posData, img1 = self.q.get() - # self.logger.log('Storing state...') - if posData.cca_df is not None: - cca_df = posData.cca_df.copy() - else: - cca_df = None - - state = { - 'image': img1.copy(), - 'labels': posData.storedLab.copy(), - 'editID_info': posData.editID_info.copy(), - 'binnedIDs': posData.binnedIDs.copy(), - 'ripIDs': posData.ripIDs.copy(), - 'cca_df': cca_df - } - posData.UndoRedoStates[posData.frame_i].insert(0, state) - if self.q.empty(): - # self.logger.log('State stored...') - self.sigDone.emit() - else: - self.pause() - - self.isFinished = True - self.finished.emit(self) - -class AutoSaveWorker(QObject): - finished = Signal(object) - sigDone = Signal() - critical = Signal(object) - progress = Signal(str, object) - sigStartTimer = Signal(object, object) - sigStopTimer = Signal() - sigAutoSaveCannotProceed = Signal() - - def __init__(self, mutex, waitCond, savedSegmData): - QObject.__init__(self) - self.savedSegmData = savedSegmData - self.logger = workerLogger(self.progress) - self.mutex = mutex - self.waitCond = waitCond - self.exit = False - self.isFinished = False - self.stopSaving = False - self.isSaving = False - self.isPaused = False - self.dataQ = deque(maxlen=5) - self.isAutoSaveON = False - self.isAutoSaveAnnotON = True - self.debug = False - - def pause(self): - if self.debug: - self.logger.log('Autosaving is idle.') - self.mutex.lock() - self.isPaused = True - self.waitCond.wait(self.mutex) - self.mutex.unlock() - self.isPaused = False - - def enqueue(self, posData): - # First stop previously saving data - if self.isSaving: - self.stopSaving = True - self._enqueue(posData) - - def _enqueue(self, posData): - if self.debug: - self.logger.log('Enqueing posData autosave...') - self.dataQ.append(posData) - if len(self.dataQ) == 1: - # Wake up worker upon inserting first element - self.stopSaving = False - self.waitCond.wakeAll() - - def _stop(self): - self.exit = True - self.waitCond.wakeAll() - - def stop(self): - self.stopSaving = True - while not len(self.dataQ) == 0: - data = self.dataQ.pop() - del data - self._stop() - - def cancelSaving(self): - ... - - @worker_exception_handler - def run(self): - while True: - if self.exit: - self.logger.log('Closing autosaving worker...') - break - elif not len(self.dataQ) == 0: - if self.debug: - self.logger.log('Autosaving...') - data = self.dataQ.pop() - self.isSaving = True - try: - self.saveData(data) - except Exception as e: - error = traceback.format_exc() - print('*'*40) - self.logger.log(error) - print('='*40) - self.isSaving = False - - if len(self.dataQ) == 0: - self.sigDone.emit() - else: - self.pause() - self.isFinished = True - self.finished.emit(self) - if self.debug: - self.logger.log('Autosave finished signal emitted') - - def getLastTrackedFrame(self, posData): - last_tracked_i = 0 - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - frame_i -= 1 - break - if frame_i > 0: - return frame_i - else: - return last_tracked_i - - def saveData(self, posData): - if self.debug: - self.logger.log('Started autosaving...') - - if not self.isAutoSaveON and not self.isAutoSaveAnnotON: - return - - try: - posData.setTempPaths() - except Exception as e: - self.logger.log( - '[WARNING]: Cell-ACDC cannot create the recovery folder for ' - 'the autosaving process. Autosaving will be turned off.' - ) - self.sigAutoSaveCannotProceed.emit() - return - segm_npz_path = posData.segm_npz_temp_path - - end_i = self.getLastTrackedFrame(posData) - - saved_segm_data = None - if self.isAutoSaveON: - if end_i < len(posData.segm_data): - saved_segm_data = posData.segm_data - else: - frame_shape = posData.segm_data.shape[1:] - segm_shape = (end_i+1, *frame_shape) - saved_segm_data = np.zeros(segm_shape, dtype=np.uint32) - - keys = [] - acdc_df_li = [] - - for frame_i, data_dict in enumerate(posData.allData_li[:end_i+1]): - if self.stopSaving: - break - - # Build saved_segm_data - lab = data_dict['labels'] - if lab is None: - break - - if self.isAutoSaveON and saved_segm_data is not None: - if posData.SizeT > 1: - saved_segm_data[frame_i] = lab - else: - saved_segm_data = lab - - if self.isAutoSaveAnnotON: - acdc_df = data_dict['acdc_df'] - - if acdc_df is None: - continue - - if not np.any(lab): - continue - - if self.isAutoSaveAnnotON: - acdc_df = load.pd_bool_and_float_to_int_to_str( - acdc_df, inplace=False, colsToCastInt=[] - ) - - acdc_df_li.append(acdc_df) - key = (frame_i, posData.TimeIncrement*frame_i) - keys.append(key) - - if self.stopSaving: - break - - if not self.stopSaving: - if self.isAutoSaveON: - segm_data = np.squeeze(saved_segm_data) - self._saveSegm(segm_npz_path, segm_data) - - if acdc_df_li: - all_frames_acdc_df = pd.concat( - acdc_df_li, keys=keys, - names=['frame_i', 'time_seconds', 'Cell_ID'] - ) - self._save_acdc_df(all_frames_acdc_df, posData) - - if self.debug: - self.logger.log(f'Autosaving done.') - self.logger.log(f'Stopped autosaving {self.stopSaving}.') - - self.stopSaving = False - - def _saveSegm(self, recovery_path, data): - try: - equalToSavedSegm = np.all(self.savedSegmData == data) - except Exception as err: - return - - if equalToSavedSegm: - return - else: - io.savez_compressed(recovery_path, np.squeeze(data)) - - def _save_acdc_df(self, recovery_acdc_df: pd.DataFrame, posData): - recovery_folderpath = posData.recoveryFolderpath() - if not os.path.exists(posData.acdc_output_csv_path): - load.store_unsaved_acdc_df(recovery_folderpath, recovery_acdc_df) - return - - saved_acdc_df_path = posData.acdc_output_csv_path - saved_acdc_df = ( - pd.read_csv(saved_acdc_df_path, dtype=load.acdc_df_str_cols) - .set_index(['frame_i', 'Cell_ID']) - ) - - recovery_acdc_df = ( - recovery_acdc_df.reset_index(allow_duplicates=True) - .set_index(['frame_i', 'Cell_ID']) - ) - recovery_acdc_df = recovery_acdc_df.loc[ - :, ~recovery_acdc_df.columns.duplicated() - ] - try: - # Try to insert into the recovery_acdc_df any column that was saved - # but is not in the recovered df (e.g., metrics) - df_left = recovery_acdc_df - existing_cols = df_left.columns.intersection(saved_acdc_df.columns) - df_right = saved_acdc_df.drop(columns=existing_cols) - recovery_acdc_df = df_left.join(df_right, how='left') - except Exception as error: - self.logger.log(f'[WARNING]: {error}') - - # Check if last saved acdc_df is equal - last_unsaved_csv_path = load.get_last_stored_unsaved_acdc_df_filepath( - recovery_folderpath - ) - if last_unsaved_csv_path is None: - reference_acdc_df = saved_acdc_df - else: - try: - reference_acdc_df = ( - pd.read_csv(last_unsaved_csv_path, dtype=load.acdc_df_str_cols) - .set_index(['frame_i', 'Cell_ID']) - ) - except Exception as e: - self.logger.log(f'[WARNING]: {e}') - reference_acdc_df = saved_acdc_df - - if myutils.are_acdc_dfs_equal(recovery_acdc_df, reference_acdc_df): - return - - load.store_unsaved_acdc_df(recovery_folderpath, recovery_acdc_df) - -class segmWorker(QObject): - finished = Signal(np.ndarray, float) - debug = Signal(object) - critical = Signal(object) - - def __init__( - self, mainWin, - secondChannelData=None, - mutex: QWaitCondition=None, - waitCond: QMutex=None - ): - QObject.__init__(self) - self.mainWin = mainWin - self.logger = self.mainWin.logger - self.z_range = None - self.secondChannelData = secondChannelData - self.mutex = mutex - self.waitCond = waitCond - - def emitDebug(self, to_debug): - if self.mutex is None: - return - - self.mutex.lock() - self.debug.emit(to_debug) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - @worker_exception_handler - def run(self): - t0 = time.perf_counter() - if self.mainWin.segment3D: - img = self.mainWin.getDisplayedZstack() - if self.z_range is not None: - startZ, stopZ = self.z_range - img = img[startZ:stopZ+1] - else: - img = self.mainWin.getDisplayedImg1() - - posData = self.mainWin.data[self.mainWin.pos_i] - lab = np.zeros_like(posData.segm_data[0]) - - # self.emitDebug((img, self.secondChannelData)) - - if self.secondChannelData is not None: - img = self.mainWin.model.second_ch_img_to_stack( - img, self.secondChannelData - ) - - start_z_slice = 0 - if self.z_range is not None: - start_z_slice, _ = self.z_range - elif not self.mainWin.segment3D and posData.isSegm3D: - idx = (posData.filename, posData.frame_i) - start_z_slice = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - - _lab = core.segm_model_segment( - self.mainWin.model, img, - self.mainWin.model_kwargs, - frame_i=posData.frame_i, - posData=posData, - start_z_slice=start_z_slice - ) - posData.saveSamEmbeddings(logger_func=self.logger.info) - - if self.mainWin.applyPostProcessing: - _lab = core.post_process_segm( - _lab, **self.mainWin.standardPostProcessKwargs - ) - if self.mainWin.customPostProcessFeatures: - _lab = features.custom_post_process_segm( - posData, self.mainWin.customPostProcessGroupedFeatures, - _lab, img, posData.frame_i, posData.filename, - posData.user_ch_name, self.mainWin.customPostProcessFeatures - ) - - if self.z_range is not None: - # 3D segmentation of a z-slices subset - startZ, stopZ = self.z_range - lab[startZ:stopZ+1] = _lab - elif not self.mainWin.segment3D and posData.isSegm3D: - # 3D segmentation but segmented current z-slice - idx = (posData.filename, posData.frame_i) - z = posData.segmInfo_df.at[idx, 'z_slice_used_gui'] - lab[z] = _lab - else: - # Either whole z-stack or 2D segmentation - lab = _lab - - t1 = time.perf_counter() - exec_time = t1-t0 - self.finished.emit(lab, exec_time) - -class segmVideoWorker(QObject): - finished = Signal(float) - debug = Signal(object) - critical = Signal(object) - progressBar = Signal(int) - progress = Signal(str, object) - - def __init__(self, posData, paramWin, model, startFrameNum, stopFrameNum): - QObject.__init__(self) - self.standardPostProcessKwargs = paramWin.standardPostProcessKwargs - self.applyPostProcessing = paramWin.applyPostProcessing - self.customPostProcessFeatures = paramWin.customPostProcessFeatures - self.customPostProcessGroupedFeatures = ( - paramWin.customPostProcessGroupedFeatures - ) - self.model_kwargs = paramWin.model_kwargs - self.preproc_recipe = paramWin.preproc_recipe - self.secondChannelName = paramWin.secondChannelName - self.model = model - self.posData = posData - self.startFrameNum = startFrameNum - self.stopFrameNum = stopFrameNum - self.logger = workerLogger(self.progress) - - def _check_extend_segm_data(self, segm_data, stop_frame_num): - if stop_frame_num <= len(segm_data): - return segm_data - extended_shape = (stop_frame_num, *segm_data.shape[1:]) - extended_segm_data = np.zeros(extended_shape, dtype=segm_data.dtype) - extended_segm_data[:len(segm_data)] = segm_data - if len(extended_shape) == 4: - return extended_segm_data - if self.posData.SizeZ == 1: - return extended_segm_data - else: - num_added_frames = len(extended_segm_data) - len(segm_data) - half_z = int(self.posData.SizeZ/2) - # 2D segm on 3D over time data --> fix segmInfo - segmInfo_extended = pd.DataFrame({ - 'filename': [self.posData.filename]*num_added_frames, - 'frame_i': list(range(len(segm_data), len(extended_segm_data))), - 'z_slice_used_gui': [half_z]*num_added_frames, - 'which_z_proj_gui': ['single z-slice']*num_added_frames - }).set_index(['filename', 'frame_i']) - segmInfo_df = pd.concat([self.posData.segmInfo_df, segmInfo_extended]) - self.posData.segmInfo_df = segmInfo_df - self.posData.segmInfo_df.to_csv(self.posData.segmInfo_df_csv_path) - return extended_segm_data - - @worker_exception_handler - def run(self): - t0 = time.perf_counter() - self.posData.segm_data = self._check_extend_segm_data( - self.posData.segm_data, self.stopFrameNum - ) - img_data = self.posData.img_data[self.startFrameNum-1:self.stopFrameNum] - is4D = img_data.ndim == 4 - is2D_segm = self.posData.segm_data.ndim == 3 - if is4D and is2D_segm: - filename = self.posData.filename - zz = self.posData.segmInfo_df.loc[filename, 'z_slice_used_gui'] - else: - zz = None - for i, img in enumerate(img_data): - frame_i = i+self.startFrameNum-1 - if self.secondChannelData is not None: - img = self.model.second_ch_img_to_stack( - img, self.secondChannelData - ) - if zz is not None: - z_slice = zz.loc[frame_i] - img = img[z_slice] - - lab = core.segm_model_segment( - self.model, img, self.model_kwargs, frame_i=frame_i, - preproc_recipe=self.preproc_recipe, - posData=self.posData - ) - self.posData.saveSamEmbeddings(logger_func=self.logger.log) - if self.applyPostProcessing: - lab = core.post_process_segm( - lab, **self.standardPostProcessKwargs - ) - if self.customPostProcessFeatures: - lab = features.custom_post_process_segm( - self.posData, - self.customPostProcessGroupedFeatures, - lab, img, self.posData.frame_i, - self.posData.filename, - self.posData.user_ch_name, - self.customPostProcessFeatures - ) - self.posData.segm_data[frame_i] = lab - self.progressBar.emit(1) - t1 = time.perf_counter() - exec_time = t1-t0 - self.finished.emit(exec_time) - -class ComputeMetricsWorker(QObject): - progressBar = Signal(int, int, float) - - def __init__(self, mainWin): - QObject.__init__(self) - self.signals = signals() - self.abort = False - self.setup_done = False - self.logger = workerLogger(self.signals.progress) - self.mutex = QMutex() - self.waitCond = QWaitCondition() - self.mainWin = mainWin - - def emitSelectSegmFiles(self, exp_path, pos_foldernames): - self.mutex.lock() - self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - return True - else: - return False - - @worker_exception_handler - def run(self): - np.seterr(invalid='ignore') - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.standardMetricsErrors = {} - self.customMetricsErrors = {} - self.regionPropsErrors = {} - tot_pos = len(pos_foldernames) - self.allPosDataInputs = [] - posDatas = [] - self.logger.log('-'*30) - expFoldername = os.path.basename(exp_path) - - if i == 0: - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.signals.finished.emit(self) - return - - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - - self.signals.sigUpdatePbarDesc.emit(f'Loading {pos_path}...') - - # Use first found channel, it doesn't matter for metrics - chName = chNames[0] - file_path = myutils.getChannelFilePath(images_path, chName) - - # Load data - posData = load.loadData(file_path, chName) - posData.getBasenameAndChNames(useExt=('.tif', '.h5')) - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=False, - load_acdc_df=True, - load_metadata=True, - loadSegmInfo=True, - load_customCombineMetrics=True - ) - - posDatas.append(posData) - - self.allPosDataInputs.append({ - 'file_path': file_path, - 'chName': chName, - 'combineMetricsConfig': posData.combineMetricsConfig, - 'combineMetricsPath': posData.custom_combine_metrics_path - }) - - if any([posData.SizeT > 1 for posData in posDatas]): - self.mutex.lock() - self.signals.sigAskStopFrame.emit(posDatas) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.signals.finished.emit(self) - return - for p, posData in enumerate(posDatas): - self.allPosDataInputs[p]['stopFrameNum'] = ( - posData.stopFrameNum - ) - else: - for p, posData in enumerate(posDatas): - self.allPosDataInputs[p]['stopFrameNum'] = 1 - - self.kernel = cli.ComputeMeasurementsKernel( - self.logger, - self.mainWin.log_path, - False, - ) - - # Iterate pos and calculate metrics - numPos = len(self.allPosDataInputs) - for p, posDataInputs in enumerate(self.allPosDataInputs): - self.logger.log('='*40) - file_path = posDataInputs['file_path'] - chName = posDataInputs['chName'] - stopFrameNum = posDataInputs['stopFrameNum'] - - self.kernel.run( - img_path=file_path, - stop_frame_n=stopFrameNum, - end_filename_segm=self.mainWin.endFilenameSegm, - computeMetricsWorker=self, - do_init_metrics=p == 0, - ) - - if self.kernel.setup_done: - return - - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log('*'*30) - - self.mutex.lock() - self.signals.sigErrorsReport.emit( - self.standardMetricsErrors, - self.customMetricsErrors, - self.regionPropsErrors - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - self.signals.finished.emit(self) - - def emitSigComputeVolume(self, posData, stop_frame_n): - # Recreate allData_li attribute of the gui - posData.allData_li = [] - for frame_i, lab in enumerate(posData.segm_data[:stop_frame_n]): - data_dict = { - 'labels': lab, - 'regionprops': skimage.measure.regionprops(lab) - } - posData.allData_li.append(data_dict) - self.mutex.lock() - self.signals.sigComputeVolume.emit( - stop_frame_n, posData - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitSigPermissionErrorAndSave( - self, posData, traceback_str, all_frames_acdc_df, - custom_annot_columns - ): - self.mutex.lock() - self.signals.sigPermissionError.emit( - traceback_str, posData.acdc_output_csv_path - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - load.save_acdc_df_file( - all_frames_acdc_df, posData.acdc_output_csv_path, - custom_annot_columns=custom_annot_columns - ) - - def emitSigInitMetricsDialog(self, posData): - self.mainWin.gui.data = [posData] - self.mainWin.gui.pos_i = 0 - self.mainWin.gui.isSegm3D = posData.getIsSegm3D() - self.mutex.lock() - self.signals.sigInitAddMetrics.emit( - posData, self.allPosDataInputs - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitSigAskRunNow(self): - self.mutex.lock() - self.signals.sigAskRunNow.emit(self) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - -class loadDataWorker(QObject): - def __init__(self, mainWin, user_ch_file_paths, user_ch_name, firstPosData): - QObject.__init__(self) - self.signals = signals() - self.mainWin = mainWin - self.user_ch_file_paths = user_ch_file_paths - self.user_ch_name = user_ch_name - self.logger = workerLogger(self.signals.progress) - self.mutex = self.mainWin.loadDataMutex - self.waitCond = self.mainWin.loadDataWaitCond - self.firstPosData = firstPosData - self.abort = False - self.loadUnsaved = False - self.recoveryAsked = False - self.loadSafeOverwriteNpz = False - - def pause(self): - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def checkSelectedDataShape(self, posData, numPos): - skipPos = False - abort = False - emitWarning = ( - not posData.segmFound and posData.SizeT > 1 - and not self.mainWin.isNewFile - ) - if emitWarning: - self.signals.dataIntegrityWarning.emit(posData.pos_foldername) - self.pause() - abort = self.abort - return skipPos, abort - - def warnMismatchSegmDataShape(self, posData): - self.skipPos = False - self.mutex.lock() - self.signals.sigWarnMismatchSegmDataShape.emit(posData) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.skipPos - - @worker_exception_handler - def run(self): - data = [] - user_ch_file_paths = self.user_ch_file_paths - numPos = len(self.user_ch_file_paths) - user_ch_name = self.user_ch_name - self.signals.initProgressBar.emit(len(user_ch_file_paths)) - for i, file_path in enumerate(user_ch_file_paths): - if i == 0: - posData = self.firstPosData - segmFound = self.firstPosData.segmFound - loadSegm = False - else: - posData = load.loadData(file_path, user_ch_name) - loadSegm = True - - self.logger.log(f'Loading {posData.relPath}...') - - posData.loadSizeS = self.mainWin.loadSizeS - posData.loadSizeT = self.mainWin.loadSizeT - posData.loadSizeZ = self.mainWin.loadSizeZ - posData.SizeT = self.mainWin.SizeT - posData.SizeZ = self.mainWin.SizeZ - posData.isSegm3D = self.mainWin.isSegm3D - - if i > 0: - # First pos was already loaded in the main thread - # see loadSelectedData function in gui.py - posData.getBasenameAndChNames() - posData.buildPaths() - if not self.firstPosData.onlyEditMetadata: - posData.loadImgData() - - if self.firstPosData.onlyEditMetadata: - loadSegm = False - - posData.loadOtherFiles( - load_segm_data=loadSegm, - load_acdc_df=True, - load_shifts=True, - loadSegmInfo=True, - load_delROIsInfo=True, - load_bkgr_data=True, - loadBkgrROIs=True, - load_dataPrep_ROIcoords=True, - load_last_tracked_i=True, - load_metadata=True, - load_customAnnot=True, - load_customCombineMetrics=True, - end_filename_segm=self.mainWin.selectedSegmEndName, - create_new_segm=self.mainWin.isNewFile, - new_endname=self.mainWin.newSegmEndName, - labelBoolSegm=self.mainWin.labelBoolSegm, - ) - posData.labelSegmData() - - if i == 0: - posData.segmFound = segmFound - - posData.addYXcentroidColsIfMissing(show_progress=True) - - isPosSegm3D = posData.getIsSegm3D() - isMismatch = ( - isPosSegm3D != self.mainWin.isSegm3D - and isPosSegm3D is not None - and not self.mainWin.isNewFile - ) - if isMismatch: - skipPos = self.warnMismatchSegmDataShape(posData) - if skipPos: - self.logger.log( - f'Skipping "{posData.relPath}" because segmentation ' - 'data shape different from first Position loaded.' - ) - continue - else: - data = 'abort' - break - - self.logger.log( - 'Loaded paths:\n' - f'Segmentation file name: {os.path.basename(posData.segm_npz_path)}\n' - f'ACDC output file name {os.path.basename(posData.acdc_output_csv_path)}' - ) - - posData.SizeT = self.mainWin.SizeT - if self.mainWin.SizeZ > 1: - SizeZ = posData.img_data_shape[-3] - posData.SizeZ = SizeZ - else: - posData.SizeZ = 1 - posData.TimeIncrement = self.mainWin.TimeIncrement - posData.PhysicalSizeZ = self.mainWin.PhysicalSizeZ - posData.PhysicalSizeY = self.mainWin.PhysicalSizeY - posData.PhysicalSizeX = self.mainWin.PhysicalSizeX - posData.isSegm3D = self.mainWin.isSegm3D - posData.saveMetadata( - signals=self.signals, mutex=self.mutex, waitCond=self.waitCond, - additionalMetadata=self.firstPosData._additionalMetadataValues - ) - if hasattr(posData, 'img_data_shape'): - SizeY, SizeX = posData.img_data_shape[-2:] - - if posData.SizeZ > 1 and posData.img_data.ndim < 3: - posData.SizeZ = 1 - posData.segmInfo_df = None - try: - os.remove(posData.segmInfo_df_csv_path) - except FileNotFoundError: - pass - - posData.setBlankSegmData( - posData.SizeT, posData.SizeZ, SizeY, SizeX - ) - if not self.firstPosData.onlyEditMetadata: - skipPos, abort = self.checkSelectedDataShape(posData, numPos) - else: - skipPos, abort = False, False - - if skipPos: - continue - elif abort: - data = 'abort' - break - - posData.setTempPaths(createFolder=False) - isRecoveredDataPresent = ( - os.path.exists(posData.segm_npz_temp_path) - or posData.isRecoveredAcdcDfPresent() - or posData.isSafeNpzOverwritePresent() - ) - if isRecoveredDataPresent and not self.mainWin.newSegmEndName: - if not self.recoveryAsked: - self.mutex.lock() - self.signals.sigRecovery.emit(posData) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - self.recoveryAsked = True - if self.abort: - data = 'abort' - break - if self.loadUnsaved: - self.logger.log('Loading unsaved data...') - if os.path.exists(posData.segm_npz_temp_path): - segm_npz_path = posData.segm_npz_temp_path - posData.segm_data = np.load(segm_npz_path)['arr_0'] - segm_filename = os.path.basename(segm_npz_path) - posData.segm_npz_path = os.path.join( - posData.images_path, segm_filename - ) - - posData.loadMostRecentUnsavedAcdcDf() - elif self.loadSafeOverwriteNpz: - self.logger.log('Loading safe npz overwrite...') - segm_safe_npz_path = posData.getSafeNpzOverwritePath() - posData.segm_data = np.load(segm_safe_npz_path)['arr_0'] - - # Allow single 2D/3D image - if posData.SizeT == 1: - posData.img_data = posData.img_data[np.newaxis] - posData.segm_data = posData.segm_data[np.newaxis] - if hasattr(posData, 'img_data_shape'): - img_shape = posData.img_data_shape - img_shape = 'Not Loaded' - if hasattr(posData, 'img_data_shape'): - datasetShape = posData.img_data.shape - else: - datasetShape = 'Not Loaded' - if posData.segm_data is not None: - posData.segmSizeT = len(posData.segm_data) - SizeT = posData.SizeT - SizeZ = posData.SizeZ - self.logger.log(f'Full dataset shape = {img_shape}') - self.logger.log(f'Loaded dataset shape = {datasetShape}') - self.logger.log(f'Number of frames = {SizeT}') - self.logger.log(f'Number of z-slices per frame = {SizeZ}') - data.append(posData) - self.signals.progressBar.emit(1) - - if not data: - data = None - self.signals.dataIntegrityCritical.emit() - - self.signals.finished.emit(data) - -class trackingWorker(QObject): - finished = Signal() - critical = Signal(object) - progress = Signal(str) - debug = Signal(object) - - def __init__(self, posData, mainWin, video_to_track): - QObject.__init__(self) - self.mainWin = mainWin - self.posData = posData - self.mutex = QMutex() - self.signals = signals() - self.waitCond = QWaitCondition() - self.tracker = self.mainWin.tracker - self.track_params = self.mainWin.track_params - self.video_to_track = video_to_track - - def _get_first_untracked_lab(self): - start_frame_i = self.mainWin.start_n - 1 - frameData = self.posData.allData_li[start_frame_i] - lab = frameData['labels'] - if lab is not None: - return lab - else: - return self.posData.segm_data[start_frame_i] - - def _relabel_first_frame_labels(self, tracked_video): - first_untracked_lab = self._get_first_untracked_lab() - self.mainWin.setAllIDs() - max_allIDs = max(self.posData.allIDs, default=0) - max_tracked_video = tracked_video.max() - overall_max = max(max_allIDs, max_tracked_video) - uniqueID = overall_max + 1 - - tracked_video = transformation.retrack_based_on_untracked_first_frame( - tracked_video, first_untracked_lab, uniqueID=uniqueID - ) - return tracked_video - - def _setProgressBarIndefiniteWait(self): - try: - if hasattr(self.signals, 'innerPbar_available'): - if self.signals.innerPbar_available: - # Use inner pbar of the GUI widget (top pbar is for positions) - self.signals.sigInitInnerPbar.emit(1) - return - else: - self.signals.initProgressBar.emit(1) - except Exception as err: - pass - - @worker_exception_handler - def run(self): - self.mutex.lock() - self.progress.emit( - 'Tracking process started (more details in the terminal)...') - - trackerInputImage = None - self.track_params['signals'] = self.signals - if 'image' in self.track_params: - trackerInputImage = self.track_params.pop('image') - start_frame_i = self.mainWin.start_n-1 - stop_frame_n = self.mainWin.stop_n - - trackerInputImage = trackerInputImage[start_frame_i:stop_frame_n] - - tracked_video = core.tracker_track( - self.video_to_track, self.tracker, self.track_params, - intensity_img=trackerInputImage, - logger_func=self.progress.emit - ) - - self._setProgressBarIndefiniteWait() - - # self.debug.emit((tracked_video, self)) - # self.waitCond.wait(self.mutex) - - self.progress.emit('Re-tracking first frame to ensure continuity...') - # Relabel first frame objects back to IDs they had before tracking - # (to ensure continuity with past untracked frames) - tracked_video = self._relabel_first_frame_labels(tracked_video) - - print('') - self.progress.emit('Generating annotations...') - acdc_df = self.posData.fromTrackerToAcdcDf( - self.tracker, tracked_video, start_frame_i=self.mainWin.start_n-1 - ) - # Store new tracked video - current_frame_i = self.posData.frame_i - self.trackingOnNeverVisitedFrames = False - print('') - self.progress.emit('Storing tracked video...') - pbar = tqdm(total=len(tracked_video), ncols=100) - for rel_frame_i, lab in enumerate(tracked_video): - frame_i = rel_frame_i + self.mainWin.start_n - 1 - - if acdc_df is not None: - cca_cols = acdc_df.columns.intersection( - cca_df_colnames_with_tree - ) - # Store cca_df if it is an output of the tracker - cca_df = acdc_df.loc[frame_i][cca_cols] - self.mainWin.store_cca_df( - frame_i=frame_i, cca_df=cca_df, mainThread=False, - autosave=False - ) - - if self.posData.allData_li[frame_i]['labels'] is None: - # repeating tracking on a never visited frame - # --> modify only raw data and ask later what to do - self.posData.segm_data[frame_i] = lab - self.trackingOnNeverVisitedFrames = True - else: - # Get the rest of the stored metadata based on the new lab - self.posData.allData_li[frame_i]['labels'] = lab - self.posData.frame_i = frame_i - self.mainWin.get_data() - self.mainWin.store_data(autosave=False) - - pbar.update() - pbar.close() - - # Back to current frame - self.posData.frame_i = current_frame_i - self.mainWin.get_data() - self.mainWin.store_data(autosave=True) - self.mutex.unlock() - self.finished.emit() - -class reapplyDataPrepWorker(QObject): - finished = Signal() - debug = Signal(object) - critical = Signal(object) - progress = Signal(str) - initPbar = Signal(int) - updatePbar = Signal() - sigCriticalNoChannels = Signal(str) - sigSelectChannels = Signal(object, object, object, str) - - def __init__(self, expPath, posFoldernames): - super().__init__() - self.expPath = expPath - self.posFoldernames = posFoldernames - self.abort = False - self.mutex = QMutex() - self.waitCond = QWaitCondition() - - def raiseSegmInfoNotFound(self, path): - raise FileNotFoundError( - 'The following file is required for the alignment of 4D data ' - f'but it was not found: "{path}"' - ) - - def saveBkgrData(self, imageData, posData, isAligned=False): - bkgrROI_data = {} - for r, roi in enumerate(posData.bkgrROIs): - xl, yt = [int(round(c)) for c in roi.pos()] - w, h = [int(round(c)) for c in roi.size()] - if not yt+h>yt or not xl+w>xl: - # Prevent 0 height or 0 width roi - continue - is4D = posData.SizeT > 1 and posData.SizeZ > 1 - is3Dz = posData.SizeT == 1 and posData.SizeZ > 1 - is3Dt = posData.SizeT > 1 and posData.SizeZ == 1 - is2D = posData.SizeT == 1 and posData.SizeZ == 1 - if is4D: - bkgr_data = imageData[:, :, yt:yt+h, xl:xl+w] - elif is3Dz or is3Dt: - bkgr_data = imageData[:, yt:yt+h, xl:xl+w] - elif is2D: - bkgr_data = imageData[yt:yt+h, xl:xl+w] - bkgrROI_data[f'roi{r}_data'] = bkgr_data - - if not bkgrROI_data: - return - - if isAligned: - bkgr_data_fn = f'{posData.filename}_aligned_bkgrRoiData.npz' - else: - bkgr_data_fn = f'{posData.filename}_bkgrRoiData.npz' - bkgr_data_path = os.path.join(posData.images_path, bkgr_data_fn) - self.progress.emit('Saving background data to:') - self.progress.emit(bkgr_data_path) - io.savez_compressed(bkgr_data_path, **bkgrROI_data) - - def run(self): - ch_name_selector = prompts.select_channel_name( - which_channel='segm', allow_abort=False - ) - for p, pos in enumerate(self.posFoldernames): - if self.abort: - break - - self.progress.emit(f'Processing {pos}...') - - posPath = os.path.join(self.expPath, pos) - imagesPath = os.path.join(posPath, 'Images') - - ls = myutils.listdir(imagesPath) - if p == 0: - ch_names, basenameNotFound = ( - ch_name_selector.get_available_channels(ls, imagesPath) - ) - if not ch_names: - self.sigCriticalNoChannels.emit(imagesPath) - break - self.mutex.lock() - if len(self.posFoldernames) == 1: - # User selected only one pos --> allow selecting and adding - # and external .tif file that will be renamed with the basename - basename = ch_name_selector.basename - else: - basename = None - self.sigSelectChannels.emit( - ch_name_selector, ch_names, imagesPath, basename - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - break - - self.progress.emit( - f'Selected channels: {self.selectedChannels}' - ) - - for chName in self.selectedChannels: - filePath = load.get_filename_from_channel(imagesPath, chName) - posData = load.loadData(filePath, chName) - posData.getBasenameAndChNames() - posData.buildPaths() - posData.loadImgData() - posData.loadOtherFiles( - load_segm_data=False, - getTifPath=True, - load_metadata=True, - load_shifts=True, - load_dataPrep_ROIcoords=True, - loadBkgrROIs=True - ) - - imageData = posData.img_data - - prepped = False - isAligned = False - # Align - if posData.loaded_shifts is not None: - self.progress.emit('Aligning frames...') - shifts = posData.loaded_shifts - if imageData.ndim == 4: - align_func = core.align_frames_3D - else: - align_func = core.align_frames_2D - imageData, _ = align_func(imageData, user_shifts=shifts) - prepped = True - isAligned = True - - # Crop and save background - if posData.dataPrep_ROIcoords is not None: - df = posData.dataPrep_ROIcoords - isCropped = int(df.at['cropped', 'value']) == 1 - if isCropped: - self.saveBkgrData(imageData, posData, isAligned) - self.progress.emit('Cropping...') - x0 = int(df.at['x_left', 'value']) - y0 = int(df.at['y_top', 'value']) - x1 = int(df.at['x_right', 'value']) - y1 = int(df.at['y_bottom', 'value']) - if imageData.ndim == 4: - imageData = imageData[:, :, y0:y1, x0:x1] - elif imageData.ndim == 3: - imageData = imageData[:, y0:y1, x0:x1] - elif imageData.ndim == 2: - imageData = imageData[y0:y1, x0:x1] - prepped = True - else: - filename = os.path.basename(posData.dataPrepBkgrROis_path) - self.progress.emit( - f'WARNING: the file "{filename}" was not found. ' - 'I cannot crop the data.' - ) - - if prepped: - self.progress.emit('Saving prepped data...') - io.savez_compressed(posData.align_npz_path, imageData) - if hasattr(posData, 'tif_path'): - myutils.to_tiff( - posData.tif_path, imageData - ) - - self.updatePbar.emit() - if self.abort: - break - self.finished.emit() - -class LazyLoader(QObject): - sigLoadingFinished = Signal() - - def __init__(self, mutex, waitCond, readH5mutex, waitReadH5cond): - QObject.__init__(self) - self.signals = signals() - self.mutex = mutex - self.waitCond = waitCond - self.exit = False - self.salute = True - self.sender = None - self.H5readWait = False - self.waitReadH5cond = waitReadH5cond - self.readH5mutex = readH5mutex - - def setArgs(self, posData, current_idx, axis, updateImgOnFinished): - self.wait = False - self.updateImgOnFinished = updateImgOnFinished - self.posData = posData - self.current_idx = current_idx - self.axis = axis - - def pauseH5read(self): - self.readH5mutex.lock() - self.waitReadH5cond.wait(self.mutex) - self.readH5mutex.unlock() - - def pause(self): - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - @worker_exception_handler - def run(self): - while True: - if self.exit: - self.signals.progress.emit( - 'Closing lazy loader...', 'INFO' - ) - break - elif self.wait: - self.signals.progress.emit( - 'Lazy loader paused.', 'INFO' - ) - self.pause() - else: - self.signals.progress.emit( - 'Lazy loader resumed.', 'INFO' - ) - self.posData.loadChannelDataChunk( - self.current_idx, axis=self.axis, worker=self - ) - self.sigLoadingFinished.emit() - self.wait = True - - self.signals.finished.emit(None) - - -class ImagesToPositionsWorker(QObject): - finished = Signal() - debug = Signal(object) - critical = Signal(object) - progress = Signal(str) - initPbar = Signal(int) - updatePbar = Signal() - - def __init__(self, folderPath, targetFolderPath, appendText): - super().__init__() - self.abort = False - self.folderPath = folderPath - self.targetFolderPath = targetFolderPath - self.appendText = appendText - - @worker_exception_handler - def run(self): - self.progress.emit(f'Selected folder: "{self.folderPath}"') - self.progress.emit(f'Target folder: "{self.targetFolderPath}"') - self.progress.emit(' ') - ls = myutils.listdir(self.folderPath) - numFiles = len(ls) - self.initPbar.emit(numFiles) - numPosDigits = len(str(numFiles)) - if numPosDigits == 1: - numPosDigits = 2 - pos = 1 - for file in ls: - if self.abort: - break - - filePath = os.path.join(self.folderPath, file) - if os.path.isdir(filePath): - # Skip directories - self.updatePbar.emit() - continue - - self.progress.emit(f'Loading file: {file}') - filename, ext = os.path.splitext(file) - s0p = str(pos).zfill(numPosDigits) - try: - data = load.imread(filePath) - if data.ndim == 3 and (data.shape[-1] == 3 or data.shape[-1] == 4): - self.progress.emit('Converting RGB image to grayscale...') - data = skimage.color.rgb2gray(data) - data = skimage.img_as_ubyte(data) - - posName = f'Position_{pos}' - posPath = os.path.join(self.targetFolderPath, posName) - imagesPath = os.path.join(posPath, 'Images') - if not os.path.exists(imagesPath): - os.makedirs(imagesPath, exist_ok=True) - newFilename = f's{s0p}_{filename}_{self.appendText}.tif' - relPath = os.path.join(posName, 'Images', newFilename) - tifFilePath = os.path.join(imagesPath, newFilename) - self.progress.emit(f'Saving to file: ...{os.sep}{relPath}') - myutils.to_tiff( - tifFilePath, data - ) - pos += 1 - except Exception as e: - self.progress.emit( - f'WARNING: {file} is not a valid image file. Skipping it.' - ) - - self.progress.emit(' ') - self.updatePbar.emit() - - if self.abort: - break - self.finished.emit() - -class BaseWorkerUtil(QObject): - progressBar = Signal(int, int, float) - - def __init__(self, mainWin): - QObject.__init__(self) - self.signals = signals() - self.abort = False - self.skipExp = False - self.logger = workerLogger(self.signals.progress) - self.mutex = QMutex() - self.waitCond = QWaitCondition() - self.mainWin = mainWin - - def emitSelectSegmFiles(self, exp_path, pos_foldernames): - self.mutex.lock() - self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitSelectFilesWithText( - self, exp_path, pos_foldernames, with_text, ext=None - ): - self.mutex.lock() - self.signals.sigSelectFilesWithText.emit( - exp_path, pos_foldernames, with_text, ext - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitSelectFile(self, start_dir, caption='', filters='All files (*.)'): - self.mutex.lock() - self.signals.sigSelectFile.emit(start_dir, caption, filters) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitSelectAcdcOutputFiles( - self, exp_path, pos_foldernames, infoText='', - allowSingleSelection=False, multiSelection=True - ): - self.mutex.lock() - self.signals.sigSelectAcdcOutputFiles.emit( - exp_path, pos_foldernames, infoText, allowSingleSelection, - multiSelection - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitSelectSpotmaxRun( - self, exp_path, pos_foldernames, all_runs, infoText='', - allowSingleSelection=True, multiSelection=True - ): - self.mutex.lock() - self.signals.sigSelectSpotmaxRun.emit( - exp_path, pos_foldernames, all_runs, infoText, allowSingleSelection, - multiSelection - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - -class DataPrepSaveBkgrDataWorker(QObject): - def __init__(self, posData, dataPrepWin): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.posData = posData - self.dataPrepWin = dataPrepWin - - @worker_exception_handler - def run(self): - self.dataPrepWin.saveBkgrData(self.posData) - self.signals.finished.emit(self) - -class DataPrepCropWorker(QObject): - def __init__(self, posData, dataPrepWin, dstPath): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.posData = posData - self.dataPrepWin = dataPrepWin - self.dstPath = dstPath - - @worker_exception_handler - def run(self): - self.dataPrepWin.saveSingleCrop( - self.posData, self.posData.cropROIs[0], self.dstPath - ) - self.signals.finished.emit(self) - -class TrackSubCellObjectsWorker(BaseWorkerUtil): - sigAskAppendName = Signal(str, list) - sigCriticalNotEnoughSegmFiles = Signal(str) - sigAborted = Signal() - - def __init__(self, mainWin): - super().__init__(mainWin) - if mainWin.trackingMode.find('Delete both') != -1: - self.trackingMode = 'delete_both' - elif mainWin.trackingMode.find('Delete sub-cellular') != -1: - self.trackingMode = 'delete_sub' - elif mainWin.trackingMode.find('Delete cells') != -1: - self.trackingMode = 'delete_cells' - elif mainWin.trackingMode.find('Only track') != -1: - self.trackingMode = 'only_track' - - self.relabelSubObjLab = mainWin.relabelSubObjLab - self.IoAthresh = mainWin.IoAthresh - self.createThirdSegm = mainWin.createThirdSegm - self.thirdSegmAppendedText = mainWin.thirdSegmAppendedText - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - red_text = html_utils.span('OF THE CELLs') - self.mainWin.infoText = f'Select segmentation file {red_text}' - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Critical --> there are not enough segm files - if len(self.mainWin.existingSegmEndNames) < 2: - self.mutex.lock() - self.sigCriticalNotEnoughSegmFiles.emit(exp_path) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - self.sigAborted.emit() - return - - self.cellsSegmEndFilename = self.mainWin.endFilenameSegm - - red_text = html_utils.span('OF THE SUB-CELLULAR OBJECTS') - self.mainWin.infoText = ( - f'Select segmentation file {red_text}' - ) - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit( - self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_acdc_df=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - - # Load cells segmentation file - segmDataCells, segmCellsPath = load.load_segm_file( - images_path, end_name_segm_file=self.cellsSegmEndFilename, - return_path=True - ) - acdc_df_cells_endname = self.cellsSegmEndFilename.replace( - '_segm', '_acdc_output' - ) - acdc_df_cell, acdc_df_cells_path = load.load_acdc_df_file( - images_path, end_name_acdc_df_file=acdc_df_cells_endname, - return_path=True - ) - - if posData.SizeT > 1: - numFrames = min((len(segmDataCells), len(posData.segm_data))) - segmDataCells = segmDataCells[:numFrames] - posData.segm_data = posData.segm_data[:numFrames] - else: - numFrames = 1 - - self.signals.sigInitInnerPbar.emit(numFrames*2) - - self.logger.log('Tracking sub-cellular objects...') - tracked = core.track_sub_cell_objects( - segmDataCells, posData.segm_data, self.IoAthresh, - how=self.trackingMode, SizeT=numFrames, - sigProgress=self.signals.sigUpdateInnerPbar, - relabel_sub_obj_lab=self.relabelSubObjLab - ) - (trackedSubSegmData, trackedCellsSegmData, numSubObjPerCell, - replacedSubIds) = tracked - - self.logger.log('Saving tracked segmentation files...') - subSegmFilename, ext = os.path.splitext(posData.segm_npz_path) - trackedSubPath = f'{subSegmFilename}_{appendedName}.npz' - io.savez_compressed(trackedSubPath, trackedSubSegmData) - posData.saveIsSegm3Dmetadata(trackedSubPath) - - if trackedCellsSegmData is not None: - cellsSegmFilename, ext = os.path.splitext(segmCellsPath) - trackedCellsPath = f'{cellsSegmFilename}_{appendedName}.npz' - io.savez_compressed(trackedCellsPath, trackedCellsSegmData) - - if self.createThirdSegm: - self.logger.log( - f'Generating segmentation from ' - f'"{self.cellsSegmEndFilename} - {appendedName}" ' - 'difference...' - ) - if trackedCellsSegmData is not None: - parentSegmData = trackedCellsSegmData - else: - parentSegmData = segmDataCells - diffSegmData = parentSegmData.copy() - diffSegmData[trackedSubSegmData != 0] = 0 - - self.logger.log('Saving difference segmentation file...') - diffSegmPath = ( - f'{subSegmFilename}_{appendedName}' - f'_{self.thirdSegmAppendedText}.npz' - ) - io.savez_compressed(diffSegmPath, diffSegmData) - posData.saveIsSegm3Dmetadata(diffSegmPath) - del diffSegmData - - if self.relabelSubObjLab: - # When we relabel the sub-cell objs acdc_df is not valid anymore - # because IDs could be different - posData.acdc_df = None - - self.logger.log('Generating acdc_output tables...') - # Update or create acdc_df for sub-cellular objects - acdc_dfs_tracked = core.track_sub_cell_objects_acdc_df( - trackedSubSegmData, posData.acdc_df, - replacedSubIds, numSubObjPerCell, - tracked_cells_segm_data=trackedCellsSegmData, - cells_acdc_df=acdc_df_cell, SizeT=posData.SizeT, - sigProgress=self.signals.sigUpdateInnerPbar - ) - subTrackedAcdcDf, trackedAcdcDf = acdc_dfs_tracked - - self.logger.log('Saving acdc_output tables...') - subAcdcDfFilename, _ = os.path.splitext( - posData.acdc_output_csv_path - ) - subTrackedAcdcDfPath = f'{subAcdcDfFilename}_{appendedName}.csv' - subTrackedAcdcDf.to_csv(subTrackedAcdcDfPath) - - if trackedAcdcDf is not None: - basen = posData.basename - cellsSegmFilename = os.path.basename(segmCellsPath) - cellsSegmFilename, ext = os.path.splitext(cellsSegmFilename) - cellsSegmEndname = cellsSegmFilename[len(basen):] - trackedAcdcDfEndname = cellsSegmEndname.replace( - 'segm', 'acdc_output' - ) - trackedAcdcDfFilename = f'{basen}{trackedAcdcDfEndname}' - trackedAcdcDfFilename = f'{trackedAcdcDfFilename}_{appendedName}.csv' - trackedAcdcDfPath = os.path.join( - posData.images_path, trackedAcdcDfFilename - ) - trackedAcdcDf.to_csv(trackedAcdcDfPath) - - if self.createThirdSegm: - if posData.SizeT == 1: - parentSegmData = parentSegmData[np.newaxis] - subAcdcDfFilename = ( - subSegmFilename.replace('.npz', '.csv') - .replace('segm', 'acdc_output') - ) - diffAcdcDfPath = ( - f'{subAcdcDfFilename}_{appendedName}' - f'_{self.thirdSegmAppendedText}.csv' - ) - third_segm_acdc_df = ( - core.track_sub_cell_objects_third_segm_acdc_df( - parentSegmData, trackedAcdcDf - ) - ) - third_segm_acdc_df.to_csv(diffAcdcDfPath) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class PostProcessSegmWorker(QObject): - def __init__( - self, - postProcessKwargs, - customPostProcessGroupedFeatures, - customPostProcessFeatures, - mainWin - ): - super().__init__() - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.kwargs = postProcessKwargs - self.customPostProcessGroupedFeatures = customPostProcessGroupedFeatures - self.customPostProcessFeatures = customPostProcessFeatures - self.mainWin = mainWin - - @worker_exception_handler - def run(self): - mainWin = self.mainWin - data = mainWin.data - posData = data[mainWin.pos_i] - if len(data) > 1: - self.signals.initProgressBar.emit(len(data)) - else: - current_frame_i = posData.frame_i - self.signals.initProgressBar.emit(posData.SizeT - current_frame_i) - - self.logger.log('Post-process segmentation process started.') - self._run() - self.signals.finished.emit(None) - - def _run(self): - kwargs = self.kwargs - mainWin = self.mainWin - data = mainWin.data - - for posData in data: - current_frame_i = posData.frame_i - data_li = posData.allData_li[current_frame_i:] - for i, data_dict in enumerate(data_li): - frame_i = current_frame_i + i - visited = True - lab = data_dict['labels'] - if lab is None: - visited = False - try: - lab = posData.segm_data[frame_i] - except Exception as e: - return - - image = posData.img_data[frame_i] - - processed_lab = core.post_process_segm( - lab, return_delIDs=False, **kwargs - ) - if self.customPostProcessFeatures: - processed_lab = features.custom_post_process_segm( - posData, - self.customPostProcessGroupedFeatures, - processed_lab, - image, - posData.frame_i, - posData.filename, - posData.user_ch_name, - self.customPostProcessFeatures - ) - if visited: - posData.allData_li[frame_i]['labels'] = processed_lab - # Get the rest of the stored metadata based on the new lab - posData.frame_i = frame_i - mainWin.get_data() - mainWin.store_data(autosave=False) - else: - posData.segm_data[frame_i] = lab - - self.signals.progressBar.emit(1) - - posData.frame_i = current_frame_i - -class CreateConnected3Dsegm(BaseWorkerUtil): - sigAskAppendName = Signal(str, list) - sigAborted = Signal() - - def __init__(self, mainWin): - super().__init__(mainWin) - - def criticalSegmIsNot3D(self): - raise TypeError( - 'Input segmentation masks are not 3D. You can use this utility ' - 'only on 3D z-stack data or 4D z-stack over time data.' - ) - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = f'Select 3D segmentation file to connect' - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit( - self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit( - f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_acdc_df=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - if posData.segm_data.ndim == 3: - posData.segm_data = posData.segm_data[np.newaxis] - - self.logger.log('Connecting 3D objects...') - - numFrames = len(posData.segm_data) - self.signals.sigInitInnerPbar.emit(numFrames) - connectedSegmData = np.zeros_like(posData.segm_data) - for frame_i, lab in enumerate(posData.segm_data): - if lab.ndim != 3: - self.criticalSegmIsNot3D() - - connected_lab = core.connect_3Dlab_zboundaries(lab) - connectedSegmData[frame_i] = connected_lab - - self.signals.sigUpdateInnerPbar.emit(1) - - self.logger.log('Saving connected 3D segmentation file...') - segmFilename, ext = os.path.splitext(posData.segm_npz_path) - newSegmFilepath = f'{segmFilename}_{appendedName}.npz' - connectedSegmData = np.squeeze(connectedSegmData) - io.savez_compressed(newSegmFilepath, connectedSegmData) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class ApplyTrackInfoWorker(BaseWorkerUtil): - def __init__( - self, parentWin, endFilenameSegm, trackInfoCsvPath, - trackedSegmFilename, trackColsInfo, posPath - ): - super().__init__(parentWin) - self.endFilenameSegm = endFilenameSegm - self.trackInfoCsvPath = trackInfoCsvPath - self.trackedSegmFilename = trackedSegmFilename - self.trackColsInfo = trackColsInfo - self.posPath = posPath - - @worker_exception_handler - def run(self): - self.logger.log('Loading segmentation file...') - self.signals.initProgressBar.emit(0) - imagesPath = os.path.join(self.posPath, 'Images') - segmFilename = [ - f for f in myutils.listdir(imagesPath) - if f.endswith(f'{self.endFilenameSegm}.npz') - ][0] - segmFilePath = os.path.join(imagesPath, segmFilename) - segmData = np.load(segmFilePath)['arr_0'] - - self.logger.log('Loading table containing tracking info...') - df = pd.read_csv(self.trackInfoCsvPath) - - frameIndexCol = self.trackColsInfo['frameIndexCol'] - - parentIDcol = self.trackColsInfo['parentIDcol'] - pbarMax = len(df[frameIndexCol].unique()) - self.signals.initProgressBar.emit(pbarMax) - - # Apply tracking info - result = core.apply_tracking_from_table( - segmData, self.trackColsInfo, df, signal=self.signals.progressBar, - logger=self.logger.log, pbarMax=pbarMax - ) - trackedData, trackedIDsMapper, deleteIDsMapper = result - - if self.trackedSegmFilename: - trackedSegmFilepath = os.path.join( - imagesPath, self.trackedSegmFilename - ) - else: - trackedSegmFilepath = os.path.join(segmFilePath) - - self.signals.initProgressBar.emit(0) - self.logger.log('Saving tracked segmentation file...') - io.savez_compressed(trackedSegmFilepath, trackedData) - - - mapperPath = os.path.splitext(trackedSegmFilepath)[0] - mapperJsonPath = f'{mapperPath}_deletedIDs_mapper.json' - mapperJsonName = os.path.basename(mapperJsonPath) - self.logger.log(f'Saving deleted IDs to {mapperJsonName}...') - with open(mapperJsonPath, 'w') as file: - file.write(json.dumps(deleteIDsMapper)) - - mapperPath = os.path.splitext(trackedSegmFilepath)[0] - mapperJsonPath = f'{mapperPath}_replacedIDs_mapper.json' - mapperJsonName = os.path.basename(mapperJsonPath) - self.logger.log(f'Saving IDs replacements to {mapperJsonName}...') - with open(mapperJsonPath, 'w') as file: - file.write(json.dumps(trackedIDsMapper)) - - self.logger.log('Generating acdc_output table...') - acdc_df = None - if not self.trackedSegmFilename: - # Fix existing acdc_df - acdcEndname = self.endFilenameSegm.replace('_segm', '_acdc_output') - acdcFilename = [ - f for f in myutils.listdir(imagesPath) - if f.endswith(f'{acdcEndname}.csv') - ] - if acdcFilename: - acdcFilePath = os.path.join(imagesPath, acdcFilename[0]) - acdc_df = pd.read_csv( - acdcFilePath, index_col=['frame_i', 'Cell_ID'] - ) - - if acdc_df is not None: - acdc_df = core.apply_trackedIDs_mapper_to_acdc_df( - trackedIDsMapper, deleteIDsMapper, acdc_df - ) - else: - acdc_dfs = [] - keys = [] - for frame_i, lab in enumerate(trackedData): - rp = skimage.measure.regionprops(lab) - acdc_df_frame_i = myutils.getBaseAcdcDf(rp) - acdc_dfs.append(acdc_df_frame_i) - keys.append(frame_i) - - acdc_df = pd.concat(acdc_dfs, keys=keys, names=['frame_i', 'Cell_ID']) - segmFilename = os.path.basename(trackedSegmFilepath) - acdcFilename = re.sub(segm_re_pattern, '_acdc_output', segmFilename) - acdcFilePath = os.path.join(imagesPath, acdcFilename) - - self.signals.initProgressBar.emit(pbarMax) - parentIDcol = self.trackColsInfo['parentIDcol'] - trackIDsCol = self.trackColsInfo['trackIDsCol'] - if parentIDcol != 'None': - self.logger.log(f'Adding lineage info from "{parentIDcol}" column...') - acdc_df = core.add_cca_info_from_parentID_col( - df, acdc_df, frameIndexCol, trackIDsCol, parentIDcol, - len(segmData), signal=self.signals.progressBar, - maskID_colname=self.trackColsInfo['maskIDsCol'], - x_colname=self.trackColsInfo['xCentroidCol'], - y_colname=self.trackColsInfo['yCentroidCol'] - ) - - self.logger.log('Saving acdc_output table...') - acdc_df.to_csv(acdcFilePath) - - self.signals.finished.emit(self) - -class RestructMultiPosWorker(BaseWorkerUtil): - sigSaveTiff = Signal(str, object, object) - - def __init__(self, rootFolderPath, dstFolderPath, action='copy'): - super().__init__(None) - self.rootFolderPath = rootFolderPath - self.dstFolderPath = dstFolderPath - self.mutex = QMutex() - self.waitCond = QWaitCondition() - self.action = action - - @worker_exception_handler - def run(self): - load._restructure_multi_files_multi_pos( - self.rootFolderPath, self.dstFolderPath, signals=self.signals, - logger=self.logger.log, action=self.action - ) - self.signals.finished.emit(self) - - -class RestructMultiTimepointsWorker(BaseWorkerUtil): - sigSaveTiff = Signal(str, object, object) - - def __init__( - self, allChannels, frame_name_pattern, basename, validFilenames, - rootFolderPath, dstFolderPath, segmFolderPath='' - ): - super().__init__(None) - self.allChannels = allChannels - self.frame_name_pattern = frame_name_pattern - self.basename = basename - self.validFilenames = validFilenames - self.rootFolderPath = rootFolderPath - self.dstFolderPath = dstFolderPath - self.segmFolderPath = segmFolderPath - self.mutex = QMutex() - self.waitCond = QWaitCondition() - - @worker_exception_handler - def run(self): - allChannels = self.allChannels - frame_name_pattern = self.frame_name_pattern - rootFolderPath = self.rootFolderPath - dstFolderPath = self.dstFolderPath - segmFolderPath = self.segmFolderPath - filesInfo = {} - self.signals.initProgressBar.emit(len(self.validFilenames)+1) - for file in self.validFilenames: - try: - # Determine which channel is this file - for ch in allChannels: - m = re.findall(rf'(.*)_{ch}{frame_name_pattern}', file) - if m: - break - else: - raise FileNotFoundError( - f'The file name "{file}" does not contain any channel name' - ) - posName, _, frameName = m[0] - frameNumber = int(frameName) - if posName not in filesInfo: - filesInfo[posName] = {ch: [(file, frameNumber)]} - elif ch not in filesInfo[posName]: - filesInfo[posName][ch] = [(file, frameNumber)] - else: - filesInfo[posName][ch].append((file, frameNumber)) - except Exception as e: - self.logger.log(traceback.format_exc()) - self.logger.log( - f'WARNING: File "{file}" does not contain valid pattern. ' - 'Skipping it.' - ) - continue - - self.signals.progressBar.emit(1) - - df_metadata = None - partial_basename = self.basename - allPosDataInfo = [] - for p, (posName, channelInfo) in enumerate(filesInfo.items()): - self.logger.log(f'='*40) - self.logger.log(f'Processing position "{posName}"...') - - for _, filesList in channelInfo.items(): - # Get info from first file - filePath = os.path.join(rootFolderPath, filesList[0][0]) - try: - img = load.imread(filePath) - break - except Exception as e: - self.logger.log(traceback.format_exc()) - continue - else: - self.logger.log( - f'WARNING: No valid image files found for position {posName}' - ) - continue - - # Get basename - if partial_basename: - basename = f'{partial_basename}_{posName}_' - else: - basename = f'{posName}_' - - # Get SizeT from first file - SizeT = len(filesList) - - # Save metadata.csv - df_metadata = pd.DataFrame({ - 'SizeT': SizeT, - 'basename': basename - }, index=['values']) - - # Iterate channels - for c, (channelName, filesList) in enumerate(channelInfo.items()): - self.logger.log( - f' Processing channel "{channelName}"...' - ) - # Sort by frame number - sortedFilesList = sorted(filesList, key=lambda t:t[1]) - - df_metadata[f'channel_{c}_name'] = [channelName] - - imagesPath = os.path.join(dstFolderPath, f'Position_{p+1}', 'Images') - if not os.path.exists(imagesPath): - os.makedirs(imagesPath, exist_ok=True) - - # Iterate frames - videoData = None - srcSegmPaths = ['']*SizeT - frameNumbers = [] - for frame_i, fileInfo in enumerate(sortedFilesList): - file, _ = fileInfo - ext = os.path.splitext(file)[1] - srcImgFilePath = os.path.join(rootFolderPath, file) - try: - img = load.imread(srcImgFilePath) - if videoData is None: - shape = (SizeT, *img.shape) - videoData = np.zeros(shape, dtype=img.dtype) - videoData[frame_i] = img - pattern = self.frame_name_pattern - frameNumberMatch = re.findall(pattern, file)[0][1] - frameNumber = int(frameNumberMatch) - frameNumbers.append(frameNumber) - except Exception as e: - self.logger.log(traceback.format_exc()) - continue - - if segmFolderPath and c==0: - srcSegmFilePath = os.path.join(segmFolderPath, file) - srcSegmPaths[frame_i] = srcSegmFilePath - - SizeZ = 1 - if img.ndim == 3: - SizeZ = len(img) - - df_metadata['SizeZ'] = [SizeZ] - - self.signals.progressBar.emit(1) - - if videoData is None: - self.logger.log( - f'WARNING: No valid image files found for position ' - f'"{posName}", channel "{channelName}"' - ) - continue - else: - imgFileName = f'{basename}{channelName}.tif' - dstImgFilePath = os.path.join(imagesPath, imgFileName) - dstSegmFileName = f'{basename}segm_{channelName}.npz' - dstSegmPath = os.path.join(imagesPath, dstSegmFileName) - imgDataInfo = { - 'path': dstImgFilePath, 'SizeT': SizeT, 'SizeZ': SizeZ, - 'data': videoData, 'frameNumbers': frameNumbers, - 'dst_segm_path': dstSegmPath, - 'src_segm_paths': srcSegmPaths - } - allPosDataInfo.append(imgDataInfo) - - if df_metadata is not None: - metadata_csv_path = os.path.join( - imagesPath, f'{basename}metadata.csv' - ) - df_metadata = df_metadata.T - df_metadata.index.name = 'Description' - df_metadata.to_csv(metadata_csv_path) - - self.logger.log(f'*'*40) - - if not allPosDataInfo: - self.signals.finished.emit(self) - return - - self.signals.initProgressBar.emit(len(allPosDataInfo)) - self.logger.log('Saving image files...') - maxSizeT = max([d['SizeT'] for d in allPosDataInfo]) - minFrameNumber = min([d['frameNumbers'][0] for d in allPosDataInfo]) - # Pad missing frames in video files according to frame number - for p, imgDataInfo in enumerate(allPosDataInfo): - SizeT = imgDataInfo['SizeT'] - SizeZ = imgDataInfo['SizeZ'] - dstImgFilePath = imgDataInfo['path'] - videoData = imgDataInfo['data'] - frameNumbers = imgDataInfo['frameNumbers'] - paddedShape = (maxSizeT, *videoData.shape[1:]) - imgDataInfo['paddedShape'] = paddedShape - dtype = videoData.dtype - paddedVideoData = np.zeros(paddedShape, dtype=dtype) - for n, img in zip(frameNumbers, videoData): - frame_i = n - minFrameNumber - paddedVideoData[frame_i] = img - - del videoData - imgDataInfo['data'] = None - - self.mutex.lock() - self.sigSaveTiff.emit(dstImgFilePath, paddedVideoData, self.waitCond) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - self.signals.progressBar.emit(1) - - if not segmFolderPath: - self.signals.finished.emit(self) - return - - self.signals.initProgressBar.emit(len(allPosDataInfo)) - self.logger.log('Saving segmentation files...') - for p, imgDataInfo in enumerate(allPosDataInfo): - SizeT = imgDataInfo['SizeT'] - frameNumbers = imgDataInfo['frameNumbers'] - SizeT = imgDataInfo['SizeT'] - SizeZ = imgDataInfo['SizeZ'] - frameNumbers = imgDataInfo['frameNumbers'] - paddedShape = imgDataInfo['paddedShape'] - segmData = np.zeros(paddedShape, dtype=np.uint32) - for n, segmFilePath in zip(frameNumbers, imgDataInfo['src_segm_paths']): - frame_i = n - minFrameNumber - try: - lab = load.imread(segmFilePath).astype(np.uint32) - segmData[frame_i] = lab - except Exception as e: - self.logger.log(traceback.format_exc()) - self.logger.log( - 'WARNING: The following segmentation file does not ' - f'exist, saving empty masks: "{srcSegmFilePath}"' - ) - - io.savez_compressed(imgDataInfo['dst_segm_path'], segmData) - del segmData - - self.signals.finished.emit(self) - -class ComputeMetricsMultiChannelWorker(BaseWorkerUtil): - sigAskAppendName = Signal(str, list, list) - sigCriticalNotEnoughSegmFiles = Signal(str) - sigAborted = Signal() - sigHowCombineMetrics = Signal(str, list, list, list) - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitHowCombineMetrics( - self, imagesPath, selectedAcdcOutputEndnames, - existingAcdcOutputEndnames, allChNames - ): - self.mutex.lock() - self.sigHowCombineMetrics.emit( - imagesPath, selectedAcdcOutputEndnames, - existingAcdcOutputEndnames, allChNames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def loadAcdcDfs(self, imagesPath, selectedAcdcOutputEndnames): - for end in selectedAcdcOutputEndnames: - filePath, _ = load.get_path_from_endname(end, imagesPath) - acdc_df = pd.read_csv(filePath) - yield acdc_df - - def run_iter_exp(self, exp_path, pos_foldernames, i, tot_exp): - tot_pos = len(pos_foldernames) - - abort = self.emitSelectAcdcOutputFiles( - exp_path, pos_foldernames, infoText=' to combine', - allowSingleSelection=False - ) - if abort: - self.sigAborted.emit() - return - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit( - f'{self.mainWin.basename_pos1}acdc_output', - self.mainWin.existingAcdcOutputEndnames, - self.mainWin.selectedAcdcOutputEndnames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - selectedAcdcOutputEndnames = self.mainWin.selectedAcdcOutputEndnames - existingAcdcOutputEndnames = self.mainWin.existingAcdcOutputEndnames - appendedName = self.appendedName - - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - imagesPath = os.path.join(exp_path, pos, 'Images') - basename, chNames = myutils.getBasenameAndChNames( - imagesPath, useExt=('.tif', '.h5') - ) - - if p == 0: - abort = self.emitHowCombineMetrics( - imagesPath, selectedAcdcOutputEndnames, - existingAcdcOutputEndnames, chNames - ) - if abort: - self.sigAborted.emit() - return - acdcDfs = self.acdcDfs.values() - # Update selected acdc_dfs since the user could have - # loaded additional ones inside the emitHowCombineMetrics - # dialog - selectedAcdcOutputEndnames = self.acdcDfs.keys() - else: - acdcDfs = self.loadAcdcDfs( - imagesPath, selectedAcdcOutputEndnames - ) - - dfs = [] - for i, acdc_df in enumerate(acdcDfs): - dfs.append(acdc_df.add_suffix(f'_table{i+1}')) - combined_df = pd.concat(dfs, axis=1) - - newAcdcDf = pd.DataFrame(index=combined_df.index) - for newColname, equation in self.equations.items(): - newAcdcDf[newColname] = combined_df.eval(equation) - - newAcdcDfPath = os.path.join( - imagesPath, f'{basename}acdc_output_{appendedName}.csv' - ) - newAcdcDf.to_csv(newAcdcDfPath) - - equationsIniPath = os.path.join( - imagesPath, f'{basename}equations_{appendedName}.ini' - ) - equationsConfig = config.ConfigParser() - if os.path.exists(equationsIniPath): - equationsConfig.read(equationsIniPath) - equationsConfig = self.addEquationsToConfigPars( - equationsConfig, selectedAcdcOutputEndnames, self.equations - ) - with open(equationsIniPath, 'w') as configfile: - equationsConfig.write(configfile) - - self.signals.progressBar.emit(1) - - return True - - def addEquationsToConfigPars(self, cp, selectedAcdcOutputEndnames, equations): - section = [ - f'df{i+1}:{end}' for i, end in enumerate(selectedAcdcOutputEndnames) - ] - section = ';'.join(section) - if section not in cp: - cp[section] = {} - - for metricName, expression in equations.items(): - cp[section][metricName] = expression - - return cp - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - self.errors = {} - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - try: - result = self.run_iter_exp(exp_path, pos_foldernames, i, tot_exp) - if result is None: - return - except Exception as e: - traceback_str = traceback.format_exc() - self.errors[e] = traceback_str - self.logger.log(traceback_str) - - self.signals.finished.emit(self) - -class ConcatAcdcDfsWorker(BaseWorkerUtil): - sigAborted = Signal() - sigAskFolder = Signal(str) - sigSetMeasurements = Signal(object) - sigAskAppendName = Signal(str, list) - - def __init__(self, mainWin, format='CSV'): - super().__init__(mainWin) - if format.startswith('CSV'): - self._to_format = 'to_csv' - elif format.startswith('XLS'): - self._to_format = 'to_excel' - - def emitSetMeasurements(self, kwargs): - self.mutex.lock() - self.sigSetMeasurements.emit(kwargs) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitAskAppendName(self, allPos_acdc_df_basename): - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit(allPos_acdc_df_basename, []) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - - self.signals.initProgressBar.emit(0) - acdc_dfs_allexp = [] - acdc_objs_count_dfs_allexp = {} - keys_exp = [] - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - if i == 0: - abort = self.emitSelectAcdcOutputFiles( - exp_path, pos_foldernames, infoText=' to combine', - allowSingleSelection=True, multiSelection=False - ) - if abort: - self.sigAborted.emit() - return - - selectedAcdcOutputEndname = self.mainWin.selectedAcdcOutputEndnames[0] - selectedAcdcObjsCountEndname = selectedAcdcOutputEndname.replace( - 'acdc_output', 'acdc_objects_count' - ) - - self.signals.initProgressBar.emit(len(pos_foldernames)) - acdc_dfs = [] - acdc_objs_count_dfs = {} - keys = [] - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - - ls = myutils.listdir(images_path) - - acdc_output_file = [ - f for f in ls - if f.endswith(f'{selectedAcdcOutputEndname}.csv') - ] - if not acdc_output_file: - self.logger.log( - f'{pos} does not contain any ' - f'{selectedAcdcOutputEndname}.csv file. ' - 'Skipping it.' - ) - self.signals.progressBar.emit(1) - continue - - acdc_objs_count_file = [ - f for f in ls - if f.endswith(f'{selectedAcdcObjsCountEndname}.csv') - ] - if acdc_objs_count_file: - df_count_filepath = os.path.join( - images_path, acdc_objs_count_file[0] - ) - df_count = pd.read_csv(df_count_filepath) - acdc_objs_count_dfs[pos] = df_count - - acdc_df_filepath = os.path.join(images_path, acdc_output_file[0]) - acdc_df = pd.read_csv(acdc_df_filepath).set_index('Cell_ID') - acdc_dfs.append(acdc_df) - keys.append(pos) - - self.signals.progressBar.emit(1) - - self.signals.initProgressBar.emit(0) - acdc_df_allpos = pd.concat( - acdc_dfs, keys=keys, names=['Position_n', 'Cell_ID'] - ) - acdc_df_allpos['experiment_folderpath'] = exp_path - - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - df_metadata = load.load_metadata_df(images_path) - SizeZ = df_metadata.at['SizeZ', 'values'] - SizeZ = int(float(SizeZ)) - existing_colnames = acdc_df_allpos.columns - isSegm3D = any([col.endswith('3D') for col in existing_colnames]) - - if i == 0: - kwargs = { - 'loadedChNames': chNames, - 'notLoadedChNames': [], - 'isZstack': SizeZ > 1, - 'isSegm3D': isSegm3D, - 'existing_colnames': existing_colnames - } - self.emitSetMeasurements(kwargs) - if self.abort: - self.sigAborted.emit() - return - - selected_cols = [ - col for col in self.selectedColumns - if col in acdc_df_allpos.columns - ] - acdc_df_allpos = acdc_df_allpos[selected_cols] - acdc_dfs_allexp.append(acdc_df_allpos) - exp_name = os.path.basename(exp_path) - keys_exp.append((exp_path, exp_name)) - - allpos_dir = os.path.join(exp_path, 'AllPos_acdc_output') - if not os.path.exists(allpos_dir): - os.mkdir(allpos_dir) - - allPos_acdc_df_basename = f'AllPos_{selectedAcdcOutputEndname}' - if i == 0: - self.emitAskAppendName(allPos_acdc_df_basename) - if self.abort: - self.sigAborted.emit() - return - - acdc_objs_count_df_allpos_filename = ( - self.concat_df_filename.replace( - 'acdc_output', 'acdc_objects_count' - ) - ) - - acdc_dfs_allpos_filepath = os.path.join( - allpos_dir, self.concat_df_filename - ) - - self.logger.log( - 'Saving all positions concatenated file to ' - f'"{acdc_dfs_allpos_filepath}"' - ) - to_format_func = getattr(acdc_df_allpos, self._to_format) - to_format_func(acdc_dfs_allpos_filepath) - self.acdc_dfs_allpos_filepath = acdc_dfs_allpos_filepath - - if not acdc_objs_count_dfs: - continue - - acdc_objs_count_df_allpos = pd.concat( - acdc_objs_count_dfs, names=['Position_n'] - ) - acdc_objs_count_df_allpos['experiment_folderpath'] = exp_path - - acdc_objs_count_df_allpos_filepath = os.path.join( - allpos_dir, acdc_objs_count_df_allpos_filename - ) - - self.logger.log( - 'Saving all positions objects count file to ' - f'"{acdc_objs_count_df_allpos_filepath}"' - ) - to_format_func = getattr(acdc_objs_count_df_allpos, self._to_format) - to_format_func(acdc_objs_count_df_allpos_filepath) - - acdc_objs_count_dfs_allexp[(exp_path, exp_name)] = ( - acdc_objs_count_df_allpos - ) - - if len(keys_exp) <= 1: - self.signals.finished.emit(self) - return - - allExp_filename = f'multiExp_{self.concat_df_filename}' - self.mutex.lock() - self.sigAskFolder.emit(allExp_filename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - acdc_df_allexp = pd.concat( - acdc_dfs_allexp, keys=keys_exp, - names=['experiment_folderpath', 'experiment_foldername'] - ) - acdc_dfs_allexp_filepath = os.path.join( - self.allExpSaveFolder, allExp_filename - ) - self.logger.log( - 'Saving multiple experiments concatenated file to ' - f'"{acdc_dfs_allexp_filepath}"' - ) - to_format_func = getattr(acdc_df_allexp, self._to_format) - to_format_func(acdc_dfs_allexp_filepath) - - if acdc_objs_count_dfs_allexp: - allexp_count_df_filename = ( - f'multiExp_{acdc_objs_count_df_allpos_filename}' - ) - acdc_objs_count_df_allexp = pd.concat( - acdc_objs_count_dfs_allexp, - names=['experiment_folderpath', 'experiment_foldername'] - ) - acdc_objs_count_df_allexp_filepath = os.path.join( - self.allExpSaveFolder, allexp_count_df_filename - ) - self.logger.log( - 'Saving multiple experiments concatenated file to ' - f'"{acdc_objs_count_df_allexp_filepath}"' - ) - to_format_func = getattr(acdc_objs_count_df_allexp, self._to_format) - to_format_func(acdc_objs_count_df_allexp_filepath) - - self.signals.finished.emit(self) - -class FromImajeJroiToSegmNpzWorker(BaseWorkerUtil): - sigSelectRoisProps = Signal(str, object, bool) - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitSelectRoisProps(self, roi_filepath, TZYX_shape, is_multi_pos): - self.mutex.lock() - self.sigSelectRoisProps.emit(roi_filepath, TZYX_shape, is_multi_pos) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - @worker_exception_handler - def run(self): - import roifile - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - abort = self.emitSelectFilesWithText( - exp_path, pos_foldernames, 'imagej_rois', ext='.zip' - ) - if abort: - self.signals.finished.emit(self) - return - - self.askRoiPreferences = True - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameRoi = self.mainWin.endFilenameWithText - ls = myutils.listdir(images_path) - rois_filepaths = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameRoi}.zip') - ] - - if not rois_filepaths: - self.logger.log( - '[WARNING]: The following Position folder does not ' - f'contain any file ending with {endFilenameRoi}. ' - f'Skipping it. "{os.path.join(exp_path, pos)}")' - ) - continue - - rois_filepath = rois_filepaths[0] - - if self.askRoiPreferences: - is_multi_pos = len(pos_foldernames) > 1 - self.logger.log('Loading image data to get image shape...') - TZYX_shape = load.get_tzyx_shape(images_path) - abort = self.emitSelectRoisProps( - rois_filepath, TZYX_shape, is_multi_pos - ) - if abort: - self.signals.finished.emit(self) - return - - self.askRoiPreferences = not self.useSamePropsForNextPos - elif self.areAllRoisSelected: - rois = roifile.roiread(rois_filepath) - self.IDsToRoisMapper = {i+i: roi for roi in enumerate(rois)} - else: - # Use same ID of previous position - rois = roifile.roiread(rois_filepath) - IDsToRoisMapper = {i+i: roi for i, roi in enumerate(rois)} - self.IDsToRoisMapper = { - ID: IDsToRoisMapper[ID] - for ID in self.IDsToRoisMapper.keys() - } - - self.logger.log('Generating segm mask from ROIs...') - segm_data = myutils.from_imagej_rois_to_segm_data( - TZYX_shape, self.IDsToRoisMapper, self.rescaleRoisSizes, - self.repeatRoisZslicesRange - ) - - - segm_filepath = (rois_filepath - .replace('imagej_rois', 'segm') - .replace('.zip', '.npz') - ) - self.logger.log(f'Saving segm mask to "{segm_filepath}"...') - io.savez_compressed(segm_filepath, segm_data) - - self.signals.finished.emit(self) - - -class ToImajeJroiWorker(BaseWorkerUtil): - def __init__(self, mainWin): - super().__init__(mainWin) - - @worker_exception_handler - def run(self): - from roifile import ImagejRoi, roiwrite - - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.signals.finished.emit(self) - return - - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - - files_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ] - - if not files_path: - self.logger.log( - '[WARNING]: The following Position folder does not ' - f'contain any file ending with {endFilenameSegm}. ' - f'Skipping it. "{os.path.join(exp_path, pos)}")' - ) - continue - - file_path = files_path[0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - - if posData.SizeT > 1: - rois = [] - max_ID = posData.segm_data.max() - for t, lab in enumerate(posData.segm_data): - rois_t = myutils.from_lab_to_imagej_rois( - lab, ImagejRoi, t=t, SizeT=posData.SizeT, - max_ID=max_ID - ) - rois.extend(rois_t) - else: - rois = myutils.from_lab_to_imagej_rois( - posData.segm_data, ImagejRoi - ) - - roi_filepath = posData.segm_npz_path.replace('.npz', '.zip') - roi_filepath = roi_filepath.replace('_segm', '_imagej_rois') - - try: - os.remove(roi_filepath) - except Exception as e: - pass - - roiwrite(roi_filepath, rois) - - self.signals.finished.emit(self) - - -class ToSymDivWorker(QObject): - progressBar = Signal(int, int, float) - - def __init__(self, mainWin): - QObject.__init__(self) - self.signals = signals() - self.abort = False - self.logger = workerLogger(self.signals.progress) - self.mutex = QMutex() - self.waitCond = QWaitCondition() - self.mainWin = mainWin - - def emitSelectSegmFiles(self, exp_path, pos_foldernames): - self.mutex.lock() - self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - self.missingAnnotErrors = {} - tot_pos = len(pos_foldernames) - self.allPosDataInputs = [] - posDatas = [] - self.logger.log('-'*30) - expFoldername = os.path.basename(exp_path) - - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.signals.finished.emit(self) - return - - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - - self.signals.sigUpdatePbarDesc.emit(f'Loading {pos_path}...') - - # Use first found channel, it doesn't matter for metrics - for chName in chNames: - file_path = myutils.getChannelFilePath(images_path, chName) - if file_path: - break - else: - raise FileNotFoundError( - f'None of the channels "{chNames}" were found in the path ' - f'"{images_path}".' - ) - - # Load data - posData = load.loadData(file_path, chName) - posData.getBasenameAndChNames(useExt=('.tif', '.h5')) - - posData.loadOtherFiles( - load_segm_data=False, - load_acdc_df=True, - load_metadata=True, - loadSegmInfo=True - ) - - posDatas.append(posData) - - self.allPosDataInputs.append({ - 'file_path': file_path, - 'chName': chName - }) - - # Iterate pos and calculate metrics - numPos = len(self.allPosDataInputs) - for p, posDataInputs in enumerate(self.allPosDataInputs): - file_path = posDataInputs['file_path'] - chName = posDataInputs['chName'] - - posData = load.loadData(file_path, chName) - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames(useExt=('.tif', '.h5')) - posData.buildPaths() - posData.loadImgData() - - posData.loadOtherFiles( - load_segm_data=False, - load_acdc_df=True, - end_filename_segm=self.mainWin.endFilenameSegm - ) - if not posData.acdc_df_found: - relPath = ( - f'...{os.sep}{expFoldername}' - f'{os.sep}{posData.pos_foldername}' - ) - self.logger.log( - f'WARNING: Skipping "{relPath}" ' - f'because acdc_output.csv file was not found.' - ) - self.missingAnnotErrors[relPath] = ( - f'
    FileNotFoundError: the Positon "{relPath}" ' - 'does not have the acdc_output.csv file.
    ') - - continue - - acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) - self.logger.log( - 'Loaded path:\n' - f'ACDC output file name: "{acdc_df_filename}"' - ) - - self.logger.log('Building tree...') - try: - tree = core.LineageTree(posData.acdc_df) - error = tree.build() - if isinstance(error, KeyError): - self.logger.log(str(error)) - - self.logger.log( - 'WARNING: Annotations missing in ' - f'"{posData.acdc_output_csv_path}"' - ) - self.missingAnnotErrors[acdc_df_filename] = str(error) - continue - elif error is not None: - raise error - posData.acdc_df = tree.df - except Exception as error: - traceback_format = traceback.format_exc() - self.logger.log(traceback_format) - self.errors[error] = traceback_format - - try: - posData.acdc_df.to_csv(posData.acdc_output_csv_path) - except PermissionError: - traceback_str = traceback.format_exc() - self.mutex.lock() - self.signals.sigPermissionError.emit( - traceback_str, posData.acdc_output_csv_path - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - posData.acdc_df.to_csv(posData.acdc_output_csv_path) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class AlignWorker(BaseWorkerUtil): - sigAborted = Signal() - sigAskUseSavedShifts = Signal(str, str) - sigAskSelectChannel = Signal(list) - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitAskUseSavedShifts(self, expPath, basename): - self.mutex.lock() - self.sigAskUseSavedShifts.emit(expPath, basename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitAskSelectChannel(self, channels): - self.mutex.lock() - self.sigAskSelectChannel.emit(channels) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - @worker_exception_handler - def run(self): - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - shiftsFound = False - for pos in pos_foldernames: - images_path = os.path.join(exp_path, pos, 'Images') - ls = myutils.listdir(images_path) - for file in ls: - if file.endswith('align_shift.npy'): - shiftsFound = True - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - break - if shiftsFound: - break - - savedShiftsHow = None - if shiftsFound: - basename_ch0 = f'{basename}{chNames[0]}_' - abort = self.emitAskUseSavedShifts(exp_path, basename_ch0) - if abort: - self.sigAborted.emit() - return - - savedShiftsHow = self.savedShiftsHow - - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log('*'*40) - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - pos_path = os.path.join(exp_path, pos) - images_path = os.path.join(pos_path, 'Images') - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - - self.signals.sigUpdatePbarDesc.emit(f'Loading {pos_path}...') - - if p == 0: - self.logger.log(f'Asking to select reference channel...') - abort = self.emitAskSelectChannel(chNames) - if abort: - self.sigAborted.emit() - return - chName = self.chName - - file_path = myutils.getChannelFilePath(images_path, chName) - - # Load data - posData = load.loadData(file_path, chName) - posData.getBasenameAndChNames(useExt=('.tif', '.h5')) - posData.buildPaths() - posData.loadImgData() - - posData.loadOtherFiles( - load_segm_data=False, - load_shifts=True, - loadSegmInfo=True - ) - - if posData.img_data.ndim == 4: - align_func = core.align_frames_3D - if posData.segmInfo_df is None: - raise FileNotFoundError( - 'To align 4D data you need to select which z-slice ' - 'you want to use for alignment. Please run the module ' - '`1. Launch data prep module...` before aligning the ' - 'frames. (z-slice info MISSING from position ' - f'"{posData.relPath}")' - ) - df = posData.segmInfo_df.loc[posData.filename] - zz = df['z_slice_used_dataPrep'].to_list() - elif posData.img_data.ndim == 3: - align_func = core.align_frames_2D - zz = None - - useSavedShifts = ( - savedShiftsHow == 'use_saved_shifts' - and posData.loaded_shifts is not None - ) - if useSavedShifts: - user_shifts = posData.loaded_shifts - else: - user_shifts = None - - if savedShiftsHow == 'rever_alignment': - if posData.loaded_shifts is None: - self.logger.log( - f'WARNING: Cannot revert alignment in "{posData.relPath}" ' - 'since it is missing previously computed shifts. ' - 'Skipping this positon.' - ) - continue - - # Revert alignment and save selected channel - for chName in chNames: - self.logger.log( - f'Reverting alignment on "{chName}"...' - ) - if chName == posData.user_ch_name: - data = posData.img_data - else: - file_path = myutils.getChannelFilePath( - images_path, chName - ) - data = load.load_image_file(file_path) - - self.signals.sigInitInnerPbar.emit(len(data)-1) - revertedData = core.revert_alignment( - posData.loaded_shifts, data, - sigPyqt=self.signals.sigUpdateInnerPbar - ) - self.logger.log( - f'Saving "{chName}"...' - ) - self.signals.sigInitInnerPbar.emit(0) - self.saveAlignedData( - revertedData, images_path, posData.basename, - chName, self.revertedAlignEndname, - ext=posData.ext - ) - del revertedData, data - else: - for chName in chNames: - self.logger.log( - f'Aligning "{chName}"...' - ) - if chName == posData.user_ch_name: - data = posData.img_data - else: - file_path = myutils.getChannelFilePath( - images_path, chName - ) - data = load.load_image_file(file_path) - self.signals.sigInitInnerPbar.emit(len(data)-1) - - alignedImgData, shifts = align_func( - data, slices=zz, user_shifts=user_shifts, - sigPyqt=self.signals.sigUpdateInnerPbar - ) - self.logger.log(f'Saving "{chName}"...') - np.save(posData.align_shifts_path, shifts) - - self.signals.sigInitInnerPbar.emit(0) - self.saveAlignedData( - alignedImgData, images_path, posData.basename, - chName, '', ext=posData.non_aligned_ext - ) - self.saveAlignedData( - alignedImgData, images_path, posData.basename, - chName, 'aligned', ext='.npz' - ) - del alignedImgData, data - - self.signals.finished.emit(self) - - def saveAlignedData( - self, data, imagesPath, basename, chName, endname, ext='.tif' - ): - if endname: - newFilename = f'{basename}{chName}_{endname}{ext}' - else: - newFilename = f'{basename}{chName}{ext}' - - filePath = os.path.join(imagesPath, newFilename) - - if ext == '.tif': - SizeT = data.shape[0] - SizeZ = 1 - if data.ndim == 4: - SizeZ = data.shape[1] - myutils.to_tiff(filePath, data) - elif ext == '.npz': - io.savez_compressed(filePath, data) - elif ext == '.h5': - load.save_to_h5(filePath, data) - -class ToObjCoordsWorker(BaseWorkerUtil): - def __init__(self, mainWin): - super().__init__(mainWin) - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.signals.finished.emit(self) - return - - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - - if posData.SizeT == 1: - posData.segm_data = posData.segm_data[np.newaxis] - - dfs = [] - n_frames = len(posData.segm_data) - self.signals.initProgressBar.emit(n_frames) - for frame_i, lab in enumerate(posData.segm_data): - df_coords_i = myutils.from_lab_to_obj_coords(lab) - dfs.append(df_coords_i) - self.signals.progressBar.emit(1) - df_filepath = posData.segm_npz_path.replace('.npz', '.csv') - df_filepath = df_filepath.replace('_segm', '_objects_coordinates') - - keys = list(range(len(posData.segm_data))) - df = pd.concat(dfs, keys=keys, names=['frame_i']) - - self.signals.initProgressBar.emit(0) - df.to_csv(df_filepath) - - self.signals.finished.emit(self) - -class Stack2DsegmTo3Dsegm(BaseWorkerUtil): - sigAskAppendName = Signal(str, list) - sigAborted = Signal() - - def __init__(self, mainWin, SizeZ): - super().__init__(mainWin) - self.SizeZ = SizeZ - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = f'Select 2D segmentation file to stack' - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit( - self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_acdc_df=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - if posData.segm_data.ndim == 2: - posData.segm_data = posData.segm_data[np.newaxis] - - self.logger.log('Stacking 2D into 3D objects...') - - numFrames = len(posData.segm_data) - self.signals.sigInitInnerPbar.emit(numFrames) - T, Y, X = posData.segm_data.shape - newShape = (T, self.SizeZ, Y, X) - segmData2D = np.zeros(newShape, dtype=np.uint32) - for frame_i, lab in enumerate(posData.segm_data): - stacked_lab = core.stack_2Dlab_to_3D(lab, self.SizeZ) - segmData2D[frame_i] = stacked_lab - - self.signals.sigUpdateInnerPbar.emit(1) - - self.logger.log('Saving stacked 3D segmentation file...') - segmFilename, ext = os.path.splitext(posData.segm_npz_path) - newSegmFilepath = f'{segmFilename}_{appendedName}.npz' - segmData2D = np.squeeze(segmData2D) - io.savez_compressed(newSegmFilepath, segmData2D) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class MigrateUserProfileWorker(QObject): - finished = Signal(object) - critical = Signal(object) - progress = Signal(str) - debug = Signal(object) - - def __init__(self, src_path, dst_path, acdc_folders): - QObject.__init__(self) - self.signals = signals() - self.src_path = src_path - self.dst_path = dst_path - self.acdc_folders = acdc_folders - - @worker_exception_handler - def run(self): - import shutil - from . import models_path - - self.progress.emit( - 'Migrating user profile data from ' - f'"{self.src_path}" to "{self.dst_path}"...' - ) - acdc_folders = self.acdc_folders - self.signals.initProgressBar.emit(2*len(acdc_folders)) - dst_folder = os.path.basename(self.dst_path) - folders_to_remove = [] - for acdc_folder in acdc_folders: - if acdc_folder == dst_folder: - # Skip the destination folder that would be picked up if the - # user called it with acdc at the start of the name - self.signals.progressBar.emit(2) - continue - src = os.path.join(self.src_path, acdc_folder) - dst = os.path.join(self.dst_path, acdc_folder) - self.progress.emit(f'Copying {src} to {dst}...') - files_failed_move = copy_or_move_tree( - src, dst, copy=False, - sigInitPbar=self.signals.sigInitInnerPbar, - sigUpdatePbar=self.signals.sigUpdateInnerPbar - ) - folders_to_remove.append(src) - self.signals.progressBar.emit(1) - - for to_remove in folders_to_remove: - try: - self.progress.emit(f'Removing "{to_remove}"...') - shutil.rmtree(to_remove) - except Exception as err: - self.progress.emit( - '--------------------------------------------------------\n' - f'[WARNING]: Removal of the folder "{to_remove}" failed. ' - 'Please remove manually.\n' - '--------------------------------------------------------' - ) - finally: - self.signals.progressBar.emit(1) - - # Update model's paths - load.migrate_models_paths(self.dst_path) - - # Store user profile data folder path - from . import user_profile_path_txt - os.makedirs(os.path.dirname(user_profile_path_txt), exist_ok=True) - with open(user_profile_path_txt, 'w') as txt: - txt.write(self.dst_path) - - self.finished.emit(self) - -class DelObjectsOutsideSegmROIWorker(QObject): - finished = Signal(object) - critical = Signal(object) - progress = Signal(str) - debug = Signal(object) - - def __init__( - self, - segm_roi_endname: os.PathLike, - segm_data: np.ndarray, - images_path: os.PathLike - ): - QObject.__init__(self) - self.signals = signals() - self.segm_roi_endname = segm_roi_endname - self.segm_data = segm_data - self.images_path = images_path - - @worker_exception_handler - def run(self): - segm_roi_endname = self.segm_roi_endname - segm_roi_filepath, _ = load.get_path_from_endname( - segm_roi_endname, self.images_path - ) - self.progress.emit(f'Loading segmentation file "{segm_roi_filepath}"...') - segm_roi_data = load.load_image_file(segm_roi_filepath) - - self.progress.emit(f'Deleting objects outside of selected ROIs...') - cleared_segm_data, delIDs = transformation.del_objs_outside_segm_roi( - segm_roi_data, self.segm_data - ) - - self.finished.emit((self, cleared_segm_data, delIDs)) - -class ConcatSpotmaxDfsWorker(BaseWorkerUtil): - sigAborted = Signal() - sigAskFolder = Signal(str) - sigSetMeasurements = Signal(object) - sigAskAppendName = Signal(str, list) - - def __init__(self, mainWin, format='CSV'): - super().__init__(mainWin) - if format.startswith('CSV'): - self._final_ext = '.csv' - elif format.startswith('XLS'): - self._final_ext = '.xlsx' - self.acdcOutputEndname = None - - def emitSetMeasurements(self, kwargs): - self.mutex.lock() - self.sigSetMeasurements.emit(kwargs) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitAskAppendName(self, allPos_spotmax_df_basename): - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit(allPos_spotmax_df_basename, []) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitAskCopyCca(self, images_path): - self.mutex.lock() - self.signals.sigAskCopyCca.emit(images_path) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def setAcdcOutputEndname(self, acdcOutputEndname): - self.acdcOutputEndname = acdcOutputEndname - - def getAcdcDf(self, images_path): - if self.acdcOutputEndname is None: - return - - for file in myutils.listdir(images_path): - if not file.endswith(self.acdcOutputEndname): - continue - - filepath = os.path.join(images_path, file) - acdc_df = pd.read_csv(filepath, index_col=['frame_i', 'Cell_ID']) - return acdc_df - - def copyCcaColsFromAcdcDf(self, df, acdc_df, debug=False): - if acdc_df is None: - return df - - if debug: - printl(acdc_df.columns.to_list(), pretty=True) - - idx = df.index.intersection(acdc_df.index) - for col in cca_df_colnames: - if col not in acdc_df.columns: - continue - - if col not in self.selectedColumns: - continue - - df.loc[idx, col] = acdc_df.loc[idx, col] - - for col in lineage_tree_cols: - if col not in acdc_df.columns: - continue - - if col not in self.selectedColumns: - continue - - df.loc[idx, col] = acdc_df.loc[idx, col] - - for col in default_annot_df.keys(): - if col not in acdc_df.columns: - continue - - if col not in self.selectedColumns: - continue - - df.loc[idx, col] = acdc_df.loc[idx, col] - - for col in self.selectedColumns: - if col not in acdc_df.columns: - continue - - df.loc[idx, col] = acdc_df.loc[idx, col] - - if debug and col == 'cell_vol_fl': - printl(df[[col]]) - - return df - - def emitAskFolderWhereToSaveMultiExp(self): - self.mutex.lock() - self.sigAskFolder.emit('') - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - return self.allExpSaveFolder - - def askSelectMeasurements(self, exp_path, posFoldernames): - acdc_dfs = [] - keys = [] - for p, pos in enumerate(posFoldernames): - if self.abort: - self.sigAborted.emit() - return False - - images_path = os.path.join(exp_path, pos, 'Images') - acdc_df = self.getAcdcDf(images_path) - if acdc_df is None: - continue - - acdc_dfs.append(acdc_df) - keys.append(pos) - - if not acdc_dfs: - return True - - acdc_df_allpos = pd.concat( - acdc_dfs, keys=keys, names=['Position_n', 'frame_i', 'Cell_ID'] - ) - acdc_df_allpos['experiment_folderpath'] = exp_path - basename, chNames = myutils.getBasenameAndChNames( - images_path, useExt=('.tif', '.h5') - ) - df_metadata = load.load_metadata_df(images_path) - SizeZ = df_metadata.at['SizeZ', 'values'] - SizeZ = int(float(SizeZ)) - existing_colnames = acdc_df_allpos.columns - isSegm3D = any([col.endswith('3D') for col in existing_colnames]) - - kwargs = { - 'loadedChNames': chNames, - 'notLoadedChNames': [], - 'isZstack': SizeZ > 1, - 'isSegm3D': isSegm3D, - 'existing_colnames': existing_colnames - } - self.emitSetMeasurements(kwargs) - if self.abort: - self.sigAborted.emit() - return False - - return True - - @worker_exception_handler - def run(self): - from spotmax import DFs_FILENAMES, DF_REF_CH_FILENAME - from spotmax.utils import get_runs_num_and_desc - import spotmax.io - - self.selectedColumns = None - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - spotmax_dfs_spots_allexp = defaultdict(lambda: defaultdict(list)) - spotmax_dfs_aggr_allexp = defaultdict(lambda: defaultdict(list)) - ref_ch_dfs_allexp = defaultdict(lambda: defaultdict(list)) - runNumberAlreadyAsked = False - copyFromCcaAlreadyAsked = False - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - all_runs = get_runs_num_and_desc( - exp_path, pos_foldernames=pos_foldernames - ) - if not all_runs: - self.logger.log( - '[WARNING] The following experiment does not contain ' - f'valid spotMAX output files. Skipping it. "{exp_path}"' - ) - continue - - if not runNumberAlreadyAsked: - abort = self.emitSelectSpotmaxRun( - exp_path, pos_foldernames, all_runs, - infoText=' to combine', - allowSingleSelection=True, - multiSelection=False - ) - if abort: - self.sigAborted.emit() - return - runNumberAlreadyAsked = True - - selectedSpotmaxRuns = self.mainWin.selectedSpotmaxRuns - - self.signals.initProgressBar.emit(len(pos_foldernames)) - dfs_spots = defaultdict(list) - dfs_aggr = defaultdict(list) - dfs_ref_ch = defaultdict(list) - pos_runs = defaultdict(list) - pos_runs_ref_ch = defaultdict(list) - pos_ini_filepaths = {} - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - pos_path = os.path.join(exp_path, pos) - spotmax_output_path = os.path.join(pos_path, 'spotMAX_output') - - if not os.path.exists(spotmax_output_path): - self.logger.log( - '[WARNING] The following Position folder does not contain ' - f'valid spotMAX output files. Skipping it. "{pos_path}"' - ) - continue - - images_path = os.path.join(exp_path, pos, 'Images') - - if not copyFromCcaAlreadyAsked: - self.emitAskCopyCca(images_path) - if self.abort: - self.sigAborted.emit() - return - - self.askSelectMeasurements(exp_path, pos_foldernames) - if self.abort: - return - copyFromCcaAlreadyAsked = True - - acdc_df = self.getAcdcDf(images_path) - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - - for run_desc in selectedSpotmaxRuns: - run, desc = run_desc.split('_...') - ini_filename = f'{run}_analysis_parameters{desc}.ini' - ini_filepath = os.path.join( - spotmax_output_path, ini_filename - ) - if not os.path.exists(ini_filepath): - self.logger.log( - '[WARNING] The following Position folder does not contain ' - f'the spotMAX output file for run number {run}. ' - f'Skipping it. "{pos_path}"' - ) - continue - - pos_ini_filepaths[(run, desc)] = ini_filepath - for _, pattern_filename in DFs_FILENAMES.items(): - run_filename = pattern_filename.replace('*rn*', run) - run_filename = run_filename.replace('*desc*', desc) - aggr_filename = f'{run_filename}_aggregated.csv' - aggr_filepath = os.path.join( - spotmax_output_path, aggr_filename - ) - if not os.path.exists(aggr_filepath): - continue - - df_spots_filename = f'{run_filename}.h5' - spots_filepath = os.path.join( - spotmax_output_path, df_spots_filename - ) - ext_spots = '.h5' - if not os.path.exists(spots_filepath): - df_spots_filename = f'{run_filename}.csv' - spots_filepath = os.path.join( - spotmax_output_path, df_spots_filename - ) - ext_spots = '.csv' - - if not os.path.exists(spots_filepath): - continue - - analysis_step = re.findall( - r'\*rn\*(.*)\*desc\*', pattern_filename - )[0] - key = (run, analysis_step, desc, ext_spots) - try: - df_spots = spotmax.io.load_spots_table( - spotmax_output_path, df_spots_filename - ).reset_index().set_index(['frame_i', 'Cell_ID']) - df_spots = self.copyCcaColsFromAcdcDf( - df_spots, acdc_df, debug=False - ) - df_spots = ( - df_spots.reset_index() - .set_index(['frame_i', 'Cell_ID', 'spot_id']) - ) - dfs_spots[key].append(df_spots) - except Exception as err: - self.logger.log(str(err), level='ERROR') - self.logger.log( - 'WARNING: Error when reading single-spots ' - 'tables (possibly because there are no spots). ' - 'Skipping this Position.', - level='WARNING' - ) - pass - - df_aggregated = pd.read_csv( - aggr_filepath, index_col=['frame_i', 'Cell_ID'] - ) - df_aggregated = self.copyCcaColsFromAcdcDf( - df_aggregated, acdc_df - ) - dfs_aggr[key].append(df_aggregated) - pos_runs[key].append(pos) - - ref_ch_id_text = re.findall( - r'\*rn\*(.*)\*desc\*', DF_REF_CH_FILENAME - )[0] - ref_ch_filename = ( - DF_REF_CH_FILENAME.replace('*rn*', run) - ) - ref_ch_filename = ( - ref_ch_filename.replace('*desc*', desc) - ) - ref_ch_filepath = os.path.join( - spotmax_output_path, ref_ch_filename - ) - if not os.path.exists(ref_ch_filepath): - continue - - df_ref_ch = pd.read_csv( - ref_ch_filepath, index_col=['frame_i', 'Cell_ID'] - ) - df_ref_ch = self.copyCcaColsFromAcdcDf(df_ref_ch, acdc_df) - ref_ch_key = (run, ref_ch_id_text, desc) - dfs_ref_ch[ref_ch_key].append(df_ref_ch) - pos_runs_ref_ch[ref_ch_key].append(pos) - - self.signals.progressBar.emit(1) - - self.signals.initProgressBar.emit(0) - - self.logger.log('Saving concantenated files...') - - allpos_folderpath = os.path.join(exp_path, 'spotMAX_multipos_output') - os.makedirs(allpos_folderpath, exist_ok=True) - - exp_name = os.path.basename(exp_path) - for key, dfs in dfs_spots.items(): - pos_keys = pos_runs[key] - run, analysis_step, desc, ext_spots = key - - if ext_spots == '.csv': - ext_spots = self._final_ext - filename = f'multipos_{run}{analysis_step}{desc}{ext_spots}' - all_exp_key = filename - df_spots_concat = spotmax.io.save_concat_dfs( - dfs, pos_keys, allpos_folderpath, filename, ext_spots, - names=['Position_n'], return_concat_df=True - ) - df_spots_concat['experiment_foldername'] = exp_name - df_spots_concat['experiment_folderpath'] = exp_path - spotmax_dfs_spots_allexp[all_exp_key]['dfs'].append( - df_spots_concat - ) - spotmax_dfs_spots_allexp[all_exp_key]['keys'].append( - exp_path - ) - ini_filepath = pos_ini_filepaths[(run, desc)] - ini_filename = os.path.basename(ini_filepath) - dst_ini_filepath = os.path.join(allpos_folderpath, ini_filename) - if not os.path.exists(dst_ini_filepath): - shutil.copy2(ini_filepath, dst_ini_filepath) - - spotmax_dfs_spots_allexp[all_exp_key]['ini_filepath'].append( - dst_ini_filepath - ) - - for key, dfs in dfs_aggr.items(): - pos_keys = pos_runs[key] - run, analysis_step, desc, _ = key - filename = ( - f'multipos_{run}{analysis_step}{desc}' - f'_aggregated{self._final_ext}' - ) - all_exp_aggr_key = filename - df_aggr_concat = spotmax.io.save_concat_dfs( - dfs, pos_keys, allpos_folderpath, filename, self._final_ext, - names=['Position_n'], return_concat_df=True - ) - spotmax_dfs_aggr_allexp[all_exp_aggr_key]['dfs'].append( - df_aggr_concat - ) - spotmax_dfs_aggr_allexp[all_exp_aggr_key]['keys'].append( - (exp_path, exp_name) - ) - - for key, dfs in dfs_ref_ch.items(): - run, ref_ch_id_text, desc = key - pos_keys = pos_runs_ref_ch[key] - filename = ( - f'multipos_{run}{ref_ch_id_text}{desc}{self._final_ext}' - ) - all_exp_ref_ch_key = filename - df_ref_ch_concat = spotmax.io.save_concat_dfs( - dfs, pos_keys, allpos_folderpath, filename, self._final_ext, - names=['Position_n'], return_concat_df=True - ) - ref_ch_dfs_allexp[all_exp_ref_ch_key]['dfs'].append( - df_ref_ch_concat - ) - ref_ch_dfs_allexp[all_exp_ref_ch_key]['keys'].append( - (exp_path, exp_name) - ) - - multiexp_dst_folderpath = '' - if len(expPaths) == 1: - self.signals.finished.emit(self) - return - - multiexp_dst_folderpath = self.emitAskFolderWhereToSaveMultiExp() - printl(multiexp_dst_folderpath) - if multiexp_dst_folderpath is None: - return - - self.logger.log( - f'Saving multi-experiment files to "{multiexp_dst_folderpath}"...' - ) - names = ['experiment_folderpath', 'experiment_foldername'] - for filename, items in spotmax_dfs_spots_allexp.items(): - keys = items['keys'] - dfs = items['dfs'] - multiexp_filename = f'multiexp_{filename}' - extension = os.path.splitext(filename)[-1] - spotmax.io.save_concat_dfs( - dfs, keys, multiexp_dst_folderpath, - multiexp_filename, - extension, - names=['experiment_folderpath'] - ) - ini_filepath = items['ini_filepath'][0] - ini_filename = os.path.basename(ini_filepath) - dst_ini_filepath = os.path.join( - multiexp_dst_folderpath, ini_filename - ) - if not os.path.exists(dst_ini_filepath): - shutil.copy2(ini_filepath, dst_ini_filepath) - - for filename, items in spotmax_dfs_aggr_allexp.items(): - keys = items['keys'] - dfs = items['dfs'] - printl(keys, pretty=True) - multiexp_filename = f'multiexp_{filename}' - extension = os.path.splitext(filename)[-1] - spotmax.io.save_concat_dfs( - dfs, keys, multiexp_dst_folderpath, - multiexp_filename, - extension, - names=names - ) - - for filename, items in ref_ch_dfs_allexp.items(): - keys = items['keys'] - dfs = items['dfs'] - multiexp_filename = f'multiexp_{filename}' - extension = os.path.splitext(filename)[-1] - spotmax.io.save_concat_dfs( - dfs, keys, multiexp_dst_folderpath, - multiexp_filename, - extension, - names=names - ) - - self.signals.finished.emit(self) - -class FilterObjsFromCoordsTable(BaseWorkerUtil): - sigAskAppendName = Signal(str, list) - sigAborted = Signal() - sigSetColumnsNames = Signal(object, object, object) - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitSetColumnsNames(self, columns, categories, optionalCategories): - self.mutex.lock() - self.sigSetColumnsNames.emit(columns, categories, optionalCategories) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def getColumnsCategories( - self, df_coords, exp_path, pos_foldernames, endFilenameSegm - ): - columns = df_coords.columns.to_list() - categories = ['X coord. column', 'Y coord. column'] - optionalCategories = [] - - images_path = os.path.join(exp_path, pos_foldernames[0], 'Images') - metadata_df = load.load_metadata_df(images_path) - SizeT = float(metadata_df.at['SizeT', 'values']) - SizeZ = float(metadata_df.at['SizeZ', 'values']) - - segmData = load.load_segm_file( - images_path, end_name_segm_file=endFilenameSegm - ) - - if segmData.ndim == 4: - categories.append('Z coord. column') - categories.append('Frame index column') - elif segmData.ndim == 3: - if SizeZ > 1 and SizeT == 1: - # 3D z-stack data - categories.append('Z coord. column') - else: - optionalCategories.append('Z coord. column') - - if SizeT > 1: - # 3D time-lapse - categories.append('Frame index column') - else: - optionalCategories.append('Frame index column') - else: - optionalCategories.append('Z coord. column') - optionalCategories.append('Frame index column') - - if len(pos_foldernames) > 1: - categories.append('Position_n') - else: - optionalCategories.append('Position_n') - - return columns, categories, optionalCategories - - def getDfCoords( - self, df_coords, selectedColumnsPerCategory, pos_foldername, frame_i - ): - pos_col = selectedColumnsPerCategory.get('Position_n', 'None') - frame_i_col = selectedColumnsPerCategory.get( - 'Frame index column', 'None' - ) - x_col = selectedColumnsPerCategory['X coord. column'] - y_col = selectedColumnsPerCategory['Y coord. column'] - if pos_col != 'None': - df_coords = df_coords[df_coords[pos_col] == pos_foldername] - if frame_i_col != 'None': - df_coords = df_coords[df_coords[frame_i_col] == frame_i] - - xy_cols = [x_col, y_col] - - df_out = pd.DataFrame( - index=df_coords.index, - data=df_coords[xy_cols].values, - columns=['x', 'y'] - ) - z_col = selectedColumnsPerCategory.get('Z coord. column', 'None') - if z_col != 'None': - df_out['z'] = df_coords[z_col] - - return df_out - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = f'Select segmentation file to filter' - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - endFilenameSegm = self.mainWin.endFilenameSegm - - self.logger.log('Asking to select the CSV table file...') - - abort = self.emitSelectFile( - exp_path, 'Select CSV table file with coordinates to filter', - 'CSV (*.csv)' - ) - if abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Loading table file `{self.mainWin.selectedFilepath}`..' - ) - df_coords = pd.read_csv(self.mainWin.selectedFilepath) - - columns, categories, optionalCategories = self.getColumnsCategories( - df_coords, exp_path, pos_foldernames, endFilenameSegm - ) - - abort = self.emitSetColumnsNames( - columns, categories, optionalCategories - ) - if abort: - self.sigAborted.emit() - return - - selectedColumnsPerCategory = self.mainWin.selectedColumnsPerCategory - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit( - self.mainWin.endFilenameSegm, - self.mainWin.existingSegmEndNames - ) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit(f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_acdc_df=True, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - if posData.SizeT == 1: - posData.segm_data = posData.segm_data[np.newaxis] - - self.logger.log('Filtering objects...') - - numFrames = len(posData.segm_data) - self.signals.sigInitInnerPbar.emit(numFrames) - filteredSegmData = np.zeros_like(posData.segm_data) - for frame_i, lab in enumerate(posData.segm_data): - df_coords_frame_i = self.getDfCoords( - df_coords, selectedColumnsPerCategory, pos, frame_i - ) - if df_coords_frame_i.empty: - num_frames_missing = len(posData.segm_data[frame_i:]) - self.signals.sigUpdateInnerPbar.emit(num_frames_missing) - filteredSegmData = filteredSegmData[:frame_i] - break - - filtered_lab = core.filter_segm_objs_from_table_coords( - lab, df_coords_frame_i - ) - filteredSegmData[frame_i] = filtered_lab - - self.signals.sigUpdateInnerPbar.emit(1) - - self.logger.log('Saving filtered segmentation file...') - segmFilename, ext = os.path.splitext(posData.segm_npz_path) - newSegmFilepath = f'{segmFilename}_{appendedName}.npz' - filteredSegmData = np.squeeze(filteredSegmData) - io.savez_compressed(newSegmFilepath, filteredSegmData) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class ScreenRecorderWorker(QObject): - sigGrabScreen = Signal() - finished = Signal() - - def __init__(self, screenRecorderWin, folder_path): - QObject.__init__(self) - self.screenRecorderWin = screenRecorderWin - self.folder_path = folder_path - - def run(self): - for i in range(4): - fn = f'shot_{i:03}.jpg' - grab_path = os.path.join(self.folder_path, fn) - screen = self.screenRecorderWin.screen() - screenshot = screen.grabWindow(self.screenRecorderWin.winId()) - screenshot.save(grab_path, 'jpg') - print(grab_path) - time.sleep(0.2) - - self.finished.emit() - -class CcaIntegrityCheckerWorker(QObject): - finished = Signal(object) - critical = Signal(object) - progress = Signal(str, object) - sigDone = Signal() - sigWarning = Signal(str, str) - sigFixWillDivide = Signal(str, list) - - def __init__(self, mutex, waitCond): - QObject.__init__(self) - self.logger = workerLogger(self.progress) - self.mutex = mutex - self.waitCond = waitCond - self.exit = False - self.isFinished = False - self.abortChecking = False - self.isChecking = False - self.isPaused = False - self.debug = False - self.dataQ = deque(maxlen=10) - - def pause(self): - if self.debug: - self.logger.log('Cell cycle annotations checker is idle.') - self.mutex.lock() - self.isPaused = True - self.waitCond.wait(self.mutex) - self.mutex.unlock() - self.isPaused = False - - def enqueue(self, posData): - # First stop previous checking - if self.isChecking: - self.abortChecking = True - self._enqueue(posData) - - def _enqueue(self, posData): - if self.debug: - self.logger.log('Enqueing posData...') - self.dataQ.append(posData) - if len(self.dataQ) == 1: - # Wake worker upon inserting first element - self.abortChecking = False - self.waitCond.wakeAll() - - def clearQueue(self): - self.dataQ.clear() - - def _stop(self): - self.exit = True - self.waitCond.wakeAll() - - def abort(self): - self.abortChecking = True - while not len(self.dataQ) == 0: - data = self.dataQ.pop() - del data - self._stop() - - def _check_equality_num_mothers_buds_in_S(self, checker, frame_i): - num_moth_S, num_buds = checker.get_num_mothers_and_buds_in_S() - - if num_moth_S == num_buds: - return True - - category = 'number of buds different from number of mothers in S phase' - ul_items = [ - f'Number of buds = {num_buds}', - f'Number of mothers in S phase = {num_moth_S}' - ] - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} the number of buds and number of ' - 'mother cells in S phase are different!' - f'{html_utils.to_list(ul_items)}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_mothers_multiple_buds(self, checker, frame_i): - mother_IDs_with_multiple_buds = ( - checker.get_mother_IDs_with_multiple_buds() - ) - if len(mother_IDs_with_multiple_buds) == 0: - return True - - category = 'mother cells with multiple buds' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following mother cells have multiple buds assigned to it' - f'

    {mother_IDs_with_multiple_buds}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_cells_without_G1(self, checker, global_cca_df): - IDs_cycles_without_G1 = ( - checker.get_IDs_cycles_without_G1(global_cca_df) - ) - if len(IDs_cycles_without_G1) == 0: - return True - - category = 'cell cycles without G1' - txt = html_utils.paragraph( - 'Cell-ACDC requires that every cell cycle has at least ' - 'one frame in G1.
    ' - 'The following pairs of (ID, generation number) ' - 'do not satisfy this condition:

    ' - f'{IDs_cycles_without_G1}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_will_divide_is_true(self, checker, global_cca_df): - # NOTE: unfortunately this function performs pandas manipulations - # that are either not thread-safe or in any case are freezing the - # GUI. For now we don't run this until we find a solution - return True - - IDs_will_divide_wrong = ( - checker.get_IDs_gen_num_will_divide_wrong(global_cca_df) - ) - if len(IDs_will_divide_wrong) == 0: - return True - - txt = html_utils.paragraph( - 'Cell-ACDC found that `will_divide` is annotated as True on the ' - 'following (ID, generation number) cell
    ' - 'despite the fact that division is still not annotated on ' - 'these cells

    :' - f'{IDs_will_divide_wrong}' - ) - self.sigFixWillDivide.emit(txt, IDs_will_divide_wrong) - return False - - def _check_buds_gen_num_zero(self, checker, frame_i): - bud_IDs_gen_num_nonzero = ( - checker.get_bud_IDs_gen_num_nonzero() - ) - if len(bud_IDs_gen_num_nonzero) == 0: - return True - - category = 'buds whose generation number is not zero' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following bud IDs have generation number different from 0:' - f'

    {bud_IDs_gen_num_nonzero}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_mothers_gen_num_greater_one(self, checker, frame_i): - moth_IDs_gen_num_non_greater_one = ( - checker.get_moth_IDs_gen_num_non_greater_one() - ) - if len(moth_IDs_gen_num_non_greater_one) == 0: - return True - - category = 'mothers whose generation number is < 1' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following mother cells have generation number < 1:' - f'

    {moth_IDs_gen_num_non_greater_one}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_buds_G1(self, checker, frame_i): - buds_G1 = ( - checker.get_buds_G1() - ) - if len(buds_G1) == 0: - return True - - category = 'buds in G1' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following bud IDs are in G1 (buds must be in S):' - f'

    {buds_G1}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_cell_S_rel_ID_zero(self, checker, frame_i): - cell_S_rel_ID_zero = ( - checker.get_cell_S_rel_ID_zero() - ) - if len(cell_S_rel_ID_zero) == 0: - return True - - category = 'buds in G1' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following cell IDs in S phase do not have ' - 'relative_ID > 0:' - f'

    {cell_S_rel_ID_zero}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_ID_rel_ID_mismatches(self, checker, frame_i): - ID_rel_ID_mismatches = checker.get_ID_rel_ID_mismatches() - if len(ID_rel_ID_mismatches) == 0: - return True - - items = [ - f'Cell ID {ID} has relative ID = {relID}, ' - f'while cell ID {relID} has relative ID = {relID_of_relID}' - for ID, relID, relID_of_relID in ID_rel_ID_mismatches - ] - category = '`ID-relative_ID` mismatches' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'there are the following `ID-relative_ID` mismatches:' - f'{html_utils.to_list(items)}' - ) - self.sigWarning.emit(txt, category) - return False - - def _check_lonely_cells_in_S(self, checker, frame_i): - lonely_cells_in_S = checker.get_lonely_cells_in_S() - if len(lonely_cells_in_S) == 0: - return True - - category = 'Lovely cells in S phase' - txt = html_utils.paragraph( - f'At frame n. {frame_i+1} ' - 'the following cell IDs are in `S` phase but their `relative_ID` ' - f'does not exist:

    ' - f'{lonely_cells_in_S}' - ) - self.sigWarning.emit(txt, category) - return False - - def _get_cca_df_copy(self, acdc_df): - try: - cca_df = pd.DataFrame( - data=acdc_df[cca_df_colnames].values, - columns=cca_df_colnames, - index=acdc_df.index - ) - return cca_df - except KeyError as error: - return - - def check(self, posData): - self.isChecking = True - checkpoints = ( - '_check_lonely_cells_in_S', - '_check_equality_num_mothers_buds_in_S', - '_check_mothers_multiple_buds', - '_check_buds_gen_num_zero', - '_check_mothers_gen_num_greater_one', - '_check_buds_G1', - '_check_cell_S_rel_ID_zero', - '_check_ID_rel_ID_mismatches' - ) - cca_dfs = [] - keys = [] - check_integrity_globally = True - for frame_i, data_dict in enumerate(posData.allData_li): - if self.abortChecking: - check_integrity_globally = False - break - - lab = data_dict['labels'] - if lab is None: - break - - cca_df = data_dict.get('cca_df_checker') - if cca_df is None: - # There are no annotations at frame_i --> stop - break - - IDs = data_dict['IDs'] - checker = core.CcaIntegrityChecker(cca_df, lab, IDs) - - for checkpoint in checkpoints: - proceed = getattr(self, checkpoint)(checker, frame_i) - if not proceed: - break - - if not proceed: - check_integrity_globally = False - break - - cca_dfs.append(cca_df) - keys.append(frame_i) - - if check_integrity_globally and len(cca_dfs)>1: - global_checkpoints = [ - '_check_cells_without_G1', - # '_check_will_divide_is_true' - ] - # Check integrity globally - global_cca_df = pd.concat(cca_dfs, keys=keys, names=['frame_i']) - for checkpoint in global_checkpoints: - proceed = getattr(self, checkpoint)(checker, global_cca_df) - if not proceed: - break - - self.abortChecking = False - self.isChecking = False - time.sleep(1) - - @worker_exception_handler - def run(self): - while True: - if self.exit: - self.logger.log('Closing cell cycle integrity checker worker...') - break - elif not len(self.dataQ) == 0: - if self.debug: - self.logger.log( - 'Checking integrity of cell cycle annotations ' - f'({len(self.dataQ)})...' - ) - data = self.dataQ.pop() - self.check(data) - if len(self.dataQ) == 0: - self.sigDone.emit() - else: - self.pause() - self.isFinished = True - self.finished.emit(self) - -class ApplyImageFilterWorker(QObject): - finished = Signal(object) - critical = Signal(object) - progress = Signal(str) - - def __init__(self, filter_func, input_data): - QObject.__init__(self) - self.filter_func = filter_func - self.input_data = input_data - - @worker_exception_handler - def run(self): - self.progress.emit('Filtering image...') - filtered_data = self.filter_func(self.input_data) - self.finished.emit(filtered_data) - -class MoveTempFilesWorker(QObject): - def __init__(self, temp_files_to_move: Dict[os.PathLike, os.PathLike]): - QObject.__init__(self) - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.temp_files_to_move = temp_files_to_move - - @worker_exception_handler - def run(self): - for src, dst in self.temp_files_to_move.items(): - self.logger.log(f'Saving channel data to: {dst}...') - shutil.move(src, dst) - tempDir = os.path.dirname(src) - shutil.rmtree(tempDir) - self.signals.progressBar.emit(1) - self.signals.finished.emit(self) - -class ResizeUtilWorker(BaseWorkerUtil): - sigSetResizeProps = Signal(str) - - def emitSetResizeProps(self, input_path): - self.mutex.lock() - self.sigSetResizeProps.emit(input_path) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def __init__(self, mainWin): - super().__init__(mainWin) - - def validateOutputPath(self, path): - if path is None: - return - - images_path = myutils.validate_images_path(path, create_dirs_tree=True) - return images_path - - @worker_exception_handler - def run(self): - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - abort = self.emitSetResizeProps(exp_path) - if abort: - self.signals.finished.emit(self) - return - - tot_pos = len(pos_foldernames) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.signals.finished.emit(self) - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - images_path = os.path.join(exp_path, pos, 'Images') - - - rf = self.resizeFactor - text_to_append = self.textToAppend - images_path_out = self.validateOutputPath(self.expFolderpathOut) - if images_path_out is None: - images_path_out = images_path - resize.run( - images_path, rf, - text_to_append=text_to_append, - images_path_out=images_path_out - ) - - self.signals.finished.emit(self) - -class FucciPreprocessWorker(BaseWorkerUtil): - sigAskAppendName = Signal(str) - sigAskParams = Signal(object, object) - sigAborted = Signal() - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitAskParams(self, exp_path, pos_foldernames): - self.mutex.lock() - self.sigAskParams.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def applyPipeline(self, first_ch_data, second_ch_data, filter_kwargs): - processed_data = np.zeros(first_ch_data.shape, dtype=np.uint8) - pbar = tqdm(total=len(processed_data), ncols=100) - with concurrent.futures.ThreadPoolExecutor() as executor: - iterable = enumerate(zip(first_ch_data, second_ch_data)) - func = partial( - core.fucci_pipeline_executor_map, **filter_kwargs - ) - result = executor.map(func, iterable) - for frame_i, processed_img in result: - processed_img = skimage.exposure.rescale_intensity(processed_img, out_range=(0, 255)) - processed_img = processed_img.astype(np.uint8) - processed_data[frame_i] = processed_img - pbar.update() - pbar.close() - - return processed_data - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = f'Setup parameters' - - if i == 0: - abort = self.emitAskParams(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Ask appendend name - self.mutex.lock() - self.sigAskAppendName.emit(self.basename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - - self.logger.log( - f'Loading {self.firstChannelName} channel data...' - ) - first_ch_filepath = load.get_filename_from_channel( - images_path, self.firstChannelName - ) - first_ch_data = load.load_image_file(first_ch_filepath) - - self.logger.log( - f'Loading {self.secondChannelName} channel data...' - ) - second_ch_filepath = load.get_filename_from_channel( - images_path, self.secondChannelName - ) - second_ch_data = load.load_image_file(second_ch_filepath) - - self.logger.log( - 'Applying FUCCI pre-processing pipeline...\n' - ) - processed_data = self.applyPipeline( - first_ch_data, second_ch_data, self.fucciFilterKwargs - ) - - basename, chNames = myutils.getBasenameAndChNames(images_path) - _, ext = os.path.splitext(first_ch_filepath) - processed_filename = f'{basename}{appendedName}{ext}' - processed_filepath = os.path.join( - images_path, processed_filename - ) - self.logger.log( - f'Saving pre-processed images to "{processed_filepath}"...' - ) - io.save_image_data(processed_filepath, processed_data) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class SimpleWorker(QObject): - def __init__(self, posData, func, func_args=None, func_kwargs=None): - QObject.__init__(self) - self.posData = posData - self.signals = signals() - self.output = {} - - if func_args is None: - func_args = [] - - if func_kwargs is None: - func_kwargs = {} - - self.func = func - self.func_args = func_args - self.func_kwargs = func_kwargs - self.posData = posData - - @worker_exception_handler - def run(self): - self.result = self.func( - self.posData, *self.func_args, **self.func_kwargs - ) - self.signals.finished.emit(self.output) - -class CopyAllLostObjectsWorker(QObject): - navigateToFrame = Signal(int) - returnToFrame = Signal(int) - copyLostObjectMask = Signal(int) - refreshRp = Signal() - progressBar = Signal(int) - finished = Signal(object) - critical = Signal(object) - - def __init__(self, gui, posData, for_future_frame_n, max_overlap_perc): - super().__init__() - self.gui = gui - self.posData = posData - self.for_future_frame_n = for_future_frame_n - self.max_overlap_perc = max_overlap_perc - - @worker_exception_handler - def run(self): - current_frame_i = self.posData.frame_i - last_visited_frame_i = self.gui.get_last_tracked_i() - last_copied_frame_i = current_frame_i + self.for_future_frame_n + 1 - frames_range = (current_frame_i, last_copied_frame_i) - overlap_warning = False - output = {} - - for frame_i in range(*frames_range): - if frame_i == self.posData.SizeT: - break - - if frame_i > self.posData.frame_i: - # Main thread navigates, runs tracking, updates rp/IDs, etc - self.navigateToFrame.emit(frame_i) - - for lostObj in skimage.measure.regionprops(self.gui.lostObjImage): - overlap = np.count_nonzero( - self.gui.currentLab2D[lostObj.slice][lostObj.image] - ) - overlap_perc = overlap / lostObj.area * 100 - if overlap_perc > self.max_overlap_perc: - overlap_warning = True - continue - - self.copyLostObjectMask.emit(lostObj.label) - - # Refresh rp so the next frame's updateLostNewCurrentIDs sees the - # copied IDs as belonging to this frame and marks them lost there. - self.refreshRp.emit() - - self.progressBar.emit(1) - - if self.for_future_frame_n == 0: - output['overlap_warning'] = overlap_warning - self.finished.emit(output) - return - - # Back to current frame - self.returnToFrame.emit(current_frame_i) - - if last_visited_frame_i < last_copied_frame_i: - output['doReinitLastSegmFrame'] = True - output['last_visited_frame_i'] = last_visited_frame_i - - output['overlap_warning'] = overlap_warning - self.finished.emit(output) - -class SaveProcessedDataWorker(QObject): - def __init__( - self, - allPosData: Iterable['load.loadData'], - appended_text_filename: str, - ext: str = None - ): - QObject.__init__(self) - self.allPosData = allPosData - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.appended_text_filename = appended_text_filename - self.ext = ext - - @worker_exception_handler - def run(self): - self.signals.initProgressBar.emit(0) - for posData in self.allPosData: - ext_loc = self.ext if self.ext is not None else posData.ext - processed_filename = ( - f'{posData.basename}{posData.user_ch_name}_' - f'{self.appended_text_filename}{ext_loc}' - ) - processed_filepath = os.path.join( - posData.images_path, processed_filename - ) - self.logger.log(f'Saving {processed_filepath}...') - processed_data = posData.preprocessedDataArray() - if processed_data is None: - self.logger.log( - f'[WARNING]: {posData.pos_foldername} does not have ' - 'preprocessed data. Skipping it.' - ) - continue - - io.save_image_data(processed_filepath, processed_data) - - self.signals.finished.emit(self) - -class SaveCombinedChannelsWorker(QObject): - sigDebugShowImg = Signal(object) - def __init__( - self, - allPosData: Iterable['load.loadData'], - filename: str, - debug: bool = False - ): - QObject.__init__(self) - self.allPosData = allPosData - self.signals = signals() - self.logger = workerLogger(self.signals.progress) - self.filename = filename - self.debug = debug - - @worker_exception_handler - def run(self): - self.signals.initProgressBar.emit(0) - for posData in self.allPosData: - processed_filepath = os.path.join( - posData.images_path, self.filename - ) - self.logger.log(f'Saving {processed_filepath}...') - processed_data = posData.combinedChannelsDataArray() - if processed_data is None: - self.logger.log( - f'[WARNING]: {posData.pos_foldername} does not have ' - 'combined channels data. Skipping it.' - ) - continue - if self.debug: - printl(processed_data.shape) - printl(processed_data.dtype) - printl(processed_data.min()) - printl(processed_data.max()) - printl(processed_filepath) - self.sigDebugShowImg.emit(processed_data) - # cellacdc.plot.imshow(processed_data) - io.save_image_data(processed_filepath, processed_data) - - self.signals.finished.emit(self) - -class CustomPreprocessWorkerGUI(QObject): - sigDone = Signal(object, str) - sigPreviewDone = Signal(object, tuple) - sigIsQueueEmpty = Signal(bool) - - def __init__(self, mutex, waitCond): - QObject.__init__(self) - self.signals = signals() - self.mutex = mutex - self.waitCond = waitCond - self.logger = workerLogger(self.signals.progress) - self.dataQ = deque(maxlen=2) - self.exit = False - self.wait = True - self._abort = False - - def enqueue( - self, - func: Callable, - image: np.ndarray, - recipe: Dict[str, Any], - key: Tuple[int, int, Union[int, str]] - ): - self.dataQ.append((func, image, recipe, key)) - if len(self.dataQ) == 1: - self.sigIsQueueEmpty.emit(False) - # Wake up worker upon inserting first element - self.wakeUp() - - def wakeUp(self): - self.wait = False - self.waitCond.wakeAll() - - def pause(self): - self.wait = True - self.mutex.lock() - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def abort(self): - self._abort = True - - def stop(self): - self.abort() - self.exit = True - self.waitCond.wakeAll() - self.signals.finished.emit(self) - - def setupJob( - self, - func: Callable, - image_data: np.ndarray, - recipe: Dict[str, Any], - how: str - ): - self._func = func - self._image_data = image_data - self._recipe = recipe - self._how = how - - def runJob(self, image=None, recipe=None): - if image is None: - image = self._image_data.copy() - if recipe is None: - recipe = self._recipe - - return self.applyRecipe(self._func, image, recipe) - - def applyRecipe( - self, - func: Callable, - image: np.ndarray, - recipe: List[Dict[str, Any]] - ): - preprocessed_data = func(image, recipe) - - keep_input_data_type = recipe[0].get('keep_input_data_type', True) - if not keep_input_data_type: - return preprocessed_data - - try: - preprocessed_data = myutils.convert_to_dtype( - preprocessed_data, image.dtype - ) - except Exception as err: - preprocessed_data = preprocessed_data.astype(image.dtype) - return preprocessed_data - - @worker_exception_handler - def run(self): - while True: - if self.exit: - self.logger.log('Closing pre-processing worker...') - break - elif self.wait: - self.logger.log('Pre-processing worker paused.') - self.pause() - elif len(self.dataQ) > 0: - func, image, recipe, key = self.dataQ.pop() - processed_data = self.applyRecipe(func, image, recipe) - self.sigPreviewDone.emit(processed_data, key) - if len(self.dataQ) == 0: - self.wait = True - self.sigIsQueueEmpty.emit(True) - else: - self.logger.log('Pre-processing worker resumed.') - processed_data = self.runJob() - self.sigDone.emit(processed_data, self._how) - self.wait = True - - self.signals.finished.emit(self) - -class CombineChannelsWorkerGUI(CustomPreprocessWorkerGUI): - sigDone = Signal(object, list) - sigPreviewDone = Signal(object, list) - sigAskLoadChannels = Signal(set, object) - - def __init__(self, mutex, waitCond, logger_func: Callable,): -# signals_parent=None): - super().__init__(mutex, waitCond) - - self.waitCondLoadFluoChannels = QWaitCondition() - self.logger_func = logger_func - - # if not signals_parent: - # signals_parent = signals() - - # self.signals = signals_parent - - def enqueue( - self, - data, - steps: Dict[str, Any], - key: Tuple[int, int, Union[int, str]], - keep_input_data_type: bool, - output_as_segm: bool, - formula: str, - ): - self.dataQ.append((data, steps, key, keep_input_data_type,output_as_segm, formula)) - if len(self.dataQ) == 1: - self.sigIsQueueEmpty.emit(False) - # Wake up worker upon inserting first element - self.wakeUp() - - def setupJob( - self, - data: Dict[str, np.ndarray], - steps: Dict[str, Any], - keep_input_data_type: bool, - key: Tuple[Union[int, None], Union[int, None], Union[int, None]], - output_as_segm: bool, - formula: str, - ): - self._key = key - self._steps = steps - self._data = data - self._keep_input_data_type = keep_input_data_type - self._output_as_segm = output_as_segm - self._formula = formula - - def runJob(self, data=None, steps=None, keep_input_data_type=None, key=None, - output_as_segm=None, formula=None): - if data is None: - data = self._data - if steps is None: - steps = self._steps - if keep_input_data_type is None: - keep_input_data_type = self._keep_input_data_type - if key is None: - key = self._key - if output_as_segm is None: - output_as_segm = self._output_as_segm - if formula is None: - formula = self._formula - - if not steps and formula is None: - return - - return self.applySteps(data, steps, keep_input_data_type, key, output_as_segm, formula=formula) - - def applySteps( - self, - data: Dict[str, np.ndarray], - steps: List[Dict[str, Any]], - keep_input_data_type: bool, - key: Tuple[Union[int, None], Union[int, None], Union[int, None]], - output_as_segm: bool, - formula: str, - ): - - new_keys = [] - key = list(key) - if key[0] is None: - pos_number = len(data) - key[0] = list(range(pos_number)) - else: - key[0] = [key[0]] - - for pos_i in key[0]: - new_keys_per_pos = [[pos_i]] - if key[1] is None: - frames = data[pos_i].SizeT - new_keys_per_pos.append(list(range(frames))) - else: - new_keys_per_pos.append([key[1]]) - - if key[2] is None: - z_slices = data[pos_i].SizeZ - if not z_slices: - z_slices = 1 - new_keys_per_pos.append(list(range(z_slices))) - else: - new_keys_per_pos.append([key[2]]) - - new_keys_per_pos = list(itertools.product(*new_keys_per_pos)) - new_keys.extend(new_keys_per_pos) - - output_imgs, out_keys = core.combine_channels_multithread_return_imgs( - steps=steps, - data=data, - keep_input_data_type=keep_input_data_type, - keys=new_keys, - logger_func=self.logger, - signals=self.signals, - output_as_segm=output_as_segm, - formula=formula, - - ) - return output_imgs, out_keys - - def requiredChannels(self, steps=None, pos_i=None): - if steps is None: - steps = self._steps - - required_channels = core.get_selected_channels(steps) - if pos_i is None: - pos_i = self._key[0] - - return required_channels, pos_i - - @worker_exception_handler - def run(self): - while True: - if self.exit: - self.logger.log('Closing combining channels worker...') - break - elif self.wait: - self.logger.log('Combining channels worker paused.') - self.pause() - elif len(self.dataQ) > 0: - data, steps, key, keep_input_data_type, output_as_segm, formula = self.dataQ.pop() - requ_steps, pos_i = self.requiredChannels(steps, key[0]) - self.emitsigAskLoadChannels(requ_steps, pos_i) - output_imgs, out_keys = self.applySteps( - data, steps, keep_input_data_type, key, - output_as_segm=output_as_segm, formula=formula - ) - self.sigPreviewDone.emit(output_imgs, out_keys) - if len(self.dataQ) == 0: - self.wait = True - self.sigIsQueueEmpty.emit(True) - else: - self.logger.log('Combining channels worker resumed.') - requ_steps, pos_i = self.requiredChannels() - self.emitsigAskLoadChannels(requ_steps, pos_i) - output_imgs, out_keys = self.runJob() - self.sigDone.emit(output_imgs, out_keys) - self.wait = True - - self.signals.finished.emit(self) - - def emitsigAskLoadChannels(self, requChannels, pos_i): - self.mutex.lock() - self.sigAskLoadChannels.emit(requChannels, pos_i) - self.waitCondLoadFluoChannels.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def wake_waitCondLoadFluoChannels(self): - self.mutex.lock() - self.waitCondLoadFluoChannels.wakeAll() - self.mutex.unlock() - -class CustomPreprocessWorkerUtil(BaseWorkerUtil): - sigAskAppendName = Signal(str) - sigAskSetupRecipe = Signal(object, object) - sigAborted = Signal() - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitAskSetupRecipe(self, exp_path, pos_foldernames): - self.mutex.lock() - self.sigAskSetupRecipe.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def applyPipeline( - self, - images_path: os.PathLike, - channel_names: Iterable[str], - recipe: List[Dict[str, Any]], - appended_text_filename: str - ): - posData = None - preprocessed_data = {} - for channel in channel_names: - self.logger.log(f'Loading {channel} channel data...') - ch_filepath = load.get_filename_from_channel(images_path, channel) - ch_image_data = load.load_image_file(ch_filepath) - if posData is None: - posData = load.loadData(ch_filepath, channel) - posData.getBasenameAndChNames() - posData.buildPaths() - posData.loadOtherFiles( - load_segm_data=False, - load_metadata=True, - ) - if posData.SizeT == 1: - ch_image_data = (ch_image_data,) - - preprocessed_ch_data = core.preprocess_image_from_recipe_multithread( - ch_image_data, recipe - ) - - keep_input_data_type = recipe[0].get('keep_input_data_type', True) - if keep_input_data_type: - preprocessed_ch_data = myutils.convert_to_dtype( - preprocessed_ch_data, ch_image_data.dtype - ) - - _, ext = os.path.splitext(ch_filepath) - basename = posData.basename - processed_filename = ( - f'{basename}{channel}_{appended_text_filename}{ext}' - ) - preprocessed_data[processed_filename] = preprocessed_ch_data - - return preprocessed_data - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = 'Setup recipe' - - if i == 0: - abort = self.emitAskSetupRecipe(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - # Ask append name - self.mutex.lock() - basename = f'{self.basename}{self.selectedChannels[0]}_' - self.sigAskAppendName.emit(basename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - self.logger.log( - 'Applying custom pre-processing recipe...\n' - ) - processed_data = self.applyPipeline( - images_path, self.selectedChannels, - self.recipe, appendedName - ) - - for filename, preprocessed_ch_data in processed_data.items(): - preprocessed_filepath = os.path.join(images_path, filename) - self.logger.log( - f'Saving pre-processed images to ' - f'"{preprocessed_filepath}"...' - ) - - io.save_image_data( - preprocessed_filepath, preprocessed_ch_data - ) - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) - -class CombineChannelsWorkerUtil(BaseWorkerUtil): - sigAskAppendName = Signal(str) - sigAskSetup = Signal(object) - sigAborted = Signal() - - def __init__(self, mainWin, mutex=None, waitCond=None): - super().__init__(mainWin) - - def emitAskSetup(self, expPaths): - self.mutex.lock() - self.sigAskSetup.emit(expPaths) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def applyPipeline( - self, - image_paths: os.PathLike, - steps: Dict[str, Dict[str, Any]], - appended_text_filename: str, - keep_input_data_type: bool, - n_threads: int = None, - formula: str = None, - ): - save_filepaths = [] - images_path_to_process = [] - if self.saveAsSegm: - out_ext = '.npz' - basename_ext = 'segm_' - else: - out_ext = '.tif' - basename_ext = '' - for images_path in image_paths: - basename, channels = myutils.getBasenameAndChNames(images_path) - - savename = ( - f'{basename}{basename_ext}{appended_text_filename}{out_ext}' - ) - - images_path_to_process.append(images_path) - save_filepaths.append(os.path.join(images_path, savename)) - - core.combine_channels_multithread( - steps=steps, - images_paths=images_path_to_process, - keep_input_data_type=keep_input_data_type, - save_filepaths=save_filepaths, - signals=self.signals, - logger_func=self.logger.log, - n_threads=n_threads, - output_as_segm=self.saveAsSegm, - formula=formula, - ) - - @worker_exception_handler - def run(self): - - self.signals.initProgressBar.emit(0) - - expPaths = self.mainWin.expPaths - abort = self.emitAskSetup(expPaths) - if abort: - self.sigAborted.emit() - return - - # Ask append name - self.mutex.lock() - basename = f'{self.basename}' - self.sigAskAppendName.emit(basename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - if self.abort: - self.sigAborted.emit() - return - - appendedName = self.appendedName - - selectedSteps = self.selectedSteps - - self.logger.log('Applying pipeline...') - self.logger.log('Selected steps:') - for step in selectedSteps.values(): - self.logger.log(step) - - image_paths = [] - for exp_path, pos_foldernames in expPaths.items(): - image_paths += [os.path.join(exp_path, pos, 'Images') for pos in pos_foldernames] - - self.signals.initProgressBar.emit(len(pos_foldernames)) - formula = self.formula - self.applyPipeline( - image_paths, - selectedSteps, - appendedName, - self.keepInputDataType, - n_threads=self.nThreads, - formula=formula, - ) - - self.signals.finished.emit(self) - -class saveDataWorker(QObject): - finished = Signal() - progress = Signal(str) - sigLog = Signal(str) - progressBar = Signal(int, int, float) - critical = Signal(object) - addMetricsCritical = Signal(str, str) - regionPropsCritical = Signal(str, str) - criticalPermissionError = Signal(str) - metricsPbarProgress = Signal(int, int) - askZsliceAbsent = Signal(str, object) - customMetricsCritical = Signal(str, str) - sigCombinedMetricsMissingColumn = Signal(str, str) - sigDebug = Signal(object) - - def __init__(self, mainWin): - QObject.__init__(self) - self.mainWin = mainWin - self.saveWin = mainWin.saveWin - self.mutex = mainWin.mutex - self.waitCond = mainWin.waitCond - self.customMetricsErrors = {} - self.addMetricsErrors = {} - self.regionPropsErrors = {} - self.abort = False - - def checkAbort(self): - if self.saveWin.aborted: - self.finished.emit() - return True - return False - - def saveManualBackgroundData(self, posData, frame_i): - data_dict = posData.allData_li[frame_i] - if 'manualBackgroundLab' not in data_dict: - return - - manualBackgrData = data_dict['manualBackgroundLab'] - posData.saveManualBackgroundData(manualBackgrData) - - def emitSigPermissionErrorAndSave( - self, all_frames_acdc_df, acdc_output_csv_path, - custom_annot_columns - ): - err_msg = ( - 'The below file is open in another app ' - '(Excel maybe?).\n\n' - f'{acdc_output_csv_path}\n\n' - 'Close file and then press "Ok".' - ) - self.mutex.lock() - self.criticalPermissionError.emit(err_msg) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - # Save segmentation metadata - load.save_acdc_df_file( - all_frames_acdc_df, acdc_output_csv_path, - custom_annot_columns=custom_annot_columns, - last_cca_frame_i=self.mainWin.save_cca_until_frame_i - ) - - def _emitSigDebug(self, stuff_to_debug): - self.mutex.lock() - self.sigDebug.emit(stuff_to_debug) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - - def emitUpdateProgressBar(self): - t = time.perf_counter() - exec_time = t - self.time_last_pbar_update - self.progressBar.emit(1, -1, exec_time) - self.time_last_pbar_update = t - - def saveAcdcDf(self, posData: load.loadData, end_i): - acdc_dfs_li = [] - keys = [] - self.progress.emit(f'Saving annotations for {posData.relPath}...') - for frame_i, data_dict in enumerate(posData.allData_li[:end_i+1]): - if self.saveWin.aborted: - self.finished.emit() - return - - # Build saved_segm_data - lab = data_dict['labels'] - if lab is None: - break - - acdc_df = posData.allData_li[frame_i]['acdc_df'] - if acdc_df is None: - continue - - acdc_dfs_li.append(acdc_df) - keys.append((frame_i, posData.TimeIncrement*frame_i)) - - if not acdc_dfs_li: - return - - self.mainWin._measurements_kernel._concat_and_save_acdc_df( - acdc_dfs_li, keys, posData, self.mainWin.save_metrics, - saveDataWorker=self, - last_cca_frame_i=self.mainWin.save_cca_until_frame_i - ) - - def saveSegmData(self, posData, end_i, saved_segm_data): - self.progress.emit(f'Saving segmentation data for {posData.relPath}...') - - - # extend saved_segm_data if needed - if posData.SizeT > 1: - missing_frames_number = end_i + 1 - len(saved_segm_data) - if missing_frames_number > 0: - saved_segm_data = np.concatenate( - ( - saved_segm_data, - np.zeros( - (missing_frames_number, *saved_segm_data.shape[1:]), - dtype=saved_segm_data.dtype - ) - ), - ) - - - for frame_i, data_dict in enumerate(posData.allData_li[:end_i+1]): - if self.saveWin.aborted: - self.finished.emit() - return - - # Build saved_segm_data - lab = data_dict['labels'] - if lab is None: - break - - posData.lab = lab - - if posData.SizeT > 1: - saved_segm_data[frame_i] = lab - else: - saved_segm_data = lab - if 'manualBackgroundLab' in data_dict: - manualBackgrData = data_dict['manualBackgroundLab'] - posData.saveManualBackgroundData(manualBackgrData) - - # Save segmentation file - io.savez_compressed( - posData.segm_npz_path, np.squeeze(saved_segm_data) - ) - posData.segm_data = saved_segm_data - # Allow single 2D/3D image - if posData.SizeT == 1: - posData.segm_data = posData.segm_data[np.newaxis] - - try: - os.remove(posData.segm_npz_temp_path) - except Exception as e: - pass - - @worker_exception_handler - def run(self): - posToSave = self.mainWin.posToSave - if posToSave is None: - numPosToSave = 1 - else: - numPosToSave = len(posToSave) - save_metrics = self.mainWin.save_metrics - if self.isQuickSave: - save_metrics = False - self.time_last_pbar_update = time.perf_counter() - mode = self.mode - for p, posData in enumerate(self.mainWin.data): - if self.saveWin.aborted: - self.finished.emit() - return - - if posToSave is not None: - if posData.pos_foldername not in posToSave: - self.progress.emit(f'Skipping {posData.relPath}') - continue - - last_tracked_i_path = posData.last_tracked_i_path - end_i = self.mainWin.save_until_frame_i - self.saveSegmData(posData, end_i, posData.segm_data) - - posData.saveCustomAnnotationParams() - current_frame_i = posData.frame_i - - posData.saveTrackedLostCentroids() - - if not self.mainWin.isSnapshot: - last_tracked_i = self.mainWin.last_tracked_i - if last_tracked_i is None: - self.mainWin.saveWin.aborted = True - self.finished.emit() - return - elif self.mainWin.isSnapshot: - last_tracked_i = 0 - - if p == 0: - self.progressBar.emit(0, numPosToSave*(last_tracked_i+1), 0) - - acdc_output_csv_path = posData.acdc_output_csv_path - delROIs_info_path = posData.delROIs_info_path - - # Add segmented channel data for calc metrics if requested - add_user_channel_data = True - for chName in self.mainWin._measurements_kernel.chNamesToSkip: - skipUserChannel = ( - posData.filename.endswith(chName) - or posData.filename.endswith(f'{chName}_aligned') - ) - if skipUserChannel: - add_user_channel_data = False - - if add_user_channel_data and not self.isQuickSave: - posData.fluo_data_dict[posData.filename] = posData.img_data - - if not self.isQuickSave: - posData.fluo_bkgrData_dict[posData.filename] = posData.bkgrData - - posData.setLoadedChannelNames() - - if not self.isQuickSave: - self.mainWin.initMetricsToSave(posData) - self.mainWin._measurements_kernel.run( - posData=posData, - stop_frame_n=end_i+1, - saveDataWorker=self, - save_metrics=self.mainWin.save_metrics, - last_cca_frame_i=self.mainWin.save_cca_until_frame_i - ) - else: - self.saveAcdcDf(posData, end_i) - - self.progress.emit(f'Saving {posData.relPath}') - - if not self.do_not_save_og_whitelist: - og_save_path = os.path.join( - posData.images_path, self.append_name_og_whitelist - ) - posData.whitelist.saveOGLabs(og_save_path) - - if posData.whitelist: - whitelistIDs_path = posData.segm_npz_path.replace( - '.npz', '_whitelistIDs.json' - ) - new_centroids_path = posData.segm_npz_path.replace( - '.npz', '_new_centroids.json' - ) - posData.whitelist.save( - whitelistIDs_path, new_centroids_path=new_centroids_path - ) - - if posData.segmInfo_df is not None: - try: - posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - except PermissionError: - err_msg = ( - 'The below file is open in another app ' - '(Excel maybe?).\n\n' - f'{posData.segmInfo_df_csv_path}\n\n' - 'Close file and then press "Ok".' - ) - self.mutex.lock() - self.criticalPermissionError.emit(err_msg) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) - - posData.fluo_data_dict.pop(posData.filename, None) - - if not self.isQuickSave: - posData.fluo_bkgrData_dict.pop(posData.filename) - - if posData.SizeT > 1: - self.progress.emit('Almost done...') - self.progressBar.emit(0, 0, 0) - - if self.isQuickSave: - # Go back to current frame - posData.frame_i = current_frame_i - self.mainWin.get_data() - continue - - with open(last_tracked_i_path, 'w+') as txt: - txt.write(str(end_i)) - - # Save combined metrics equations - posData.saveCombineMetrics() - self.mainWin.pointsLayerDataToDf(posData) - posData.saveClickEntryPointsDfs() - - posData.last_tracked_i = last_tracked_i - - # Go back to current frame - posData.frame_i = current_frame_i - self.mainWin.get_data() - - if mode == 'Segmentation and Tracking' or mode == 'Viewer': - self.progress.emit( - f'Saved data until frame number {end_i+1}' - ) - elif mode == 'Cell cycle analysis': - self.progress.emit( - 'Saved cell cycle annotations until frame ' - f'number {self.mainWin.last_cca_frame_i+1}' - ) - # self.progressBar.emit(1) - if self.mainWin.isSnapshot: - self.progress.emit(f'Saved all {p+1} Positions!') - - self.finished.emit() - -class relabelSequentialWorker(QObject): - finished = Signal() - critical = Signal(object) - progress = Signal(str) - sigRemoveItemsGUI = Signal(int) - debug = Signal(object) - - def __init__(self, mainWin, posFoldernames): - QObject.__init__(self) - self.mainWin = mainWin - self.data = mainWin.data - self.posFoldernames = posFoldernames - self.mutex = QMutex() - self.waitCond = QWaitCondition() - - def progressNewIDs(self, oldIDs, newIDs): - li = list(zip(oldIDs, newIDs)) - s = '\n'.join([str(pair).replace(',', ' -->') for pair in li]) - s = f'IDs relabelled as follows:\n{s}' - self.progress.emit(s) - - @worker_exception_handler - def run(self): - self.mutex.lock() - - self.progress.emit('Relabelling process started...') - mainWin = self.mainWin - - current_pos_i = mainWin.pos_i - - for p, posData in enumerate(self.data): - if posData.pos_foldername not in self.posFoldernames: - continue - - mainWin.pos_i = p - current_lab = mainWin.get_2Dlab(posData.lab).copy() - current_frame_i = posData.frame_i - segm_data = [] - for frame_i, data_dict in enumerate(posData.allData_li): - lab = data_dict['labels'] - if lab is None: - break - segm_data.append(lab) - # if frame_i == current_frame_i: - # break - - if not segm_data: - segm_data = np.array([current_lab]) - - segm_data = np.array(segm_data) - segm_data, oldIDs, newIDs = core.relabel_sequential( - segm_data, is_timelapse=posData.SizeT>1 - ) - self.progressNewIDs(oldIDs, newIDs) - self.sigRemoveItemsGUI.emit(np.max(segm_data)) - - self.progress.emit( - 'Updating stored data and cell cycle annotations ' - '(if present)...' - ) - - mainWin.updateAnnotatedIDs(oldIDs, newIDs, logger=self.progress.emit) - mainWin.store_data(mainThread=False) - - for frame_i, lab in enumerate(segm_data): - posData.frame_i = frame_i - posData.lab = lab - mainWin.get_cca_df() - if posData.cca_df is not None: - mainWin.update_cca_df_relabelling( - posData, oldIDs, newIDs - ) - mainWin.update_rp(draw=False) - mainWin.store_data(mainThread=False) - - # Go back to current frame - mainWin.pos_i = current_pos_i - posData = self.data[mainWin.pos_i] - posData.frame_i = current_frame_i - mainWin.get_data() - - self.mutex.unlock() - self.finished.emit() - -class MagicPromptsWorker(QObject): - def __init__( - self, posData, image, df_points, model, model_segment_kwargs, - image_origin=(0, 0, 0), global_image=None - ): - QObject.__init__(self) - - self.signals = signals() - self.posData = posData - self.image = image - if global_image is not None: - self.global_image = global_image - else: - self.global_image = image - self.df_points = df_points - self.image_origin = image_origin - self.model = model - self.model_segment_kwargs = model_segment_kwargs - - @worker_exception_handler - def run(self): - from cellacdc.promptable_models import utils - - for row in self.df_points.itertuples(): - prompt_id = row.id - point = (row.z, row.y, row.x) - print(f'Adding point prompt {point} with id = {prompt_id}...') - parent_obj_id = row.Cell_ID if row.Cell_ID == prompt_id else 0 - self.model.add_prompt( - prompt=point, - prompt_id=prompt_id, - parent_obj_id=parent_obj_id, - image=self.image, - image_origin=self.image_origin, - prompt_type='point' - ) - - lab_out = self.model.segment( - self.global_image, - lab=self.posData.lab, - **self.model_segment_kwargs - ) - edited_IDs = self.df_points['Cell_ID'].unique() - - lab_new, lab_union, lab_interesection = ( - utils.insert_model_output_into_labels( - self.posData.lab, - lab_out, - edited_IDs=edited_IDs - ) - ) - - self.signals.finished.emit((lab_new, lab_union, lab_interesection)) - -class FillHolesInSegWorker(BaseWorkerUtil): - sigAskAppendName = Signal(str) - sigAborted = Signal() - sigSelectSegmFiles = Signal(str, list) - - - def __init__(self, mainWin): - super().__init__(mainWin) - - def emitSelectSegmFiles(self, exp_path, pos_foldernames): - self.mutex.lock() - self.sigSelectSegmFiles.emit(exp_path, pos_foldernames) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - def emitAskAppendName(self, basename): - self.mutex.lock() - self.sigAskAppendName.emit(basename) - self.waitCond.wait(self.mutex) - self.mutex.unlock() - return self.abort - - @worker_exception_handler - def run(self): - expPaths = self.mainWin.expPaths - lab_paths_dict = dict() - unique_segm_files = set() - tot_segm_files = 0 - for exp_path, pos_foldernames in expPaths.items(): - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - for pos_folder in pos_foldernames: - imgs_path = os.path.join(exp_path, - pos_folder, - "Images") - lab_paths_dict[imgs_path] = self.endFilenameSegmTemp - tot_segm_files += len(self.endFilenameSegmTemp) - unique_segm_files.update(self.endFilenameSegmTemp) - - self.logger.info('Filling holes in segmentation masks...') - abort = self.emitAskAppendName("/".join(unique_segm_files)) - if abort: - self.sigAborted.emit() - return - self.signals.initProgressBar.emit(tot_segm_files) - for images_path, segm_file_names in lab_paths_dict.items(): - for segm_file_name in segm_file_names: - segm_data, segm_data_path = load.load_segm_file( - images_path, end_name_segm_file=segm_file_name, return_path=True - ) - segm_data_shape = segm_data.shape - segm_data_ndim = len(segm_data_shape) - if segm_data_ndim == 2: - segm_data = segm_data[np.newaxis, np.newaxis, ...] - elif segm_data_ndim == 3: - segm_data = segm_data[np.newaxis, ...] - elif segm_data_ndim == 4: - segm_data = segm_data - else: - raise NotImplementedError( - "This ndim is not supported!" - ) - for i, stack in enumerate(segm_data): - for j, lab in enumerate(stack): - segm_data[i, j] = core.fill_holes_in_segmentation(lab) - - segm_data_save_path = (segm_data_path - .replace(segm_file_name, - f"{segm_file_name}{self.appendedName}")) - io.savez_compressed(segm_data_save_path, segm_data) - self.signals.progressBar.emit(1) - self.signals.finished.emit(self) - -class GenerateMotherBudTotalTableWorker(BaseWorkerUtil): - def __init__( - self, parentWin, input_csv_filepath, selected_options, - out_csv_filepath - ): - super().__init__(parentWin) - self.input_csv_filepath = input_csv_filepath - self.selected_options = selected_options - self.out_csv_filepath = out_csv_filepath - - @worker_exception_handler - def run(self): - self.logger.log(f'Loading table "{self.input_csv_filepath}"...') - self.signals.initProgressBar.emit(0) - - input_df = pd.read_csv(self.input_csv_filepath) - - self.logger.log('Generating output table...') - out_df = cca_functions.generate_mother_bud_total_df( - input_df, **self.selected_options - ) - - self.logger.log(f'Saving output table to "{self.out_csv_filepath}"...') - - out_df.to_csv(self.out_csv_filepath) - - self.signals.finished.emit(self) - -class CountObjectsInSegm(BaseWorkerUtil): - sigAskAppendName = Signal(str, list) - sigAborted = Signal() - - def __init__(self, mainWin): - super().__init__(mainWin) - - @worker_exception_handler - def run(self): - debugging = False - expPaths = self.mainWin.expPaths - tot_exp = len(expPaths) - self.signals.initProgressBar.emit(0) - for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): - self.errors = {} - tot_pos = len(pos_foldernames) - - self.mainWin.infoText = f'Select segmentation file to count' - abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) - if abort: - self.sigAborted.emit() - return - - self.signals.initProgressBar.emit(len(pos_foldernames)) - for p, pos in enumerate(pos_foldernames): - if self.abort: - self.sigAborted.emit() - return - - self.logger.log( - f'Processing experiment n. {i+1}/{tot_exp}, ' - f'{pos} ({p+1}/{tot_pos})' - ) - - images_path = os.path.join(exp_path, pos, 'Images') - endFilenameSegm = self.mainWin.endFilenameSegm - ls = myutils.listdir(images_path) - file_path = [ - os.path.join(images_path, f) for f in ls - if f.endswith(f'{endFilenameSegm}.npz') - ][0] - - posData = load.loadData(file_path, '') - - self.signals.sigUpdatePbarDesc.emit( - f'Processing {posData.pos_path}') - - posData.getBasenameAndChNames() - posData.buildPaths() - - posData.loadOtherFiles( - load_segm_data=True, - load_acdc_df=False, - load_metadata=True, - end_filename_segm=endFilenameSegm - ) - if posData.segm_data.ndim == 3: - posData.segm_data = posData.segm_data[np.newaxis] - - self.logger.log('Counting objects...') - - countMapper = posData.countObjectsInSegm() - countMapper.pop('In current frame', None) - df_count_endname = posData.saveObjCounts(countMapper) - - self.logger.log( - 'Saved object counts table to file ending with: ' - f'"{df_count_endname}"' - ) - - self.signals.progressBar.emit(1) - - self.signals.finished.emit(self) diff --git a/cellacdc/workers/__init__.py b/cellacdc/workers/__init__.py new file mode 100644 index 000000000..aec650a87 --- /dev/null +++ b/cellacdc/workers/__init__.py @@ -0,0 +1,149 @@ +"""Background Qt workers.""" + +from ._base import ( + BaseWorkerUtil, + SimpleWorker, + signals, + workerLogger, + worker_exception_handler, +) + +from .alignment import ( + AlignDataWorker, + AlignWorker, +) + +from .data_prep import ( + CombineChannelsWorkerGUI, + CombineChannelsWorkerUtil, + CustomPreprocessWorkerGUI, + CustomPreprocessWorkerUtil, + DataPrepCropWorker, + DataPrepSaveBkgrDataWorker, + FucciPreprocessWorker, + ImagesToPositionsWorker, + RestructMultiPosWorker, + RestructMultiTimepointsWorker, + SaveCombinedChannelsWorker, + SaveProcessedDataWorker, + reapplyDataPrepWorker, +) + +from .gui import ( + AutoPilotWorker, + FindNextNewIdWorker, +) + +from .io import ( + AutoSaveWorker, + LazyLoader, + MigrateUserProfileWorker, + MoveTempFilesWorker, + StoreGuiStateWorker, + loadDataWorker, + relabelSequentialWorker, + saveDataWorker, +) + +from .metrics import ( + CcaIntegrityCheckerWorker, + ComputeMetricsMultiChannelWorker, + ComputeMetricsWorker, + ConcatAcdcDfsWorker, + ConcatSpotmaxDfsWorker, + CountObjectsInSegm, + GenerateMotherBudTotalTableWorker, +) + +from .segm import ( + CreateConnected3Dsegm, + DelObjectsOutsideSegmROIWorker, + FillHolesInSegWorker, + LabelRoiWorker, + MagicPromptsWorker, + PostProcessSegmWorker, + SegForLostIDsWorker, + segmVideoWorker, + segmWorker, +) + +from .tracking import ( + ApplyTrackInfoWorker, + CopyAllLostObjectsWorker, + ToSymDivWorker, + TrackSubCellObjectsWorker, + trackingWorker, +) + +from .util import ( + ApplyImageFilterWorker, + FilterObjsFromCoordsTable, + FromImajeJroiToSegmNpzWorker, + ResizeUtilWorker, + ScreenRecorderWorker, + Stack2DsegmTo3Dsegm, + ToImajeJroiWorker, + ToObjCoordsWorker, +) + +__all__ = [ + "BaseWorkerUtil", + "SimpleWorker", + "signals", + "workerLogger", + "worker_exception_handler", + "AlignDataWorker", + "AlignWorker", + "CombineChannelsWorkerGUI", + "CombineChannelsWorkerUtil", + "CustomPreprocessWorkerGUI", + "CustomPreprocessWorkerUtil", + "DataPrepCropWorker", + "DataPrepSaveBkgrDataWorker", + "FucciPreprocessWorker", + "ImagesToPositionsWorker", + "RestructMultiPosWorker", + "RestructMultiTimepointsWorker", + "SaveCombinedChannelsWorker", + "SaveProcessedDataWorker", + "reapplyDataPrepWorker", + "AutoPilotWorker", + "FindNextNewIdWorker", + "AutoSaveWorker", + "LazyLoader", + "MigrateUserProfileWorker", + "MoveTempFilesWorker", + "StoreGuiStateWorker", + "loadDataWorker", + "relabelSequentialWorker", + "saveDataWorker", + "CcaIntegrityCheckerWorker", + "ComputeMetricsMultiChannelWorker", + "ComputeMetricsWorker", + "ConcatAcdcDfsWorker", + "ConcatSpotmaxDfsWorker", + "CountObjectsInSegm", + "GenerateMotherBudTotalTableWorker", + "CreateConnected3Dsegm", + "DelObjectsOutsideSegmROIWorker", + "FillHolesInSegWorker", + "LabelRoiWorker", + "MagicPromptsWorker", + "PostProcessSegmWorker", + "SegForLostIDsWorker", + "segmVideoWorker", + "segmWorker", + "ApplyTrackInfoWorker", + "CopyAllLostObjectsWorker", + "ToSymDivWorker", + "TrackSubCellObjectsWorker", + "trackingWorker", + "ApplyImageFilterWorker", + "FilterObjsFromCoordsTable", + "FromImajeJroiToSegmNpzWorker", + "ResizeUtilWorker", + "ScreenRecorderWorker", + "Stack2DsegmTo3Dsegm", + "ToImajeJroiWorker", + "ToObjCoordsWorker", +] diff --git a/cellacdc/workers/_base.py b/cellacdc/workers/_base.py new file mode 100644 index 000000000..532c54ecb --- /dev/null +++ b/cellacdc/workers/_base.py @@ -0,0 +1,239 @@ +"""Background Qt workers: _base.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +def worker_exception_handler(func): + @wraps(func) + def run(self): + try: + func(self) + except Exception as error: + printl(traceback.format_exc()) + try: + self.dataQ.clear() + except Exception as err: + pass + + # Some workers have both self.critical and self.signals.critical + # errors but only one of them is connected --> emit both just + # in case + try: + self.critical.emit((self, error)) + except Exception as err: + self.signals.critical.emit((self, error)) + + try: + self.signals.critical.emit((self, error)) + except Exception as err: + self.critical.emit((self, error)) + + try: + self.mutex.unlock() + except Exception as err: + pass + + return run + + +class workerLogger: + def __init__(self, sigProcess): + self.sigProcess = sigProcess + + def log(self, message, level="INFO"): + try: + self.sigProcess.emit(str(message), level) + except Exception as err: + print(message, level) + try: + traceback_format = traceback.format_exc() + print(traceback_format) + except Exception as err: + pass + printl(err) + finally: + pass + + def info(self, message): + self.log(message, level="INFO") + + def warning(self, message): + self.log(message, level="WARNING") + + def exception(self, message): + self.log(message, level="EXCEPTION") + + +class signals(QObject): + progress = Signal(str, object) + finished = Signal(object) + initProgressBar = Signal(int) + progressBar = Signal(int) + critical = Signal(object) + dataIntegrityWarning = Signal(str) + dataIntegrityCritical = Signal() + sigLoadingFinished = Signal() + sigLoadingNewChunk = Signal(object) + resetInnerPbar = Signal(int) + progress_tqdm = Signal(int) + signal_close_tqdm = Signal() + create_tqdm = Signal(int) + innerProgressBar = Signal(int) + sigPermissionError = Signal(str, object) + sigSelectSegmFiles = Signal(object, object) + sigSelectAcdcOutputFiles = Signal(object, object, str, bool, bool) + sigSelectSpotmaxRun = Signal(object, object, object, str, bool, bool) + sigSetMeasurements = Signal(object) + sigInitAddMetrics = Signal(object, object) + sigUpdatePbarDesc = Signal(str) + sigComputeVolume = Signal(int, object) + sigAskStopFrame = Signal(object) + sigWarnMismatchSegmDataShape = Signal(object) + sigErrorsReport = Signal(dict, dict, dict) + sigMissingAcdcAnnot = Signal(dict) + sigRecovery = Signal(object) + sigInitInnerPbar = Signal(int) + sigUpdateInnerPbar = Signal(int) + sigSelectFile = Signal(str, str, str) + sigAskCopyCca = Signal(str) + sigSelectFilesWithText = Signal(str, object, str, object) + sigAskRunNow = Signal(object) + + +class BaseWorkerUtil(QObject): + progressBar = Signal(int, int, float) + + def __init__(self, mainWin): + QObject.__init__(self) + self.signals = signals() + self.abort = False + self.skipExp = False + self.logger = workerLogger(self.signals.progress) + self.mutex = QMutex() + self.waitCond = QWaitCondition() + self.mainWin = mainWin + + def emitSelectSegmFiles(self, exp_path, pos_foldernames): + self.mutex.lock() + self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitSelectFilesWithText(self, exp_path, pos_foldernames, with_text, ext=None): + self.mutex.lock() + self.signals.sigSelectFilesWithText.emit( + exp_path, pos_foldernames, with_text, ext + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitSelectFile(self, start_dir, caption="", filters="All files (*.)"): + self.mutex.lock() + self.signals.sigSelectFile.emit(start_dir, caption, filters) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitSelectAcdcOutputFiles( + self, + exp_path, + pos_foldernames, + infoText="", + allowSingleSelection=False, + multiSelection=True, + ): + self.mutex.lock() + self.signals.sigSelectAcdcOutputFiles.emit( + exp_path, pos_foldernames, infoText, allowSingleSelection, multiSelection + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitSelectSpotmaxRun( + self, + exp_path, + pos_foldernames, + all_runs, + infoText="", + allowSingleSelection=True, + multiSelection=True, + ): + self.mutex.lock() + self.signals.sigSelectSpotmaxRun.emit( + exp_path, + pos_foldernames, + all_runs, + infoText, + allowSingleSelection, + multiSelection, + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + +class SimpleWorker(QObject): + def __init__(self, posData, func, func_args=None, func_kwargs=None): + QObject.__init__(self) + self.posData = posData + self.signals = signals() + self.output = {} + + if func_args is None: + func_args = [] + + if func_kwargs is None: + func_kwargs = {} + + self.func = func + self.func_args = func_args + self.func_kwargs = func_kwargs + self.posData = posData + + @worker_exception_handler + def run(self): + self.result = self.func(self.posData, *self.func_args, **self.func_kwargs) + self.signals.finished.emit(self.output) diff --git a/cellacdc/workers/alignment.py b/cellacdc/workers/alignment.py new file mode 100644 index 000000000..3138f0949 --- /dev/null +++ b/cellacdc/workers/alignment.py @@ -0,0 +1,476 @@ +"""Background Qt workers: alignment.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class AlignDataWorker(QObject): + sigWarnTifAligned = Signal(object, object, object) + sigAskAlignSegmData = Signal() + + def __init__(self, posData, dataPrepWin, mutex, waitCond): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.posData = posData + self.dataPrepWin = dataPrepWin + self.mutex = mutex + self.waitCond = waitCond + self.doNotAlignSegmData = False + self.doAbort = False + + def set_attr(self, align, user_ch_name): + self.align = align + self.user_ch_name = user_ch_name + + def pause(self): + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def restart(self): + self.waitCond.wakeAll() + + def emitWarnTifAligned(self, numFramesWith0s, tif, posData): + self.sigWarnTifAligned.emit(numFramesWith0s, tif, posData) + self.pause() + + def emitSigAskAlignSegmData(self): + self.sigAskAlignSegmData.emit() + self.pause() + + def _align_data(self): + _zip = zip(self.posData.tif_paths, self.posData.npz_paths) + aligned = False + self.posData.all_npz_paths = [ + tif.replace(".tif", "_aligned.npz") for tif in self.posData.tif_paths + ] + for i, (tif, npz) in enumerate(_zip): + doAlign = npz is None or self.posData.loaded_shifts is None + + filename_tif = os.path.basename(tif) + user_ch_filename = f"{self.posData.basename}{self.user_ch_name}.tif" + + if not doAlign: + _npz = f"{os.path.splitext(tif)[0]}_aligned.npz" + if os.path.exists(_npz): + self.posData.all_npz_paths[i] = _npz + continue + + if filename_tif != user_ch_filename: + continue + + if not self.align: + continue + + # Align based on user_ch_name + aligned = True + self.logger.log(f"Aligning: {tif}") + + tif_data = load.imread(tif) + numFramesWith0s = self.dataPrepWin.detectTifAlignment( + tif_data, self.posData + ) + if self.align: + self.emitWarnTifAligned(numFramesWith0s, tif, self.posData) + if self.doAbort: + return + + # Alignment routine + if self.posData.SizeZ > 1: + align_func = core.align_frames_3D + df = self.posData.segmInfo_df.loc[self.posData.filename] + zz = df["z_slice_used_dataPrep"].to_list() + if not self.posData.filename.endswith("aligned") and self.align: + # Add aligned channel to segmInfo + df_aligned = self.posData.segmInfo_df.rename( + index={ + self.posData.filename: f"{self.posData.filename}_aligned" + } + ) + self.posData.segmInfo_df = pd.concat( + [self.posData.segmInfo_df, df_aligned] + ) + self.posData.segmInfo_df.to_csv(self.posData.segmInfo_df_csv_path) + else: + align_func = core.align_frames_2D + zz = None + + if self.align: + self.signals.initProgressBar.emit(len(tif_data)) + aligned_frames, shifts = align_func( + tif_data, + slices=zz, + user_shifts=self.posData.loaded_shifts, + sigPyqt=self.signals.progressBar, + ) + self.posData.loaded_shifts = shifts + else: + aligned_frames = tif_data + + if self.align: + self.signals.initProgressBar.emit(0) + _npz = f"{os.path.splitext(tif)[0]}_aligned.npz" + self.logger.log(f"Storing temporary file: {_npz}") + temp_npz = self.dataPrepWin.getTempfilePath(_npz) + io.savez_compressed(temp_npz, aligned_frames) + self.dataPrepWin.storeTempFileMove(temp_npz, _npz) + np.save(self.posData.align_shifts_path, self.posData.loaded_shifts) + self.posData.all_npz_paths[i] = _npz + + self.logger.log(f"Storing temporary file: {tif}") + temp_tif = self.dataPrepWin.getTempfilePath(tif) + utils.to_tiff(temp_tif, aligned_frames) + self.dataPrepWin.storeTempFileMove(temp_tif, tif) + self.posData.img_data = load.imread(temp_tif) + + _zip = zip(self.posData.tif_paths, self.posData.npz_paths) + for i, (tif, npz) in enumerate(_zip): + doAlign = npz is None or aligned + + if not doAlign: + continue + + if tif.endswith(f"{self.user_ch_name}.tif"): + continue + + if not self.align: + continue + + # Align the other channels + if self.posData.loaded_shifts is None: + break + + if self.align: + self.logger.log(f"Aligning: {tif}") + tif_data = load.imread(tif) + + # Alignment routine + if self.posData.SizeZ > 1: + align_func = core.align_frames_3D + df = self.posData.segmInfo_df.loc[self.posData.filename] + zz = df["z_slice_used_dataPrep"].to_list() + else: + align_func = core.align_frames_2D + zz = None + if self.align: + self.signals.initProgressBar.emit(len(tif_data)) + aligned_frames, shifts = align_func( + tif_data, + slices=zz, + user_shifts=self.posData.loaded_shifts, + sigPyqt=self.signals.progressBar, + ) + else: + aligned_frames = tif_data + + _npz = f"{os.path.splitext(tif)[0]}_aligned.npz" + + if self.align: + self.signals.initProgressBar.emit(0) + self.logger.log(f"Saving: {_npz}") + temp_npz = self.dataPrepWin.getTempfilePath(_npz) + io.savez_compressed(temp_npz, aligned_frames) + self.dataPrepWin.storeTempFileMove(temp_npz, _npz) + self.posData.all_npz_paths[i] = _npz + + self.logger.log(f"Saving: {tif}") + temp_tif = self.dataPrepWin.getTempfilePath(tif) + utils.to_tiff(temp_tif, aligned_frames) + self.dataPrepWin.storeTempFileMove(temp_tif, tif) + + if not aligned: + return + + if not self.posData.segmFound: + return + + # Align segmentation data accordingly + self.segmAligned = False + if self.posData.loaded_shifts is None or not self.align: + return + + self.emitSigAskAlignSegmData() + if self.doNotAlignSegmData: + return + + self.dataPrepWin.segmAligned = True + self.logger.log(f"Aligning: {self.posData.segm_npz_path}") + self.posData.segm_data, shifts = core.align_frames_2D( + self.posData.segm_data, slices=None, user_shifts=self.posData.loaded_shifts + ) + self.logger.log(f"Saving: {self.posData.segm_npz_path}") + temp_npz = self.dataPrepWin.getTempfilePath(self.posData.segm_npz_path) + io.savez_compressed(temp_npz, self.posData.segm_data) + self.dataPrepWin.storeTempFileMove(temp_npz, self.posData.segm_npz_path) + + @worker_exception_handler + def run(self): + self._align_data() + self.signals.finished.emit(self) + + +class AlignWorker(BaseWorkerUtil): + sigAborted = Signal() + sigAskUseSavedShifts = Signal(str, str) + sigAskSelectChannel = Signal(list) + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitAskUseSavedShifts(self, expPath, basename): + self.mutex.lock() + self.sigAskUseSavedShifts.emit(expPath, basename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitAskSelectChannel(self, channels): + self.mutex.lock() + self.sigAskSelectChannel.emit(channels) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + @worker_exception_handler + def run(self): + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + shiftsFound = False + for pos in pos_foldernames: + images_path = os.path.join(exp_path, pos, "Images") + ls = utils.listdir(images_path) + for file in ls: + if file.endswith("align_shift.npy"): + shiftsFound = True + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + break + if shiftsFound: + break + + savedShiftsHow = None + if shiftsFound: + basename_ch0 = f"{basename}{chNames[0]}_" + abort = self.emitAskUseSavedShifts(exp_path, basename_ch0) + if abort: + self.sigAborted.emit() + return + + savedShiftsHow = self.savedShiftsHow + + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log("*" * 40) + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + pos_path = os.path.join(exp_path, pos) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + + self.signals.sigUpdatePbarDesc.emit(f"Loading {pos_path}...") + + if p == 0: + self.logger.log(f"Asking to select reference channel...") + abort = self.emitAskSelectChannel(chNames) + if abort: + self.sigAborted.emit() + return + chName = self.chName + + file_path = utils.getChannelFilePath(images_path, chName) + + # Load data + posData = load.loadData(file_path, chName) + posData.getBasenameAndChNames(useExt=(".tif", ".h5")) + posData.buildPaths() + posData.loadImgData() + + posData.loadOtherFiles( + load_segm_data=False, load_shifts=True, loadSegmInfo=True + ) + + if posData.img_data.ndim == 4: + align_func = core.align_frames_3D + if posData.segmInfo_df is None: + raise FileNotFoundError( + "To align 4D data you need to select which z-slice " + "you want to use for alignment. Please run the module " + "`1. Launch data prep module...` before aligning the " + "frames. (z-slice info MISSING from position " + f'"{posData.relPath}")' + ) + df = posData.segmInfo_df.loc[posData.filename] + zz = df["z_slice_used_dataPrep"].to_list() + elif posData.img_data.ndim == 3: + align_func = core.align_frames_2D + zz = None + + useSavedShifts = ( + savedShiftsHow == "use_saved_shifts" + and posData.loaded_shifts is not None + ) + if useSavedShifts: + user_shifts = posData.loaded_shifts + else: + user_shifts = None + + if savedShiftsHow == "rever_alignment": + if posData.loaded_shifts is None: + self.logger.log( + f'WARNING: Cannot revert alignment in "{posData.relPath}" ' + "since it is missing previously computed shifts. " + "Skipping this positon." + ) + continue + + # Revert alignment and save selected channel + for chName in chNames: + self.logger.log(f'Reverting alignment on "{chName}"...') + if chName == posData.user_ch_name: + data = posData.img_data + else: + file_path = utils.getChannelFilePath(images_path, chName) + data = load.load_image_file(file_path) + + self.signals.sigInitInnerPbar.emit(len(data) - 1) + revertedData = core.revert_alignment( + posData.loaded_shifts, + data, + sigPyqt=self.signals.sigUpdateInnerPbar, + ) + self.logger.log(f'Saving "{chName}"...') + self.signals.sigInitInnerPbar.emit(0) + self.saveAlignedData( + revertedData, + images_path, + posData.basename, + chName, + self.revertedAlignEndname, + ext=posData.ext, + ) + del revertedData, data + else: + for chName in chNames: + self.logger.log(f'Aligning "{chName}"...') + if chName == posData.user_ch_name: + data = posData.img_data + else: + file_path = utils.getChannelFilePath(images_path, chName) + data = load.load_image_file(file_path) + self.signals.sigInitInnerPbar.emit(len(data) - 1) + + alignedImgData, shifts = align_func( + data, + slices=zz, + user_shifts=user_shifts, + sigPyqt=self.signals.sigUpdateInnerPbar, + ) + self.logger.log(f'Saving "{chName}"...') + np.save(posData.align_shifts_path, shifts) + + self.signals.sigInitInnerPbar.emit(0) + self.saveAlignedData( + alignedImgData, + images_path, + posData.basename, + chName, + "", + ext=posData.non_aligned_ext, + ) + self.saveAlignedData( + alignedImgData, + images_path, + posData.basename, + chName, + "aligned", + ext=".npz", + ) + del alignedImgData, data + + self.signals.finished.emit(self) + + def saveAlignedData(self, data, imagesPath, basename, chName, endname, ext=".tif"): + if endname: + newFilename = f"{basename}{chName}_{endname}{ext}" + else: + newFilename = f"{basename}{chName}{ext}" + + filePath = os.path.join(imagesPath, newFilename) + + if ext == ".tif": + SizeT = data.shape[0] + SizeZ = 1 + if data.ndim == 4: + SizeZ = data.shape[1] + utils.to_tiff(filePath, data) + elif ext == ".npz": + io.savez_compressed(filePath, data) + elif ext == ".h5": + load.save_to_h5(filePath, data) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/data_prep.py b/cellacdc/workers/data_prep.py new file mode 100644 index 000000000..828889aba --- /dev/null +++ b/cellacdc/workers/data_prep.py @@ -0,0 +1,1276 @@ +"""Background Qt workers: data_prep.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class reapplyDataPrepWorker(QObject): + finished = Signal() + debug = Signal(object) + critical = Signal(object) + progress = Signal(str) + initPbar = Signal(int) + updatePbar = Signal() + sigCriticalNoChannels = Signal(str) + sigSelectChannels = Signal(object, object, object, str) + + def __init__(self, expPath, posFoldernames): + super().__init__() + self.expPath = expPath + self.posFoldernames = posFoldernames + self.abort = False + self.mutex = QMutex() + self.waitCond = QWaitCondition() + + def raiseSegmInfoNotFound(self, path): + raise FileNotFoundError( + "The following file is required for the alignment of 4D data " + f'but it was not found: "{path}"' + ) + + def saveBkgrData(self, imageData, posData, isAligned=False): + bkgrROI_data = {} + for r, roi in enumerate(posData.bkgrROIs): + xl, yt = [int(round(c)) for c in roi.pos()] + w, h = [int(round(c)) for c in roi.size()] + if not yt + h > yt or not xl + w > xl: + # Prevent 0 height or 0 width roi + continue + is4D = posData.SizeT > 1 and posData.SizeZ > 1 + is3Dz = posData.SizeT == 1 and posData.SizeZ > 1 + is3Dt = posData.SizeT > 1 and posData.SizeZ == 1 + is2D = posData.SizeT == 1 and posData.SizeZ == 1 + if is4D: + bkgr_data = imageData[:, :, yt : yt + h, xl : xl + w] + elif is3Dz or is3Dt: + bkgr_data = imageData[:, yt : yt + h, xl : xl + w] + elif is2D: + bkgr_data = imageData[yt : yt + h, xl : xl + w] + bkgrROI_data[f"roi{r}_data"] = bkgr_data + + if not bkgrROI_data: + return + + if isAligned: + bkgr_data_fn = f"{posData.filename}_aligned_bkgrRoiData.npz" + else: + bkgr_data_fn = f"{posData.filename}_bkgrRoiData.npz" + bkgr_data_path = os.path.join(posData.images_path, bkgr_data_fn) + self.progress.emit("Saving background data to:") + self.progress.emit(bkgr_data_path) + io.savez_compressed(bkgr_data_path, **bkgrROI_data) + + def run(self): + ch_name_selector = prompts.select_channel_name( + which_channel="segm", allow_abort=False + ) + for p, pos in enumerate(self.posFoldernames): + if self.abort: + break + + self.progress.emit(f"Processing {pos}...") + + posPath = os.path.join(self.expPath, pos) + imagesPath = os.path.join(posPath, "Images") + + ls = utils.listdir(imagesPath) + if p == 0: + ch_names, basenameNotFound = ch_name_selector.get_available_channels( + ls, imagesPath + ) + if not ch_names: + self.sigCriticalNoChannels.emit(imagesPath) + break + self.mutex.lock() + if len(self.posFoldernames) == 1: + # User selected only one pos --> allow selecting and adding + # and external .tif file that will be renamed with the basename + basename = ch_name_selector.basename + else: + basename = None + self.sigSelectChannels.emit( + ch_name_selector, ch_names, imagesPath, basename + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + break + + self.progress.emit(f"Selected channels: {self.selectedChannels}") + + for chName in self.selectedChannels: + filePath = load.get_filename_from_channel(imagesPath, chName) + posData = load.loadData(filePath, chName) + posData.getBasenameAndChNames() + posData.buildPaths() + posData.loadImgData() + posData.loadOtherFiles( + load_segm_data=False, + getTifPath=True, + load_metadata=True, + load_shifts=True, + load_dataPrep_ROIcoords=True, + loadBkgrROIs=True, + ) + + imageData = posData.img_data + + prepped = False + isAligned = False + # Align + if posData.loaded_shifts is not None: + self.progress.emit("Aligning frames...") + shifts = posData.loaded_shifts + if imageData.ndim == 4: + align_func = core.align_frames_3D + else: + align_func = core.align_frames_2D + imageData, _ = align_func(imageData, user_shifts=shifts) + prepped = True + isAligned = True + + # Crop and save background + if posData.dataPrep_ROIcoords is not None: + df = posData.dataPrep_ROIcoords + isCropped = int(df.at["cropped", "value"]) == 1 + if isCropped: + self.saveBkgrData(imageData, posData, isAligned) + self.progress.emit("Cropping...") + x0 = int(df.at["x_left", "value"]) + y0 = int(df.at["y_top", "value"]) + x1 = int(df.at["x_right", "value"]) + y1 = int(df.at["y_bottom", "value"]) + if imageData.ndim == 4: + imageData = imageData[:, :, y0:y1, x0:x1] + elif imageData.ndim == 3: + imageData = imageData[:, y0:y1, x0:x1] + elif imageData.ndim == 2: + imageData = imageData[y0:y1, x0:x1] + prepped = True + else: + filename = os.path.basename(posData.dataPrepBkgrROis_path) + self.progress.emit( + f'WARNING: the file "{filename}" was not found. ' + "I cannot crop the data." + ) + + if prepped: + self.progress.emit("Saving prepped data...") + io.savez_compressed(posData.align_npz_path, imageData) + if hasattr(posData, "tif_path"): + utils.to_tiff(posData.tif_path, imageData) + + self.updatePbar.emit() + if self.abort: + break + self.finished.emit() + + +class ImagesToPositionsWorker(QObject): + finished = Signal() + debug = Signal(object) + critical = Signal(object) + progress = Signal(str) + initPbar = Signal(int) + updatePbar = Signal() + + def __init__(self, folderPath, targetFolderPath, appendText): + super().__init__() + self.abort = False + self.folderPath = folderPath + self.targetFolderPath = targetFolderPath + self.appendText = appendText + + @worker_exception_handler + def run(self): + self.progress.emit(f'Selected folder: "{self.folderPath}"') + self.progress.emit(f'Target folder: "{self.targetFolderPath}"') + self.progress.emit(" ") + ls = utils.listdir(self.folderPath) + numFiles = len(ls) + self.initPbar.emit(numFiles) + numPosDigits = len(str(numFiles)) + if numPosDigits == 1: + numPosDigits = 2 + pos = 1 + for file in ls: + if self.abort: + break + + filePath = os.path.join(self.folderPath, file) + if os.path.isdir(filePath): + # Skip directories + self.updatePbar.emit() + continue + + self.progress.emit(f"Loading file: {file}") + filename, ext = os.path.splitext(file) + s0p = str(pos).zfill(numPosDigits) + try: + data = load.imread(filePath) + if data.ndim == 3 and (data.shape[-1] == 3 or data.shape[-1] == 4): + self.progress.emit("Converting RGB image to grayscale...") + data = skimage.color.rgb2gray(data) + data = skimage.img_as_ubyte(data) + + posName = f"Position_{pos}" + posPath = os.path.join(self.targetFolderPath, posName) + imagesPath = os.path.join(posPath, "Images") + if not os.path.exists(imagesPath): + os.makedirs(imagesPath, exist_ok=True) + newFilename = f"s{s0p}_{filename}_{self.appendText}.tif" + relPath = os.path.join(posName, "Images", newFilename) + tifFilePath = os.path.join(imagesPath, newFilename) + self.progress.emit(f"Saving to file: ...{os.sep}{relPath}") + utils.to_tiff(tifFilePath, data) + pos += 1 + except Exception as e: + self.progress.emit( + f"WARNING: {file} is not a valid image file. Skipping it." + ) + + self.progress.emit(" ") + self.updatePbar.emit() + + if self.abort: + break + self.finished.emit() + + +class DataPrepSaveBkgrDataWorker(QObject): + def __init__(self, posData, dataPrepWin): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.posData = posData + self.dataPrepWin = dataPrepWin + + @worker_exception_handler + def run(self): + self.dataPrepWin.saveBkgrData(self.posData) + self.signals.finished.emit(self) + + +class DataPrepCropWorker(QObject): + def __init__(self, posData, dataPrepWin, dstPath): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.posData = posData + self.dataPrepWin = dataPrepWin + self.dstPath = dstPath + + @worker_exception_handler + def run(self): + self.dataPrepWin.saveSingleCrop( + self.posData, self.posData.cropROIs[0], self.dstPath + ) + self.signals.finished.emit(self) + + +class RestructMultiPosWorker(BaseWorkerUtil): + sigSaveTiff = Signal(str, object, object) + + def __init__(self, rootFolderPath, dstFolderPath, action="copy"): + super().__init__(None) + self.rootFolderPath = rootFolderPath + self.dstFolderPath = dstFolderPath + self.mutex = QMutex() + self.waitCond = QWaitCondition() + self.action = action + + @worker_exception_handler + def run(self): + load._restructure_multi_files_multi_pos( + self.rootFolderPath, + self.dstFolderPath, + signals=self.signals, + logger=self.logger.log, + action=self.action, + ) + self.signals.finished.emit(self) + + +class RestructMultiTimepointsWorker(BaseWorkerUtil): + sigSaveTiff = Signal(str, object, object) + + def __init__( + self, + allChannels, + frame_name_pattern, + basename, + validFilenames, + rootFolderPath, + dstFolderPath, + segmFolderPath="", + ): + super().__init__(None) + self.allChannels = allChannels + self.frame_name_pattern = frame_name_pattern + self.basename = basename + self.validFilenames = validFilenames + self.rootFolderPath = rootFolderPath + self.dstFolderPath = dstFolderPath + self.segmFolderPath = segmFolderPath + self.mutex = QMutex() + self.waitCond = QWaitCondition() + + @worker_exception_handler + def run(self): + allChannels = self.allChannels + frame_name_pattern = self.frame_name_pattern + rootFolderPath = self.rootFolderPath + dstFolderPath = self.dstFolderPath + segmFolderPath = self.segmFolderPath + filesInfo = {} + self.signals.initProgressBar.emit(len(self.validFilenames) + 1) + for file in self.validFilenames: + try: + # Determine which channel is this file + for ch in allChannels: + m = re.findall(rf"(.*)_{ch}{frame_name_pattern}", file) + if m: + break + else: + raise FileNotFoundError( + f'The file name "{file}" does not contain any channel name' + ) + posName, _, frameName = m[0] + frameNumber = int(frameName) + if posName not in filesInfo: + filesInfo[posName] = {ch: [(file, frameNumber)]} + elif ch not in filesInfo[posName]: + filesInfo[posName][ch] = [(file, frameNumber)] + else: + filesInfo[posName][ch].append((file, frameNumber)) + except Exception as e: + self.logger.log(traceback.format_exc()) + self.logger.log( + f'WARNING: File "{file}" does not contain valid pattern. ' + "Skipping it." + ) + continue + + self.signals.progressBar.emit(1) + + df_metadata = None + partial_basename = self.basename + allPosDataInfo = [] + for p, (posName, channelInfo) in enumerate(filesInfo.items()): + self.logger.log(f"=" * 40) + self.logger.log(f'Processing position "{posName}"...') + + for _, filesList in channelInfo.items(): + # Get info from first file + filePath = os.path.join(rootFolderPath, filesList[0][0]) + try: + img = load.imread(filePath) + break + except Exception as e: + self.logger.log(traceback.format_exc()) + continue + else: + self.logger.log( + f"WARNING: No valid image files found for position {posName}" + ) + continue + + # Get basename + if partial_basename: + basename = f"{partial_basename}_{posName}_" + else: + basename = f"{posName}_" + + # Get SizeT from first file + SizeT = len(filesList) + + # Save metadata.csv + df_metadata = pd.DataFrame( + {"SizeT": SizeT, "basename": basename}, index=["values"] + ) + + # Iterate channels + for c, (channelName, filesList) in enumerate(channelInfo.items()): + self.logger.log(f' Processing channel "{channelName}"...') + # Sort by frame number + sortedFilesList = sorted(filesList, key=lambda t: t[1]) + + df_metadata[f"channel_{c}_name"] = [channelName] + + imagesPath = os.path.join(dstFolderPath, f"Position_{p + 1}", "Images") + if not os.path.exists(imagesPath): + os.makedirs(imagesPath, exist_ok=True) + + # Iterate frames + videoData = None + srcSegmPaths = [""] * SizeT + frameNumbers = [] + for frame_i, fileInfo in enumerate(sortedFilesList): + file, _ = fileInfo + ext = os.path.splitext(file)[1] + srcImgFilePath = os.path.join(rootFolderPath, file) + try: + img = load.imread(srcImgFilePath) + if videoData is None: + shape = (SizeT, *img.shape) + videoData = np.zeros(shape, dtype=img.dtype) + videoData[frame_i] = img + pattern = self.frame_name_pattern + frameNumberMatch = re.findall(pattern, file)[0][1] + frameNumber = int(frameNumberMatch) + frameNumbers.append(frameNumber) + except Exception as e: + self.logger.log(traceback.format_exc()) + continue + + if segmFolderPath and c == 0: + srcSegmFilePath = os.path.join(segmFolderPath, file) + srcSegmPaths[frame_i] = srcSegmFilePath + + SizeZ = 1 + if img.ndim == 3: + SizeZ = len(img) + + df_metadata["SizeZ"] = [SizeZ] + + self.signals.progressBar.emit(1) + + if videoData is None: + self.logger.log( + f"WARNING: No valid image files found for position " + f'"{posName}", channel "{channelName}"' + ) + continue + else: + imgFileName = f"{basename}{channelName}.tif" + dstImgFilePath = os.path.join(imagesPath, imgFileName) + dstSegmFileName = f"{basename}segm_{channelName}.npz" + dstSegmPath = os.path.join(imagesPath, dstSegmFileName) + imgDataInfo = { + "path": dstImgFilePath, + "SizeT": SizeT, + "SizeZ": SizeZ, + "data": videoData, + "frameNumbers": frameNumbers, + "dst_segm_path": dstSegmPath, + "src_segm_paths": srcSegmPaths, + } + allPosDataInfo.append(imgDataInfo) + + if df_metadata is not None: + metadata_csv_path = os.path.join(imagesPath, f"{basename}metadata.csv") + df_metadata = df_metadata.T + df_metadata.index.name = "Description" + df_metadata.to_csv(metadata_csv_path) + + self.logger.log(f"*" * 40) + + if not allPosDataInfo: + self.signals.finished.emit(self) + return + + self.signals.initProgressBar.emit(len(allPosDataInfo)) + self.logger.log("Saving image files...") + maxSizeT = max([d["SizeT"] for d in allPosDataInfo]) + minFrameNumber = min([d["frameNumbers"][0] for d in allPosDataInfo]) + # Pad missing frames in video files according to frame number + for p, imgDataInfo in enumerate(allPosDataInfo): + SizeT = imgDataInfo["SizeT"] + SizeZ = imgDataInfo["SizeZ"] + dstImgFilePath = imgDataInfo["path"] + videoData = imgDataInfo["data"] + frameNumbers = imgDataInfo["frameNumbers"] + paddedShape = (maxSizeT, *videoData.shape[1:]) + imgDataInfo["paddedShape"] = paddedShape + dtype = videoData.dtype + paddedVideoData = np.zeros(paddedShape, dtype=dtype) + for n, img in zip(frameNumbers, videoData): + frame_i = n - minFrameNumber + paddedVideoData[frame_i] = img + + del videoData + imgDataInfo["data"] = None + + self.mutex.lock() + self.sigSaveTiff.emit(dstImgFilePath, paddedVideoData, self.waitCond) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + self.signals.progressBar.emit(1) + + if not segmFolderPath: + self.signals.finished.emit(self) + return + + self.signals.initProgressBar.emit(len(allPosDataInfo)) + self.logger.log("Saving segmentation files...") + for p, imgDataInfo in enumerate(allPosDataInfo): + SizeT = imgDataInfo["SizeT"] + frameNumbers = imgDataInfo["frameNumbers"] + SizeT = imgDataInfo["SizeT"] + SizeZ = imgDataInfo["SizeZ"] + frameNumbers = imgDataInfo["frameNumbers"] + paddedShape = imgDataInfo["paddedShape"] + segmData = np.zeros(paddedShape, dtype=np.uint32) + for n, segmFilePath in zip(frameNumbers, imgDataInfo["src_segm_paths"]): + frame_i = n - minFrameNumber + try: + lab = load.imread(segmFilePath).astype(np.uint32) + segmData[frame_i] = lab + except Exception as e: + self.logger.log(traceback.format_exc()) + self.logger.log( + "WARNING: The following segmentation file does not " + f'exist, saving empty masks: "{srcSegmFilePath}"' + ) + + io.savez_compressed(imgDataInfo["dst_segm_path"], segmData) + del segmData + + self.signals.finished.emit(self) + + +class FucciPreprocessWorker(BaseWorkerUtil): + sigAskAppendName = Signal(str) + sigAskParams = Signal(object, object) + sigAborted = Signal() + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitAskParams(self, exp_path, pos_foldernames): + self.mutex.lock() + self.sigAskParams.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def applyPipeline(self, first_ch_data, second_ch_data, filter_kwargs): + processed_data = np.zeros(first_ch_data.shape, dtype=np.uint8) + pbar = tqdm(total=len(processed_data), ncols=100) + with concurrent.futures.ThreadPoolExecutor() as executor: + iterable = enumerate(zip(first_ch_data, second_ch_data)) + func = partial(core.fucci_pipeline_executor_map, **filter_kwargs) + result = executor.map(func, iterable) + for frame_i, processed_img in result: + processed_img = skimage.exposure.rescale_intensity( + processed_img, out_range=(0, 255) + ) + processed_img = processed_img.astype(np.uint8) + processed_data[frame_i] = processed_img + pbar.update() + pbar.close() + + return processed_data + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = f"Setup parameters" + + if i == 0: + abort = self.emitAskParams(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit(self.basename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + + self.logger.log(f"Loading {self.firstChannelName} channel data...") + first_ch_filepath = load.get_filename_from_channel( + images_path, self.firstChannelName + ) + first_ch_data = load.load_image_file(first_ch_filepath) + + self.logger.log(f"Loading {self.secondChannelName} channel data...") + second_ch_filepath = load.get_filename_from_channel( + images_path, self.secondChannelName + ) + second_ch_data = load.load_image_file(second_ch_filepath) + + self.logger.log("Applying FUCCI pre-processing pipeline...\n") + processed_data = self.applyPipeline( + first_ch_data, second_ch_data, self.fucciFilterKwargs + ) + + basename, chNames = utils.getBasenameAndChNames(images_path) + _, ext = os.path.splitext(first_ch_filepath) + processed_filename = f"{basename}{appendedName}{ext}" + processed_filepath = os.path.join(images_path, processed_filename) + self.logger.log( + f'Saving pre-processed images to "{processed_filepath}"...' + ) + io.save_image_data(processed_filepath, processed_data) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class SaveProcessedDataWorker(QObject): + def __init__( + self, + allPosData: Iterable["load.loadData"], + appended_text_filename: str, + ext: str = None, + ): + QObject.__init__(self) + self.allPosData = allPosData + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.appended_text_filename = appended_text_filename + self.ext = ext + + @worker_exception_handler + def run(self): + self.signals.initProgressBar.emit(0) + for posData in self.allPosData: + ext_loc = self.ext if self.ext is not None else posData.ext + processed_filename = ( + f"{posData.basename}{posData.user_ch_name}_" + f"{self.appended_text_filename}{ext_loc}" + ) + processed_filepath = os.path.join(posData.images_path, processed_filename) + self.logger.log(f"Saving {processed_filepath}...") + processed_data = posData.preprocessedDataArray() + if processed_data is None: + self.logger.log( + f"[WARNING]: {posData.pos_foldername} does not have " + "preprocessed data. Skipping it." + ) + continue + + io.save_image_data(processed_filepath, processed_data) + + self.signals.finished.emit(self) + + +class SaveCombinedChannelsWorker(QObject): + sigDebugShowImg = Signal(object) + + def __init__( + self, allPosData: Iterable["load.loadData"], filename: str, debug: bool = False + ): + QObject.__init__(self) + self.allPosData = allPosData + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.filename = filename + self.debug = debug + + @worker_exception_handler + def run(self): + self.signals.initProgressBar.emit(0) + for posData in self.allPosData: + processed_filepath = os.path.join(posData.images_path, self.filename) + self.logger.log(f"Saving {processed_filepath}...") + processed_data = posData.combinedChannelsDataArray() + if processed_data is None: + self.logger.log( + f"[WARNING]: {posData.pos_foldername} does not have " + "combined channels data. Skipping it." + ) + continue + if self.debug: + printl(processed_data.shape) + printl(processed_data.dtype) + printl(processed_data.min()) + printl(processed_data.max()) + printl(processed_filepath) + self.sigDebugShowImg.emit(processed_data) + # cellacdc.plot.imshow(processed_data) + io.save_image_data(processed_filepath, processed_data) + + self.signals.finished.emit(self) + + +class CustomPreprocessWorkerGUI(QObject): + sigDone = Signal(object, str) + sigPreviewDone = Signal(object, tuple) + sigIsQueueEmpty = Signal(bool) + + def __init__(self, mutex, waitCond): + QObject.__init__(self) + self.signals = signals() + self.mutex = mutex + self.waitCond = waitCond + self.logger = workerLogger(self.signals.progress) + self.dataQ = deque(maxlen=2) + self.exit = False + self.wait = True + self._abort = False + + def enqueue( + self, + func: Callable, + image: np.ndarray, + recipe: Dict[str, Any], + key: Tuple[int, int, Union[int, str]], + ): + self.dataQ.append((func, image, recipe, key)) + if len(self.dataQ) == 1: + self.sigIsQueueEmpty.emit(False) + # Wake up worker upon inserting first element + self.wakeUp() + + def wakeUp(self): + self.wait = False + self.waitCond.wakeAll() + + def pause(self): + self.wait = True + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def abort(self): + self._abort = True + + def stop(self): + self.abort() + self.exit = True + self.waitCond.wakeAll() + self.signals.finished.emit(self) + + def setupJob( + self, func: Callable, image_data: np.ndarray, recipe: Dict[str, Any], how: str + ): + self._func = func + self._image_data = image_data + self._recipe = recipe + self._how = how + + def runJob(self, image=None, recipe=None): + if image is None: + image = self._image_data.copy() + if recipe is None: + recipe = self._recipe + + return self.applyRecipe(self._func, image, recipe) + + def applyRecipe( + self, func: Callable, image: np.ndarray, recipe: List[Dict[str, Any]] + ): + preprocessed_data = func(image, recipe) + + keep_input_data_type = recipe[0].get("keep_input_data_type", True) + if not keep_input_data_type: + return preprocessed_data + + try: + preprocessed_data = utils.convert_to_dtype(preprocessed_data, image.dtype) + except Exception as err: + preprocessed_data = preprocessed_data.astype(image.dtype) + return preprocessed_data + + @worker_exception_handler + def run(self): + while True: + if self.exit: + self.logger.log("Closing pre-processing worker...") + break + elif self.wait: + self.logger.log("Pre-processing worker paused.") + self.pause() + elif len(self.dataQ) > 0: + func, image, recipe, key = self.dataQ.pop() + processed_data = self.applyRecipe(func, image, recipe) + self.sigPreviewDone.emit(processed_data, key) + if len(self.dataQ) == 0: + self.wait = True + self.sigIsQueueEmpty.emit(True) + else: + self.logger.log("Pre-processing worker resumed.") + processed_data = self.runJob() + self.sigDone.emit(processed_data, self._how) + self.wait = True + + self.signals.finished.emit(self) + + +class CombineChannelsWorkerGUI(CustomPreprocessWorkerGUI): + sigDone = Signal(object, list) + sigPreviewDone = Signal(object, list) + sigAskLoadChannels = Signal(set, object) + + def __init__( + self, + mutex, + waitCond, + logger_func: Callable, + ): + # signals_parent=None): + super().__init__(mutex, waitCond) + + self.waitCondLoadFluoChannels = QWaitCondition() + self.logger_func = logger_func + + # if not signals_parent: + # signals_parent = signals() + + # self.signals = signals_parent + + def enqueue( + self, + data, + steps: Dict[str, Any], + key: Tuple[int, int, Union[int, str]], + keep_input_data_type: bool, + output_as_segm: bool, + formula: str, + ): + self.dataQ.append( + (data, steps, key, keep_input_data_type, output_as_segm, formula) + ) + if len(self.dataQ) == 1: + self.sigIsQueueEmpty.emit(False) + # Wake up worker upon inserting first element + self.wakeUp() + + def setupJob( + self, + data: Dict[str, np.ndarray], + steps: Dict[str, Any], + keep_input_data_type: bool, + key: Tuple[Union[int, None], Union[int, None], Union[int, None]], + output_as_segm: bool, + formula: str, + ): + self._key = key + self._steps = steps + self._data = data + self._keep_input_data_type = keep_input_data_type + self._output_as_segm = output_as_segm + self._formula = formula + + def runJob( + self, + data=None, + steps=None, + keep_input_data_type=None, + key=None, + output_as_segm=None, + formula=None, + ): + if data is None: + data = self._data + if steps is None: + steps = self._steps + if keep_input_data_type is None: + keep_input_data_type = self._keep_input_data_type + if key is None: + key = self._key + if output_as_segm is None: + output_as_segm = self._output_as_segm + if formula is None: + formula = self._formula + + if not steps and formula is None: + return + + return self.applySteps( + data, steps, keep_input_data_type, key, output_as_segm, formula=formula + ) + + def applySteps( + self, + data: Dict[str, np.ndarray], + steps: List[Dict[str, Any]], + keep_input_data_type: bool, + key: Tuple[Union[int, None], Union[int, None], Union[int, None]], + output_as_segm: bool, + formula: str, + ): + + new_keys = [] + key = list(key) + if key[0] is None: + pos_number = len(data) + key[0] = list(range(pos_number)) + else: + key[0] = [key[0]] + + for pos_i in key[0]: + new_keys_per_pos = [[pos_i]] + if key[1] is None: + frames = data[pos_i].SizeT + new_keys_per_pos.append(list(range(frames))) + else: + new_keys_per_pos.append([key[1]]) + + if key[2] is None: + z_slices = data[pos_i].SizeZ + if not z_slices: + z_slices = 1 + new_keys_per_pos.append(list(range(z_slices))) + else: + new_keys_per_pos.append([key[2]]) + + new_keys_per_pos = list(itertools.product(*new_keys_per_pos)) + new_keys.extend(new_keys_per_pos) + + output_imgs, out_keys = core.combine_channels_multithread_return_imgs( + steps=steps, + data=data, + keep_input_data_type=keep_input_data_type, + keys=new_keys, + logger_func=self.logger, + signals=self.signals, + output_as_segm=output_as_segm, + formula=formula, + ) + return output_imgs, out_keys + + def requiredChannels(self, steps=None, pos_i=None): + if steps is None: + steps = self._steps + + required_channels = core.get_selected_channels(steps) + if pos_i is None: + pos_i = self._key[0] + + return required_channels, pos_i + + @worker_exception_handler + def run(self): + while True: + if self.exit: + self.logger.log("Closing combining channels worker...") + break + elif self.wait: + self.logger.log("Combining channels worker paused.") + self.pause() + elif len(self.dataQ) > 0: + data, steps, key, keep_input_data_type, output_as_segm, formula = ( + self.dataQ.pop() + ) + requ_steps, pos_i = self.requiredChannels(steps, key[0]) + self.emitsigAskLoadChannels(requ_steps, pos_i) + output_imgs, out_keys = self.applySteps( + data, + steps, + keep_input_data_type, + key, + output_as_segm=output_as_segm, + formula=formula, + ) + self.sigPreviewDone.emit(output_imgs, out_keys) + if len(self.dataQ) == 0: + self.wait = True + self.sigIsQueueEmpty.emit(True) + else: + self.logger.log("Combining channels worker resumed.") + requ_steps, pos_i = self.requiredChannels() + self.emitsigAskLoadChannels(requ_steps, pos_i) + output_imgs, out_keys = self.runJob() + self.sigDone.emit(output_imgs, out_keys) + self.wait = True + + self.signals.finished.emit(self) + + def emitsigAskLoadChannels(self, requChannels, pos_i): + self.mutex.lock() + self.sigAskLoadChannels.emit(requChannels, pos_i) + self.waitCondLoadFluoChannels.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def wake_waitCondLoadFluoChannels(self): + self.mutex.lock() + self.waitCondLoadFluoChannels.wakeAll() + self.mutex.unlock() + + +class CustomPreprocessWorkerUtil(BaseWorkerUtil): + sigAskAppendName = Signal(str) + sigAskSetupRecipe = Signal(object, object) + sigAborted = Signal() + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitAskSetupRecipe(self, exp_path, pos_foldernames): + self.mutex.lock() + self.sigAskSetupRecipe.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def applyPipeline( + self, + images_path: os.PathLike, + channel_names: Iterable[str], + recipe: List[Dict[str, Any]], + appended_text_filename: str, + ): + posData = None + preprocessed_data = {} + for channel in channel_names: + self.logger.log(f"Loading {channel} channel data...") + ch_filepath = load.get_filename_from_channel(images_path, channel) + ch_image_data = load.load_image_file(ch_filepath) + if posData is None: + posData = load.loadData(ch_filepath, channel) + posData.getBasenameAndChNames() + posData.buildPaths() + posData.loadOtherFiles( + load_segm_data=False, + load_metadata=True, + ) + if posData.SizeT == 1: + ch_image_data = (ch_image_data,) + + preprocessed_ch_data = core.preprocess_image_from_recipe_multithread( + ch_image_data, recipe + ) + + keep_input_data_type = recipe[0].get("keep_input_data_type", True) + if keep_input_data_type: + preprocessed_ch_data = utils.convert_to_dtype( + preprocessed_ch_data, ch_image_data.dtype + ) + + _, ext = os.path.splitext(ch_filepath) + basename = posData.basename + processed_filename = f"{basename}{channel}_{appended_text_filename}{ext}" + preprocessed_data[processed_filename] = preprocessed_ch_data + + return preprocessed_data + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = "Setup recipe" + + if i == 0: + abort = self.emitAskSetupRecipe(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Ask append name + self.mutex.lock() + basename = f"{self.basename}{self.selectedChannels[0]}_" + self.sigAskAppendName.emit(basename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + self.logger.log("Applying custom pre-processing recipe...\n") + processed_data = self.applyPipeline( + images_path, self.selectedChannels, self.recipe, appendedName + ) + + for filename, preprocessed_ch_data in processed_data.items(): + preprocessed_filepath = os.path.join(images_path, filename) + self.logger.log( + f'Saving pre-processed images to "{preprocessed_filepath}"...' + ) + + io.save_image_data(preprocessed_filepath, preprocessed_ch_data) + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class CombineChannelsWorkerUtil(BaseWorkerUtil): + sigAskAppendName = Signal(str) + sigAskSetup = Signal(object) + sigAborted = Signal() + + def __init__(self, mainWin, mutex=None, waitCond=None): + super().__init__(mainWin) + + def emitAskSetup(self, expPaths): + self.mutex.lock() + self.sigAskSetup.emit(expPaths) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def applyPipeline( + self, + image_paths: os.PathLike, + steps: Dict[str, Dict[str, Any]], + appended_text_filename: str, + keep_input_data_type: bool, + n_threads: int = None, + formula: str = None, + ): + save_filepaths = [] + images_path_to_process = [] + if self.saveAsSegm: + out_ext = ".npz" + basename_ext = "segm_" + else: + out_ext = ".tif" + basename_ext = "" + for images_path in image_paths: + basename, channels = utils.getBasenameAndChNames(images_path) + + savename = f"{basename}{basename_ext}{appended_text_filename}{out_ext}" + + images_path_to_process.append(images_path) + save_filepaths.append(os.path.join(images_path, savename)) + + core.combine_channels_multithread( + steps=steps, + images_paths=images_path_to_process, + keep_input_data_type=keep_input_data_type, + save_filepaths=save_filepaths, + signals=self.signals, + logger_func=self.logger.log, + n_threads=n_threads, + output_as_segm=self.saveAsSegm, + formula=formula, + ) + + @worker_exception_handler + def run(self): + + self.signals.initProgressBar.emit(0) + + expPaths = self.mainWin.expPaths + abort = self.emitAskSetup(expPaths) + if abort: + self.sigAborted.emit() + return + + # Ask append name + self.mutex.lock() + basename = f"{self.basename}" + self.sigAskAppendName.emit(basename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + + selectedSteps = self.selectedSteps + + self.logger.log("Applying pipeline...") + self.logger.log("Selected steps:") + for step in selectedSteps.values(): + self.logger.log(step) + + image_paths = [] + for exp_path, pos_foldernames in expPaths.items(): + image_paths += [ + os.path.join(exp_path, pos, "Images") for pos in pos_foldernames + ] + + self.signals.initProgressBar.emit(len(pos_foldernames)) + formula = self.formula + self.applyPipeline( + image_paths, + selectedSteps, + appendedName, + self.keepInputDataType, + n_threads=self.nThreads, + formula=formula, + ) + + self.signals.finished.emit(self) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/gui.py b/cellacdc/workers/gui.py new file mode 100644 index 000000000..3649b4cf9 --- /dev/null +++ b/cellacdc/workers/gui.py @@ -0,0 +1,110 @@ +"""Background Qt workers: gui.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +class AutoPilotWorker(QObject): + finished = Signal() + critical = Signal(object) + progress = Signal(str, object) + sigStarted = Signal() + sigStopTimer = Signal() + + def __init__(self, guiWin): + QObject.__init__(self) + self.logger = workerLogger(self.progress) + self.guiWin = guiWin + self.app = guiWin.app + # self.timer = timer + + def timerCallback(self): + pass + + def stop(self): + self.sigStopTimer.emit() + self.finished.emit() + + def run(self): + self.sigStarted.emit() + + +class FindNextNewIdWorker(QObject): + def __init__(self, posData, guiWin): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.posData = posData + self.guiWin = guiWin + + @worker_exception_handler + def run(self): + prev_IDs = None + next_frame_i = -1 + for frame_i, data_dict in enumerate(self.posData.allData_li): + lab = data_dict["labels"] + rp = data_dict["regionprops"] + IDs = data_dict["IDs"] + if lab is None: + lab = self.posData.segm_data[frame_i] + rp = skimage.measure.regionprops(lab) + IDs = [obj.label for obj in rp] + + if prev_IDs is None: + prev_IDs = IDs + continue + + newIDs = [ID for ID in IDs if ID not in prev_IDs] + if newIDs: + next_frame_i = frame_i + break + prev_IDs = IDs + + self.signals.finished.emit(next_frame_i) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/io.py b/cellacdc/workers/io.py new file mode 100644 index 000000000..60e92a114 --- /dev/null +++ b/cellacdc/workers/io.py @@ -0,0 +1,1117 @@ +"""Background Qt workers: io.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +class StoreGuiStateWorker(QObject): + finished = Signal(object) + sigDone = Signal() + progress = Signal(str, object) + + def __init__(self, mutex, waitCond): + QObject.__init__(self) + self.mutex = mutex + self.waitCond = waitCond + self.exit = False + self.isFinished = False + self.q = queue.Queue() + self.logger = workerLogger(self.progress) + + def pause(self): + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def enqueue(self, posData, img1): + self.q.put((posData, img1)) + self.waitCond.wakeAll() + + def _stop(self): + self.exit = True + self.waitCond.wakeAll() + + def run(self): + while True: + if self.exit: + self.logger.log("Closing store state worker...") + break + elif not self.q.empty(): + posData, img1 = self.q.get() + # self.logger.log('Storing state...') + if posData.cca_df is not None: + cca_df = posData.cca_df.copy() + else: + cca_df = None + + state = { + "image": img1.copy(), + "labels": posData.storedLab.copy(), + "editID_info": posData.editID_info.copy(), + "binnedIDs": posData.binnedIDs.copy(), + "ripIDs": posData.ripIDs.copy(), + "cca_df": cca_df, + } + posData.UndoRedoStates[posData.frame_i].insert(0, state) + if self.q.empty(): + # self.logger.log('State stored...') + self.sigDone.emit() + else: + self.pause() + + self.isFinished = True + self.finished.emit(self) + + +class AutoSaveWorker(QObject): + finished = Signal(object) + sigDone = Signal() + critical = Signal(object) + progress = Signal(str, object) + sigStartTimer = Signal(object, object) + sigStopTimer = Signal() + sigAutoSaveCannotProceed = Signal() + + def __init__(self, mutex, waitCond, savedSegmData): + QObject.__init__(self) + self.savedSegmData = savedSegmData + self.logger = workerLogger(self.progress) + self.mutex = mutex + self.waitCond = waitCond + self.exit = False + self.isFinished = False + self.stopSaving = False + self.isSaving = False + self.isPaused = False + self.dataQ = deque(maxlen=5) + self.isAutoSaveON = False + self.isAutoSaveAnnotON = True + self.debug = False + + def pause(self): + if self.debug: + self.logger.log("Autosaving is idle.") + self.mutex.lock() + self.isPaused = True + self.waitCond.wait(self.mutex) + self.mutex.unlock() + self.isPaused = False + + def enqueue(self, posData): + # First stop previously saving data + if self.isSaving: + self.stopSaving = True + self._enqueue(posData) + + def _enqueue(self, posData): + if self.debug: + self.logger.log("Enqueing posData autosave...") + self.dataQ.append(posData) + if len(self.dataQ) == 1: + # Wake up worker upon inserting first element + self.stopSaving = False + self.waitCond.wakeAll() + + def _stop(self): + self.exit = True + self.waitCond.wakeAll() + + def stop(self): + self.stopSaving = True + while not len(self.dataQ) == 0: + data = self.dataQ.pop() + del data + self._stop() + + def cancelSaving(self): ... + + @worker_exception_handler + def run(self): + while True: + if self.exit: + self.logger.log("Closing autosaving worker...") + break + elif not len(self.dataQ) == 0: + if self.debug: + self.logger.log("Autosaving...") + data = self.dataQ.pop() + self.isSaving = True + try: + self.saveData(data) + except Exception as e: + error = traceback.format_exc() + print("*" * 40) + self.logger.log(error) + print("=" * 40) + self.isSaving = False + + if len(self.dataQ) == 0: + self.sigDone.emit() + else: + self.pause() + self.isFinished = True + self.finished.emit(self) + if self.debug: + self.logger.log("Autosave finished signal emitted") + + def getLastTrackedFrame(self, posData): + last_tracked_i = 0 + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + frame_i -= 1 + break + if frame_i > 0: + return frame_i + else: + return last_tracked_i + + def saveData(self, posData): + if self.debug: + self.logger.log("Started autosaving...") + + if not self.isAutoSaveON and not self.isAutoSaveAnnotON: + return + + try: + posData.setTempPaths() + except Exception as e: + self.logger.log( + "[WARNING]: Cell-ACDC cannot create the recovery folder for " + "the autosaving process. Autosaving will be turned off." + ) + self.sigAutoSaveCannotProceed.emit() + return + segm_npz_path = posData.segm_npz_temp_path + + end_i = self.getLastTrackedFrame(posData) + + saved_segm_data = None + if self.isAutoSaveON: + if end_i < len(posData.segm_data): + saved_segm_data = posData.segm_data + else: + frame_shape = posData.segm_data.shape[1:] + segm_shape = (end_i + 1, *frame_shape) + saved_segm_data = np.zeros(segm_shape, dtype=np.uint32) + + keys = [] + acdc_df_li = [] + + for frame_i, data_dict in enumerate(posData.allData_li[: end_i + 1]): + if self.stopSaving: + break + + # Build saved_segm_data + lab = data_dict["labels"] + if lab is None: + break + + if self.isAutoSaveON and saved_segm_data is not None: + if posData.SizeT > 1: + saved_segm_data[frame_i] = lab + else: + saved_segm_data = lab + + if self.isAutoSaveAnnotON: + acdc_df = data_dict["acdc_df"] + + if acdc_df is None: + continue + + if not np.any(lab): + continue + + if self.isAutoSaveAnnotON: + acdc_df = load.pd_bool_and_float_to_int_to_str( + acdc_df, inplace=False, colsToCastInt=[] + ) + + acdc_df_li.append(acdc_df) + key = (frame_i, posData.TimeIncrement * frame_i) + keys.append(key) + + if self.stopSaving: + break + + if not self.stopSaving: + if self.isAutoSaveON: + segm_data = np.squeeze(saved_segm_data) + self._saveSegm(segm_npz_path, segm_data) + + if acdc_df_li: + all_frames_acdc_df = pd.concat( + acdc_df_li, keys=keys, names=["frame_i", "time_seconds", "Cell_ID"] + ) + self._save_acdc_df(all_frames_acdc_df, posData) + + if self.debug: + self.logger.log(f"Autosaving done.") + self.logger.log(f"Stopped autosaving {self.stopSaving}.") + + self.stopSaving = False + + def _saveSegm(self, recovery_path, data): + try: + equalToSavedSegm = np.all(self.savedSegmData == data) + except Exception as err: + return + + if equalToSavedSegm: + return + else: + io.savez_compressed(recovery_path, np.squeeze(data)) + + def _save_acdc_df(self, recovery_acdc_df: pd.DataFrame, posData): + recovery_folderpath = posData.recoveryFolderpath() + if not os.path.exists(posData.acdc_output_csv_path): + load.store_unsaved_acdc_df(recovery_folderpath, recovery_acdc_df) + return + + saved_acdc_df_path = posData.acdc_output_csv_path + saved_acdc_df = pd.read_csv( + saved_acdc_df_path, dtype=load.acdc_df_str_cols + ).set_index(["frame_i", "Cell_ID"]) + + recovery_acdc_df = recovery_acdc_df.reset_index( + allow_duplicates=True + ).set_index(["frame_i", "Cell_ID"]) + recovery_acdc_df = recovery_acdc_df.loc[ + :, ~recovery_acdc_df.columns.duplicated() + ] + try: + # Try to insert into the recovery_acdc_df any column that was saved + # but is not in the recovered df (e.g., metrics) + df_left = recovery_acdc_df + existing_cols = df_left.columns.intersection(saved_acdc_df.columns) + df_right = saved_acdc_df.drop(columns=existing_cols) + recovery_acdc_df = df_left.join(df_right, how="left") + except Exception as error: + self.logger.log(f"[WARNING]: {error}") + + # Check if last saved acdc_df is equal + last_unsaved_csv_path = load.get_last_stored_unsaved_acdc_df_filepath( + recovery_folderpath + ) + if last_unsaved_csv_path is None: + reference_acdc_df = saved_acdc_df + else: + try: + reference_acdc_df = pd.read_csv( + last_unsaved_csv_path, dtype=load.acdc_df_str_cols + ).set_index(["frame_i", "Cell_ID"]) + except Exception as e: + self.logger.log(f"[WARNING]: {e}") + reference_acdc_df = saved_acdc_df + + if utils.are_acdc_dfs_equal(recovery_acdc_df, reference_acdc_df): + return + + load.store_unsaved_acdc_df(recovery_folderpath, recovery_acdc_df) + + +class loadDataWorker(QObject): + def __init__(self, mainWin, user_ch_file_paths, user_ch_name, firstPosData): + QObject.__init__(self) + self.signals = signals() + self.mainWin = mainWin + self.user_ch_file_paths = user_ch_file_paths + self.user_ch_name = user_ch_name + self.logger = workerLogger(self.signals.progress) + self.mutex = self.mainWin.loadDataMutex + self.waitCond = self.mainWin.loadDataWaitCond + self.firstPosData = firstPosData + self.abort = False + self.loadUnsaved = False + self.recoveryAsked = False + self.loadSafeOverwriteNpz = False + + def pause(self): + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def checkSelectedDataShape(self, posData, numPos): + skipPos = False + abort = False + emitWarning = ( + not posData.segmFound and posData.SizeT > 1 and not self.mainWin.isNewFile + ) + if emitWarning: + self.signals.dataIntegrityWarning.emit(posData.pos_foldername) + self.pause() + abort = self.abort + return skipPos, abort + + def warnMismatchSegmDataShape(self, posData): + self.skipPos = False + self.mutex.lock() + self.signals.sigWarnMismatchSegmDataShape.emit(posData) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.skipPos + + @worker_exception_handler + def run(self): + data = [] + user_ch_file_paths = self.user_ch_file_paths + numPos = len(self.user_ch_file_paths) + user_ch_name = self.user_ch_name + self.signals.initProgressBar.emit(len(user_ch_file_paths)) + for i, file_path in enumerate(user_ch_file_paths): + if i == 0: + posData = self.firstPosData + segmFound = self.firstPosData.segmFound + loadSegm = False + else: + posData = load.loadData(file_path, user_ch_name) + loadSegm = True + + self.logger.log(f"Loading {posData.relPath}...") + + posData.loadSizeS = self.mainWin.loadSizeS + posData.loadSizeT = self.mainWin.loadSizeT + posData.loadSizeZ = self.mainWin.loadSizeZ + posData.SizeT = self.mainWin.SizeT + posData.SizeZ = self.mainWin.SizeZ + posData.isSegm3D = self.mainWin.isSegm3D + + if i > 0: + # First pos was already loaded in the main thread + # see loadSelectedData function in gui.py + posData.getBasenameAndChNames() + posData.buildPaths() + if not self.firstPosData.onlyEditMetadata: + posData.loadImgData() + + if self.firstPosData.onlyEditMetadata: + loadSegm = False + + posData.loadOtherFiles( + load_segm_data=loadSegm, + load_acdc_df=True, + load_shifts=True, + loadSegmInfo=True, + load_delROIsInfo=True, + load_bkgr_data=True, + loadBkgrROIs=True, + load_dataPrep_ROIcoords=True, + load_last_tracked_i=True, + load_metadata=True, + load_customAnnot=True, + load_customCombineMetrics=True, + end_filename_segm=self.mainWin.selectedSegmEndName, + create_new_segm=self.mainWin.isNewFile, + new_endname=self.mainWin.newSegmEndName, + labelBoolSegm=self.mainWin.labelBoolSegm, + ) + posData.labelSegmData() + + if i == 0: + posData.segmFound = segmFound + + posData.addYXcentroidColsIfMissing(show_progress=True) + + isPosSegm3D = posData.getIsSegm3D() + isMismatch = ( + isPosSegm3D != self.mainWin.isSegm3D + and isPosSegm3D is not None + and not self.mainWin.isNewFile + ) + if isMismatch: + skipPos = self.warnMismatchSegmDataShape(posData) + if skipPos: + self.logger.log( + f'Skipping "{posData.relPath}" because segmentation ' + "data shape different from first Position loaded." + ) + continue + else: + data = "abort" + break + + self.logger.log( + "Loaded paths:\n" + f"Segmentation file name: {os.path.basename(posData.segm_npz_path)}\n" + f"ACDC output file name {os.path.basename(posData.acdc_output_csv_path)}" + ) + + posData.SizeT = self.mainWin.SizeT + if self.mainWin.SizeZ > 1: + SizeZ = posData.img_data_shape[-3] + posData.SizeZ = SizeZ + else: + posData.SizeZ = 1 + posData.TimeIncrement = self.mainWin.TimeIncrement + posData.PhysicalSizeZ = self.mainWin.PhysicalSizeZ + posData.PhysicalSizeY = self.mainWin.PhysicalSizeY + posData.PhysicalSizeX = self.mainWin.PhysicalSizeX + posData.isSegm3D = self.mainWin.isSegm3D + posData.saveMetadata( + signals=self.signals, + mutex=self.mutex, + waitCond=self.waitCond, + additionalMetadata=self.firstPosData._additionalMetadataValues, + ) + if hasattr(posData, "img_data_shape"): + SizeY, SizeX = posData.img_data_shape[-2:] + + if posData.SizeZ > 1 and posData.img_data.ndim < 3: + posData.SizeZ = 1 + posData.segmInfo_df = None + try: + os.remove(posData.segmInfo_df_csv_path) + except FileNotFoundError: + pass + + posData.setBlankSegmData(posData.SizeT, posData.SizeZ, SizeY, SizeX) + if not self.firstPosData.onlyEditMetadata: + skipPos, abort = self.checkSelectedDataShape(posData, numPos) + else: + skipPos, abort = False, False + + if skipPos: + continue + elif abort: + data = "abort" + break + + posData.setTempPaths(createFolder=False) + isRecoveredDataPresent = ( + os.path.exists(posData.segm_npz_temp_path) + or posData.isRecoveredAcdcDfPresent() + or posData.isSafeNpzOverwritePresent() + ) + if isRecoveredDataPresent and not self.mainWin.newSegmEndName: + if not self.recoveryAsked: + self.mutex.lock() + self.signals.sigRecovery.emit(posData) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + self.recoveryAsked = True + if self.abort: + data = "abort" + break + if self.loadUnsaved: + self.logger.log("Loading unsaved data...") + if os.path.exists(posData.segm_npz_temp_path): + segm_npz_path = posData.segm_npz_temp_path + posData.segm_data = np.load(segm_npz_path)["arr_0"] + segm_filename = os.path.basename(segm_npz_path) + posData.segm_npz_path = os.path.join( + posData.images_path, segm_filename + ) + + posData.loadMostRecentUnsavedAcdcDf() + elif self.loadSafeOverwriteNpz: + self.logger.log("Loading safe npz overwrite...") + segm_safe_npz_path = posData.getSafeNpzOverwritePath() + posData.segm_data = np.load(segm_safe_npz_path)["arr_0"] + + # Allow single 2D/3D image + if posData.SizeT == 1: + posData.img_data = posData.img_data[np.newaxis] + posData.segm_data = posData.segm_data[np.newaxis] + if hasattr(posData, "img_data_shape"): + img_shape = posData.img_data_shape + img_shape = "Not Loaded" + if hasattr(posData, "img_data_shape"): + datasetShape = posData.img_data.shape + else: + datasetShape = "Not Loaded" + if posData.segm_data is not None: + posData.segmSizeT = len(posData.segm_data) + SizeT = posData.SizeT + SizeZ = posData.SizeZ + self.logger.log(f"Full dataset shape = {img_shape}") + self.logger.log(f"Loaded dataset shape = {datasetShape}") + self.logger.log(f"Number of frames = {SizeT}") + self.logger.log(f"Number of z-slices per frame = {SizeZ}") + data.append(posData) + self.signals.progressBar.emit(1) + + if not data: + data = None + self.signals.dataIntegrityCritical.emit() + + self.signals.finished.emit(data) + + +class LazyLoader(QObject): + sigLoadingFinished = Signal() + + def __init__(self, mutex, waitCond, readH5mutex, waitReadH5cond): + QObject.__init__(self) + self.signals = signals() + self.mutex = mutex + self.waitCond = waitCond + self.exit = False + self.salute = True + self.sender = None + self.H5readWait = False + self.waitReadH5cond = waitReadH5cond + self.readH5mutex = readH5mutex + + def setArgs(self, posData, current_idx, axis, updateImgOnFinished): + self.wait = False + self.updateImgOnFinished = updateImgOnFinished + self.posData = posData + self.current_idx = current_idx + self.axis = axis + + def pauseH5read(self): + self.readH5mutex.lock() + self.waitReadH5cond.wait(self.mutex) + self.readH5mutex.unlock() + + def pause(self): + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + @worker_exception_handler + def run(self): + while True: + if self.exit: + self.signals.progress.emit("Closing lazy loader...", "INFO") + break + elif self.wait: + self.signals.progress.emit("Lazy loader paused.", "INFO") + self.pause() + else: + self.signals.progress.emit("Lazy loader resumed.", "INFO") + self.posData.loadChannelDataChunk( + self.current_idx, axis=self.axis, worker=self + ) + self.sigLoadingFinished.emit() + self.wait = True + + self.signals.finished.emit(None) + + +class MigrateUserProfileWorker(QObject): + finished = Signal(object) + critical = Signal(object) + progress = Signal(str) + debug = Signal(object) + + def __init__(self, src_path, dst_path, acdc_folders): + QObject.__init__(self) + self.signals = signals() + self.src_path = src_path + self.dst_path = dst_path + self.acdc_folders = acdc_folders + + @worker_exception_handler + def run(self): + import shutil + from . import models_path + + self.progress.emit( + "Migrating user profile data from " + f'"{self.src_path}" to "{self.dst_path}"...' + ) + acdc_folders = self.acdc_folders + self.signals.initProgressBar.emit(2 * len(acdc_folders)) + dst_folder = os.path.basename(self.dst_path) + folders_to_remove = [] + for acdc_folder in acdc_folders: + if acdc_folder == dst_folder: + # Skip the destination folder that would be picked up if the + # user called it with acdc at the start of the name + self.signals.progressBar.emit(2) + continue + src = os.path.join(self.src_path, acdc_folder) + dst = os.path.join(self.dst_path, acdc_folder) + self.progress.emit(f"Copying {src} to {dst}...") + files_failed_move = copy_or_move_tree( + src, + dst, + copy=False, + sigInitPbar=self.signals.sigInitInnerPbar, + sigUpdatePbar=self.signals.sigUpdateInnerPbar, + ) + folders_to_remove.append(src) + self.signals.progressBar.emit(1) + + for to_remove in folders_to_remove: + try: + self.progress.emit(f'Removing "{to_remove}"...') + shutil.rmtree(to_remove) + except Exception as err: + self.progress.emit( + "--------------------------------------------------------\n" + f'[WARNING]: Removal of the folder "{to_remove}" failed. ' + "Please remove manually.\n" + "--------------------------------------------------------" + ) + finally: + self.signals.progressBar.emit(1) + + # Update model's paths + load.migrate_models_paths(self.dst_path) + + # Store user profile data folder path + from . import user_profile_path_txt + + os.makedirs(os.path.dirname(user_profile_path_txt), exist_ok=True) + with open(user_profile_path_txt, "w") as txt: + txt.write(self.dst_path) + + self.finished.emit(self) + + +class MoveTempFilesWorker(QObject): + def __init__(self, temp_files_to_move: Dict[os.PathLike, os.PathLike]): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.temp_files_to_move = temp_files_to_move + + @worker_exception_handler + def run(self): + for src, dst in self.temp_files_to_move.items(): + self.logger.log(f"Saving channel data to: {dst}...") + shutil.move(src, dst) + tempDir = os.path.dirname(src) + shutil.rmtree(tempDir) + self.signals.progressBar.emit(1) + self.signals.finished.emit(self) + + +class saveDataWorker(QObject): + finished = Signal() + progress = Signal(str) + sigLog = Signal(str) + progressBar = Signal(int, int, float) + critical = Signal(object) + addMetricsCritical = Signal(str, str) + regionPropsCritical = Signal(str, str) + criticalPermissionError = Signal(str) + metricsPbarProgress = Signal(int, int) + askZsliceAbsent = Signal(str, object) + customMetricsCritical = Signal(str, str) + sigCombinedMetricsMissingColumn = Signal(str, str) + sigDebug = Signal(object) + + def __init__(self, mainWin): + QObject.__init__(self) + self.mainWin = mainWin + self.saveWin = mainWin.saveWin + self.mutex = mainWin.mutex + self.waitCond = mainWin.waitCond + self.customMetricsErrors = {} + self.addMetricsErrors = {} + self.regionPropsErrors = {} + self.abort = False + + def checkAbort(self): + if self.saveWin.aborted: + self.finished.emit() + return True + return False + + def saveManualBackgroundData(self, posData, frame_i): + data_dict = posData.allData_li[frame_i] + if "manualBackgroundLab" not in data_dict: + return + + manualBackgrData = data_dict["manualBackgroundLab"] + posData.saveManualBackgroundData(manualBackgrData) + + def emitSigPermissionErrorAndSave( + self, all_frames_acdc_df, acdc_output_csv_path, custom_annot_columns + ): + err_msg = ( + "The below file is open in another app " + "(Excel maybe?).\n\n" + f"{acdc_output_csv_path}\n\n" + 'Close file and then press "Ok".' + ) + self.mutex.lock() + self.criticalPermissionError.emit(err_msg) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + # Save segmentation metadata + load.save_acdc_df_file( + all_frames_acdc_df, + acdc_output_csv_path, + custom_annot_columns=custom_annot_columns, + last_cca_frame_i=self.mainWin.save_cca_until_frame_i, + ) + + def _emitSigDebug(self, stuff_to_debug): + self.mutex.lock() + self.sigDebug.emit(stuff_to_debug) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitUpdateProgressBar(self): + t = time.perf_counter() + exec_time = t - self.time_last_pbar_update + self.progressBar.emit(1, -1, exec_time) + self.time_last_pbar_update = t + + def saveAcdcDf(self, posData: load.loadData, end_i): + acdc_dfs_li = [] + keys = [] + self.progress.emit(f"Saving annotations for {posData.relPath}...") + for frame_i, data_dict in enumerate(posData.allData_li[: end_i + 1]): + if self.saveWin.aborted: + self.finished.emit() + return + + # Build saved_segm_data + lab = data_dict["labels"] + if lab is None: + break + + acdc_df = posData.allData_li[frame_i]["acdc_df"] + if acdc_df is None: + continue + + acdc_dfs_li.append(acdc_df) + keys.append((frame_i, posData.TimeIncrement * frame_i)) + + if not acdc_dfs_li: + return + + self.mainWin._measurements_kernel._concat_and_save_acdc_df( + acdc_dfs_li, + keys, + posData, + self.mainWin.save_metrics, + saveDataWorker=self, + last_cca_frame_i=self.mainWin.save_cca_until_frame_i, + ) + + def saveSegmData(self, posData, end_i, saved_segm_data): + self.progress.emit(f"Saving segmentation data for {posData.relPath}...") + + # extend saved_segm_data if needed + if posData.SizeT > 1: + missing_frames_number = end_i + 1 - len(saved_segm_data) + if missing_frames_number > 0: + saved_segm_data = np.concatenate( + ( + saved_segm_data, + np.zeros( + (missing_frames_number, *saved_segm_data.shape[1:]), + dtype=saved_segm_data.dtype, + ), + ), + ) + + for frame_i, data_dict in enumerate(posData.allData_li[: end_i + 1]): + if self.saveWin.aborted: + self.finished.emit() + return + + # Build saved_segm_data + lab = data_dict["labels"] + if lab is None: + break + + posData.lab = lab + + if posData.SizeT > 1: + saved_segm_data[frame_i] = lab + else: + saved_segm_data = lab + if "manualBackgroundLab" in data_dict: + manualBackgrData = data_dict["manualBackgroundLab"] + posData.saveManualBackgroundData(manualBackgrData) + + # Save segmentation file + io.savez_compressed(posData.segm_npz_path, np.squeeze(saved_segm_data)) + posData.segm_data = saved_segm_data + # Allow single 2D/3D image + if posData.SizeT == 1: + posData.segm_data = posData.segm_data[np.newaxis] + + try: + os.remove(posData.segm_npz_temp_path) + except Exception as e: + pass + + @worker_exception_handler + def run(self): + posToSave = self.mainWin.posToSave + if posToSave is None: + numPosToSave = 1 + else: + numPosToSave = len(posToSave) + save_metrics = self.mainWin.save_metrics + if self.isQuickSave: + save_metrics = False + self.time_last_pbar_update = time.perf_counter() + mode = self.mode + for p, posData in enumerate(self.mainWin.data): + if self.saveWin.aborted: + self.finished.emit() + return + + if posToSave is not None: + if posData.pos_foldername not in posToSave: + self.progress.emit(f"Skipping {posData.relPath}") + continue + + last_tracked_i_path = posData.last_tracked_i_path + end_i = self.mainWin.save_until_frame_i + self.saveSegmData(posData, end_i, posData.segm_data) + + posData.saveCustomAnnotationParams() + current_frame_i = posData.frame_i + + posData.saveTrackedLostCentroids() + + if not self.mainWin.isSnapshot: + last_tracked_i = self.mainWin.last_tracked_i + if last_tracked_i is None: + self.mainWin.saveWin.aborted = True + self.finished.emit() + return + elif self.mainWin.isSnapshot: + last_tracked_i = 0 + + if p == 0: + self.progressBar.emit(0, numPosToSave * (last_tracked_i + 1), 0) + + acdc_output_csv_path = posData.acdc_output_csv_path + delROIs_info_path = posData.delROIs_info_path + + # Add segmented channel data for calc metrics if requested + add_user_channel_data = True + for chName in self.mainWin._measurements_kernel.chNamesToSkip: + skipUserChannel = posData.filename.endswith( + chName + ) or posData.filename.endswith(f"{chName}_aligned") + if skipUserChannel: + add_user_channel_data = False + + if add_user_channel_data and not self.isQuickSave: + posData.fluo_data_dict[posData.filename] = posData.img_data + + if not self.isQuickSave: + posData.fluo_bkgrData_dict[posData.filename] = posData.bkgrData + + posData.setLoadedChannelNames() + + if not self.isQuickSave: + self.mainWin.initMetricsToSave(posData) + self.mainWin._measurements_kernel.run( + posData=posData, + stop_frame_n=end_i + 1, + saveDataWorker=self, + save_metrics=self.mainWin.save_metrics, + last_cca_frame_i=self.mainWin.save_cca_until_frame_i, + ) + else: + self.saveAcdcDf(posData, end_i) + + self.progress.emit(f"Saving {posData.relPath}") + + if not self.do_not_save_og_whitelist: + og_save_path = os.path.join( + posData.images_path, self.append_name_og_whitelist + ) + posData.whitelist.saveOGLabs(og_save_path) + + if posData.whitelist: + whitelistIDs_path = posData.segm_npz_path.replace( + ".npz", "_whitelistIDs.json" + ) + new_centroids_path = posData.segm_npz_path.replace( + ".npz", "_new_centroids.json" + ) + posData.whitelist.save( + whitelistIDs_path, new_centroids_path=new_centroids_path + ) + + if posData.segmInfo_df is not None: + try: + posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) + except PermissionError: + err_msg = ( + "The below file is open in another app " + "(Excel maybe?).\n\n" + f"{posData.segmInfo_df_csv_path}\n\n" + 'Close file and then press "Ok".' + ) + self.mutex.lock() + self.criticalPermissionError.emit(err_msg) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + posData.segmInfo_df.to_csv(posData.segmInfo_df_csv_path) + + posData.fluo_data_dict.pop(posData.filename, None) + + if not self.isQuickSave: + posData.fluo_bkgrData_dict.pop(posData.filename) + + if posData.SizeT > 1: + self.progress.emit("Almost done...") + self.progressBar.emit(0, 0, 0) + + if self.isQuickSave: + # Go back to current frame + posData.frame_i = current_frame_i + self.mainWin.get_data() + continue + + with open(last_tracked_i_path, "w+") as txt: + txt.write(str(end_i)) + + # Save combined metrics equations + posData.saveCombineMetrics() + self.mainWin.pointsLayerDataToDf(posData) + posData.saveClickEntryPointsDfs() + + posData.last_tracked_i = last_tracked_i + + # Go back to current frame + posData.frame_i = current_frame_i + self.mainWin.get_data() + + if mode == "Segmentation and Tracking" or mode == "Viewer": + self.progress.emit(f"Saved data until frame number {end_i + 1}") + elif mode == "Cell cycle analysis": + self.progress.emit( + "Saved cell cycle annotations until frame " + f"number {self.mainWin.last_cca_frame_i + 1}" + ) + # self.progressBar.emit(1) + if self.mainWin.isSnapshot: + self.progress.emit(f"Saved all {p + 1} Positions!") + + self.finished.emit() + + +class relabelSequentialWorker(QObject): + finished = Signal() + critical = Signal(object) + progress = Signal(str) + sigRemoveItemsGUI = Signal(int) + debug = Signal(object) + + def __init__(self, mainWin, posFoldernames): + QObject.__init__(self) + self.mainWin = mainWin + self.data = mainWin.data + self.posFoldernames = posFoldernames + self.mutex = QMutex() + self.waitCond = QWaitCondition() + + def progressNewIDs(self, oldIDs, newIDs): + li = list(zip(oldIDs, newIDs)) + s = "\n".join([str(pair).replace(",", " -->") for pair in li]) + s = f"IDs relabelled as follows:\n{s}" + self.progress.emit(s) + + @worker_exception_handler + def run(self): + self.mutex.lock() + + self.progress.emit("Relabelling process started...") + mainWin = self.mainWin + + current_pos_i = mainWin.pos_i + + for p, posData in enumerate(self.data): + if posData.pos_foldername not in self.posFoldernames: + continue + + mainWin.pos_i = p + current_lab = mainWin.get_2Dlab(posData.lab).copy() + current_frame_i = posData.frame_i + segm_data = [] + for frame_i, data_dict in enumerate(posData.allData_li): + lab = data_dict["labels"] + if lab is None: + break + segm_data.append(lab) + # if frame_i == current_frame_i: + # break + + if not segm_data: + segm_data = np.array([current_lab]) + + segm_data = np.array(segm_data) + segm_data, oldIDs, newIDs = core.relabel_sequential( + segm_data, is_timelapse=posData.SizeT > 1 + ) + self.progressNewIDs(oldIDs, newIDs) + self.sigRemoveItemsGUI.emit(np.max(segm_data)) + + self.progress.emit( + "Updating stored data and cell cycle annotations (if present)..." + ) + + mainWin.updateAnnotatedIDs(oldIDs, newIDs, logger=self.progress.emit) + mainWin.store_data(mainThread=False) + + for frame_i, lab in enumerate(segm_data): + posData.frame_i = frame_i + posData.lab = lab + mainWin.get_cca_df() + if posData.cca_df is not None: + mainWin.update_cca_df_relabelling(posData, oldIDs, newIDs) + mainWin.update_rp(draw=False) + mainWin.store_data(mainThread=False) + + # Go back to current frame + mainWin.pos_i = current_pos_i + posData = self.data[mainWin.pos_i] + posData.frame_i = current_frame_i + mainWin.get_data() + + self.mutex.unlock() + self.finished.emit() + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/metrics.py b/cellacdc/workers/metrics.py new file mode 100644 index 000000000..dcac024fe --- /dev/null +++ b/cellacdc/workers/metrics.py @@ -0,0 +1,1520 @@ +"""Background Qt workers: metrics.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class ComputeMetricsWorker(QObject): + progressBar = Signal(int, int, float) + + def __init__(self, mainWin): + QObject.__init__(self) + self.signals = signals() + self.abort = False + self.setup_done = False + self.logger = workerLogger(self.signals.progress) + self.mutex = QMutex() + self.waitCond = QWaitCondition() + self.mainWin = mainWin + + def emitSelectSegmFiles(self, exp_path, pos_foldernames): + self.mutex.lock() + self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + return True + else: + return False + + @worker_exception_handler + def run(self): + np.seterr(invalid="ignore") + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.standardMetricsErrors = {} + self.customMetricsErrors = {} + self.regionPropsErrors = {} + tot_pos = len(pos_foldernames) + self.allPosDataInputs = [] + posDatas = [] + self.logger.log("-" * 30) + expFoldername = os.path.basename(exp_path) + + if i == 0: + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.signals.finished.emit(self) + return + + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + pos_path = os.path.join(exp_path, pos) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + + self.signals.sigUpdatePbarDesc.emit(f"Loading {pos_path}...") + + # Use first found channel, it doesn't matter for metrics + chName = chNames[0] + file_path = utils.getChannelFilePath(images_path, chName) + + # Load data + posData = load.loadData(file_path, chName) + posData.getBasenameAndChNames(useExt=(".tif", ".h5")) + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=False, + load_acdc_df=True, + load_metadata=True, + loadSegmInfo=True, + load_customCombineMetrics=True, + ) + + posDatas.append(posData) + + self.allPosDataInputs.append( + { + "file_path": file_path, + "chName": chName, + "combineMetricsConfig": posData.combineMetricsConfig, + "combineMetricsPath": posData.custom_combine_metrics_path, + } + ) + + if any([posData.SizeT > 1 for posData in posDatas]): + self.mutex.lock() + self.signals.sigAskStopFrame.emit(posDatas) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.signals.finished.emit(self) + return + for p, posData in enumerate(posDatas): + self.allPosDataInputs[p]["stopFrameNum"] = posData.stopFrameNum + else: + for p, posData in enumerate(posDatas): + self.allPosDataInputs[p]["stopFrameNum"] = 1 + + self.kernel = cli.ComputeMeasurementsKernel( + self.logger, + self.mainWin.log_path, + False, + ) + + from cellacdc.workflow.pipelines.batch import run_gui_measurements_batch + from cellacdc.workflow.runnable import RunnableConfig + + run_gui_measurements_batch( + kernel=self.kernel, + paths=[inp["file_path"] for inp in self.allPosDataInputs], + stop_frame_numbers=[ + inp["stopFrameNum"] for inp in self.allPosDataInputs + ], + end_filename_segm=self.mainWin.endFilenameSegm, + compute_metrics_worker=self, + config=RunnableConfig(logger_func=self.logger.log), + ) + + if self.kernel.setup_done or self.abort: + return + + self.logger.log("*" * 30) + + self.mutex.lock() + self.signals.sigErrorsReport.emit( + self.standardMetricsErrors, + self.customMetricsErrors, + self.regionPropsErrors, + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + self.signals.finished.emit(self) + + def emitSigComputeVolume(self, posData, stop_frame_n): + # Recreate allData_li attribute of the gui + posData.allData_li = [] + for frame_i, lab in enumerate(posData.segm_data[:stop_frame_n]): + data_dict = {"labels": lab, "regionprops": skimage.measure.regionprops(lab)} + posData.allData_li.append(data_dict) + self.mutex.lock() + self.signals.sigComputeVolume.emit(stop_frame_n, posData) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitSigPermissionErrorAndSave( + self, posData, traceback_str, all_frames_acdc_df, custom_annot_columns + ): + self.mutex.lock() + self.signals.sigPermissionError.emit( + traceback_str, posData.acdc_output_csv_path + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + load.save_acdc_df_file( + all_frames_acdc_df, + posData.acdc_output_csv_path, + custom_annot_columns=custom_annot_columns, + ) + + def emitSigInitMetricsDialog(self, posData): + self.mainWin.gui.data = [posData] + self.mainWin.gui.pos_i = 0 + self.mainWin.gui.isSegm3D = posData.getIsSegm3D() + self.mutex.lock() + self.signals.sigInitAddMetrics.emit(posData, self.allPosDataInputs) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitSigAskRunNow(self): + self.mutex.lock() + self.signals.sigAskRunNow.emit(self) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + +class ComputeMetricsMultiChannelWorker(BaseWorkerUtil): + sigAskAppendName = Signal(str, list, list) + sigCriticalNotEnoughSegmFiles = Signal(str) + sigAborted = Signal() + sigHowCombineMetrics = Signal(str, list, list, list) + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitHowCombineMetrics( + self, + imagesPath, + selectedAcdcOutputEndnames, + existingAcdcOutputEndnames, + allChNames, + ): + self.mutex.lock() + self.sigHowCombineMetrics.emit( + imagesPath, + selectedAcdcOutputEndnames, + existingAcdcOutputEndnames, + allChNames, + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def loadAcdcDfs(self, imagesPath, selectedAcdcOutputEndnames): + for end in selectedAcdcOutputEndnames: + filePath, _ = load.get_path_from_endname(end, imagesPath) + acdc_df = pd.read_csv(filePath) + yield acdc_df + + def run_iter_exp(self, exp_path, pos_foldernames, i, tot_exp): + tot_pos = len(pos_foldernames) + + abort = self.emitSelectAcdcOutputFiles( + exp_path, + pos_foldernames, + infoText=" to combine", + allowSingleSelection=False, + ) + if abort: + self.sigAborted.emit() + return + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit( + f"{self.mainWin.basename_pos1}acdc_output", + self.mainWin.existingAcdcOutputEndnames, + self.mainWin.selectedAcdcOutputEndnames, + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + selectedAcdcOutputEndnames = self.mainWin.selectedAcdcOutputEndnames + existingAcdcOutputEndnames = self.mainWin.existingAcdcOutputEndnames + appendedName = self.appendedName + + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, {pos} ({p + 1}/{tot_pos})" + ) + + imagesPath = os.path.join(exp_path, pos, "Images") + basename, chNames = utils.getBasenameAndChNames( + imagesPath, useExt=(".tif", ".h5") + ) + + if p == 0: + abort = self.emitHowCombineMetrics( + imagesPath, + selectedAcdcOutputEndnames, + existingAcdcOutputEndnames, + chNames, + ) + if abort: + self.sigAborted.emit() + return + acdcDfs = self.acdcDfs.values() + # Update selected acdc_dfs since the user could have + # loaded additional ones inside the emitHowCombineMetrics + # dialog + selectedAcdcOutputEndnames = self.acdcDfs.keys() + else: + acdcDfs = self.loadAcdcDfs(imagesPath, selectedAcdcOutputEndnames) + + dfs = [] + for i, acdc_df in enumerate(acdcDfs): + dfs.append(acdc_df.add_suffix(f"_table{i + 1}")) + combined_df = pd.concat(dfs, axis=1) + + newAcdcDf = pd.DataFrame(index=combined_df.index) + for newColname, equation in self.equations.items(): + newAcdcDf[newColname] = combined_df.eval(equation) + + newAcdcDfPath = os.path.join( + imagesPath, f"{basename}acdc_output_{appendedName}.csv" + ) + newAcdcDf.to_csv(newAcdcDfPath) + + equationsIniPath = os.path.join( + imagesPath, f"{basename}equations_{appendedName}.ini" + ) + equationsConfig = config.ConfigParser() + if os.path.exists(equationsIniPath): + equationsConfig.read(equationsIniPath) + equationsConfig = self.addEquationsToConfigPars( + equationsConfig, selectedAcdcOutputEndnames, self.equations + ) + with open(equationsIniPath, "w") as configfile: + equationsConfig.write(configfile) + + self.signals.progressBar.emit(1) + + return True + + def addEquationsToConfigPars(self, cp, selectedAcdcOutputEndnames, equations): + section = [ + f"df{i + 1}:{end}" for i, end in enumerate(selectedAcdcOutputEndnames) + ] + section = ";".join(section) + if section not in cp: + cp[section] = {} + + for metricName, expression in equations.items(): + cp[section][metricName] = expression + + return cp + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + self.errors = {} + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + try: + result = self.run_iter_exp(exp_path, pos_foldernames, i, tot_exp) + if result is None: + return + except Exception as e: + traceback_str = traceback.format_exc() + self.errors[e] = traceback_str + self.logger.log(traceback_str) + + self.signals.finished.emit(self) + + +class ConcatAcdcDfsWorker(BaseWorkerUtil): + sigAborted = Signal() + sigAskFolder = Signal(str) + sigSetMeasurements = Signal(object) + sigAskAppendName = Signal(str, list) + + def __init__(self, mainWin, format="CSV"): + super().__init__(mainWin) + if format.startswith("CSV"): + self._to_format = "to_csv" + elif format.startswith("XLS"): + self._to_format = "to_excel" + + def emitSetMeasurements(self, kwargs): + self.mutex.lock() + self.sigSetMeasurements.emit(kwargs) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitAskAppendName(self, allPos_acdc_df_basename): + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit(allPos_acdc_df_basename, []) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + + self.signals.initProgressBar.emit(0) + acdc_dfs_allexp = [] + acdc_objs_count_dfs_allexp = {} + keys_exp = [] + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + if i == 0: + abort = self.emitSelectAcdcOutputFiles( + exp_path, + pos_foldernames, + infoText=" to combine", + allowSingleSelection=True, + multiSelection=False, + ) + if abort: + self.sigAborted.emit() + return + + selectedAcdcOutputEndname = self.mainWin.selectedAcdcOutputEndnames[0] + selectedAcdcObjsCountEndname = selectedAcdcOutputEndname.replace( + "acdc_output", "acdc_objects_count" + ) + + self.signals.initProgressBar.emit(len(pos_foldernames)) + acdc_dfs = [] + acdc_objs_count_dfs = {} + keys = [] + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + + ls = utils.listdir(images_path) + + acdc_output_file = [ + f for f in ls if f.endswith(f"{selectedAcdcOutputEndname}.csv") + ] + if not acdc_output_file: + self.logger.log( + f"{pos} does not contain any " + f"{selectedAcdcOutputEndname}.csv file. " + "Skipping it." + ) + self.signals.progressBar.emit(1) + continue + + acdc_objs_count_file = [ + f for f in ls if f.endswith(f"{selectedAcdcObjsCountEndname}.csv") + ] + if acdc_objs_count_file: + df_count_filepath = os.path.join( + images_path, acdc_objs_count_file[0] + ) + df_count = pd.read_csv(df_count_filepath) + acdc_objs_count_dfs[pos] = df_count + + acdc_df_filepath = os.path.join(images_path, acdc_output_file[0]) + acdc_df = pd.read_csv(acdc_df_filepath).set_index("Cell_ID") + acdc_dfs.append(acdc_df) + keys.append(pos) + + self.signals.progressBar.emit(1) + + self.signals.initProgressBar.emit(0) + acdc_df_allpos = pd.concat( + acdc_dfs, keys=keys, names=["Position_n", "Cell_ID"] + ) + acdc_df_allpos["experiment_folderpath"] = exp_path + + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + df_metadata = load.load_metadata_df(images_path) + SizeZ = df_metadata.at["SizeZ", "values"] + SizeZ = int(float(SizeZ)) + existing_colnames = acdc_df_allpos.columns + isSegm3D = any([col.endswith("3D") for col in existing_colnames]) + + if i == 0: + kwargs = { + "loadedChNames": chNames, + "notLoadedChNames": [], + "isZstack": SizeZ > 1, + "isSegm3D": isSegm3D, + "existing_colnames": existing_colnames, + } + self.emitSetMeasurements(kwargs) + if self.abort: + self.sigAborted.emit() + return + + selected_cols = [ + col for col in self.selectedColumns if col in acdc_df_allpos.columns + ] + acdc_df_allpos = acdc_df_allpos[selected_cols] + acdc_dfs_allexp.append(acdc_df_allpos) + exp_name = os.path.basename(exp_path) + keys_exp.append((exp_path, exp_name)) + + allpos_dir = os.path.join(exp_path, "AllPos_acdc_output") + if not os.path.exists(allpos_dir): + os.mkdir(allpos_dir) + + allPos_acdc_df_basename = f"AllPos_{selectedAcdcOutputEndname}" + if i == 0: + self.emitAskAppendName(allPos_acdc_df_basename) + if self.abort: + self.sigAborted.emit() + return + + acdc_objs_count_df_allpos_filename = self.concat_df_filename.replace( + "acdc_output", "acdc_objects_count" + ) + + acdc_dfs_allpos_filepath = os.path.join(allpos_dir, self.concat_df_filename) + + self.logger.log( + "Saving all positions concatenated file to " + f'"{acdc_dfs_allpos_filepath}"' + ) + to_format_func = getattr(acdc_df_allpos, self._to_format) + to_format_func(acdc_dfs_allpos_filepath) + self.acdc_dfs_allpos_filepath = acdc_dfs_allpos_filepath + + if not acdc_objs_count_dfs: + continue + + acdc_objs_count_df_allpos = pd.concat( + acdc_objs_count_dfs, names=["Position_n"] + ) + acdc_objs_count_df_allpos["experiment_folderpath"] = exp_path + + acdc_objs_count_df_allpos_filepath = os.path.join( + allpos_dir, acdc_objs_count_df_allpos_filename + ) + + self.logger.log( + "Saving all positions objects count file to " + f'"{acdc_objs_count_df_allpos_filepath}"' + ) + to_format_func = getattr(acdc_objs_count_df_allpos, self._to_format) + to_format_func(acdc_objs_count_df_allpos_filepath) + + acdc_objs_count_dfs_allexp[(exp_path, exp_name)] = acdc_objs_count_df_allpos + + if len(keys_exp) <= 1: + self.signals.finished.emit(self) + return + + allExp_filename = f"multiExp_{self.concat_df_filename}" + self.mutex.lock() + self.sigAskFolder.emit(allExp_filename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + acdc_df_allexp = pd.concat( + acdc_dfs_allexp, + keys=keys_exp, + names=["experiment_folderpath", "experiment_foldername"], + ) + acdc_dfs_allexp_filepath = os.path.join(self.allExpSaveFolder, allExp_filename) + self.logger.log( + "Saving multiple experiments concatenated file to " + f'"{acdc_dfs_allexp_filepath}"' + ) + to_format_func = getattr(acdc_df_allexp, self._to_format) + to_format_func(acdc_dfs_allexp_filepath) + + if acdc_objs_count_dfs_allexp: + allexp_count_df_filename = f"multiExp_{acdc_objs_count_df_allpos_filename}" + acdc_objs_count_df_allexp = pd.concat( + acdc_objs_count_dfs_allexp, + names=["experiment_folderpath", "experiment_foldername"], + ) + acdc_objs_count_df_allexp_filepath = os.path.join( + self.allExpSaveFolder, allexp_count_df_filename + ) + self.logger.log( + "Saving multiple experiments concatenated file to " + f'"{acdc_objs_count_df_allexp_filepath}"' + ) + to_format_func = getattr(acdc_objs_count_df_allexp, self._to_format) + to_format_func(acdc_objs_count_df_allexp_filepath) + + self.signals.finished.emit(self) + + +class ConcatSpotmaxDfsWorker(BaseWorkerUtil): + sigAborted = Signal() + sigAskFolder = Signal(str) + sigSetMeasurements = Signal(object) + sigAskAppendName = Signal(str, list) + + def __init__(self, mainWin, format="CSV"): + super().__init__(mainWin) + if format.startswith("CSV"): + self._final_ext = ".csv" + elif format.startswith("XLS"): + self._final_ext = ".xlsx" + self.acdcOutputEndname = None + + def emitSetMeasurements(self, kwargs): + self.mutex.lock() + self.sigSetMeasurements.emit(kwargs) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitAskAppendName(self, allPos_spotmax_df_basename): + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit(allPos_spotmax_df_basename, []) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitAskCopyCca(self, images_path): + self.mutex.lock() + self.signals.sigAskCopyCca.emit(images_path) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def setAcdcOutputEndname(self, acdcOutputEndname): + self.acdcOutputEndname = acdcOutputEndname + + def getAcdcDf(self, images_path): + if self.acdcOutputEndname is None: + return + + for file in utils.listdir(images_path): + if not file.endswith(self.acdcOutputEndname): + continue + + filepath = os.path.join(images_path, file) + acdc_df = pd.read_csv(filepath, index_col=["frame_i", "Cell_ID"]) + return acdc_df + + def copyCcaColsFromAcdcDf(self, df, acdc_df, debug=False): + if acdc_df is None: + return df + + if debug: + printl(acdc_df.columns.to_list(), pretty=True) + + idx = df.index.intersection(acdc_df.index) + for col in cca_df_colnames: + if col not in acdc_df.columns: + continue + + if col not in self.selectedColumns: + continue + + df.loc[idx, col] = acdc_df.loc[idx, col] + + for col in lineage_tree_cols: + if col not in acdc_df.columns: + continue + + if col not in self.selectedColumns: + continue + + df.loc[idx, col] = acdc_df.loc[idx, col] + + for col in default_annot_df.keys(): + if col not in acdc_df.columns: + continue + + if col not in self.selectedColumns: + continue + + df.loc[idx, col] = acdc_df.loc[idx, col] + + for col in self.selectedColumns: + if col not in acdc_df.columns: + continue + + df.loc[idx, col] = acdc_df.loc[idx, col] + + if debug and col == "cell_vol_fl": + printl(df[[col]]) + + return df + + def emitAskFolderWhereToSaveMultiExp(self): + self.mutex.lock() + self.sigAskFolder.emit("") + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + return self.allExpSaveFolder + + def askSelectMeasurements(self, exp_path, posFoldernames): + acdc_dfs = [] + keys = [] + for p, pos in enumerate(posFoldernames): + if self.abort: + self.sigAborted.emit() + return False + + images_path = os.path.join(exp_path, pos, "Images") + acdc_df = self.getAcdcDf(images_path) + if acdc_df is None: + continue + + acdc_dfs.append(acdc_df) + keys.append(pos) + + if not acdc_dfs: + return True + + acdc_df_allpos = pd.concat( + acdc_dfs, keys=keys, names=["Position_n", "frame_i", "Cell_ID"] + ) + acdc_df_allpos["experiment_folderpath"] = exp_path + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + df_metadata = load.load_metadata_df(images_path) + SizeZ = df_metadata.at["SizeZ", "values"] + SizeZ = int(float(SizeZ)) + existing_colnames = acdc_df_allpos.columns + isSegm3D = any([col.endswith("3D") for col in existing_colnames]) + + kwargs = { + "loadedChNames": chNames, + "notLoadedChNames": [], + "isZstack": SizeZ > 1, + "isSegm3D": isSegm3D, + "existing_colnames": existing_colnames, + } + self.emitSetMeasurements(kwargs) + if self.abort: + self.sigAborted.emit() + return False + + return True + + @worker_exception_handler + def run(self): + from spotmax import DFs_FILENAMES, DF_REF_CH_FILENAME + from spotmax.utils import get_runs_num_and_desc + import spotmax.io + + self.selectedColumns = None + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + spotmax_dfs_spots_allexp = defaultdict(lambda: defaultdict(list)) + spotmax_dfs_aggr_allexp = defaultdict(lambda: defaultdict(list)) + ref_ch_dfs_allexp = defaultdict(lambda: defaultdict(list)) + runNumberAlreadyAsked = False + copyFromCcaAlreadyAsked = False + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + all_runs = get_runs_num_and_desc(exp_path, pos_foldernames=pos_foldernames) + if not all_runs: + self.logger.log( + "[WARNING] The following experiment does not contain " + f'valid spotMAX output files. Skipping it. "{exp_path}"' + ) + continue + + if not runNumberAlreadyAsked: + abort = self.emitSelectSpotmaxRun( + exp_path, + pos_foldernames, + all_runs, + infoText=" to combine", + allowSingleSelection=True, + multiSelection=False, + ) + if abort: + self.sigAborted.emit() + return + runNumberAlreadyAsked = True + + selectedSpotmaxRuns = self.mainWin.selectedSpotmaxRuns + + self.signals.initProgressBar.emit(len(pos_foldernames)) + dfs_spots = defaultdict(list) + dfs_aggr = defaultdict(list) + dfs_ref_ch = defaultdict(list) + pos_runs = defaultdict(list) + pos_runs_ref_ch = defaultdict(list) + pos_ini_filepaths = {} + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + pos_path = os.path.join(exp_path, pos) + spotmax_output_path = os.path.join(pos_path, "spotMAX_output") + + if not os.path.exists(spotmax_output_path): + self.logger.log( + "[WARNING] The following Position folder does not contain " + f'valid spotMAX output files. Skipping it. "{pos_path}"' + ) + continue + + images_path = os.path.join(exp_path, pos, "Images") + + if not copyFromCcaAlreadyAsked: + self.emitAskCopyCca(images_path) + if self.abort: + self.sigAborted.emit() + return + + self.askSelectMeasurements(exp_path, pos_foldernames) + if self.abort: + return + copyFromCcaAlreadyAsked = True + + acdc_df = self.getAcdcDf(images_path) + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + for run_desc in selectedSpotmaxRuns: + run, desc = run_desc.split("_...") + ini_filename = f"{run}_analysis_parameters{desc}.ini" + ini_filepath = os.path.join(spotmax_output_path, ini_filename) + if not os.path.exists(ini_filepath): + self.logger.log( + "[WARNING] The following Position folder does not contain " + f"the spotMAX output file for run number {run}. " + f'Skipping it. "{pos_path}"' + ) + continue + + pos_ini_filepaths[(run, desc)] = ini_filepath + for _, pattern_filename in DFs_FILENAMES.items(): + run_filename = pattern_filename.replace("*rn*", run) + run_filename = run_filename.replace("*desc*", desc) + aggr_filename = f"{run_filename}_aggregated.csv" + aggr_filepath = os.path.join(spotmax_output_path, aggr_filename) + if not os.path.exists(aggr_filepath): + continue + + df_spots_filename = f"{run_filename}.h5" + spots_filepath = os.path.join( + spotmax_output_path, df_spots_filename + ) + ext_spots = ".h5" + if not os.path.exists(spots_filepath): + df_spots_filename = f"{run_filename}.csv" + spots_filepath = os.path.join( + spotmax_output_path, df_spots_filename + ) + ext_spots = ".csv" + + if not os.path.exists(spots_filepath): + continue + + analysis_step = re.findall( + r"\*rn\*(.*)\*desc\*", pattern_filename + )[0] + key = (run, analysis_step, desc, ext_spots) + try: + df_spots = ( + spotmax.io.load_spots_table( + spotmax_output_path, df_spots_filename + ) + .reset_index() + .set_index(["frame_i", "Cell_ID"]) + ) + df_spots = self.copyCcaColsFromAcdcDf( + df_spots, acdc_df, debug=False + ) + df_spots = df_spots.reset_index().set_index( + ["frame_i", "Cell_ID", "spot_id"] + ) + dfs_spots[key].append(df_spots) + except Exception as err: + self.logger.log(str(err), level="ERROR") + self.logger.log( + "WARNING: Error when reading single-spots " + "tables (possibly because there are no spots). " + "Skipping this Position.", + level="WARNING", + ) + pass + + df_aggregated = pd.read_csv( + aggr_filepath, index_col=["frame_i", "Cell_ID"] + ) + df_aggregated = self.copyCcaColsFromAcdcDf( + df_aggregated, acdc_df + ) + dfs_aggr[key].append(df_aggregated) + pos_runs[key].append(pos) + + ref_ch_id_text = re.findall( + r"\*rn\*(.*)\*desc\*", DF_REF_CH_FILENAME + )[0] + ref_ch_filename = DF_REF_CH_FILENAME.replace("*rn*", run) + ref_ch_filename = ref_ch_filename.replace("*desc*", desc) + ref_ch_filepath = os.path.join(spotmax_output_path, ref_ch_filename) + if not os.path.exists(ref_ch_filepath): + continue + + df_ref_ch = pd.read_csv( + ref_ch_filepath, index_col=["frame_i", "Cell_ID"] + ) + df_ref_ch = self.copyCcaColsFromAcdcDf(df_ref_ch, acdc_df) + ref_ch_key = (run, ref_ch_id_text, desc) + dfs_ref_ch[ref_ch_key].append(df_ref_ch) + pos_runs_ref_ch[ref_ch_key].append(pos) + + self.signals.progressBar.emit(1) + + self.signals.initProgressBar.emit(0) + + self.logger.log("Saving concantenated files...") + + allpos_folderpath = os.path.join(exp_path, "spotMAX_multipos_output") + os.makedirs(allpos_folderpath, exist_ok=True) + + exp_name = os.path.basename(exp_path) + for key, dfs in dfs_spots.items(): + pos_keys = pos_runs[key] + run, analysis_step, desc, ext_spots = key + + if ext_spots == ".csv": + ext_spots = self._final_ext + filename = f"multipos_{run}{analysis_step}{desc}{ext_spots}" + all_exp_key = filename + df_spots_concat = spotmax.io.save_concat_dfs( + dfs, + pos_keys, + allpos_folderpath, + filename, + ext_spots, + names=["Position_n"], + return_concat_df=True, + ) + df_spots_concat["experiment_foldername"] = exp_name + df_spots_concat["experiment_folderpath"] = exp_path + spotmax_dfs_spots_allexp[all_exp_key]["dfs"].append(df_spots_concat) + spotmax_dfs_spots_allexp[all_exp_key]["keys"].append(exp_path) + ini_filepath = pos_ini_filepaths[(run, desc)] + ini_filename = os.path.basename(ini_filepath) + dst_ini_filepath = os.path.join(allpos_folderpath, ini_filename) + if not os.path.exists(dst_ini_filepath): + shutil.copy2(ini_filepath, dst_ini_filepath) + + spotmax_dfs_spots_allexp[all_exp_key]["ini_filepath"].append( + dst_ini_filepath + ) + + for key, dfs in dfs_aggr.items(): + pos_keys = pos_runs[key] + run, analysis_step, desc, _ = key + filename = ( + f"multipos_{run}{analysis_step}{desc}_aggregated{self._final_ext}" + ) + all_exp_aggr_key = filename + df_aggr_concat = spotmax.io.save_concat_dfs( + dfs, + pos_keys, + allpos_folderpath, + filename, + self._final_ext, + names=["Position_n"], + return_concat_df=True, + ) + spotmax_dfs_aggr_allexp[all_exp_aggr_key]["dfs"].append(df_aggr_concat) + spotmax_dfs_aggr_allexp[all_exp_aggr_key]["keys"].append( + (exp_path, exp_name) + ) + + for key, dfs in dfs_ref_ch.items(): + run, ref_ch_id_text, desc = key + pos_keys = pos_runs_ref_ch[key] + filename = f"multipos_{run}{ref_ch_id_text}{desc}{self._final_ext}" + all_exp_ref_ch_key = filename + df_ref_ch_concat = spotmax.io.save_concat_dfs( + dfs, + pos_keys, + allpos_folderpath, + filename, + self._final_ext, + names=["Position_n"], + return_concat_df=True, + ) + ref_ch_dfs_allexp[all_exp_ref_ch_key]["dfs"].append(df_ref_ch_concat) + ref_ch_dfs_allexp[all_exp_ref_ch_key]["keys"].append( + (exp_path, exp_name) + ) + + multiexp_dst_folderpath = "" + if len(expPaths) == 1: + self.signals.finished.emit(self) + return + + multiexp_dst_folderpath = self.emitAskFolderWhereToSaveMultiExp() + printl(multiexp_dst_folderpath) + if multiexp_dst_folderpath is None: + return + + self.logger.log( + f'Saving multi-experiment files to "{multiexp_dst_folderpath}"...' + ) + names = ["experiment_folderpath", "experiment_foldername"] + for filename, items in spotmax_dfs_spots_allexp.items(): + keys = items["keys"] + dfs = items["dfs"] + multiexp_filename = f"multiexp_{filename}" + extension = os.path.splitext(filename)[-1] + spotmax.io.save_concat_dfs( + dfs, + keys, + multiexp_dst_folderpath, + multiexp_filename, + extension, + names=["experiment_folderpath"], + ) + ini_filepath = items["ini_filepath"][0] + ini_filename = os.path.basename(ini_filepath) + dst_ini_filepath = os.path.join(multiexp_dst_folderpath, ini_filename) + if not os.path.exists(dst_ini_filepath): + shutil.copy2(ini_filepath, dst_ini_filepath) + + for filename, items in spotmax_dfs_aggr_allexp.items(): + keys = items["keys"] + dfs = items["dfs"] + printl(keys, pretty=True) + multiexp_filename = f"multiexp_{filename}" + extension = os.path.splitext(filename)[-1] + spotmax.io.save_concat_dfs( + dfs, + keys, + multiexp_dst_folderpath, + multiexp_filename, + extension, + names=names, + ) + + for filename, items in ref_ch_dfs_allexp.items(): + keys = items["keys"] + dfs = items["dfs"] + multiexp_filename = f"multiexp_{filename}" + extension = os.path.splitext(filename)[-1] + spotmax.io.save_concat_dfs( + dfs, + keys, + multiexp_dst_folderpath, + multiexp_filename, + extension, + names=names, + ) + + self.signals.finished.emit(self) + + +class CcaIntegrityCheckerWorker(QObject): + finished = Signal(object) + critical = Signal(object) + progress = Signal(str, object) + sigDone = Signal() + sigWarning = Signal(str, str) + sigFixWillDivide = Signal(str, list) + + def __init__(self, mutex, waitCond): + QObject.__init__(self) + self.logger = workerLogger(self.progress) + self.mutex = mutex + self.waitCond = waitCond + self.exit = False + self.isFinished = False + self.abortChecking = False + self.isChecking = False + self.isPaused = False + self.debug = False + self.dataQ = deque(maxlen=10) + + def pause(self): + if self.debug: + self.logger.log("Cell cycle annotations checker is idle.") + self.mutex.lock() + self.isPaused = True + self.waitCond.wait(self.mutex) + self.mutex.unlock() + self.isPaused = False + + def enqueue(self, posData): + # First stop previous checking + if self.isChecking: + self.abortChecking = True + self._enqueue(posData) + + def _enqueue(self, posData): + if self.debug: + self.logger.log("Enqueing posData...") + self.dataQ.append(posData) + if len(self.dataQ) == 1: + # Wake worker upon inserting first element + self.abortChecking = False + self.waitCond.wakeAll() + + def clearQueue(self): + self.dataQ.clear() + + def _stop(self): + self.exit = True + self.waitCond.wakeAll() + + def abort(self): + self.abortChecking = True + while not len(self.dataQ) == 0: + data = self.dataQ.pop() + del data + self._stop() + + def _check_equality_num_mothers_buds_in_S(self, checker, frame_i): + num_moth_S, num_buds = checker.get_num_mothers_and_buds_in_S() + + if num_moth_S == num_buds: + return True + + category = "number of buds different from number of mothers in S phase" + ul_items = [ + f"Number of buds = {num_buds}", + f"Number of mothers in S phase = {num_moth_S}", + ] + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} the number of buds and number of " + "mother cells in S phase are different!" + f"{html_utils.to_list(ul_items)}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_mothers_multiple_buds(self, checker, frame_i): + mother_IDs_with_multiple_buds = checker.get_mother_IDs_with_multiple_buds() + if len(mother_IDs_with_multiple_buds) == 0: + return True + + category = "mother cells with multiple buds" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following mother cells have multiple buds assigned to it" + f"

    {mother_IDs_with_multiple_buds}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_cells_without_G1(self, checker, global_cca_df): + IDs_cycles_without_G1 = checker.get_IDs_cycles_without_G1(global_cca_df) + if len(IDs_cycles_without_G1) == 0: + return True + + category = "cell cycles without G1" + txt = html_utils.paragraph( + "Cell-ACDC requires that every cell cycle has at least " + "one frame in G1.
    " + "The following pairs of (ID, generation number) " + "do not satisfy this condition:

    " + f"{IDs_cycles_without_G1}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_will_divide_is_true(self, checker, global_cca_df): + # NOTE: unfortunately this function performs pandas manipulations + # that are either not thread-safe or in any case are freezing the + # GUI. For now we don't run this until we find a solution + return True + + IDs_will_divide_wrong = checker.get_IDs_gen_num_will_divide_wrong(global_cca_df) + if len(IDs_will_divide_wrong) == 0: + return True + + txt = html_utils.paragraph( + "Cell-ACDC found that `will_divide` is annotated as True on the " + "following (ID, generation number) cell
    " + "despite the fact that division is still not annotated on " + "these cells

    :" + f"{IDs_will_divide_wrong}" + ) + self.sigFixWillDivide.emit(txt, IDs_will_divide_wrong) + return False + + def _check_buds_gen_num_zero(self, checker, frame_i): + bud_IDs_gen_num_nonzero = checker.get_bud_IDs_gen_num_nonzero() + if len(bud_IDs_gen_num_nonzero) == 0: + return True + + category = "buds whose generation number is not zero" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following bud IDs have generation number different from 0:" + f"

    {bud_IDs_gen_num_nonzero}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_mothers_gen_num_greater_one(self, checker, frame_i): + moth_IDs_gen_num_non_greater_one = ( + checker.get_moth_IDs_gen_num_non_greater_one() + ) + if len(moth_IDs_gen_num_non_greater_one) == 0: + return True + + category = "mothers whose generation number is < 1" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following mother cells have generation number < 1:" + f"

    {moth_IDs_gen_num_non_greater_one}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_buds_G1(self, checker, frame_i): + buds_G1 = checker.get_buds_G1() + if len(buds_G1) == 0: + return True + + category = "buds in G1" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following bud IDs are in G1 (buds must be in S):" + f"

    {buds_G1}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_cell_S_rel_ID_zero(self, checker, frame_i): + cell_S_rel_ID_zero = checker.get_cell_S_rel_ID_zero() + if len(cell_S_rel_ID_zero) == 0: + return True + + category = "buds in G1" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following cell IDs in S phase do not have " + "relative_ID > 0:" + f"

    {cell_S_rel_ID_zero}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_ID_rel_ID_mismatches(self, checker, frame_i): + ID_rel_ID_mismatches = checker.get_ID_rel_ID_mismatches() + if len(ID_rel_ID_mismatches) == 0: + return True + + items = [ + f"Cell ID {ID} has relative ID = {relID}, " + f"while cell ID {relID} has relative ID = {relID_of_relID}" + for ID, relID, relID_of_relID in ID_rel_ID_mismatches + ] + category = "`ID-relative_ID` mismatches" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "there are the following `ID-relative_ID` mismatches:" + f"{html_utils.to_list(items)}" + ) + self.sigWarning.emit(txt, category) + return False + + def _check_lonely_cells_in_S(self, checker, frame_i): + lonely_cells_in_S = checker.get_lonely_cells_in_S() + if len(lonely_cells_in_S) == 0: + return True + + category = "Lovely cells in S phase" + txt = html_utils.paragraph( + f"At frame n. {frame_i + 1} " + "the following cell IDs are in `S` phase but their `relative_ID` " + f"does not exist:

    " + f"{lonely_cells_in_S}" + ) + self.sigWarning.emit(txt, category) + return False + + def _get_cca_df_copy(self, acdc_df): + try: + cca_df = pd.DataFrame( + data=acdc_df[cca_df_colnames].values, + columns=cca_df_colnames, + index=acdc_df.index, + ) + return cca_df + except KeyError as error: + return + + def check(self, posData): + self.isChecking = True + checkpoints = ( + "_check_lonely_cells_in_S", + "_check_equality_num_mothers_buds_in_S", + "_check_mothers_multiple_buds", + "_check_buds_gen_num_zero", + "_check_mothers_gen_num_greater_one", + "_check_buds_G1", + "_check_cell_S_rel_ID_zero", + "_check_ID_rel_ID_mismatches", + ) + cca_dfs = [] + keys = [] + check_integrity_globally = True + for frame_i, data_dict in enumerate(posData.allData_li): + if self.abortChecking: + check_integrity_globally = False + break + + lab = data_dict["labels"] + if lab is None: + break + + cca_df = data_dict.get("cca_df_checker") + if cca_df is None: + # There are no annotations at frame_i --> stop + break + + IDs = data_dict["IDs"] + checker = core.CcaIntegrityChecker(cca_df, lab, IDs) + + for checkpoint in checkpoints: + proceed = getattr(self, checkpoint)(checker, frame_i) + if not proceed: + break + + if not proceed: + check_integrity_globally = False + break + + cca_dfs.append(cca_df) + keys.append(frame_i) + + if check_integrity_globally and len(cca_dfs) > 1: + global_checkpoints = [ + "_check_cells_without_G1", + # '_check_will_divide_is_true' + ] + # Check integrity globally + global_cca_df = pd.concat(cca_dfs, keys=keys, names=["frame_i"]) + for checkpoint in global_checkpoints: + proceed = getattr(self, checkpoint)(checker, global_cca_df) + if not proceed: + break + + self.abortChecking = False + self.isChecking = False + time.sleep(1) + + @worker_exception_handler + def run(self): + while True: + if self.exit: + self.logger.log("Closing cell cycle integrity checker worker...") + break + elif not len(self.dataQ) == 0: + if self.debug: + self.logger.log( + "Checking integrity of cell cycle annotations " + f"({len(self.dataQ)})..." + ) + data = self.dataQ.pop() + self.check(data) + if len(self.dataQ) == 0: + self.sigDone.emit() + else: + self.pause() + self.isFinished = True + self.finished.emit(self) + + +class GenerateMotherBudTotalTableWorker(BaseWorkerUtil): + def __init__( + self, parentWin, input_csv_filepath, selected_options, out_csv_filepath + ): + super().__init__(parentWin) + self.input_csv_filepath = input_csv_filepath + self.selected_options = selected_options + self.out_csv_filepath = out_csv_filepath + + @worker_exception_handler + def run(self): + self.logger.log(f'Loading table "{self.input_csv_filepath}"...') + self.signals.initProgressBar.emit(0) + + input_df = pd.read_csv(self.input_csv_filepath) + + self.logger.log("Generating output table...") + out_df = cca_functions.generate_mother_bud_total_df( + input_df, **self.selected_options + ) + + self.logger.log(f'Saving output table to "{self.out_csv_filepath}"...') + + out_df.to_csv(self.out_csv_filepath) + + self.signals.finished.emit(self) + + +class CountObjectsInSegm(BaseWorkerUtil): + sigAskAppendName = Signal(str, list) + sigAborted = Signal() + + def __init__(self, mainWin): + super().__init__(mainWin) + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = f"Select segmentation file to count" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_acdc_df=False, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + if posData.segm_data.ndim == 3: + posData.segm_data = posData.segm_data[np.newaxis] + + self.logger.log("Counting objects...") + + countMapper = posData.countObjectsInSegm() + countMapper.pop("In current frame", None) + df_count_endname = posData.saveObjCounts(countMapper) + + self.logger.log( + "Saved object counts table to file ending with: " + f'"{df_count_endname}"' + ) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/segm.py b/cellacdc/workers/segm.py new file mode 100644 index 000000000..0a7f08b53 --- /dev/null +++ b/cellacdc/workers/segm.py @@ -0,0 +1,894 @@ +"""Background Qt workers: segm.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class SegForLostIDsWorker(QObject): + sigAskInit = Signal() + sigAskInstallModel = Signal(str) + sigshowImageDebug = Signal(object) + sigStoreData = Signal(bool) + sigUpdateRP = Signal(bool, bool) + # sigGetData = Signal() + # sigGet2Dlab = Signal() + # sigGetTrackedLostIDs = Signal() + # sigGetBrushID = Signal() + sigSegForLostIDsWorkerAskInstallGPU = Signal(str, bool) + sigTrackManuallyAddedObject = Signal(object, object, bool, bool) + + def __init__(self, guiWin, mutex, waitCond, debug=False): + QObject.__init__(self) + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.guiWin = guiWin + self.mutex = mutex + self.waitCond = waitCond + self._debug = debug + + def emitSigAskInit(self): + self.mutex.lock() + self.sigAskInit.emit() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitSigShowImageDebug(self, img): + # self.mutex.lock() + self.sigshowImageDebug.emit(img) + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() + + def emitSigStoreData(self, autosave): + self.mutex.lock() + self.sigStoreData.emit(autosave) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitSigUpdateRP(self, wl_track_og_curr, wl_update): + self.mutex.lock() + self.sigUpdateRP.emit(wl_track_og_curr, wl_update) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + # def emitSigGetData(self): + # self.mutex.lock() + # self.sigGetData.emit() + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() + + def emitSigAskInstallModel(self, model_name): + self.mutex.lock() + self.sigAskInstallModel.emit(model_name) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def emitSigAskInstallGPU(self, base_model_name, use_gpu): + self.mutex.lock() + self.sigSegForLostIDsWorkerAskInstallGPU.emit(base_model_name, use_gpu) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + # def emitGet2Dlab(self): + # self.mutex.lock() + # self.sigGet2Dlab.emit() + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() + + # def emitGetTrackedLostIDs(self): + # self.mutex.lock() + # self.sigGetTrackedLostIDs.emit() + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() + + # def emitGetBrushID(self): + # self.mutex.lock() + # self.sigGetBrushID.emit() + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() + + def emitTrackManuallyAddedObject(self, IDs, isLost, wl_update, wl_track_og_curr): + self.mutex.lock() + self.sigTrackManuallyAddedObject.emit(IDs, isLost, wl_update, wl_track_og_curr) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + @worker_exception_handler + def run(self): + posData = self.guiWin.data[self.guiWin.pos_i] + frame_i = posData.frame_i + + if not self.guiWin.SegForLostIDsSettings: + self.emitSigAskInit() + + if not self.guiWin.SegForLostIDsSettings: + self.signals.finished.emit(self) + return + + self.logger.info("Segmentation for lost IDs started.") + model_name = "local_seg" + base_model_name = self.guiWin.SegForLostIDsSettings["base_model_name"] + idx = self.guiWin.modelNames.index(model_name) + acdcSegment = self.guiWin.acdcSegment_li[idx] + + init_kwargs = self.guiWin.SegForLostIDsSettings["win"].init_kwargs + + use_gpu = init_kwargs.get("device_type", "cpu") != "cpu" + use_gpu = use_gpu or init_kwargs.get("use_gpu", False) + + self.emitSigAskInstallGPU(base_model_name, use_gpu) + + if not self.gpu_go: + self.signals.finished.emit(self) + return + + if not self.dont_force_cpu: + if "device" in init_kwargs: + init_kwargs["device"] = "cpu" + if "use_gpu" in init_kwargs: + init_kwargs["use_gpu"] = False + + if ( + acdcSegment is None + or base_model_name != self.guiWin.local_seg_base_model_name + ): + try: + self.logger.info(f"Importing {base_model_name}...") + self.emitSigAskInstallModel(base_model_name) + acdcSegment = utils.import_segment_module(base_model_name) + self.guiWin.acdcSegment_li[idx] = acdcSegment + self.guiWin.local_seg_base_model_name = base_model_name + except (IndexError, ImportError, KeyError) as e: + self.logger.warning( + f"Cannot import {base_model_name} model. Please install it first." + ) + self.signals.critical.emit( + ( + self, + f"Cannot import {base_model_name} model. " + "Please install it first.", + ) + ) + self.signals.finished.emit(self) + return + + win = self.guiWin.SegForLostIDsSettings["win"] + init_kwargs_new = self.guiWin.SegForLostIDsSettings["init_kwargs_new"] + args_new = self.guiWin.SegForLostIDsSettings["args_new"] + + model = utils.init_segm_model(acdcSegment, posData, init_kwargs_new) + if model is None: + self.logger.info("Segmentation model was not initialized correctly!") + self.signals.critical.emit( + (self, "Segmentation model was not initialized correctly!") + ) + self.signals.finished.emit(self) + return + if self._debug: + try: + model.setupLogger(self.guiwin.logger) + except Exception as e: + pass + + assigned_IDs = [] + missing_IDs_global = set() + original_lab = posData.lab.copy() + IDs_bboxs_list = [] + bboxs_list = [] + + curr_img = self.guiWin.getDisplayedImg1() + prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i - 1]["labels"]) + prev_IDs = set(posData.allData_li[frame_i - 1]["IDs"]) + + # should probably not paly so much with posData.lab, instead handle stuff myself + self.signals.initProgressBar.emit(2 * args_new["max_iterations"]) + new_labs = np.zeros( + [args_new["max_iterations"], *posData.lab.shape], dtype=np.uint32 + ) + for i in range(args_new["max_iterations"]): + curr_lab = self.guiWin.get_2Dlab(posData.lab) + tracked_lost_IDs = self.guiWin.getTrackedLostIDs() + new_unique_ID = self.guiWin.setBrushID(useCurrentLab=True, return_val=True) + + missing_IDs = prev_IDs - set(posData.IDs) - set(tracked_lost_IDs) + missing_IDs_global.update(missing_IDs) + + assigned_IDs_prev = assigned_IDs.copy() + out = segm_utils.single_cell_seg( + model, + prev_lab, + curr_lab, + curr_img, + missing_IDs, + new_unique_ID, + win, + posData, + distance_filler_growth=args_new["distance_filler_growth"], + overlap_threshold=args_new["overlap_threshold"], + padding=args_new["padding"], + ) + new_lab, assigned_IDs, IDs_bboxs, bboxs = out + + IDs_bboxs_list.append(IDs_bboxs) + bboxs_list.append(bboxs) + posData.lab = new_lab + self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) + newly_assigned_IDs = set(assigned_IDs) - set(assigned_IDs_prev) + self.emitTrackManuallyAddedObject(newly_assigned_IDs, True, False, False) + new_labs[i] = posData.lab.copy() + self.signals.progressBar.emit(1) + + if self._debug: + originals = [] + models = [] + + posData.lab = original_lab.copy() + + global_area_mean = np.mean([obj.area for obj in posData.rp]) + for IDs_bboxs, bboxs in zip(IDs_bboxs_list, bboxs_list): + model_lab = new_labs[i] + if self._debug: + originals.append(original_lab.copy()) + models.append(posData.lab.copy()) + + for IDs, bbox in zip(IDs_bboxs, bboxs): + box_x_min, box_x_max, box_y_min, box_y_max = bbox + original_bbox_lab = original_lab[ + box_x_min:box_x_max, box_y_min:box_y_max + ] + original_bbox_lab_cleared_borders = skimage.segmentation.clear_border( + original_bbox_lab + ) + box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] + + # original_bbox_lab[np.isin(original_bbox_lab, IDs)] = 0 should be a given. If not seg for lost IDs this recommended + + box_model_lab = skimage.segmentation.clear_border( + box_model_lab, buffer_size=1 + ) + + rp_model_lab = skimage.measure.regionprops(box_model_lab) + rp_original_lab = skimage.measure.regionprops(original_bbox_lab) + rp_original_lab_cleared = skimage.measure.regionprops( + original_bbox_lab_cleared_borders + ) + + original_IDs = [obj.label for obj in rp_original_lab] + areas = [obj.area for obj in rp_original_lab_cleared] + if len(areas) > 0: + area_mean = np.mean(areas) + else: + area_mean = global_area_mean + if args_new["allow_only_tracked_cells"]: + filtered_IDs = [ + obj.label + for obj in rp_model_lab + if obj.area > (1 - args_new["size_perc_diff"]) * area_mean + and obj.area < (1 + args_new["size_perc_diff"]) * area_mean + and obj.label not in original_IDs + and obj.label in missing_IDs_global + ] + else: + filtered_IDs = [ + obj.label + for obj in rp_model_lab + if obj.area > (1 - args_new["size_perc_diff"]) * area_mean + and obj.area < (1 + args_new["size_perc_diff"]) * area_mean + and obj.label not in original_IDs + ] + + if self._debug or DEBUG: + filtered_sizes = [ + (obj.label, obj.area) + for obj in rp_model_lab + if obj.label in filtered_IDs + ] + self.logger.info(f"Filtered sizes: {filtered_sizes}") + for label in filtered_IDs: + original_bbox_lab[box_model_lab == label] = ( + label # here the stuff should be tracked, so we keep the ID! + ) + + # original_lab[box_x_min:box_x_max, box_y_min:box_y_max] = original_bbox_lab + + self.signals.progressBar.emit(1) + + posData.lab = original_lab + + # if self._debug: + # originals = np.concatenate(originals, axis=0) + # models = np.concatenate(models, axis=0) + # self.emitSigShowImageDebug(originals) + # self.emitSigShowImageDebug(models) + + self.emitSigUpdateRP(wl_track_og_curr=True, wl_update=True) + self.emitSigStoreData(autosave=True) + + self.logger.info("Segmentation for lost IDs done.") + + self.signals.finished.emit(self) + + +class LabelRoiWorker(QObject): + finished = Signal() + critical = Signal(object) + progress = Signal(str, object) + sigProgressBar = Signal(int) + sigLabellingDone = Signal(object, bool) + + def __init__(self, Gui): + QObject.__init__(self) + self.logger = workerLogger(self.progress) + self.Gui = Gui + self.mutex = Gui.labelRoiMutex + self.waitCond = Gui.labelRoiWaitCond + self.exit = False + self.started = False + + def pause(self): + self.logger.log("Draw box around object to start magic labeller.") + self.mutex.lock() + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + def start(self, roiImg, posData, roiSecondChannel=None, isTimelapse=False): + self.posData = posData + self.isTimelapse = isTimelapse + self.imageData = roiImg + self.roiSecondChannel = roiSecondChannel + self.restart() + + def restart(self, log=True): + if log: + self.logger.log("Magic labeller started...") + self.started = True + self.waitCond.wakeAll() + + def _stop(self): + self.logger.log("Magic labeller backend process done. Closing it...") + self.exit = True + self.waitCond.wakeAll() + + def _segment_image(self, img, secondChannelImg): + if secondChannelImg is not None: + img = self.Gui.labelRoiModel.second_ch_img_to_stack(img, secondChannelImg) + + lab = core.segm_model_segment( + self.Gui.labelRoiModel, + img, + self.Gui.model_kwargs, + preproc_recipe=self.Gui.preproc_recipe, + posData=self.posData, + ) + if self.Gui.applyPostProcessing: + from cellacdc.workflow.pipelines.postprocess_nodes import apply_postprocess + + lab = apply_postprocess( + lab, + img, + self.posData, + self.posData.frame_i, + apply_postprocessing=True, + standard_postprocess_kwargs=self.Gui.standardPostProcessKwargs, + custom_postprocess_features=self.Gui.customPostProcessFeatures, + custom_postprocess_grouped_features=self.Gui.customPostProcessGroupedFeatures, + ) + return lab + + @worker_exception_handler + def run(self): + while not self.exit: + if self.exit: + break + elif self.started: + self.logger.log("Magic labeller is doing its magic...") + if self.isTimelapse: + segmData = np.zeros(self.imageData.shape, dtype=np.uint32) + for frame_i, img in enumerate(self.imageData): + if self.roiSecondChannel is not None: + secondChannelImg = self.roiSecondChannel[frame_i] + else: + secondChannelImg = None + lab = self._segment_image(img, secondChannelImg) + segmData[frame_i] = lab + self.sigProgressBar.emit(1) + else: + img = self.imageData + secondChannelImg = self.roiSecondChannel + segmData = self._segment_image(img, secondChannelImg) + + self.sigLabellingDone.emit(segmData, self.isTimelapse) + self.started = False + self.pause() + self.finished.emit() + + +class segmWorker(QObject): + finished = Signal(np.ndarray, float) + debug = Signal(object) + critical = Signal(object) + + def __init__( + self, + mainWin, + secondChannelData=None, + mutex: QWaitCondition = None, + waitCond: QMutex = None, + ): + QObject.__init__(self) + self.mainWin = mainWin + self.logger = self.mainWin.logger + self.z_range = None + self.secondChannelData = secondChannelData + self.mutex = mutex + self.waitCond = waitCond + + def emitDebug(self, to_debug): + if self.mutex is None: + return + + self.mutex.lock() + self.debug.emit(to_debug) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + + @worker_exception_handler + def run(self): + from cellacdc.workflow.adapters import ( + interactive_segm_context_from_main_win, + runnable_config_from_main_win, + ) + from cellacdc.workflow.pipelines.interactive_segm import ( + build_interactive_segm_graph, + ) + from cellacdc.workflow.state import InteractiveSegmState + + t0 = time.perf_counter() + ctx = interactive_segm_context_from_main_win( + self.mainWin, + second_channel_data=self.secondChannelData, + z_range=self.z_range, + ) + graph = build_interactive_segm_graph(ctx).compile() + state = graph.invoke( + InteractiveSegmState(main_win=self.mainWin), + runnable_config_from_main_win(self.mainWin), + ) + t1 = time.perf_counter() + self.finished.emit(state.lab, t1 - t0) + + +class segmVideoWorker(QObject): + finished = Signal(float) + debug = Signal(object) + critical = Signal(object) + progressBar = Signal(int) + progress = Signal(str, object) + + def __init__(self, posData, paramWin, model, startFrameNum, stopFrameNum): + QObject.__init__(self) + self.standardPostProcessKwargs = paramWin.standardPostProcessKwargs + self.applyPostProcessing = paramWin.applyPostProcessing + self.customPostProcessFeatures = paramWin.customPostProcessFeatures + self.customPostProcessGroupedFeatures = ( + paramWin.customPostProcessGroupedFeatures + ) + self.model_kwargs = paramWin.model_kwargs + self.preproc_recipe = paramWin.preproc_recipe + self.secondChannelName = paramWin.secondChannelName + self.model = model + self.posData = posData + self.startFrameNum = startFrameNum + self.stopFrameNum = stopFrameNum + self.logger = workerLogger(self.progress) + + @worker_exception_handler + def run(self): + from cellacdc.workflow.adapters import interactive_video_segm_context_from_worker + from cellacdc.workflow.pipelines.interactive_video_segm import ( + build_interactive_video_segm_graph, + ) + from cellacdc.workflow.state import InteractiveVideoSegmState + + t0 = time.perf_counter() + ctx = interactive_video_segm_context_from_worker(self) + graph = build_interactive_video_segm_graph(ctx).compile() + graph.invoke( + InteractiveVideoSegmState(pos_data=self.posData), + ) + t1 = time.perf_counter() + self.finished.emit(t1 - t0) + + +class PostProcessSegmWorker(QObject): + def __init__( + self, + postProcessKwargs, + customPostProcessGroupedFeatures, + customPostProcessFeatures, + mainWin, + ): + super().__init__() + self.signals = signals() + self.logger = workerLogger(self.signals.progress) + self.kwargs = postProcessKwargs + self.customPostProcessGroupedFeatures = customPostProcessGroupedFeatures + self.customPostProcessFeatures = customPostProcessFeatures + self.mainWin = mainWin + + @worker_exception_handler + def run(self): + mainWin = self.mainWin + data = mainWin.data + posData = data[mainWin.pos_i] + if len(data) > 1: + self.signals.initProgressBar.emit(len(data)) + else: + current_frame_i = posData.frame_i + self.signals.initProgressBar.emit(posData.SizeT - current_frame_i) + + self.logger.log("Post-process segmentation process started.") + self._run() + self.signals.finished.emit(None) + + def _run(self): + kwargs = self.kwargs + mainWin = self.mainWin + data = mainWin.data + + for posData in data: + current_frame_i = posData.frame_i + data_li = posData.allData_li[current_frame_i:] + for i, data_dict in enumerate(data_li): + frame_i = current_frame_i + i + visited = True + lab = data_dict["labels"] + if lab is None: + visited = False + try: + lab = posData.segm_data[frame_i] + except Exception as e: + return + + image = posData.img_data[frame_i] + + processed_lab = core.post_process_segm( + lab, return_delIDs=False, **kwargs + ) + if self.customPostProcessFeatures: + processed_lab = features.custom_post_process_segm( + posData, + self.customPostProcessGroupedFeatures, + processed_lab, + image, + posData.frame_i, + posData.filename, + posData.user_ch_name, + self.customPostProcessFeatures, + ) + if visited: + posData.allData_li[frame_i]["labels"] = processed_lab + # Get the rest of the stored metadata based on the new lab + posData.frame_i = frame_i + mainWin.get_data() + mainWin.store_data(autosave=False) + else: + posData.segm_data[frame_i] = lab + + self.signals.progressBar.emit(1) + + posData.frame_i = current_frame_i + + +class CreateConnected3Dsegm(BaseWorkerUtil): + sigAskAppendName = Signal(str, list) + sigAborted = Signal() + + def __init__(self, mainWin): + super().__init__(mainWin) + + def criticalSegmIsNot3D(self): + raise TypeError( + "Input segmentation masks are not 3D. You can use this utility " + "only on 3D z-stack data or 4D z-stack over time data." + ) + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = f"Select 3D segmentation file to connect" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit( + self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_acdc_df=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + if posData.segm_data.ndim == 3: + posData.segm_data = posData.segm_data[np.newaxis] + + self.logger.log("Connecting 3D objects...") + + numFrames = len(posData.segm_data) + self.signals.sigInitInnerPbar.emit(numFrames) + connectedSegmData = np.zeros_like(posData.segm_data) + for frame_i, lab in enumerate(posData.segm_data): + if lab.ndim != 3: + self.criticalSegmIsNot3D() + + connected_lab = core.connect_3Dlab_zboundaries(lab) + connectedSegmData[frame_i] = connected_lab + + self.signals.sigUpdateInnerPbar.emit(1) + + self.logger.log("Saving connected 3D segmentation file...") + segmFilename, ext = os.path.splitext(posData.segm_npz_path) + newSegmFilepath = f"{segmFilename}_{appendedName}.npz" + connectedSegmData = np.squeeze(connectedSegmData) + io.savez_compressed(newSegmFilepath, connectedSegmData) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class DelObjectsOutsideSegmROIWorker(QObject): + finished = Signal(object) + critical = Signal(object) + progress = Signal(str) + debug = Signal(object) + + def __init__( + self, + segm_roi_endname: os.PathLike, + segm_data: np.ndarray, + images_path: os.PathLike, + ): + QObject.__init__(self) + self.signals = signals() + self.segm_roi_endname = segm_roi_endname + self.segm_data = segm_data + self.images_path = images_path + + @worker_exception_handler + def run(self): + segm_roi_endname = self.segm_roi_endname + segm_roi_filepath, _ = load.get_path_from_endname( + segm_roi_endname, self.images_path + ) + self.progress.emit(f'Loading segmentation file "{segm_roi_filepath}"...') + segm_roi_data = load.load_image_file(segm_roi_filepath) + + self.progress.emit(f"Deleting objects outside of selected ROIs...") + cleared_segm_data, delIDs = transformation.del_objs_outside_segm_roi( + segm_roi_data, self.segm_data + ) + + self.finished.emit((self, cleared_segm_data, delIDs)) + + +class MagicPromptsWorker(QObject): + def __init__( + self, + posData, + image, + df_points, + model, + model_segment_kwargs, + image_origin=(0, 0, 0), + global_image=None, + ): + QObject.__init__(self) + + self.signals = signals() + self.posData = posData + self.image = image + if global_image is not None: + self.global_image = global_image + else: + self.global_image = image + self.df_points = df_points + self.image_origin = image_origin + self.model = model + self.model_segment_kwargs = model_segment_kwargs + + @worker_exception_handler + def run(self): + from cellacdc.segmenters_promptable import utils + + for row in self.df_points.itertuples(): + prompt_id = row.id + point = (row.z, row.y, row.x) + print(f"Adding point prompt {point} with id = {prompt_id}...") + parent_obj_id = row.Cell_ID if row.Cell_ID == prompt_id else 0 + self.model.add_prompt( + prompt=point, + prompt_id=prompt_id, + parent_obj_id=parent_obj_id, + image=self.image, + image_origin=self.image_origin, + prompt_type="point", + ) + + lab_out = self.model.segment( + self.global_image, lab=self.posData.lab, **self.model_segment_kwargs + ) + edited_IDs = self.df_points["Cell_ID"].unique() + + lab_new, lab_union, lab_interesection = utils.insert_model_output_into_labels( + self.posData.lab, lab_out, edited_IDs=edited_IDs + ) + + self.signals.finished.emit((lab_new, lab_union, lab_interesection)) + + +class FillHolesInSegWorker(BaseWorkerUtil): + sigAskAppendName = Signal(str) + sigAborted = Signal() + sigSelectSegmFiles = Signal(str, list) + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitSelectSegmFiles(self, exp_path, pos_foldernames): + self.mutex.lock() + self.sigSelectSegmFiles.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def emitAskAppendName(self, basename): + self.mutex.lock() + self.sigAskAppendName.emit(basename) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + @worker_exception_handler + def run(self): + expPaths = self.mainWin.expPaths + lab_paths_dict = dict() + unique_segm_files = set() + tot_segm_files = 0 + for exp_path, pos_foldernames in expPaths.items(): + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + for pos_folder in pos_foldernames: + imgs_path = os.path.join(exp_path, pos_folder, "Images") + lab_paths_dict[imgs_path] = self.endFilenameSegmTemp + tot_segm_files += len(self.endFilenameSegmTemp) + unique_segm_files.update(self.endFilenameSegmTemp) + + self.logger.info("Filling holes in segmentation masks...") + abort = self.emitAskAppendName("/".join(unique_segm_files)) + if abort: + self.sigAborted.emit() + return + self.signals.initProgressBar.emit(tot_segm_files) + for images_path, segm_file_names in lab_paths_dict.items(): + for segm_file_name in segm_file_names: + segm_data, segm_data_path = load.load_segm_file( + images_path, end_name_segm_file=segm_file_name, return_path=True + ) + segm_data_shape = segm_data.shape + segm_data_ndim = len(segm_data_shape) + if segm_data_ndim == 2: + segm_data = segm_data[np.newaxis, np.newaxis, ...] + elif segm_data_ndim == 3: + segm_data = segm_data[np.newaxis, ...] + elif segm_data_ndim == 4: + segm_data = segm_data + else: + raise NotImplementedError("This ndim is not supported!") + for i, stack in enumerate(segm_data): + for j, lab in enumerate(stack): + segm_data[i, j] = core.fill_holes_in_segmentation(lab) + + segm_data_save_path = segm_data_path.replace( + segm_file_name, f"{segm_file_name}{self.appendedName}" + ) + io.savez_compressed(segm_data_save_path, segm_data) + self.signals.progressBar.emit(1) + self.signals.finished.emit(self) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/tracking.py b/cellacdc/workers/tracking.py new file mode 100644 index 000000000..9db3c602b --- /dev/null +++ b/cellacdc/workers/tracking.py @@ -0,0 +1,778 @@ +"""Background Qt workers: tracking.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class trackingWorker(QObject): + finished = Signal() + critical = Signal(object) + progress = Signal(str) + debug = Signal(object) + + def __init__(self, posData, mainWin, video_to_track): + QObject.__init__(self) + self.mainWin = mainWin + self.posData = posData + self.mutex = QMutex() + self.signals = signals() + self.waitCond = QWaitCondition() + self.tracker = self.mainWin.tracker + self.track_params = self.mainWin.track_params + self.video_to_track = video_to_track + + def _get_first_untracked_lab(self): + start_frame_i = self.mainWin.start_n - 1 + frameData = self.posData.allData_li[start_frame_i] + lab = frameData["labels"] + if lab is not None: + return lab + else: + return self.posData.segm_data[start_frame_i] + + def _relabel_first_frame_labels(self, tracked_video): + first_untracked_lab = self._get_first_untracked_lab() + self.mainWin.setAllIDs() + max_allIDs = max(self.posData.allIDs, default=0) + max_tracked_video = tracked_video.max() + overall_max = max(max_allIDs, max_tracked_video) + uniqueID = overall_max + 1 + + tracked_video = transformation.retrack_based_on_untracked_first_frame( + tracked_video, first_untracked_lab, uniqueID=uniqueID + ) + return tracked_video + + def _setProgressBarIndefiniteWait(self): + try: + if hasattr(self.signals, "innerPbar_available"): + if self.signals.innerPbar_available: + # Use inner pbar of the GUI widget (top pbar is for positions) + self.signals.sigInitInnerPbar.emit(1) + return + else: + self.signals.initProgressBar.emit(1) + except Exception as err: + pass + + @worker_exception_handler + def run(self): + self.mutex.lock() + self.progress.emit("Tracking process started (more details in the terminal)...") + + trackerInputImage = None + self.track_params["signals"] = self.signals + if "image" in self.track_params: + trackerInputImage = self.track_params.pop("image") + start_frame_i = self.mainWin.start_n - 1 + stop_frame_n = self.mainWin.stop_n + + trackerInputImage = trackerInputImage[start_frame_i:stop_frame_n] + + tracked_video = core.tracker_track( + self.video_to_track, + self.tracker, + self.track_params, + intensity_img=trackerInputImage, + logger_func=self.progress.emit, + ) + + self._setProgressBarIndefiniteWait() + + # self.debug.emit((tracked_video, self)) + # self.waitCond.wait(self.mutex) + + self.progress.emit("Re-tracking first frame to ensure continuity...") + # Relabel first frame objects back to IDs they had before tracking + # (to ensure continuity with past untracked frames) + tracked_video = self._relabel_first_frame_labels(tracked_video) + + print("") + self.progress.emit("Generating annotations...") + acdc_df = self.posData.fromTrackerToAcdcDf( + self.tracker, tracked_video, start_frame_i=self.mainWin.start_n - 1 + ) + # Store new tracked video + current_frame_i = self.posData.frame_i + self.trackingOnNeverVisitedFrames = False + print("") + self.progress.emit("Storing tracked video...") + pbar = tqdm(total=len(tracked_video), ncols=100) + for rel_frame_i, lab in enumerate(tracked_video): + frame_i = rel_frame_i + self.mainWin.start_n - 1 + + if acdc_df is not None: + cca_cols = acdc_df.columns.intersection(cca_df_colnames_with_tree) + # Store cca_df if it is an output of the tracker + cca_df = acdc_df.loc[frame_i][cca_cols] + self.mainWin.store_cca_df( + frame_i=frame_i, cca_df=cca_df, mainThread=False, autosave=False + ) + + if self.posData.allData_li[frame_i]["labels"] is None: + # repeating tracking on a never visited frame + # --> modify only raw data and ask later what to do + self.posData.segm_data[frame_i] = lab + self.trackingOnNeverVisitedFrames = True + else: + # Get the rest of the stored metadata based on the new lab + self.posData.allData_li[frame_i]["labels"] = lab + self.posData.frame_i = frame_i + self.mainWin.get_data() + self.mainWin.store_data(autosave=False) + + pbar.update() + pbar.close() + + # Back to current frame + self.posData.frame_i = current_frame_i + self.mainWin.get_data() + self.mainWin.store_data(autosave=True) + self.mutex.unlock() + self.finished.emit() + + +class TrackSubCellObjectsWorker(BaseWorkerUtil): + sigAskAppendName = Signal(str, list) + sigCriticalNotEnoughSegmFiles = Signal(str) + sigAborted = Signal() + + def __init__(self, mainWin): + super().__init__(mainWin) + if mainWin.trackingMode.find("Delete both") != -1: + self.trackingMode = "delete_both" + elif mainWin.trackingMode.find("Delete sub-cellular") != -1: + self.trackingMode = "delete_sub" + elif mainWin.trackingMode.find("Delete cells") != -1: + self.trackingMode = "delete_cells" + elif mainWin.trackingMode.find("Only track") != -1: + self.trackingMode = "only_track" + + self.relabelSubObjLab = mainWin.relabelSubObjLab + self.IoAthresh = mainWin.IoAthresh + self.createThirdSegm = mainWin.createThirdSegm + self.thirdSegmAppendedText = mainWin.thirdSegmAppendedText + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + red_text = html_utils.span("OF THE CELLs") + self.mainWin.infoText = f"Select segmentation file {red_text}" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Critical --> there are not enough segm files + if len(self.mainWin.existingSegmEndNames) < 2: + self.mutex.lock() + self.sigCriticalNotEnoughSegmFiles.emit(exp_path) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + self.sigAborted.emit() + return + + self.cellsSegmEndFilename = self.mainWin.endFilenameSegm + + red_text = html_utils.span("OF THE SUB-CELLULAR OBJECTS") + self.mainWin.infoText = f"Select segmentation file {red_text}" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit( + self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_acdc_df=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + + # Load cells segmentation file + segmDataCells, segmCellsPath = load.load_segm_file( + images_path, + end_name_segm_file=self.cellsSegmEndFilename, + return_path=True, + ) + acdc_df_cells_endname = self.cellsSegmEndFilename.replace( + "_segm", "_acdc_output" + ) + acdc_df_cell, acdc_df_cells_path = load.load_acdc_df_file( + images_path, + end_name_acdc_df_file=acdc_df_cells_endname, + return_path=True, + ) + + if posData.SizeT > 1: + numFrames = min((len(segmDataCells), len(posData.segm_data))) + segmDataCells = segmDataCells[:numFrames] + posData.segm_data = posData.segm_data[:numFrames] + else: + numFrames = 1 + + self.signals.sigInitInnerPbar.emit(numFrames * 2) + + self.logger.log("Tracking sub-cellular objects...") + tracked = core.track_sub_cell_objects( + segmDataCells, + posData.segm_data, + self.IoAthresh, + how=self.trackingMode, + SizeT=numFrames, + sigProgress=self.signals.sigUpdateInnerPbar, + relabel_sub_obj_lab=self.relabelSubObjLab, + ) + ( + trackedSubSegmData, + trackedCellsSegmData, + numSubObjPerCell, + replacedSubIds, + ) = tracked + + self.logger.log("Saving tracked segmentation files...") + subSegmFilename, ext = os.path.splitext(posData.segm_npz_path) + trackedSubPath = f"{subSegmFilename}_{appendedName}.npz" + io.savez_compressed(trackedSubPath, trackedSubSegmData) + posData.saveIsSegm3Dmetadata(trackedSubPath) + + if trackedCellsSegmData is not None: + cellsSegmFilename, ext = os.path.splitext(segmCellsPath) + trackedCellsPath = f"{cellsSegmFilename}_{appendedName}.npz" + io.savez_compressed(trackedCellsPath, trackedCellsSegmData) + + if self.createThirdSegm: + self.logger.log( + f"Generating segmentation from " + f'"{self.cellsSegmEndFilename} - {appendedName}" ' + "difference..." + ) + if trackedCellsSegmData is not None: + parentSegmData = trackedCellsSegmData + else: + parentSegmData = segmDataCells + diffSegmData = parentSegmData.copy() + diffSegmData[trackedSubSegmData != 0] = 0 + + self.logger.log("Saving difference segmentation file...") + diffSegmPath = ( + f"{subSegmFilename}_{appendedName}" + f"_{self.thirdSegmAppendedText}.npz" + ) + io.savez_compressed(diffSegmPath, diffSegmData) + posData.saveIsSegm3Dmetadata(diffSegmPath) + del diffSegmData + + if self.relabelSubObjLab: + # When we relabel the sub-cell objs acdc_df is not valid anymore + # because IDs could be different + posData.acdc_df = None + + self.logger.log("Generating acdc_output tables...") + # Update or create acdc_df for sub-cellular objects + acdc_dfs_tracked = core.track_sub_cell_objects_acdc_df( + trackedSubSegmData, + posData.acdc_df, + replacedSubIds, + numSubObjPerCell, + tracked_cells_segm_data=trackedCellsSegmData, + cells_acdc_df=acdc_df_cell, + SizeT=posData.SizeT, + sigProgress=self.signals.sigUpdateInnerPbar, + ) + subTrackedAcdcDf, trackedAcdcDf = acdc_dfs_tracked + + self.logger.log("Saving acdc_output tables...") + subAcdcDfFilename, _ = os.path.splitext(posData.acdc_output_csv_path) + subTrackedAcdcDfPath = f"{subAcdcDfFilename}_{appendedName}.csv" + subTrackedAcdcDf.to_csv(subTrackedAcdcDfPath) + + if trackedAcdcDf is not None: + basen = posData.basename + cellsSegmFilename = os.path.basename(segmCellsPath) + cellsSegmFilename, ext = os.path.splitext(cellsSegmFilename) + cellsSegmEndname = cellsSegmFilename[len(basen) :] + trackedAcdcDfEndname = cellsSegmEndname.replace( + "segm", "acdc_output" + ) + trackedAcdcDfFilename = f"{basen}{trackedAcdcDfEndname}" + trackedAcdcDfFilename = ( + f"{trackedAcdcDfFilename}_{appendedName}.csv" + ) + trackedAcdcDfPath = os.path.join( + posData.images_path, trackedAcdcDfFilename + ) + trackedAcdcDf.to_csv(trackedAcdcDfPath) + + if self.createThirdSegm: + if posData.SizeT == 1: + parentSegmData = parentSegmData[np.newaxis] + subAcdcDfFilename = subSegmFilename.replace( + ".npz", ".csv" + ).replace("segm", "acdc_output") + diffAcdcDfPath = ( + f"{subAcdcDfFilename}_{appendedName}" + f"_{self.thirdSegmAppendedText}.csv" + ) + third_segm_acdc_df = ( + core.track_sub_cell_objects_third_segm_acdc_df( + parentSegmData, trackedAcdcDf + ) + ) + third_segm_acdc_df.to_csv(diffAcdcDfPath) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class ApplyTrackInfoWorker(BaseWorkerUtil): + def __init__( + self, + parentWin, + endFilenameSegm, + trackInfoCsvPath, + trackedSegmFilename, + trackColsInfo, + posPath, + ): + super().__init__(parentWin) + self.endFilenameSegm = endFilenameSegm + self.trackInfoCsvPath = trackInfoCsvPath + self.trackedSegmFilename = trackedSegmFilename + self.trackColsInfo = trackColsInfo + self.posPath = posPath + + @worker_exception_handler + def run(self): + self.logger.log("Loading segmentation file...") + self.signals.initProgressBar.emit(0) + imagesPath = os.path.join(self.posPath, "Images") + segmFilename = [ + f + for f in utils.listdir(imagesPath) + if f.endswith(f"{self.endFilenameSegm}.npz") + ][0] + segmFilePath = os.path.join(imagesPath, segmFilename) + segmData = np.load(segmFilePath)["arr_0"] + + self.logger.log("Loading table containing tracking info...") + df = pd.read_csv(self.trackInfoCsvPath) + + frameIndexCol = self.trackColsInfo["frameIndexCol"] + + parentIDcol = self.trackColsInfo["parentIDcol"] + pbarMax = len(df[frameIndexCol].unique()) + self.signals.initProgressBar.emit(pbarMax) + + # Apply tracking info + result = core.apply_tracking_from_table( + segmData, + self.trackColsInfo, + df, + signal=self.signals.progressBar, + logger=self.logger.log, + pbarMax=pbarMax, + ) + trackedData, trackedIDsMapper, deleteIDsMapper = result + + if self.trackedSegmFilename: + trackedSegmFilepath = os.path.join(imagesPath, self.trackedSegmFilename) + else: + trackedSegmFilepath = os.path.join(segmFilePath) + + self.signals.initProgressBar.emit(0) + self.logger.log("Saving tracked segmentation file...") + io.savez_compressed(trackedSegmFilepath, trackedData) + + mapperPath = os.path.splitext(trackedSegmFilepath)[0] + mapperJsonPath = f"{mapperPath}_deletedIDs_mapper.json" + mapperJsonName = os.path.basename(mapperJsonPath) + self.logger.log(f"Saving deleted IDs to {mapperJsonName}...") + with open(mapperJsonPath, "w") as file: + file.write(json.dumps(deleteIDsMapper)) + + mapperPath = os.path.splitext(trackedSegmFilepath)[0] + mapperJsonPath = f"{mapperPath}_replacedIDs_mapper.json" + mapperJsonName = os.path.basename(mapperJsonPath) + self.logger.log(f"Saving IDs replacements to {mapperJsonName}...") + with open(mapperJsonPath, "w") as file: + file.write(json.dumps(trackedIDsMapper)) + + self.logger.log("Generating acdc_output table...") + acdc_df = None + if not self.trackedSegmFilename: + # Fix existing acdc_df + acdcEndname = self.endFilenameSegm.replace("_segm", "_acdc_output") + acdcFilename = [ + f + for f in utils.listdir(imagesPath) + if f.endswith(f"{acdcEndname}.csv") + ] + if acdcFilename: + acdcFilePath = os.path.join(imagesPath, acdcFilename[0]) + acdc_df = pd.read_csv(acdcFilePath, index_col=["frame_i", "Cell_ID"]) + + if acdc_df is not None: + acdc_df = core.apply_trackedIDs_mapper_to_acdc_df( + trackedIDsMapper, deleteIDsMapper, acdc_df + ) + else: + acdc_dfs = [] + keys = [] + for frame_i, lab in enumerate(trackedData): + rp = skimage.measure.regionprops(lab) + acdc_df_frame_i = utils.getBaseAcdcDf(rp) + acdc_dfs.append(acdc_df_frame_i) + keys.append(frame_i) + + acdc_df = pd.concat(acdc_dfs, keys=keys, names=["frame_i", "Cell_ID"]) + segmFilename = os.path.basename(trackedSegmFilepath) + acdcFilename = re.sub(segm_re_pattern, "_acdc_output", segmFilename) + acdcFilePath = os.path.join(imagesPath, acdcFilename) + + self.signals.initProgressBar.emit(pbarMax) + parentIDcol = self.trackColsInfo["parentIDcol"] + trackIDsCol = self.trackColsInfo["trackIDsCol"] + if parentIDcol != "None": + self.logger.log(f'Adding lineage info from "{parentIDcol}" column...') + acdc_df = core.add_cca_info_from_parentID_col( + df, + acdc_df, + frameIndexCol, + trackIDsCol, + parentIDcol, + len(segmData), + signal=self.signals.progressBar, + maskID_colname=self.trackColsInfo["maskIDsCol"], + x_colname=self.trackColsInfo["xCentroidCol"], + y_colname=self.trackColsInfo["yCentroidCol"], + ) + + self.logger.log("Saving acdc_output table...") + acdc_df.to_csv(acdcFilePath) + + self.signals.finished.emit(self) + + +class ToSymDivWorker(QObject): + progressBar = Signal(int, int, float) + + def __init__(self, mainWin): + QObject.__init__(self) + self.signals = signals() + self.abort = False + self.logger = workerLogger(self.signals.progress) + self.mutex = QMutex() + self.waitCond = QWaitCondition() + self.mainWin = mainWin + + def emitSelectSegmFiles(self, exp_path, pos_foldernames): + self.mutex.lock() + self.signals.sigSelectSegmFiles.emit(exp_path, pos_foldernames) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + self.missingAnnotErrors = {} + tot_pos = len(pos_foldernames) + self.allPosDataInputs = [] + posDatas = [] + self.logger.log("-" * 30) + expFoldername = os.path.basename(exp_path) + + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.signals.finished.emit(self) + return + + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + pos_path = os.path.join(exp_path, pos) + images_path = os.path.join(pos_path, "Images") + basename, chNames = utils.getBasenameAndChNames( + images_path, useExt=(".tif", ".h5") + ) + + self.signals.sigUpdatePbarDesc.emit(f"Loading {pos_path}...") + + # Use first found channel, it doesn't matter for metrics + for chName in chNames: + file_path = utils.getChannelFilePath(images_path, chName) + if file_path: + break + else: + raise FileNotFoundError( + f'None of the channels "{chNames}" were found in the path ' + f'"{images_path}".' + ) + + # Load data + posData = load.loadData(file_path, chName) + posData.getBasenameAndChNames(useExt=(".tif", ".h5")) + + posData.loadOtherFiles( + load_segm_data=False, + load_acdc_df=True, + load_metadata=True, + loadSegmInfo=True, + ) + + posDatas.append(posData) + + self.allPosDataInputs.append({"file_path": file_path, "chName": chName}) + + # Iterate pos and calculate metrics + numPos = len(self.allPosDataInputs) + for p, posDataInputs in enumerate(self.allPosDataInputs): + file_path = posDataInputs["file_path"] + chName = posDataInputs["chName"] + + posData = load.loadData(file_path, chName) + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames(useExt=(".tif", ".h5")) + posData.buildPaths() + posData.loadImgData() + + posData.loadOtherFiles( + load_segm_data=False, + load_acdc_df=True, + end_filename_segm=self.mainWin.endFilenameSegm, + ) + if not posData.acdc_df_found: + relPath = ( + f"...{os.sep}{expFoldername}{os.sep}{posData.pos_foldername}" + ) + self.logger.log( + f'WARNING: Skipping "{relPath}" ' + f"because acdc_output.csv file was not found." + ) + self.missingAnnotErrors[relPath] = ( + f'
    FileNotFoundError: the Positon "{relPath}" ' + "does not have the acdc_output.csv file.
    " + ) + + continue + + acdc_df_filename = os.path.basename(posData.acdc_output_csv_path) + self.logger.log( + f'Loaded path:\nACDC output file name: "{acdc_df_filename}"' + ) + + self.logger.log("Building tree...") + try: + tree = core.LineageTree(posData.acdc_df) + error = tree.build() + if isinstance(error, KeyError): + self.logger.log(str(error)) + + self.logger.log( + "WARNING: Annotations missing in " + f'"{posData.acdc_output_csv_path}"' + ) + self.missingAnnotErrors[acdc_df_filename] = str(error) + continue + elif error is not None: + raise error + posData.acdc_df = tree.df + except Exception as error: + traceback_format = traceback.format_exc() + self.logger.log(traceback_format) + self.errors[error] = traceback_format + + try: + posData.acdc_df.to_csv(posData.acdc_output_csv_path) + except PermissionError: + traceback_str = traceback.format_exc() + self.mutex.lock() + self.signals.sigPermissionError.emit( + traceback_str, posData.acdc_output_csv_path + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + posData.acdc_df.to_csv(posData.acdc_output_csv_path) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class CopyAllLostObjectsWorker(QObject): + navigateToFrame = Signal(int) + returnToFrame = Signal(int) + copyLostObjectMask = Signal(int) + refreshRp = Signal() + progressBar = Signal(int) + finished = Signal(object) + critical = Signal(object) + + def __init__(self, gui, posData, for_future_frame_n, max_overlap_perc): + super().__init__() + self.gui = gui + self.posData = posData + self.for_future_frame_n = for_future_frame_n + self.max_overlap_perc = max_overlap_perc + + @worker_exception_handler + def run(self): + current_frame_i = self.posData.frame_i + last_visited_frame_i = self.gui.get_last_tracked_i() + last_copied_frame_i = current_frame_i + self.for_future_frame_n + 1 + frames_range = (current_frame_i, last_copied_frame_i) + overlap_warning = False + output = {} + + for frame_i in range(*frames_range): + if frame_i == self.posData.SizeT: + break + + if frame_i > self.posData.frame_i: + # Main thread navigates, runs tracking, updates rp/IDs, etc + self.navigateToFrame.emit(frame_i) + + for lostObj in skimage.measure.regionprops(self.gui.lostObjImage): + overlap = np.count_nonzero( + self.gui.currentLab2D[lostObj.slice][lostObj.image] + ) + overlap_perc = overlap / lostObj.area * 100 + if overlap_perc > self.max_overlap_perc: + overlap_warning = True + continue + + self.copyLostObjectMask.emit(lostObj.label) + + # Refresh rp so the next frame's updateLostNewCurrentIDs sees the + # copied IDs as belonging to this frame and marks them lost there. + self.refreshRp.emit() + + self.progressBar.emit(1) + + if self.for_future_frame_n == 0: + output["overlap_warning"] = overlap_warning + self.finished.emit(output) + return + + # Back to current frame + self.returnToFrame.emit(current_frame_i) + + if last_visited_frame_i < last_copied_frame_i: + output["doReinitLastSegmFrame"] = True + output["last_visited_frame_i"] = last_visited_frame_i + + output["overlap_warning"] = overlap_warning + self.finished.emit(output) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + signals, + workerLogger, + worker_exception_handler, +) + diff --git a/cellacdc/workers/util.py b/cellacdc/workers/util.py new file mode 100644 index 000000000..249b56c57 --- /dev/null +++ b/cellacdc/workers/util.py @@ -0,0 +1,709 @@ +"""Background Qt workers: util.""" + +import re +import os +import shutil +import time +import json +import concurrent.futures +from functools import partial +from collections import defaultdict, deque +import itertools + +from typing import Union, List, Dict, Callable, Any, Tuple, Iterable + +from functools import wraps +import numpy as np +import pandas as pd +import h5py +import traceback + +import skimage.io +import skimage.measure +import skimage.exposure + +import queue + +from tqdm import tqdm + +from qtpy.QtCore import Signal, QObject, QMutex, QWaitCondition + +from cellacdc import html_utils + +from .. import load, utils, core, prompts, printl, config, segm_re_pattern, io +from .. import transformation, measurements, cca_functions +from ..path import copy_or_move_tree +from .. import features, plot +from .. import core +from .. import cca_df_colnames, lineage_tree_cols, default_annot_df +from .. import cca_df_colnames_with_tree +from .. import cli +from ..tools import resize +from .. import segm_utils + +DEBUG = False + +from ._base import ( + BaseWorkerUtil, +) + +class FromImajeJroiToSegmNpzWorker(BaseWorkerUtil): + sigSelectRoisProps = Signal(str, object, bool) + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitSelectRoisProps(self, roi_filepath, TZYX_shape, is_multi_pos): + self.mutex.lock() + self.sigSelectRoisProps.emit(roi_filepath, TZYX_shape, is_multi_pos) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + @worker_exception_handler + def run(self): + import roifile + + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + abort = self.emitSelectFilesWithText( + exp_path, pos_foldernames, "imagej_rois", ext=".zip" + ) + if abort: + self.signals.finished.emit(self) + return + + self.askRoiPreferences = True + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameRoi = self.mainWin.endFilenameWithText + ls = utils.listdir(images_path) + rois_filepaths = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameRoi}.zip") + ] + + if not rois_filepaths: + self.logger.log( + "[WARNING]: The following Position folder does not " + f"contain any file ending with {endFilenameRoi}. " + f'Skipping it. "{os.path.join(exp_path, pos)}")' + ) + continue + + rois_filepath = rois_filepaths[0] + + if self.askRoiPreferences: + is_multi_pos = len(pos_foldernames) > 1 + self.logger.log("Loading image data to get image shape...") + TZYX_shape = load.get_tzyx_shape(images_path) + abort = self.emitSelectRoisProps( + rois_filepath, TZYX_shape, is_multi_pos + ) + if abort: + self.signals.finished.emit(self) + return + + self.askRoiPreferences = not self.useSamePropsForNextPos + elif self.areAllRoisSelected: + rois = roifile.roiread(rois_filepath) + self.IDsToRoisMapper = {i + i: roi for roi in enumerate(rois)} + else: + # Use same ID of previous position + rois = roifile.roiread(rois_filepath) + IDsToRoisMapper = {i + i: roi for i, roi in enumerate(rois)} + self.IDsToRoisMapper = { + ID: IDsToRoisMapper[ID] for ID in self.IDsToRoisMapper.keys() + } + + self.logger.log("Generating segm mask from ROIs...") + segm_data = utils.from_imagej_rois_to_segm_data( + TZYX_shape, + self.IDsToRoisMapper, + self.rescaleRoisSizes, + self.repeatRoisZslicesRange, + ) + + segm_filepath = rois_filepath.replace("imagej_rois", "segm").replace( + ".zip", ".npz" + ) + self.logger.log(f'Saving segm mask to "{segm_filepath}"...') + io.savez_compressed(segm_filepath, segm_data) + + self.signals.finished.emit(self) + + +class ToImajeJroiWorker(BaseWorkerUtil): + def __init__(self, mainWin): + super().__init__(mainWin) + + @worker_exception_handler + def run(self): + from roifile import ImagejRoi, roiwrite + + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.signals.finished.emit(self) + return + + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + + files_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ] + + if not files_path: + self.logger.log( + "[WARNING]: The following Position folder does not " + f"contain any file ending with {endFilenameSegm}. " + f'Skipping it. "{os.path.join(exp_path, pos)}")' + ) + continue + + file_path = files_path[0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + + if posData.SizeT > 1: + rois = [] + max_ID = posData.segm_data.max() + for t, lab in enumerate(posData.segm_data): + rois_t = utils.from_lab_to_imagej_rois( + lab, ImagejRoi, t=t, SizeT=posData.SizeT, max_ID=max_ID + ) + rois.extend(rois_t) + else: + rois = utils.from_lab_to_imagej_rois(posData.segm_data, ImagejRoi) + + roi_filepath = posData.segm_npz_path.replace(".npz", ".zip") + roi_filepath = roi_filepath.replace("_segm", "_imagej_rois") + + try: + os.remove(roi_filepath) + except Exception as e: + pass + + roiwrite(roi_filepath, rois) + + self.signals.finished.emit(self) + + +class ToObjCoordsWorker(BaseWorkerUtil): + def __init__(self, mainWin): + super().__init__(mainWin) + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.signals.finished.emit(self) + return + + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + + if posData.SizeT == 1: + posData.segm_data = posData.segm_data[np.newaxis] + + dfs = [] + n_frames = len(posData.segm_data) + self.signals.initProgressBar.emit(n_frames) + for frame_i, lab in enumerate(posData.segm_data): + df_coords_i = utils.from_lab_to_obj_coords(lab) + dfs.append(df_coords_i) + self.signals.progressBar.emit(1) + df_filepath = posData.segm_npz_path.replace(".npz", ".csv") + df_filepath = df_filepath.replace("_segm", "_objects_coordinates") + + keys = list(range(len(posData.segm_data))) + df = pd.concat(dfs, keys=keys, names=["frame_i"]) + + self.signals.initProgressBar.emit(0) + df.to_csv(df_filepath) + + self.signals.finished.emit(self) + + +class Stack2DsegmTo3Dsegm(BaseWorkerUtil): + sigAskAppendName = Signal(str, list) + sigAborted = Signal() + + def __init__(self, mainWin, SizeZ): + super().__init__(mainWin) + self.SizeZ = SizeZ + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = f"Select 2D segmentation file to stack" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit( + self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + endFilenameSegm = self.mainWin.endFilenameSegm + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_acdc_df=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + if posData.segm_data.ndim == 2: + posData.segm_data = posData.segm_data[np.newaxis] + + self.logger.log("Stacking 2D into 3D objects...") + + numFrames = len(posData.segm_data) + self.signals.sigInitInnerPbar.emit(numFrames) + T, Y, X = posData.segm_data.shape + newShape = (T, self.SizeZ, Y, X) + segmData2D = np.zeros(newShape, dtype=np.uint32) + for frame_i, lab in enumerate(posData.segm_data): + stacked_lab = core.stack_2Dlab_to_3D(lab, self.SizeZ) + segmData2D[frame_i] = stacked_lab + + self.signals.sigUpdateInnerPbar.emit(1) + + self.logger.log("Saving stacked 3D segmentation file...") + segmFilename, ext = os.path.splitext(posData.segm_npz_path) + newSegmFilepath = f"{segmFilename}_{appendedName}.npz" + segmData2D = np.squeeze(segmData2D) + io.savez_compressed(newSegmFilepath, segmData2D) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class FilterObjsFromCoordsTable(BaseWorkerUtil): + sigAskAppendName = Signal(str, list) + sigAborted = Signal() + sigSetColumnsNames = Signal(object, object, object) + + def __init__(self, mainWin): + super().__init__(mainWin) + + def emitSetColumnsNames(self, columns, categories, optionalCategories): + self.mutex.lock() + self.sigSetColumnsNames.emit(columns, categories, optionalCategories) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def getColumnsCategories( + self, df_coords, exp_path, pos_foldernames, endFilenameSegm + ): + columns = df_coords.columns.to_list() + categories = ["X coord. column", "Y coord. column"] + optionalCategories = [] + + images_path = os.path.join(exp_path, pos_foldernames[0], "Images") + metadata_df = load.load_metadata_df(images_path) + SizeT = float(metadata_df.at["SizeT", "values"]) + SizeZ = float(metadata_df.at["SizeZ", "values"]) + + segmData = load.load_segm_file(images_path, end_name_segm_file=endFilenameSegm) + + if segmData.ndim == 4: + categories.append("Z coord. column") + categories.append("Frame index column") + elif segmData.ndim == 3: + if SizeZ > 1 and SizeT == 1: + # 3D z-stack data + categories.append("Z coord. column") + else: + optionalCategories.append("Z coord. column") + + if SizeT > 1: + # 3D time-lapse + categories.append("Frame index column") + else: + optionalCategories.append("Frame index column") + else: + optionalCategories.append("Z coord. column") + optionalCategories.append("Frame index column") + + if len(pos_foldernames) > 1: + categories.append("Position_n") + else: + optionalCategories.append("Position_n") + + return columns, categories, optionalCategories + + def getDfCoords( + self, df_coords, selectedColumnsPerCategory, pos_foldername, frame_i + ): + pos_col = selectedColumnsPerCategory.get("Position_n", "None") + frame_i_col = selectedColumnsPerCategory.get("Frame index column", "None") + x_col = selectedColumnsPerCategory["X coord. column"] + y_col = selectedColumnsPerCategory["Y coord. column"] + if pos_col != "None": + df_coords = df_coords[df_coords[pos_col] == pos_foldername] + if frame_i_col != "None": + df_coords = df_coords[df_coords[frame_i_col] == frame_i] + + xy_cols = [x_col, y_col] + + df_out = pd.DataFrame( + index=df_coords.index, data=df_coords[xy_cols].values, columns=["x", "y"] + ) + z_col = selectedColumnsPerCategory.get("Z coord. column", "None") + if z_col != "None": + df_out["z"] = df_coords[z_col] + + return df_out + + @worker_exception_handler + def run(self): + debugging = False + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + self.errors = {} + tot_pos = len(pos_foldernames) + + self.mainWin.infoText = f"Select segmentation file to filter" + abort = self.emitSelectSegmFiles(exp_path, pos_foldernames) + if abort: + self.sigAborted.emit() + return + endFilenameSegm = self.mainWin.endFilenameSegm + + self.logger.log("Asking to select the CSV table file...") + + abort = self.emitSelectFile( + exp_path, + "Select CSV table file with coordinates to filter", + "CSV (*.csv)", + ) + if abort: + self.sigAborted.emit() + return + + self.logger.log(f"Loading table file `{self.mainWin.selectedFilepath}`..") + df_coords = pd.read_csv(self.mainWin.selectedFilepath) + + columns, categories, optionalCategories = self.getColumnsCategories( + df_coords, exp_path, pos_foldernames, endFilenameSegm + ) + + abort = self.emitSetColumnsNames(columns, categories, optionalCategories) + if abort: + self.sigAborted.emit() + return + + selectedColumnsPerCategory = self.mainWin.selectedColumnsPerCategory + + # Ask appendend name + self.mutex.lock() + self.sigAskAppendName.emit( + self.mainWin.endFilenameSegm, self.mainWin.existingSegmEndNames + ) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + if self.abort: + self.sigAborted.emit() + return + + appendedName = self.appendedName + self.signals.initProgressBar.emit(len(pos_foldernames)) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.sigAborted.emit() + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + + images_path = os.path.join(exp_path, pos, "Images") + ls = utils.listdir(images_path) + file_path = [ + os.path.join(images_path, f) + for f in ls + if f.endswith(f"{endFilenameSegm}.npz") + ][0] + + posData = load.loadData(file_path, "") + + self.signals.sigUpdatePbarDesc.emit(f"Processing {posData.pos_path}") + + posData.getBasenameAndChNames() + posData.buildPaths() + + posData.loadOtherFiles( + load_segm_data=True, + load_acdc_df=True, + load_metadata=True, + end_filename_segm=endFilenameSegm, + ) + if posData.SizeT == 1: + posData.segm_data = posData.segm_data[np.newaxis] + + self.logger.log("Filtering objects...") + + numFrames = len(posData.segm_data) + self.signals.sigInitInnerPbar.emit(numFrames) + filteredSegmData = np.zeros_like(posData.segm_data) + for frame_i, lab in enumerate(posData.segm_data): + df_coords_frame_i = self.getDfCoords( + df_coords, selectedColumnsPerCategory, pos, frame_i + ) + if df_coords_frame_i.empty: + num_frames_missing = len(posData.segm_data[frame_i:]) + self.signals.sigUpdateInnerPbar.emit(num_frames_missing) + filteredSegmData = filteredSegmData[:frame_i] + break + + filtered_lab = core.filter_segm_objs_from_table_coords( + lab, df_coords_frame_i + ) + filteredSegmData[frame_i] = filtered_lab + + self.signals.sigUpdateInnerPbar.emit(1) + + self.logger.log("Saving filtered segmentation file...") + segmFilename, ext = os.path.splitext(posData.segm_npz_path) + newSegmFilepath = f"{segmFilename}_{appendedName}.npz" + filteredSegmData = np.squeeze(filteredSegmData) + io.savez_compressed(newSegmFilepath, filteredSegmData) + + self.signals.progressBar.emit(1) + + self.signals.finished.emit(self) + + +class ScreenRecorderWorker(QObject): + sigGrabScreen = Signal() + finished = Signal() + + def __init__(self, screenRecorderWin, folder_path): + QObject.__init__(self) + self.screenRecorderWin = screenRecorderWin + self.folder_path = folder_path + + def run(self): + for i in range(4): + fn = f"shot_{i:03}.jpg" + grab_path = os.path.join(self.folder_path, fn) + screen = self.screenRecorderWin.screen() + screenshot = screen.grabWindow(self.screenRecorderWin.winId()) + screenshot.save(grab_path, "jpg") + print(grab_path) + time.sleep(0.2) + + self.finished.emit() + + +class ApplyImageFilterWorker(QObject): + finished = Signal(object) + critical = Signal(object) + progress = Signal(str) + + def __init__(self, filter_func, input_data): + QObject.__init__(self) + self.filter_func = filter_func + self.input_data = input_data + + @worker_exception_handler + def run(self): + self.progress.emit("Filtering image...") + filtered_data = self.filter_func(self.input_data) + self.finished.emit(filtered_data) + + +class ResizeUtilWorker(BaseWorkerUtil): + sigSetResizeProps = Signal(str) + + def emitSetResizeProps(self, input_path): + self.mutex.lock() + self.sigSetResizeProps.emit(input_path) + self.waitCond.wait(self.mutex) + self.mutex.unlock() + return self.abort + + def __init__(self, mainWin): + super().__init__(mainWin) + + def validateOutputPath(self, path): + if path is None: + return + + images_path = utils.validate_images_path(path, create_dirs_tree=True) + return images_path + + @worker_exception_handler + def run(self): + expPaths = self.mainWin.expPaths + tot_exp = len(expPaths) + + self.signals.initProgressBar.emit(0) + for i, (exp_path, pos_foldernames) in enumerate(expPaths.items()): + abort = self.emitSetResizeProps(exp_path) + if abort: + self.signals.finished.emit(self) + return + + tot_pos = len(pos_foldernames) + for p, pos in enumerate(pos_foldernames): + if self.abort: + self.signals.finished.emit(self) + return + + self.logger.log( + f"Processing experiment n. {i + 1}/{tot_exp}, " + f"{pos} ({p + 1}/{tot_pos})" + ) + images_path = os.path.join(exp_path, pos, "Images") + + rf = self.resizeFactor + text_to_append = self.textToAppend + images_path_out = self.validateOutputPath(self.expFolderpathOut) + if images_path_out is None: + images_path_out = images_path + resize.run( + images_path, + rf, + text_to_append=text_to_append, + images_path_out=images_path_out, + ) + + self.signals.finished.emit(self) + +# Sibling imports (deferred to avoid import cycles) +from ._base import ( + worker_exception_handler, +) + diff --git a/cellacdc/workflow/__init__.py b/cellacdc/workflow/__init__.py new file mode 100644 index 000000000..b745b2263 --- /dev/null +++ b/cellacdc/workflow/__init__.py @@ -0,0 +1,60 @@ +"""LangGraph-style workflow modeling for Cell-ACDC pipelines.""" + +from .adapters import ( + configure_measurements_kernel_for_cli, + runnable_config_from_segm_kernel, + sync_segm_kernel_from_context, + update_workflow_context_from_segm_kernel, + workflow_context_from_ini, + workflow_context_from_segm_kernel, +) +from .constants import END, START +from .graph import CompiledStateGraph, StateGraph +from .runnable import Runnable, RunnableConfig, RunnableLambda, RunnableSequence +from .state import ( + BatchState, + FullWorkflowState, + InteractiveSegmContext, + InteractiveSegmState, + InteractiveVideoSegmContext, + InteractiveVideoSegmState, + MeasurementsBatchContext, + MeasurementsContext, + MeasurementsGuiBatchContext, + MeasurementsGuiContext, + MeasurementsGuiState, + MeasurementsState, + PositionState, + WorkflowContext, +) + +__all__ = [ + "BatchState", + "CompiledStateGraph", + "configure_measurements_kernel_for_cli", + "END", + "FullWorkflowState", + "InteractiveSegmContext", + "InteractiveSegmState", + "InteractiveVideoSegmContext", + "InteractiveVideoSegmState", + "MeasurementsBatchContext", + "MeasurementsContext", + "MeasurementsGuiBatchContext", + "MeasurementsGuiContext", + "MeasurementsGuiState", + "MeasurementsState", + "PositionState", + "Runnable", + "RunnableConfig", + "RunnableLambda", + "RunnableSequence", + "START", + "StateGraph", + "WorkflowContext", + "runnable_config_from_segm_kernel", + "sync_segm_kernel_from_context", + "update_workflow_context_from_segm_kernel", + "workflow_context_from_ini", + "workflow_context_from_segm_kernel", +] diff --git a/cellacdc/workflow/adapters.py b/cellacdc/workflow/adapters.py new file mode 100644 index 000000000..3e783ea40 --- /dev/null +++ b/cellacdc/workflow/adapters.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import Any + +from .runnable import RunnableConfig +from .state import WorkflowContext + + +def workflow_context_from_segm_kernel(kernel: Any) -> WorkflowContext: + """Build a mutable workflow context from a SegmKernel instance.""" + return WorkflowContext( + user_ch_name=kernel.user_ch_name, + segm_endname=kernel.segm_endname, + model_name=kernel.model_name, + tracker_name=kernel.tracker_name, + do_tracking=kernel.do_tracking, + do_postprocess=kernel.do_postprocess, + do_save=kernel.do_save, + is_segm_3d=kernel.isSegm3D, + use_roi=kernel.use_ROI, + use_freehand_roi=kernel.use_freehand_ROI, + use_3d_data_for_2d_segm=kernel.use3DdataFor2Dsegm, + second_channel_name=kernel.second_channel_name, + image_channel_tracker=kernel.image_channel_tracker, + size_t=kernel.SizeT, + size_z=kernel.SizeZ, + model_kwargs=dict(kernel.model_kwargs or {}), + init_model_kwargs=dict(kernel.init_model_kwargs or {}), + track_params=dict(kernel.track_params or {}), + init_tracker_kwargs=dict(kernel.init_tracker_kwargs or {}), + standard_postprocess_kwargs=dict(kernel.standard_postrocess_kwargs or {}), + custom_postprocess_features=dict(kernel.custom_postproc_features or {}), + custom_postprocess_grouped_features=dict( + kernel.custom_postproc_grouped_features or {} + ), + preproc_recipe=kernel.preproc_recipe, + reduce_memory_usage=getattr(kernel, "reduce_memory_usage", False), + model=kernel.model, + tracker=kernel.tracker, + is_segment3dt_available=kernel.is_segment3DT_available, + inner_pbar_available=kernel.innerPbar_available, + signals=kernel.signals, + ) + + +def sync_segm_kernel_from_context(kernel: Any, ctx: WorkflowContext) -> None: + """Copy mutable pipeline resources back onto the kernel after a graph run.""" + kernel.model = ctx.model + kernel.tracker = ctx.tracker + kernel.is_segment3DT_available = ctx.is_segment3dt_available + kernel.model_kwargs = ctx.model_kwargs + kernel.track_params = ctx.track_params + kernel.init_model_kwargs = ctx.init_model_kwargs + + +def runnable_config_from_segm_kernel(kernel: Any) -> RunnableConfig: + return RunnableConfig( + logger_func=kernel.logger_func, + signals=kernel.signals, + metadata={"model_name": kernel.model_name}, + ) + + +def update_workflow_context_from_segm_kernel( + ctx: WorkflowContext, kernel: Any +) -> WorkflowContext: + """Refresh context fields that may change between batch positions.""" + ctx.model = kernel.model + ctx.tracker = kernel.tracker + ctx.is_segment3dt_available = kernel.is_segment3DT_available + ctx.model_kwargs = dict(kernel.model_kwargs or {}) + ctx.track_params = dict(kernel.track_params or {}) + ctx.signals = kernel.signals + return ctx + + +def _parse_custom_postproc_features_grouped(workflow_params: dict[str, Any]) -> dict: + custom_postproc_grouped_features: dict[str, Any] = {} + for section, options in workflow_params.items(): + if not section.startswith("postprocess_features."): + continue + category = section.split(".")[-1] + for option, value in options.items(): + if option == "names": + values = value.strip("\n").strip().split("\n") + custom_postproc_grouped_features[category] = values + continue + channel = option + if category not in custom_postproc_grouped_features: + custom_postproc_grouped_features[category] = {channel: [value]} + elif channel not in custom_postproc_grouped_features[category]: + custom_postproc_grouped_features[category][channel] = [value] + else: + custom_postproc_grouped_features[category][channel].append(value) + return custom_postproc_grouped_features + + +def workflow_context_from_ini(workflow_params: dict[str, Any]) -> WorkflowContext: + """Build a workflow context directly from parsed INI workflow parameters.""" + from cellacdc import config + + initialization = workflow_params["initialization"] + return WorkflowContext( + user_ch_name=initialization["user_ch_name"], + segm_endname=initialization.get("segm_endname", "segm.npz"), + model_name=initialization.get("model_name", ""), + tracker_name=initialization.get("tracker_name", ""), + do_tracking=initialization.get("do_tracking", False), + do_postprocess=initialization.get("do_postprocess", True), + do_save=initialization.get("do_save", True), + is_segm_3d=initialization.get("isSegm3D", False), + use_roi=initialization.get("use_ROI", True), + use_freehand_roi=initialization.get("use_freehand_ROI", True), + use_3d_data_for_2d_segm=initialization.get("use3DdataFor2Dsegm", False), + second_channel_name=initialization.get("second_channel_name"), + image_channel_tracker=initialization.get("image_channel_tracker"), + size_t=workflow_params["metadata"]["SizeT"], + size_z=workflow_params["metadata"]["SizeZ"], + model_kwargs=dict(workflow_params.get("segmentation_model_params", {})), + init_model_kwargs=dict( + workflow_params.get("init_segmentation_model_params", {}) + ), + track_params=dict(workflow_params.get("tracker_params", {})), + init_tracker_kwargs=dict(workflow_params.get("init_tracker_params", {})), + standard_postprocess_kwargs=dict( + workflow_params.get("standard_postprocess_features", {}) + ), + custom_postprocess_features=dict( + workflow_params.get("custom_postprocess_features", {}) + ), + custom_postprocess_grouped_features=_parse_custom_postproc_features_grouped( + workflow_params + ), + preproc_recipe=config.preprocess_ini_items_to_recipe(workflow_params), + reduce_memory_usage=initialization.get("reduce_memory_usage", False), + ) + + +def interactive_segm_context_from_main_win(main_win, second_channel_data=None, z_range=None): + from cellacdc.workflow.state import InteractiveSegmContext + + return InteractiveSegmContext( + model=main_win.model, + model_kwargs=main_win.model_kwargs, + apply_postprocessing=main_win.applyPostProcessing, + standard_postprocess_kwargs=main_win.standardPostProcessKwargs, + custom_postprocess_features=main_win.customPostProcessFeatures, + custom_postprocess_grouped_features=main_win.customPostProcessGroupedFeatures, + segment_3d=main_win.segment3D, + second_channel_data=second_channel_data, + z_range=z_range, + ) + + +def runnable_config_from_main_win(main_win): + return RunnableConfig(logger_func=main_win.logger.info) + + +def interactive_video_segm_context_from_worker(worker) -> InteractiveSegmContext: + from cellacdc.workflow.state import InteractiveVideoSegmContext + + return InteractiveVideoSegmContext( + model=worker.model, + model_kwargs=worker.model_kwargs, + apply_postprocessing=worker.applyPostProcessing, + standard_postprocess_kwargs=worker.standardPostProcessKwargs, + custom_postprocess_features=worker.customPostProcessFeatures, + custom_postprocess_grouped_features=worker.customPostProcessGroupedFeatures, + preproc_recipe=worker.preproc_recipe, + second_channel_data=getattr(worker, "secondChannelData", None), + start_frame_num=worker.startFrameNum, + stop_frame_num=worker.stopFrameNum, + progress_callback=worker.progressBar, + logger_func=worker.logger.log, + ) + + +def configure_measurements_kernel_for_cli( + kernel: Any, + channels: list[str] | str, + end_filename_segm: str = "segm", + *, + channels_to_skip: list[str] | None = None, + channels_to_process: list[str] | None = None, + is_segm_3d: bool = False, + is_timelapse: bool = False, + is_zstack: bool = False, +) -> Any: + """Configure a ComputeMeasurementsKernel for headless graph runs (no GUI/INI).""" + from cellacdc import measurements + + if isinstance(channels, str): + channels = [name.strip() for name in channels.split("\n") if name.strip()] + + kernel.init_args(channels, end_filename_segm) + kernel.chNamesToSkip = list(channels_to_skip or []) + kernel.chNamesToProcess = list(channels_to_process or channels) + kernel.metricsToSave = None + kernel.metricsToSkip = {ch: [] for ch in channels} + kernel.calc_for_each_zslice_mapper = {ch: False for ch in channels} + kernel.calc_size_for_each_zslice = False + kernel.save_object_counts_table = False + kernel.mixedChCombineMetricsToSkip = [] + kernel.regionPropsToSave = ( + measurements.get_props_names_3D() + if is_segm_3d + else measurements.get_props_names() + ) + kernel.sizeMetricsToSave = list( + measurements.get_size_metrics_desc(is_segm_3d, is_timelapse).keys() + ) + kernel.chIndipendCustomMetricsToSave = list( + measurements.ch_indipend_custom_metrics_desc( + is_zstack, + isSegm3D=is_segm_3d, + ).keys() + ) + return kernel diff --git a/cellacdc/workflow/constants.py b/cellacdc/workflow/constants.py new file mode 100644 index 000000000..5448156f6 --- /dev/null +++ b/cellacdc/workflow/constants.py @@ -0,0 +1,2 @@ +START = "__start__" +END = "__end__" diff --git a/cellacdc/workflow/graph.py b/cellacdc/workflow/graph.py new file mode 100644 index 000000000..8cfd29d4f --- /dev/null +++ b/cellacdc/workflow/graph.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, TypeVar + +from .constants import END, START +from .runnable import RunnableConfig +from .state import merge_state + +StateT = TypeVar("StateT") +ContextT = TypeVar("ContextT") + +NodeFn = Callable[[StateT, ContextT, RunnableConfig], dict[str, Any]] +RouteFn = Callable[[StateT, ContextT], str] + + +@dataclass(slots=True) +class CompiledStateGraph(Generic[StateT, ContextT]): + """Executable graph returned by StateGraph.compile().""" + + nodes: dict[str, NodeFn[StateT, ContextT]] + edges: dict[str, str] + conditional_edges: dict[str, tuple[RouteFn[StateT, ContextT], dict[str, str]]] + entrypoint: str + state_type: type[StateT] + context: ContextT + + def invoke( + self, + state: StateT, + config: RunnableConfig | None = None, + ) -> StateT: + config = config or RunnableConfig() + node_name = self.entrypoint + while node_name != END: + node = self.nodes[node_name] + update = node(state, self.context, config) + state = merge_state(state, update) + if node_name in self.conditional_edges: + route_fn, mapping = self.conditional_edges[node_name] + node_name = mapping[route_fn(state, self.context)] + continue + node_name = self.edges[node_name] + return state + + +@dataclass +class StateGraph(Generic[StateT, ContextT]): + """Declarative workflow graph (LangGraph StateGraph analogue).""" + + state_type: type[StateT] + context: ContextT + _nodes: dict[str, NodeFn[StateT, ContextT]] = field(default_factory=dict) + _edges: dict[str, str] = field(default_factory=dict) + _conditional_edges: dict[str, tuple[RouteFn[StateT, ContextT], dict[str, str]]] = ( + field(default_factory=dict) + ) + _entrypoint: str | None = None + + def add_node(self, name: str, fn: NodeFn[StateT, ContextT]) -> StateGraph: + self._nodes[name] = fn + return self + + def set_entry_point(self, name: str) -> StateGraph: + self._entrypoint = name + return self + + def add_edge(self, start: str, end: str) -> StateGraph: + self._edges[start] = end + return self + + def add_conditional_edges( + self, + start: str, + route: RouteFn[StateT, ContextT], + mapping: dict[str, str], + ) -> StateGraph: + self._conditional_edges[start] = (route, mapping) + return self + + def compile(self) -> CompiledStateGraph[StateT, ContextT]: + if self._entrypoint is None: + raise ValueError("Graph has no entry point. Call set_entry_point().") + if self._entrypoint not in self._nodes: + raise ValueError(f"Unknown entry point node: {self._entrypoint}") + return CompiledStateGraph( + nodes=dict(self._nodes), + edges=dict(self._edges), + conditional_edges=dict(self._conditional_edges), + entrypoint=self._entrypoint, + state_type=self.state_type, + context=self.context, + ) + + def get_graph(self) -> dict[str, Any]: + """Return a serializable graph description for tests and debugging.""" + return { + "nodes": sorted(self._nodes), + "edges": dict(self._edges), + "conditional_edges": { + name: sorted(mapping) + for name, (_, mapping) in self._conditional_edges.items() + }, + "entrypoint": self._entrypoint, + } diff --git a/cellacdc/workflow/pipelines/__init__.py b/cellacdc/workflow/pipelines/__init__.py new file mode 100644 index 000000000..efe19ac83 --- /dev/null +++ b/cellacdc/workflow/pipelines/__init__.py @@ -0,0 +1,23 @@ +from .batch import batch_state_from_workflow_params, run_measurements_batch, run_segm_batch +from .batch_graph import build_segm_batch_graph +from .full_workflow import build_full_workflow_graph +from .interactive_segm import build_interactive_segm_graph +from .interactive_video_segm import build_interactive_video_segm_graph +from .measurements import build_measurements_position_graph +from .measurements_batch_graph import build_measurements_batch_graph +from .measurements_gui import build_gui_measurements_graph +from .segm import build_position_segm_graph + +__all__ = [ + "batch_state_from_workflow_params", + "build_full_workflow_graph", + "build_gui_measurements_graph", + "build_interactive_segm_graph", + "build_interactive_video_segm_graph", + "build_measurements_batch_graph", + "build_measurements_position_graph", + "build_position_segm_graph", + "build_segm_batch_graph", + "run_measurements_batch", + "run_segm_batch", +] diff --git a/cellacdc/workflow/pipelines/batch.py b/cellacdc/workflow/pipelines/batch.py new file mode 100644 index 000000000..8f48e162d --- /dev/null +++ b/cellacdc/workflow/pipelines/batch.py @@ -0,0 +1,102 @@ +"""Batch execution helpers for workflow graphs.""" + +from __future__ import annotations + +from typing import Any + +from tqdm import tqdm + +from ..runnable import RunnableConfig +from ..state import BatchState, BatchWorkflowContext, PositionState, WorkflowContext +from .batch_graph import build_segm_batch_graph +from .measurements_batch_graph import build_measurements_batch_graph +from ..state import ( + MeasurementsBatchContext, + MeasurementsContext, + MeasurementsGuiBatchContext, + MeasurementsGuiState, +) +from .measurements_gui_batch_graph import build_gui_measurements_batch_graph + + +def run_segm_batch( + ctx: WorkflowContext, + paths: list[str], + stop_frame_numbers: list[int], + config: RunnableConfig | None = None, + progress: tqdm | None = None, +) -> list[PositionState]: + """Run the position segmentation graph for each path.""" + config = config or RunnableConfig() + if progress is not None: + config.metadata["progress"] = progress + + batch_ctx = BatchWorkflowContext(position_ctx=ctx) + graph = build_segm_batch_graph(batch_ctx).compile() + batch_state = graph.invoke( + BatchState(paths=paths, stop_frame_numbers=stop_frame_numbers), + config, + ) + return batch_state.results + + +def run_measurements_batch( + kernel: Any, + paths: list[str], + stop_frame_numbers: list[int], + end_filename_segm: str, + config: RunnableConfig | None = None, + progress: tqdm | None = None, +) -> list[Any]: + config = config or RunnableConfig(logger_func=kernel.log) + if progress is not None: + config.metadata["progress"] = progress + + measurements_ctx = MeasurementsContext( + end_filename_segm=end_filename_segm, + kernel=kernel, + ) + batch_ctx = MeasurementsBatchContext(measurements_ctx=measurements_ctx) + graph = build_measurements_batch_graph(batch_ctx).compile() + batch_state = graph.invoke( + BatchState(paths=paths, stop_frame_numbers=stop_frame_numbers), + config, + ) + return batch_state.results + + +def run_gui_measurements_batch( + kernel: Any, + paths: list[str], + stop_frame_numbers: list[int], + end_filename_segm: str, + *, + compute_metrics_worker: Any | None = None, + save_data_worker: Any | None = None, + save_metrics: bool = True, + config: RunnableConfig | None = None, + progress: tqdm | None = None, +) -> list[MeasurementsGuiState]: + config = config or RunnableConfig(logger_func=kernel.log) + if progress is not None: + config.metadata["progress"] = progress + + batch_ctx = MeasurementsGuiBatchContext( + kernel=kernel, + compute_metrics_worker=compute_metrics_worker, + save_data_worker=save_data_worker, + save_metrics=save_metrics, + end_filename_segm=end_filename_segm, + ) + graph = build_gui_measurements_batch_graph(batch_ctx).compile() + batch_state = graph.invoke( + BatchState(paths=paths, stop_frame_numbers=stop_frame_numbers), + config, + ) + return batch_state.results + + +def batch_state_from_workflow_params(workflow_params: dict[str, Any]) -> BatchState: + paths = workflow_params["paths_info"]["paths"] + stop_frames = [int(n) for n in workflow_params["paths_info"]["stop_frame_numbers"]] + return BatchState(paths=paths, stop_frame_numbers=stop_frames) diff --git a/cellacdc/workflow/pipelines/batch_graph.py b/cellacdc/workflow/pipelines/batch_graph.py new file mode 100644 index 000000000..f33c0c12d --- /dev/null +++ b/cellacdc/workflow/pipelines/batch_graph.py @@ -0,0 +1,57 @@ +"""Parent graph for batch segmentation over many positions.""" + +from __future__ import annotations + +from typing import Any + +from ..constants import END +from ..graph import StateGraph +from ..runnable import RunnableConfig +from ..state import BatchState, BatchWorkflowContext, PositionState +from .segm import build_position_segm_graph + + +def _position_graph(ctx: BatchWorkflowContext): + if ctx.position_graph is None: + ctx.position_graph = build_position_segm_graph(ctx.position_ctx).compile() + return ctx.position_graph + + +def process_position( + state: BatchState, + ctx: BatchWorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + path = state.paths[state.current_index] + stop_frame_n = state.stop_frame_numbers[state.current_index] + config.logger_func(f'\nProcessing "{path}"...') + result = _position_graph(ctx).invoke( + PositionState(img_path=path, stop_frame_n=stop_frame_n), + config, + ) + results = list(state.results) + results.append(result) + progress = config.metadata.get("progress") + if progress is not None: + progress.update(1) + return {"results": results, "current_index": state.current_index + 1} + + +def _route_batch(state: BatchState, _ctx: BatchWorkflowContext) -> str: + if state.current_index >= len(state.paths): + return END + return "process_position" + + +def build_segm_batch_graph( + ctx: BatchWorkflowContext, +) -> StateGraph[BatchState, BatchWorkflowContext]: + graph = StateGraph(BatchState, ctx) + graph.add_node("process_position", process_position) + graph.set_entry_point("process_position") + graph.add_conditional_edges( + "process_position", + _route_batch, + {"process_position": "process_position", END: END}, + ) + return graph diff --git a/cellacdc/workflow/pipelines/full_workflow.py b/cellacdc/workflow/pipelines/full_workflow.py new file mode 100644 index 000000000..030bdae67 --- /dev/null +++ b/cellacdc/workflow/pipelines/full_workflow.py @@ -0,0 +1,60 @@ +"""Top-level INI workflow orchestration graph.""" + +from __future__ import annotations + +from typing import Any + +from ..constants import END +from ..graph import StateGraph +from ..runnable import RunnableConfig +from ..state import FullWorkflowState + + +def run_segm_phase( + state: FullWorkflowState, + ctx: Any, + config: RunnableConfig, +) -> dict[str, Any]: + if not state.run_segm: + return {"segm_done": True} + + from cellacdc._run import run_segm_workflow + + run_segm_workflow(state.segm_params, ctx.logger, ctx.log_path) + return {"segm_done": True} + + +def run_measurements_phase( + state: FullWorkflowState, + ctx: Any, + config: RunnableConfig, +) -> dict[str, Any]: + if not state.run_measurements or state.measurements_params is None: + return {"measurements_done": True} + + from cellacdc._run import run_measurements_workflow + + run_measurements_workflow(state.measurements_params, ctx.logger, ctx.log_path) + return {"measurements_done": True} + + +def _route_after_segm(state: FullWorkflowState, _ctx: Any) -> str: + if state.run_measurements: + return "run_measurements_phase" + return END + + +def build_full_workflow_graph( + ctx: Any, +) -> StateGraph[FullWorkflowState, Any]: + graph = StateGraph(FullWorkflowState, ctx) + graph.add_node("run_segm_phase", run_segm_phase) + graph.add_node("run_measurements_phase", run_measurements_phase) + graph.set_entry_point("run_segm_phase") + graph.add_conditional_edges( + "run_segm_phase", + _route_after_segm, + {"run_measurements_phase": "run_measurements_phase", END: END}, + ) + graph.add_edge("run_measurements_phase", END) + return graph diff --git a/cellacdc/workflow/pipelines/interactive_segm.py b/cellacdc/workflow/pipelines/interactive_segm.py new file mode 100644 index 000000000..53fa8df94 --- /dev/null +++ b/cellacdc/workflow/pipelines/interactive_segm.py @@ -0,0 +1,28 @@ +"""Interactive single-frame segmentation graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..state import InteractiveSegmContext, InteractiveSegmState +from . import interactive_segm_nodes as nodes + + +def build_interactive_segm_graph( + ctx: InteractiveSegmContext, +) -> StateGraph[InteractiveSegmState, InteractiveSegmContext]: + graph = StateGraph(InteractiveSegmState, ctx) + graph.add_node("prepare_frame", nodes.prepare_frame) + graph.add_node("segment_frame", nodes.segment_frame) + graph.add_node("postprocess_frame", nodes.postprocess_frame) + graph.add_node("merge_result", nodes.merge_result) + graph.set_entry_point("prepare_frame") + graph.add_edge("prepare_frame", "segment_frame") + graph.add_conditional_edges( + "segment_frame", + nodes._route_postprocess, + {"postprocess_frame": "postprocess_frame", "merge_result": "merge_result"}, + ) + graph.add_edge("postprocess_frame", "merge_result") + graph.add_edge("merge_result", END) + return graph diff --git a/cellacdc/workflow/pipelines/interactive_segm_nodes.py b/cellacdc/workflow/pipelines/interactive_segm_nodes.py new file mode 100644 index 000000000..fc857f1d7 --- /dev/null +++ b/cellacdc/workflow/pipelines/interactive_segm_nodes.py @@ -0,0 +1,112 @@ +"""Interactive single-frame segmentation nodes for the main viewer.""" + +from __future__ import annotations + +import time +from typing import Any + +from cellacdc import core + +from ..runnable import RunnableConfig +from ..state import InteractiveSegmContext, InteractiveSegmState +from .postprocess_nodes import apply_postprocess + + +def prepare_frame( + state: InteractiveSegmState, + ctx: InteractiveSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + main_win = state.main_win + pos_data = main_win.data[main_win.pos_i] + + if ctx.segment_3d: + img = main_win.getDisplayedZstack() + if ctx.z_range is not None: + start_z, stop_z = ctx.z_range + img = img[start_z : stop_z + 1] + else: + img = main_win.getDisplayedImg1() + + lab = __import__("numpy").zeros_like(pos_data.segm_data[0]) + start_z_slice = 0 + if ctx.z_range is not None: + start_z_slice, _ = ctx.z_range + elif not ctx.segment_3d and pos_data.isSegm3D: + idx = (pos_data.filename, pos_data.frame_i) + start_z_slice = pos_data.segmInfo_df.at[idx, "z_slice_used_gui"] + + return { + "pos_data": pos_data, + "img": img, + "lab": lab, + "start_z_slice": start_z_slice, + } + + +def segment_frame( + state: InteractiveSegmState, + ctx: InteractiveSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + img = state.img + if ctx.second_channel_data is not None: + img = ctx.model.second_ch_img_to_stack(img, ctx.second_channel_data) + + lab = core.segm_model_segment( + ctx.model, + img, + ctx.model_kwargs, + frame_i=state.pos_data.frame_i, + posData=state.pos_data, + start_z_slice=state.start_z_slice, + ) + state.pos_data.saveSamEmbeddings(logger_func=config.logger_func) + return {"img": img, "segmented_lab": lab} + + +def postprocess_frame( + state: InteractiveSegmState, + ctx: InteractiveSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + if not ctx.apply_postprocessing: + return {} + + lab = apply_postprocess( + state.segmented_lab, + state.img, + state.pos_data, + state.pos_data.frame_i, + apply_postprocessing=True, + standard_postprocess_kwargs=ctx.standard_postprocess_kwargs, + custom_postprocess_features=ctx.custom_postprocess_features, + custom_postprocess_grouped_features=ctx.custom_postprocess_grouped_features, + ) + return {"segmented_lab": lab} + + +def merge_result( + state: InteractiveSegmState, + ctx: InteractiveSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + lab = state.lab + segmented = state.segmented_lab + + if ctx.z_range is not None: + start_z, stop_z = ctx.z_range + lab[start_z : stop_z + 1] = segmented + elif not ctx.segment_3d and pos_data.isSegm3D: + idx = (pos_data.filename, pos_data.frame_i) + z = pos_data.segmInfo_df.at[idx, "z_slice_used_gui"] + lab[z] = segmented + else: + lab = segmented + + return {"lab": lab} + + +def _route_postprocess(_state: InteractiveSegmState, ctx: InteractiveSegmContext) -> str: + return "postprocess_frame" if ctx.apply_postprocessing else "merge_result" diff --git a/cellacdc/workflow/pipelines/interactive_video_segm.py b/cellacdc/workflow/pipelines/interactive_video_segm.py new file mode 100644 index 000000000..042634ec1 --- /dev/null +++ b/cellacdc/workflow/pipelines/interactive_video_segm.py @@ -0,0 +1,24 @@ +"""Interactive timelapse segmentation graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..state import InteractiveVideoSegmContext, InteractiveVideoSegmState +from . import interactive_video_segm_nodes as nodes + + +def build_interactive_video_segm_graph( + ctx: InteractiveVideoSegmContext, +) -> StateGraph[InteractiveVideoSegmState, InteractiveVideoSegmContext]: + graph = StateGraph(InteractiveVideoSegmState, ctx) + graph.add_node("extend_segm_data", nodes.extend_segm_data) + graph.add_node("prepare_video_stack", nodes.prepare_video_stack) + graph.add_node("segment_video_frames", nodes.segment_video_frames) + graph.add_node("finalize_video_run", nodes.finalize_video_run) + graph.set_entry_point("extend_segm_data") + graph.add_edge("extend_segm_data", "prepare_video_stack") + graph.add_edge("prepare_video_stack", "segment_video_frames") + graph.add_edge("segment_video_frames", "finalize_video_run") + graph.add_edge("finalize_video_run", END) + return graph diff --git a/cellacdc/workflow/pipelines/interactive_video_segm_nodes.py b/cellacdc/workflow/pipelines/interactive_video_segm_nodes.py new file mode 100644 index 000000000..81b7ff8d4 --- /dev/null +++ b/cellacdc/workflow/pipelines/interactive_video_segm_nodes.py @@ -0,0 +1,118 @@ +"""Interactive timelapse segmentation nodes for the main viewer.""" + +from __future__ import annotations + +import time +from typing import Any + +import numpy as np +import pandas as pd + +from cellacdc import core + +from ..runnable import RunnableConfig +from ..state import InteractiveVideoSegmContext, InteractiveVideoSegmState +from .postprocess_nodes import apply_postprocess + + +def extend_segm_data( + state: InteractiveVideoSegmState, + ctx: InteractiveVideoSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + segm_data = pos_data.segm_data + stop_frame_num = ctx.stop_frame_num + + if stop_frame_num <= len(segm_data): + return {"segm_data": segm_data} + + extended_shape = (stop_frame_num, *segm_data.shape[1:]) + extended_segm_data = np.zeros(extended_shape, dtype=segm_data.dtype) + extended_segm_data[: len(segm_data)] = segm_data + + if len(extended_shape) == 4 or pos_data.SizeZ == 1: + pos_data.segm_data = extended_segm_data + return {"segm_data": extended_segm_data} + + num_added_frames = len(extended_segm_data) - len(segm_data) + half_z = int(pos_data.SizeZ / 2) + segm_info_extended = pd.DataFrame( + { + "filename": [pos_data.filename] * num_added_frames, + "frame_i": list(range(len(segm_data), len(extended_segm_data))), + "z_slice_used_gui": [half_z] * num_added_frames, + "which_z_proj_gui": ["single z-slice"] * num_added_frames, + } + ).set_index(["filename", "frame_i"]) + pos_data.segmInfo_df = pd.concat([pos_data.segmInfo_df, segm_info_extended]) + pos_data.segmInfo_df.to_csv(pos_data.segmInfo_df_csv_path) + pos_data.segm_data = extended_segm_data + return {"segm_data": extended_segm_data} + + +def prepare_video_stack( + state: InteractiveVideoSegmState, + ctx: InteractiveVideoSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + img_data = pos_data.img_data[ctx.start_frame_num - 1 : ctx.stop_frame_num] + is_4d = img_data.ndim == 4 + is_2d_segm = pos_data.segm_data.ndim == 3 + z_slices = None + if is_4d and is_2d_segm: + z_slices = pos_data.segmInfo_df.loc[pos_data.filename, "z_slice_used_gui"] + return {"img_data": img_data, "z_slices": z_slices} + + +def segment_video_frames( + state: InteractiveVideoSegmState, + ctx: InteractiveVideoSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + progress = ctx.progress_callback + + for i, img in enumerate(state.img_data): + frame_i = i + ctx.start_frame_num - 1 + if ctx.second_channel_data is not None: + img = ctx.model.second_ch_img_to_stack(img, ctx.second_channel_data) + if state.z_slices is not None: + img = img[state.z_slices.loc[frame_i]] + + lab = core.segm_model_segment( + ctx.model, + img, + ctx.model_kwargs, + frame_i=frame_i, + preproc_recipe=ctx.preproc_recipe, + posData=pos_data, + ) + pos_data.saveSamEmbeddings(logger_func=ctx.logger_func) + + if ctx.apply_postprocessing: + lab = apply_postprocess( + lab, + img, + pos_data, + frame_i, + apply_postprocessing=True, + standard_postprocess_kwargs=ctx.standard_postprocess_kwargs, + custom_postprocess_features=ctx.custom_postprocess_features, + custom_postprocess_grouped_features=ctx.custom_postprocess_grouped_features, + ) + + pos_data.segm_data[frame_i] = lab + if progress is not None: + progress.emit(1) + + return {} + + +def finalize_video_run( + state: InteractiveVideoSegmState, + ctx: InteractiveVideoSegmContext, + config: RunnableConfig, +) -> dict[str, Any]: + return {} diff --git a/cellacdc/workflow/pipelines/measurements.py b/cellacdc/workflow/pipelines/measurements.py new file mode 100644 index 000000000..2c6a8da23 --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements.py @@ -0,0 +1,30 @@ +"""Measurements position pipeline graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..state import MeasurementsContext, MeasurementsState +from . import measurements_nodes as nodes + + +def build_measurements_position_graph( + ctx: MeasurementsContext, +) -> StateGraph[MeasurementsState, MeasurementsContext]: + graph = StateGraph(MeasurementsState, ctx) + graph.add_node("load_position", nodes.load_position) + graph.add_node("validate_segm", nodes.validate_segm) + graph.add_node("compute_and_save", nodes.compute_and_save) + graph.set_entry_point("load_position") + graph.add_conditional_edges( + "load_position", + nodes._route_after_load, + {"validate_segm": "validate_segm", END: END}, + ) + graph.add_conditional_edges( + "validate_segm", + nodes._route_after_validate, + {"compute_and_save": "compute_and_save", END: END}, + ) + graph.add_edge("compute_and_save", END) + return graph diff --git a/cellacdc/workflow/pipelines/measurements_batch_graph.py b/cellacdc/workflow/pipelines/measurements_batch_graph.py new file mode 100644 index 000000000..74af5e5d5 --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements_batch_graph.py @@ -0,0 +1,57 @@ +"""Measurements batch parent graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..runnable import RunnableConfig +from ..state import BatchState, MeasurementsBatchContext, MeasurementsState +from .measurements import build_measurements_position_graph + + +def _position_graph(ctx: MeasurementsBatchContext): + if ctx.position_graph is None: + ctx.position_graph = build_measurements_position_graph( + ctx.measurements_ctx + ).compile() + return ctx.position_graph + + +def process_position( + state: BatchState, + ctx: MeasurementsBatchContext, + config: RunnableConfig, +) -> dict[str, Any]: + path = state.paths[state.current_index] + stop_frame_n = state.stop_frame_numbers[state.current_index] + config.logger_func(f'\nProcessing "{path}"...') + result = _position_graph(ctx).invoke( + MeasurementsState(img_path=path, stop_frame_n=stop_frame_n), + config, + ) + results = list(state.results) + results.append(result) + progress = config.metadata.get("progress") + if progress is not None: + progress.update(1) + return {"results": results, "current_index": state.current_index + 1} + + +def _route_batch(state: BatchState, _ctx: MeasurementsBatchContext) -> str: + if state.current_index >= len(state.paths): + return END + return "process_position" + + +def build_measurements_batch_graph( + ctx: MeasurementsBatchContext, +) -> StateGraph[BatchState, MeasurementsBatchContext]: + graph = StateGraph(BatchState, ctx) + graph.add_node("process_position", process_position) + graph.set_entry_point("process_position") + graph.add_conditional_edges( + "process_position", + _route_batch, + {"process_position": "process_position", END: END}, + ) + return graph diff --git a/cellacdc/workflow/pipelines/measurements_gui.py b/cellacdc/workflow/pipelines/measurements_gui.py new file mode 100644 index 000000000..3fa959d2f --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements_gui.py @@ -0,0 +1,35 @@ +"""GUI measurements position pipeline graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..state import MeasurementsGuiContext, MeasurementsGuiState +from . import measurements_gui_nodes as nodes + + +def build_gui_measurements_graph( + ctx: MeasurementsGuiContext, + *, + pos_data_loaded: bool = False, +) -> StateGraph[MeasurementsGuiState, MeasurementsGuiContext]: + graph = StateGraph(MeasurementsGuiState, ctx) + graph.add_node("load_position", nodes.load_position) + graph.add_node("prepare_gui_run", nodes.prepare_gui_run) + graph.add_node("compute_metrics_frames", nodes.compute_metrics_frames) + graph.add_node("save_metrics_results", nodes.save_metrics_results) + + if pos_data_loaded: + graph.set_entry_point("prepare_gui_run") + else: + graph.set_entry_point("load_position") + graph.add_edge("load_position", "prepare_gui_run") + + graph.add_conditional_edges( + "prepare_gui_run", + nodes._route_after_prepare, + {"compute_metrics_frames": "compute_metrics_frames", END: END}, + ) + graph.add_edge("compute_metrics_frames", "save_metrics_results") + graph.add_edge("save_metrics_results", END) + return graph diff --git a/cellacdc/workflow/pipelines/measurements_gui_batch_graph.py b/cellacdc/workflow/pipelines/measurements_gui_batch_graph.py new file mode 100644 index 000000000..f7da6d3fe --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements_gui_batch_graph.py @@ -0,0 +1,70 @@ +"""GUI measurements batch parent graph.""" + +from __future__ import annotations + +from typing import Any + +from ..constants import END +from ..graph import StateGraph +from ..runnable import RunnableConfig +from ..state import BatchState, MeasurementsGuiBatchContext, MeasurementsGuiContext, MeasurementsGuiState +from .measurements_gui import build_gui_measurements_graph + + +def process_position( + state: BatchState, + ctx: MeasurementsGuiBatchContext, + config: RunnableConfig, +) -> dict[str, Any]: + path = state.paths[state.current_index] + stop_frame_n = state.stop_frame_numbers[state.current_index] + config.logger_func(f'\nProcessing "{path}"...') + + gui_ctx = MeasurementsGuiContext( + kernel=ctx.kernel, + compute_metrics_worker=ctx.compute_metrics_worker, + save_data_worker=ctx.save_data_worker, + save_metrics=ctx.save_metrics, + do_init_metrics=state.current_index == 0, + end_filename_segm=ctx.end_filename_segm, + ) + graph = build_gui_measurements_graph(gui_ctx, pos_data_loaded=False).compile() + result = graph.invoke( + MeasurementsGuiState(img_path=path, stop_frame_n=stop_frame_n), + config, + ) + + results = list(state.results) + results.append(result) + progress = config.metadata.get("progress") + if progress is not None: + progress.update(1) + + aborted = bool(getattr(result, "aborted", False)) + return { + "results": results, + "current_index": state.current_index + 1, + "aborted": aborted or state.aborted, + } + + +def _route_batch(state: BatchState, ctx: MeasurementsGuiBatchContext) -> str: + if state.aborted or ctx.kernel.setup_done: + return END + if state.current_index >= len(state.paths): + return END + return "process_position" + + +def build_gui_measurements_batch_graph( + ctx: MeasurementsGuiBatchContext, +) -> StateGraph[BatchState, MeasurementsGuiBatchContext]: + graph = StateGraph(BatchState, ctx) + graph.add_node("process_position", process_position) + graph.set_entry_point("process_position") + graph.add_conditional_edges( + "process_position", + _route_batch, + {"process_position": "process_position", END: END}, + ) + return graph diff --git a/cellacdc/workflow/pipelines/measurements_gui_nodes.py b/cellacdc/workflow/pipelines/measurements_gui_nodes.py new file mode 100644 index 000000000..1aeb35d09 --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements_gui_nodes.py @@ -0,0 +1,135 @@ +"""GUI measurements pipeline nodes.""" + +from __future__ import annotations + +import os +import traceback + +import numpy as np +import skimage.measure + +from cellacdc import load, utils + +from ..constants import END +from ..runnable import RunnableConfig +from ..state import MeasurementsGuiContext, MeasurementsGuiState + + +def load_position( + state: MeasurementsGuiState, + ctx: MeasurementsGuiContext, + config: RunnableConfig, +) -> dict[str, Any]: + end_name = ctx.end_filename_segm or ctx.kernel.end_filename_segm + pos_data = ctx.kernel._load_posData(state.img_path, end_name) + return {"pos_data": pos_data, "skipped": False, "aborted": False} + + +def prepare_gui_run( + state: MeasurementsGuiState, + ctx: MeasurementsGuiContext, + config: RunnableConfig, +) -> dict[str, Any]: + kernel = ctx.kernel + pos_data = state.pos_data + worker = ctx.compute_metrics_worker + save_worker = ctx.save_data_worker + exp_foldername = os.path.basename(pos_data.exp_path) + + kernel._set_metrics_func_from_posData(pos_data) + + if worker is not None and ctx.do_init_metrics: + worker.emitSigInitMetricsDialog(pos_data) + if worker.abort: + worker.signals.finished.emit(worker) + return {"aborted": True} + if kernel.setup_done: + worker.signals.finished.emit(worker) + return {"aborted": True} + worker.emitSigAskRunNow() + if worker.abort or worker.savedToWorkflow: + worker.signals.finished.emit(worker) + return {"aborted": True} + + if not pos_data.segmFound: + rel_path = f"...{os.sep}{exp_foldername}{os.sep}{pos_data.pos_foldername}" + kernel.log(f'Skipping "{rel_path}" because segm. file was not found.') + return {"skipped": True} + + kernel.init_signals(worker, save_worker) + kernel.log( + "Loading the following files:\n" + f"Segmentation file name: {os.path.basename(pos_data.segm_npz_path)}\n" + f"ACDC output file name: {os.path.basename(pos_data.acdc_output_csv_path)}" + ) + pos_data.init_segmInfo_df() + + if worker is not None: + worker.emitSigComputeVolume(pos_data, state.stop_frame_n) + + kernel._init_metrics_to_save(pos_data) + + if worker is not None: + worker.signals.initProgressBar.emit(state.stop_frame_n) + + channels_to_load = [ + ch + for ch in pos_data.chNames + if ch not in kernel.chNamesToSkip and ch in kernel.chNamesToProcess + ] + kernel.log(f"Loading channels {channels_to_load}...") + kernel._load_image_data(pos_data, channels_to_load) + return {} + + +def compute_metrics_frames( + state: MeasurementsGuiState, + ctx: MeasurementsGuiContext, + config: RunnableConfig, +) -> dict[str, Any]: + acdc_df_li, keys = ctx.kernel._compute_metrics_gui_frames( + state.pos_data, + state.stop_frame_n, + save_metrics=ctx.save_metrics, + compute_metrics_worker=ctx.compute_metrics_worker, + save_data_worker=ctx.save_data_worker, + ) + return {"acdc_df_li": acdc_df_li, "keys": keys} + + +def save_metrics_results( + state: MeasurementsGuiState, + ctx: MeasurementsGuiContext, + config: RunnableConfig, +) -> dict[str, Any]: + if not state.acdc_df_li: + exp_foldername = os.path.basename(state.pos_data.exp_path) + print("-" * 30) + ctx.kernel.log( + "All selected positions in the experiment folder " + f"{exp_foldername} have EMPTY segmentation mask. " + "Metrics will not be saved." + ) + print("-" * 30) + return {} + + ctx.kernel._concat_and_save_acdc_df( + state.acdc_df_li, + state.keys, + state.pos_data, + ctx.save_metrics, + computeMetricsWorker=ctx.compute_metrics_worker, + saveDataWorker=ctx.save_data_worker, + last_cca_frame_i=ctx.last_cca_frame_i, + ) + return {} + + +def _route_entry(state: MeasurementsGuiState, _ctx: MeasurementsGuiContext) -> str: + return "prepare_gui_run" if state.pos_data is not None else "load_position" + + +def _route_after_prepare(state: MeasurementsGuiState, _ctx: MeasurementsGuiContext) -> str: + if state.aborted or state.skipped: + return END + return "compute_metrics_frames" diff --git a/cellacdc/workflow/pipelines/measurements_nodes.py b/cellacdc/workflow/pipelines/measurements_nodes.py new file mode 100644 index 000000000..5dc83445c --- /dev/null +++ b/cellacdc/workflow/pipelines/measurements_nodes.py @@ -0,0 +1,61 @@ +"""Measurements position pipeline nodes.""" + +from __future__ import annotations + +import os +from typing import Any + +from ..constants import END +from ..runnable import RunnableConfig +from ..state import MeasurementsContext, MeasurementsState + + +def load_position( + state: MeasurementsState, + ctx: MeasurementsContext, + config: RunnableConfig, +) -> dict[str, Any]: + kernel = ctx.kernel + pos_data = kernel._load_posData(state.img_path, ctx.end_filename_segm) + return {"pos_data": pos_data, "skipped": False, "aborted": False, "error": None} + + +def validate_segm( + state: MeasurementsState, + ctx: MeasurementsContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + if pos_data.segmFound: + return {} + + exp_foldername = os.path.basename(pos_data.exp_path) + rel_path = f"...{os.sep}{exp_foldername}{os.sep}{pos_data.pos_foldername}" + ctx.kernel.log(f'Skipping "{rel_path}" because segm. file was not found.') + return {"skipped": True} + + +def compute_and_save( + state: MeasurementsState, + ctx: MeasurementsContext, + config: RunnableConfig, +) -> dict[str, Any]: + ctx.kernel._run_metrics_cli( + state.pos_data, + state.stop_frame_n, + save_metrics=ctx.save_metrics, + last_cca_frame_i=ctx.last_cca_frame_i, + ) + return {} + + +def _route_after_validate(state: MeasurementsState, _ctx: MeasurementsContext) -> str: + if state.skipped: + return END + return "compute_and_save" + + +def _route_after_load(state: MeasurementsState, _ctx: MeasurementsContext) -> str: + if state.aborted: + return END + return "validate_segm" diff --git a/cellacdc/workflow/pipelines/postprocess_nodes.py b/cellacdc/workflow/pipelines/postprocess_nodes.py new file mode 100644 index 000000000..0b150e34c --- /dev/null +++ b/cellacdc/workflow/pipelines/postprocess_nodes.py @@ -0,0 +1,39 @@ +"""Shared postprocess helpers for workflow nodes.""" + +from __future__ import annotations + +from typing import Any + +from cellacdc import core, features + + +def apply_postprocess( + lab: Any, + img: Any, + pos_data: Any, + frame_i: int, + *, + apply_postprocessing: bool, + standard_postprocess_kwargs: dict[str, Any], + custom_postprocess_features: dict[str, Any], + custom_postprocess_grouped_features: dict[str, Any], + user_ch_name: str | None = None, +) -> Any: + if not apply_postprocessing: + return lab + + lab = core.post_process_segm(lab, **standard_postprocess_kwargs) + if not custom_postprocess_features: + return lab + + ch_name = user_ch_name or pos_data.user_ch_name + return features.custom_post_process_segm( + pos_data, + custom_postprocess_grouped_features, + lab, + img, + frame_i, + pos_data.filename, + ch_name, + custom_postprocess_features, + ) diff --git a/cellacdc/workflow/pipelines/segm.py b/cellacdc/workflow/pipelines/segm.py new file mode 100644 index 000000000..aee4676bd --- /dev/null +++ b/cellacdc/workflow/pipelines/segm.py @@ -0,0 +1,67 @@ +"""Segmentation position pipeline graph.""" + +from __future__ import annotations + +from ..constants import END +from ..graph import StateGraph +from ..state import PositionState, WorkflowContext +from . import segm_nodes as nodes + + +def build_position_segm_graph( + ctx: WorkflowContext, +) -> StateGraph[PositionState, WorkflowContext]: + """Build the per-position segmentation graph.""" + graph = StateGraph(PositionState, ctx) + for name, fn in ( + ("load_position", nodes.load_position), + ("prepare_stack", nodes.prepare_stack), + ("ensure_model", nodes.ensure_model), + ("segment", nodes.segment), + ("filter_freehand_roi", nodes.filter_freehand_roi), + ("postprocess", nodes.postprocess), + ("before_track", nodes.passthrough), + ("track", nodes.track), + ("skip_track", nodes.skip_track_progress), + ("before_pad", nodes.passthrough), + ("pad_roi", nodes.pad_roi), + ("before_save", nodes.passthrough), + ("save", nodes.save), + ): + graph.add_node(name, fn) + + graph.set_entry_point("load_position") + graph.add_edge("load_position", "prepare_stack") + graph.add_edge("prepare_stack", "ensure_model") + graph.add_conditional_edges( + "ensure_model", + nodes._route_after_model, + {"segment": "segment", END: END}, + ) + graph.add_edge("segment", "filter_freehand_roi") + graph.add_conditional_edges( + "filter_freehand_roi", + nodes._route_postprocess, + {"postprocess": "postprocess", "before_track": "before_track"}, + ) + graph.add_edge("postprocess", "before_track") + graph.add_conditional_edges( + "before_track", + nodes._route_track, + {"track": "track", "skip_track": "skip_track"}, + ) + graph.add_edge("track", "before_pad") + graph.add_edge("skip_track", "before_pad") + graph.add_conditional_edges( + "before_pad", + nodes._route_pad_roi, + {"pad_roi": "pad_roi", "before_save": "before_save"}, + ) + graph.add_edge("pad_roi", "before_save") + graph.add_conditional_edges( + "before_save", + nodes._route_save, + {"save": "save", END: END}, + ) + graph.add_edge("save", END) + return graph diff --git a/cellacdc/workflow/pipelines/segm_nodes.py b/cellacdc/workflow/pipelines/segm_nodes.py new file mode 100644 index 000000000..68e48dec3 --- /dev/null +++ b/cellacdc/workflow/pipelines/segm_nodes.py @@ -0,0 +1,528 @@ +"""Segmentation pipeline node implementations.""" + +from __future__ import annotations + +import os +import time +from typing import Any + +import numpy as np +from tqdm import tqdm + +from cellacdc import core, features, io, load, utils + +from ..constants import END +from ..runnable import RunnableConfig +from ..state import PositionState, WorkflowContext + + +def passthrough( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + return {} + + +def load_position( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = load.loadData(state.img_path, ctx.user_ch_name) + config.logger_func(f"Loading {pos_data.relPath}...") + + pos_data.getBasenameAndChNames() + pos_data.buildPaths() + pos_data.loadImgData() + pos_data.loadOtherFiles( + load_segm_data=False, + load_acdc_df=False, + load_shifts=True, + loadSegmInfo=True, + load_delROIsInfo=False, + load_dataPrep_ROIcoords=True, + load_bkgr_data=True, + load_last_tracked_i=False, + load_metadata=True, + load_dataprep_free_roi=True, + end_filename_segm=ctx.segm_endname, + ) + + end_name = ( + ctx.segm_endname.replace("segm", "", 1).replace("_", "", 1).split(".")[0] + ) + if end_name: + pos_data.setFilePaths(end_name) + + if ctx.do_save: + segm_filename = os.path.basename(pos_data.segm_npz_path) + config.logger_func(f"\nSegmentation file {segm_filename}...") + + pos_data.SizeT = ctx.size_t + if ctx.size_z > 1: + pos_data.SizeZ = pos_data.img_data.shape[-3] + else: + pos_data.SizeZ = 1 + + pos_data.isSegm3D = ctx.is_segm_3d + pos_data.saveMetadata() + + is_roi_active = False + roi_bounds = None + if pos_data.dataPrep_ROIcoords is not None and ctx.use_roi: + df_roi = pos_data.dataPrep_ROIcoords.loc[0] + is_roi_active = df_roi.at["cropped", "value"] == 0 + x0, x1, y0, y1 = df_roi["value"].astype(int)[:4] + y_shape, x_shape = pos_data.img_data.shape[-2:] + x0 = x0 if x0 > 0 else 0 + y0 = y0 if y0 > 0 else 0 + x1 = x1 if x1 < x_shape else x_shape + y1 = y1 if y1 < y_shape else y_shape + roi_bounds = (x0, x1, y0, y1) + + return { + "pos_data": pos_data, + "is_roi_active": is_roi_active, + "roi_bounds": roi_bounds, + "stop_i": state.stop_frame_n, + "t0": 0, + "aborted": False, + "error": None, + } + + +def prepare_stack( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + stop_i = state.stop_i + is_roi_active = state.is_roi_active + pad_info = None + second_ch_data = None + + if ctx.second_channel_name is not None: + config.logger_func(f'Loading second channel "{ctx.second_channel_name}"...') + second_ch_filepath = load.get_filename_from_channel( + pos_data.images_path, ctx.second_channel_name + ) + second_ch_img_data = load.load_image_file(second_ch_filepath) + else: + second_ch_img_data = None + + x0 = x1 = y0 = y1 = 0 + if state.roi_bounds is not None: + x0, x1, y0, y1 = state.roi_bounds + + if pos_data.SizeT > 1: + t0 = state.t0 + if pos_data.SizeZ > 1 and not ctx.is_segm_3d and not ctx.use_3d_data_for_2d_segm: + img_data = pos_data.img_data + if ctx.second_channel_name is not None: + second_ch_data_slice = second_ch_img_data[t0:stop_i] + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + img_data = img_data[:, :, y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data_slice = second_ch_data_slice[:, :, y0:y1, x0:x1] + pad_info = ((0, 0), (y0, y_shape - y1), (x0, x_shape - x1)) + + img_data_slice = img_data[t0:stop_i] + postprocess_img = img_data + y_shape, x_shape = img_data.shape[-2:] + new_shape = (stop_i, y_shape, x_shape) + img_data = np.zeros(new_shape, img_data.dtype) + if ctx.second_channel_name is not None: + second_ch_data = np.zeros(new_shape, second_ch_img_data.dtype) + + df = pos_data.segmInfo_df.loc[pos_data.filename] + for z_info in df[:stop_i].itertuples(): + i = z_info.Index + z = z_info.z_slice_used_dataPrep + z_proj_how = z_info.which_z_proj + img = img_data_slice[i] + if ctx.second_channel_name is not None: + second_ch_img = second_ch_data_slice[i] + if z_proj_how == "single z-slice": + img_data[i] = img[z] + if ctx.second_channel_name is not None: + second_ch_data[i] = second_ch_img[z] + elif z_proj_how == "max z-projection": + img_data[i] = img.max(axis=0) + if ctx.second_channel_name is not None: + second_ch_data[i] = second_ch_img.max(axis=0) + elif z_proj_how == "mean z-projection": + img_data[i] = img.mean(axis=0) + if ctx.second_channel_name is not None: + second_ch_data[i] = second_ch_img.mean(axis=0) + elif z_proj_how == "median z-proj.": + img_data[i] = np.median(img, axis=0) + if ctx.second_channel_name is not None: + second_ch_data[i] = np.median(second_ch_img, axis=0) + elif pos_data.SizeZ > 1 and (ctx.is_segm_3d or ctx.use_3d_data_for_2d_segm): + img_data = pos_data.img_data[t0:stop_i] + postprocess_img = img_data + if ctx.second_channel_name is not None: + second_ch_data = second_ch_img_data[t0:stop_i] + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + img_data = img_data[:, :, y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] + pad_info = ((0, 0), (0, 0), (y0, y_shape - y1), (x0, x_shape - x1)) + else: + img_data = pos_data.img_data[t0:stop_i] + postprocess_img = img_data + if ctx.second_channel_name is not None: + second_ch_data = second_ch_img_data[t0:stop_i] + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + img_data = img_data[:, y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] + pad_info = ((0, 0), (y0, y_shape - y1), (x0, x_shape - x1)) + elif pos_data.SizeZ > 1 and not ctx.is_segm_3d and not ctx.use_3d_data_for_2d_segm: + img_data = pos_data.img_data + if ctx.second_channel_name is not None: + second_ch_data = second_ch_img_data + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + pad_info = ((y0, y_shape - y1), (x0, x_shape - x1)) + img_data = img_data[:, y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[:, :, y0:y1, x0:x1] + + postprocess_img = img_data + z_info = pos_data.segmInfo_df.loc[pos_data.filename].iloc[0] + z = z_info.z_slice_used_dataPrep + z_proj_how = z_info.which_z_proj + if z_proj_how == "single z-slice": + img_data = img_data[z] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[z] + elif z_proj_how == "max z-projection": + img_data = img_data.max(axis=0) + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data.max(axis=0) + elif z_proj_how == "mean z-projection": + img_data = img_data.mean(axis=0) + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data.mean(axis=0) + elif z_proj_how == "median z-proj.": + img_data = np.median(img_data, axis=0) + if ctx.second_channel_name is not None: + second_ch_data = np.median(second_ch_data, axis=0) + elif pos_data.SizeZ > 1 and (ctx.is_segm_3d or ctx.use_3d_data_for_2d_segm): + img_data = pos_data.img_data + if ctx.second_channel_name is not None: + second_ch_data = second_ch_img_data + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + pad_info = ((0, 0), (y0, y_shape - y1), (x0, x_shape - x1)) + img_data = img_data[:, y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[:, y0:y1, x0:x1] + postprocess_img = img_data + else: + img_data = pos_data.img_data + if ctx.second_channel_name is not None: + second_ch_data = second_ch_img_data + if is_roi_active: + y_shape, x_shape = img_data.shape[-2:] + pad_info = ((y0, y_shape - y1), (x0, x_shape - x1)) + img_data = img_data[y0:y1, x0:x1] + if ctx.second_channel_name is not None: + second_ch_data = second_ch_data[y0:y1, x0:x1] + postprocess_img = img_data + + config.logger_func(f"\nImage shape = {img_data.shape}") + return { + "img_data": img_data, + "second_ch_data": second_ch_data, + "postprocess_img": postprocess_img, + "pad_info": pad_info, + } + + +def ensure_model( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + if ctx.model is not None: + return {} + + if ctx.signals is not None: + ctx.signals.progress.emit( + f"\nInitializing {ctx.model_name} segmentation model..." + ) + else: + config.logger_func(f"\nInitializing {ctx.model_name} segmentation model...") + + acdc_segment = utils.import_segment_module(ctx.model_name) + init_argspecs, segment_argspecs = utils.getModelArgSpec(acdc_segment) + ctx.init_model_kwargs = utils.parse_model_params( + init_argspecs, ctx.init_model_kwargs + ) + ctx.model_kwargs = utils.parse_model_params(segment_argspecs, ctx.model_kwargs) + if ctx.second_channel_name is not None: + ctx.init_model_kwargs["is_rgb"] = True + + ctx.model = utils.init_segm_model( + acdc_segment, state.pos_data, ctx.init_model_kwargs + ) + if ctx.model is None: + message = f"Segmentation model {ctx.model_name} was not initialized!" + config.logger_func(f"\n{message}") + return {"aborted": True, "error": message} + + ctx.is_segment3dt_available = any( + name == "segment3DT" for name in dir(ctx.model) + ) and not ctx.reduce_memory_usage + return {"model": ctx.model} + + +def segment( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + img_data = state.img_data + second_ch_data = state.second_ch_data + + config.logger_func(f"\nSegmenting with {ctx.model_name}...") + time.perf_counter() + + if pos_data.SizeT > 1: + if ctx.inner_pbar_available and ctx.signals is not None: + ctx.signals.resetInnerPbar.emit(len(img_data)) + + if ctx.is_segment3dt_available and img_data.ndim == 3: + ctx.model_kwargs["signals"] = (ctx.signals, ctx.inner_pbar_available) + if ctx.second_channel_name is not None: + img_data = ctx.model.second_ch_img_to_stack(img_data, second_ch_data) + lab_stack = core.segm_model_segment( + ctx.model, + img_data, + ctx.model_kwargs, + is_timelapse_model_and_data=True, + preproc_recipe=ctx.preproc_recipe, + posData=pos_data, + ) + if ctx.inner_pbar_available and ctx.signals is not None: + ctx.signals.progressBar.emit(1) + else: + lab_stack = [] + pbar = tqdm(total=len(img_data), ncols=100) + for t, img in enumerate(img_data): + if ctx.second_channel_name is not None: + img = ctx.model.second_ch_img_to_stack(img, second_ch_data[t]) + lab = core.segm_model_segment( + ctx.model, + img, + ctx.model_kwargs, + frame_i=t, + preproc_recipe=ctx.preproc_recipe, + posData=pos_data, + ) + lab_stack.append(lab) + if ctx.signals is not None: + if ctx.inner_pbar_available: + ctx.signals.innerProgressBar.emit(1) + else: + ctx.signals.progressBar.emit(1) + pbar.update() + pbar.close() + lab_stack = np.array(lab_stack, dtype=np.uint32) + if ctx.inner_pbar_available and ctx.signals is not None: + ctx.signals.progressBar.emit(1) + else: + if ctx.second_channel_name is not None: + img_data = ctx.model.second_ch_img_to_stack(img_data, second_ch_data) + lab_stack = core.segm_model_segment( + ctx.model, + img_data, + ctx.model_kwargs, + frame_i=0, + preproc_recipe=ctx.preproc_recipe, + posData=pos_data, + ) + if ctx.signals is not None: + ctx.signals.progressBar.emit(1) + + pos_data.saveSamEmbeddings(logger_func=config.logger_func) + return {"lab_stack": lab_stack, "img_data": img_data} + + +def filter_freehand_roi( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + lab_stack = state.lab_stack + if len(pos_data.dataPrepFreeRoiPoints) > 0 and ctx.use_freehand_roi: + config.logger_func("Removing objects outside the dataprep free-hand ROI...") + lab_stack = pos_data.clearSegmObjsDataPrepFreeRoi( + lab_stack, is_timelapse=pos_data.SizeT > 1 + ) + return {"lab_stack": lab_stack} + + +def postprocess( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + lab_stack = state.lab_stack + postprocess_img = state.postprocess_img + + if pos_data.SizeT > 1: + pbar = tqdm(total=len(lab_stack), ncols=100) + for t, lab in enumerate(lab_stack): + lab_cleaned = core.post_process_segm( + lab, **ctx.standard_postprocess_kwargs + ) + lab_stack[t] = lab_cleaned + if ctx.custom_postprocess_features: + lab_filtered = features.custom_post_process_segm( + pos_data, + ctx.custom_postprocess_grouped_features, + lab_cleaned, + postprocess_img, + t, + pos_data.filename, + pos_data.user_ch_name, + ctx.custom_postprocess_features, + ) + lab_stack[t] = lab_filtered + pbar.update() + pbar.close() + else: + lab_stack = core.post_process_segm( + lab_stack, **ctx.standard_postprocess_kwargs + ) + if ctx.custom_postprocess_features: + lab_stack = features.custom_post_process_segm( + pos_data, + ctx.custom_postprocess_grouped_features, + lab_stack, + postprocess_img, + 0, + pos_data.filename, + pos_data.user_ch_name, + ctx.custom_postprocess_features, + ) + return {"lab_stack": lab_stack} + + +def track( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + lab_stack = state.lab_stack + + config.logger_func(f"\nTracking with {ctx.tracker_name} tracker...") + if ctx.do_save: + config.logger_func(f"Saving NON-tracked masks of {pos_data.relPath}...") + io.savez_compressed(pos_data.segm_npz_path, lab_stack) + + if ctx.signals is not None: + ctx.signals.innerPbar_available = ctx.inner_pbar_available + ctx.track_params["signals"] = ctx.signals + + tracker_input_img = None + if ctx.image_channel_tracker is not None: + if "image" in ctx.track_params: + tracker_input_img = ctx.track_params.pop("image") + else: + config.logger_func( + f'Loading image data of channel "{ctx.image_channel_tracker}"' + ) + tracker_input_img = pos_data.loadChannelData(ctx.image_channel_tracker) + + tracked_stack = core.tracker_track( + lab_stack, + ctx.tracker, + ctx.track_params, + intensity_img=tracker_input_img, + logger_func=config.logger_func, + ) + pos_data.fromTrackerToAcdcDf(ctx.tracker, tracked_stack, save=True) + return {"tracked_stack": tracked_stack, "lab_stack": lab_stack} + + +def skip_track_progress( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + if ctx.signals is None: + return {"tracked_stack": state.lab_stack} + + try: + if ctx.inner_pbar_available: + ctx.signals.innerProgressBar.emit(state.stop_i) + else: + ctx.signals.progressBar.emit(state.stop_i) + except AttributeError: + if ctx.inner_pbar_available: + ctx.signals.innerProgressBar.emit(1) + else: + ctx.signals.progressBar.emit(1) + return {"tracked_stack": state.lab_stack} + + +def pad_roi( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + tracked_stack = state.tracked_stack + if state.pad_info is not None: + config.logger_func(f"Padding with zeros {state.pad_info}...") + tracked_stack = np.pad(tracked_stack, state.pad_info, mode="constant") + return {"tracked_stack": tracked_stack} + + +def save( + state: PositionState, + ctx: WorkflowContext, + config: RunnableConfig, +) -> dict[str, Any]: + pos_data = state.pos_data + config.logger_func(f"Saving {pos_data.relPath}...") + io.savez_compressed(pos_data.segm_npz_path, state.tracked_stack) + config.logger_func(f"\n{pos_data.relPath} done.") + return {} + + +def _route_after_model(state: PositionState, ctx: WorkflowContext) -> str: + if state.aborted or ctx.model is None: + return END + return "segment" + + +def _route_postprocess(_state: PositionState, ctx: WorkflowContext) -> str: + return "postprocess" if ctx.do_postprocess else "before_track" + + +def _route_track(state: PositionState, ctx: WorkflowContext) -> str: + if not ctx.do_tracking: + return "skip_track" + size_t = getattr(state.pos_data, "SizeT", ctx.size_t) + return "track" if size_t > 1 else "skip_track" + + +def _route_pad_roi(state: PositionState, _ctx: WorkflowContext) -> str: + return "pad_roi" if state.is_roi_active else "before_save" + + +def _route_save(_state: PositionState, ctx: WorkflowContext) -> str: + return "save" if ctx.do_save else END diff --git a/cellacdc/workflow/runnable.py b/cellacdc/workflow/runnable.py new file mode 100644 index 000000000..a41ca4ced --- /dev/null +++ b/cellacdc/workflow/runnable.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, runtime_checkable + + +@dataclass(slots=True) +class RunnableConfig: + """Per-run callbacks and metadata (LangChain RunnableConfig analogue).""" + + logger_func: Callable[[str], None] = print + tags: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + signals: Any | None = None + + +@runtime_checkable +class Runnable(Protocol): + """Minimal composable step interface.""" + + def invoke(self, input: Any, config: RunnableConfig | None = None) -> Any: ... + + +@dataclass(slots=True) +class RunnableLambda(Runnable): + """Wrap a plain callable as a Runnable.""" + + func: Callable[..., Any] + name: str | None = None + + def invoke(self, input: Any, config: RunnableConfig | None = None) -> Any: + if config is None: + return self.func(input) + return self.func(input, config) + + +@dataclass(slots=True) +class RunnableSequence(Runnable): + """Linear chain of runnables (LangChain RunnableSequence analogue).""" + + steps: tuple[Runnable, ...] + + def invoke(self, input: Any, config: RunnableConfig | None = None) -> Any: + value = input + for step in self.steps: + value = step.invoke(value, config) + return value + + def __or__(self, other: Runnable) -> RunnableSequence: + if isinstance(other, RunnableSequence): + return RunnableSequence(self.steps + other.steps) + return RunnableSequence(self.steps + (other,)) diff --git a/cellacdc/workflow/state.py b/cellacdc/workflow/state.py new file mode 100644 index 000000000..e1101985a --- /dev/null +++ b/cellacdc/workflow/state.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field, fields, replace +from typing import Any + + +def merge_state(state: Any, update: dict[str, Any] | None) -> Any: + if not update: + return state + if isinstance(state, dict): + return {**state, **update} + return replace(state, **update) + + +@dataclass(slots=True) +class WorkflowContext: + """Immutable workflow configuration (LangGraph context_schema analogue).""" + + user_ch_name: str + segm_endname: str = "segm.npz" + model_name: str = "" + tracker_name: str = "" + do_tracking: bool = False + do_postprocess: bool = True + do_save: bool = True + is_segm_3d: bool = False + use_roi: bool = True + use_freehand_roi: bool = True + use_3d_data_for_2d_segm: bool = False + second_channel_name: str | None = None + image_channel_tracker: str | None = None + size_t: int = 1 + size_z: int = 1 + model_kwargs: dict[str, Any] = field(default_factory=dict) + init_model_kwargs: dict[str, Any] = field(default_factory=dict) + track_params: dict[str, Any] = field(default_factory=dict) + init_tracker_kwargs: dict[str, Any] = field(default_factory=dict) + standard_postprocess_kwargs: dict[str, Any] = field(default_factory=dict) + custom_postprocess_features: dict[str, Any] = field(default_factory=dict) + custom_postprocess_grouped_features: dict[str, Any] = field(default_factory=dict) + preproc_recipe: list[dict[str, Any]] | None = None + reduce_memory_usage: bool = False + model: Any | None = None + tracker: Any | None = None + is_segment3dt_available: bool = False + inner_pbar_available: bool = False + signals: Any | None = None + + +@dataclass(slots=True) +class PositionState: + """Mutable per-position pipeline state (LangGraph state_schema analogue).""" + + img_path: str + stop_frame_n: int = 1 + pos_data: Any | None = None + img_data: Any | None = None + second_ch_data: Any | None = None + postprocess_img: Any | None = None + lab_stack: Any | None = None + tracked_stack: Any | None = None + model: Any | None = None + tracker: Any | None = None + is_roi_active: bool = False + pad_info: tuple | None = None + roi_bounds: tuple[int, int, int, int] | None = None + stop_i: int = 1 + t0: int = 0 + aborted: bool = False + error: str | None = None + + def as_update(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass(slots=True) +class BatchState: + """Outer loop over many positions.""" + + paths: list[str] = field(default_factory=list) + stop_frame_numbers: list[int] = field(default_factory=list) + current_index: int = 0 + results: list[Any] = field(default_factory=list) + aborted: bool = False + + @property + def done(self) -> bool: + return self.current_index >= len(self.paths) + + @property + def current_path(self) -> str | None: + if self.done: + return None + return self.paths[self.current_index] + + @property + def current_stop_frame(self) -> int: + if not self.stop_frame_numbers: + return 1 + index = min(self.current_index, len(self.stop_frame_numbers) - 1) + return int(self.stop_frame_numbers[index]) + + +@dataclass(slots=True) +class BatchWorkflowContext: + """Context for batch parent graphs.""" + + position_ctx: WorkflowContext + position_graph: Any | None = None + + +@dataclass(slots=True) +class MeasurementsContext: + """Context for measurements position pipeline.""" + + end_filename_segm: str + kernel: Any + save_metrics: bool = True + last_cca_frame_i: Any | None = None + + +@dataclass(slots=True) +class MeasurementsBatchContext: + measurements_ctx: MeasurementsContext + position_graph: Any | None = None + + +@dataclass(slots=True) +class MeasurementsState: + img_path: str = "" + stop_frame_n: int = 1 + pos_data: Any | None = None + skipped: bool = False + aborted: bool = False + error: str | None = None + + +@dataclass(slots=True) +class InteractiveSegmContext: + """Context for in-viewer single-frame segmentation.""" + + model: Any + model_kwargs: dict[str, Any] + apply_postprocessing: bool = False + standard_postprocess_kwargs: dict[str, Any] = field(default_factory=dict) + custom_postprocess_features: dict[str, Any] = field(default_factory=dict) + custom_postprocess_grouped_features: dict[str, Any] = field(default_factory=dict) + segment_3d: bool = False + second_channel_data: Any | None = None + z_range: tuple[int, int] | None = None + + +@dataclass(slots=True) +class InteractiveSegmState: + main_win: Any + pos_data: Any | None = None + img: Any | None = None + lab: Any | None = None + segmented_lab: Any | None = None + start_z_slice: int = 0 + exec_time: float = 0.0 + + +@dataclass(slots=True) +class InteractiveVideoSegmContext: + """Context for in-viewer timelapse segmentation.""" + + model: Any + model_kwargs: dict[str, Any] + apply_postprocessing: bool = False + standard_postprocess_kwargs: dict[str, Any] = field(default_factory=dict) + custom_postprocess_features: dict[str, Any] = field(default_factory=dict) + custom_postprocess_grouped_features: dict[str, Any] = field(default_factory=dict) + preproc_recipe: list[dict[str, Any]] | None = None + second_channel_data: Any | None = None + start_frame_num: int = 1 + stop_frame_num: int = 1 + progress_callback: Any | None = None + logger_func: Any = print + + +@dataclass(slots=True) +class InteractiveVideoSegmState: + pos_data: Any + segm_data: Any | None = None + img_data: Any | None = None + z_slices: Any | None = None + exec_time: float = 0.0 + + +@dataclass(slots=True) +class MeasurementsGuiContext: + """Context for GUI-driven measurements runs.""" + + kernel: Any + compute_metrics_worker: Any | None = None + save_data_worker: Any | None = None + save_metrics: bool = True + do_init_metrics: bool = True + last_cca_frame_i: Any | None = None + end_filename_segm: str = "" + + +@dataclass(slots=True) +class MeasurementsGuiState: + img_path: str = "" + stop_frame_n: int = 1 + pos_data: Any | None = None + skipped: bool = False + aborted: bool = False + acdc_df_li: list[Any] = field(default_factory=list) + keys: list[Any] = field(default_factory=list) + + +@dataclass(slots=True) +class MeasurementsGuiBatchContext: + kernel: Any + compute_metrics_worker: Any | None = None + save_data_worker: Any | None = None + save_metrics: bool = True + end_filename_segm: str = "" + + +@dataclass(slots=True) +class FullWorkflowState: + """Top-level INI workflow state.""" + + segm_params: dict[str, Any] = field(default_factory=dict) + measurements_params: dict[str, Any] | None = None + run_segm: bool = True + run_measurements: bool = False + segm_done: bool = False + measurements_done: bool = False + + +def state_field_names(state_type: type) -> set[str]: + if hasattr(state_type, "__dataclass_fields__"): + return set(state_type.__dataclass_fields__) + return {field.name for field in fields(state_type)} diff --git a/examples/run_headless_workflow.py b/examples/run_headless_workflow.py new file mode 100644 index 000000000..c6c7978b2 --- /dev/null +++ b/examples/run_headless_workflow.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +"""Run Cell-ACDC segmentation and measurements without GUI or INI files. + +Edit USER CONFIG below, then: + + python examples/run_headless_workflow.py + +Data must follow the usual ACDC layout: + + /path/to/MyExperiment/Position_001/Images/phase.tif + /path/to/MyExperiment/Position_001/Images/GFP.tif + ... + +You can also build or extend graphs directly — this script shows the +stock pipelines with plain Python configuration. +""" + +from __future__ import annotations + +import os +import sys + +from tqdm import tqdm + +# --------------------------------------------------------------------------- +# USER CONFIG — edit these +# --------------------------------------------------------------------------- + +EXPERIMENT_PATH = "/path/to/MyExperiment" +USER_CH_NAME = "phase" +FLUOR_CHANNELS = ["GFP"] # channels to quantify after segmentation +SEGM_ENDNAME = "segm.npz" # output segm file basename +STOP_FRAME = 10 # frames to process per position (use 1 for single frame) + +# Segmentation +MODEL_NAME = "cellpose" +MODEL_KWARGS = {"diameter": 30} +INIT_MODEL_KWARGS: dict = {} +DO_TRACKING = False +TRACKER_NAME = "" +TRACK_PARAMS: dict = {} +DO_POSTPROCESS = True +DO_SAVE = True +IS_SEGM_3D = False +USE_ROI = True + +# Postprocess (empty dicts = defaults / no custom features) +STANDARD_POSTPROCESS_KWARGS: dict = {} +CUSTOM_POSTPROCESS_FEATURES: dict = {} +CUSTOM_POSTPROCESS_GROUPED_FEATURES: dict = {} + +RUN_SEGMENTATION = True +RUN_MEASUREMENTS = True + +# --------------------------------------------------------------------------- + + +def collect_position_paths(exp_path: str, user_ch: str) -> list[str]: + from cellacdc import utils + + paths: list[str] = [] + for pos in utils.get_pos_foldernames(exp_path): + images_path = os.path.join(exp_path, pos, "Images") + paths.append(utils.getChannelFilePath(images_path, user_ch)) + return paths + + +def build_segm_context(): + """Pure-Python workflow context — no kernel, no INI.""" + from cellacdc.workflow.state import WorkflowContext + + return WorkflowContext( + user_ch_name=USER_CH_NAME, + segm_endname=SEGM_ENDNAME, + model_name=MODEL_NAME, + tracker_name=TRACKER_NAME, + do_tracking=DO_TRACKING, + do_postprocess=DO_POSTPROCESS, + do_save=DO_SAVE, + is_segm_3d=IS_SEGM_3D, + use_roi=USE_ROI, + model_kwargs=dict(MODEL_KWARGS), + init_model_kwargs=dict(INIT_MODEL_KWARGS), + track_params=dict(TRACK_PARAMS), + standard_postprocess_kwargs=dict(STANDARD_POSTPROCESS_KWARGS), + custom_postprocess_features=dict(CUSTOM_POSTPROCESS_FEATURES), + custom_postprocess_grouped_features=dict(CUSTOM_POSTPROCESS_GROUPED_FEATURES), + size_t=STOP_FRAME, + size_z=1, + ) + + +def run_segmentation(logger, log_path, paths: list[str]) -> None: + from cellacdc.workflow.pipelines.batch import run_segm_batch + from cellacdc.workflow.runnable import RunnableConfig + + ctx = build_segm_context() + stops = [STOP_FRAME] * len(paths) + pbar = tqdm(total=len(paths), desc="Segmentation", ncols=100) + results = run_segm_batch( + ctx, + paths, + stops, + config=RunnableConfig(logger_func=logger.info), + progress=pbar, + ) + pbar.close() + aborted = [r for r in results if getattr(r, "aborted", False)] + if aborted: + logger.warning(f"{len(aborted)} position(s) aborted during segmentation.") + + +def run_measurements(logger, log_path, paths: list[str]) -> None: + from cellacdc import cli + from cellacdc.workflow.adapters import configure_measurements_kernel_for_cli + from cellacdc.workflow.pipelines.batch import run_measurements_batch + from cellacdc.workflow.runnable import RunnableConfig + + kernel = cli.ComputeMeasurementsKernel(logger, log_path, is_cli=True) + configure_measurements_kernel_for_cli( + kernel, + channels=[USER_CH_NAME, *FLUOR_CHANNELS], + end_filename_segm=SEGM_ENDNAME.replace(".npz", ""), + is_timelapse=STOP_FRAME > 1, + ) + + stops = [STOP_FRAME] * len(paths) + pbar = tqdm(total=len(paths), desc="Measurements", ncols=100) + run_measurements_batch( + kernel, + paths, + stops, + end_filename_segm=kernel.end_filename_segm, + config=RunnableConfig(logger_func=logger.info), + progress=pbar, + ) + pbar.close() + + +def run_single_position_graph_example(path: str) -> None: + """Minimal example: build one graph and invoke it once.""" + from cellacdc.workflow.pipelines.segm import build_position_segm_graph + from cellacdc.workflow.runnable import RunnableConfig + from cellacdc.workflow.state import PositionState + + graph = build_position_segm_graph(build_segm_context()).compile() + result = graph.invoke( + PositionState(img_path=path, stop_frame_n=STOP_FRAME), + RunnableConfig(logger_func=print), + ) + print("done:", result.aborted, result.error) + + +def main() -> int: + if EXPERIMENT_PATH.startswith("/path/to"): + print("Edit USER CONFIG in examples/run_headless_workflow.py first.", file=sys.stderr) + return 1 + + from cellacdc import utils + + logger, _, log_path, _ = utils.setupLogger(module="headless", logs_path=None) + paths = collect_position_paths(EXPERIMENT_PATH, USER_CH_NAME) + if not paths: + logger.error(f"No positions found under {EXPERIMENT_PATH}") + return 1 + + logger.info(f"Found {len(paths)} position(s)") + + if RUN_SEGMENTATION: + run_segmentation(logger, log_path, paths) + + if RUN_MEASUREMENTS: + run_measurements(logger, log_path, paths) + + logger.info("Finished.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/notebooks/acdc_paper_plots.ipynb b/notebooks/acdc_paper_plots.ipynb index 645f633ae..596ef911e 100755 --- a/notebooks/acdc_paper_plots.ipynb +++ b/notebooks/acdc_paper_plots.ipynb @@ -44,11 +44,13 @@ "\n", "cwd_path = os.getcwd()\n", "Cell_ACDC_path = os.path.dirname(cwd_path)\n", - "data_dir = os.path.join('..', 'tables', 'paper_plot_data')\n", - "plot_data3a_path = os.path.join(data_dir, 'plot_data3a.csv')\n", - "plot_data3d_path = os.path.join(data_dir, 'p38_AB_AllPos_BF_manual_cell_vol_VS_nucl_vol.csv')\n", - "plot_data4d_left_path = os.path.join(data_dir, 'plot_data4c.csv')\n", - "plot_data4d_right_path = os.path.join(data_dir, 'plot_data4d.csv')\n", + "data_dir = os.path.join(\"..\", \"tables\", \"paper_plot_data\")\n", + "plot_data3a_path = os.path.join(data_dir, \"plot_data3a.csv\")\n", + "plot_data3d_path = os.path.join(\n", + " data_dir, \"p38_AB_AllPos_BF_manual_cell_vol_VS_nucl_vol.csv\"\n", + ")\n", + "plot_data4d_left_path = os.path.join(data_dir, \"plot_data4c.csv\")\n", + "plot_data4d_right_path = os.path.join(data_dir, \"plot_data4d.csv\")\n", "os.path.exists(plot_data3d_path)" ] }, @@ -168,23 +170,23 @@ ], "source": [ "# Discard cells that are larger than 2.5*mean cell_vol_fl\n", - "col = 'cell_vol_fl'\n", + "col = \"cell_vol_fl\"\n", "\n", "df_3d = pd.read_csv(plot_data3d_path)\n", "max_vol_3d = df_3d[col].mean() * 2.5\n", - "df_3d['discard'] = 0\n", - "df_3d.loc[df_3d[col] >= max_vol_3d, 'discard'] = 1\n", + "df_3d[\"discard\"] = 0\n", + "df_3d.loc[df_3d[col] >= max_vol_3d, \"discard\"] = 1\n", "\n", "df_4d = pd.read_csv(plot_data4d_left_path)\n", "max_vol_4d = df_4d[col].mean() * 2.5\n", - "df_4d['discard'] = 0\n", - "df_4d.loc[df_4d[col] >= max_vol_4d, 'discard'] = 1\n", + "df_4d[\"discard\"] = 0\n", + "df_4d.loc[df_4d[col] >= max_vol_4d, \"discard\"] = 1\n", "\n", "df_4d_box = pd.read_csv(plot_data4d_right_path)\n", - "df_4d_box['discard'] = 0\n", - "df_4d_box.loc[df_4d_box[col] >= max_vol_4d, 'discard'] = 1\n", + "df_4d_box[\"discard\"] = 0\n", + "df_4d_box.loc[df_4d_box[col] >= max_vol_4d, \"discard\"] = 1\n", "\n", - "df_4d_box.sort_values(col)[[col, 'discard']]" + "df_4d_box.sort_values(col)[[col, \"discard\"]]" ] }, { @@ -250,78 +252,88 @@ "plot_data3b = pd.read_csv(plot_data3d_path)\n", "\n", "# Drop discarded\n", - "plot_data3b = plot_data3b[plot_data3b['discard'] == 0]\n", + "plot_data3b = plot_data3b[plot_data3b[\"discard\"] == 0]\n", "\n", - "sns.set_theme(context='talk', font_scale=1.6)\n", + "sns.set_theme(context=\"talk\", font_scale=1.6)\n", "sns.set_style(\"whitegrid\", {\"grid.color\": \".95\"})\n", - "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(20,10))\n", + "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))\n", "sns.regplot(\n", - " data=plot_data3a[plot_data3a.relationship_cellpose=='mother'],\n", - " x='cell_vol_fl_yeaz',\n", - " y='cell_vol_fl_cellpose',\n", - " ax = ax[0],\n", - " color = sns.color_palette()[0]\n", + " data=plot_data3a[plot_data3a.relationship_cellpose == \"mother\"],\n", + " x=\"cell_vol_fl_yeaz\",\n", + " y=\"cell_vol_fl_cellpose\",\n", + " ax=ax[0],\n", + " color=sns.color_palette()[0],\n", ")\n", "sns.regplot(\n", - " data=plot_data3a[plot_data3a.relationship_cellpose=='bud'],\n", - " x='cell_vol_fl_yeaz',\n", - " y='cell_vol_fl_cellpose',\n", - " ax = ax[0],\n", - " color = sns.color_palette()[1]\n", + " data=plot_data3a[plot_data3a.relationship_cellpose == \"bud\"],\n", + " x=\"cell_vol_fl_yeaz\",\n", + " y=\"cell_vol_fl_cellpose\",\n", + " ax=ax[0],\n", + " color=sns.color_palette()[1],\n", ")\n", - "labels = [\n", - " 'Mother cells',\n", - " 'Buds'\n", - "]\n", + "labels = [\"Mother cells\", \"Buds\"]\n", "handles = [\n", - " mlines.Line2D([], [], color=sns.color_palette()[0], marker='o', linestyle='None',\n", - " markersize=10),\n", - " mlines.Line2D([], [], color=sns.color_palette()[1], marker='o', linestyle='None',\n", - " markersize=10)\n", + " mlines.Line2D(\n", + " [],\n", + " [],\n", + " color=sns.color_palette()[0],\n", + " marker=\"o\",\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " ),\n", + " mlines.Line2D(\n", + " [],\n", + " [],\n", + " color=sns.color_palette()[1],\n", + " marker=\"o\",\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " ),\n", "]\n", - "ax[0].legend(\n", - " handles=handles,\n", - " labels=labels, \n", - " loc='upper left',\n", - " framealpha=0.5\n", - ")\n", - "scatter_plot_max = max(plot_data3a.cell_vol_fl_yeaz.max(), plot_data3a.cell_vol_fl_cellpose.max())\n", - "ax[0].set_xlabel('Cell vol. (phase contrast + YeaZ) [fL]')\n", - "ax[0].set_ylabel('Cell vol. ($\\it{ACT1pr}$ signal + Cellpose) [fL]')\n", - "#ax[0].set_title('A', fontsize=50, loc='left', pad=30, x=-0.175)\n", - "ax[0].set_ylim(0, int(scatter_plot_max)+10)\n", - "ax[0].set_xlim(0, int(scatter_plot_max)+10)\n", - "ax[0].set_xticks(np.arange(0, scatter_plot_max+10, 20))\n", - "ax[0].set_yticks(np.arange(0, scatter_plot_max+10, 20))\n", + "ax[0].legend(handles=handles, labels=labels, loc=\"upper left\", framealpha=0.5)\n", + "scatter_plot_max = max(\n", + " plot_data3a.cell_vol_fl_yeaz.max(), plot_data3a.cell_vol_fl_cellpose.max()\n", + ")\n", + "ax[0].set_xlabel(\"Cell vol. (phase contrast + YeaZ) [fL]\")\n", + "ax[0].set_ylabel(\"Cell vol. ($\\it{ACT1pr}$ signal + Cellpose) [fL]\")\n", + "# ax[0].set_title('A', fontsize=50, loc='left', pad=30, x=-0.175)\n", + "ax[0].set_ylim(0, int(scatter_plot_max) + 10)\n", + "ax[0].set_xlim(0, int(scatter_plot_max) + 10)\n", + "ax[0].set_xticks(np.arange(0, scatter_plot_max + 10, 20))\n", + "ax[0].set_yticks(np.arange(0, scatter_plot_max + 10, 20))\n", "\n", "sns.regplot(\n", " data=plot_data3b,\n", - " x='nucleus_vol_fl',\n", - " y='cell_vol_fl',\n", + " x=\"nucleus_vol_fl\",\n", + " y=\"cell_vol_fl\",\n", " robust=False,\n", - " ax = ax[1],\n", - " color = sns.color_palette()[2]\n", + " ax=ax[1],\n", + " color=sns.color_palette()[2],\n", ")\n", "\n", "scatter_plot_max = plot_data3b.nucleus_vol_fl.max()\n", - "ax[1].set_ylabel('HSCs vol. (Bright-field + YeaZ) [fL]')\n", - "ax[1].set_xlabel('HSCs nucl. vol. (DAPI signal + StarDist) [fL]')\n", - "ax[1].set_xticks(np.arange(0, scatter_plot_max+10, 100))\n", - "ax[1].set_yticks(np.arange(0, plot_data3b.cell_vol_fl.max()+10, 100))\n", + "ax[1].set_ylabel(\"HSCs vol. (Bright-field + YeaZ) [fL]\")\n", + "ax[1].set_xlabel(\"HSCs nucl. vol. (DAPI signal + StarDist) [fL]\")\n", + "ax[1].set_xticks(np.arange(0, scatter_plot_max + 10, 100))\n", + "ax[1].set_yticks(np.arange(0, plot_data3b.cell_vol_fl.max() + 10, 100))\n", "\n", "\n", "plt.tight_layout()\n", - "plt.savefig('../figures/new_fig3/fig3_final.svg')\n", - "#plt.savefig('../figures/new_fig3/fig3.png', dpi=300)\n", + "plt.savefig(\"../figures/new_fig3/fig3_final.svg\")\n", + "# plt.savefig('../figures/new_fig3/fig3.png', dpi=300)\n", "plt.show()\n", "\n", - "print(f'Sample size Fig. 3A: {len(plot_data3a)//2}')\n", - "pearson_r, p_value = scipy.stats.pearsonr(plot_data3a.cell_vol_fl_yeaz, plot_data3a.cell_vol_fl_cellpose)\n", - "print(f'Pearson Correlation and p-value for non-correlation 3A: {pearson_r, p_value}')\n", + "print(f\"Sample size Fig. 3A: {len(plot_data3a) // 2}\")\n", + "pearson_r, p_value = scipy.stats.pearsonr(\n", + " plot_data3a.cell_vol_fl_yeaz, plot_data3a.cell_vol_fl_cellpose\n", + ")\n", + "print(f\"Pearson Correlation and p-value for non-correlation 3A: {pearson_r, p_value}\")\n", "\n", - "print(f'Sample size Fig. 3B: {len(plot_data3b)}')\n", - "pearson_r, p_value = scipy.stats.pearsonr(plot_data3b.nucleus_vol_fl, plot_data3b.cell_vol_fl)\n", - "print(f'Pearson Correlation and p-value for non-correlation 3A: {pearson_r, p_value}')" + "print(f\"Sample size Fig. 3B: {len(plot_data3b)}\")\n", + "pearson_r, p_value = scipy.stats.pearsonr(\n", + " plot_data3b.nucleus_vol_fl, plot_data3b.cell_vol_fl\n", + ")\n", + "print(f\"Pearson Correlation and p-value for non-correlation 3A: {pearson_r, p_value}\")" ] }, { @@ -401,209 +413,305 @@ ], "source": [ "# load data from csv\n", - "plot_data4a = pd.read_csv(os.path.join(data_dir, 'plot_data4a.csv'))\n", - "plot_data4b = pd.read_csv(os.path.join(data_dir, 'plot_data4b.csv'))\n", + "plot_data4a = pd.read_csv(os.path.join(data_dir, \"plot_data4a.csv\"))\n", + "plot_data4b = pd.read_csv(os.path.join(data_dir, \"plot_data4b.csv\"))\n", "plot_data4c = pd.read_csv(plot_data4d_left_path)\n", "plot_data4d = pd.read_csv(plot_data4d_right_path)\n", "\n", "# Drop discarded\n", - "plot_data4c = plot_data4c[plot_data4c['discard'] == 0]\n", - "plot_data4d = plot_data4d[plot_data4d['discard'] == 0]\n", + "plot_data4c = plot_data4c[plot_data4c[\"discard\"] == 0]\n", + "plot_data4d = plot_data4d[plot_data4d[\"discard\"] == 0]\n", "\n", - "sns.set_theme(context='talk', font_scale=1.725)\n", + "sns.set_theme(context=\"talk\", font_scale=1.725)\n", "sns.set_style(\"whitegrid\", {\"grid.color\": \".95\"})\n", - "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(24,20))\n", + "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(24, 20))\n", "\n", "# subplot 1\n", "sns.scatterplot(\n", " data=plot_data4a,\n", - " x='cell_vol_fl',\n", - " y='FITC_concentration',\n", - " ax = axs[0,0],\n", - " color=sns.color_palette('pastel')[2],\n", + " x=\"cell_vol_fl\",\n", + " y=\"FITC_concentration\",\n", + " ax=axs[0, 0],\n", + " color=sns.color_palette(\"pastel\")[2],\n", " s=11,\n", - " legend=False\n", - " #scatter_kws={'s':10},\n", - " #x_bins=20\n", - " #hue='size_category'\n", + " legend=False,\n", + " # scatter_kws={'s':10},\n", + " # x_bins=20\n", + " # hue='size_category'\n", ")\n", "nbins = 12\n", "bins_min_count = 10\n", - "xe, ye, std = cca_functions.binned_mean_stats(plot_data4a.cell_vol_fl, plot_data4a.FITC_concentration, nbins, bins_min_count)\n", - "axs[0,0].errorbar(xe, ye, yerr=std, capsize=6, lw=3, c=sns.color_palette()[2])\n", - "axs[0,0].set_xlabel('Cell volume [fL]')\n", - "axs[0,0].set_ylabel('mTOR activity [a.u.]')\n", - "axs[0,0].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[0,0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", - "#lower_y_border, upper_y_border = plot_data4b.FITC_concentration.min()-10, plot_data4b.FITC_concentration.max()+10\n", - "lower_y_border, upper_y_border = -200, plot_data4a.FITC_concentration.max()+10\n", + "xe, ye, std = cca_functions.binned_mean_stats(\n", + " plot_data4a.cell_vol_fl, plot_data4a.FITC_concentration, nbins, bins_min_count\n", + ")\n", + "axs[0, 0].errorbar(xe, ye, yerr=std, capsize=6, lw=3, c=sns.color_palette()[2])\n", + "axs[0, 0].set_xlabel(\"Cell volume [fL]\")\n", + "axs[0, 0].set_ylabel(\"mTOR activity [a.u.]\")\n", + "axs[0, 0].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[0, 0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "# lower_y_border, upper_y_border = plot_data4b.FITC_concentration.min()-10, plot_data4b.FITC_concentration.max()+10\n", + "lower_y_border, upper_y_border = -200, plot_data4a.FITC_concentration.max() + 10\n", "height = upper_y_border - lower_y_border\n", "# configure borders for \"size blocks\"\n", "xs_borders = 0, np.percentile(plot_data4a.cell_vol_fl, 15)\n", - "m_borders = np.percentile(plot_data4a.cell_vol_fl, 35), np.percentile(plot_data4a.cell_vol_fl, 65)\n", - "xl_borders = np.percentile(plot_data4a.cell_vol_fl, 85), np.percentile(plot_data4a.cell_vol_fl, 85)*2 + 20\n", - "xs_width = xs_borders[1]-xs_borders[0]\n", - "m_width = m_borders[1]-m_borders[0]\n", - "xl_width = xl_borders[1]-xl_borders[0]\n", + "m_borders = (\n", + " np.percentile(plot_data4a.cell_vol_fl, 35),\n", + " np.percentile(plot_data4a.cell_vol_fl, 65),\n", + ")\n", + "xl_borders = (\n", + " np.percentile(plot_data4a.cell_vol_fl, 85),\n", + " np.percentile(plot_data4a.cell_vol_fl, 85) * 2 + 20,\n", + ")\n", + "xs_width = xs_borders[1] - xs_borders[0]\n", + "m_width = m_borders[1] - m_borders[0]\n", + "xl_width = xl_borders[1] - xl_borders[0]\n", "# add gray rectangles for size categories\n", - "axs[0,0].add_patch(\n", - " patches.Rectangle((xs_borders[0], lower_y_border), xs_width, height, color='black', alpha=0.1)\n", + "axs[0, 0].add_patch(\n", + " patches.Rectangle(\n", + " (xs_borders[0], lower_y_border), xs_width, height, color=\"black\", alpha=0.1\n", + " )\n", ")\n", - "axs[0,0].text(0.5*sum(xs_borders)-10, upper_y_border-(upper_y_border//10), 'XS', fontdict={'fontsize':30})\n", - "axs[0,0].add_patch(\n", - " patches.Rectangle((m_borders[0], lower_y_border), m_width, height, color='black', alpha=0.1)\n", + "axs[0, 0].text(\n", + " 0.5 * sum(xs_borders) - 10,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"XS\",\n", + " fontdict={\"fontsize\": 30},\n", ")\n", - "axs[0,0].text(0.5*sum(m_borders)-10, upper_y_border-(upper_y_border//10), 'M', fontdict={'fontsize':30})\n", - "axs[0,0].add_patch(\n", - " patches.Rectangle((xl_borders[0], lower_y_border), xl_width, height, color='black', alpha=0.1)\n", + "axs[0, 0].add_patch(\n", + " patches.Rectangle(\n", + " (m_borders[0], lower_y_border), m_width, height, color=\"black\", alpha=0.1\n", + " )\n", + ")\n", + "axs[0, 0].text(\n", + " 0.5 * sum(m_borders) - 10,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"M\",\n", + " fontdict={\"fontsize\": 30},\n", + ")\n", + "axs[0, 0].add_patch(\n", + " patches.Rectangle(\n", + " (xl_borders[0], lower_y_border), xl_width, height, color=\"black\", alpha=0.1\n", + " )\n", + ")\n", + "axs[0, 0].text(\n", + " 0.5 * sum(xl_borders) - 10,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"XL\",\n", + " fontdict={\"fontsize\": 30},\n", ")\n", - "axs[0,0].text(0.5*sum(xl_borders)-10, upper_y_border-(upper_y_border//10), 'XL', fontdict={'fontsize':30})\n", "# set x and y limits manually\n", - "axs[0,0].set_xlim(0, xl_borders[1])\n", - "axs[0,0].set_ylim(lower_y_border, upper_y_border)\n", - "#axs[1].set_title('B', fontsize=40, loc='left', pad=10)\n", - "#axs[1].set_yscale('log')\n", + "axs[0, 0].set_xlim(0, xl_borders[1])\n", + "axs[0, 0].set_ylim(lower_y_border, upper_y_border)\n", + "# axs[1].set_title('B', fontsize=40, loc='left', pad=10)\n", + "# axs[1].set_yscale('log')\n", "\n", "# subplot 2\n", "sns.boxplot(\n", " data=plot_data4b,\n", - " x='size_category',\n", - " y='FITC_concentration',\n", + " x=\"size_category\",\n", + " y=\"FITC_concentration\",\n", " order=[\"Control\", \"All\", \"XS\", \"M\", \"XL\"],\n", - " palette=['lightgray', sns.color_palette()[2]]+ [sns.color_palette('pastel')[2]]*3,\n", - " ax=axs[0,1],\n", - " #size=1\n", - " #inner='quartile'\n", - ")\n", - "axs[0,1].set_xlabel('Size category')\n", - "axs[0,1].set_ylabel('mTOR activity [a.u.]')\n", - "#axs[0,1].set_title('C', fontsize=40, loc='left', pad=10)\n", - "axs[0,1].set_ylim(lower_y_border, upper_y_border)\n", - "axs[0,1].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[0,1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", - "#axs[0,1].set_yscale('log')\n", - "\n", - "\n", - "print(f'Sample size Fig. 4A&B: {len(plot_data4a)}')\n", - "print(f'Sample size control Fig. 4B: {len(plot_data4b[plot_data4b.size_category==\"Control\"])}')\n", - "print(f'Sample size XS: {len(plot_data4b[plot_data4b.size_category==\"XS\"])}')\n", - "print(f'Sample size M: {len(plot_data4b[plot_data4b.size_category==\"M\"])}')\n", - "print(f'Sample size XL: {len(plot_data4b[plot_data4b.size_category==\"XL\"])}')\n", + " palette=[\"lightgray\", sns.color_palette()[2]]\n", + " + [sns.color_palette(\"pastel\")[2]] * 3,\n", + " ax=axs[0, 1],\n", + " # size=1\n", + " # inner='quartile'\n", + ")\n", + "axs[0, 1].set_xlabel(\"Size category\")\n", + "axs[0, 1].set_ylabel(\"mTOR activity [a.u.]\")\n", + "# axs[0,1].set_title('C', fontsize=40, loc='left', pad=10)\n", + "axs[0, 1].set_ylim(lower_y_border, upper_y_border)\n", + "axs[0, 1].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[0, 1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "# axs[0,1].set_yscale('log')\n", "\n", + "\n", + "print(f\"Sample size Fig. 4A&B: {len(plot_data4a)}\")\n", "print(\n", - " f'ommitted {len(plot_data4a[plot_data4a.FITC_concentration>upper_y_border])} cells with FITC concentration '\n", - " f'higher than {upper_y_border}'\n", - " )\n", + " f\"Sample size control Fig. 4B: {len(plot_data4b[plot_data4b.size_category == 'Control'])}\"\n", + ")\n", + "print(f\"Sample size XS: {len(plot_data4b[plot_data4b.size_category == 'XS'])}\")\n", + "print(f\"Sample size M: {len(plot_data4b[plot_data4b.size_category == 'M'])}\")\n", + "print(f\"Sample size XL: {len(plot_data4b[plot_data4b.size_category == 'XL'])}\")\n", "\n", - "print(f'**Effect sizes FITC amount per volume:**')\n", - "print(f'Effect size (cohen) All vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, \"All\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (cohen) XS vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, \"XS\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (cohen) M vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, \"M\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (cohen) XL vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, \"XL\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", + "print(\n", + " f\"ommitted {len(plot_data4a[plot_data4a.FITC_concentration > upper_y_border])} cells with FITC concentration \"\n", + " f\"higher than {upper_y_border}\"\n", + ")\n", "\n", - "print(f'Effect size (glass) All vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, \"All\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (glass) XS vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, \"XS\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (glass) M vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, \"M\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", - "print(f'Effect size (glass) XL vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, \"XL\", \"Control\",val_column=\"FITC_concentration\"), 2)}')\n", + "print(f\"**Effect sizes FITC amount per volume:**\")\n", + "print(\n", + " f\"Effect size (cohen) All vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, 'All', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) XS vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, 'XS', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) M vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, 'M', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) XL vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4b, 'XL', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "\n", + "print(\n", + " f\"Effect size (glass) All vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, 'All', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) XS vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, 'XS', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) M vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, 'M', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) XL vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4b, 'XL', 'Control', val_column='FITC_concentration'), 2)}\"\n", + ")\n", "\n", "\n", "# subplot 3\n", "sns.scatterplot(\n", " data=plot_data4c,\n", - " x='cell_vol_fl',\n", - " y='Pp38_concentration',\n", - " ax = axs[1,0],\n", - " color=sns.color_palette('pastel')[0],\n", + " x=\"cell_vol_fl\",\n", + " y=\"Pp38_concentration\",\n", + " ax=axs[1, 0],\n", + " color=sns.color_palette(\"pastel\")[0],\n", " s=11,\n", - " legend=False\n", - " #scatter_kws={'s':10},\n", - " #x_bins=20\n", - " #hue='size_category'\n", + " legend=False,\n", + " # scatter_kws={'s':10},\n", + " # x_bins=20\n", + " # hue='size_category'\n", ")\n", "nbins = 8\n", "bins_min_count = 10\n", - "xe, ye, std = cca_functions.binned_mean_stats(plot_data4c.cell_vol_fl, plot_data4c.Pp38_concentration, nbins, bins_min_count)\n", - "axs[1,0].errorbar(xe, ye, yerr=std, capsize=6, lw=3, c=sns.color_palette()[0])\n", - "axs[1,0].set_xlabel('Nuclear volume (DAPI) [fL]')\n", - "axs[1,0].set_ylabel('p38 activity [a.u.]')\n", - "#lower_y_border, upper_y_border = plot_data3b.FITC_concentration.min()-10, plot_data3b.FITC_concentration.max()+10\n", - "lower_y_border, upper_y_border = -200, plot_data4c.Pp38_concentration.max()+10\n", + "xe, ye, std = cca_functions.binned_mean_stats(\n", + " plot_data4c.cell_vol_fl, plot_data4c.Pp38_concentration, nbins, bins_min_count\n", + ")\n", + "axs[1, 0].errorbar(xe, ye, yerr=std, capsize=6, lw=3, c=sns.color_palette()[0])\n", + "axs[1, 0].set_xlabel(\"Nuclear volume (DAPI) [fL]\")\n", + "axs[1, 0].set_ylabel(\"p38 activity [a.u.]\")\n", + "# lower_y_border, upper_y_border = plot_data3b.FITC_concentration.min()-10, plot_data3b.FITC_concentration.max()+10\n", + "lower_y_border, upper_y_border = -200, plot_data4c.Pp38_concentration.max() + 10\n", "height = upper_y_border - lower_y_border\n", "# configure borders for \"size blocks\"\n", "xs_borders = 0, np.percentile(plot_data4c.cell_vol_fl, 15)\n", - "m_borders = np.percentile(plot_data4c.cell_vol_fl, 35), np.percentile(plot_data4c.cell_vol_fl, 65)\n", - "xl_borders = np.percentile(plot_data4c.cell_vol_fl, 85), np.max(plot_data4c.cell_vol_fl) + 20\n", - "xs_width = xs_borders[1]-xs_borders[0]\n", - "m_width = m_borders[1]-m_borders[0]\n", - "xl_width = xl_borders[1]-xl_borders[0]\n", + "m_borders = (\n", + " np.percentile(plot_data4c.cell_vol_fl, 35),\n", + " np.percentile(plot_data4c.cell_vol_fl, 65),\n", + ")\n", + "xl_borders = (\n", + " np.percentile(plot_data4c.cell_vol_fl, 85),\n", + " np.max(plot_data4c.cell_vol_fl) + 20,\n", + ")\n", + "xs_width = xs_borders[1] - xs_borders[0]\n", + "m_width = m_borders[1] - m_borders[0]\n", + "xl_width = xl_borders[1] - xl_borders[0]\n", "# add gray rectangles for size categories\n", - "axs[1,0].add_patch(\n", - " patches.Rectangle((xs_borders[0], lower_y_border), xs_width, height, color='black', alpha=0.1)\n", + "axs[1, 0].add_patch(\n", + " patches.Rectangle(\n", + " (xs_borders[0], lower_y_border), xs_width, height, color=\"black\", alpha=0.1\n", + " )\n", ")\n", - "axs[1,0].text(0.5*sum(xs_borders)-9, upper_y_border-(upper_y_border//10), 'XS', fontdict={'fontsize':30})\n", - "axs[1,0].add_patch(\n", - " patches.Rectangle((m_borders[0], lower_y_border), m_width, height, color='black', alpha=0.1)\n", + "axs[1, 0].text(\n", + " 0.5 * sum(xs_borders) - 9,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"XS\",\n", + " fontdict={\"fontsize\": 30},\n", ")\n", - "axs[1,0].text(0.5*sum(m_borders)-6, upper_y_border-(upper_y_border//10), 'M', fontdict={'fontsize':30})\n", - "axs[1,0].add_patch(\n", - " patches.Rectangle((xl_borders[0], lower_y_border), xl_width, height, color='black', alpha=0.1)\n", + "axs[1, 0].add_patch(\n", + " patches.Rectangle(\n", + " (m_borders[0], lower_y_border), m_width, height, color=\"black\", alpha=0.1\n", + " )\n", + ")\n", + "axs[1, 0].text(\n", + " 0.5 * sum(m_borders) - 6,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"M\",\n", + " fontdict={\"fontsize\": 30},\n", + ")\n", + "axs[1, 0].add_patch(\n", + " patches.Rectangle(\n", + " (xl_borders[0], lower_y_border), xl_width, height, color=\"black\", alpha=0.1\n", + " )\n", + ")\n", + "axs[1, 0].text(\n", + " 0.5 * sum(xl_borders) - 5,\n", + " upper_y_border - (upper_y_border // 10),\n", + " \"XL\",\n", + " fontdict={\"fontsize\": 30},\n", ")\n", - "axs[1,0].text(0.5*sum(xl_borders)-5, upper_y_border-(upper_y_border//10), 'XL', fontdict={'fontsize':30})\n", "# set x and y limits manually\n", - "axs[1,0].set_xlim(0, xl_borders[1])\n", - "axs[1,0].set_ylim(lower_y_border, upper_y_border)\n", - "axs[1,0].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[1,0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", - "#axs[1].set_title('B', fontsize=40, loc='left', pad=10)\n", - "#axs[1].set_yscale('log')\n", + "axs[1, 0].set_xlim(0, xl_borders[1])\n", + "axs[1, 0].set_ylim(lower_y_border, upper_y_border)\n", + "axs[1, 0].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[1, 0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "# axs[1].set_title('B', fontsize=40, loc='left', pad=10)\n", + "# axs[1].set_yscale('log')\n", "\n", "# subplot 4\n", "sns.boxplot(\n", " data=plot_data4d,\n", - " x='size_category',\n", - " y='Pp38_concentration',\n", + " x=\"size_category\",\n", + " y=\"Pp38_concentration\",\n", " order=[\"Control\", \"All\", \"XS\", \"M\", \"XL\"],\n", - " palette=['lightgray', sns.color_palette()[0]]+ [sns.color_palette('pastel')[0]]*3,\n", - " ax=axs[1,1],\n", - " #size=1\n", - " #inner='quartile'\n", - ")\n", - "axs[1,1].set_xlabel('Size category')\n", - "axs[1,1].set_ylabel('p38 activity [a.u.]')\n", - "#axs[1,1].set_title('C', fontsize=40, loc='left', pad=10)\n", - "axs[1,1].set_ylim(lower_y_border, upper_y_border)\n", - "axs[1,1].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[1,1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", - "#axs[1,1].set_yscale('log')\n", + " palette=[\"lightgray\", sns.color_palette()[0]]\n", + " + [sns.color_palette(\"pastel\")[0]] * 3,\n", + " ax=axs[1, 1],\n", + " # size=1\n", + " # inner='quartile'\n", + ")\n", + "axs[1, 1].set_xlabel(\"Size category\")\n", + "axs[1, 1].set_ylabel(\"p38 activity [a.u.]\")\n", + "# axs[1,1].set_title('C', fontsize=40, loc='left', pad=10)\n", + "axs[1, 1].set_ylim(lower_y_border, upper_y_border)\n", + "axs[1, 1].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[1, 1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "# axs[1,1].set_yscale('log')\n", "\n", "plt.tight_layout()\n", "\n", - "plt.savefig('../figures/new_fig4/combined_fig4_v4.svg')\n", - "plt.savefig('../figures/new_fig4/combined_fig4_v4.png', dpi=300)\n", + "plt.savefig(\"../figures/new_fig4/combined_fig4_v4.svg\")\n", + "plt.savefig(\"../figures/new_fig4/combined_fig4_v4.png\", dpi=300)\n", "\n", "plt.show()\n", "\n", "\n", - "print(f'Sample size Fig. 4C&D: {len(plot_data4c)}')\n", - "print(f'Sample size control Fig. 4D: {len(plot_data4d[plot_data4d.size_category==\"Control\"])}')\n", - "print(f'Sample size XS: {len(plot_data4d[plot_data4d.size_category==\"XS\"])}')\n", - "print(f'Sample size M: {len(plot_data4d[plot_data4d.size_category==\"M\"])}')\n", - "print(f'Sample size XL: {len(plot_data4d[plot_data4d.size_category==\"XL\"])}')\n", + "print(f\"Sample size Fig. 4C&D: {len(plot_data4c)}\")\n", + "print(\n", + " f\"Sample size control Fig. 4D: {len(plot_data4d[plot_data4d.size_category == 'Control'])}\"\n", + ")\n", + "print(f\"Sample size XS: {len(plot_data4d[plot_data4d.size_category == 'XS'])}\")\n", + "print(f\"Sample size M: {len(plot_data4d[plot_data4d.size_category == 'M'])}\")\n", + "print(f\"Sample size XL: {len(plot_data4d[plot_data4d.size_category == 'XL'])}\")\n", + "\n", + "print(\n", + " f\"ommitted {len(plot_data4c[plot_data4c.Pp38_concentration > upper_y_border])} cells with Pp38 concentration \"\n", + " f\"higher than {upper_y_border}\"\n", + ")\n", "\n", "print(\n", - " f'ommitted {len(plot_data4c[plot_data4c.Pp38_concentration>upper_y_border])} cells with Pp38 concentration '\n", - " f'higher than {upper_y_border}'\n", - " )\n", - "\n", - "print(f'Effect size (cohen) All vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, \"All\", \"Control\"), 2)}')\n", - "print(f'Effect size (cohen) XS vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, \"XS\", \"Control\"), 2)}')\n", - "print(f'Effect size (cohen) M vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, \"M\", \"Control\"), 2)}')\n", - "print(f'Effect size (cohen) XL vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, \"XL\", \"Control\"), 2)}')\n", - "\n", - "print(f'Effect size (glass) All vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, \"All\", \"Control\"), 2)}')\n", - "print(f'Effect size (glass) XS vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, \"XS\", \"Control\"), 2)}')\n", - "print(f'Effect size (glass) M vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, \"M\", \"Control\"), 2)}')\n", - "print(f'Effect size (glass) XL vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, \"XL\", \"Control\"), 2)}')" + " f\"Effect size (cohen) All vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, 'All', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) XS vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, 'XS', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) M vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, 'M', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (cohen) XL vs. Control: {round(cca_functions.calculate_effect_size_cohen(plot_data4d, 'XL', 'Control'), 2)}\"\n", + ")\n", + "\n", + "print(\n", + " f\"Effect size (glass) All vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, 'All', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) XS vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, 'XS', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) M vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, 'M', 'Control'), 2)}\"\n", + ")\n", + "print(\n", + " f\"Effect size (glass) XL vs. Control: {round(cca_functions.calculate_effect_size_glass(plot_data4d, 'XL', 'Control'), 2)}\"\n", + ")" ] }, { @@ -661,219 +769,224 @@ } ], "source": [ - "plot_data5a = pd.read_csv(os.path.join(data_dir, 'plot_data5a_v2.csv'))\n", - "plot_data5a_melted = pd.read_csv(os.path.join(data_dir, 'plot_data5a_melted_v2.csv'))\n", - "plot_data5b = pd.read_csv(os.path.join(data_dir, 'plot_data5b_v2.csv'))\n", - "plot_data5c = pd.read_csv(os.path.join(data_dir, 'plot_data5c.csv'))\n", - "sns.set_theme(context='talk', font_scale=1.6)\n", + "plot_data5a = pd.read_csv(os.path.join(data_dir, \"plot_data5a_v2.csv\"))\n", + "plot_data5a_melted = pd.read_csv(os.path.join(data_dir, \"plot_data5a_melted_v2.csv\"))\n", + "plot_data5b = pd.read_csv(os.path.join(data_dir, \"plot_data5b_v2.csv\"))\n", + "plot_data5c = pd.read_csv(os.path.join(data_dir, \"plot_data5c.csv\"))\n", + "sns.set_theme(context=\"talk\", font_scale=1.6)\n", "sns.set_style(\"whitegrid\", {\"grid.color\": \".95\"})\n", - "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(20,20))#, sharey='row')\n", + "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(20, 20)) # , sharey='row')\n", "\n", - "shared_y_max = plot_data5b.relevant_amount.max()+0.2e5\n", + "shared_y_max = plot_data5b.relevant_amount.max() + 0.2e5\n", "split_by_gen = True\n", "\n", "# subplot 1\n", "if split_by_gen:\n", - " style='Generation'\n", + " style = \"Generation\"\n", "else:\n", - " style=None\n", + " style = None\n", "sns.lineplot(\n", - " data=plot_data5a_melted[plot_data5a_melted.centered_time_in_minutes>=0].sort_values('Generation', ascending=False),\n", - " x=\"centered_time_in_minutes\", \n", + " data=plot_data5a_melted[\n", + " plot_data5a_melted.centered_time_in_minutes >= 0\n", + " ].sort_values(\"Generation\", ascending=False),\n", + " x=\"centered_time_in_minutes\",\n", " y=\"value\",\n", - " hue='Method of calculation',\n", - " palette=[sns.color_palette('dark')[0],sns.color_palette('dark')[1]],\n", + " hue=\"Method of calculation\",\n", + " palette=[sns.color_palette(\"dark\")[0], sns.color_palette(\"dark\")[1]],\n", " style=style,\n", " ci=95,\n", - " ax=axs[0,0],\n", - " legend=False\n", + " ax=axs[0, 0],\n", + " legend=False,\n", ")\n", "sns.lineplot(\n", " data=plot_data5a_melted[\n", - " (plot_data5a_melted.centered_time_in_minutes<=0) &\n", - " (plot_data5a_melted['Method of calculation'] == \"Combined signal\")\n", - " ].sort_values('Generation', ascending=False),\n", - " x=\"centered_time_in_minutes\", \n", + " (plot_data5a_melted.centered_time_in_minutes <= 0)\n", + " & (plot_data5a_melted[\"Method of calculation\"] == \"Combined signal\")\n", + " ].sort_values(\"Generation\", ascending=False),\n", + " x=\"centered_time_in_minutes\",\n", " y=\"value\",\n", - " hue='Method of calculation',\n", - " palette=[sns.color_palette('pastel')[1]],\n", + " hue=\"Method of calculation\",\n", + " palette=[sns.color_palette(\"pastel\")[1]],\n", " style=style,\n", " ci=95,\n", - " ax=axs[0,0],\n", - " legend=False\n", + " ax=axs[0, 0],\n", + " legend=False,\n", ")\n", "\n", - "axs[0,0].axvline(x=0, color='red')#, label='Time of Bud Emergence')\n", - "axs[0,0].text(\n", - " 0.7, 1.5e5, \"Time of \\nbud emerg.\", horizontalalignment='left', \n", - " size='medium', color='red', weight='normal'\n", + "axs[0, 0].axvline(x=0, color=\"red\") # , label='Time of Bud Emergence')\n", + "axs[0, 0].text(\n", + " 0.7,\n", + " 1.5e5,\n", + " \"Time of \\nbud emerg.\",\n", + " horizontalalignment=\"left\",\n", + " size=\"medium\",\n", + " color=\"red\",\n", + " weight=\"normal\",\n", ")\n", "# custom legend\n", - "labels = [\n", - " 'Combined signal',\n", - " 'Bud signal',\n", - " 'Division 1',\n", - " 'Divisions 2+'\n", - "]\n", + "labels = [\"Combined signal\", \"Bud signal\", \"Division 1\", \"Divisions 2+\"]\n", "handles = [\n", - " mpatches.Patch(color=sns.color_palette('pastel')[1]),\n", - " mpatches.Patch(color=sns.color_palette('dark')[0]),\n", - " mlines.Line2D([], [], color='gray', linestyle='-'),\n", - " mlines.Line2D([], [], color='gray', linestyle='--')\n", + " mpatches.Patch(color=sns.color_palette(\"pastel\")[1]),\n", + " mpatches.Patch(color=sns.color_palette(\"dark\")[0]),\n", + " mlines.Line2D([], [], color=\"gray\", linestyle=\"-\"),\n", + " mlines.Line2D([], [], color=\"gray\", linestyle=\"--\"),\n", "]\n", "handles2 = [\n", - " mpatches.Patch(color=sns.color_palette('dark')[1]),\n", - " mpatches.Patch(color='white'),\n", - " mpatches.Patch(color='white'),\n", - " mpatches.Patch(color='white'),\n", + " mpatches.Patch(color=sns.color_palette(\"dark\")[1]),\n", + " mpatches.Patch(color=\"white\"),\n", + " mpatches.Patch(color=\"white\"),\n", + " mpatches.Patch(color=\"white\"),\n", "]\n", - "axs[0,0].legend(\n", - " handles=handles+handles2,\n", + "axs[0, 0].legend(\n", + " handles=handles + handles2,\n", " ncol=2,\n", - " labels=['']*4+labels,\n", + " labels=[\"\"] * 4 + labels,\n", " columnspacing=-0.5,\n", - " loc='upper left',\n", - " bbox_to_anchor = (-.01,1),\n", + " loc=\"upper left\",\n", + " bbox_to_anchor=(-0.01, 1),\n", " framealpha=0.5,\n", - " handlelength=1\n", - ")\n", - "#plt.setp(axs[0,0].get_legend().get_title(), fontsize='20') \n", - "axs[0,0].set_ylabel(\"Amount of Htb1-mCitrine [a.u.]\")\n", - "axs[0,0].set_xlabel(\"Time since bud emergence [minutes]\")\n", - "#axs[0,0].set_title('A', fontsize=30, loc='left', pad=10)\n", - "axs[0,0].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[0,0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", - "axs[0,0].set_ylim(-0.2e5, shared_y_max)\n", - "axs[0,0].set_xlim(\n", + " handlelength=1,\n", + ")\n", + "# plt.setp(axs[0,0].get_legend().get_title(), fontsize='20')\n", + "axs[0, 0].set_ylabel(\"Amount of Htb1-mCitrine [a.u.]\")\n", + "axs[0, 0].set_xlabel(\"Time since bud emergence [minutes]\")\n", + "# axs[0,0].set_title('A', fontsize=30, loc='left', pad=10)\n", + "axs[0, 0].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[0, 0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "axs[0, 0].set_ylim(-0.2e5, shared_y_max)\n", + "axs[0, 0].set_xlim(\n", " plot_data5a_melted.centered_time_in_minutes.min(),\n", - " plot_data5a_melted.centered_time_in_minutes.max()\n", + " plot_data5a_melted.centered_time_in_minutes.max(),\n", ")\n", - "#axs[0,0].legend().get_texts()[0].set_text(matplotlib.text.Text(text='test', fontweight=1000))\n", + "# axs[0,0].legend().get_texts()[0].set_text(matplotlib.text.Text(text='test', fontweight=1000))\n", "\n", - "#subplot 2\n", + "# subplot 2\n", "# Initialize the figure\n", "custom_colors = [\n", - " sns.color_palette('dark')[1],\n", - " sns.color_palette('pastel')[1],\n", + " sns.color_palette(\"dark\")[1],\n", + " sns.color_palette(\"pastel\")[1],\n", " sns.color_palette()[4],\n", - " sns.color_palette()[2]\n", + " sns.color_palette()[2],\n", "]\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b[plot_data5b.generation_num==1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b[plot_data5b.generation_num == 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " palette=custom_colors,\n", " hue=\"Kind of Measurement new\",\n", - " marker='x',\n", - " ax=axs[0,1]\n", + " marker=\"x\",\n", + " ax=axs[0, 1],\n", ")\n", "\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b[plot_data5b.generation_num>1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b[plot_data5b.generation_num > 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " palette=custom_colors,\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " marker='o',\n", - " ax=axs[0,1]\n", + " marker=\"o\",\n", + " ax=axs[0, 1],\n", ")\n", "measurements = [\n", - " 'Mother+bud at cytokinesis',\n", - " 'At G1-entry',\n", - " 'AF control, m+b at cytokinesis',\n", - " 'AF control at G1-entry'\n", + " \"Mother+bud at cytokinesis\",\n", + " \"At G1-entry\",\n", + " \"AF control, m+b at cytokinesis\",\n", + " \"AF control at G1-entry\",\n", "]\n", "\n", "# add regplots in for loop\n", - "print(pd.unique(plot_data5b['Kind of Measurement new']))\n", + "print(pd.unique(plot_data5b[\"Kind of Measurement new\"]))\n", "for idx, measure in enumerate(measurements):\n", " sns.regplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b[plot_data5b['Kind of Measurement new']==measure],\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b[plot_data5b[\"Kind of Measurement new\"] == measure],\n", " color=custom_colors[idx],\n", " scatter=False,\n", - " ax=axs[0,1]\n", + " ax=axs[0, 1],\n", " )\n", - "labels = [\n", - " 'Division 1',\n", - " 'Divisions 2+'\n", - "]\n", + "labels = [\"Division 1\", \"Divisions 2+\"]\n", "handles = [\n", - " mpatches.Patch(color=sns.color_palette('dark')[1]),\n", - " mpatches.Patch(color=sns.color_palette('pastel')[1]),\n", + " mpatches.Patch(color=sns.color_palette(\"dark\")[1]),\n", + " mpatches.Patch(color=sns.color_palette(\"pastel\")[1]),\n", " mpatches.Patch(color=sns.color_palette()[2]),\n", " mpatches.Patch(color=sns.color_palette()[4]),\n", - " mlines.Line2D([], [], color='gray', marker='x', linestyle='None',\n", - " markersize=10),\n", - " mlines.Line2D([], [], color='gray', marker='o', linestyle='None',\n", - " markersize=10)\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"x\", linestyle=\"None\", markersize=10),\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"o\", linestyle=\"None\", markersize=10),\n", "]\n", - "axs[0,1].legend(\n", + "axs[0, 1].legend(\n", " handles=handles,\n", - " labels=measurements+labels, \n", - " loc='lower right',\n", - " #bbox_to_anchor = (1,0),\n", + " labels=measurements + labels,\n", + " loc=\"lower right\",\n", + " # bbox_to_anchor = (1,0),\n", " framealpha=0.5,\n", - " handlelength=0.75\n", + " handlelength=0.75,\n", ")\n", - "axs[0,1].set_ylabel(\"Amount of Htb1-mCitrine [a.u.]\")\n", - "axs[0,1].set_xlabel('Cell volume at G1-entry / before cytokinesis [fL]')\n", + "axs[0, 1].set_ylabel(\"Amount of Htb1-mCitrine [a.u.]\")\n", + "axs[0, 1].set_xlabel(\"Cell volume at G1-entry / before cytokinesis [fL]\")\n", "# format y-axis\n", - "axs[0,1].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[0,1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", + "axs[0, 1].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[0, 1].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", "# format x-axis\n", - "axs[0,1].set_xlim(0, plot_data5b.relevant_volume.max()+2)\n", - "axs[0,1].set_ylim(-0.2e5, shared_y_max)\n", - "#axs[0,1].set_title('B', fontsize=30, loc='left', pad=10)\n", + "axs[0, 1].set_xlim(0, plot_data5b.relevant_volume.max() + 2)\n", + "axs[0, 1].set_ylim(-0.2e5, shared_y_max)\n", + "# axs[0,1].set_title('B', fontsize=30, loc='left', pad=10)\n", "\n", "sns.boxplot(\n", " data=plot_data5c,\n", - " x='x_label',\n", - " y='mCitrine_corrected_concentration',\n", - " palette='vlag',\n", + " x=\"x_label\",\n", + " y=\"mCitrine_corrected_concentration\",\n", + " palette=\"vlag\",\n", " fliersize=0,\n", - " ax=axs[1,0]\n", + " ax=axs[1, 0],\n", ")\n", "\n", - "#add stripplot on top\n", + "# add stripplot on top\n", "sns.stripplot(\n", " data=plot_data5c,\n", - " x='x_label',\n", - " y='mCitrine_corrected_concentration',\n", + " x=\"x_label\",\n", + " y=\"mCitrine_corrected_concentration\",\n", " color=\".3\",\n", - " ax=axs[1,0]\n", + " ax=axs[1, 0],\n", ")\n", "\n", "# switch to scientific number format on y-Axis and move text\n", - "axs[1,0].ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "axs[1,0].get_yaxis().get_offset_text().set_position((-0.07,0))\n", + "axs[1, 0].ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "axs[1, 0].get_yaxis().get_offset_text().set_position((-0.07, 0))\n", "\n", "# Rename axes and set title\n", - "axs[1,0].set_ylabel(\"Htb1-mCitrine amount per volume\\nin mother cell at division [a.u.]\")#, fontsize=20)\n", - "axs[1,0].set_xlabel(\"Division\")\n", - "#axs[1,0].set_title(f\"Concentration by Generation (n={len(plot_data5c)})\", fontsize=25) # changed this from 30 to 25 compared to 5B\n", - "axs[1,0].set_ylim(0, plot_data5c.mCitrine_corrected_concentration.max()+0.1e4)\n", + "axs[1, 0].set_ylabel(\n", + " \"Htb1-mCitrine amount per volume\\nin mother cell at division [a.u.]\"\n", + ") # , fontsize=20)\n", + "axs[1, 0].set_xlabel(\"Division\")\n", + "# axs[1,0].set_title(f\"Concentration by Generation (n={len(plot_data5c)})\", fontsize=25) # changed this from 30 to 25 compared to 5B\n", + "axs[1, 0].set_ylim(0, plot_data5c.mCitrine_corrected_concentration.max() + 0.1e4)\n", "\n", "plt.tight_layout()\n", "\n", - "plt.savefig(os.path.join('..', 'figures', 'new_fig5', 'combined_fig5_v4.png'), dpi=300)\n", - "plt.savefig(os.path.join('..', 'figures', 'new_fig5', 'combined_fig5_v4.svg'))\n", + "plt.savefig(os.path.join(\"..\", \"figures\", \"new_fig5\", \"combined_fig5_v4.png\"), dpi=300)\n", + "plt.savefig(os.path.join(\"..\", \"figures\", \"new_fig5\", \"combined_fig5_v4.svg\"))\n", "\n", "plt.show()\n", - "sample_size5a = len(plot_data5a[['position', 'Cell_ID', 'file', 'generation_num']].drop_duplicates())\n", + "sample_size5a = len(\n", + " plot_data5a[[\"position\", \"Cell_ID\", \"file\", \"generation_num\"]].drop_duplicates()\n", + ")\n", "sample_size5b = len(plot_data5b)\n", - "print(f'Fig 5A sample size: {sample_size5a}')\n", - "print(f'Fig 5A sample sizes by generation: {plot_data5a_melted.Generation.unique()}')\n", - "print(f\"Fig 5B sample size: {int(sample_size5b/2)}\")\n", - "print(f'Fig 5B sample size flu-control: {len(plot_data5b[plot_data5b.selection_subset==1])//2}')\n", - "print(f'Fig 5B sample size tagged strain: {len(plot_data5b[plot_data5b.selection_subset==0])//2}')\n", - "print(f'Fig 5C sample size: {len(plot_data5c)}')" + "print(f\"Fig 5A sample size: {sample_size5a}\")\n", + "print(f\"Fig 5A sample sizes by generation: {plot_data5a_melted.Generation.unique()}\")\n", + "print(f\"Fig 5B sample size: {int(sample_size5b / 2)}\")\n", + "print(\n", + " f\"Fig 5B sample size flu-control: {len(plot_data5b[plot_data5b.selection_subset == 1]) // 2}\"\n", + ")\n", + "print(\n", + " f\"Fig 5B sample size tagged strain: {len(plot_data5b[plot_data5b.selection_subset == 0]) // 2}\"\n", + ")\n", + "print(f\"Fig 5C sample size: {len(plot_data5c)}\")" ] }, { @@ -883,7 +996,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_data5a_melted[plot_data5a_melted.centered_time_in_minutes<=0]" + "plot_data5a_melted[plot_data5a_melted.centered_time_in_minutes <= 0]" ] }, { @@ -921,13 +1034,17 @@ }, "outputs": [], "source": [ - "phase_contr_yeaz_data = pd.read_csv(os.path.join(data_dir, 'SegmPhaseContr_YeaZ_AllPos_acdc_output.csv'))\n", - "cellpose_act1_data = pd.read_csv(os.path.join(data_dir, 'SegmACT1_Cellpose_AllPos_acdc_output.csv'))\n", + "phase_contr_yeaz_data = pd.read_csv(\n", + " os.path.join(data_dir, \"SegmPhaseContr_YeaZ_AllPos_acdc_output.csv\")\n", + ")\n", + "cellpose_act1_data = pd.read_csv(\n", + " os.path.join(data_dir, \"SegmACT1_Cellpose_AllPos_acdc_output.csv\")\n", + ")\n", "merged_vol_data = pd.merge(\n", " phase_contr_yeaz_data,\n", " cellpose_act1_data,\n", - " on=['Position_n', 'Cell_ID'],\n", - " suffixes=('_yeaz', '_cellpose')\n", + " on=[\"Position_n\", \"Cell_ID\"],\n", + " suffixes=(\"_yeaz\", \"_cellpose\"),\n", ")" ] }, @@ -944,37 +1061,46 @@ }, "outputs": [], "source": [ - "plt.figure(figsize=(10,10))\n", + "plt.figure(figsize=(10, 10))\n", "fig = sns.lmplot(\n", " data=merged_vol_data,\n", - " x='cell_vol_fl_yeaz',\n", - " y='cell_vol_fl_cellpose',\n", - " hue='relationship_cellpose',\n", + " x=\"cell_vol_fl_yeaz\",\n", + " y=\"cell_vol_fl_cellpose\",\n", + " hue=\"relationship_cellpose\",\n", " height=7.5,\n", - " legend=False\n", + " legend=False,\n", ")\n", "ax = plt.gca()\n", - "labels = [\n", - " 'Mother cells',\n", - " 'Buds & daughter cells'\n", - "]\n", + "labels = [\"Mother cells\", \"Buds & daughter cells\"]\n", "handles = [\n", - " mlines.Line2D([], [], color=sns.color_palette()[0], marker='o', linestyle='None',\n", - " markersize=10),\n", - " mlines.Line2D([], [], color=sns.color_palette()[1], marker='o', linestyle='None',\n", - " markersize=10)\n", + " mlines.Line2D(\n", + " [],\n", + " [],\n", + " color=sns.color_palette()[0],\n", + " marker=\"o\",\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " ),\n", + " mlines.Line2D(\n", + " [],\n", + " [],\n", + " color=sns.color_palette()[1],\n", + " marker=\"o\",\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " ),\n", "]\n", "ax.legend(\n", " handles=handles,\n", - " labels=labels, \n", - " loc='center right',\n", - " bbox_to_anchor = (1,0.2),\n", - " framealpha=0.5\n", + " labels=labels,\n", + " loc=\"center right\",\n", + " bbox_to_anchor=(1, 0.2),\n", + " framealpha=0.5,\n", ")\n", - "ax.set_xlabel('Cell Volume Phase Contrast + YeaZ [fL]')\n", - "ax.set_ylabel('Cell Volume Act1 signal + cellpose [fL]')\n", + "ax.set_xlabel(\"Cell Volume Phase Contrast + YeaZ [fL]\")\n", + "ax.set_ylabel(\"Cell Volume Act1 signal + cellpose [fL]\")\n", "plt.show()\n", - "#merged_vol_data.to_csv(os.path.join(data_dir, 'plot_data3a.csv'), index=False)" + "# merged_vol_data.to_csv(os.path.join(data_dir, 'plot_data3a.csv'), index=False)" ] }, { @@ -1011,12 +1137,20 @@ }, "outputs": [], "source": [ - "stem_data = pd.read_csv(os.path.join(data_dir, 'p38_AB_AllPos_acdc_output.csv'))\n", + "stem_data = pd.read_csv(os.path.join(data_dir, \"p38_AB_AllPos_acdc_output.csv\"))\n", "# configure borders for \"size blocks\"\n", "xs_borders = 0, np.percentile(stem_data.cell_vol_fl, 15)\n", - "m_borders = np.percentile(stem_data.cell_vol_fl, 35), np.percentile(stem_data.cell_vol_fl, 65)\n", - "xl_borders = np.percentile(stem_data.cell_vol_fl, 85), np.max(stem_data.cell_vol_fl) + 20\n", - "stem_data['Pp38_concentration'] = stem_data['Pp38_amount_autoBkgr_zSlice'] / stem_data['cell_vol_fl']" + "m_borders = (\n", + " np.percentile(stem_data.cell_vol_fl, 35),\n", + " np.percentile(stem_data.cell_vol_fl, 65),\n", + ")\n", + "xl_borders = (\n", + " np.percentile(stem_data.cell_vol_fl, 85),\n", + " np.max(stem_data.cell_vol_fl) + 20,\n", + ")\n", + "stem_data[\"Pp38_concentration\"] = (\n", + " stem_data[\"Pp38_amount_autoBkgr_zSlice\"] / stem_data[\"cell_vol_fl\"]\n", + ")" ] }, { @@ -1043,33 +1177,46 @@ }, "outputs": [], "source": [ - "plt.subplots(figsize=(10,10))\n", - "sns.set_theme(context='talk', style='darkgrid')\n", + "plt.subplots(figsize=(10, 10))\n", + "sns.set_theme(context=\"talk\", style=\"darkgrid\")\n", "ax = sns.scatterplot(\n", " data=stem_data,\n", - " x='cell_vol_fl',\n", - " y='Pp38_concentration',\n", - " #hue='size_category'\n", + " x=\"cell_vol_fl\",\n", + " y=\"Pp38_concentration\",\n", + " # hue='size_category'\n", + ")\n", + "ax.set_xlabel(\"HSC Nuclear Volume [fL]\")\n", + "ax.set_ylabel(\"Mean intensity Pp38 [a.u.]\")\n", + "lower_y_border, upper_y_border = (\n", + " stem_data.Pp38_concentration.min() - 10,\n", + " stem_data.Pp38_concentration.max() + 10,\n", ")\n", - "ax.set_xlabel('HSC Nuclear Volume [fL]')\n", - "ax.set_ylabel('Mean intensity Pp38 [a.u.]')\n", - "lower_y_border, upper_y_border = stem_data.Pp38_concentration.min()-10, stem_data.Pp38_concentration.max()+10\n", "height = upper_y_border - lower_y_border\n", - "xs_width = xs_borders[1]-xs_borders[0]\n", - "m_width = m_borders[1]-m_borders[0]\n", - "xl_width = xl_borders[1]-xl_borders[0]\n", + "xs_width = xs_borders[1] - xs_borders[0]\n", + "m_width = m_borders[1] - m_borders[0]\n", + "xl_width = xl_borders[1] - xl_borders[0]\n", "ax.add_patch(\n", - " patches.Rectangle((xs_borders[0], lower_y_border), xs_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (xs_borders[0], lower_y_border), xs_width, height, color=\"black\", alpha=0.2\n", + " )\n", + ")\n", + "plt.text(\n", + " 0.5 * sum(xs_borders) - 20, upper_y_border - 50, \"XS\", fontdict={\"fontsize\": 30}\n", ")\n", - "plt.text(0.5*sum(xs_borders)-20, upper_y_border-50, 'XS', fontdict={'fontsize':30})\n", "ax.add_patch(\n", - " patches.Rectangle((m_borders[0], lower_y_border), m_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (m_borders[0], lower_y_border), m_width, height, color=\"black\", alpha=0.2\n", + " )\n", ")\n", - "plt.text(0.5*sum(m_borders)-20, upper_y_border-50, 'M', fontdict={'fontsize':30})\n", + "plt.text(0.5 * sum(m_borders) - 20, upper_y_border - 50, \"M\", fontdict={\"fontsize\": 30})\n", "ax.add_patch(\n", - " patches.Rectangle((xl_borders[0], lower_y_border), xl_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (xl_borders[0], lower_y_border), xl_width, height, color=\"black\", alpha=0.2\n", + " )\n", + ")\n", + "plt.text(\n", + " 0.5 * sum(xl_borders) - 20, upper_y_border - 50, \"XL\", fontdict={\"fontsize\": 30}\n", ")\n", - "plt.text(0.5*sum(xl_borders)-20, upper_y_border-50, 'XL', fontdict={'fontsize':30})\n", "\"\"\"\n", "plt.savefig(\n", " '../figures/stemcell_scatter_v1.pdf',\n", @@ -1077,7 +1224,7 @@ ")\n", "\"\"\"\n", "plt.show()\n", - "#stem_data.to_csv(os.path.join(data_dir, 'plot_data4c.csv'), index=False)" + "# stem_data.to_csv(os.path.join(data_dir, 'plot_data4c.csv'), index=False)" ] }, { @@ -1104,22 +1251,33 @@ }, "outputs": [], "source": [ - "stem_bkgr_data = pd.read_csv(os.path.join(data_dir, 'p38_control_AllPos_acdc_output.csv'))\n", - "stem_bkgr_data['Pp38_concentration'] = stem_bkgr_data['Pp38_amount_autoBkgr_zSlice'] / stem_bkgr_data['cell_vol_fl']\n", + "stem_bkgr_data = pd.read_csv(\n", + " os.path.join(data_dir, \"p38_control_AllPos_acdc_output.csv\")\n", + ")\n", + "stem_bkgr_data[\"Pp38_concentration\"] = (\n", + " stem_bkgr_data[\"Pp38_amount_autoBkgr_zSlice\"] / stem_bkgr_data[\"cell_vol_fl\"]\n", + ")\n", + "\n", + "\n", "def generate_size_str(x):\n", - " if x>=0 and x<=xs_borders[1]:\n", - " return 'XS'\n", - " elif x>=m_borders[0] and x<=m_borders[1]:\n", - " return 'M'\n", - " elif x>=xl_borders[0]:\n", - " return 'XL'\n", + " if x >= 0 and x <= xs_borders[1]:\n", + " return \"XS\"\n", + " elif x >= m_borders[0] and x <= m_borders[1]:\n", + " return \"M\"\n", + " elif x >= xl_borders[0]:\n", + " return \"XL\"\n", " else:\n", - " return 'rest'\n", + " return \"rest\"\n", + "\n", + "\n", "all_data = stem_data.copy()\n", - "all_data['size_category'] = 'All'\n", - "stem_data['size_category'] = stem_data.cell_vol_fl.apply(generate_size_str)\n", - "stem_bkgr_data['size_category'] = 'Control'\n", - "box_data = pd.concat([all_data, stem_data[stem_data.size_category!='rest'], stem_bkgr_data], ignore_index=True)" + "all_data[\"size_category\"] = \"All\"\n", + "stem_data[\"size_category\"] = stem_data.cell_vol_fl.apply(generate_size_str)\n", + "stem_bkgr_data[\"size_category\"] = \"Control\"\n", + "box_data = pd.concat(\n", + " [all_data, stem_data[stem_data.size_category != \"rest\"], stem_bkgr_data],\n", + " ignore_index=True,\n", + ")" ] }, { @@ -1135,18 +1293,18 @@ }, "outputs": [], "source": [ - "sns.set_theme(context='talk', style='darkgrid')\n", - "plt.figure(figsize=(10,10))\n", + "sns.set_theme(context=\"talk\", style=\"darkgrid\")\n", + "plt.figure(figsize=(10, 10))\n", "ax = sns.boxplot(\n", " data=box_data,\n", - " x='size_category',\n", - " y='Pp38_concentration',\n", + " x=\"size_category\",\n", + " y=\"Pp38_concentration\",\n", " order=[\"Control\", \"All\", \"XS\", \"M\", \"XL\"],\n", - " color=sns.color_palette()[0]\n", - " #inner='quartile'\n", + " color=sns.color_palette()[0],\n", + " # inner='quartile'\n", ")\n", - "ax.set_xlabel('Size Category')\n", - "ax.set_ylabel('Mean intensity Pp38 [a.u.]')\n", + "ax.set_xlabel(\"Size Category\")\n", + "ax.set_ylabel(\"Mean intensity Pp38 [a.u.]\")\n", "\"\"\"\n", "plt.savefig(\n", " '../figures/stemcell_violin_v1.pdf',\n", @@ -1154,7 +1312,7 @@ ")\n", "\"\"\"\n", "plt.show()\n", - "#box_data.to_csv(os.path.join(data_dir, 'plot_data4d.csv'), index=False)" + "# box_data.to_csv(os.path.join(data_dir, 'plot_data4d.csv'), index=False)" ] }, { @@ -1181,16 +1339,26 @@ }, "outputs": [], "source": [ - "stem_data = pd.read_csv(os.path.join(data_dir, 'stemcell_data.csv'))\n", + "stem_data = pd.read_csv(os.path.join(data_dir, \"stemcell_data.csv\"))\n", "# configure borders for \"size blocks\"\n", "xs_borders = 0, np.percentile(stem_data.cell_vol_fl, 15)\n", - "m_borders = np.percentile(stem_data.cell_vol_fl, 35), np.percentile(stem_data.cell_vol_fl, 65)\n", - "xl_borders = np.percentile(stem_data.cell_vol_fl, 85), np.percentile(stem_data.cell_vol_fl, 85) * 2 + 20\n", + "m_borders = (\n", + " np.percentile(stem_data.cell_vol_fl, 35),\n", + " np.percentile(stem_data.cell_vol_fl, 65),\n", + ")\n", + "xl_borders = (\n", + " np.percentile(stem_data.cell_vol_fl, 85),\n", + " np.percentile(stem_data.cell_vol_fl, 85) * 2 + 20,\n", + ")\n", "# In Fig. 3B very small cells are assumed to be imaging fragments, very large cells missed Segmentation errors\n", - "min_vol, max_vol = 0, xl_borders[0]*2\n", - "stem_selection_indices = np.logical_and(stem_data.cell_vol_fl>min_vol, stem_data.cell_vol_fl min_vol, stem_data.cell_vol_fl < max_vol\n", + ")\n", "stem_data = stem_data[stem_selection_indices]\n", - "stem_data['FITC_concentration'] = stem_data['FITC_amount_autoBkgr_zSlice'] / stem_data['cell_vol_fl']" + "stem_data[\"FITC_concentration\"] = (\n", + " stem_data[\"FITC_amount_autoBkgr_zSlice\"] / stem_data[\"cell_vol_fl\"]\n", + ")" ] }, { @@ -1217,35 +1385,48 @@ }, "outputs": [], "source": [ - "plt.subplots(figsize=(10,10))\n", - "sns.set_theme(context='talk', style='darkgrid')\n", + "plt.subplots(figsize=(10, 10))\n", + "sns.set_theme(context=\"talk\", style=\"darkgrid\")\n", "ax = sns.scatterplot(\n", " data=stem_data,\n", - " x='cell_vol_fl',\n", - " y='FITC_concentration',\n", - " color=sns.color_palette()[2]\n", - " #hue='size_category'\n", - ")\n", - "ax.set_xlabel('HSC Volume [fL]')\n", - "ax.set_ylabel('Mean intensity FITC [a.u.]')\n", - "lower_y_border, upper_y_border = stem_data.FITC_concentration.min()-10, stem_data.FITC_concentration.max()+10\n", + " x=\"cell_vol_fl\",\n", + " y=\"FITC_concentration\",\n", + " color=sns.color_palette()[2],\n", + " # hue='size_category'\n", + ")\n", + "ax.set_xlabel(\"HSC Volume [fL]\")\n", + "ax.set_ylabel(\"Mean intensity FITC [a.u.]\")\n", + "lower_y_border, upper_y_border = (\n", + " stem_data.FITC_concentration.min() - 10,\n", + " stem_data.FITC_concentration.max() + 10,\n", + ")\n", "height = upper_y_border - lower_y_border\n", - "xs_width = xs_borders[1]-xs_borders[0]\n", - "m_width = m_borders[1]-m_borders[0]\n", - "xl_width = xl_borders[1]-xl_borders[0]\n", + "xs_width = xs_borders[1] - xs_borders[0]\n", + "m_width = m_borders[1] - m_borders[0]\n", + "xl_width = xl_borders[1] - xl_borders[0]\n", "ax.add_patch(\n", - " patches.Rectangle((xs_borders[0], lower_y_border), xs_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (xs_borders[0], lower_y_border), xs_width, height, color=\"black\", alpha=0.2\n", + " )\n", + ")\n", + "plt.text(\n", + " 0.5 * sum(xs_borders) - 20, upper_y_border - 50, \"XS\", fontdict={\"fontsize\": 30}\n", ")\n", - "plt.text(0.5*sum(xs_borders)-20, upper_y_border-50, 'XS', fontdict={'fontsize':30})\n", "ax.add_patch(\n", - " patches.Rectangle((m_borders[0], lower_y_border), m_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (m_borders[0], lower_y_border), m_width, height, color=\"black\", alpha=0.2\n", + " )\n", ")\n", - "plt.text(0.5*sum(m_borders)-20, upper_y_border-50, 'M', fontdict={'fontsize':30})\n", + "plt.text(0.5 * sum(m_borders) - 20, upper_y_border - 50, \"M\", fontdict={\"fontsize\": 30})\n", "ax.add_patch(\n", - " patches.Rectangle((xl_borders[0], lower_y_border), xl_width, height, color='black', alpha=0.2)\n", + " patches.Rectangle(\n", + " (xl_borders[0], lower_y_border), xl_width, height, color=\"black\", alpha=0.2\n", + " )\n", ")\n", "ax.set_xlim(0, xl_borders[1])\n", - "plt.text(0.5*sum(xl_borders)-20, upper_y_border-50, 'XL', fontdict={'fontsize':30})\n", + "plt.text(\n", + " 0.5 * sum(xl_borders) - 20, upper_y_border - 50, \"XL\", fontdict={\"fontsize\": 30}\n", + ")\n", "\"\"\"\n", "plt.savefig(\n", " '../figures/stemcell_scatter_v1.pdf',\n", @@ -1253,7 +1434,7 @@ ")\n", "\"\"\"\n", "plt.show()\n", - "#stem_data.to_csv(os.path.join(data_dir, 'plot_data4a.csv'), index=False)" + "# stem_data.to_csv(os.path.join(data_dir, 'plot_data4a.csv'), index=False)" ] }, { @@ -1280,22 +1461,31 @@ }, "outputs": [], "source": [ - "stem_bkgr_data = pd.read_csv(os.path.join(data_dir, 'stemcell_bkgr_data.csv'))\n", - "stem_bkgr_data['FITC_concentration'] = stem_bkgr_data['FITC_amount_autoBkgr_zSlice'] / stem_bkgr_data['cell_vol_fl']\n", + "stem_bkgr_data = pd.read_csv(os.path.join(data_dir, \"stemcell_bkgr_data.csv\"))\n", + "stem_bkgr_data[\"FITC_concentration\"] = (\n", + " stem_bkgr_data[\"FITC_amount_autoBkgr_zSlice\"] / stem_bkgr_data[\"cell_vol_fl\"]\n", + ")\n", + "\n", + "\n", "def generate_size_str(x):\n", - " if x>=0 and x<=xs_borders[1]:\n", - " return 'XS'\n", - " elif x>=m_borders[0] and x<=m_borders[1]:\n", - " return 'M'\n", - " elif x>=xl_borders[0]:\n", - " return 'XL'\n", + " if x >= 0 and x <= xs_borders[1]:\n", + " return \"XS\"\n", + " elif x >= m_borders[0] and x <= m_borders[1]:\n", + " return \"M\"\n", + " elif x >= xl_borders[0]:\n", + " return \"XL\"\n", " else:\n", - " return 'rest'\n", + " return \"rest\"\n", + "\n", + "\n", "all_data = stem_data.copy()\n", - "all_data['size_category'] = 'All'\n", - "stem_data['size_category'] = stem_data.cell_vol_fl.apply(generate_size_str)\n", - "stem_bkgr_data['size_category'] = 'Control'\n", - "box_data = pd.concat([all_data, stem_data[stem_data.size_category!='rest'], stem_bkgr_data], ignore_index=True)" + "all_data[\"size_category\"] = \"All\"\n", + "stem_data[\"size_category\"] = stem_data.cell_vol_fl.apply(generate_size_str)\n", + "stem_bkgr_data[\"size_category\"] = \"Control\"\n", + "box_data = pd.concat(\n", + " [all_data, stem_data[stem_data.size_category != \"rest\"], stem_bkgr_data],\n", + " ignore_index=True,\n", + ")" ] }, { @@ -1311,18 +1501,18 @@ }, "outputs": [], "source": [ - "sns.set_theme(context='talk', style='darkgrid')\n", - "plt.figure(figsize=(10,10))\n", + "sns.set_theme(context=\"talk\", style=\"darkgrid\")\n", + "plt.figure(figsize=(10, 10))\n", "ax = sns.boxplot(\n", " data=box_data,\n", - " x='size_category',\n", - " y='FITC_concentration',\n", + " x=\"size_category\",\n", + " y=\"FITC_concentration\",\n", " order=[\"Control\", \"All\", \"XS\", \"M\", \"XL\"],\n", - " color=sns.color_palette()[2]\n", - " #inner='quartile'\n", + " color=sns.color_palette()[2],\n", + " # inner='quartile'\n", ")\n", - "ax.set_xlabel('Size Category')\n", - "ax.set_ylabel('Mean intensity FITC [a.u.]')\n", + "ax.set_xlabel(\"Size Category\")\n", + "ax.set_ylabel(\"Mean intensity FITC [a.u.]\")\n", "\"\"\"\n", "plt.savefig(\n", " '../figures/stemcell_violin_v1.pdf',\n", @@ -1330,7 +1520,7 @@ ")\n", "\"\"\"\n", "plt.show()\n", - "#box_data.to_csv(os.path.join(data_dir, 'plot_data4b.csv'), index=False)" + "# box_data.to_csv(os.path.join(data_dir, 'plot_data4b.csv'), index=False)" ] }, { @@ -1369,19 +1559,22 @@ "source": [ "data_dirs, positions = (\n", " [\n", - " '../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_labeled',\n", - " '../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_flu_control_labeled'\n", + " \"../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_labeled\",\n", + " \"../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_flu_control_labeled\",\n", " ],\n", " [\n", - " ['Position_2', 'Position_3', 'Position_4', 'Position_5', 'Position_8'],\n", - " ['Position_1', 'Position_3']\n", - " ]\n", + " [\"Position_2\", \"Position_3\", \"Position_4\", \"Position_5\", \"Position_8\"],\n", + " [\"Position_1\", \"Position_3\"],\n", + " ],\n", ")\n", "file_names = [os.path.split(path)[-1] for path in data_dirs]\n", - "image_folders = [[os.path.join(data_dir, pos_str, 'Images') for pos_str in pos_list] for pos_list, data_dir in zip(positions, data_dirs)]\n", + "image_folders = [\n", + " [os.path.join(data_dir, pos_str, \"Images\") for pos_str in pos_list]\n", + " for pos_list, data_dir in zip(positions, data_dirs)\n", + "]\n", "# determine available channels based on first(!) position.\n", "# Warn user if one or more of the channels are not available for some positions\n", - "first_pos_dir = os.path.join(data_dirs[0], positions[0][0], 'Images')\n", + "first_pos_dir = os.path.join(data_dirs[0], positions[0][0], \"Images\")\n", "first_pos_files = myutils.listdir(first_pos_dir)\n", "channels, warn = cca_functions.find_available_channels(first_pos_files, first_pos_dir)" ] @@ -1401,13 +1594,9 @@ "outputs": [], "source": [ "overall_df, is_timelapse_data, is_zstack_data = cca_functions.calculate_downstream_data(\n", - " file_names,\n", - " image_folders,\n", - " positions,\n", - " channels, \n", - " force_recalculation=False\n", + " file_names, image_folders, positions, channels, force_recalculation=False\n", ")\n", - "#overall_df.to_csv(os.path.join(data_dir, 'raw_downstream_data_fig4_v2.csv'), index=False)" + "# overall_df.to_csv(os.path.join(data_dir, 'raw_downstream_data_fig4_v2.csv'), index=False)" ] }, { @@ -1433,8 +1622,8 @@ }, "outputs": [], "source": [ - "data_dir = os.path.join('..', 'data', 'paper_plot_data')\n", - "overall_df = pd.read_csv(os.path.join(data_dir, 'raw_downstream_data_fig5_v2.csv'))" + "data_dir = os.path.join(\"..\", \"data\", \"paper_plot_data\")\n", + "overall_df = pd.read_csv(os.path.join(data_dir, \"raw_downstream_data_fig5_v2.csv\"))" ] }, { @@ -1463,17 +1652,24 @@ "overall_df_with_rel = cca_functions.calculate_relatives_data(overall_df, channels)\n", "# If working with timelapse data build dataframe grouped by phases\n", "group_cols = [\n", - " 'Cell_ID', 'generation_num', 'cell_cycle_stage', 'relationship', 'position', 'file', \n", - " 'max_frame_pos', 'selection_subset', 'max_t'\n", + " \"Cell_ID\",\n", + " \"generation_num\",\n", + " \"cell_cycle_stage\",\n", + " \"relationship\",\n", + " \"position\",\n", + " \"file\",\n", + " \"max_frame_pos\",\n", + " \"selection_subset\",\n", + " \"max_t\",\n", "]\n", "# calculate data grouped by phase only in the case, that timelapse data is available\n", "if is_timelapse_data:\n", - " phase_grouped = cca_functions.calculate_per_phase_quantities(overall_df_with_rel, group_cols, channels)\n", + " phase_grouped = cca_functions.calculate_per_phase_quantities(\n", + " overall_df_with_rel, group_cols, channels\n", + " )\n", " # append phase-grouped data to overall_df\n", " overall_df_with_rel = overall_df_with_rel.merge(\n", - " phase_grouped,\n", - " how='left',\n", - " on=group_cols\n", + " phase_grouped, how=\"left\", on=group_cols\n", " )" ] }, @@ -1518,7 +1714,7 @@ "# some configurations\n", "# frame interval of video\n", "frame_interval_minutes = 3\n", - "# quantiles of complete cell cycles (wrt phase lengths) to exclude from analysis \n", + "# quantiles of complete cell cycles (wrt phase lengths) to exclude from analysis\n", "# (not used, keep this for potential later use)\n", "down_q, upper_q = 0, 1\n", "# minimum number of cell cycles contributing to the mean+CI curve:\n", @@ -1530,104 +1726,131 @@ "\n", "# select needed cols from overall_df_with_rel to not end up with too many columns\n", "needed_cols = [\n", - " 'selection_subset', 'position', 'Cell_ID', 'cell_cycle_stage', 'generation_num', 'frame_i',\n", - " 'mCitrine_corrected_amount', 'mCitrine_corrected_amount_rel', \n", - " 'file', 'relationship', 'relative_ID', 'phase_length', 'phase_begin', 'gui_mCitrine_amount_autoBkgr'\n", + " \"selection_subset\",\n", + " \"position\",\n", + " \"Cell_ID\",\n", + " \"cell_cycle_stage\",\n", + " \"generation_num\",\n", + " \"frame_i\",\n", + " \"mCitrine_corrected_amount\",\n", + " \"mCitrine_corrected_amount_rel\",\n", + " \"file\",\n", + " \"relationship\",\n", + " \"relative_ID\",\n", + " \"phase_length\",\n", + " \"phase_begin\",\n", + " \"gui_mCitrine_amount_autoBkgr\",\n", "]\n", - "filter_idx = np.logical_and(overall_df_with_rel['complete_cycle'] == 1, overall_df_with_rel.selection_subset==0)\n", + "filter_idx = np.logical_and(\n", + " overall_df_with_rel[\"complete_cycle\"] == 1,\n", + " overall_df_with_rel.selection_subset == 0,\n", + ")\n", "plot_data5a = overall_df_with_rel.loc[filter_idx, needed_cols].copy()\n", "# calculate the time the cell already spent in the current frame at the current timepoint\n", - "plot_data5a['frames_in_phase'] = plot_data5a['frame_i'] - plot_data5a['phase_begin'] + 1\n", - "# calculate the time to the next (for G1 cells) and from the last (for S cells) G1/S transition \n", - "plot_data5a['centered_frames_in_phase'] = plot_data5a.apply(\n", - " lambda x: x.loc['frames_in_phase'] if\\\n", - " x.loc['cell_cycle_stage']=='S' else\\\n", - " x.loc['frames_in_phase']-1-x.loc['phase_length'],\n", - " axis=1\n", + "plot_data5a[\"frames_in_phase\"] = plot_data5a[\"frame_i\"] - plot_data5a[\"phase_begin\"] + 1\n", + "# calculate the time to the next (for G1 cells) and from the last (for S cells) G1/S transition\n", + "plot_data5a[\"centered_frames_in_phase\"] = plot_data5a.apply(\n", + " lambda x: (\n", + " x.loc[\"frames_in_phase\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\"\n", + " else x.loc[\"frames_in_phase\"] - 1 - x.loc[\"phase_length\"]\n", + " ),\n", + " axis=1,\n", ")\n", "# calculate combined signal and the \"Pool, Phase ID\" for the legend\n", - "# plot_data5a at this point only contains relationship==mother, \n", + "# plot_data5a at this point only contains relationship==mother,\n", "# as generation_num==0 and relationship==bud are filtered out (incomplete cycle, cycles start with G1)\n", - "plot_data5a['Combined signal'] = plot_data5a.apply(\n", - " lambda x: x.loc['mCitrine_corrected_amount']+x.loc['mCitrine_corrected_amount_rel'] if\\\n", - " x.loc['cell_cycle_stage']=='S' and x.loc['relationship'] == 'mother' else\\\n", - " x.loc['mCitrine_corrected_amount'],\n", - " axis=1\n", + "plot_data5a[\"Combined signal\"] = plot_data5a.apply(\n", + " lambda x: (\n", + " x.loc[\"mCitrine_corrected_amount\"] + x.loc[\"mCitrine_corrected_amount_rel\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\" and x.loc[\"relationship\"] == \"mother\"\n", + " else x.loc[\"mCitrine_corrected_amount\"]\n", + " ),\n", + " axis=1,\n", ")\n", - "plot_data5a['Bud signal'] = plot_data5a.apply(\n", - " lambda x: x.loc['mCitrine_corrected_amount_rel'] if\\\n", - " x.loc['cell_cycle_stage']=='S' and x.loc['relationship'] == 'mother' else 0,\n", - " axis=1\n", + "plot_data5a[\"Bud signal\"] = plot_data5a.apply(\n", + " lambda x: (\n", + " x.loc[\"mCitrine_corrected_amount_rel\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\" and x.loc[\"relationship\"] == \"mother\"\n", + " else 0\n", + " ),\n", + " axis=1,\n", ")\n", "# scale data if needed\n", "if scale_data:\n", - " maximum = max(\n", - " plot_data5a['Combined signal'].max(), \n", - " plot_data5a['Bud signal'].max()\n", - " )\n", - " plot_data5a['Combined signal'] /= maximum\n", - " plot_data5a['Bud signal'] /= maximum\n", + " maximum = max(plot_data5a[\"Combined signal\"].max(), plot_data5a[\"Bud signal\"].max())\n", + " plot_data5a[\"Combined signal\"] /= maximum\n", + " plot_data5a[\"Bud signal\"] /= maximum\n", "# calculate min and max centered times per generation to eliminate up to a percentile\n", "# (not used, as upper_q and lower_q are set to 100/0 respectively)\n", - "plot_data5a['min_centered_frames'] = plot_data5a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ")['centered_frames_in_phase'].transform(\n", - " 'min'\n", - ")\n", - "plot_data5a['max_centered_frames'] = plot_data5a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ")['centered_frames_in_phase'].transform(\n", - " 'max'\n", - ")\n", - "min_and_max = plot_data5a.groupby(\n", - " ['Cell_ID', 'generation_num', 'position', 'file']\n", - ").agg(\n", - " min_centered = ('min_centered_frames', 'first'),\n", - " max_centered = ('max_centered_frames', 'first')\n", - ").reset_index()\n", - "min_val, max_val = np.quantile(\n", - " min_and_max.min_centered, down_q\n", - ") * frame_interval_minutes, np.quantile(\n", - " min_and_max.max_centered, upper_q\n", - ") * frame_interval_minutes\n", + "plot_data5a[\"min_centered_frames\"] = plot_data5a.groupby(\n", + " [\"position\", \"file\", \"Cell_ID\", \"generation_num\"]\n", + ")[\"centered_frames_in_phase\"].transform(\"min\")\n", + "plot_data5a[\"max_centered_frames\"] = plot_data5a.groupby(\n", + " [\"position\", \"file\", \"Cell_ID\", \"generation_num\"]\n", + ")[\"centered_frames_in_phase\"].transform(\"max\")\n", + "min_and_max = (\n", + " plot_data5a.groupby([\"Cell_ID\", \"generation_num\", \"position\", \"file\"])\n", + " .agg(\n", + " min_centered=(\"min_centered_frames\", \"first\"),\n", + " max_centered=(\"max_centered_frames\", \"first\"),\n", + " )\n", + " .reset_index()\n", + ")\n", + "min_val, max_val = (\n", + " np.quantile(min_and_max.min_centered, down_q) * frame_interval_minutes,\n", + " np.quantile(min_and_max.max_centered, upper_q) * frame_interval_minutes,\n", + ")\n", "# perform selection (won't change anything if upper and lower are 100 and 0 respectively)\n", "selection_indices = np.logical_and(\n", - " plot_data5a.min_centered_frames*frame_interval_minutes>=min_val, \n", - " plot_data5a.max_centered_frames*frame_interval_minutes<=max_val\n", + " plot_data5a.min_centered_frames * frame_interval_minutes >= min_val,\n", + " plot_data5a.max_centered_frames * frame_interval_minutes <= max_val,\n", ")\n", "plot_data5a = plot_data5a[selection_indices]\n", "\n", "# calculate centered time in minutes\n", - "plot_data5a['centered_time_in_minutes'] = plot_data5a.centered_frames_in_phase * frame_interval_minutes\n", + "plot_data5a[\"centered_time_in_minutes\"] = (\n", + " plot_data5a.centered_frames_in_phase * frame_interval_minutes\n", + ")\n", "\n", "# group dataframe to calculate sample sizes per generation\n", - "standard_grouped = plot_data5a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ").agg('count').reset_index()\n", - "plot_data5a['Generation'] = plot_data5a.apply(\n", - " lambda x: f'1st ($n_1$={len(standard_grouped[standard_grouped.generation_num==1])})' if\\\n", - " x.loc['generation_num']==1 else f'2+ ($n_2$={len(standard_grouped[standard_grouped.generation_num>1])})',\n", - " axis=1\n", + "standard_grouped = (\n", + " plot_data5a.groupby([\"position\", \"file\", \"Cell_ID\", \"generation_num\"])\n", + " .agg(\"count\")\n", + " .reset_index()\n", + ")\n", + "plot_data5a[\"Generation\"] = plot_data5a.apply(\n", + " lambda x: (\n", + " f\"1st ($n_1$={len(standard_grouped[standard_grouped.generation_num == 1])})\"\n", + " if x.loc[\"generation_num\"] == 1\n", + " else f\"2+ ($n_2$={len(standard_grouped[standard_grouped.generation_num > 1])})\"\n", + " ),\n", + " axis=1,\n", ")\n", "if split_by_gen:\n", - " g_cols = ['centered_frames_in_phase', 'Generation']\n", + " g_cols = [\"centered_frames_in_phase\", \"Generation\"]\n", "else:\n", - " g_cols = 'centered_frames_in_phase'\n", - "plot_data5a['contributing_ccs_at_time'] = plot_data5a.groupby(g_cols).transform('count')['selection_subset']\n", + " g_cols = \"centered_frames_in_phase\"\n", + "plot_data5a[\"contributing_ccs_at_time\"] = plot_data5a.groupby(g_cols).transform(\n", + " \"count\"\n", + ")[\"selection_subset\"]\n", "plot_data5a = plot_data5a[plot_data5a.contributing_ccs_at_time >= min_no_of_ccs]\n", "\n", "# finally prepare data for plot (use melt for multiple lines)\n", "sample_size_5a = len(standard_grouped)\n", - "avg_cell_cycle_length = round(standard_grouped.loc[:,'centered_time_in_minutes'].mean())*frame_interval_minutes\n", - "cols_to_plot = ['Bud signal', 'Combined signal']\n", + "avg_cell_cycle_length = (\n", + " round(standard_grouped.loc[:, \"centered_time_in_minutes\"].mean())\n", + " * frame_interval_minutes\n", + ")\n", + "cols_to_plot = [\"Bud signal\", \"Combined signal\"]\n", "index_cols = [col for col in plot_data5a.columns if col not in cols_to_plot]\n", "plot_data5a_melted = pd.melt(\n", - " plot_data5a, index_cols, var_name='Method of calculation'\n", - ").sort_values('Method of calculation')\n", - "data_dir = os.path.join('..', 'data', 'paper_plot_data')\n", + " plot_data5a, index_cols, var_name=\"Method of calculation\"\n", + ").sort_values(\"Method of calculation\")\n", + "data_dir = os.path.join(\"..\", \"data\", \"paper_plot_data\")\n", "# save preprocessed data for Fig. 5A\n", - "#plot_data5a_melted.to_csv(os.path.join(data_dir, 'plot_data5a_melted_v2.csv'), index=False)\n", - "#plot_data5a.to_csv(os.path.join(data_dir, 'plot_data5a_v2.csv'), index=False)" + "# plot_data5a_melted.to_csv(os.path.join(data_dir, 'plot_data5a_melted_v2.csv'), index=False)\n", + "# plot_data5a.to_csv(os.path.join(data_dir, 'plot_data5a_v2.csv'), index=False)" ] }, { @@ -1656,33 +1879,42 @@ "sns.set_theme(style=\"darkgrid\", font_scale=1.6)\n", "f, ax = plt.subplots(figsize=(15, 12))\n", "if split_by_gen:\n", - " style='Generation'\n", + " style = \"Generation\"\n", "else:\n", - " style=None\n", + " style = None\n", "ax = sns.lineplot(\n", - " data=plot_data5a_melted,#.sort_values('Pool, Phase'),\n", - " x=\"centered_time_in_minutes\", \n", + " data=plot_data5a_melted, # .sort_values('Pool, Phase'),\n", + " x=\"centered_time_in_minutes\",\n", " y=\"value\",\n", - " hue='Method of calculation',\n", + " hue=\"Method of calculation\",\n", " style=style,\n", - " #style='position',\n", - " ci=95\n", + " # style='position',\n", + " ci=95,\n", ")\n", - "ax.axvline(x=0, color='red')#, label='Time of Bud Emergence')\n", + "ax.axvline(x=0, color=\"red\") # , label='Time of Bud Emergence')\n", "ax.text(\n", - " 0.5, 100000, \"Time of \\nBud Emergence\", horizontalalignment='left', \n", - " size='medium', color='red', weight='normal'\n", + " 0.5,\n", + " 100000,\n", + " \"Time of \\nBud Emergence\",\n", + " horizontalalignment=\"left\",\n", + " size=\"medium\",\n", + " color=\"red\",\n", + " weight=\"normal\",\n", ")\n", "ax.legend(\n", - " #title=f'Avg CC Length: {avg_cell_cycle_length} min, n = {sample_size_5a}', \n", + " # title=f'Avg CC Length: {avg_cell_cycle_length} min, n = {sample_size_5a}',\n", " fancybox=True,\n", " labelspacing=0.5,\n", " handlelength=1.5,\n", - " loc = 'upper left'\n", + " loc=\"upper left\",\n", + ")\n", + "ax.set_ylabel(\n", + " \"Total amount of Htb1-mCitrine corrected by background [a.u.]\", fontsize=20\n", ")\n", - "ax.set_ylabel(\"Total amount of Htb1-mCitrine corrected by background [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Time in phase relative to G1/S transition [minutes]\", fontsize=20)\n", - "ax.set_title(\"Corrected Htb1-mCitrine Amount during Cell Cycle Progression\", fontsize=30)\n", + "ax.set_title(\n", + " \"Corrected Htb1-mCitrine Amount during Cell Cycle Progression\", fontsize=30\n", + ")\n", "plt.tight_layout()\n", "\"\"\"\n", "plt.savefig(os.path.join('..', 'figures', 'new_fig5', 'mCitrine_over_time_by_gen_v6.svg'))\n", @@ -1713,14 +1945,22 @@ }, "outputs": [], "source": [ - "# obtain table where one cell cycle is represented by one row: \n", + "# obtain table where one cell cycle is represented by one row:\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", "needed_cols = [\n", - " 'Cell_ID', 'generation_num', 'position', 'file', 'cell_cycle_stage', 'selection_subset', \n", - " 'phase_volume_at_beginning', 'phase_volume_at_end', 'phase_mCitrine_amount_at_beginning',\n", - " 'phase_mCitrine_combined_amount_at_end', 'phase_combined_volume_at_end'\n", + " \"Cell_ID\",\n", + " \"generation_num\",\n", + " \"position\",\n", + " \"file\",\n", + " \"cell_cycle_stage\",\n", + " \"selection_subset\",\n", + " \"phase_volume_at_beginning\",\n", + " \"phase_volume_at_end\",\n", + " \"phase_mCitrine_amount_at_beginning\",\n", + " \"phase_mCitrine_combined_amount_at_end\",\n", + " \"phase_combined_volume_at_end\",\n", "]\n", - "plot_data5b = phase_grouped.loc[phase_grouped.all_complete==1, needed_cols]\n", + "plot_data5b = phase_grouped.loc[phase_grouped.all_complete == 1, needed_cols]\n", "scale_data = False" ] }, @@ -1736,47 +1976,54 @@ }, "outputs": [], "source": [ - "plot_data5b['relevant_volume'] = plot_data5b.apply(\n", - " lambda x: x.loc['phase_volume_at_beginning'] if\\\n", - " x.loc['cell_cycle_stage']=='G1' else\\\n", - " x.loc['phase_combined_volume_at_end'],\n", - " axis=1\n", - ")\n", - "plot_data5b['relevant_amount'] = plot_data5b.apply(\n", - " lambda x: x.loc['phase_mCitrine_amount_at_beginning'] if\\\n", - " x.loc['cell_cycle_stage']=='G1' else\\\n", - " x.loc['phase_mCitrine_combined_amount_at_end'],\n", - " axis=1\n", - ")\n", - "# defining a function to generate entries for the figure legend \n", + "plot_data5b[\"relevant_volume\"] = plot_data5b.apply(\n", + " lambda x: (\n", + " x.loc[\"phase_volume_at_beginning\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\"\n", + " else x.loc[\"phase_combined_volume_at_end\"]\n", + " ),\n", + " axis=1,\n", + ")\n", + "plot_data5b[\"relevant_amount\"] = plot_data5b.apply(\n", + " lambda x: (\n", + " x.loc[\"phase_mCitrine_amount_at_beginning\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\"\n", + " else x.loc[\"phase_mCitrine_combined_amount_at_end\"]\n", + " ),\n", + " axis=1,\n", + ")\n", + "\n", + "\n", + "# defining a function to generate entries for the figure legend\n", "# (assuming that selection_subset>0 is the autofluorescence control of the experiment)\n", "def calc_legend_entry(x):\n", - " if x.loc['selection_subset'] == 0:\n", - " if x.loc['cell_cycle_stage']=='G1':\n", - " return 'At G1-entry'\n", + " if x.loc[\"selection_subset\"] == 0:\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\":\n", + " return \"At G1-entry\"\n", " else:\n", - " return 'Mother+bud at cytokinesis'\n", + " return \"Mother+bud at cytokinesis\"\n", " else:\n", - " if x.loc['cell_cycle_stage']=='G1':\n", - " return 'AF control at G1-entry'\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\":\n", + " return \"AF control at G1-entry\"\n", " else:\n", - " return 'AF control, m+b at cytokinesis'\n", - " \n", - "plot_data5b['Kind of Measurement new'] = plot_data5b.apply(\n", - " calc_legend_entry,\n", - " axis=1\n", - ")\n", - "plot_data5b['Generation'] = plot_data5b.apply(\n", - " lambda x: f'1st ($n_1$={int(len(plot_data5b[plot_data5b.generation_num==1])/2)})' if\\\n", - " x.loc['generation_num']==1 else f'2+ ($n_2$={int(len(plot_data5b[plot_data5b.generation_num>1])/2)})',\n", - " axis=1\n", + " return \"AF control, m+b at cytokinesis\"\n", + "\n", + "\n", + "plot_data5b[\"Kind of Measurement new\"] = plot_data5b.apply(calc_legend_entry, axis=1)\n", + "plot_data5b[\"Generation\"] = plot_data5b.apply(\n", + " lambda x: (\n", + " f\"1st ($n_1$={int(len(plot_data5b[plot_data5b.generation_num == 1]) / 2)})\"\n", + " if x.loc[\"generation_num\"] == 1\n", + " else f\"2+ ($n_2$={int(len(plot_data5b[plot_data5b.generation_num > 1]) / 2)})\"\n", + " ),\n", + " axis=1,\n", ")\n", "if scale_data:\n", - " maximum = plot_data5b['relevant_amount'].max()\n", - " plot_data5b['relevant_amount'] /= maximum\n", + " maximum = plot_data5b[\"relevant_amount\"].max()\n", + " plot_data5b[\"relevant_amount\"] /= maximum\n", "sample_size_5b = len(plot_data5b)\n", - "data_dir = os.path.join('..', 'data', 'paper_plot_data')\n", - "#plot_data5b.to_csv(os.path.join(data_dir, 'plot_data5b_v2.csv'), index=False)" + "data_dir = os.path.join(\"..\", \"data\", \"paper_plot_data\")\n", + "# plot_data5b.to_csv(os.path.join(data_dir, 'plot_data5b_v2.csv'), index=False)" ] }, { @@ -1801,93 +2048,98 @@ }, "outputs": [], "source": [ - "#plot_data5b = plot_data5b[plot_data5b.selection_subset==1]\n", + "# plot_data5b = plot_data5b[plot_data5b.selection_subset==1]\n", "sns.set_theme(style=\"whitegrid\", font_scale=1.6)\n", "# Initialize the figure\n", "sns.lmplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b.sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", - " ),\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b.sort_values(\"Kind of Measurement new\", ascending=False),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " #style=\"generation_num\",\n", - " #row=\"selection_subset\",\n", - " #sharex=False,\n", + " # style=\"generation_num\",\n", + " # row=\"selection_subset\",\n", + " # sharex=False,\n", " height=10,\n", " aspect=1.1,\n", - " scatter=False\n", + " scatter=False,\n", ")\n", "\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b[plot_data5b.generation_num==1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b[plot_data5b.generation_num == 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " marker='x'\n", + " marker=\"x\",\n", ")\n", "\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data5b[plot_data5b.generation_num>1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data5b[plot_data5b.generation_num > 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " marker='o'\n", + " marker=\"o\",\n", ")\n", "\n", - "#g._legend.set_title('Kind of Measurement')\n", + "# g._legend.set_title('Kind of Measurement')\n", "ax = plt.gca()\n", - "#ax.set(yscale=\"log2\")\n", - "#ax.set_yscale('log', basey=2)\n", - "#ax.set_xscale('log', basex=10)\n", + "# ax.set(yscale=\"log2\")\n", + "# ax.set_yscale('log', basey=2)\n", + "# ax.set_xscale('log', basex=10)\n", "labels = [\n", - " 'Single cell at G1-entry',\n", - " 'Mother&bud at cytokinesis',\n", - " 'Af control, single cell at G1-entry',\n", - " 'Af control, combined mother&bud at cytokinesis',\n", - " 'Generation 1',\n", - " 'Generation 2+'\n", + " \"Single cell at G1-entry\",\n", + " \"Mother&bud at cytokinesis\",\n", + " \"Af control, single cell at G1-entry\",\n", + " \"Af control, combined mother&bud at cytokinesis\",\n", + " \"Generation 1\",\n", + " \"Generation 2+\",\n", "]\n", "handles = [\n", " mpatches.Patch(color=sns.color_palette()[0]),\n", " mpatches.Patch(color=sns.color_palette()[1]),\n", " mpatches.Patch(color=sns.color_palette()[2]),\n", " mpatches.Patch(color=sns.color_palette()[3]),\n", - " mlines.Line2D([], [], color='gray', marker='x', linestyle='None',\n", - " markersize=10),\n", - " mlines.Line2D([], [], color='gray', marker='o', linestyle='None',\n", - " markersize=10)\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"x\", linestyle=\"None\", markersize=10),\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"o\", linestyle=\"None\", markersize=10),\n", "]\n", "ax.legend(\n", " handles=handles,\n", - " labels=labels, \n", - " loc='center right',\n", - " bbox_to_anchor = (1,0.2),\n", - " framealpha=0.5\n", + " labels=labels,\n", + " loc=\"center right\",\n", + " bbox_to_anchor=(1, 0.2),\n", + " framealpha=0.5,\n", ")\n", "ax.set_ylabel(\"Amount of Htb1-mCitrine in Cell(s) [a.u.]\", fontsize=20)\n", - "ax.set_xlabel(\"Volume at G1-entry / Combined Volume Before Cytokinesis [fL]\", fontsize=20)\n", - "ax.set_title(f\"Volume at G1-entry vs Htb1-mCitrine Amount (n={int(sample_size_5b/2)})\", fontsize=30)\n", + "ax.set_xlabel(\n", + " \"Volume at G1-entry / Combined Volume Before Cytokinesis [fL]\", fontsize=20\n", + ")\n", + "ax.set_title(\n", + " f\"Volume at G1-entry vs Htb1-mCitrine Amount (n={int(sample_size_5b / 2)})\",\n", + " fontsize=30,\n", + ")\n", "# format y-axis\n", - "plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "ax.get_yaxis().get_offset_text().set_position((-0.05,0))\n", + "plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "ax.get_yaxis().get_offset_text().set_position((-0.05, 0))\n", "# format x-axis\n", - "ax.set_xlim(0, plot_data5b.relevant_volume.max()+20)\n", + "ax.set_xlim(0, plot_data5b.relevant_volume.max() + 20)\n", "plt.tight_layout()\n", "\"\"\"\n", "plt.savefig(os.path.join('..', 'figures', 'new_fig5', 'mCitrine_at_birth_and_cytokinesis_v6.png'), dpi=300)\n", "plt.savefig(os.path.join('..', 'figures', 'new_fig5', 'mCitrine_at_birth_and_cytokinesis_v6.svg'))\n", "\"\"\"\n", "plt.show()\n", - "print(f'sample size flu-control: {len(plot_data5b[plot_data5b.selection_subset==1])//2}')\n", - "print(f'sample size tagged strain: {len(plot_data5b[plot_data5b.selection_subset==0])//2}')" + "print(\n", + " f\"sample size flu-control: {len(plot_data5b[plot_data5b.selection_subset == 1]) // 2}\"\n", + ")\n", + "print(\n", + " f\"sample size tagged strain: {len(plot_data5b[plot_data5b.selection_subset == 0]) // 2}\"\n", + ")" ] }, { @@ -1909,21 +2161,21 @@ "source": [ "# will show up at x=1 --> later mother cells at their own birth\n", "mothers_at_birth = overall_df_with_rel[\n", - " (overall_df_with_rel.generation_num==1) & \n", - " (overall_df_with_rel.cell_cycle_stage=='G1') & \n", - " (overall_df_with_rel.frame_i==overall_df_with_rel.phase_begin) & \n", - " (overall_df_with_rel.is_history_known) &\n", - " (overall_df_with_rel.file=='MIA_KC_htb1_mCitrine_labeled') &\n", - " (~overall_df_with_rel.is_cell_excluded)\n", + " (overall_df_with_rel.generation_num == 1)\n", + " & (overall_df_with_rel.cell_cycle_stage == \"G1\")\n", + " & (overall_df_with_rel.frame_i == overall_df_with_rel.phase_begin)\n", + " & (overall_df_with_rel.is_history_known)\n", + " & (overall_df_with_rel.file == \"MIA_KC_htb1_mCitrine_labeled\")\n", + " & (~overall_df_with_rel.is_cell_excluded)\n", "]\n", "# will show up at x>1 --> mother cells now dividing from their own daughter cell the first (gen=2), second (gen=3),... time\n", "mothers_at_division = overall_df_with_rel[\n", - " (overall_df_with_rel.generation_num>1) & \n", - " (overall_df_with_rel.cell_cycle_stage=='G1') & \n", - " (overall_df_with_rel.frame_i==overall_df_with_rel.division_frame_i) & \n", - " (overall_df_with_rel.is_history_known) &\n", - " (overall_df_with_rel.file=='MIA_KC_htb1_mCitrine_labeled') &\n", - " (~overall_df_with_rel.is_cell_excluded)\n", + " (overall_df_with_rel.generation_num > 1)\n", + " & (overall_df_with_rel.cell_cycle_stage == \"G1\")\n", + " & (overall_df_with_rel.frame_i == overall_df_with_rel.division_frame_i)\n", + " & (overall_df_with_rel.is_history_known)\n", + " & (overall_df_with_rel.file == \"MIA_KC_htb1_mCitrine_labeled\")\n", + " & (~overall_df_with_rel.is_cell_excluded)\n", "]" ] }, @@ -1934,13 +2186,27 @@ "metadata": {}, "outputs": [], "source": [ - "mothers_df = pd.concat([mothers_at_division,mothers_at_birth], ignore_index=True)\n", - "mothers_df['pos_cell_id'] = mothers_df.apply(lambda x: f'cell_{x.loc[\"Cell_ID\"]}_{x.loc[\"position\"]}', axis=1)\n", + "mothers_df = pd.concat([mothers_at_division, mothers_at_birth], ignore_index=True)\n", + "mothers_df[\"pos_cell_id\"] = mothers_df.apply(\n", + " lambda x: f\"cell_{x.loc['Cell_ID']}_{x.loc['position']}\", axis=1\n", + ")\n", "# calculate number of cells per generation\n", "gen_counter = Counter(mothers_df.generation_num)\n", - "mothers_df['x_label'] = mothers_df.generation_num.apply(lambda x: f'{int(x)} (n={gen_counter[x]})')\n", - "mothers_df = mothers_df[['frame_i', 'Cell_ID', 'file', 'position', 'x_label', 'mCitrine_corrected_amount', 'mCitrine_corrected_concentration']].sort_values('x_label')\n", - "#mothers_df.to_csv(os.path.join(data_dir, 'plot_data5c.csv'), index=False)" + "mothers_df[\"x_label\"] = mothers_df.generation_num.apply(\n", + " lambda x: f\"{int(x)} (n={gen_counter[x]})\"\n", + ")\n", + "mothers_df = mothers_df[\n", + " [\n", + " \"frame_i\",\n", + " \"Cell_ID\",\n", + " \"file\",\n", + " \"position\",\n", + " \"x_label\",\n", + " \"mCitrine_corrected_amount\",\n", + " \"mCitrine_corrected_concentration\",\n", + " ]\n", + "].sort_values(\"x_label\")\n", + "# mothers_df.to_csv(os.path.join(data_dir, 'plot_data5c.csv'), index=False)" ] }, { @@ -1961,39 +2227,43 @@ "outputs": [], "source": [ "sns.set_theme(style=\"whitegrid\", font_scale=1.6)\n", - "fig, ax = plt.subplots(figsize=(10,10))\n", + "fig, ax = plt.subplots(figsize=(10, 10))\n", "sns.boxplot(\n", " data=mothers_df,\n", - " x='x_label',\n", - " y='mCitrine_corrected_concentration',\n", - " palette='vlag',\n", + " x=\"x_label\",\n", + " y=\"mCitrine_corrected_concentration\",\n", + " palette=\"vlag\",\n", " fliersize=0,\n", - " ax=ax\n", + " ax=ax,\n", ")\n", "\n", - "#add stripplot on top\n", + "# add stripplot on top\n", "sns.stripplot(\n", " data=mothers_df,\n", - " x='x_label',\n", - " y='mCitrine_corrected_concentration',\n", + " x=\"x_label\",\n", + " y=\"mCitrine_corrected_concentration\",\n", " color=\".3\",\n", - " ax=ax\n", + " ax=ax,\n", ")\n", "\n", "# switch to scientific number format on y-Axis and move text\n", - "ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "ax.get_yaxis().get_offset_text().set_position((-0.05,0))\n", + "ax.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "ax.get_yaxis().get_offset_text().set_position((-0.05, 0))\n", "\n", "# Rename axes and set title\n", - "ax.set_ylabel(\"Htb1-mCitrine amount per volume in mother cell at division [a.u.]\", fontsize=20)\n", + "ax.set_ylabel(\n", + " \"Htb1-mCitrine amount per volume in mother cell at division [a.u.]\", fontsize=20\n", + ")\n", "ax.set_xlabel(\"Generation\", fontsize=20)\n", - "ax.set_title(f\"Amount per Volume by Generation (n={len(mothers_df)})\", fontsize=25) # changed this from 30 to 25 compared to 5B\n", - "ax.set_ylim(0, mothers_df.mCitrine_corrected_concentration.max()+0.1e4)\n", + "ax.set_title(\n", + " f\"Amount per Volume by Generation (n={len(mothers_df)})\", fontsize=25\n", + ") # changed this from 30 to 25 compared to 5B\n", + "ax.set_ylim(0, mothers_df.mCitrine_corrected_concentration.max() + 0.1e4)\n", "\n", "# save and show\n", "plt.tight_layout()\n", - "#plt.savefig('../figures/generation_plot_v5.svg')#, dpi=300)\n", - "#plt.savefig('../figures/generation_plot_v5.png', dpi=300)\n", + "# plt.savefig('../figures/generation_plot_v5.svg')#, dpi=300)\n", + "# plt.savefig('../figures/generation_plot_v5.png', dpi=300)\n", "plt.show()" ] }, @@ -2014,19 +2284,23 @@ "metadata": {}, "outputs": [], "source": [ - "outliers = mothers_df.loc[mothers_df['mCitrine_corrected_concentration'] > 1e4][['Cell_ID', 'frame_i', 'file', 'position']]\n", - "data_path = f'../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_labeled'\n", + "outliers = mothers_df.loc[mothers_df[\"mCitrine_corrected_concentration\"] > 1e4][\n", + " [\"Cell_ID\", \"frame_i\", \"file\", \"position\"]\n", + "]\n", + "data_path = f\"../data/acdc_test_data/TimeLapse_2D/MIA_KC_htb1_mCitrine_labeled\"\n", "for idx, line in outliers.iterrows():\n", - " print(line)# if 'is' in str(v)])\n", - " pos_dir = f'{data_path}/{line[\"position\"]}/Images'\n", - " channel_data, seg_mask, cc_data, metadata, cc_props = cca_functions._load_files(pos_dir, ['phase_contr'])\n", - " plt.figure(figsize=(12,5))\n", + " print(line) # if 'is' in str(v)])\n", + " pos_dir = f\"{data_path}/{line['position']}/Images\"\n", + " channel_data, seg_mask, cc_data, metadata, cc_props = cca_functions._load_files(\n", + " pos_dir, [\"phase_contr\"]\n", + " )\n", + " plt.figure(figsize=(12, 5))\n", " plt.subplot(121)\n", - " plt.title('Phase Contrast')\n", + " plt.title(\"Phase Contrast\")\n", " plt.imshow(channel_data[line[\"frame_i\"]])\n", " plt.subplot(122)\n", - " plt.title('Outlier_cell')\n", - " plt.imshow(seg_mask[line[\"frame_i\"]]==line[\"Cell_ID\"])\n", + " plt.title(\"Outlier_cell\")\n", + " plt.imshow(seg_mask[line[\"frame_i\"]] == line[\"Cell_ID\"])\n", " plt.show()" ] }, diff --git a/notebooks/cell_cycle_analysis.ipynb b/notebooks/cell_cycle_analysis.ipynb index ff0c11706..d4ab46d45 100755 --- a/notebooks/cell_cycle_analysis.ipynb +++ b/notebooks/cell_cycle_analysis.ipynb @@ -20,13 +20,15 @@ "import glob\n", "import numpy as np\n", "import pandas as pd\n", + "\n", "pd.set_option(\"display.max_columns\", 200)\n", "pd.set_option(\"display.max_rows\", 50)\n", - "pd.set_option('display.max_colwidth', 150)\n", + "pd.set_option(\"display.max_colwidth\", 150)\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches\n", "import matplotlib.lines as mlines\n", "import seaborn as sns\n", + "\n", "sns.set_theme()\n", "try:\n", " from cellacdc import cca_functions\n", @@ -34,7 +36,7 @@ "except FileNotFoundError:\n", " # Check if user has developer version --> add the Cell_ACDC/cellacdc\n", " # folder to path and import from thre\n", - " sys.path.insert(0, '../cellacdc/')\n", + " sys.path.insert(0, \"../cellacdc/\")\n", " from cellacdc import cca_functions\n", " from cellacdc import myutils" ] @@ -83,12 +85,17 @@ "source": [ "data_dirs, positions, app = cca_functions.configuration_dialog()\n", "file_names = [os.path.split(path)[-1] for path in data_dirs]\n", - "image_folders = [[os.path.join(data_dir, pos_str, 'Images') for pos_str in pos_list] for pos_list, data_dir in zip(positions, data_dirs)]\n", + "image_folders = [\n", + " [os.path.join(data_dir, pos_str, \"Images\") for pos_str in pos_list]\n", + " for pos_list, data_dir in zip(positions, data_dirs)\n", + "]\n", "# determine available channels based on first(!) position.\n", "# Warn user if one or more of the channels are not available for some positions\n", - "first_pos_dir = os.path.join(data_dirs[0], positions[0][0], 'Images')\n", + "first_pos_dir = os.path.join(data_dirs[0], positions[0][0], \"Images\")\n", "first_pos_files = myutils.listdir(first_pos_dir)\n", - "channels, basename = cca_functions.find_available_channels(first_pos_files, first_pos_dir)\n", + "channels, basename = cca_functions.find_available_channels(\n", + " first_pos_files, first_pos_dir\n", + ")\n", "segm_endname = cca_functions.get_segm_endname(first_pos_dir, basename)" ] }, @@ -143,9 +150,9 @@ " file_names,\n", " image_folders,\n", " positions,\n", - " channels, \n", + " channels,\n", " segm_endname,\n", - " force_recalculation=False\n", + " force_recalculation=False,\n", ")\n", "\"\"\"\n", "overall_df = cca_functions.load_acdc_output_only(\n", @@ -181,24 +188,38 @@ "outputs": [], "source": [ "# if cell cycle annotations were performed in ACDC, extend the dataframe by a join on each cells relative cell\n", - "if 'cell_cycle_stage' in overall_df.columns:\n", + "if \"cell_cycle_stage\" in overall_df.columns:\n", " overall_df_with_rel = cca_functions.calculate_relatives_data(overall_df, channels)\n", "# If working with timelapse data build dataframe grouped by phases\n", "group_cols = [\n", - " 'Cell_ID', 'generation_num', 'cell_cycle_stage', 'relationship', 'position', 'file', \n", - " 'max_frame_pos', 'selection_subset', 'max_t'\n", + " \"Cell_ID\",\n", + " \"generation_num\",\n", + " \"cell_cycle_stage\",\n", + " \"relationship\",\n", + " \"position\",\n", + " \"file\",\n", + " \"max_frame_pos\",\n", + " \"selection_subset\",\n", + " \"max_t\",\n", "]\n", "# calculate data grouped by phase only in the case, that timelapse data is available\n", - "if is_timelapse_data and 'max_t' in overall_df_with_rel.columns:\n", - " phase_grouped = cca_functions.calculate_per_phase_quantities(overall_df_with_rel, group_cols, channels)\n", + "if is_timelapse_data and \"max_t\" in overall_df_with_rel.columns:\n", + " phase_grouped = cca_functions.calculate_per_phase_quantities(\n", + " overall_df_with_rel, group_cols, channels\n", + " )\n", " # append phase-grouped data to overall_df_with_rel\n", " overall_df_with_rel = overall_df_with_rel.merge(\n", - " phase_grouped,\n", - " how='left',\n", - " on=group_cols\n", + " phase_grouped, how=\"left\", on=group_cols\n", " )\n", - " overall_df_with_rel['time_in_phase'] = overall_df_with_rel['frame_i'] - overall_df_with_rel['phase_begin'] + 1\n", - " overall_df_with_rel['time_in_cell_cycle'] = overall_df_with_rel.groupby(['Cell_ID', 'generation_num', 'position', 'file'])['frame_i'].transform('cumcount') + 1" + " overall_df_with_rel[\"time_in_phase\"] = (\n", + " overall_df_with_rel[\"frame_i\"] - overall_df_with_rel[\"phase_begin\"] + 1\n", + " )\n", + " overall_df_with_rel[\"time_in_cell_cycle\"] = (\n", + " overall_df_with_rel.groupby([\"Cell_ID\", \"generation_num\", \"position\", \"file\"])[\n", + " \"frame_i\"\n", + " ].transform(\"cumcount\")\n", + " + 1\n", + " )" ] }, { @@ -258,30 +279,26 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "symm_plot_data = overall_df.copy()\n", - "grouping_cols = [\n", - " 'Cell_ID', \n", - " 'Cell_ID_tree', \n", - " 'parent_ID_tree', \n", - " 'file', \n", - " 'position'\n", - "]\n", - "symm_plot_data['rolling_avg_area'] = (\n", - " symm_plot_data.groupby(grouping_cols)['cell_area_pxl']\n", - " .transform(lambda x: x.rolling(window=3, center=True).mean())\n", - ")\n", - "symm_plot_data = symm_plot_data.sort_values('frame_i')\n", - "cc_frames = symm_plot_data.groupby(grouping_cols).agg(\n", - " root_ID = ('root_ID_tree', lambda x: x.iloc[0]),\n", - " birth_frame = ('frame_i', min),\n", - " cc_end_frame = ('frame_i', max),\n", - " birth_area = ('cell_area_pxl', lambda x: x.iloc[0]),\n", - " mean_area_first_5 = ('cell_area_pxl', lambda x: x.iloc[:5].mean()),\n", - " mean_area_first_10 = ('cell_area_pxl', lambda x: x.iloc[:10].mean()),\n", - " avg_area_first3 = ('rolling_avg_area', lambda x: try_find_entry(x, 1)),\n", - " avg_area_9to11 = ('rolling_avg_area', lambda x: try_find_entry(x, 9))\n", - ").reset_index()\n" + "grouping_cols = [\"Cell_ID\", \"Cell_ID_tree\", \"parent_ID_tree\", \"file\", \"position\"]\n", + "symm_plot_data[\"rolling_avg_area\"] = symm_plot_data.groupby(grouping_cols)[\n", + " \"cell_area_pxl\"\n", + "].transform(lambda x: x.rolling(window=3, center=True).mean())\n", + "symm_plot_data = symm_plot_data.sort_values(\"frame_i\")\n", + "cc_frames = (\n", + " symm_plot_data.groupby(grouping_cols)\n", + " .agg(\n", + " root_ID=(\"root_ID_tree\", lambda x: x.iloc[0]),\n", + " birth_frame=(\"frame_i\", min),\n", + " cc_end_frame=(\"frame_i\", max),\n", + " birth_area=(\"cell_area_pxl\", lambda x: x.iloc[0]),\n", + " mean_area_first_5=(\"cell_area_pxl\", lambda x: x.iloc[:5].mean()),\n", + " mean_area_first_10=(\"cell_area_pxl\", lambda x: x.iloc[:10].mean()),\n", + " avg_area_first3=(\"rolling_avg_area\", lambda x: try_find_entry(x, 1)),\n", + " avg_area_9to11=(\"rolling_avg_area\", lambda x: try_find_entry(x, 9)),\n", + " )\n", + " .reset_index()\n", + ")" ] }, { @@ -291,20 +308,26 @@ "metadata": {}, "outputs": [], "source": [ - "cc_frames['cc_length'] = cc_frames['cc_end_frame'] - cc_frames['birth_frame']\n", + "cc_frames[\"cc_length\"] = cc_frames[\"cc_end_frame\"] - cc_frames[\"birth_frame\"]\n", "# filter for birth_frame>0 (birth of cells present at beginning of experiment cannot be observed)\n", - "cc_frames = cc_frames.loc[cc_frames['birth_frame']>0]\n", + "cc_frames = cc_frames.loc[cc_frames[\"birth_frame\"] > 0]\n", "# calculate last frames per experiment and join this information with data\n", - "last_frames_per_experiment = overall_df.groupby(['file', 'position']).agg(\n", - " last_frame = ('frame_i', max)\n", - ").reset_index()\n", - "cc_frames = pd.merge(cc_frames, last_frames_per_experiment, how='left', on=['position', 'file'])\n", + "last_frames_per_experiment = (\n", + " overall_df.groupby([\"file\", \"position\"])\n", + " .agg(last_frame=(\"frame_i\", max))\n", + " .reset_index()\n", + ")\n", + "cc_frames = pd.merge(\n", + " cc_frames, last_frames_per_experiment, how=\"left\", on=[\"position\", \"file\"]\n", + ")\n", "# filter out rows where cell cycle \"ends\" on last frame (could just be bc of end of experiment)\n", - "cc_frames = cc_frames[cc_frames.last_frame!=cc_frames.cc_end_frame]\n", + "cc_frames = cc_frames[cc_frames.last_frame != cc_frames.cc_end_frame]\n", "# filter out rows with cell cycle length 0 (those are rows representing non-observable S phases)\n", - "cc_frames = cc_frames[cc_frames.cc_length>0]\n", + "cc_frames = cc_frames[cc_frames.cc_length > 0]\n", "# calculate growth at beginning of cell cycle by subtracting sliding avg of first 3 frames from sliding avg within frames 9 to 11\n", - "cc_frames['growth_in_first_10_frames'] = cc_frames.avg_area_9to11 - cc_frames.avg_area_first3" + "cc_frames[\"growth_in_first_10_frames\"] = (\n", + " cc_frames.avg_area_9to11 - cc_frames.avg_area_first3\n", + ")" ] }, { @@ -333,15 +356,11 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8,8))\n", - "plt.title('Correlation of cell size at birth and cell cycle length')\n", - "sns.regplot(\n", - " data=cc_frames,\n", - " x = 'birth_area',\n", - " y = 'cc_length'\n", - ")\n", - "plt.ylabel('Length of cell cycle [frames]')\n", - "plt.xlabel('Area of cell at birth [pixels]')\n", + "plt.figure(figsize=(8, 8))\n", + "plt.title(\"Correlation of cell size at birth and cell cycle length\")\n", + "sns.regplot(data=cc_frames, x=\"birth_area\", y=\"cc_length\")\n", + "plt.ylabel(\"Length of cell cycle [frames]\")\n", + "plt.xlabel(\"Area of cell at birth [pixels]\")\n", "plt.show()" ] }, @@ -352,15 +371,11 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8,8))\n", - "plt.title('Correlation of cell size at birth and growth in the first 10 frames')\n", - "sns.regplot(\n", - " data=cc_frames,\n", - " x = 'birth_area',\n", - " y = 'growth_in_first_10_frames'\n", - ")\n", - "plt.ylabel('Change of area during 10 first frames [pixels]')\n", - "plt.xlabel('Area of cell at birth [pixels]')\n", + "plt.figure(figsize=(8, 8))\n", + "plt.title(\"Correlation of cell size at birth and growth in the first 10 frames\")\n", + "sns.regplot(data=cc_frames, x=\"birth_area\", y=\"growth_in_first_10_frames\")\n", + "plt.ylabel(\"Change of area during 10 first frames [pixels]\")\n", + "plt.xlabel(\"Area of cell at birth [pixels]\")\n", "plt.show()" ] }, @@ -380,14 +395,10 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8,8))\n", - "plt.title('Distribution of cell cycle lengths')\n", - "sns.histplot(\n", - " data=cc_frames,\n", - " x='cc_length',\n", - " bins=10\n", - " )\n", - "plt.xlabel('Length of cell cycle [frames]')\n", + "plt.figure(figsize=(8, 8))\n", + "plt.title(\"Distribution of cell cycle lengths\")\n", + "sns.histplot(data=cc_frames, x=\"cc_length\", bins=10)\n", + "plt.xlabel(\"Length of cell cycle [frames]\")\n", "plt.show()" ] }, @@ -429,8 +440,10 @@ } ], "source": [ - "complete_cc_data = overall_df_with_rel[overall_df_with_rel.complete_cycle==1]\n", - "cc_lengths = complete_cc_data.groupby(['Cell_ID', 'generation_num', 'file', 'position'])['time_in_cell_cycle'].max()\n", + "complete_cc_data = overall_df_with_rel[overall_df_with_rel.complete_cycle == 1]\n", + "cc_lengths = complete_cc_data.groupby(\n", + " [\"Cell_ID\", \"generation_num\", \"file\", \"position\"]\n", + ")[\"time_in_cell_cycle\"].max()\n", "sns.histplot(cc_lengths)\n", "plt.show()" ] @@ -457,22 +470,35 @@ "outputs": [], "source": [ "# set this to match with channel of interest\n", - "ch_name = 'mCitrine'\n", + "ch_name = \"mCitrine\"\n", "# filter for relevant rows (first gen G1 cells)\n", "plot_data7 = overall_df_with_rel[\n", - " (overall_df_with_rel.cell_cycle_stage=='G1') &\n", - " (overall_df_with_rel.generation_num==1) &\n", - " (overall_df_with_rel.is_history_known) &\n", - " (overall_df_with_rel.complete_phase)\n", + " (overall_df_with_rel.cell_cycle_stage == \"G1\")\n", + " & (overall_df_with_rel.generation_num == 1)\n", + " & (overall_df_with_rel.is_history_known)\n", + " & (overall_df_with_rel.complete_phase)\n", "]\n", "# select columns of interest for the plot\n", "plot_data7 = plot_data7[\n", - " ['file', 'position', 'frame_i', 'Cell_ID', 'phase_begin', 'generation_num', f'{ch_name}_corrected_concentration']\n", + " [\n", + " \"file\",\n", + " \"position\",\n", + " \"frame_i\",\n", + " \"Cell_ID\",\n", + " \"phase_begin\",\n", + " \"generation_num\",\n", + " f\"{ch_name}_corrected_concentration\",\n", + " ]\n", "]\n", "# calculate \"time in phase\" column for x-axis\n", - "plot_data7['time_in_phase'] = plot_data7['frame_i'] - plot_data7['phase_begin']\n", + "plot_data7[\"time_in_phase\"] = plot_data7[\"frame_i\"] - plot_data7[\"phase_begin\"]\n", "# calculate a unique cell id accross files by just appending file, pos & cell id\n", - "plot_data7['f_pos_cell_id'] = plot_data7.apply(lambda x: f'{x[\"file\"]}_{x[\"position\"]}_Cell_{x[\"Cell_ID\"]}_Gen_{int(x[\"generation_num\"])}', axis=1)" + "plot_data7[\"f_pos_cell_id\"] = plot_data7.apply(\n", + " lambda x: (\n", + " f\"{x['file']}_{x['position']}_Cell_{x['Cell_ID']}_Gen_{int(x['generation_num'])}\"\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -494,47 +520,47 @@ ], "source": [ "# Generate figures (aggregated, single traces, combined\n", - "sns.set_theme(context='talk', font_scale=1.15)\n", + "sns.set_theme(context=\"talk\", font_scale=1.15)\n", "sns.set_style(\"whitegrid\", {\"grid.color\": \".95\"})\n", - "fig, axs = plt.subplots(ncols=3, figsize=(30,10), sharey=False)\n", + "fig, axs = plt.subplots(ncols=3, figsize=(30, 10), sharey=False)\n", "sns.lineplot(\n", " data=plot_data7,\n", - " x=\"time_in_phase\", \n", - " y=f'{ch_name}_corrected_concentration',\n", + " x=\"time_in_phase\",\n", + " y=f\"{ch_name}_corrected_concentration\",\n", " ci=95,\n", - " ax=axs[0]\n", + " ax=axs[0],\n", ")\n", "sns.lineplot(\n", " data=plot_data7,\n", - " x=\"time_in_phase\", \n", - " y=f'{ch_name}_corrected_concentration',\n", + " x=\"time_in_phase\",\n", + " y=f\"{ch_name}_corrected_concentration\",\n", " estimator=None,\n", - " units='f_pos_cell_id',\n", + " units=\"f_pos_cell_id\",\n", " ax=axs[1],\n", " lw=0.5,\n", - " #alpha=0.5\n", + " # alpha=0.5\n", ")\n", "sns.lineplot(\n", " data=plot_data7,\n", - " x=\"time_in_phase\", \n", - " y=f'{ch_name}_corrected_concentration',\n", + " x=\"time_in_phase\",\n", + " y=f\"{ch_name}_corrected_concentration\",\n", " ci=95,\n", - " ax=axs[2]\n", + " ax=axs[2],\n", ")\n", "sns.lineplot(\n", " data=plot_data7,\n", - " x=\"time_in_phase\", \n", - " y=f'{ch_name}_corrected_concentration',\n", + " x=\"time_in_phase\",\n", + " y=f\"{ch_name}_corrected_concentration\",\n", " estimator=None,\n", - " units='f_pos_cell_id',\n", + " units=\"f_pos_cell_id\",\n", " ax=axs[2],\n", " lw=0.5,\n", - " #alpha=0.5\n", + " # alpha=0.5\n", ")\n", "axs[0].set_ylabel(f\"Amount per Volume of {ch_name} [a.u.]\")\n", "axs[1].set_ylabel(f\"Amount per Volume of {ch_name} [a.u.]\")\n", "axs[2].set_ylabel(f\"Amount per Volume of {ch_name} [a.u.]\")\n", - "#plt.savefig('../figures/firstgen_g1_concentration.png', dpi=300)\n", + "# plt.savefig('../figures/firstgen_g1_concentration.png', dpi=300)\n", "plt.show()" ] }, @@ -563,17 +589,24 @@ }, "outputs": [], "source": [ - "# obtain table where one cell cycle is represented by one row: \n", + "# obtain table where one cell cycle is represented by one row:\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", - "complete_cc_data = phase_grouped[phase_grouped.all_complete==1]\n", - "s_data = complete_cc_data[complete_cc_data.cell_cycle_stage==\"S\"]\n", - "g1_data = complete_cc_data[complete_cc_data.cell_cycle_stage==\"G1\"]\n", + "complete_cc_data = phase_grouped[phase_grouped.all_complete == 1]\n", + "s_data = complete_cc_data[complete_cc_data.cell_cycle_stage == \"S\"]\n", + "g1_data = complete_cc_data[complete_cc_data.cell_cycle_stage == \"G1\"]\n", "plot_data2 = g1_data.merge(\n", - " s_data, on=['Cell_ID', 'generation_num', 'position'], how='inner', suffixes=('_g1','_s')\n", + " s_data,\n", + " on=[\"Cell_ID\", \"generation_num\", \"position\"],\n", + " how=\"inner\",\n", + " suffixes=(\"_g1\", \"_s\"),\n", + ")\n", + "plot_data2 = plot_data2[plot_data2.generation_num == 1]\n", + "plot_data2[\"combined_motherbud_growth\"] = (\n", + " plot_data2[\"phase_area_growth_s\"] + plot_data2[\"phase_daughter_area_growth_s\"]\n", ")\n", - "plot_data2 = plot_data2[plot_data2.generation_num==1]\n", - "plot_data2['combined_motherbud_growth'] = plot_data2['phase_area_growth_s'] + plot_data2['phase_daughter_area_growth_s']\n", - "plot_data2['combined_motherbud_vol_growth'] = plot_data2['phase_volume_growth_s'] + plot_data2['phase_daughter_volume_growth_s']" + "plot_data2[\"combined_motherbud_vol_growth\"] = (\n", + " plot_data2[\"phase_volume_growth_s\"] + plot_data2[\"phase_daughter_volume_growth_s\"]\n", + ")" ] }, { @@ -591,9 +624,14 @@ "source": [ "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", - "g = sns.lmplot(x=\"phase_volume_growth_g1\", y=\"combined_motherbud_vol_growth\", data=plot_data2,\n", - " hue=\"selection_subset_g1\", height=10)\n", - "g._legend.set_title('Position Pool')\n", + "g = sns.lmplot(\n", + " x=\"phase_volume_growth_g1\",\n", + " y=\"combined_motherbud_vol_growth\",\n", + " data=plot_data2,\n", + " hue=\"selection_subset_g1\",\n", + " height=10,\n", + ")\n", + "g._legend.set_title(\"Position Pool\")\n", "ax = plt.gca()\n", "ax.set_ylabel(\"Combined Mother+Bud S growth [fL]\", fontsize=20)\n", "ax.set_xlabel(\"G1 growth [fL]\", fontsize=20)\n", @@ -627,17 +665,22 @@ }, "outputs": [], "source": [ - "# obtain table where one cell cycle is represented by one row: \n", + "# obtain table where one cell cycle is represented by one row:\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", - "plot_data3 = phase_grouped[phase_grouped.cell_cycle_stage==\"G1\"]\n", - "plot_data3 = plot_data3[plot_data3.complete_phase==1]\n", - "plot_data3 = plot_data3[plot_data3.generation_num==1]\n", + "plot_data3 = phase_grouped[phase_grouped.cell_cycle_stage == \"G1\"]\n", + "plot_data3 = plot_data3[plot_data3.complete_phase == 1]\n", + "plot_data3 = plot_data3[plot_data3.generation_num == 1]\n", "\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", - "g = sns.lmplot(x=\"phase_volume_at_beginning\", y=\"phase_length\", data=plot_data3,\n", - " hue=\"selection_subset\", height=10)\n", - "g._legend.set_title('Position Pool')\n", + "g = sns.lmplot(\n", + " x=\"phase_volume_at_beginning\",\n", + " y=\"phase_length\",\n", + " data=plot_data3,\n", + " hue=\"selection_subset\",\n", + " height=10,\n", + ")\n", + "g._legend.set_title(\"Position Pool\")\n", "ax = plt.gca()\n", "ax.set_ylabel(\"Duration of first G1 phase [no of frames]\", fontsize=20)\n", "ax.set_xlabel(\"Volume at birth (first cytokinesis) [fL]\", fontsize=20)\n", @@ -671,23 +714,30 @@ "outputs": [], "source": [ "# set channel name here:\n", - "ch_name = 'mCitrine'\n", - "# obtain table where one cell cycle is represented by one row: \n", + "ch_name = \"mCitrine\"\n", + "# obtain table where one cell cycle is represented by one row:\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", - "plot_data4 = phase_grouped[phase_grouped.cell_cycle_stage==\"G1\"]\n", - "plot_data4 = plot_data4[plot_data4.complete_phase==1]\n", - "plot_data4 = plot_data4[plot_data4.generation_num==1]\n", + "plot_data4 = phase_grouped[phase_grouped.cell_cycle_stage == \"G1\"]\n", + "plot_data4 = plot_data4[plot_data4.complete_phase == 1]\n", + "plot_data4 = plot_data4[plot_data4.generation_num == 1]\n", "\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", - "g = sns.lmplot(x=\"phase_volume_at_beginning\", y=f\"phase_{ch_name}_concentration_at_beginning\", data=plot_data4,\n", - " hue=\"selection_subset\", height=10, )\n", - "g._legend.set_title('Position Pool')\n", + "g = sns.lmplot(\n", + " x=\"phase_volume_at_beginning\",\n", + " y=f\"phase_{ch_name}_concentration_at_beginning\",\n", + " data=plot_data4,\n", + " hue=\"selection_subset\",\n", + " height=10,\n", + ")\n", + "g._legend.set_title(\"Position Pool\")\n", "g.set(yscale=\"log\")\n", "ax = plt.gca()\n", "ax.set_ylabel(\"mCitrine signal amount per volume in cell [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Volume at birth (first cytokinesis) [fL]\", fontsize=20)\n", - "ax.set_title(\"Volume at birth vs mCitrine signal amount per volume (1st generation)\", fontsize=30)\n", + "ax.set_title(\n", + " \"Volume at birth vs mCitrine signal amount per volume (1st generation)\", fontsize=30\n", + ")\n", "plt.show()" ] }, @@ -716,19 +766,26 @@ }, "outputs": [], "source": [ - "# obtain table where one cell cycle is represented by one row: \n", + "# obtain table where one cell cycle is represented by one row:\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", - "complete_cc_data = phase_grouped[phase_grouped.all_complete==1]\n", - "s_data = complete_cc_data[complete_cc_data.cell_cycle_stage==\"S\"]\n", - "g1_data = complete_cc_data[complete_cc_data.cell_cycle_stage==\"G1\"]\n", - "plot_data1 = g1_data.merge(s_data, on=['Cell_ID', 'generation_num', 'position', 'file'], how='inner')\n", - "plot_data1 = plot_data1[plot_data1.generation_num==1]\n", + "complete_cc_data = phase_grouped[phase_grouped.all_complete == 1]\n", + "s_data = complete_cc_data[complete_cc_data.cell_cycle_stage == \"S\"]\n", + "g1_data = complete_cc_data[complete_cc_data.cell_cycle_stage == \"G1\"]\n", + "plot_data1 = g1_data.merge(\n", + " s_data, on=[\"Cell_ID\", \"generation_num\", \"position\", \"file\"], how=\"inner\"\n", + ")\n", + "plot_data1 = plot_data1[plot_data1.generation_num == 1]\n", "\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", - "g = sns.lmplot(x=\"phase_length_x\", y=\"phase_length_y\", data=plot_data1,\n", - " hue=\"selection_subset_x\", height=10)\n", - "g._legend.set_title('Position Pool')\n", + "g = sns.lmplot(\n", + " x=\"phase_length_x\",\n", + " y=\"phase_length_y\",\n", + " data=plot_data1,\n", + " hue=\"selection_subset_x\",\n", + " height=10,\n", + ")\n", + "g._legend.set_title(\"Position Pool\")\n", "ax = plt.gca()\n", "ax.set_ylabel(\"S duration same cycle [frames]\", fontsize=20)\n", "ax.set_xlabel(\"G1 duration [frames]\", fontsize=20)\n", @@ -773,36 +830,29 @@ "sns.set_theme(style=\"ticks\", font_scale=2)\n", "\n", "# Initialize the figure\n", - "plt.figure(figsize=(10,10))\n", + "plt.figure(figsize=(10, 10))\n", "sns.histplot(\n", - " x='cell_vol_fl', \n", - " data=overall_df,\n", - " hue='relationship',\n", - " bins=20,\n", - " legend=False\n", + " x=\"cell_vol_fl\", data=overall_df, hue=\"relationship\", bins=20, legend=False\n", ")\n", "ax = plt.gca()\n", - "labels = [\n", - " 'Mother cells',\n", - " 'Buds'\n", - "]\n", + "labels = [\"Mother cells\", \"Buds\"]\n", "handles = [\n", - " mpatches.Patch(color=sns.color_palette('pastel')[0]),\n", - " mpatches.Patch(color=sns.color_palette('pastel')[1])\n", + " mpatches.Patch(color=sns.color_palette(\"pastel\")[0]),\n", + " mpatches.Patch(color=sns.color_palette(\"pastel\")[1]),\n", "]\n", "ax.legend(\n", " handles=handles,\n", - " labels=labels, \n", - " loc='upper right',\n", - " #bbox_to_anchor = (1,0.2),\n", - " framealpha=0.5\n", + " labels=labels,\n", + " loc=\"upper right\",\n", + " # bbox_to_anchor = (1,0.2),\n", + " framealpha=0.5,\n", ")\n", "\n", "# Tweak the visual presentation\n", "ax = plt.gca()\n", "ax.set_xlabel(\"Cell volume [fL]\", fontsize=20)\n", "ax.set_title(f\"Volume distribution, n: {overall_df.shape[0]}\", fontsize=30)\n", - "#sns.despine(trim=True, left=True)\n", + "# sns.despine(trim=True, left=True)\n", "plt.show()" ] }, @@ -830,20 +880,25 @@ "outputs": [], "source": [ "# set channel name here:\n", - "ch_name = 'act1'\n", + "ch_name = \"act1\"\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", "g = sns.lmplot(\n", - " x=\"act1_amount_autoBkgr_meanProj\", \n", - " y=\"act1_amount_autoBkgr_maxProj\", \n", + " x=\"act1_amount_autoBkgr_meanProj\",\n", + " y=\"act1_amount_autoBkgr_maxProj\",\n", " data=overall_df,\n", - " hue=\"relationship\", \n", + " hue=\"relationship\",\n", " # hue='selection_subset', # try this if you selected multiple position pools\n", - " height=10)\n", - "g._legend.set_title('Cell type')\n", + " height=10,\n", + ")\n", + "g._legend.set_title(\"Cell type\")\n", "ax = plt.gca()\n", - "ax.set_ylabel(f\"{ch_name} signal amount (max projection of z-slices) [a.u.]\", fontsize=20)\n", - "ax.set_xlabel(f\"{ch_name} signal amount (mean projection of z-slices) [a.u.]\", fontsize=20)\n", + "ax.set_ylabel(\n", + " f\"{ch_name} signal amount (max projection of z-slices) [a.u.]\", fontsize=20\n", + ")\n", + "ax.set_xlabel(\n", + " f\"{ch_name} signal amount (mean projection of z-slices) [a.u.]\", fontsize=20\n", + ")\n", "ax.set_title(\"Comparing mean projection with max projection\", fontsize=30)\n", "plt.show()" ] @@ -872,18 +927,18 @@ "outputs": [], "source": [ "# set channel name here:\n", - "ch_name = 'act1'\n", + "ch_name = \"act1\"\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", "g = sns.lmplot(\n", - " x=\"cell_vol_fl\", \n", - " y=f\"{ch_name}_amount_autoBkgr_meanProj\", \n", + " x=\"cell_vol_fl\",\n", + " y=f\"{ch_name}_amount_autoBkgr_meanProj\",\n", " data=overall_df,\n", " hue=\"relationship\",\n", " # hue='selection_subset', # try this if you selected multiple position pools\n", - " height=10\n", + " height=10,\n", ")\n", - "g._legend.set_title('Position Pool')\n", + "g._legend.set_title(\"Position Pool\")\n", "ax = plt.gca()\n", "ax.set_ylabel(f\"{ch_name} signal amount in cell [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Cell volume [fL]\", fontsize=20)\n", @@ -915,17 +970,17 @@ "outputs": [], "source": [ "# set channel name here:\n", - "ch_name = 'act1'\n", + "ch_name = \"act1\"\n", "sns.set_theme(style=\"darkgrid\", font_scale=2)\n", "# Initialize the figure\n", "g = sns.lmplot(\n", - " x=\"cell_vol_fl\", \n", - " y=f\"{ch_name}_mean_meanProj\", \n", + " x=\"cell_vol_fl\",\n", + " y=f\"{ch_name}_mean_meanProj\",\n", " data=overall_df,\n", - " hue=\"relationship\", \n", - " height=10, \n", + " hue=\"relationship\",\n", + " height=10,\n", ")\n", - "g._legend.set_title('Cell type')\n", + "g._legend.set_title(\"Cell type\")\n", "ax = plt.gca()\n", "ax.set_ylabel(f\"{ch_name} mean signal strength in cell [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Cell volume [fL]\", fontsize=20)\n", @@ -975,7 +1030,7 @@ "# some configurations\n", "# frame interval of video\n", "frame_interval_minutes = 3\n", - "# quantiles of complete cell cycles (wrt phase lengths) to exclude from analysis \n", + "# quantiles of complete cell cycles (wrt phase lengths) to exclude from analysis\n", "# (not used, keep this for potential later use)\n", "down_q, upper_q = 0, 1\n", "# minimum number of cell cycles contributing to the mean+CI curve:\n", @@ -985,7 +1040,7 @@ "# wether to scale to 0/1 or not\n", "scale_data = False\n", "# name of channel the signal of which should be plotted\n", - "ch_name = 'mCitrine'" + "ch_name = \"mCitrine\"" ] }, { @@ -1003,104 +1058,133 @@ "source": [ "# select needed cols from overall_df_with_rel to not end up with too many columns\n", "needed_cols = [\n", - " 'selection_subset', 'position', 'Cell_ID', 'cell_cycle_stage', 'generation_num', 'frame_i',\n", - " f'{ch_name}_corrected_amount', f'{ch_name}_corrected_amount_rel', \n", - " 'file', 'relationship', 'relative_ID', 'phase_length', 'phase_begin', f'gui_{ch_name}_amount_autoBkgr'\n", + " \"selection_subset\",\n", + " \"position\",\n", + " \"Cell_ID\",\n", + " \"cell_cycle_stage\",\n", + " \"generation_num\",\n", + " \"frame_i\",\n", + " f\"{ch_name}_corrected_amount\",\n", + " f\"{ch_name}_corrected_amount_rel\",\n", + " \"file\",\n", + " \"relationship\",\n", + " \"relative_ID\",\n", + " \"phase_length\",\n", + " \"phase_begin\",\n", + " f\"gui_{ch_name}_amount_autoBkgr\",\n", "]\n", - "filter_idx = np.logical_and(overall_df_with_rel['complete_cycle'] == 1, overall_df_with_rel.selection_subset==0)\n", + "filter_idx = np.logical_and(\n", + " overall_df_with_rel[\"complete_cycle\"] == 1,\n", + " overall_df_with_rel.selection_subset == 0,\n", + ")\n", "plot_data4a = overall_df_with_rel.loc[filter_idx, needed_cols].copy()\n", "# calculate the time the cell already spent in the current frame at the current timepoint\n", - "plot_data4a['frames_in_phase'] = plot_data4a['frame_i'] - plot_data4a['phase_begin'] + 1\n", - "# calculate the time to the next (for G1 cells) and from the last (for S cells) G1/S transition \n", - "plot_data4a['centered_frames_in_phase'] = plot_data4a.apply(\n", - " lambda x: x.loc['frames_in_phase'] if\\\n", - " x.loc['cell_cycle_stage']=='S' else\\\n", - " x.loc['frames_in_phase']-1-x.loc['phase_length'],\n", - " axis=1\n", + "plot_data4a[\"frames_in_phase\"] = plot_data4a[\"frame_i\"] - plot_data4a[\"phase_begin\"] + 1\n", + "# calculate the time to the next (for G1 cells) and from the last (for S cells) G1/S transition\n", + "plot_data4a[\"centered_frames_in_phase\"] = plot_data4a.apply(\n", + " lambda x: (\n", + " x.loc[\"frames_in_phase\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\"\n", + " else x.loc[\"frames_in_phase\"] - 1 - x.loc[\"phase_length\"]\n", + " ),\n", + " axis=1,\n", ")\n", "# calculate combined signal and the \"Pool, Phase ID\" for the legend\n", - "# plot_data4a at this point only contains relationship==mother, \n", + "# plot_data4a at this point only contains relationship==mother,\n", "# as generation_num==0 and relationship==bud are filtered out (incomplete cycle, cycles start with G1)\n", - "plot_data4a['Combined signal m&b'] = plot_data4a.apply(\n", - " lambda x: x.loc[f'{ch_name}_corrected_amount']+x.loc[f'{ch_name}_corrected_amount_rel'] if\\\n", - " x.loc['cell_cycle_stage']=='S' and x.loc['relationship'] == 'mother' else\\\n", - " x.loc[f'{ch_name}_corrected_amount'],\n", - " axis=1\n", + "plot_data4a[\"Combined signal m&b\"] = plot_data4a.apply(\n", + " lambda x: (\n", + " x.loc[f\"{ch_name}_corrected_amount\"] + x.loc[f\"{ch_name}_corrected_amount_rel\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\" and x.loc[\"relationship\"] == \"mother\"\n", + " else x.loc[f\"{ch_name}_corrected_amount\"]\n", + " ),\n", + " axis=1,\n", ")\n", - "plot_data4a['Bud signal'] = plot_data4a.apply(\n", - " lambda x: x.loc[f'{ch_name}_corrected_amount_rel'] if\\\n", - " x.loc['cell_cycle_stage']=='S' and x.loc['relationship'] == 'mother' else 0,\n", - " axis=1\n", + "plot_data4a[\"Bud signal\"] = plot_data4a.apply(\n", + " lambda x: (\n", + " x.loc[f\"{ch_name}_corrected_amount_rel\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"S\" and x.loc[\"relationship\"] == \"mother\"\n", + " else 0\n", + " ),\n", + " axis=1,\n", ")\n", "# scale data if needed\n", "if scale_data:\n", " maximum = max(\n", - " plot_data4a['Combined signal m&b'].max(), \n", - " plot_data4a['Bud signal'].max()\n", + " plot_data4a[\"Combined signal m&b\"].max(), plot_data4a[\"Bud signal\"].max()\n", " )\n", - " plot_data4a['Combined signal m&b'] /= maximum\n", - " plot_data4a['Bud signal'] /= maximum\n", + " plot_data4a[\"Combined signal m&b\"] /= maximum\n", + " plot_data4a[\"Bud signal\"] /= maximum\n", "# calculate min and max centered times per generation to eliminate up to a percentile\n", "# (not used, as upper_q and lower_q are set to 100/0 respectively)\n", - "plot_data4a['min_centered_frames'] = plot_data4a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ")['centered_frames_in_phase'].transform(\n", - " 'min'\n", + "plot_data4a[\"min_centered_frames\"] = plot_data4a.groupby(\n", + " [\"position\", \"file\", \"Cell_ID\", \"generation_num\"]\n", + ")[\"centered_frames_in_phase\"].transform(\"min\")\n", + "plot_data4a[\"max_centered_frames\"] = plot_data4a.groupby(\n", + " [\"position\", \"file\", \"Cell_ID\", \"generation_num\"]\n", + ")[\"centered_frames_in_phase\"].transform(\"max\")\n", + "min_and_max = (\n", + " plot_data4a.groupby([\"Cell_ID\", \"generation_num\", \"position\", \"file\"])\n", + " .agg(\n", + " min_centered=(\"min_centered_frames\", \"first\"),\n", + " max_centered=(\"max_centered_frames\", \"first\"),\n", + " )\n", + " .reset_index()\n", ")\n", - "plot_data4a['max_centered_frames'] = plot_data4a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ")['centered_frames_in_phase'].transform(\n", - " 'max'\n", + "min_val, max_val = (\n", + " np.quantile(min_and_max.min_centered, down_q) * frame_interval_minutes,\n", + " np.quantile(min_and_max.max_centered, upper_q) * frame_interval_minutes,\n", ")\n", - "min_and_max = plot_data4a.groupby(\n", - " ['Cell_ID', 'generation_num', 'position', 'file']\n", - ").agg(\n", - " min_centered = ('min_centered_frames', 'first'),\n", - " max_centered = ('max_centered_frames', 'first')\n", - ").reset_index()\n", - "min_val, max_val = np.quantile(\n", - " min_and_max.min_centered, down_q\n", - ") * frame_interval_minutes, np.quantile(\n", - " min_and_max.max_centered, upper_q\n", - ") * frame_interval_minutes\n", "# perform selection (won't change anything if upper and lower are 100 and 0 respectively)\n", "selection_indices = np.logical_and(\n", - " plot_data4a.min_centered_frames*frame_interval_minutes>=min_val, \n", - " plot_data4a.max_centered_frames*frame_interval_minutes<=max_val\n", + " plot_data4a.min_centered_frames * frame_interval_minutes >= min_val,\n", + " plot_data4a.max_centered_frames * frame_interval_minutes <= max_val,\n", ")\n", "plot_data4a = plot_data4a[selection_indices]\n", "\n", "# calculate centered time in minutes\n", - "plot_data4a['centered_time_in_minutes'] = plot_data4a.centered_frames_in_phase * frame_interval_minutes\n", + "plot_data4a[\"centered_time_in_minutes\"] = (\n", + " plot_data4a.centered_frames_in_phase * frame_interval_minutes\n", + ")\n", "\n", "# group dataframe to calculate sample sizes per generation\n", - "standard_grouped = plot_data4a.groupby(\n", - " ['position', 'file', 'Cell_ID', 'generation_num']\n", - ").agg('count').reset_index()\n", - "plot_data4a['Generation'] = plot_data4a.apply(\n", - " lambda x: f'1st ($n_1$={len(standard_grouped[standard_grouped.generation_num==1])})' if\\\n", - " x.loc['generation_num']==1 else f'2+ ($n_2$={len(standard_grouped[standard_grouped.generation_num>1])})',\n", - " axis=1\n", + "standard_grouped = (\n", + " plot_data4a.groupby([\"position\", \"file\", \"Cell_ID\", \"generation_num\"])\n", + " .agg(\"count\")\n", + " .reset_index()\n", + ")\n", + "plot_data4a[\"Generation\"] = plot_data4a.apply(\n", + " lambda x: (\n", + " f\"1st ($n_1$={len(standard_grouped[standard_grouped.generation_num == 1])})\"\n", + " if x.loc[\"generation_num\"] == 1\n", + " else f\"2+ ($n_2$={len(standard_grouped[standard_grouped.generation_num > 1])})\"\n", + " ),\n", + " axis=1,\n", ")\n", "if split_by_gen:\n", - " g_cols = ['centered_frames_in_phase', 'Generation']\n", + " g_cols = [\"centered_frames_in_phase\", \"Generation\"]\n", "else:\n", - " g_cols = 'centered_frames_in_phase'\n", - "plot_data4a['contributing_ccs_at_time'] = plot_data4a.groupby(g_cols).transform('count')['selection_subset']\n", + " g_cols = \"centered_frames_in_phase\"\n", + "plot_data4a[\"contributing_ccs_at_time\"] = plot_data4a.groupby(g_cols).transform(\n", + " \"count\"\n", + ")[\"selection_subset\"]\n", "plot_data4a = plot_data4a[plot_data4a.contributing_ccs_at_time >= min_no_of_ccs]\n", "\n", "# finally prepare data for plot (use melt for multiple lines)\n", "sample_size_4a = len(standard_grouped)\n", - "avg_cell_cycle_length = round(standard_grouped.loc[:,'centered_time_in_minutes'].mean())*frame_interval_minutes\n", - "cols_to_plot = ['Bud signal', 'Combined signal m&b']\n", + "avg_cell_cycle_length = (\n", + " round(standard_grouped.loc[:, \"centered_time_in_minutes\"].mean())\n", + " * frame_interval_minutes\n", + ")\n", + "cols_to_plot = [\"Bud signal\", \"Combined signal m&b\"]\n", "index_cols = [col for col in plot_data4a.columns if col not in cols_to_plot]\n", "plot_data4a_melted = pd.melt(\n", - " plot_data4a, index_cols, var_name='Method of calculation'\n", - ").sort_values('Method of calculation')\n", - "data_dir = os.path.join('..', 'data', 'paper_plot_data')\n", + " plot_data4a, index_cols, var_name=\"Method of calculation\"\n", + ").sort_values(\"Method of calculation\")\n", + "data_dir = os.path.join(\"..\", \"data\", \"paper_plot_data\")\n", "# save preprocessed data for Fig. 4A\n", - "#plot_data4a_melted.to_csv(os.path.join(data_dir, 'plot_data4a_melted.csv'), index=False)\n", - "#plot_data4a.to_csv(os.path.join(data_dir, 'plot_data4a.csv'), index=False)" + "# plot_data4a_melted.to_csv(os.path.join(data_dir, 'plot_data4a_melted.csv'), index=False)\n", + "# plot_data4a.to_csv(os.path.join(data_dir, 'plot_data4a.csv'), index=False)" ] }, { @@ -1120,29 +1204,34 @@ "sns.set_theme(style=\"darkgrid\", font_scale=1.6)\n", "f, ax = plt.subplots(figsize=(15, 12))\n", "if split_by_gen:\n", - " style='Generation'\n", + " style = \"Generation\"\n", "else:\n", - " style=None\n", + " style = None\n", "ax = sns.lineplot(\n", - " data=plot_data6_melted,#.sort_values('Pool, Phase'),\n", - " x=\"centered_time_in_minutes\", \n", + " data=plot_data6_melted, # .sort_values('Pool, Phase'),\n", + " x=\"centered_time_in_minutes\",\n", " y=\"value\",\n", - " hue='Method of Calculation',\n", - " #hue='position',\n", + " hue=\"Method of Calculation\",\n", + " # hue='position',\n", " style=style,\n", - " ci=95\n", + " ci=95,\n", ")\n", - "ax.axvline(x=0, color='red')#, label='Time of Bud Emergence')\n", + "ax.axvline(x=0, color=\"red\") # , label='Time of Bud Emergence')\n", "ax.text(\n", - " 0.5, 0.21, \"Time of \\nBud Emergence\", horizontalalignment='left', \n", - " size='medium', color='red', weight='normal'\n", + " 0.5,\n", + " 0.21,\n", + " \"Time of \\nBud Emergence\",\n", + " horizontalalignment=\"left\",\n", + " size=\"medium\",\n", + " color=\"red\",\n", + " weight=\"normal\",\n", ")\n", "ax.legend(\n", - " title=f'Avg CC Length: {avg_cell_cycle_length} min, n = {sample_size}', \n", + " title=f\"Avg CC Length: {avg_cell_cycle_length} min, n = {sample_size}\",\n", " fancybox=True,\n", " labelspacing=0.5,\n", " handlelength=1.5,\n", - " loc = 'upper left'\n", + " loc=\"upper left\",\n", ")\n", "ax.set_ylabel(\"Total amount of Signal corrected by background [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Time in phase relative to G1/S transition [minutes]\", fontsize=20)\n", @@ -1179,14 +1268,22 @@ "outputs": [], "source": [ "# configure channel the signal of which should be plotted\n", - "ch_name = 'mCitrine'\n", + "ch_name = \"mCitrine\"\n", "# first set of columns (like phase_length, growth...) for G1, second set of cols for S\n", "needed_cols = [\n", - " 'Cell_ID', 'generation_num', 'position', 'file', 'cell_cycle_stage', 'selection_subset', \n", - " 'phase_volume_at_beginning', 'phase_volume_at_end', f'phase_{ch_name}_amount_at_beginning',\n", - " f'phase_{ch_name}_combined_amount_at_end','phase_combined_volume_at_end'\n", + " \"Cell_ID\",\n", + " \"generation_num\",\n", + " \"position\",\n", + " \"file\",\n", + " \"cell_cycle_stage\",\n", + " \"selection_subset\",\n", + " \"phase_volume_at_beginning\",\n", + " \"phase_volume_at_end\",\n", + " f\"phase_{ch_name}_amount_at_beginning\",\n", + " f\"phase_{ch_name}_combined_amount_at_end\",\n", + " \"phase_combined_volume_at_end\",\n", "]\n", - "plot_data4 = phase_grouped.loc[phase_grouped.complete_cycle==1, needed_cols]\n", + "plot_data4 = phase_grouped.loc[phase_grouped.complete_cycle == 1, needed_cols]\n", "scale_data = False" ] }, @@ -1203,51 +1300,60 @@ }, "outputs": [], "source": [ - "plot_data4['relevant_volume'] = plot_data4.apply(\n", - " lambda x: x.loc['phase_volume_at_beginning'] if\\\n", - " x.loc['cell_cycle_stage']=='G1' else\\\n", - " x.loc['phase_combined_volume_at_end'],\n", - " axis=1\n", + "plot_data4[\"relevant_volume\"] = plot_data4.apply(\n", + " lambda x: (\n", + " x.loc[\"phase_volume_at_beginning\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\"\n", + " else x.loc[\"phase_combined_volume_at_end\"]\n", + " ),\n", + " axis=1,\n", ")\n", - "plot_data4['relevant_amount'] = plot_data4.apply(\n", - " lambda x: x.loc[f'phase_{ch_name}_amount_at_beginning'] if\\\n", - " x.loc['cell_cycle_stage']=='G1' else\\\n", - " x.loc[f'phase_{ch_name}_combined_amount_at_end'],\n", - " axis=1\n", + "plot_data4[\"relevant_amount\"] = plot_data4.apply(\n", + " lambda x: (\n", + " x.loc[f\"phase_{ch_name}_amount_at_beginning\"]\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\"\n", + " else x.loc[f\"phase_{ch_name}_combined_amount_at_end\"]\n", + " ),\n", + " axis=1,\n", ")\n", - "# defining a function to generate entries for the figure legend \n", + "\n", + "\n", + "# defining a function to generate entries for the figure legend\n", "# (assuming that selection_subset>0 is the autofluorescence control of the experiment)\n", "def calc_legend_entry(x):\n", - " if x.loc['selection_subset'] == 0:\n", - " if x.loc['cell_cycle_stage']=='G1':\n", - " return 'Single cell at birth'\n", + " if x.loc[\"selection_subset\"] == 0:\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\":\n", + " return \"Single cell at birth\"\n", " else:\n", - " return 'Combined mother&bud at cytokinesis'\n", + " return \"Combined mother&bud at cytokinesis\"\n", " else:\n", - " if x.loc['cell_cycle_stage']=='G1':\n", - " return 'Af control, single cell at birth'\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\":\n", + " return \"Af control, single cell at birth\"\n", " else:\n", - " return 'Af control, combined mother&bud at cytokinesis'\n", - " \n", - "plot_data4['Kind of Measurement'] = plot_data4.apply(\n", - " lambda x: 'Single Cell in G1 (Frame after Cytokinesis)' if\\\n", - " x.loc['cell_cycle_stage']=='G1' else\\\n", - " 'Combined Mother & Bud in S (Frame before Cytokinesis)',\n", - " axis=1\n", - ")\n", - "plot_data4['Kind of Measurement new'] = plot_data4.apply(\n", - " calc_legend_entry,\n", - " axis=1\n", + " return \"Af control, combined mother&bud at cytokinesis\"\n", + "\n", + "\n", + "plot_data4[\"Kind of Measurement\"] = plot_data4.apply(\n", + " lambda x: (\n", + " \"Single Cell in G1 (Frame after Cytokinesis)\"\n", + " if x.loc[\"cell_cycle_stage\"] == \"G1\"\n", + " else \"Combined Mother & Bud in S (Frame before Cytokinesis)\"\n", + " ),\n", + " axis=1,\n", ")\n", - "plot_data4['Generation'] = plot_data4.apply(\n", - " lambda x: f'1st ($n_1$={int(len(plot_data4[plot_data4.generation_num==1])/2)})' if\\\n", - " x.loc['generation_num']==1 else f'2+ ($n_2$={int(len(plot_data4[plot_data4.generation_num>1])/2)})',\n", - " axis=1\n", + "plot_data4[\"Kind of Measurement new\"] = plot_data4.apply(calc_legend_entry, axis=1)\n", + "plot_data4[\"Generation\"] = plot_data4.apply(\n", + " lambda x: (\n", + " f\"1st ($n_1$={int(len(plot_data4[plot_data4.generation_num == 1]) / 2)})\"\n", + " if x.loc[\"generation_num\"] == 1\n", + " else f\"2+ ($n_2$={int(len(plot_data4[plot_data4.generation_num > 1]) / 2)})\"\n", + " ),\n", + " axis=1,\n", ")\n", "if scale_data:\n", - " maximum = plot_data4['relevant_amount'].max()\n", - " plot_data4['relevant_amount'] /= maximum\n", - "sample_size = len(plot_data4)\n" + " maximum = plot_data4[\"relevant_amount\"].max()\n", + " plot_data4[\"relevant_amount\"] /= maximum\n", + "sample_size = len(plot_data4)" ] }, { @@ -1263,79 +1369,81 @@ }, "outputs": [], "source": [ - "#plot_data4 = plot_data4[plot_data4.selection_subset==1]\n", + "# plot_data4 = plot_data4[plot_data4.selection_subset==1]\n", "sns.set_theme(style=\"darkgrid\", font_scale=1.6)\n", "# create lmplot. Don't scatter and ommit legend to customize scatterplot and legend\n", "sns.lmplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data4.sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", - " ),\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data4.sort_values(\"Kind of Measurement new\", ascending=False),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", " height=10,\n", " aspect=1.1,\n", - " scatter=False\n", + " scatter=False,\n", ")\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data4[plot_data4.generation_num==1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data4[plot_data4.generation_num == 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " marker='x'\n", + " marker=\"x\",\n", ")\n", "sns.scatterplot(\n", - " x=\"relevant_volume\", \n", - " y=\"relevant_amount\", \n", - " data=plot_data4[plot_data4.generation_num>1].sort_values(\n", - " 'Kind of Measurement new', ascending=False\n", + " x=\"relevant_volume\",\n", + " y=\"relevant_amount\",\n", + " data=plot_data4[plot_data4.generation_num > 1].sort_values(\n", + " \"Kind of Measurement new\", ascending=False\n", " ),\n", " hue=\"Kind of Measurement new\",\n", " legend=False,\n", - " marker='o'\n", + " marker=\"o\",\n", ")\n", "ax = plt.gca()\n", "labels = [\n", - " 'Single cell at birth',\n", - " 'Combined mother&bud at cytokinesis',\n", - " 'Af control, single cell at birth',\n", - " 'Af control, combined mother&bud at cytokinesis',\n", - " 'Generation 1',\n", - " 'Generation 2+'\n", + " \"Single cell at birth\",\n", + " \"Combined mother&bud at cytokinesis\",\n", + " \"Af control, single cell at birth\",\n", + " \"Af control, combined mother&bud at cytokinesis\",\n", + " \"Generation 1\",\n", + " \"Generation 2+\",\n", "]\n", "handles = [\n", " mpatches.Patch(color=sns.color_palette()[0]),\n", " mpatches.Patch(color=sns.color_palette()[1]),\n", " mpatches.Patch(color=sns.color_palette()[2]),\n", " mpatches.Patch(color=sns.color_palette()[3]),\n", - " mlines.Line2D([], [], color='gray', marker='x', linestyle='None',\n", - " markersize=10),\n", - " mlines.Line2D([], [], color='gray', marker='o', linestyle='None',\n", - " markersize=10)\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"x\", linestyle=\"None\", markersize=10),\n", + " mlines.Line2D([], [], color=\"gray\", marker=\"o\", linestyle=\"None\", markersize=10),\n", "]\n", "ax.legend(\n", " handles=handles,\n", - " labels=labels, \n", - " loc='center right',\n", - " bbox_to_anchor = (1,0.2),\n", - " framealpha=0.5\n", + " labels=labels,\n", + " loc=\"center right\",\n", + " bbox_to_anchor=(1, 0.2),\n", + " framealpha=0.5,\n", ")\n", "ax.set_ylabel(\"Amount of Signal in Cell(s) [a.u.]\", fontsize=20)\n", "ax.set_xlabel(\"Volume at Birth / Combined Volume Before Cytokinesis [fL]\", fontsize=20)\n", - "ax.set_title(f\"Volume at birth vs Signal Amount (n={int(sample_size/2)})\", fontsize=30)\n", + "ax.set_title(\n", + " f\"Volume at birth vs Signal Amount (n={int(sample_size / 2)})\", fontsize=30\n", + ")\n", "# format y-axis\n", - "plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)\n", - "ax.get_yaxis().get_offset_text().set_position((-0.05,0))\n", + "plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), useMathText=True)\n", + "ax.get_yaxis().get_offset_text().set_position((-0.05, 0))\n", "# format x-axis\n", - "ax.set_xlim(0, plot_data4.relevant_volume.max()+20)\n", + "ax.set_xlim(0, plot_data4.relevant_volume.max() + 20)\n", "plt.tight_layout()\n", "plt.show()\n", - "print(f'sample size flu-control: {len(plot_data4[plot_data4.selection_subset==1])//2}')\n", - "print(f'sample size tagged strain: {len(plot_data4[plot_data4.selection_subset==0])//2}')" + "print(\n", + " f\"sample size flu-control: {len(plot_data4[plot_data4.selection_subset == 1]) // 2}\"\n", + ")\n", + "print(\n", + " f\"sample size tagged strain: {len(plot_data4[plot_data4.selection_subset == 0]) // 2}\"\n", + ")" ] } ], diff --git a/notebooks/workshop_analyses.ipynb b/notebooks/workshop_analyses.ipynb index 3fdcc94a5..a6661f65b 100644 --- a/notebooks/workshop_analyses.ipynb +++ b/notebooks/workshop_analyses.ipynb @@ -12,11 +12,13 @@ "import numpy as np\n", "import pandas as pd\n", "from scipy.spatial import distance_matrix\n", + "\n", "pd.set_option(\"display.max_columns\", 200)\n", "pd.set_option(\"display.max_rows\", 50)\n", - "pd.set_option('display.max_colwidth', 150)\n", + "pd.set_option(\"display.max_colwidth\", 150)\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", + "\n", "sns.set_theme()\n", "try:\n", " from cellacdc import cca_functions\n", @@ -24,7 +26,7 @@ "except FileNotFoundError:\n", " # Check if user has developer version --> add the Cell_ACDC/cellacdc\n", " # folder to path and import from there\n", - " sys.path.insert(0, '../cellacdc/')\n", + " sys.path.insert(0, \"../cellacdc/\")\n", " from cellacdc import cca_functions\n", " from cellacdc import myutils" ] @@ -73,14 +75,28 @@ "source": [ "data_dirs, positions, app = cca_functions.configuration_dialog()\n", "file_names = [os.path.split(path)[-1] for path in data_dirs]\n", - "image_folders = [[os.path.join(data_dir, pos_str, 'Images') for pos_str in pos_list] for pos_list, data_dir in zip(positions, data_dirs)]\n", + "image_folders = [\n", + " [os.path.join(data_dir, pos_str, \"Images\") for pos_str in pos_list]\n", + " for pos_list, data_dir in zip(positions, data_dirs)\n", + "]\n", "# determine available channels based on first(!) position.\n", "# Warn user if one or more of the channels are not available for some positions\n", - "first_pos_dirs = [os.path.join(data_dir, positions[0][0], 'Images') for data_dir in data_dirs]\n", + "first_pos_dirs = [\n", + " os.path.join(data_dir, positions[0][0], \"Images\") for data_dir in data_dirs\n", + "]\n", "first_pos_files = [myutils.listdir(first_pos_dir) for first_pos_dir in first_pos_dirs]\n", - "channels = [cca_functions.find_available_channels(fpf, fpd)[0] for fpf, fpd in zip(first_pos_files, first_pos_dirs)]\n", - "basenames = [cca_functions.find_available_channels(fpf, fpd)[1] for fpf, fpd in zip(first_pos_files, first_pos_dirs)]\n", - "segm_endnames = [cca_functions.get_segm_endname(fpd, bn) for fpd, bn in zip(first_pos_dirs, basenames)]\n" + "channels = [\n", + " cca_functions.find_available_channels(fpf, fpd)[0]\n", + " for fpf, fpd in zip(first_pos_files, first_pos_dirs)\n", + "]\n", + "basenames = [\n", + " cca_functions.find_available_channels(fpf, fpd)[1]\n", + " for fpf, fpd in zip(first_pos_files, first_pos_dirs)\n", + "]\n", + "segm_endnames = [\n", + " cca_functions.get_segm_endname(fpd, bn)\n", + " for fpd, bn in zip(first_pos_dirs, basenames)\n", + "]" ] }, { @@ -116,12 +132,9 @@ "outputs": [], "source": [ "overall_df = cca_functions.load_acdc_output_only(\n", - " file_names,\n", - " image_folders,\n", - " positions,\n", - " segm_endnames\n", + " file_names, image_folders, positions, segm_endnames\n", ")\n", - "is_timelapse_data = True # Maybe not needed" + "is_timelapse_data = True # Maybe not needed" ] }, { @@ -148,24 +161,38 @@ "outputs": [], "source": [ "# if cell cycle annotations were performed in ACDC, extend the dataframe by a join on each cells relative cell\n", - "if 'cell_cycle_stage' in overall_df.columns:\n", + "if \"cell_cycle_stage\" in overall_df.columns:\n", " overall_df_with_rel = cca_functions.calculate_relatives_data(overall_df, channels)\n", " # If working with timelapse data build dataframe grouped by phases\n", " group_cols = [\n", - " 'Cell_ID', 'generation_num', 'cell_cycle_stage', 'relationship', 'position', 'file', \n", - " 'max_frame_pos', 'selection_subset', 'max_t'\n", + " \"Cell_ID\",\n", + " \"generation_num\",\n", + " \"cell_cycle_stage\",\n", + " \"relationship\",\n", + " \"position\",\n", + " \"file\",\n", + " \"max_frame_pos\",\n", + " \"selection_subset\",\n", + " \"max_t\",\n", " ]\n", " # calculate data grouped by phase only in the case, that timelapse data is available\n", - " if is_timelapse_data and 'max_t' in overall_df_with_rel.columns:\n", - " phase_grouped = cca_functions.calculate_per_phase_quantities(overall_df_with_rel, group_cols, channels)\n", + " if is_timelapse_data and \"max_t\" in overall_df_with_rel.columns:\n", + " phase_grouped = cca_functions.calculate_per_phase_quantities(\n", + " overall_df_with_rel, group_cols, channels\n", + " )\n", " # append phase-grouped data to overall_df_with_rel\n", " overall_df_with_rel = overall_df_with_rel.merge(\n", - " phase_grouped,\n", - " how='left',\n", - " on=group_cols\n", + " phase_grouped, how=\"left\", on=group_cols\n", + " )\n", + " overall_df_with_rel[\"time_in_phase\"] = (\n", + " overall_df_with_rel[\"frame_i\"] - overall_df_with_rel[\"phase_begin\"] + 1\n", " )\n", - " overall_df_with_rel['time_in_phase'] = overall_df_with_rel['frame_i'] - overall_df_with_rel['phase_begin'] + 1\n", - " overall_df_with_rel['time_in_cell_cycle'] = overall_df_with_rel.groupby(['Cell_ID', 'generation_num', 'position', 'file'])['frame_i'].transform('cumcount') + 1" + " overall_df_with_rel[\"time_in_cell_cycle\"] = (\n", + " overall_df_with_rel.groupby(\n", + " [\"Cell_ID\", \"generation_num\", \"position\", \"file\"]\n", + " )[\"frame_i\"].transform(\"cumcount\")\n", + " + 1\n", + " )" ] }, { @@ -207,22 +234,40 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n", - "sns.lineplot(data=overall_df, x='frame_i', y='cell_area_um2', hue='selection_subset', ci='sd', ax=axs[0])\n", "sns.lineplot(\n", - " data=overall_df.groupby(['frame_i', 'selection_subset']).size().reset_index(drop=False), \n", - " x='frame_i', \n", - " y=0, \n", - " hue='selection_subset', \n", - " ci='sd', \n", - " ax=axs[1]\n", - " )\n", - "track_lengths = overall_df.groupby(\n", - " ['selection_subset', 'Cell_ID']\n", - " )['frame_i'].apply(lambda x: x.max() - x.min()).reset_index(drop=False)\n", - "sns.histplot(data=track_lengths, x='frame_i', kde=True, ax=axs[2], hue='selection_subset', multiple='dodge')\n", - "axs[0].set_title('Mean cell area over time')\n", - "axs[1].set_title('Number of cells over time')\n", - "axs[2].set_title('Track length distribution')" + " data=overall_df,\n", + " x=\"frame_i\",\n", + " y=\"cell_area_um2\",\n", + " hue=\"selection_subset\",\n", + " ci=\"sd\",\n", + " ax=axs[0],\n", + ")\n", + "sns.lineplot(\n", + " data=overall_df.groupby([\"frame_i\", \"selection_subset\"])\n", + " .size()\n", + " .reset_index(drop=False),\n", + " x=\"frame_i\",\n", + " y=0,\n", + " hue=\"selection_subset\",\n", + " ci=\"sd\",\n", + " ax=axs[1],\n", + ")\n", + "track_lengths = (\n", + " overall_df.groupby([\"selection_subset\", \"Cell_ID\"])[\"frame_i\"]\n", + " .apply(lambda x: x.max() - x.min())\n", + " .reset_index(drop=False)\n", + ")\n", + "sns.histplot(\n", + " data=track_lengths,\n", + " x=\"frame_i\",\n", + " kde=True,\n", + " ax=axs[2],\n", + " hue=\"selection_subset\",\n", + " multiple=\"dodge\",\n", + ")\n", + "axs[0].set_title(\"Mean cell area over time\")\n", + "axs[1].set_title(\"Number of cells over time\")\n", + "axs[2].set_title(\"Track length distribution\")" ] }, { @@ -236,10 +281,11 @@ { "cell_type": "code", "execution_count": null, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [], "source": [ - "plot_data = overall_df.loc[overall_df['selection_subset'] == 0]" + "plot_data = overall_df.loc[overall_df[\"selection_subset\"] == 0]" ] }, { @@ -260,27 +306,27 @@ "plt.figure(figsize=(18, 6))\n", "# First Panel: Number of Cells per Frame\n", "plt.subplot(1, 3, 1)\n", - "plot_data.groupby('frame_i').size().plot(kind='line')\n", - "plt.xlabel('Frame')\n", - "plt.ylabel('Number of Cells')\n", - "plt.title('Number of Cells per Frame')\n", + "plot_data.groupby(\"frame_i\").size().plot(kind=\"line\")\n", + "plt.xlabel(\"Frame\")\n", + "plt.ylabel(\"Number of Cells\")\n", + "plt.title(\"Number of Cells per Frame\")\n", "\n", "# Second Panel: Mean Cell Volume over Time\n", "plt.subplot(1, 3, 2)\n", - "sns.lineplot(data=plot_data, x='frame_i', y='cell_area_um2', ci='sd')\n", - "plt.xlabel('Frame')\n", - "plt.ylabel('Mean Cell area (µm²)')\n", - "plt.title('Mean Cell area over Time')\n", + "sns.lineplot(data=plot_data, x=\"frame_i\", y=\"cell_area_um2\", ci=\"sd\")\n", + "plt.xlabel(\"Frame\")\n", + "plt.ylabel(\"Mean Cell area (µm²)\")\n", + "plt.title(\"Mean Cell area over Time\")\n", "\n", "# Third Panel: Total Area of All Cells over Time\n", "plt.subplot(1, 3, 3)\n", - "plot_data.groupby('frame_i')['cell_area_um2'].sum().plot(kind='line')\n", - "plt.xlabel('Frame')\n", - "plt.ylabel('Total Cell area (µm²)')\n", - "plt.title('Total Cell area over Time')\n", + "plot_data.groupby(\"frame_i\")[\"cell_area_um2\"].sum().plot(kind=\"line\")\n", + "plt.xlabel(\"Frame\")\n", + "plt.ylabel(\"Total Cell area (µm²)\")\n", + "plt.title(\"Total Cell area over Time\")\n", "\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -299,28 +345,36 @@ "outputs": [], "source": [ "# Filter the DataFrame for the first frame\n", - "first_frame_df = plot_data[plot_data['frame_i'] == 0]\n", + "first_frame_df = plot_data[plot_data[\"frame_i\"] == 0]\n", "\n", "# Filter the DataFrame for the last frame\n", - "last_frame_df = plot_data[plot_data['frame_i'] == plot_data['frame_i'].max()]\n", + "last_frame_df = plot_data[plot_data[\"frame_i\"] == plot_data[\"frame_i\"].max()]\n", "# Calculate the total number of cells in each frame\n", "first_frame_total_cells = len(first_frame_df)\n", "last_frame_total_cells = len(last_frame_df)\n", "\n", "# Plot the volume distributions\n", "plt.figure(figsize=(10, 6))\n", - "sns.histplot(data=first_frame_df, x='cell_area_um2', kde=True, label='First Frame', stat='density')\n", - "sns.histplot(data=last_frame_df, x='cell_area_um2', kde=True, label='Last Frame', stat='density')\n", - "plt.xlabel('Cell Area (µm²)')\n", - "plt.ylabel('Density')\n", - "plt.title('Relative Volume Distribution of Cells')\n", + "sns.histplot(\n", + " data=first_frame_df,\n", + " x=\"cell_area_um2\",\n", + " kde=True,\n", + " label=\"First Frame\",\n", + " stat=\"density\",\n", + ")\n", + "sns.histplot(\n", + " data=last_frame_df, x=\"cell_area_um2\", kde=True, label=\"Last Frame\", stat=\"density\"\n", + ")\n", + "plt.xlabel(\"Cell Area (µm²)\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Relative Volume Distribution of Cells\")\n", "plt.legend()\n", "\n", "# Add text annotations for the relative counts\n", - "print(f'Cell count first frame: {first_frame_total_cells}')\n", - "print(f'Cell count last frame: {last_frame_total_cells}')\n", + "print(f\"Cell count first frame: {first_frame_total_cells}\")\n", + "print(f\"Cell count last frame: {last_frame_total_cells}\")\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -339,15 +393,17 @@ "outputs": [], "source": [ "# Calculate track lengths\n", - "track_lengths = plot_data.groupby('Cell_ID')['frame_i'].apply(lambda x: x.max() - x.min())\n", + "track_lengths = plot_data.groupby(\"Cell_ID\")[\"frame_i\"].apply(\n", + " lambda x: x.max() - x.min()\n", + ")\n", "\n", "# Plot track length distribution\n", "plt.figure(figsize=(10, 6))\n", "sns.histplot(data=track_lengths, kde=True)\n", - "plt.xlabel('Track Length')\n", - "plt.ylabel('Count')\n", - "plt.title('Distribution of Track Lengths')\n", - "plt.show()\n" + "plt.xlabel(\"Track Length\")\n", + "plt.ylabel(\"Count\")\n", + "plt.title(\"Distribution of Track Lengths\")\n", + "plt.show()" ] }, { @@ -365,31 +421,30 @@ "metadata": {}, "outputs": [], "source": [ - "filtered_df = plot_data[plot_data['Cell_ID'].map(track_lengths) > 20]\n", + "filtered_df = plot_data[plot_data[\"Cell_ID\"].map(track_lengths) > 20]\n", "plt.figure(figsize=(21, 7))\n", "# First Panel: Volume over time lineplot\n", "plt.subplot(1, 2, 1)\n", - "for cell_id, cell_data in filtered_df.groupby('Cell_ID'):\n", - " plt.plot(cell_data['frame_i'], cell_data['cell_area_um2'], label=f'Cell {cell_id}')\n", - "plt.xlabel('Frame')\n", - "plt.ylabel('Cell Area (µm²)')\n", - "plt.title('Volume over Time')\n", + "for cell_id, cell_data in filtered_df.groupby(\"Cell_ID\"):\n", + " plt.plot(cell_data[\"frame_i\"], cell_data[\"cell_area_um2\"], label=f\"Cell {cell_id}\")\n", + "plt.xlabel(\"Frame\")\n", + "plt.ylabel(\"Cell Area (µm²)\")\n", + "plt.title(\"Volume over Time\")\n", "plt.legend().set_visible(False) # Hide the legend\n", "\n", "# Second Panel: Traces of all cells\n", "plt.subplot(1, 2, 2)\n", - "for cell_id, cell_data in filtered_df.groupby('Cell_ID'):\n", - " plt.plot(cell_data['centroid-1'], cell_data['centroid-0'], label=f'Cell {cell_id}')\n", - "plt.xlabel('X-coordinate')\n", - "plt.title('Traces of Cells')\n", + "for cell_id, cell_data in filtered_df.groupby(\"Cell_ID\"):\n", + " plt.plot(cell_data[\"centroid-1\"], cell_data[\"centroid-0\"], label=f\"Cell {cell_id}\")\n", + "plt.xlabel(\"X-coordinate\")\n", + "plt.title(\"Traces of Cells\")\n", "plt.legend().set_visible(False) # Hide the legend\n", - "maxCentroidAll = filtered_df[['centroid-0', 'centroid-1']].max().max()\n", - "plt.xlim(0, maxCentroidAll+50)\n", - "plt.ylim(0, maxCentroidAll+50)\n", + "maxCentroidAll = filtered_df[[\"centroid-0\", \"centroid-1\"]].max().max()\n", + "plt.xlim(0, maxCentroidAll + 50)\n", + "plt.ylim(0, maxCentroidAll + 50)\n", "\n", "plt.tight_layout()\n", - "plt.show()\n", - "\n" + "plt.show()" ] }, { @@ -411,8 +466,8 @@ " \"\"\"\n", " Calculate the frame-by-frame distance of a centroid series\n", " \"\"\"\n", - " xSeries = centroid_series['centroid-1']\n", - " ySeries = centroid_series['centroid-0']\n", + " xSeries = centroid_series[\"centroid-1\"]\n", + " ySeries = centroid_series[\"centroid-0\"]\n", " # Calculate the distance between each frame\n", " dists = np.sqrt((xSeries.diff() ** 2) + (ySeries.diff() ** 2))\n", " return dists" @@ -425,28 +480,39 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# Left panel: Total traveled distance vs. mean volume\n", "plt.figure(figsize=(14, 7))\n", "plt.subplot(1, 2, 1)\n", - "for cell_id, cell_data in filtered_df.groupby('Cell_ID'):\n", - " plt.scatter(cell_data['cell_area_um2'].mean(), np.max(distance_matrix(cell_data[['centroid-0', 'centroid-1']], cell_data[['centroid-0', 'centroid-1']])))\n", + "for cell_id, cell_data in filtered_df.groupby(\"Cell_ID\"):\n", + " plt.scatter(\n", + " cell_data[\"cell_area_um2\"].mean(),\n", + " np.max(\n", + " distance_matrix(\n", + " cell_data[[\"centroid-0\", \"centroid-1\"]],\n", + " cell_data[[\"centroid-0\", \"centroid-1\"]],\n", + " )\n", + " ),\n", + " )\n", "\n", - "plt.xlabel('Mean Area [µm²]')\n", - "plt.ylabel('Total Traveled Distance')\n", - "plt.title('Total Traveled Distance vs. Mean Area')\n", + "plt.xlabel(\"Mean Area [µm²]\")\n", + "plt.ylabel(\"Total Traveled Distance\")\n", + "plt.title(\"Total Traveled Distance vs. Mean Area\")\n", "\n", "# Right panel: Frame-by-frame distance vs. frame-by-frame growth\n", "plt.subplot(1, 2, 2)\n", - "for cell_id, cell_data in filtered_df.groupby('Cell_ID'):\n", - " plt.scatter(frame_by_frame_dist(cell_data[['centroid-0', 'centroid-1']])[1:], np.diff(cell_data['cell_area_um2']), alpha=0.4)\n", + "for cell_id, cell_data in filtered_df.groupby(\"Cell_ID\"):\n", + " plt.scatter(\n", + " frame_by_frame_dist(cell_data[[\"centroid-0\", \"centroid-1\"]])[1:],\n", + " np.diff(cell_data[\"cell_area_um2\"]),\n", + " alpha=0.4,\n", + " )\n", "\n", - "plt.xlabel('Frame-by-Frame Distance')\n", - "plt.ylabel('Frame-by-Frame Growth [Area in µm²]')\n", - "plt.title('Frame-by-Frame Distance vs. Frame-by-Frame Growth')\n", + "plt.xlabel(\"Frame-by-Frame Distance\")\n", + "plt.ylabel(\"Frame-by-Frame Growth [Area in µm²]\")\n", + "plt.title(\"Frame-by-Frame Distance vs. Frame-by-Frame Growth\")\n", "\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { diff --git a/scripts/fix_split_imports.py b/scripts/fix_split_imports.py new file mode 100644 index 000000000..38ae95a5a --- /dev/null +++ b/scripts/fix_split_imports.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""Fix parent-package imports in split submodules (from . -> from ..).""" + +from __future__ import annotations + +import re +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] / "cellacdc" + +PACKAGES: dict[str, set[str]] = { + "utils": { + "dataframe", + "install", + "io", + "logging", + "misc", + "models", + "paths", + "qt", + "text", + "version", + }, + "workers": { + "_base", + "alignment", + "data_prep", + "gui", + "io", + "metrics", + "segm", + "tracking", + "util", + }, + "widgets": {"canvas", "controls", "toolbars"}, + "dialogs": { + "_base", + "export", + "general", + "measurements", + "metadata", + "models", + "preprocess", + "tracking", + }, +} + + +def fix_line(line: str, siblings: set[str]) -> str: + m = re.match(r"^(\s*)from \. import (.+)$", line) + if m: + indent, rest = m.groups() + return f"{indent}from .. import {rest}" + + m = re.match(r"^(\s*)from \.(\S+) import (.+)$", line) + if not m: + return line + indent, module, rest = m.groups() + top = module.split(".", 1)[0] + if top in siblings: + return line + return f"{indent}from ..{module} import {rest}" + + +def fix_file(path: Path, siblings: set[str]) -> bool: + lines = path.read_text().splitlines(keepends=True) + new_lines = [fix_line(line, siblings) for line in lines] + if new_lines != lines: + path.write_text("".join(new_lines)) + return True + return False + + +def main() -> None: + for pkg, siblings in PACKAGES.items(): + pkg_dir = ROOT / pkg + changed = 0 + for path in sorted(pkg_dir.glob("*.py")): + if path.name == "__init__.py": + continue + if fix_file(path, siblings): + changed += 1 + print(f"{pkg}: fixed {changed} files") + + +if __name__ == "__main__": + main() diff --git a/scripts/rename_utils_tools.py b/scripts/rename_utils_tools.py new file mode 100644 index 000000000..99c6d87cc --- /dev/null +++ b/scripts/rename_utils_tools.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +"""Rename cellacdc.utils -> tools and cellacdc.utils -> utils in source files.""" + +from __future__ import annotations + +import re +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + +SKIP_DIRS = {".git", "__pycache__", ".venv", "venv", "node_modules"} + +# Phase 1: batch-tool package (old utils -> tools). Run before utils -> utils. +TOOLS_PATTERNS: list[tuple[str, str]] = [ + (r"\bfrom \.\.utils import resize\b", "from ..tools import resize"), + (r"\bfrom \.\.utils import base\b", "from ..tools import base"), + (r"\bfrom \.\.utils\.", "from ..tools."), + (r"\bfrom \.utils\.", "from .tools."), + (r"\bfrom \.utils import", "from .tools import"), + (r"\bfrom cellacdc\.utils\.", "from cellacdc.tools."), + (r"\bfrom cellacdc\.utils import", "from cellacdc.tools import"), + (r'"cellacdc/tools/', '"cellacdc/tools/'), + (r"'cellacdc/tools/", "'cellacdc/tools/"), + (r"\bcellacdc/utils/", "cellacdc/tools/"), +] + +# Phase 2: helper package (utils -> utils). +UTILS_PATTERNS: list[tuple[str, str]] = [ + (r"\bmyutils\b", "utils"), + (r'"cellacdc/utils/', '"cellacdc/tools/'), + (r"'cellacdc/utils/", "'cellacdc/tools/"), + (r"\bcellacdc/utils/", "cellacdc/tools/"), + (r"\bcellacdc\.utils\b", "cellacdc.utils"), +] + +# Phase 3: same-package imports inside tools/ +TOOLS_INTERNAL: list[tuple[str, str]] = [ + (r"\bfrom \.\.tools import base\b", "from . import base"), +] + + +def iter_files() -> list[Path]: + files: list[Path] = [] + for path in ROOT.rglob("*"): + if path.suffix != ".py": + continue + if any(part in SKIP_DIRS for part in path.parts): + continue + files.append(path) + return files + + +def apply_patterns(text: str, patterns: list[tuple[str, str]]) -> str: + for pattern, repl in patterns: + text = re.sub(pattern, repl, text) + return text + + +def fix_tools_package() -> None: + tools_dir = ROOT / "cellacdc" / "tools" + if not tools_dir.is_dir(): + return + for path in tools_dir.rglob("*.py"): + text = path.read_text() + updated = apply_patterns(text, TOOLS_INTERNAL) + if updated != text: + path.write_text(updated) + + +def main() -> None: + for path in iter_files(): + text = path.read_text() + updated = apply_patterns(text, TOOLS_PATTERNS) + updated = apply_patterns(updated, UTILS_PATTERNS) + if updated != text: + path.write_text(updated) + fix_tools_package() + print("Import rewrites complete.") + + +if __name__ == "__main__": + main() diff --git a/scripts/split_god_files.py b/scripts/split_god_files.py new file mode 100644 index 000000000..818694acb --- /dev/null +++ b/scripts/split_god_files.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +"""Split god files into packages while preserving public import paths.""" + +from __future__ import annotations + +import ast +import re +import shutil +import textwrap +from collections import defaultdict +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +CELLACDC = ROOT / "cellacdc" + + +def extract_nodes(source: str) -> tuple[str, list[tuple[str, str, int, int]]]: + """Return preamble and (name, kind, start, end) for each top-level def/class.""" + lines = source.splitlines(keepends=True) + tree = ast.parse(source) + nodes: list[tuple[str, str, int, int]] = [] + first_start = None + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + end = getattr(node, "end_lineno", node.lineno) + kind = "class" if isinstance(node, ast.ClassDef) else "function" + nodes.append((node.name, kind, node.lineno, end)) + if first_start is None: + first_start = node.lineno + preamble_end = first_start - 1 if first_start else len(lines) + preamble = "".join(lines[:preamble_end]) + return preamble, nodes + + +def slice_nodes(source: str, nodes: list[tuple[str, str, int, int]], names: set[str]) -> str: + lines = source.splitlines(keepends=True) + chunks: list[str] = [] + for name, _kind, start, end in nodes: + if name in names: + chunks.append("".join(lines[start - 1 : end])) + return "\n\n".join(chunks) + + +def write_module( + path: Path, + doc: str, + preamble: str, + body: str, + *, + siblings: set[str] | None = None, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if siblings is not None: + preamble = fix_preamble_imports(preamble, siblings) + content = f'"""{doc}"""\n\n{preamble.rstrip()}\n\n{body.rstrip()}\n' + path.write_text(content) + + +def assign_utils(name: str) -> str: + rules: list[tuple[str, str]] = [ + ("logging", r"log|Logger"), + ("paths", r"path|folder|dir|recent|trim_path|explorer|filemaneger|gdrive|acdc_data|pos_folder|images_folder|PosStatus|pos_status"), + ("install", r"install|gpu|pytorch|torch|java|javabridge|conda|pip|package|mamba|upgrade_javabridge|java_exists|download_java|check_install"), + ("dataframe", r"df_|dataframe|acdc_df|ctc|reset_index|format_ID|cca_col|are_acdc_dfs|fix_acdc_df"), + ("version", r"version|git|branch|date_from|info_version|salute|cellpose.*version|second_version"), + ("models", r"model|download|Tracker|tracker|segm_params|init_tracker|ArgSpec|parse_model|insertModel|getModel|promptable|ModelArg|IntensityImgRequired"), + ("qt", r"widget|Qt|Q[A-Z]|retain|cli_multi_choice|testQcore"), + ("io", r"bytes|Memory|browse_docs|save_response|read_|write_|open_url|browse_url"), + ("text", r"tooltip|instruction|html|text|string|trim|annot|elided|fstring|append_text|show_in_file"), + ] + for module, pat in rules: + if re.search(pat, name): + return module + return "misc" + + +def assign_worker(name: str) -> str: + mapping = { + "worker_exception_handler": "_base", + "workerLogger": "_base", + "signals": "_base", + "BaseWorkerUtil": "_base", + "SimpleWorker": "_base", + "AutoPilotWorker": "gui", + "FindNextNewIdWorker": "gui", + "StoreGuiStateWorker": "io", + "AutoSaveWorker": "io", + "LazyLoader": "io", + "loadDataWorker": "io", + "saveDataWorker": "io", + "MoveTempFilesWorker": "io", + "MigrateUserProfileWorker": "io", + "relabelSequentialWorker": "io", + "segmWorker": "segm", + "segmVideoWorker": "segm", + "SegForLostIDsWorker": "segm", + "PostProcessSegmWorker": "segm", + "MagicPromptsWorker": "segm", + "FillHolesInSegWorker": "segm", + "DelObjectsOutsideSegmROIWorker": "segm", + "LabelRoiWorker": "segm", + "CreateConnected3Dsegm": "segm", + "trackingWorker": "tracking", + "TrackSubCellObjectsWorker": "tracking", + "ApplyTrackInfoWorker": "tracking", + "ToSymDivWorker": "tracking", + "CopyAllLostObjectsWorker": "tracking", + "ComputeMetricsWorker": "metrics", + "ComputeMetricsMultiChannelWorker": "metrics", + "ConcatAcdcDfsWorker": "metrics", + "ConcatSpotmaxDfsWorker": "metrics", + "CountObjectsInSegm": "metrics", + "GenerateMotherBudTotalTableWorker": "metrics", + "CcaIntegrityCheckerWorker": "metrics", + "reapplyDataPrepWorker": "data_prep", + "DataPrepSaveBkgrDataWorker": "data_prep", + "DataPrepCropWorker": "data_prep", + "RestructMultiPosWorker": "data_prep", + "RestructMultiTimepointsWorker": "data_prep", + "ImagesToPositionsWorker": "data_prep", + "CustomPreprocessWorkerGUI": "data_prep", + "CombineChannelsWorkerGUI": "data_prep", + "CustomPreprocessWorkerUtil": "data_prep", + "CombineChannelsWorkerUtil": "data_prep", + "SaveProcessedDataWorker": "data_prep", + "SaveCombinedChannelsWorker": "data_prep", + "FucciPreprocessWorker": "data_prep", + "AlignDataWorker": "alignment", + "AlignWorker": "alignment", + "FromImajeJroiToSegmNpzWorker": "util", + "ToImajeJroiWorker": "util", + "ToObjCoordsWorker": "util", + "Stack2DsegmTo3Dsegm": "util", + "ResizeUtilWorker": "util", + "FilterObjsFromCoordsTable": "util", + "ApplyImageFilterWorker": "util", + "ScreenRecorderWorker": "util", + } + return mapping.get(name, "util") + + +def assign_widget(name: str) -> str: + if "Toolbar" in name or "ToolButton" in name or name in { + "ToolBarSeparator", + "ToolBar", + "rightClickToolButton", + }: + return "toolbars" + canvas_markers = ( + "pg.", + "Plot", + "ImageItem", + "ROI", + "Histogram", + "Gradient", + "Scatter", + "Contour", + "ScaleBar", + "ImShow", + "Ghost", + "Ruler", + "RectItem", + "ColorButton", + "LabelItem", + "MainPlot", + "MouseCursor", + "ScrollBar", + "sliderWithSpinBox", + ) + if any(m in name for m in canvas_markers) or name in { + "ContourItem", + "BaseScatterPlotItem", + "CustomAnnotationScatterPlotItem", + "ScatterPlotItem", + "myLabelItem", + "PolyLineROI", + "ZoomROI", + "DelROI", + "PlotCurveItem", + "BaseGradientEditorItemImage", + "BaseGradientEditorItemLabels", + "baseHistogramLUTitem", + "myHistogramLUTitem", + "overlayLabelsGradientWidget", + "labelsGradientWidget", + "BaseImageItem", + "BaseLabelsImageItem", + "OverlayImageItem", + "ParentImageItem", + "ChildImageItem", + "labImageItem", + "labelledQScrollbar", + "navigateQScrollBar", + "linkedQScrollbar", + "myColorButton", + "ScrollBarWithNumericControl", + "PointsScatterPlotItem", + "LabelRoiCircularItem", + }: + return "canvas" + return "controls" + + +def assign_dialog(name: str) -> str: + if name in {"QBaseDialog", "ArgWidget"}: + return "_base" + if name in {"addCustomModelMessages", "addCustomPromptModelMessages"}: + return "models" + rules: list[tuple[str, str]] = [ + ("tracking", r"Tracker|Track|Cca|cca|editCca|lineage|ApplyTrack|MotherBud|SymDiv|manualSeparate|FindID|EditID|NumericEntry|swap|merge"), + ("metadata", r"Metadata|metadata|XML|QDialogMetadata|filenameDialog|AppendText|EntriesWidget|ColumnNames|CropZ|CropTrange|CropT|Zslice|MultiTimePoint|TreeSelector|TreesSelector|MultiList|selectPositions|OrderableList|SelectFolders|OverlayLabels|AutoSaveInterval"), + ("preprocess", r"PreProcess|CombineChannels|Fucci|ResizeUtil|InitFiji|ImageJRois|randomWalker|PostProcess|Threshold|Crop|Formula|DataPrepSubCrops|stopFrame|startStop|FutureFrames|FunctionParams|TestSegm|wandTolerance"), + ("measurements", r"Metric|Measurement|combineMetrics|SetMeasurements|ComputeMetrics|GenerateMother|ApplyTrackTable|SelectFeatures|CombineFeatures"), + ("export", r"Export|Video|Timestamp|ScaleBar|ViewText|pdDataFrame|ViewCcaTable|ObjectCount|Screen|Logo|ShortcutEditor"), + ("models", r"Model|downloadModel|SelectPromptable|SelectModel|InstallPyTorch|Bayesian|DeltaTracker|CellACDCTracker|Promptable|QDialogModelParams|QInput|ChangeUserProfile|SelectAcdcDf|Restore"), + ] + for module, pat in rules: + if re.search(pat, name): + return module + return "general" + + +def fix_preamble_imports(preamble: str, siblings: set[str]) -> str: + """Rewrite cellacdc-root imports for package submodules.""" + out: list[str] = [] + for line in preamble.splitlines(keepends=True): + newline = "\n" if line.endswith("\n") else "" + stripped = line.rstrip("\n") + m = re.match(r"^(\s*)from \. import (.+)$", stripped) + if m: + indent, rest = m.groups() + out.append(f"{indent}from .. import {rest}{newline}") + continue + m = re.match(r"^(\s*)from \.(\S+) import (.+)$", stripped) + if m: + indent, module, rest = m.groups() + top = module.split(".", 1)[0] + if top in siblings: + out.append(line) + else: + out.append(f"{indent}from ..{module} import {rest}{newline}") + continue + out.append(line) + return "".join(out) + + +def inject_cross_imports(pkg_dir: Path) -> None: + """Wire sibling symbols: top imports for bases, trailing imports for calls.""" + assign: dict[str, str] = {} + for p in pkg_dir.glob("*.py"): + if p.name == "__init__.py": + continue + for node in ast.parse(p.read_text()).body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + assign[node.name] = p.stem + + for p in sorted(pkg_dir.glob("*.py")): + if p.name == "__init__.py": + continue + mod = p.stem + source = p.read_text() + tree = ast.parse(source) + + top_needed: dict[str, set[str]] = defaultdict(set) + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + for base in node.bases: + for sub in ast.walk(base): + if isinstance(sub, ast.Name) and sub.id in assign and assign[sub.id] != mod: + top_needed[assign[sub.id]].add(sub.id) + + trailing_needed: dict[str, set[str]] = defaultdict(set) + for node in ast.walk(tree): + if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): + if node.id not in assign or assign[node.id] == mod: + continue + src = assign[node.id] + if node.id in top_needed.get(src, set()): + continue + trailing_needed[src].add(node.id) + + if not top_needed and not trailing_needed: + continue + + def render_imports(needed: dict[str, set[str]], prefix: str) -> str: + lines: list[str] = [] + for src_mod in sorted(needed): + names = sorted(needed[src_mod]) + lines.append(f"{prefix}from .{src_mod} import (") + for name in names: + lines.append(f"{prefix} {name},") + lines.append(f"{prefix})") + return "\n".join(lines) + ("\n\n" if lines else "") + + lines = source.splitlines(keepends=True) + first_def = next( + node.lineno + for node in tree.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + ) + top_block = render_imports(top_needed, "") + trailing_block = render_imports(trailing_needed, "") + if trailing_block: + trailing_block = "\n# Sibling imports (deferred to avoid import cycles)\n" + trailing_block + + new_source = ( + "".join(lines[: first_def - 1]) + + top_block + + "".join(lines[first_def - 1 :]) + + trailing_block + ) + p.write_text(new_source) + + +def split_package( + src_file: Path, + pkg_dir: Path, + assign_fn, + module_doc: str, + *, + delete_src: bool = True, + shim_file: Path | None = None, + shim_import_from: str | None = None, +) -> dict[str, list[str]]: + source = src_file.read_text() + preamble, nodes = extract_nodes(source) + groups: dict[str, set[str]] = defaultdict(set) + for name, _kind, _s, _e in nodes: + groups[assign_fn(name)].add(name) + + if pkg_dir.exists(): + shutil.rmtree(pkg_dir) + pkg_dir.mkdir(parents=True) + + exported: dict[str, list[str]] = {} + sibling_stems = set(groups) + for module, names in sorted(groups.items()): + body = slice_nodes(source, nodes, names) + if not body.strip(): + continue + out = pkg_dir / f"{module}.py" + write_module( + out, + f"{module_doc}: {module}.", + preamble, + body, + siblings=sibling_stems, + ) + exported[module] = sorted(names) + + init_lines = [ + f'"""{module_doc}."""', + "", + ] + all_names: list[str] = [] + for module in sorted(exported): + names = exported[module] + init_lines.append(f"from .{module} import (") + for name in names: + init_lines.append(f" {name},") + all_names.append(name) + init_lines.append(")") + init_lines.append("") + init_lines.append("__all__ = [") + for name in all_names: + init_lines.append(f' "{name}",') + init_lines.append("]") + (pkg_dir / "__init__.py").write_text("\n".join(init_lines) + "\n") + + inject_cross_imports(pkg_dir) + + if delete_src: + src_file.unlink() + + if shim_file is not None: + imp = shim_import_from or pkg_dir.name + shim = textwrap.dedent( + f'''\ + """Compatibility shim; implementation lives in {imp}/.""" + + from .{imp} import * # noqa: F403 + ''' + ) + shim_file.write_text(shim) + + return exported + + +def split_widgets(src_file: Path, pkg_dir: Path) -> None: + """widgets.py becomes a package that also re-exports components/.""" + source = src_file.read_text() + # Keep import block through component re-exports as package preamble. + marker = "\n\n\n\nclass ContourItem" + idx = source.find(marker) + if idx == -1: + raise RuntimeError("Could not locate widgets split marker") + header = source[: idx + 2] + body_source = source[idx + 2 :] + _empty, nodes = extract_nodes(body_source) + + groups: dict[str, set[str]] = defaultdict(set) + for name, _kind, _s, _e in nodes: + groups[assign_widget(name)].add(name) + + if pkg_dir.exists(): + shutil.rmtree(pkg_dir) + pkg_dir.mkdir(parents=True) + + exported: dict[str, list[str]] = {} + sibling_stems = set(groups) + for module, names in sorted(groups.items()): + chunk = slice_nodes(body_source, nodes, names) + if not chunk.strip(): + continue + write_module( + pkg_dir / f"{module}.py", + f"GUI widgets: {module}.", + header, + chunk, + siblings=sibling_stems, + ) + exported[module] = sorted(names) + + init_parts = [ + '"""GUI widgets package (controls, canvas, toolbars) + components re-exports."""', + "", + "from ..components.palette import * # noqa: F403", + "from ..components.progress import * # noqa: F403", + "from ..components.buttons import * # noqa: F403", + "from ..components.layout import * # noqa: F403", + "from ..components.inputs_basic import * # noqa: F403", + "from ..components.path_controls import * # noqa: F403", + "from ..components.lists import * # noqa: F403", + "from ..components.base import QBaseWindow, QBaseDialog # noqa: F401", + "", + ] + all_names: list[str] = [] + for module in sorted(exported): + names = exported[module] + init_parts.append(f"from .{module} import (") + for name in names: + init_parts.append(f" {name},") + all_names.append(name) + init_parts.append(")") + init_parts.append("") + + init_parts.append("__all__ = [") + for name in all_names: + init_parts.append(f' "{name}",') + init_parts.append("]") + (pkg_dir / "__init__.py").write_text("\n".join(init_parts) + "\n") + inject_cross_imports(pkg_dir) + src_file.unlink() + + +def main() -> None: + # utils/ is already split from the former myutils.py monolith. + split_package( + CELLACDC / "workers.py", + CELLACDC / "workers", + assign_worker, + "Background Qt workers", + delete_src=True, + ) + split_widgets(CELLACDC / "widgets.py", CELLACDC / "widgets") + split_package( + CELLACDC / "apps.py", + CELLACDC / "dialogs", + assign_dialog, + "Cell-ACDC dialog windows", + delete_src=False, + shim_file=CELLACDC / "apps.py", + shim_import_from="dialogs", + ) + print("Split complete.") + + +if __name__ == "__main__": + main() diff --git a/scripts/split_widgets_subpackages.py b/scripts/split_widgets_subpackages.py new file mode 100644 index 000000000..049d4313e --- /dev/null +++ b/scripts/split_widgets_subpackages.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +"""Split widgets/canvas.py, controls.py, and toolbars.py into subpackages.""" + +from __future__ import annotations + +import ast +import re +import shutil +from collections import defaultdict +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +WIDGETS = ROOT / "cellacdc" / "widgets" + +CANVAS_MODULES: dict[str, set[str]] = { + "histogram": { + "BaseGradientEditorItemImage", + "BaseGradientEditorItemLabels", + "baseHistogramLUTitem", + "myHistogramLUTitem", + "overlayLabelsGradientWidget", + "labelsGradientWidget", + "myColorButton", + }, + "rois": {"ROI", "ZoomROI", "DelROI", "PolyLineROI"}, + "plot_items": { + "ContourItem", + "BaseScatterPlotItem", + "CustomAnnotationScatterPlotItem", + "ScatterPlotItem", + "myLabelItem", + "LabelRoiCircularItem", + "PlotCurveItem", + "MainPlotItem", + "GhostContourItem", + "RulerPlotItem", + "PointsScatterPlotItem", + "RectItem", + "LabelItem", + "ScaleBar", + }, + "images": { + "BaseImageItem", + "BaseLabelsImageItem", + "OverlayImageItem", + "ParentImageItem", + "ChildImageItem", + "labImageItem", + "GhostMaskItem", + "_ImShowImageItem", + }, + "imshow": {"ImShow", "ImShowPlotItem"}, + "scrollbars": { + "MouseCursor", + "labelledQScrollbar", + "navigateQScrollBar", + "linkedQScrollbar", + "sliderWithSpinBox", + "ScrollBarWithNumericControl", + }, +} + +CONTROLS_MODULES: dict[str, set[str]] = { + "dialogs": { + "QDialogListbox", + "myMessageBox", + "view_visualcpp_screenshot", + "installJavaDialog", + "selectTrackerGUI", + "warnVisualCppRequired", + }, + "inputs": { + "ExpandableListBox", + "QClickableLabel", + "QCenteredComboBox", + "AlphaNumericComboBox", + "mySpinBox", + "ShortcutLineEdit", + "CenteredDoubleSpinbox", + "readOnlyDoubleSpinbox", + "readOnlySpinbox", + "DoubleSpinBox", + "SpinBox", + "ReadOnlyLineEdit", + "FloatLineEdit", + "IntLineEdit", + "LineEdit", + "SearchLineEdit", + "VectorLineEdit", + "OddSpinBox", + "KeySequenceFromText", + "ComboBox", + "WhitelistLineEdit", + "highlightableQWidgetAction", + }, + "metrics": { + "_metricsQGBox", + "channelMetricsQGBox", + "PixelSizeGroupbox", + "objPropsQGBox", + "objIntesityMeasurQGBox", + "SetMeasurementsGroupBox", + }, + "forms": { + "selectStartStopFrames", + "formWidget", + "CheckboxesGroupBox", + "guiTabControl", + "CopiableCommandWidget", + "LabelsWidget", + "SamInputPointsWidget", + "FontSizeWidget", + "RangeSelector", + "PreProcessingSelector", + "RescaleImageJroisGroupbox", + "TimeWidget", + "YeazV2SelectModelNameCombobox", + "AutoSaveIntervalWidget", + "CheckableWidget", + "PostProcessSegmSlider", + "PostProcessSegmSpinbox", + }, + "panels": { + "statusBarPermanentLabel", + "listWidget", + "OrderableListWidget", + "KeptObjectIDsList", + "Toggle", + "ToggleTerminalButton", + "expandCollapseButton", + "ToggleVisibilityButton", + "ToggleVisibilityCheckBox", + "FeatureSelectorButton", + "CheckableSpinBoxWidgets", + "Label", + "LatexLabel", + "SwitchPlaneCombobox", + "TimestampItem", + "CheckableAction", + }, +} + +TOOLBARS_MODULES: dict[str, set[str]] = { + "_base": { + "ToolBarSeparator", + "ToolBar", + "rightClickToolButton", + "ToolButtonCustomColor", + "GradientToolButton", + "ToolButtonTextIcon", + "customAnnotToolButton", + "PointsLayerToolButton", + "OverlayChannelToolButton", + "SavePointsLayerButton", + "ManualTrackingToolBar", + "ManualBackgroundToolBar", + }, + "feature": { + "CopyLostObjectToolbar", + "DrawClearRegionToolbar", + "WhitelistIDsToolbar", + "MagicPromptsToolbar", + "PointsLayersToolbar", + "PromptableModelPointsLayerToolbar", + "OverlayToolbar", + "HighlightedIDToolbar", + "WandControlsToolbar", + }, +} + + +def extract_nodes(source: str) -> tuple[str, list[tuple[str, str, int, int]]]: + lines = source.splitlines(keepends=True) + tree = ast.parse(source) + nodes: list[tuple[str, str, int, int]] = [] + first_start = None + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + end = getattr(node, "end_lineno", node.lineno) + kind = "class" if isinstance(node, ast.ClassDef) else "function" + nodes.append((node.name, kind, node.lineno, end)) + if first_start is None: + first_start = node.lineno + preamble_end = first_start - 1 if first_start else len(lines) + return "".join(lines[:preamble_end]), nodes + + +def slice_nodes(source: str, nodes: list[tuple[str, str, int, int]], names: set[str]) -> str: + lines = source.splitlines(keepends=True) + chunks: list[str] = [] + for name, _kind, start, end in nodes: + if name in names: + chunks.append("".join(lines[start - 1 : end])) + return "\n\n".join(chunks) + + +def clean_preamble(preamble: str) -> str: + """Drop sibling-package imports; they are regenerated after the split.""" + out: list[str] = [] + skip = False + for line in preamble.splitlines(keepends=True): + stripped = line.rstrip("\n") + if re.match(r"^from \.(canvas|controls|toolbars) import ", stripped): + skip = stripped.rstrip().endswith("(") + continue + if skip: + if ")" in stripped: + skip = False + continue + if stripped.startswith("# Sibling imports"): + break + out.append(line) + return "".join(out) + + +def deepen_imports(preamble: str) -> str: + """widgets//.py needs one more parent level than widgets/.py.""" + out: list[str] = [] + for line in preamble.splitlines(keepends=True): + newline = "\n" if line.endswith("\n") else "" + stripped = line.rstrip("\n") + m = re.match(r"^(\s*)from \.\. import (.+)$", stripped) + if m: + indent, rest = m.groups() + out.append(f"{indent}from ... import {rest}{newline}") + continue + m = re.match(r"^(\s*)from \.\.(\S+) import (.+)$", stripped) + if m: + indent, module, rest = m.groups() + out.append(f"{indent}from ...{module} import {rest}{newline}") + continue + out.append(line) + return "".join(out) + + +def write_module(path: Path, doc: str, preamble: str, body: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + content = f'"""{doc}"""\n\n{preamble.rstrip()}\n\n{body.rstrip()}\n' + path.write_text(content) + + +def split_area( + src: Path, + dest: Path, + modules: dict[str, set[str]], + doc: str, +) -> dict[str, list[str]]: + source = src.read_text() + preamble, nodes = extract_nodes(source) + preamble = deepen_imports(clean_preamble(preamble)) + + if dest.exists(): + shutil.rmtree(dest) + dest.mkdir(parents=True) + + exported: dict[str, list[str]] = {} + for module, names in sorted(modules.items()): + body = slice_nodes(source, nodes, names) + if not body.strip(): + raise RuntimeError(f"No body extracted for {dest.name}/{module}.py") + write_module(dest / f"{module}.py", f"{doc}: {module}.", preamble, body) + exported[module] = sorted(names) + + init_lines = [f'"""{doc}."""', ""] + all_names: list[str] = [] + for module in sorted(exported): + names = exported[module] + init_lines.append(f"from .{module} import (") + for name in names: + init_lines.append(f" {name},") + all_names.append(name) + init_lines.append(")") + init_lines.append("") + init_lines.append("__all__ = [") + for name in all_names: + init_lines.append(f' "{name}",') + init_lines.append("]") + (dest / "__init__.py").write_text("\n".join(init_lines) + "\n") + return exported + + +def inject_widget_imports() -> None: + assign: dict[str, tuple[str, str]] = {} + for area in ("canvas", "controls", "toolbars"): + for path in (WIDGETS / area).glob("*.py"): + if path.name == "__init__.py": + continue + for node in ast.parse(path.read_text()).body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + assign[node.name] = (area, path.stem) + + for area in ("canvas", "controls", "toolbars"): + for path in sorted((WIDGETS / area).glob("*.py")): + if path.name == "__init__.py": + continue + mod = path.stem + source = path.read_text() + tree = ast.parse(source) + + top_needed: dict[tuple[str, str], set[str]] = defaultdict(set) + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + for base in node.bases: + for sub in ast.walk(base): + if ( + isinstance(sub, ast.Name) + and sub.id in assign + and assign[sub.id] != (area, mod) + ): + top_needed[assign[sub.id]].add(sub.id) + + trailing_needed: dict[tuple[str, str], set[str]] = defaultdict(set) + for node in ast.walk(tree): + if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): + if node.id not in assign or assign[node.id] == (area, mod): + continue + loc = assign[node.id] + if node.id in top_needed.get(loc, set()): + continue + trailing_needed[loc].add(node.id) + + if not top_needed and not trailing_needed: + continue + + def render(needed: dict[tuple[str, str], set[str]], prefix: str) -> str: + lines: list[str] = [] + for sub_area, sub_mod in sorted(needed): + names = sorted(needed[(sub_area, sub_mod)]) + if sub_area == area: + import_from = f".{sub_mod}" + else: + import_from = f"..{sub_area}.{sub_mod}" + lines.append(f"{prefix}from {import_from} import (") + for name in names: + lines.append(f"{prefix} {name},") + lines.append(f"{prefix})") + return "\n".join(lines) + ("\n\n" if lines else "") + + lines = source.splitlines(keepends=True) + first_def = next( + n.lineno + for n in tree.body + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + ) + top_block = render(top_needed, "") + trailing_block = render(trailing_needed, "") + if trailing_block: + trailing_block = ( + "\n# Cross-module imports (deferred to avoid import cycles)\n" + + trailing_block + ) + updated = ( + "".join(lines[: first_def - 1]) + + top_block + + "".join(lines[first_def - 1 :]) + + trailing_block + ) + path.write_text(updated) + + +def rebuild_widgets_init() -> None: + """Keep widgets/__init__.py as compatibility barrel.""" + header = '''"""GUI widgets package (canvas, controls, toolbars) + components re-exports.""" + +from ..components.palette import * # noqa: F403 +from ..components.progress import * # noqa: F403 +from ..components.buttons import * # noqa: F403 +from ..components.layout import * # noqa: F403 +from ..components.inputs_basic import * # noqa: F403 +from ..components.path_controls import * # noqa: F403 +from ..components.lists import * # noqa: F403 +from ..components.base import QBaseWindow, QBaseDialog # noqa: F401 + +''' + all_names: list[str] = [] + import_blocks: list[str] = [] + for area in ("canvas", "controls", "toolbars"): + init_path = WIDGETS / area / "__init__.py" + tree = ast.parse(init_path.read_text()) + names = [ + node.id + for node in tree.body + if isinstance(node, ast.ImportFrom) and node.module == area + for alias in node.names + ] + # parse from __all__ + for node in tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "__all__": + if isinstance(node.value, ast.List): + names = [ + elt.value + for elt in node.value.elts + if isinstance(elt, ast.Constant) + ] + import_blocks.append(f"from .{area} import (") + for name in names: + import_blocks.append(f" {name},") + all_names.append(name) + import_blocks.append(")") + import_blocks.append("") + + body = header + "\n".join(import_blocks) + "\n__all__ = [\n" + for name in all_names: + body += f' "{name}",\n' + body += "]\n" + (WIDGETS / "__init__.py").write_text(body) + + +def main() -> None: + # Pull toolbar classes that were left in controls.py into toolbars/_base. + controls_src = (WIDGETS / "controls.py").read_text() + toolbars_src = (WIDGETS / "toolbars.py").read_text() + _, controls_nodes = extract_nodes(controls_src) + _, toolbars_nodes = extract_nodes(toolbars_src) + controls_names = {n for n, _, _, _ in controls_nodes} + for name in ("ManualTrackingToolBar", "ManualBackgroundToolBar", "SavePointsLayerButton"): + if name in controls_names and name not in TOOLBARS_MODULES["_base"]: + TOOLBARS_MODULES["_base"].add(name) + + split_area(WIDGETS / "canvas.py", WIDGETS / "canvas", CANVAS_MODULES, "Canvas widgets") + split_area( + WIDGETS / "controls.py", + WIDGETS / "controls", + CONTROLS_MODULES, + "Composite controls", + ) + split_area( + WIDGETS / "toolbars.py", + WIDGETS / "toolbars", + TOOLBARS_MODULES, + "Toolbars", + ) + + for fname in ("canvas.py", "controls.py", "toolbars.py"): + (WIDGETS / fname).unlink() + + inject_widget_imports() + rebuild_widgets_init() + print("widgets/ subpackage split complete.") + + +if __name__ == "__main__": + main() diff --git a/tests/prompt_segm/test_sam.py b/tests/prompt_segm/test_sam.py index 70b07c9d5..55093dc6f 100644 --- a/tests/prompt_segm/test_sam.py +++ b/tests/prompt_segm/test_sam.py @@ -4,7 +4,7 @@ import pytest -from cellacdc import myutils +from cellacdc import utils from tests.utils import ( ensure_sam, get_test_dataset, @@ -20,7 +20,7 @@ class TestPromptableSAM: @pytest.fixture(scope="class", autouse=True) def download_models(self): """Download SAM models if not present.""" - myutils.download_model("segment_anything") + utils.download_model("segment_anything") @pytest.fixture def test_data(self): @@ -46,7 +46,7 @@ def test_promptable_segmentation_with_ground_truth_centroids(self, test_data): centroids = get_ground_truth_centroids(gt_mask) assert len(centroids) > 0, "No objects found in ground truth" - acdcPromptSegment = myutils.import_promptable_segment_module("segment_anything") + acdcPromptSegment = utils.import_promptable_segment_module("segment_anything") model = acdcPromptSegment.Model(model_type="Large", gpu=True) # Add prompts for each ground truth centroid @@ -70,7 +70,9 @@ def test_promptable_segmentation_with_ground_truth_centroids(self, test_data): plots_dir = Path(__file__).parent.parent / "_plots" / "prompt_segm" / "sam" save_segmentation_overlay( - labels, frame, frame_index, + labels, + frame, + frame_index, plots_dir / f"test_promptable_sam_frame_{frame_index:04d}.png", prompt_points=centroids, ) diff --git a/tests/prompt_segm/test_sam2.py b/tests/prompt_segm/test_sam2.py index 97ba4dc9e..bf71db5d4 100644 --- a/tests/prompt_segm/test_sam2.py +++ b/tests/prompt_segm/test_sam2.py @@ -4,7 +4,7 @@ import pytest -from cellacdc import myutils +from cellacdc import utils from tests.utils import ( ensure_sam2, get_test_dataset, @@ -20,7 +20,7 @@ class TestPromptableSAM2: @pytest.fixture(scope="class", autouse=True) def download_models(self): """Download SAM2 models if not present.""" - myutils.download_model("sam2") + utils.download_model("sam2") @pytest.fixture def test_data(self): @@ -46,7 +46,7 @@ def test_promptable_segmentation_with_ground_truth_centroids(self, test_data): centroids = get_ground_truth_centroids(gt_mask) assert len(centroids) > 0, "No objects found in ground truth" - acdcPromptSegment = myutils.import_promptable_segment_module("sam2") + acdcPromptSegment = utils.import_promptable_segment_module("sam2") model = acdcPromptSegment.Model(model_type="Large", gpu=True) # Add prompts for each ground truth centroid @@ -70,7 +70,9 @@ def test_promptable_segmentation_with_ground_truth_centroids(self, test_data): plots_dir = Path(__file__).parent.parent / "_plots" / "prompt_segm" / "sam2" save_segmentation_overlay( - labels, frame, frame_index, + labels, + frame, + frame_index, plots_dir / f"test_promptable_sam2_frame_{frame_index:04d}.png", prompt_points=centroids, ) diff --git a/tests/segm/test_cellsam.py b/tests/segm/test_cellsam.py index 17ce7da96..b28b67d6c 100644 --- a/tests/segm/test_cellsam.py +++ b/tests/segm/test_cellsam.py @@ -4,7 +4,7 @@ import pytest -from cellacdc import myutils +from cellacdc import utils from tests.utils import ( ensure_cellsam, get_test_posdata, @@ -34,7 +34,7 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): """Test CellSAM automatic segmentation on sampled frames (every 20th).""" frames, frame_indices = test_frames - acdcSegment = myutils.import_segment_module("cellsam") + acdcSegment = utils.import_segment_module("cellsam") model = acdcSegment.Model( model_type="General", @@ -59,6 +59,8 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): validate_labels(labels, frame.shape) print_segmentation_results(labels, frame, frame_i) save_segmentation_overlay( - labels, frame, frame_i, + labels, + frame, + frame_i, plots_dir / f"test_cellsam_segmentation_frame_{frame_i:04d}.png", ) diff --git a/tests/segm/test_sam.py b/tests/segm/test_sam.py index 40eb1507d..c3ca4ea1e 100644 --- a/tests/segm/test_sam.py +++ b/tests/segm/test_sam.py @@ -4,7 +4,7 @@ import pytest -from cellacdc import myutils +from cellacdc import utils from tests.utils import ( ensure_sam, get_test_posdata, @@ -23,7 +23,7 @@ class TestSAMAutomaticSegmentation: @pytest.fixture(scope="class", autouse=True) def download_models(self): """Download SAM models if not present.""" - myutils.download_model("segment_anything") + utils.download_model("segment_anything") @pytest.fixture def posData(self): @@ -39,7 +39,7 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): """Test SAM automatic segmentation on sampled frames.""" frames, frame_indices = test_frames - acdcSegment = myutils.import_segment_module("segment_anything") + acdcSegment = utils.import_segment_module("segment_anything") model = acdcSegment.Model( model_type="Small", @@ -67,6 +67,8 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): validate_labels(labels, frame.shape) print_segmentation_results(labels, frame, frame_i) save_segmentation_overlay( - labels, frame, frame_i, + labels, + frame, + frame_i, plots_dir / f"test_sam_segmentation_frame_{frame_i:04d}.png", ) diff --git a/tests/segm/test_sam2.py b/tests/segm/test_sam2.py index 78ecfc3e2..280c2e45f 100644 --- a/tests/segm/test_sam2.py +++ b/tests/segm/test_sam2.py @@ -4,7 +4,7 @@ import pytest -from cellacdc import myutils +from cellacdc import utils from tests.utils import ( ensure_sam2, get_test_posdata, @@ -23,7 +23,7 @@ class TestSAM2AutomaticSegmentation: @pytest.fixture(scope="class", autouse=True) def download_models(self): """Download SAM2 models if not present.""" - myutils.download_model("sam2") + utils.download_model("sam2") @pytest.fixture def posData(self): @@ -39,7 +39,7 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): """Test SAM2 automatic segmentation on sampled frames.""" frames, frame_indices = test_frames - acdcSegment = myutils.import_segment_module("sam2") + acdcSegment = utils.import_segment_module("sam2") model = acdcSegment.Model( model_type="Tiny", @@ -67,6 +67,8 @@ def test_automatic_segmentation_sampled_frames(self, test_frames, posData): validate_labels(labels, frame.shape) print_segmentation_results(labels, frame, frame_i) save_segmentation_overlay( - labels, frame, frame_i, + labels, + frame, + frame_i, plots_dir / f"test_sam2_segmentation_frame_{frame_i:04d}.png", ) diff --git a/tests/test_components_imports.py b/tests/test_components_imports.py new file mode 100644 index 000000000..4cd85f766 --- /dev/null +++ b/tests/test_components_imports.py @@ -0,0 +1,27 @@ +"""Smoke tests for component module imports.""" + +import importlib +import unittest + + +COMPONENT_MODULES = [ + "cellacdc.components.palette", + "cellacdc.components.base", + "cellacdc.components.inputs_basic", +] + + +class TestComponentImports(unittest.TestCase): + def test_leaf_component_modules_import(self): + for module_name in COMPONENT_MODULES: + with self.subTest(module=module_name): + importlib.import_module(module_name) + + def test_widgets_module_compiles(self): + import py_compile + + py_compile.compile("cellacdc/widgets/__init__.py", doraise=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data_source.py b/tests/test_data_source.py new file mode 100644 index 000000000..6dfc23f03 --- /dev/null +++ b/tests/test_data_source.py @@ -0,0 +1,160 @@ +"""Tests for in-memory data loading and array-based viewer API.""" + +from __future__ import annotations + +import importlib +import os +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from cellacdc.data_source import ( + ArrayDataSource, + ExperimentData, + normalize_volume, + pos_data_from_arrays, + pos_data_from_kwargs, +) + + +def test_normalize_volume_shapes(): + image = np.zeros((4, 32, 32), dtype=np.uint8) + arr, size_t, size_z = normalize_volume(image, axes="tyx") + assert arr.shape == (4, 32, 32) + assert size_t == 4 + assert size_z == 1 + + stack = np.zeros((5, 8, 16, 16), dtype=np.uint8) + arr, size_t, size_z = normalize_volume(stack, axes="tzyx") + assert arr.shape == (5, 8, 16, 16) + assert size_t == 5 + assert size_z == 8 + + +def test_experiment_data_from_arrays(tmp_path): + class DummyPosData: + def __init__(self, img_path, channel_name, **kwargs): + self.imgPath = img_path + self.user_ch_name = channel_name + self.images_path = str(tmp_path / "Images") + self.exp_path = str(tmp_path) + + def buildPaths(self): + self.metadata_csv_path = str(tmp_path / "metadata.csv") + self.segm_npz_path = str(tmp_path / "segm.npz") + + def loadOtherFiles(self, **kwargs): + pass + + def setBlankSegmData(self, size_t, size_z, size_y, size_x): + self.segm_data = np.zeros((size_y, size_x), dtype=np.uint32) + + def extractMetadata(self): + pass + + image = np.arange(32 * 32, dtype=np.uint16).reshape(32, 32) + data = ExperimentData.from_arrays( + image, + name="test", + channel_name="cells", + axes="yx", + workspace=tmp_path, + _load_data_cls=DummyPosData, + ) + + assert data.is_materialized + assert data.source == "memory" + pos = data.positions[0] + assert pos.SizeT == 1 + assert pos.img_data.shape == (1, 32, 32) + + +def test_experiment_data_from_path(tmp_path): + exp_path = tmp_path / "my_experiment" + exp_path.mkdir() + data = ExperimentData.from_path(exp_path) + + assert data.source == "path" + assert data.path == str(exp_path) + assert not data.is_materialized + + +def test_pos_data_from_arrays_without_labels(tmp_path): + class DummyPosData: + def __init__(self, img_path, channel_name, **kwargs): + self.imgPath = img_path + self.user_ch_name = channel_name + self.images_path = str(tmp_path / "Images") + self.exp_path = str(tmp_path) + + def buildPaths(self): + self.metadata_csv_path = str(tmp_path / "metadata.csv") + self.segm_npz_path = str(tmp_path / "segm.npz") + + def loadOtherFiles(self, **kwargs): + pass + + def setBlankSegmData(self, size_t, size_z, size_y, size_x): + self.segm_data = np.zeros((size_y, size_x), dtype=np.uint32) + + def extractMetadata(self): + pass + + image = np.arange(32 * 32, dtype=np.uint16).reshape(32, 32) + pos = pos_data_from_kwargs( + image, + name="test", + channel_name="cells", + axes="yx", + workspace=tmp_path, + _load_data_cls=DummyPosData, + ) + + assert pos.SizeT == 1 + assert pos.segmFound is False + + +def test_viewer_accepts_experiment_data(): + viewer_mod = importlib.import_module("cellacdc.viewer") + viewer_mod = importlib.reload(viewer_mod) + mock_win = MagicMock() + data = MagicMock() + data.is_materialized = True + + with ( + patch("cellacdc._event_loop.get_qapp", return_value=MagicMock()), + patch.object(viewer_mod, "_read_version", return_value="test"), + patch.object(viewer_mod, "_create_gui_window", return_value=mock_win), + patch.object(viewer_mod, "_check_gui_installed"), + ): + viewer = viewer_mod.Viewer(data, show=False) + + data.load_into.assert_called_once_with(mock_win) + assert viewer.data is data + + +def test_imshow_returns_viewer_and_experiment_data(tmp_path): + viewer_mod = importlib.import_module("cellacdc.viewer") + viewer_mod = importlib.reload(viewer_mod) + data = ExperimentData.from_path(tmp_path) + mock_viewer = MagicMock() + mock_viewer.data = data + + with patch.object(viewer_mod, "Viewer", return_value=mock_viewer) as mock_viewer_cls: + viewer, returned = viewer_mod.imshow(data) + + mock_viewer_cls.assert_called_once_with( + data, + show=True, + mode="Segmentation and Tracking", + ) + assert viewer is mock_viewer + assert returned is data + + +def test_lazy_exports_include_experiment_data(): + import cellacdc + + assert cellacdc.ExperimentData.__name__ == "ExperimentData" + assert cellacdc.imshow.__name__ == "imshow" diff --git a/tests/test_import_cellacdc.py b/tests/test_import_cellacdc.py index dc77a8623..5a87736de 100755 --- a/tests/test_import_cellacdc.py +++ b/tests/test_import_cellacdc.py @@ -7,6 +7,7 @@ from cellacdc import segm from cellacdc import dataPrep + def test_placeholder(): # Add test test_placeholder because we are only testing import pass diff --git a/tests/test_split_packages.py b/tests/test_split_packages.py new file mode 100644 index 000000000..3b46f3084 --- /dev/null +++ b/tests/test_split_packages.py @@ -0,0 +1,83 @@ +"""Smoke tests for split god-file packages.""" + +import py_compile +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + +PACKAGES = { + "cellacdc.tools": [ + "base", + "concat", + "align", + ], + "cellacdc.utils": [ + "logging", + "paths", + "install", + "dataframe", + ], + "cellacdc.workers": [ + "_base", + "segm", + "tracking", + "io", + ], + "cellacdc.widgets": [ + "canvas.histogram", + "canvas.imshow", + "controls.dialogs", + "controls.inputs", + "toolbars._base", + ], + "cellacdc.dialogs": [ + "_base", + "general", + "tracking", + "measurements", + ], +} + +SHIMS = [ + "cellacdc/apps.py", +] + + +class TestSplitPackages(unittest.TestCase): + def _module_path(self, module_name: str, leaf: str) -> Path: + base = ROOT / module_name.replace(".", "/") + return base / f"{leaf.replace('.', '/')}.py" + + def test_leaf_modules_compile(self): + for module_name in PACKAGES: + for leaf in PACKAGES[module_name]: + path = self._module_path(module_name, leaf) + with self.subTest(path=str(path)): + py_compile.compile(path, doraise=True) + + def test_package_init_modules_compile(self): + checked = set() + for module_name in PACKAGES: + base = ROOT / module_name.replace(".", "/") + paths = [base / "__init__.py"] + for leaf in PACKAGES[module_name]: + if "." in leaf: + subpkg = leaf.split(".", 1)[0] + paths.append(base / subpkg / "__init__.py") + for path in paths: + key = str(path) + if key in checked: + continue + checked.add(key) + with self.subTest(path=key): + py_compile.compile(path, doraise=True) + + def test_shim_modules_compile(self): + for rel_path in SHIMS: + with self.subTest(path=rel_path): + py_compile.compile(ROOT / rel_path, doraise=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_viewer_api.py b/tests/test_viewer_api.py new file mode 100644 index 000000000..a93c8d19b --- /dev/null +++ b/tests/test_viewer_api.py @@ -0,0 +1,77 @@ +"""Tests for the napari-style script API.""" + +from __future__ import annotations + +import importlib +from unittest.mock import MagicMock, patch + +import pytest + + +def _reload_viewer_module(): + import cellacdc.viewer + + return importlib.reload(cellacdc.viewer) + + +def test_viewer_sets_segmentation_and_tracking_mode(): + mock_win = MagicMock() + viewer_mod = _reload_viewer_module() + + with ( + patch("cellacdc._event_loop.get_qapp", return_value=MagicMock()), + patch.object(viewer_mod, "_read_version", return_value="test"), + patch.object(viewer_mod, "_create_gui_window", return_value=mock_win), + patch.object(viewer_mod, "_check_gui_installed"), + ): + viewer = viewer_mod.Viewer(show=False) + + mock_win.modeComboBox.setCurrentText.assert_called_once_with( + "Segmentation and Tracking" + ) + assert viewer.data is None + + +def test_run_warns_without_top_level_widgets(): + mock_app = MagicMock() + mock_app.topLevelWidgets.return_value = [] + mock_app.thread.return_value.loopLevel.return_value = 0 + + with ( + patch("cellacdc._event_loop._ipython_has_eventloop", return_value=False), + patch("qtpy.QtWidgets.QApplication") as mock_qapp_cls, + pytest.warns(UserWarning, match="Refusing to run a QApplication"), + ): + mock_qapp_cls.instance.return_value = mock_app + from cellacdc._event_loop import run + + run() + + +def test_run_starts_event_loop_when_widgets_exist(): + mock_app = MagicMock() + mock_app.topLevelWidgets.return_value = [MagicMock()] + mock_app.thread.return_value.loopLevel.return_value = 0 + + with ( + patch("cellacdc._event_loop._ipython_has_eventloop", return_value=False), + patch("qtpy.QtWidgets.QApplication") as mock_qapp_cls, + ): + mock_qapp_cls.instance.return_value = mock_app + from cellacdc._event_loop import run + + run() + + mock_app.exec_.assert_called_once() + + +def test_lazy_exports_from_package(): + import cellacdc + + assert cellacdc.Viewer.__name__ == "Viewer" + assert cellacdc.ExperimentData.__name__ == "ExperimentData" + assert cellacdc.current_viewer.__name__ == "current_viewer" + assert cellacdc.run.__name__ == "run" + assert cellacdc.get_qapp.__name__ == "get_qapp" + assert cellacdc.quit_app.__name__ == "quit_app" + assert cellacdc.imshow.__name__ == "imshow" diff --git a/tests/test_workflow_graph.py b/tests/test_workflow_graph.py new file mode 100644 index 000000000..0ccb66001 --- /dev/null +++ b/tests/test_workflow_graph.py @@ -0,0 +1,234 @@ +"""Tests for workflow graph modeling.""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path + + +def _bootstrap_workflow_package(): + root = Path(__file__).resolve().parents[1] + workflow_root = root / "cellacdc" / "workflow" + + cellacdc_pkg = sys.modules.get("cellacdc") + if cellacdc_pkg is None: + cellacdc_pkg = types.ModuleType("cellacdc") + cellacdc_pkg.__path__ = [str(root / "cellacdc")] + sys.modules["cellacdc"] = cellacdc_pkg + + workflow_pkg = types.ModuleType("cellacdc.workflow") + workflow_pkg.__path__ = [str(workflow_root)] + sys.modules["cellacdc.workflow"] = workflow_pkg + + for name in ("constants", "state", "runnable", "graph"): + module_name = f"cellacdc.workflow.{name}" + if module_name in sys.modules: + continue + spec = importlib.util.spec_from_file_location( + module_name, workflow_root / f"{name}.py" + ) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + setattr(workflow_pkg, name, module) + + return sys.modules["cellacdc.workflow"] + + +workflow = _bootstrap_workflow_package() +END = workflow.constants.END +PositionState = workflow.state.PositionState +WorkflowContext = workflow.state.WorkflowContext +RunnableConfig = workflow.runnable.RunnableConfig +StateGraph = workflow.graph.StateGraph + + +class TestSegmWorkflowGraph(unittest.TestCase): + def test_graph_structure_without_heavy_imports(self): + ctx = WorkflowContext(user_ch_name="phase", model_name="cellpose") + graph = StateGraph(PositionState, ctx) + graph.add_node("load_position", lambda s, c, cfg: {}) + graph.add_node("segment", lambda s, c, cfg: {}) + graph.set_entry_point("load_position") + graph.add_edge("load_position", "segment") + graph.add_edge("segment", END) + structure = graph.get_graph() + self.assertEqual(structure["entrypoint"], "load_position") + self.assertEqual(structure["edges"]["segment"], END) + + def test_compiled_graph_routes_to_end(self): + ctx = WorkflowContext(user_ch_name="phase", do_save=False) + logs: list[str] = [] + + def load_node(state, workflow_ctx, config): + logs.append("load") + return {} + + def save_node(state, workflow_ctx, config): + logs.append("save") + return {} + + graph = StateGraph(PositionState, ctx) + graph.add_node("load_position", load_node) + graph.add_node("save", save_node) + graph.set_entry_point("load_position") + graph.add_conditional_edges( + "load_position", + lambda _s, workflow_ctx: END if not workflow_ctx.do_save else "save", + {"save": "save", END: END}, + ) + graph.add_edge("save", END) + + compiled = graph.compile() + compiled.invoke( + PositionState(img_path="/tmp/test.tif"), + RunnableConfig(logger_func=logs.append), + ) + self.assertEqual(logs, ["load"]) + + def test_batch_graph_loops_over_paths(self): + invoked: list[str] = [] + + class _PositionGraph: + def invoke(self, state, config): + invoked.append(state.img_path) + return state + + BatchWorkflowContext = workflow.state.BatchWorkflowContext + BatchState = workflow.state.BatchState + + position_ctx = WorkflowContext(user_ch_name="phase") + batch_ctx = BatchWorkflowContext(position_ctx=position_ctx) + batch_ctx.position_graph = _PositionGraph() + + def process_position(state, ctx, config): + path = state.paths[state.current_index] + stop_frame_n = state.stop_frame_numbers[state.current_index] + result = ctx.position_graph.invoke( + PositionState(img_path=path, stop_frame_n=stop_frame_n), + config, + ) + return { + "results": [*state.results, result], + "current_index": state.current_index + 1, + } + + def route_batch(state, _ctx): + return END if state.current_index >= len(state.paths) else "process_position" + + graph = StateGraph(BatchState, batch_ctx) + graph.add_node("process_position", process_position) + graph.set_entry_point("process_position") + graph.add_conditional_edges( + "process_position", + route_batch, + {"process_position": "process_position", END: END}, + ) + graph.compile().invoke( + BatchState(paths=["/a.tif", "/b.tif"], stop_frame_numbers=[1, 2]), + RunnableConfig(), + ) + self.assertEqual(invoked, ["/a.tif", "/b.tif"]) + + def test_gui_measurements_batch_loops_and_stops_on_abort(self): + invoked: list[str] = [] + + class _GuiMeasurementsGraph: + def __init__(self, abort_on: str | None = None): + self.abort_on = abort_on + + def invoke(self, state, config): + invoked.append(state.img_path) + aborted = state.img_path == self.abort_on + return workflow.state.MeasurementsGuiState( + img_path=state.img_path, + aborted=aborted, + ) + + MeasurementsGuiBatchContext = workflow.state.MeasurementsGuiBatchContext + BatchState = workflow.state.BatchState + MeasurementsGuiContext = workflow.state.MeasurementsGuiContext + MeasurementsGuiState = workflow.state.MeasurementsGuiState + + class _Kernel: + setup_done = False + + @staticmethod + def log(msg): + pass + + batch_ctx = MeasurementsGuiBatchContext(kernel=_Kernel()) + + def build_graph(ctx, pos_data_loaded=False): + del pos_data_loaded + + class _Builder: + def compile(self): + return _GuiMeasurementsGraph(abort_on="/b.tif") + + return _Builder() + + def process_position(state, ctx, config): + path = state.paths[state.current_index] + stop_frame_n = state.stop_frame_numbers[state.current_index] + gui_ctx = MeasurementsGuiContext(kernel=ctx.kernel) + graph = build_graph(gui_ctx, pos_data_loaded=False).compile() + result = graph.invoke( + MeasurementsGuiState(img_path=path, stop_frame_n=stop_frame_n), + config, + ) + results = [*state.results, result] + aborted = bool(getattr(result, "aborted", False)) + return { + "results": results, + "current_index": state.current_index + 1, + "aborted": aborted or state.aborted, + } + + def route_batch(state, ctx): + if state.aborted or ctx.kernel.setup_done: + return END + if state.current_index >= len(state.paths): + return END + return "process_position" + + graph = StateGraph(BatchState, batch_ctx) + graph.add_node("process_position", process_position) + graph.set_entry_point("process_position") + graph.add_conditional_edges( + "process_position", + route_batch, + {"process_position": "process_position", END: END}, + ) + final = graph.compile().invoke( + BatchState(paths=["/a.tif", "/b.tif", "/c.tif"], stop_frame_numbers=[1, 1, 1]), + RunnableConfig(), + ) + self.assertEqual(invoked, ["/a.tif", "/b.tif"]) + self.assertTrue(final.aborted) + + def test_video_graph_structure(self): + graph = StateGraph( + workflow.state.InteractiveVideoSegmState, + None, + ) + steps = [ + "extend_segm_data", + "prepare_video_stack", + "segment_video_frames", + "finalize_video_run", + ] + for step in steps: + graph.add_node(step, lambda s, c, cfg: {}) + graph.set_entry_point("extend_segm_data") + for left, right in zip(steps, steps[1:]): + graph.add_edge(left, right) + graph.add_edge(steps[-1], END) + structure = graph.get_graph() + self.assertEqual(structure["entrypoint"], "extend_segm_data") + self.assertEqual(structure["edges"]["finalize_video_run"], END) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/segmentation.py b/tests/utils/segmentation.py index 0bdee1ca4..e34735a5a 100644 --- a/tests/utils/segmentation.py +++ b/tests/utils/segmentation.py @@ -95,18 +95,14 @@ def validate_labels(labels: np.ndarray, expected_shape: tuple): If validation fails. """ assert labels is not None, "Segmentation returned None" - assert isinstance(labels, np.ndarray), ( - f"Expected numpy array, got {type(labels)}" - ) + assert isinstance(labels, np.ndarray), f"Expected numpy array, got {type(labels)}" assert labels.shape == expected_shape, ( f"Shape mismatch: {labels.shape} != {expected_shape}" ) assert np.issubdtype(labels.dtype, np.integer), ( f"Expected integer dtype, got {labels.dtype}" ) - assert labels.min() >= 0, ( - f"Labels should be non-negative, got min={labels.min()}" - ) + assert labels.min() >= 0, f"Labels should be non-negative, got min={labels.min()}" def print_segmentation_results(labels: np.ndarray, frame: np.ndarray, frame_i: int): @@ -178,20 +174,27 @@ def save_segmentation_overlay( closest_idx = np.argmin(distances) y, x = coords[closest_idx] ax.text( - x, y, str(region.label), - color="white", fontsize=8, fontweight="bold", - ha="center", va="center", - path_effects=[ - patheffects.withStroke(linewidth=2, foreground="black") - ], + x, + y, + str(region.label), + color="white", + fontsize=8, + fontweight="bold", + ha="center", + va="center", + path_effects=[patheffects.withStroke(linewidth=2, foreground="black")], ) # Plot prompt points if provided if prompt_points: for label_id, y, x in prompt_points: ax.plot( - x, y, 'x', - color='red', markersize=8, markeredgewidth=2, + x, + y, + "x", + color="red", + markersize=8, + markeredgewidth=2, ) ax.set_title(f"Frame {frame_i} ({num_objects} objects)") @@ -235,6 +238,7 @@ def ensure_sam(): sys.path.insert(0, str(candidate)) import pytest + pytest.importorskip("segment_anything") @@ -253,18 +257,21 @@ def ensure_sam2(): sys.path.insert(0, str(candidate)) import pytest + pytest.importorskip("sam2") def ensure_cellsam(): """Ensure cellSAM is importable.""" import pytest + pytest.importorskip("cellSAM") def get_test_posdata(): """Get posData for the standard test dataset.""" from cellacdc import data + return data.MIA_KC_htb1_mCitrine().posData() @@ -277,6 +284,7 @@ def get_test_dataset(): Dataset object with access to images, segmentation, and metadata. """ from cellacdc import data + return data.MIA_KC_htb1_mCitrine()